1use std::{collections::BTreeMap, net::SocketAddr, num::NonZeroUsize, sync::Arc, time::Duration};
16
17use amaru_kernel::{NonEmptyBytes, Peer};
18use amaru_observability::{amaru::network, trace_span};
19use amaru_ouroboros::{ConnectionId, ConnectionProvider, ToSocketAddrs};
20use bytes::{Buf, BytesMut};
21use parking_lot::Mutex;
22use pure_stage::BoxFuture;
23use socket2::{Domain, Socket, Type};
24use tokio::{
25 io::{AsyncReadExt, AsyncWriteExt},
26 net::{
27 TcpListener, TcpStream,
28 tcp::{OwnedReadHalf, OwnedWriteHalf},
29 },
30 sync::{Mutex as AsyncMutex, mpsc},
31 task::JoinHandle,
32};
33use tracing::Instrument;
34
35use crate::socket_addr::resolve;
36
37pub struct Connection {
38 peer_addr: SocketAddr,
39 reader: Arc<AsyncMutex<(OwnedReadHalf, BytesMut)>>,
40 writer: Arc<AsyncMutex<OwnedWriteHalf>>,
41}
42
43impl Connection {
44 pub fn new(tcp_stream: TcpStream, read_buf_size: usize) -> std::io::Result<Self> {
45 tcp_stream.set_nodelay(true)?;
46 let (reader, writer) = tcp_stream.into_split();
47 let peer_addr = reader.peer_addr()?;
48 Ok(Self {
49 peer_addr,
50 reader: Arc::new(AsyncMutex::new((reader, BytesMut::with_capacity(read_buf_size)))),
51 writer: Arc::new(AsyncMutex::new(writer)),
52 })
53 }
54
55 pub fn peer_addr(&self) -> SocketAddr {
56 self.peer_addr
57 }
58}
59
60struct Connections {
61 connections: BTreeMap<ConnectionId, Connection>,
62}
63
64impl Connections {
65 fn new() -> Self {
66 Self { connections: BTreeMap::new() }
67 }
68
69 fn add_connection(&mut self, connection: Connection) -> ConnectionId {
70 let id = if let Some((&last_id, _)) = self.connections.iter().next_back() {
71 last_id.next()
72 } else {
73 ConnectionId::initial()
74 };
75 self.insert(id, connection);
76 id
77 }
78
79 fn insert(&mut self, id: ConnectionId, connection: Connection) {
80 self.connections.insert(id, connection);
81 }
82
83 fn get(&self, id: &ConnectionId) -> Option<&Connection> {
84 self.connections.get(id)
85 }
86
87 fn remove(&mut self, id: &ConnectionId) -> Option<Connection> {
88 self.connections.remove(id)
89 }
90}
91
92#[derive(Clone)]
93pub struct TokioConnections {
94 inner: Arc<Inner>,
95}
96
97struct Inner {
98 connections: Mutex<Connections>,
99 read_buf_size: usize,
100 incoming_tx: mpsc::Sender<PendingAccept>,
101 incoming_rx: AsyncMutex<mpsc::Receiver<PendingAccept>>,
102 tasks: Mutex<BTreeMap<SocketAddr, JoinHandle<()>>>,
103}
104
105impl Drop for Inner {
106 fn drop(&mut self) {
107 for (_, task) in self.tasks.lock().iter() {
108 task.abort();
109 }
110 }
111}
112
113impl TokioConnections {
114 pub fn new(read_buf_size: usize) -> Self {
115 let (incoming_tx, incoming_rx) = mpsc::channel(128);
116 let inner = Arc::new(Inner {
117 connections: Mutex::new(Connections::new()),
118 read_buf_size,
119 incoming_tx,
120 incoming_rx: AsyncMutex::new(incoming_rx),
121 tasks: Mutex::new(BTreeMap::new()),
122 });
123 Self { inner }
124 }
125}
126
127async fn connect(addr: Vec<SocketAddr>, resource: Arc<Inner>, timeout: Duration) -> std::io::Result<ConnectionId> {
128 let stream = tokio::time::timeout(timeout, TcpStream::connect(&*addr)).await??;
129 tracing::debug!(?addr, "connected");
130 let mut connections = resource.connections.lock();
131 let id = connections.add_connection(Connection::new(stream, resource.read_buf_size)?);
132 Ok(id)
133}
134
135impl ConnectionProvider for TokioConnections {
136 fn listen(&self, addr: SocketAddr) -> BoxFuture<'static, std::io::Result<SocketAddr>> {
137 let inner = self.inner.clone();
138
139 Box::pin(
140 async move {
141 let existing_task = inner.tasks.lock().remove(&addr);
144 if let Some(task) = existing_task {
145 tracing::info!(%addr, "aborting existing listener task for restart");
146 task.abort();
147 let _ = task.await;
149 }
150
151 let listener = bind_address(addr)?;
153 let local = listener.local_addr()?;
154 tracing::debug!(%local, "listening");
155
156 let incoming_tx = inner.incoming_tx.clone();
158 let task = tokio::spawn(
159 async move {
161 while let Ok((stream, peer_addr)) = listener.accept().await {
162 let Ok(_) = incoming_tx.send(PendingAccept { stream, peer_addr }).await else {
163 break;
164 };
165 }
166 tracing::info!(%local, "accept loop stopped");
167 }
168 .instrument(trace_span!(network::connection::ACCEPT_LOOP, local = %local)),
169 );
170
171 inner.tasks.lock().insert(local, task);
172
173 Ok(local)
174 }
175 .instrument(trace_span!(network::connection::LISTEN, addr = %addr)),
176 )
177 }
178
179 fn accept(&self, _listener_addr: SocketAddr) -> BoxFuture<'static, std::io::Result<(Peer, ConnectionId)>> {
182 let inner = self.inner.clone();
183
184 Box::pin(
185 async move {
186 let mut rx = inner.incoming_rx.lock().await;
187
188 #[expect(clippy::expect_used)]
189 let PendingAccept { stream, peer_addr } =
190 rx.recv().await.expect("sender cannot be dropped since we hold Inner");
191 drop(rx);
192
193 tracing::debug!(%peer_addr, "accepted connection");
194 let id = inner.connections.lock().add_connection(Connection::new(stream, inner.read_buf_size)?);
195
196 Ok((Peer::from_addr(&peer_addr), id))
197 }
198 .instrument(trace_span!(network::connection::ACCEPT)),
199 )
200 }
201
202 fn connect(&self, addr: Vec<SocketAddr>, timeout: Duration) -> BoxFuture<'static, std::io::Result<ConnectionId>> {
203 let addr2 = addr.clone();
204 Box::pin(
205 connect(addr, self.inner.clone(), timeout)
206 .instrument(trace_span!(network::connection::CONNECT, addr = ?addr2)),
207 )
208 }
209
210 fn connect_addrs(
211 &self,
212 addr: ToSocketAddrs,
213 timeout: Duration,
214 ) -> BoxFuture<'static, std::io::Result<ConnectionId>> {
215 let resource = self.inner.clone();
216 let addr2 = addr.clone();
217 Box::pin(
218 async move {
219 let addr = resolve(addr).await?;
220 tracing::debug!(?addr, "resolved addresses");
221 connect(addr, resource, timeout).await
222 }
223 .instrument(trace_span!(network::connection::CONNECT_ADDRS, addr = ?addr2)),
224 )
225 }
226
227 fn send(&self, conn: ConnectionId, data: NonEmptyBytes) -> BoxFuture<'static, std::io::Result<()>> {
228 let resource = self.inner.clone();
229 let len = data.len();
230 Box::pin(
231 async move {
232 let connection = resource
233 .connections
234 .lock()
235 .get(&conn)
236 .ok_or_else(|| std::io::Error::other(format!("connection {conn} not found for send")))?
237 .writer
238 .clone();
239 tokio::time::timeout(Duration::from_secs(100), connection.lock().await.write_all(&data)).await??;
240 Ok(())
241 }
242 .instrument(trace_span!(network::connection::SEND, conn = %conn, len = len)),
243 )
244 }
245
246 fn recv(&self, conn: ConnectionId, bytes: NonZeroUsize) -> BoxFuture<'static, std::io::Result<NonEmptyBytes>> {
247 let resource = self.inner.clone();
248 Box::pin(
249 async move {
250 let connection = resource
251 .connections
252 .lock()
253 .get(&conn)
254 .ok_or_else(|| std::io::Error::other(format!("connection {conn} not found for recv")))?
255 .reader
256 .clone();
257 let mut guard = connection.lock().await;
258 let (reader, buf) = &mut *guard;
259 buf.reserve(bytes.get() - buf.remaining().min(bytes.get()));
260 while buf.remaining() < bytes.get() {
261 if reader.read_buf(buf).await? == 0 {
262 return Err(std::io::ErrorKind::UnexpectedEof.into());
263 };
264 }
265 #[expect(clippy::expect_used)]
266 Ok(buf.copy_to_bytes(bytes.get()).try_into().expect("guaranteed by NonZeroUsize"))
267 }
268 .instrument(trace_span!(network::connection::RECV, conn = %conn, bytes = bytes)),
269 )
270 }
271
272 fn close(&self, conn: ConnectionId) -> BoxFuture<'static, std::io::Result<()>> {
273 let resource = self.inner.clone();
274 Box::pin(
275 async move {
276 let connection = resource.connections.lock().remove(&conn).ok_or_else(|| {
277 std::io::Error::other(format!("connection {conn} not found for close"))
279 })?;
280 connection.writer.lock().await.shutdown().await?;
281 Ok(())
282 }
283 .instrument(trace_span!(network::connection::CLOSE, conn = %conn)),
284 )
285 }
286}
287
288struct PendingAccept {
291 stream: TcpStream,
292 peer_addr: SocketAddr,
293}
294
295fn bind_address(addr: SocketAddr) -> std::io::Result<TcpListener> {
298 let domain = match addr {
299 SocketAddr::V4(_) => Domain::IPV4,
300 SocketAddr::V6(_) => Domain::IPV6,
301 };
302
303 let socket = Socket::new(domain, Type::STREAM, None)?;
304
305 socket.set_reuse_address(true)?;
307
308 socket.bind(&addr.into())?;
309 socket.listen(1024)?;
310
311 socket.set_nonblocking(true)?;
312 TcpListener::from_std(socket.into())
313}
314
315#[cfg(test)]
316mod tests {
317 use bytes::Bytes;
318 use tokio::{task::JoinHandle, time::timeout};
319
320 use super::*;
321
322 #[tokio::test]
323 async fn connect_to_a_server() -> anyhow::Result<()> {
324 let listener = TcpListener::bind(("127.0.0.1", 0)).await?;
326 let addr = listener.local_addr()?;
327 let server: JoinHandle<std::io::Result<()>> = tokio::spawn(async move {
328 let (mut stream, _peer) = listener.accept().await?;
329
330 let mut buf = [0u8; 4];
331 stream.read_exact(&mut buf).await?;
332 assert_eq!(&buf, b"ping");
333
334 stream.write_all(b"pong").await?;
335 Ok(())
336 });
337
338 let connections = TokioConnections::new(1024);
340 let connection_id = connections.connect(vec![addr], Duration::from_secs(1)).await?;
341 connections.send(connection_id, non_empty(b"ping")).await?;
342 let reply = connections.recv(connection_id, const { NonZeroUsize::new(4).unwrap() }).await?;
343 assert_eq!(reply.as_ref(), b"pong");
344
345 connections.close(connection_id).await?;
346 server.await.expect("server task panicked")?;
347
348 Ok(())
349 }
350
351 #[tokio::test]
352 async fn bind_and_accept_a_client_connection() -> anyhow::Result<()> {
353 let connections = TokioConnections::new(1024);
356
357 let listen_addr = SocketAddr::from(([127, 0, 0, 1], 0));
358 let addr = connections.listen(listen_addr).await?;
359
360 let client: JoinHandle<std::io::Result<()>> = tokio::spawn(async move {
363 let mut stream = TcpStream::connect(addr).await?;
364 stream.write_all(b"hello").await?;
365
366 let mut buf = String::new();
367 stream.read_to_string(&mut buf).await?;
368 assert_eq!(&buf, "world");
369
370 Ok(())
371 });
372
373 let connection_id = timeout(Duration::from_secs(1), connections.accept(listen_addr)).await??.1;
375 let result = connections.recv(connection_id, const { NonZeroUsize::new(5).unwrap() }).await?;
376 assert_eq!(result.as_ref(), b"hello");
377
378 connections.send(connection_id, non_empty(b"world")).await?;
379 connections.close(connection_id).await?;
380
381 client.await.expect("client task panicked")?;
382 Ok(())
383 }
384
385 fn non_empty(data: &'static [u8]) -> NonEmptyBytes {
388 Bytes::from_static(data).try_into().expect("test data must be non-empty")
389 }
390}