1use std::{
2 collections::HashMap,
3 future::ready,
4 pin::Pin,
5 sync::Arc,
6 task::{Context, Poll},
7};
8
9use futures::{
10 Stream, StreamExt,
11 stream::{FusedStream, FuturesUnordered},
12};
13
14use tokio::{sync::Mutex, time::Instant};
15
16use crate::{
17 Channel, Interface, InterfaceCommand, InterfaceError, InterfaceEvent, Message, Payload, PeerId,
18 bearer::{Bearer, BearerReadHalf, BearerWriteHalf, Timestamp},
19};
20
21enum InternalEvent<M: Message> {
22 Connected(PeerId, Bearer),
23 Disconnected(PeerId),
24 Sent(PeerId, M),
25 Recv(PeerId, Vec<M>, BearerReadHalf, ChunkBuffer),
26 Error(PeerId, tokio::io::Error),
27}
28
29type InterfaceFuture<M> = Pin<Box<dyn Future<Output = InternalEvent<M>> + Send>>;
30
31async fn connect<M: Message>(pid: PeerId) -> InternalEvent<M> {
32 let pid = pid.clone();
33
34 tracing::debug!(%pid, "connecting bearer");
35 let bearer = Bearer::connect_tcp((pid.host.clone(), pid.port)).await;
36
37 match bearer {
38 Ok(bearer) => InternalEvent::Connected(pid.clone(), bearer),
39 Err(e) => InternalEvent::Error(pid.clone(), e),
40 }
41}
42
43async fn send<M: Message>(
44 pid: PeerId,
45 writer: SharedWriter,
46 msg: M,
47 ts: Timestamp,
48 mode: u16,
49) -> InternalEvent<M> {
50 let pid = pid.clone();
51 let copy = msg.clone();
52
53 let mut writer = writer.lock().await;
54
55 let result = writer.write_message(msg, ts, mode).await;
56
57 match result {
58 Ok(_) => InternalEvent::Sent(pid.clone(), copy),
59 Err(e) => InternalEvent::Error(pid.clone(), e),
60 }
61}
62
63pub type ChunkBuffer = HashMap<Channel, Payload>;
66
67async fn recv<M: Message>(
68 pid: PeerId,
69 mut reader: BearerReadHalf,
70 mut partial_chunks: ChunkBuffer,
71) -> InternalEvent<M> {
72 let pid = pid.clone();
73
74 let result = reader.read_full_msgs(&mut partial_chunks).await;
75
76 match result {
77 Ok(msgs) => InternalEvent::Recv(pid.clone(), msgs, reader, partial_chunks),
78 Err(e) => InternalEvent::Error(pid.clone(), e),
79 }
80}
81
82async fn disconnect<M: Message>(pid: PeerId, writer: SharedWriter) -> InternalEvent<M> {
83 let pid = pid.clone();
84
85 let mut writer = writer.lock().await;
86
87 writer.shutdown().await.unwrap();
88
89 InternalEvent::Disconnected(pid.clone())
90}
91
92pub type SharedWriter = Arc<Mutex<BearerWriteHalf>>;
94
95struct TcpConnectionPool<M: Message> {
100 futures: FuturesUnordered<InterfaceFuture<M>>,
101 writers: HashMap<PeerId, SharedWriter>,
102 clock: Instant,
103 mode: u16,
105}
106
107impl<M: Message> TcpConnectionPool<M> {
108 fn new(mode: u16) -> Self {
109 Self {
110 futures: FuturesUnordered::new(),
111 writers: HashMap::new(),
112 clock: Instant::now(),
113 mode,
114 }
115 }
116
117 fn push_future(&mut self, f: InterfaceFuture<M>) {
118 self.futures.push(f);
119 }
120
121 fn take_writer(&mut self, pid: &PeerId) -> Option<SharedWriter> {
122 self.writers.get(pid).cloned()
123 }
124
125 fn on_connected(&mut self, pid: PeerId, bearer: Bearer) -> InterfaceEvent<M> {
126 let (read, write) = bearer.into_split();
127
128 self.writers
129 .insert(pid.clone(), Arc::new(Mutex::new(write)));
130
131 let future = recv(pid.clone(), read, HashMap::new());
132 self.futures.push(Box::pin(future));
133
134 InterfaceEvent::Connected(pid)
135 }
136
137 fn on_disconnected(&mut self, pid: PeerId) -> InterfaceEvent<M> {
138 self.writers.remove(&pid);
139 InterfaceEvent::Disconnected(pid)
140 }
141
142 fn on_sent(&mut self, pid: PeerId, msg: M) -> InterfaceEvent<M> {
143 InterfaceEvent::Sent(pid, msg)
144 }
145
146 fn on_recv(
147 &mut self,
148 pid: PeerId,
149 msgs: Vec<M>,
150 reader: BearerReadHalf,
151 partial_chunks: ChunkBuffer,
152 ) -> InterfaceEvent<M> {
153 let future = recv(pid.clone(), reader, partial_chunks);
154 self.futures.push(Box::pin(future));
155
156 InterfaceEvent::Recv(pid, msgs)
157 }
158
159 fn on_error(&mut self, pid: PeerId, error: tokio::io::Error) -> InterfaceEvent<M> {
160 tracing::error!("error: {:?}", error);
161 InterfaceEvent::Error(pid, InterfaceError::Other(error.to_string()))
162 }
163
164 fn handle_internal_event(&mut self, event: InternalEvent<M>) -> InterfaceEvent<M> {
165 match event {
166 InternalEvent::Connected(pid, stream) => self.on_connected(pid, stream),
167 InternalEvent::Sent(pid, msg) => self.on_sent(pid, msg),
168 InternalEvent::Recv(pid, msgs, stream, buf) => self.on_recv(pid, msgs, stream, buf),
169 InternalEvent::Disconnected(pid) => self.on_disconnected(pid),
170 InternalEvent::Error(pid, error) => self.on_error(pid, error),
171 }
172 }
173
174 fn dispatch_send(&mut self, pid: PeerId, msg: M) {
175 let ts = self.clock.elapsed().as_micros() as u32;
176
177 let Some(writer) = self.take_writer(&pid) else {
178 tracing::error!(%pid, "trying to send to a peer not connected");
179 return;
180 };
181
182 let future = send(pid, writer, msg, ts, self.mode);
183 self.futures.push(Box::pin(future));
184 }
185
186 fn dispatch_disconnect(&mut self, pid: PeerId) {
187 let Some(stream) = self.take_writer(&pid) else {
188 tracing::warn!(%pid, "trying to disconnect a peer not connected");
189 self.futures
190 .push(Box::pin(ready(InternalEvent::Disconnected(pid.clone()))));
191 return;
192 };
193
194 let future = disconnect(pid, stream);
195 self.futures.push(Box::pin(future));
196 }
197
198 fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Option<InterfaceEvent<M>>> {
199 let event = self.futures.poll_next_unpin(cx);
200
201 match event {
202 Poll::Ready(Some(event)) => {
203 let event = self.handle_internal_event(event);
204 Poll::Ready(Some(event))
205 }
206 Poll::Ready(None) => Poll::Pending,
207 Poll::Pending => Poll::Pending,
208 }
209 }
210}
211
212pub struct TcpInterface<M: Message> {
221 pool: TcpConnectionPool<M>,
222}
223
224impl<M: Message> Default for TcpInterface<M> {
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230impl<M: Message> TcpInterface<M> {
231 pub fn new() -> Self {
233 Self {
234 pool: TcpConnectionPool::new(crate::protocol::PROTOCOL_CLIENT),
235 }
236 }
237}
238
239impl<M: Message> Interface<M> for TcpInterface<M> {
240 fn dispatch(&mut self, cmd: InterfaceCommand<M>) {
241 match cmd {
242 InterfaceCommand::Connect(pid) => {
243 let future = connect(pid.clone());
244 self.pool.push_future(Box::pin(future));
245 }
246 InterfaceCommand::Send(pid, msg) => {
247 self.pool.dispatch_send(pid, msg);
248 }
249 InterfaceCommand::Disconnect(pid) => {
250 self.pool.dispatch_disconnect(pid);
251 }
252 }
253 }
254}
255
256impl<M: Message> FusedStream for TcpInterface<M> {
257 fn is_terminated(&self) -> bool {
258 false
259 }
260}
261
262impl<M: Message> Stream for TcpInterface<M> {
263 type Item = InterfaceEvent<M>;
264
265 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
266 self.pool.poll_next_event(cx)
267 }
268}
269
270async fn accept_tcp<M: Message>(listener: Arc<tokio::net::TcpListener>) -> InternalEvent<M> {
275 match Bearer::accept_tcp(&listener).await {
276 Ok((bearer, addr)) => {
277 let pid = PeerId {
278 host: addr.ip().to_string(),
279 port: addr.port(),
280 };
281 tracing::info!(%pid, "accepted inbound connection");
282 InternalEvent::Connected(pid, bearer)
283 }
284 Err(e) => {
285 tracing::error!("accept error: {:?}", e);
286 let pid = PeerId {
288 host: "accept-error".to_string(),
289 port: 0,
290 };
291 InternalEvent::Error(pid, e)
292 }
293 }
294}
295
296pub struct TcpListenerInterface<M: Message> {
303 pool: TcpConnectionPool<M>,
304 listener: Arc<tokio::net::TcpListener>,
305 accept_fut: InterfaceFuture<M>,
306}
307
308impl<M: Message> TcpListenerInterface<M> {
309 pub fn new(listener: tokio::net::TcpListener) -> Self {
312 let listener = Arc::new(listener);
313 let accept_fut = Box::pin(accept_tcp(Arc::clone(&listener)));
314
315 Self {
316 pool: TcpConnectionPool::new(crate::protocol::PROTOCOL_SERVER),
317 listener,
318 accept_fut,
319 }
320 }
321}
322
323impl<M: Message> Interface<M> for TcpListenerInterface<M> {
324 fn dispatch(&mut self, cmd: InterfaceCommand<M>) {
325 match cmd {
326 InterfaceCommand::Connect(pid) => {
327 tracing::warn!(%pid, "TcpListenerInterface does not support outbound Connect, ignoring");
328 }
329 InterfaceCommand::Send(pid, msg) => {
330 self.pool.dispatch_send(pid, msg);
331 }
332 InterfaceCommand::Disconnect(pid) => {
333 self.pool.dispatch_disconnect(pid);
334 }
335 }
336 }
337}
338
339impl<M: Message> FusedStream for TcpListenerInterface<M> {
340 fn is_terminated(&self) -> bool {
341 false
342 }
343}
344
345impl<M: Message> Stream for TcpListenerInterface<M> {
346 type Item = InterfaceEvent<M>;
347
348 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
349 if let Poll::Ready(event) = self.accept_fut.as_mut().poll(cx) {
351 let ie = self.pool.handle_internal_event(event);
352
353 self.accept_fut = Box::pin(accept_tcp(Arc::clone(&self.listener)));
355
356 return Poll::Ready(Some(ie));
357 }
358
359 self.pool.poll_next_event(cx)
361 }
362}