Skip to main content

amaru_network/
connection.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{collections::BTreeMap, net::SocketAddr, num::NonZeroUsize, sync::Arc, time::Duration};
16
17use amaru_kernel::{NonEmptyBytes, Peer};
18use amaru_observability::{amaru::network, trace_span};
19use amaru_ouroboros::{ConnectionId, ConnectionProvider, ToSocketAddrs};
20use bytes::{Buf, BytesMut};
21use parking_lot::Mutex;
22use pure_stage::BoxFuture;
23use socket2::{Domain, Socket, Type};
24use tokio::{
25    io::{AsyncReadExt, AsyncWriteExt},
26    net::{
27        TcpListener, TcpStream,
28        tcp::{OwnedReadHalf, OwnedWriteHalf},
29    },
30    sync::{Mutex as AsyncMutex, mpsc},
31    task::JoinHandle,
32};
33use tracing::Instrument;
34
35use crate::socket_addr::resolve;
36
37pub struct Connection {
38    peer_addr: SocketAddr,
39    reader: Arc<AsyncMutex<(OwnedReadHalf, BytesMut)>>,
40    writer: Arc<AsyncMutex<OwnedWriteHalf>>,
41}
42
43impl Connection {
44    pub fn new(tcp_stream: TcpStream, read_buf_size: usize) -> std::io::Result<Self> {
45        tcp_stream.set_nodelay(true)?;
46        let (reader, writer) = tcp_stream.into_split();
47        let peer_addr = reader.peer_addr()?;
48        Ok(Self {
49            peer_addr,
50            reader: Arc::new(AsyncMutex::new((reader, BytesMut::with_capacity(read_buf_size)))),
51            writer: Arc::new(AsyncMutex::new(writer)),
52        })
53    }
54
55    pub fn peer_addr(&self) -> SocketAddr {
56        self.peer_addr
57    }
58}
59
60struct Connections {
61    connections: BTreeMap<ConnectionId, Connection>,
62}
63
64impl Connections {
65    fn new() -> Self {
66        Self { connections: BTreeMap::new() }
67    }
68
69    fn add_connection(&mut self, connection: Connection) -> ConnectionId {
70        let id = if let Some((&last_id, _)) = self.connections.iter().next_back() {
71            last_id.next()
72        } else {
73            ConnectionId::initial()
74        };
75        self.insert(id, connection);
76        id
77    }
78
79    fn insert(&mut self, id: ConnectionId, connection: Connection) {
80        self.connections.insert(id, connection);
81    }
82
83    fn get(&self, id: &ConnectionId) -> Option<&Connection> {
84        self.connections.get(id)
85    }
86
87    fn remove(&mut self, id: &ConnectionId) -> Option<Connection> {
88        self.connections.remove(id)
89    }
90}
91
92#[derive(Clone)]
93pub struct TokioConnections {
94    inner: Arc<Inner>,
95}
96
97struct Inner {
98    connections: Mutex<Connections>,
99    read_buf_size: usize,
100    incoming_tx: mpsc::Sender<PendingAccept>,
101    incoming_rx: AsyncMutex<mpsc::Receiver<PendingAccept>>,
102    tasks: Mutex<BTreeMap<SocketAddr, JoinHandle<()>>>,
103}
104
105impl Drop for Inner {
106    fn drop(&mut self) {
107        for (_, task) in self.tasks.lock().iter() {
108            task.abort();
109        }
110    }
111}
112
113impl TokioConnections {
114    pub fn new(read_buf_size: usize) -> Self {
115        let (incoming_tx, incoming_rx) = mpsc::channel(128);
116        let inner = Arc::new(Inner {
117            connections: Mutex::new(Connections::new()),
118            read_buf_size,
119            incoming_tx,
120            incoming_rx: AsyncMutex::new(incoming_rx),
121            tasks: Mutex::new(BTreeMap::new()),
122        });
123        Self { inner }
124    }
125}
126
127async fn connect(addr: Vec<SocketAddr>, resource: Arc<Inner>, timeout: Duration) -> std::io::Result<ConnectionId> {
128    let stream = tokio::time::timeout(timeout, TcpStream::connect(&*addr)).await??;
129    tracing::debug!(?addr, "connected");
130    let mut connections = resource.connections.lock();
131    let id = connections.add_connection(Connection::new(stream, resource.read_buf_size)?);
132    Ok(id)
133}
134
135impl ConnectionProvider for TokioConnections {
136    fn listen(&self, addr: SocketAddr) -> BoxFuture<'static, std::io::Result<SocketAddr>> {
137        let inner = self.inner.clone();
138
139        Box::pin(
140            async move {
141                // If a listener already exists for this address, abort and remove it.
142                // This allows supervised restarts to work correctly.
143                let existing_task = inner.tasks.lock().remove(&addr);
144                if let Some(task) = existing_task {
145                    tracing::info!(%addr, "aborting existing listener task for restart");
146                    task.abort();
147                    // Wait for the task to complete so the TcpListener is dropped and the port is released.
148                    let _ = task.await;
149                }
150
151                // Bind the listener with SO_REUSEADDR
152                let listener = bind_address(addr)?;
153                let local = listener.local_addr()?;
154                tracing::debug!(%local, "listening");
155
156                // Accept incoming connections and send them into the channel.
157                let incoming_tx = inner.incoming_tx.clone();
158                let task = tokio::spawn(
159                    // this task contains the listener and the sender, dropping them upon abort()
160                    async move {
161                        while let Ok((stream, peer_addr)) = listener.accept().await {
162                            let Ok(_) = incoming_tx.send(PendingAccept { stream, peer_addr }).await else {
163                                break;
164                            };
165                        }
166                        tracing::info!(%local, "accept loop stopped");
167                    }
168                    .instrument(trace_span!(network::connection::ACCEPT_LOOP, local = %local)),
169                );
170
171                inner.tasks.lock().insert(local, task);
172
173                Ok(local)
174            }
175            .instrument(trace_span!(network::connection::LISTEN, addr = %addr)),
176        )
177    }
178
179    /// NOTE: For now there is only one listener used in the tokio implementation so we don't need
180    /// to use the _listener_addr.
181    fn accept(&self, _listener_addr: SocketAddr) -> BoxFuture<'static, std::io::Result<(Peer, ConnectionId)>> {
182        let inner = self.inner.clone();
183
184        Box::pin(
185            async move {
186                let mut rx = inner.incoming_rx.lock().await;
187
188                #[expect(clippy::expect_used)]
189                let PendingAccept { stream, peer_addr } =
190                    rx.recv().await.expect("sender cannot be dropped since we hold Inner");
191                drop(rx);
192
193                tracing::debug!(%peer_addr, "accepted connection");
194                let id = inner.connections.lock().add_connection(Connection::new(stream, inner.read_buf_size)?);
195
196                Ok((Peer::from_addr(&peer_addr), id))
197            }
198            .instrument(trace_span!(network::connection::ACCEPT)),
199        )
200    }
201
202    fn connect(&self, addr: Vec<SocketAddr>, timeout: Duration) -> BoxFuture<'static, std::io::Result<ConnectionId>> {
203        let addr2 = addr.clone();
204        Box::pin(
205            connect(addr, self.inner.clone(), timeout)
206                .instrument(trace_span!(network::connection::CONNECT, addr = ?addr2)),
207        )
208    }
209
210    fn connect_addrs(
211        &self,
212        addr: ToSocketAddrs,
213        timeout: Duration,
214    ) -> BoxFuture<'static, std::io::Result<ConnectionId>> {
215        let resource = self.inner.clone();
216        let addr2 = addr.clone();
217        Box::pin(
218            async move {
219                let addr = resolve(addr).await?;
220                tracing::debug!(?addr, "resolved addresses");
221                connect(addr, resource, timeout).await
222            }
223            .instrument(trace_span!(network::connection::CONNECT_ADDRS, addr = ?addr2)),
224        )
225    }
226
227    fn send(&self, conn: ConnectionId, data: NonEmptyBytes) -> BoxFuture<'static, std::io::Result<()>> {
228        let resource = self.inner.clone();
229        let len = data.len();
230        Box::pin(
231            async move {
232                let connection = resource
233                    .connections
234                    .lock()
235                    .get(&conn)
236                    .ok_or_else(|| std::io::Error::other(format!("connection {conn} not found for send")))?
237                    .writer
238                    .clone();
239                tokio::time::timeout(Duration::from_secs(100), connection.lock().await.write_all(&data)).await??;
240                Ok(())
241            }
242            .instrument(trace_span!(network::connection::SEND, conn = %conn, len = len)),
243        )
244    }
245
246    fn recv(&self, conn: ConnectionId, bytes: NonZeroUsize) -> BoxFuture<'static, std::io::Result<NonEmptyBytes>> {
247        let resource = self.inner.clone();
248        Box::pin(
249            async move {
250                let connection = resource
251                    .connections
252                    .lock()
253                    .get(&conn)
254                    .ok_or_else(|| std::io::Error::other(format!("connection {conn} not found for recv")))?
255                    .reader
256                    .clone();
257                let mut guard = connection.lock().await;
258                let (reader, buf) = &mut *guard;
259                buf.reserve(bytes.get() - buf.remaining().min(bytes.get()));
260                while buf.remaining() < bytes.get() {
261                    if reader.read_buf(buf).await? == 0 {
262                        return Err(std::io::ErrorKind::UnexpectedEof.into());
263                    };
264                }
265                #[expect(clippy::expect_used)]
266                Ok(buf.copy_to_bytes(bytes.get()).try_into().expect("guaranteed by NonZeroUsize"))
267            }
268            .instrument(trace_span!(network::connection::RECV, conn = %conn, bytes = bytes)),
269        )
270    }
271
272    fn close(&self, conn: ConnectionId) -> BoxFuture<'static, std::io::Result<()>> {
273        let resource = self.inner.clone();
274        Box::pin(
275            async move {
276                let connection = resource.connections.lock().remove(&conn).ok_or_else(|| {
277                    // TODO: figure out how to not raise an error for a connection that has simply been closed already
278                    std::io::Error::other(format!("connection {conn} not found for close"))
279                })?;
280                connection.writer.lock().await.shutdown().await?;
281                Ok(())
282            }
283            .instrument(trace_span!(network::connection::CLOSE, conn = %conn)),
284        )
285    }
286}
287
288/// Local sruct holding a pending accepted connection
289/// until it is picked up by and accept call and added to the list of connections.
290struct PendingAccept {
291    stream: TcpStream,
292    peer_addr: SocketAddr,
293}
294
295/// Binds a TCP listener to the specified address with
296/// `SO_REUSEADDR` enabled.
297fn bind_address(addr: SocketAddr) -> std::io::Result<TcpListener> {
298    let domain = match addr {
299        SocketAddr::V4(_) => Domain::IPV4,
300        SocketAddr::V6(_) => Domain::IPV6,
301    };
302
303    let socket = Socket::new(domain, Type::STREAM, None)?;
304
305    // Allow rebinding to a port that was recently used (e.g., still in TIME_WAIT).
306    socket.set_reuse_address(true)?;
307
308    socket.bind(&addr.into())?;
309    socket.listen(1024)?;
310
311    socket.set_nonblocking(true)?;
312    TcpListener::from_std(socket.into())
313}
314
315#[cfg(test)]
316mod tests {
317    use bytes::Bytes;
318    use tokio::{task::JoinHandle, time::timeout};
319
320    use super::*;
321
322    #[tokio::test]
323    async fn connect_to_a_server() -> anyhow::Result<()> {
324        // Start a TCP listener that echoes "pong" when it receives "ping".
325        let listener = TcpListener::bind(("127.0.0.1", 0)).await?;
326        let addr = listener.local_addr()?;
327        let server: JoinHandle<std::io::Result<()>> = tokio::spawn(async move {
328            let (mut stream, _peer) = listener.accept().await?;
329
330            let mut buf = [0u8; 4];
331            stream.read_exact(&mut buf).await?;
332            assert_eq!(&buf, b"ping");
333
334            stream.write_all(b"pong").await?;
335            Ok(())
336        });
337
338        // Use TokioConnections to connect to the listener.
339        let connections = TokioConnections::new(1024);
340        let connection_id = connections.connect(vec![addr], Duration::from_secs(1)).await?;
341        connections.send(connection_id, non_empty(b"ping")).await?;
342        let reply = connections.recv(connection_id, const { NonZeroUsize::new(4).unwrap() }).await?;
343        assert_eq!(reply.as_ref(), b"pong");
344
345        connections.close(connection_id).await?;
346        server.await.expect("server task panicked")?;
347
348        Ok(())
349    }
350
351    #[tokio::test]
352    async fn bind_and_accept_a_client_connection() -> anyhow::Result<()> {
353        // Create a TokioConnections instance and bind a TCP listener
354        // to an ephemeral port.
355        let connections = TokioConnections::new(1024);
356
357        let listen_addr = SocketAddr::from(([127, 0, 0, 1], 0));
358        let addr = connections.listen(listen_addr).await?;
359
360        // Start a client that connects to the listener and
361        // sends "hello", expecting "world" in response.
362        let client: JoinHandle<std::io::Result<()>> = tokio::spawn(async move {
363            let mut stream = TcpStream::connect(addr).await?;
364            stream.write_all(b"hello").await?;
365
366            let mut buf = String::new();
367            stream.read_to_string(&mut buf).await?;
368            assert_eq!(&buf, "world");
369
370            Ok(())
371        });
372
373        // Receive "hello" from the client and respond with "world".
374        let connection_id = timeout(Duration::from_secs(1), connections.accept(listen_addr)).await??.1;
375        let result = connections.recv(connection_id, const { NonZeroUsize::new(5).unwrap() }).await?;
376        assert_eq!(result.as_ref(), b"hello");
377
378        connections.send(connection_id, non_empty(b"world")).await?;
379        connections.close(connection_id).await?;
380
381        client.await.expect("client task panicked")?;
382        Ok(())
383    }
384
385    // HELPERS
386
387    fn non_empty(data: &'static [u8]) -> NonEmptyBytes {
388        Bytes::from_static(data).try_into().expect("test data must be non-empty")
389    }
390}