ergot_base/interface_manager/
std_tcp_client.rs

1// I need an interface manager that can have 0 or 1 interfaces
2// it needs to be able to be const init'd (empty)
3// at runtime we can attach the client (and maybe re-attach?)
4//
5// In normal setups, we'd probably want some way to "announce" we
6// are here, but in point-to-point
7
8use std::sync::Arc;
9
10use crate::{
11    Header, HeaderSeq, NetStack,
12    interface_manager::std_utils::{OwnedFrame, ser_frame},
13};
14
15use super::{
16    ConstInit, InterfaceManager, InterfaceSendError,
17    std_utils::{
18        ReceiverError,
19        acc::{CobsAccumulator, FeedResult},
20        de_frame,
21    },
22};
23use log::{debug, error, info, warn};
24use maitake_sync::WaitQueue;
25use mutex::ScopedRawMutex;
26use tokio::{
27    io::{AsyncReadExt, AsyncWriteExt},
28    net::{
29        TcpStream,
30        tcp::{OwnedReadHalf, OwnedWriteHalf},
31    },
32    select,
33    sync::mpsc::{Receiver, Sender, channel, error::TrySendError},
34};
35
36#[derive(Default)]
37pub struct StdTcpClientIm {
38    inner: Option<StdTcpClientImInner>,
39    seq_no: u16,
40}
41
42struct StdTcpClientImInner {
43    interface: StdTcpTxHdl,
44    net_id: u16,
45    closer: Arc<WaitQueue>,
46}
47
48#[derive(Debug, PartialEq)]
49pub enum ClientError {
50    SocketAlreadyActive,
51}
52
53pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
54    stack: &'static NetStack<R, StdTcpClientIm>,
55    skt: OwnedReadHalf,
56    closer: Arc<WaitQueue>,
57}
58
59struct StdTcpTxHdl {
60    skt_tx: Sender<OwnedFrame>,
61}
62
63// ---- impls ----
64
65impl StdTcpClientIm {
66    pub const fn new() -> Self {
67        Self {
68            inner: None,
69            seq_no: 0,
70        }
71    }
72}
73
74impl ConstInit for StdTcpClientIm {
75    #[allow(clippy::declare_interior_mutable_const)]
76    const INIT: Self = Self::new();
77}
78
79impl InterfaceManager for StdTcpClientIm {
80    fn send<T: serde::Serialize>(
81        &mut self,
82        hdr: &Header,
83        data: &T,
84    ) -> Result<(), InterfaceSendError> {
85        let Some(intfc) = self.inner.as_mut() else {
86            return Err(InterfaceSendError::NoRouteToDest);
87        };
88        if intfc.net_id == 0 {
89            // No net_id yet, don't allow routing (todo: maybe broadcast?)
90            return Err(InterfaceSendError::NoRouteToDest);
91        }
92        // todo: we could probably keep a routing table of some kind, but for
93        // now, we treat this as a "default" route, all packets go
94
95        // TODO: a LOT of this is copy/pasted from the router, can we make this
96        // shared logic, or handled by the stack somehow?
97        //
98        // TODO: Assumption: "we" are always node_id==2
99        if hdr.dst.network_id == intfc.net_id && hdr.dst.node_id == 2 {
100            return Err(InterfaceSendError::DestinationLocal);
101        }
102
103        // Now that we've filtered out "dest local" checks, see if there is
104        // any TTL left before we send to the next hop
105        let mut hdr = hdr.clone();
106        hdr.decrement_ttl()?;
107
108        // If the source is local, rewrite the source using this interface's
109        // information so responses can find their way back here
110        if hdr.src.net_node_any() {
111            // todo: if we know the destination is EXACTLY this network,
112            // we could leave the network_id local to allow for shorter
113            // addresses
114            hdr.src.network_id = intfc.net_id;
115            hdr.src.node_id = 2;
116        }
117
118        // If this is a broadcast message, update the destination, ignoring
119        // whatever was there before
120        if hdr.dst.port_id == 255 {
121            hdr.dst.network_id = intfc.net_id;
122            hdr.dst.node_id = 1;
123        }
124
125        let seq_no = self.seq_no;
126        self.seq_no = self.seq_no.wrapping_add(1);
127        let res = intfc.interface.skt_tx.try_send(OwnedFrame {
128            hdr: HeaderSeq {
129                src: hdr.src,
130                dst: hdr.dst,
131                seq_no,
132                key: hdr.key,
133                kind: hdr.kind,
134                ttl: hdr.ttl,
135            },
136            body: Ok(postcard::to_stdvec(data).unwrap()),
137        });
138        match res {
139            Ok(()) => Ok(()),
140            Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
141            Err(TrySendError::Closed(_)) => {
142                if let Some(i) = self.inner.take() {
143                    i.closer.close();
144                }
145                Err(InterfaceSendError::NoRouteToDest)
146            }
147        }
148    }
149
150    fn send_raw(&mut self, hdr: &Header, data: &[u8]) -> Result<(), InterfaceSendError> {
151        let Some(intfc) = self.inner.as_mut() else {
152            return Err(InterfaceSendError::NoRouteToDest);
153        };
154        if intfc.net_id == 0 {
155            // No net_id yet, don't allow routing (todo: maybe broadcast?)
156            return Err(InterfaceSendError::NoRouteToDest);
157        }
158        // todo: we could probably keep a routing table of some kind, but for
159        // now, we treat this as a "default" route, all packets go
160
161        // TODO: a LOT of this is copy/pasted from the router, can we make this
162        // shared logic, or handled by the stack somehow?
163        //
164        // TODO: Assumption: "we" are always node_id==2
165        if hdr.dst.network_id == intfc.net_id && hdr.dst.node_id == 2 {
166            return Err(InterfaceSendError::DestinationLocal);
167        }
168
169        // Now that we've filtered out "dest local" checks, see if there is
170        // any TTL left before we send to the next hop
171        let mut hdr = hdr.clone();
172        hdr.decrement_ttl()?;
173
174        // If the source is local, rewrite the source using this interface's
175        // information so responses can find their way back here
176        if hdr.src.net_node_any() {
177            // todo: if we know the destination is EXACTLY this network,
178            // we could leave the network_id local to allow for shorter
179            // addresses
180            hdr.src.network_id = intfc.net_id;
181            hdr.src.node_id = 2;
182        }
183
184        // If this is a broadcast message, update the destination, ignoring
185        // whatever was there before
186        if hdr.dst.port_id == 255 {
187            hdr.dst.network_id = intfc.net_id;
188            hdr.dst.node_id = 1;
189        }
190
191        let seq_no = self.seq_no;
192        self.seq_no = self.seq_no.wrapping_add(1);
193        let res = intfc.interface.skt_tx.try_send(OwnedFrame {
194            hdr: HeaderSeq {
195                src: hdr.src,
196                dst: hdr.dst,
197                seq_no,
198                key: hdr.key,
199                kind: hdr.kind,
200                ttl: hdr.ttl,
201            },
202            body: Ok(data.to_vec()),
203        });
204        match res {
205            Ok(()) => Ok(()),
206            Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
207            Err(TrySendError::Closed(_)) => {
208                self.inner.take();
209                Err(InterfaceSendError::NoRouteToDest)
210            }
211        }
212    }
213
214    fn send_err(
215        &mut self,
216        hdr: &Header,
217        err: crate::ProtocolError,
218    ) -> Result<(), InterfaceSendError> {
219        let Some(intfc) = self.inner.as_mut() else {
220            return Err(InterfaceSendError::NoRouteToDest);
221        };
222        if intfc.net_id == 0 {
223            // No net_id yet, don't allow routing (todo: maybe broadcast?)
224            return Err(InterfaceSendError::NoRouteToDest);
225        }
226        // todo: we could probably keep a routing table of some kind, but for
227        // now, we treat this as a "default" route, all packets go
228
229        // TODO: a LOT of this is copy/pasted from the router, can we make this
230        // shared logic, or handled by the stack somehow?
231        //
232        // TODO: Assumption: "we" are always node_id==2
233        if hdr.dst.network_id == intfc.net_id && hdr.dst.node_id == 2 {
234            return Err(InterfaceSendError::DestinationLocal);
235        }
236
237        // Now that we've filtered out "dest local" checks, see if there is
238        // any TTL left before we send to the next hop
239        let mut hdr = hdr.clone();
240        hdr.decrement_ttl()?;
241
242        // If the source is local, rewrite the source using this interface's
243        // information so responses can find their way back here
244        if hdr.src.net_node_any() {
245            // todo: if we know the destination is EXACTLY this network,
246            // we could leave the network_id local to allow for shorter
247            // addresses
248            hdr.src.network_id = intfc.net_id;
249            hdr.src.node_id = 2;
250        }
251
252        // If this is a broadcast message, update the destination, ignoring
253        // whatever was there before
254        if hdr.dst.port_id == 255 {
255            hdr.dst.network_id = intfc.net_id;
256            hdr.dst.node_id = 1;
257        }
258
259        let seq_no = self.seq_no;
260        self.seq_no = self.seq_no.wrapping_add(1);
261        let res = intfc.interface.skt_tx.try_send(OwnedFrame {
262            hdr: HeaderSeq {
263                src: hdr.src,
264                dst: hdr.dst,
265                seq_no,
266                key: hdr.key,
267                kind: hdr.kind,
268                ttl: hdr.ttl,
269            },
270            body: Err(err),
271        });
272        match res {
273            Ok(()) => Ok(()),
274            Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
275            Err(TrySendError::Closed(_)) => {
276                if let Some(i) = self.inner.take() {
277                    i.closer.close();
278                }
279                Err(InterfaceSendError::NoRouteToDest)
280            }
281        }
282    }
283}
284
285impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
286    pub async fn run(mut self) -> Result<(), ReceiverError> {
287        let res = self.run_inner().await;
288        // todo: this could live somewhere else?
289        self.stack.with_interface_manager(|im| {
290            _ = im.inner.take();
291        });
292        res
293    }
294
295    pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
296        let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
297        let mut raw_buf = [0u8; 4096];
298        let mut net_id = None;
299
300        loop {
301            let rd = self.skt.read(&mut raw_buf);
302            let close = self.closer.wait();
303
304            let ct = select! {
305                r = rd => {
306                    match r {
307                        Ok(0) | Err(_) => {
308                            warn!("recv run closed");
309                            return Err(ReceiverError::SocketClosed)
310                        },
311                        Ok(ct) => ct,
312                    }
313                }
314                _c = close => {
315                    return Err(ReceiverError::SocketClosed);
316                }
317            };
318
319            let buf = &raw_buf[..ct];
320            let mut window = buf;
321
322            'cobs: while !window.is_empty() {
323                window = match cobs_buf.feed_raw(window) {
324                    FeedResult::Consumed => break 'cobs,
325                    FeedResult::OverFull(new_wind) => new_wind,
326                    FeedResult::DeserError(new_wind) => new_wind,
327                    FeedResult::Success { data, remaining } => {
328                        // Successfully de-cobs'd a packet, now we need to
329                        // do something with it.
330                        if let Some(mut frame) = de_frame(data) {
331                            debug!("Got Frame!");
332                            let take_net = net_id.is_none()
333                                || net_id.is_some_and(|n| {
334                                    frame.hdr.dst.network_id != 0 && n != frame.hdr.dst.network_id
335                                });
336                            if take_net {
337                                self.stack.with_interface_manager(|im| {
338                                    if let Some(i) = im.inner.as_mut() {
339                                        // i am, whoever you say i am
340                                        i.net_id = frame.hdr.dst.network_id;
341                                    }
342                                    // else: uhhhhhh
343                                });
344                                net_id = Some(frame.hdr.dst.network_id);
345                            }
346
347                            // If the message comes in and has a src net_id of zero,
348                            // we should rewrite it so it isn't later understood as a
349                            // local packet.
350                            //
351                            // TODO: accept any packet if we don't have a net_id yet?
352                            if let Some(net) = net_id.as_ref() {
353                                if frame.hdr.src.network_id == 0 {
354                                    assert_ne!(
355                                        frame.hdr.src.node_id, 0,
356                                        "we got a local packet remotely?"
357                                    );
358                                    assert_ne!(
359                                        frame.hdr.src.node_id, 2,
360                                        "someone is pretending to be us?"
361                                    );
362
363                                    frame.hdr.src.network_id = *net;
364                                }
365                            }
366
367                            // TODO: if the destination IS self.net_id, we could rewrite the
368                            // dest net_id as zero to avoid a pass through the interface manager.
369                            //
370                            // If the dest is 0, should we rewrite the dest as self.net_id? This
371                            // is the opposite as above, but I dunno how that will work with responses
372                            let hdr = frame.hdr.clone();
373                            let hdr: Header = hdr.into();
374                            let res = match frame.body {
375                                Ok(body) => self.stack.send_raw(&hdr, &body),
376                                Err(e) => self.stack.send_err(&hdr, e),
377                            };
378                            match res {
379                                Ok(()) => {}
380                                Err(e) => {
381                                    // TODO: match on error, potentially try to send NAK?
382                                    panic!("recv->send error: {e:?}");
383                                }
384                            }
385                        } else {
386                            warn!(
387                                "Decode error! Ignoring frame on net_id {}",
388                                net_id.unwrap_or(0)
389                            );
390                        }
391
392                        remaining
393                    }
394                };
395            }
396        }
397    }
398}
399
400// Helper functions
401
402pub fn register_interface<R: ScopedRawMutex>(
403    stack: &'static NetStack<R, StdTcpClientIm>,
404    socket: TcpStream,
405) -> Result<StdTcpRecvHdl<R>, ClientError> {
406    let (rx, tx) = socket.into_split();
407    let (ctx, crx) = channel(64);
408    let closer = Arc::new(WaitQueue::new());
409    stack.with_interface_manager(|im| {
410        if im.inner.is_some() {
411            return Err(ClientError::SocketAlreadyActive);
412        }
413
414        im.inner = Some(StdTcpClientImInner {
415            interface: StdTcpTxHdl { skt_tx: ctx },
416            net_id: 0,
417            closer: closer.clone(),
418        });
419        // TODO: spawning in a non-async context!
420        tokio::task::spawn(tx_worker(tx, crx, closer.clone()));
421        Ok(())
422    })?;
423    Ok(StdTcpRecvHdl {
424        stack,
425        skt: rx,
426        closer,
427    })
428}
429
430async fn tx_worker(mut tx: OwnedWriteHalf, mut rx: Receiver<OwnedFrame>, closer: Arc<WaitQueue>) {
431    info!("Started tx_worker");
432    loop {
433        let rxf = rx.recv();
434        let clf = closer.wait();
435
436        let frame = select! {
437            r = rxf => {
438                if let Some(frame) = r {
439                    frame
440                } else {
441                    warn!("tx_workerrx closed!");
442                    closer.close();
443                    break;
444                }
445            }
446            _c = clf => {
447                break;
448            }
449        };
450
451        let msg = ser_frame(frame);
452        info!("sending pkt len:{}", msg.len());
453        let res = tx.write_all(&msg).await;
454        if let Err(e) = res {
455            error!("Err: {e:?}");
456            break;
457        }
458    }
459    // TODO: GC waker?
460    warn!("Closing interface");
461}