Skip to main content

pallas_network2/
interface.rs

1use std::{
2    collections::HashMap,
3    future::ready,
4    pin::Pin,
5    sync::Arc,
6    task::{Context, Poll},
7};
8
9use futures::{
10    Stream, StreamExt,
11    stream::{FusedStream, FuturesUnordered},
12};
13
14use tokio::{sync::Mutex, time::Instant};
15
16use crate::{
17    Channel, Interface, InterfaceCommand, InterfaceError, InterfaceEvent, Message, Payload, PeerId,
18    bearer::{Bearer, BearerReadHalf, BearerWriteHalf, Timestamp},
19};
20
21enum InternalEvent<M: Message> {
22    Connected(PeerId, Bearer),
23    Disconnected(PeerId),
24    Sent(PeerId, M),
25    Recv(PeerId, Vec<M>, BearerReadHalf, ChunkBuffer),
26    Error(PeerId, tokio::io::Error),
27}
28
29type InterfaceFuture<M> = Pin<Box<dyn Future<Output = InternalEvent<M>> + Send>>;
30
31async fn connect<M: Message>(pid: PeerId) -> InternalEvent<M> {
32    let pid = pid.clone();
33
34    tracing::debug!(%pid, "connecting bearer");
35    let bearer = Bearer::connect_tcp((pid.host.clone(), pid.port)).await;
36
37    match bearer {
38        Ok(bearer) => InternalEvent::Connected(pid.clone(), bearer),
39        Err(e) => InternalEvent::Error(pid.clone(), e),
40    }
41}
42
43async fn send<M: Message>(
44    pid: PeerId,
45    writer: SharedWriter,
46    msg: M,
47    ts: Timestamp,
48    mode: u16,
49) -> InternalEvent<M> {
50    let pid = pid.clone();
51    let copy = msg.clone();
52
53    let mut writer = writer.lock().await;
54
55    let result = writer.write_message(msg, ts, mode).await;
56
57    match result {
58        Ok(_) => InternalEvent::Sent(pid.clone(), copy),
59        Err(e) => InternalEvent::Error(pid.clone(), e),
60    }
61}
62
63/// Buffer of partial payload chunks keyed by channel, used to reassemble
64/// multi-segment messages.
65pub type ChunkBuffer = HashMap<Channel, Payload>;
66
67async fn recv<M: Message>(
68    pid: PeerId,
69    mut reader: BearerReadHalf,
70    mut partial_chunks: ChunkBuffer,
71) -> InternalEvent<M> {
72    let pid = pid.clone();
73
74    let result = reader.read_full_msgs(&mut partial_chunks).await;
75
76    match result {
77        Ok(msgs) => InternalEvent::Recv(pid.clone(), msgs, reader, partial_chunks),
78        Err(e) => InternalEvent::Error(pid.clone(), e),
79    }
80}
81
82async fn disconnect<M: Message>(pid: PeerId, writer: SharedWriter) -> InternalEvent<M> {
83    let pid = pid.clone();
84
85    let mut writer = writer.lock().await;
86
87    writer.shutdown().await.unwrap();
88
89    InternalEvent::Disconnected(pid.clone())
90}
91
92/// A thread-safe handle to a shared [`BearerWriteHalf`].
93pub type SharedWriter = Arc<Mutex<BearerWriteHalf>>;
94
95// ---------------------------------------------------------------------------
96// TcpConnectionPool — shared connection-management logic
97// ---------------------------------------------------------------------------
98
99struct TcpConnectionPool<M: Message> {
100    futures: FuturesUnordered<InterfaceFuture<M>>,
101    writers: HashMap<PeerId, SharedWriter>,
102    clock: Instant,
103    /// The mode bit to set on outgoing segments (0 for initiator, PROTOCOL_SERVER for responder).
104    mode: u16,
105}
106
107impl<M: Message> TcpConnectionPool<M> {
108    fn new(mode: u16) -> Self {
109        Self {
110            futures: FuturesUnordered::new(),
111            writers: HashMap::new(),
112            clock: Instant::now(),
113            mode,
114        }
115    }
116
117    fn push_future(&mut self, f: InterfaceFuture<M>) {
118        self.futures.push(f);
119    }
120
121    fn take_writer(&mut self, pid: &PeerId) -> Option<SharedWriter> {
122        self.writers.get(pid).cloned()
123    }
124
125    fn on_connected(&mut self, pid: PeerId, bearer: Bearer) -> InterfaceEvent<M> {
126        let (read, write) = bearer.into_split();
127
128        self.writers
129            .insert(pid.clone(), Arc::new(Mutex::new(write)));
130
131        let future = recv(pid.clone(), read, HashMap::new());
132        self.futures.push(Box::pin(future));
133
134        InterfaceEvent::Connected(pid)
135    }
136
137    fn on_disconnected(&mut self, pid: PeerId) -> InterfaceEvent<M> {
138        self.writers.remove(&pid);
139        InterfaceEvent::Disconnected(pid)
140    }
141
142    fn on_sent(&mut self, pid: PeerId, msg: M) -> InterfaceEvent<M> {
143        InterfaceEvent::Sent(pid, msg)
144    }
145
146    fn on_recv(
147        &mut self,
148        pid: PeerId,
149        msgs: Vec<M>,
150        reader: BearerReadHalf,
151        partial_chunks: ChunkBuffer,
152    ) -> InterfaceEvent<M> {
153        let future = recv(pid.clone(), reader, partial_chunks);
154        self.futures.push(Box::pin(future));
155
156        InterfaceEvent::Recv(pid, msgs)
157    }
158
159    fn on_error(&mut self, pid: PeerId, error: tokio::io::Error) -> InterfaceEvent<M> {
160        tracing::error!("error: {:?}", error);
161        InterfaceEvent::Error(pid, InterfaceError::Other(error.to_string()))
162    }
163
164    fn handle_internal_event(&mut self, event: InternalEvent<M>) -> InterfaceEvent<M> {
165        match event {
166            InternalEvent::Connected(pid, stream) => self.on_connected(pid, stream),
167            InternalEvent::Sent(pid, msg) => self.on_sent(pid, msg),
168            InternalEvent::Recv(pid, msgs, stream, buf) => self.on_recv(pid, msgs, stream, buf),
169            InternalEvent::Disconnected(pid) => self.on_disconnected(pid),
170            InternalEvent::Error(pid, error) => self.on_error(pid, error),
171        }
172    }
173
174    fn dispatch_send(&mut self, pid: PeerId, msg: M) {
175        let ts = self.clock.elapsed().as_micros() as u32;
176
177        let Some(writer) = self.take_writer(&pid) else {
178            tracing::error!(%pid, "trying to send to a peer not connected");
179            return;
180        };
181
182        let future = send(pid, writer, msg, ts, self.mode);
183        self.futures.push(Box::pin(future));
184    }
185
186    fn dispatch_disconnect(&mut self, pid: PeerId) {
187        let Some(stream) = self.take_writer(&pid) else {
188            tracing::warn!(%pid, "trying to disconnect a peer not connected");
189            self.futures
190                .push(Box::pin(ready(InternalEvent::Disconnected(pid.clone()))));
191            return;
192        };
193
194        let future = disconnect(pid, stream);
195        self.futures.push(Box::pin(future));
196    }
197
198    fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Option<InterfaceEvent<M>>> {
199        let event = self.futures.poll_next_unpin(cx);
200
201        match event {
202            Poll::Ready(Some(event)) => {
203                let event = self.handle_internal_event(event);
204                Poll::Ready(Some(event))
205            }
206            Poll::Ready(None) => Poll::Pending,
207            Poll::Pending => Poll::Pending,
208        }
209    }
210}
211
212// ---------------------------------------------------------------------------
213// TcpInterface — outbound connections
214// ---------------------------------------------------------------------------
215
216/// A network interface that initiates outbound TCP connections to peers.
217///
218/// Implements [`Interface`] by managing a pool of TCP connections and
219/// dispatching connect/send/disconnect commands.
220pub struct TcpInterface<M: Message> {
221    pool: TcpConnectionPool<M>,
222}
223
224impl<M: Message> Default for TcpInterface<M> {
225    fn default() -> Self {
226        Self::new()
227    }
228}
229
230impl<M: Message> TcpInterface<M> {
231    /// Creates a new TCP interface for initiating outbound connections.
232    pub fn new() -> Self {
233        Self {
234            pool: TcpConnectionPool::new(crate::protocol::PROTOCOL_CLIENT),
235        }
236    }
237}
238
239impl<M: Message> Interface<M> for TcpInterface<M> {
240    fn dispatch(&mut self, cmd: InterfaceCommand<M>) {
241        match cmd {
242            InterfaceCommand::Connect(pid) => {
243                let future = connect(pid.clone());
244                self.pool.push_future(Box::pin(future));
245            }
246            InterfaceCommand::Send(pid, msg) => {
247                self.pool.dispatch_send(pid, msg);
248            }
249            InterfaceCommand::Disconnect(pid) => {
250                self.pool.dispatch_disconnect(pid);
251            }
252        }
253    }
254}
255
256impl<M: Message> FusedStream for TcpInterface<M> {
257    fn is_terminated(&self) -> bool {
258        false
259    }
260}
261
262impl<M: Message> Stream for TcpInterface<M> {
263    type Item = InterfaceEvent<M>;
264
265    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
266        self.pool.poll_next_event(cx)
267    }
268}
269
270// ---------------------------------------------------------------------------
271// TcpListenerInterface — inbound connections via a bound TCP listener
272// ---------------------------------------------------------------------------
273
274async fn accept_tcp<M: Message>(listener: Arc<tokio::net::TcpListener>) -> InternalEvent<M> {
275    match Bearer::accept_tcp(&listener).await {
276        Ok((bearer, addr)) => {
277            let pid = PeerId {
278                host: addr.ip().to_string(),
279                port: addr.port(),
280            };
281            tracing::info!(%pid, "accepted inbound connection");
282            InternalEvent::Connected(pid, bearer)
283        }
284        Err(e) => {
285            tracing::error!("accept error: {:?}", e);
286            // Use a sentinel peer id for accept errors
287            let pid = PeerId {
288                host: "accept-error".to_string(),
289                port: 0,
290            };
291            InternalEvent::Error(pid, e)
292        }
293    }
294}
295
296/// A network interface that accepts inbound TCP connections from a bound
297/// listener.
298///
299/// Implements [`Interface`] by continuously accepting new connections and
300/// managing the resulting peer sessions. Outbound `Connect` commands are
301/// ignored since connections are initiated by remote peers.
302pub struct TcpListenerInterface<M: Message> {
303    pool: TcpConnectionPool<M>,
304    listener: Arc<tokio::net::TcpListener>,
305    accept_fut: InterfaceFuture<M>,
306}
307
308impl<M: Message> TcpListenerInterface<M> {
309    /// Creates a new listener interface that will accept connections on the
310    /// given [`TcpListener`](tokio::net::TcpListener).
311    pub fn new(listener: tokio::net::TcpListener) -> Self {
312        let listener = Arc::new(listener);
313        let accept_fut = Box::pin(accept_tcp(Arc::clone(&listener)));
314
315        Self {
316            pool: TcpConnectionPool::new(crate::protocol::PROTOCOL_SERVER),
317            listener,
318            accept_fut,
319        }
320    }
321}
322
323impl<M: Message> Interface<M> for TcpListenerInterface<M> {
324    fn dispatch(&mut self, cmd: InterfaceCommand<M>) {
325        match cmd {
326            InterfaceCommand::Connect(pid) => {
327                tracing::warn!(%pid, "TcpListenerInterface does not support outbound Connect, ignoring");
328            }
329            InterfaceCommand::Send(pid, msg) => {
330                self.pool.dispatch_send(pid, msg);
331            }
332            InterfaceCommand::Disconnect(pid) => {
333                self.pool.dispatch_disconnect(pid);
334            }
335        }
336    }
337}
338
339impl<M: Message> FusedStream for TcpListenerInterface<M> {
340    fn is_terminated(&self) -> bool {
341        false
342    }
343}
344
345impl<M: Message> Stream for TcpListenerInterface<M> {
346    type Item = InterfaceEvent<M>;
347
348    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
349        // First, poll the accept future for new inbound connections
350        if let Poll::Ready(event) = self.accept_fut.as_mut().poll(cx) {
351            let ie = self.pool.handle_internal_event(event);
352
353            // Re-arm the accept future for the next connection
354            self.accept_fut = Box::pin(accept_tcp(Arc::clone(&self.listener)));
355
356            return Poll::Ready(Some(ie));
357        }
358
359        // Then poll existing connections
360        self.pool.poll_next_event(cx)
361    }
362}