ergot_base/interface_manager/
std_tcp_client.rs

1// I need an interface manager that can have 0 or 1 interfaces
2// it needs to be able to be const init'd (empty)
3// at runtime we can attach the client (and maybe re-attach?)
4//
5// In normal setups, we'd probably want some way to "announce" we
6// are here, but in point-to-point
7
8use std::sync::Arc;
9
10use crate::{
11    Header, NetStack,
12    interface_manager::{
13        ConstInit, InterfaceManager, InterfaceSendError, cobs_stream,
14        std_utils::{
15            ReceiverError, StdQueue,
16            acc::{CobsAccumulator, FeedResult},
17        },
18    },
19    wire_frames::{CommonHeader, de_frame},
20};
21use bbq2::{prod_cons::stream::StreamConsumer, traits::storage::BoxedSlice};
22use log::{debug, error, info, warn};
23use maitake_sync::WaitQueue;
24use mutex::ScopedRawMutex;
25use tokio::{
26    io::{AsyncReadExt, AsyncWriteExt},
27    net::{
28        TcpStream,
29        tcp::{OwnedReadHalf, OwnedWriteHalf},
30    },
31    select,
32};
33
34#[derive(Default)]
35pub struct StdTcpClientIm {
36    inner: Option<StdTcpClientImInner>,
37    seq_no: u16,
38}
39
40struct StdTcpClientImInner {
41    interface: StdTcpTxHdl,
42    net_id: u16,
43    closer: Arc<WaitQueue>,
44}
45
46#[derive(Debug, PartialEq)]
47pub enum ClientError {
48    SocketAlreadyActive,
49}
50
51pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
52    stack: &'static NetStack<R, StdTcpClientIm>,
53    skt: OwnedReadHalf,
54    closer: Arc<WaitQueue>,
55}
56
57struct StdTcpTxHdl {
58    skt_tx: cobs_stream::Interface<StdQueue>,
59}
60
61// ---- impls ----
62
63impl StdTcpClientIm {
64    pub const fn new() -> Self {
65        Self {
66            inner: None,
67            seq_no: 0,
68        }
69    }
70}
71
72impl ConstInit for StdTcpClientIm {
73    #[allow(clippy::declare_interior_mutable_const)]
74    const INIT: Self = Self::new();
75}
76
77impl StdTcpClientIm {
78    fn common_send<'a, 'b>(
79        &'b mut self,
80        ihdr: &'a Header,
81    ) -> Result<(&'b mut StdTcpClientImInner, CommonHeader), InterfaceSendError> {
82        let intfc = match self.inner.take() {
83            None => return Err(InterfaceSendError::NoRouteToDest),
84            Some(intfc) if intfc.closer.is_closed() => {
85                drop(intfc);
86                return Err(InterfaceSendError::NoRouteToDest);
87            }
88            Some(intfc) => self.inner.insert(intfc),
89        };
90
91        if intfc.net_id == 0 {
92            // No net_id yet, don't allow routing (todo: maybe broadcast?)
93            return Err(InterfaceSendError::NoRouteToDest);
94        }
95        // todo: we could probably keep a routing table of some kind, but for
96        // now, we treat this as a "default" route, all packets go
97
98        // TODO: a LOT of this is copy/pasted from the router, can we make this
99        // shared logic, or handled by the stack somehow?
100        //
101        // TODO: Assumption: "we" are always node_id==2
102        if ihdr.dst.network_id == intfc.net_id && ihdr.dst.node_id == 2 {
103            return Err(InterfaceSendError::DestinationLocal);
104        }
105
106        // Now that we've filtered out "dest local" checks, see if there is
107        // any TTL left before we send to the next hop
108        let mut hdr = ihdr.clone();
109        hdr.decrement_ttl()?;
110
111        // If the source is local, rewrite the source using this interface's
112        // information so responses can find their way back here
113        if hdr.src.net_node_any() {
114            // todo: if we know the destination is EXACTLY this network,
115            // we could leave the network_id local to allow for shorter
116            // addresses
117            hdr.src.network_id = intfc.net_id;
118            hdr.src.node_id = 2;
119        }
120
121        // If this is a broadcast message, update the destination, ignoring
122        // whatever was there before
123        if hdr.dst.port_id == 255 {
124            hdr.dst.network_id = intfc.net_id;
125            hdr.dst.node_id = 1;
126        }
127
128        let seq_no = self.seq_no;
129        self.seq_no = self.seq_no.wrapping_add(1);
130
131        let header = CommonHeader {
132            src: hdr.src,
133            dst: hdr.dst,
134            seq_no,
135            kind: hdr.kind,
136            ttl: hdr.ttl,
137        };
138        if [0, 255].contains(&hdr.dst.port_id) {
139            if ihdr.any_all.is_none() {
140                return Err(InterfaceSendError::AnyPortMissingKey);
141            }
142        }
143
144        Ok((intfc, header))
145    }
146}
147
148impl InterfaceManager for StdTcpClientIm {
149    fn send<T: serde::Serialize>(
150        &mut self,
151        hdr: &Header,
152        data: &T,
153    ) -> Result<(), InterfaceSendError> {
154        let (intfc, header) = self.common_send(hdr)?;
155        let res = intfc
156            .interface
157            .skt_tx
158            .send_ty(&header, hdr.any_all.as_ref(), data);
159
160        match res {
161            Ok(()) => Ok(()),
162            Err(()) => Err(InterfaceSendError::InterfaceFull),
163        }
164    }
165
166    fn send_raw(
167        &mut self,
168        hdr: &Header,
169        raw_hdr: &[u8],
170        data: &[u8],
171    ) -> Result<(), InterfaceSendError> {
172        let (intfc, header) = self.common_send(hdr)?;
173        let res = intfc.interface.skt_tx.send_raw(&header, raw_hdr, data);
174
175        match res {
176            Ok(()) => Ok(()),
177            Err(()) => Err(InterfaceSendError::InterfaceFull),
178        }
179    }
180
181    fn send_err(
182        &mut self,
183        hdr: &Header,
184        err: crate::ProtocolError,
185    ) -> Result<(), InterfaceSendError> {
186        let (intfc, header) = self.common_send(hdr)?;
187        let res = intfc.interface.skt_tx.send_err(&header, err);
188
189        match res {
190            Ok(()) => Ok(()),
191            Err(()) => Err(InterfaceSendError::InterfaceFull),
192        }
193    }
194}
195
196impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
197    pub async fn run(mut self) -> Result<(), ReceiverError> {
198        let res = self.run_inner().await;
199        // todo: this could live somewhere else?
200        self.stack.with_interface_manager(|im| {
201            _ = im.inner.take();
202        });
203        res
204    }
205
206    pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
207        let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
208        let mut raw_buf = [0u8; 4096];
209        let mut net_id = None;
210
211        loop {
212            let rd = self.skt.read(&mut raw_buf);
213            let close = self.closer.wait();
214
215            let ct = select! {
216                r = rd => {
217                    match r {
218                        Ok(0) | Err(_) => {
219                            warn!("recv run closed");
220                            return Err(ReceiverError::SocketClosed)
221                        },
222                        Ok(ct) => ct,
223                    }
224                }
225                _c = close => {
226                    return Err(ReceiverError::SocketClosed);
227                }
228            };
229
230            let buf = &raw_buf[..ct];
231            let mut window = buf;
232
233            'cobs: while !window.is_empty() {
234                window = match cobs_buf.feed_raw(window) {
235                    FeedResult::Consumed => break 'cobs,
236                    FeedResult::OverFull(new_wind) => new_wind,
237                    FeedResult::DeserError(new_wind) => new_wind,
238                    FeedResult::Success { data, remaining } => {
239                        // Successfully de-cobs'd a packet, now we need to
240                        // do something with it.
241                        if let Some(mut frame) = de_frame(data) {
242                            debug!("Got Frame!");
243                            let take_net = net_id.is_none()
244                                || net_id.is_some_and(|n| {
245                                    frame.hdr.dst.network_id != 0 && n != frame.hdr.dst.network_id
246                                });
247                            if take_net {
248                                self.stack.with_interface_manager(|im| {
249                                    if let Some(i) = im.inner.as_mut() {
250                                        // i am, whoever you say i am
251                                        i.net_id = frame.hdr.dst.network_id;
252                                    }
253                                    // else: uhhhhhh
254                                });
255                                net_id = Some(frame.hdr.dst.network_id);
256                            }
257
258                            // If the message comes in and has a src net_id of zero,
259                            // we should rewrite it so it isn't later understood as a
260                            // local packet.
261                            //
262                            // TODO: accept any packet if we don't have a net_id yet?
263                            if let Some(net) = net_id.as_ref() {
264                                if frame.hdr.src.network_id == 0 {
265                                    assert_ne!(
266                                        frame.hdr.src.node_id, 0,
267                                        "we got a local packet remotely?"
268                                    );
269                                    assert_ne!(
270                                        frame.hdr.src.node_id, 2,
271                                        "someone is pretending to be us?"
272                                    );
273
274                                    frame.hdr.src.network_id = *net;
275                                }
276                            }
277
278                            // TODO: if the destination IS self.net_id, we could rewrite the
279                            // dest net_id as zero to avoid a pass through the interface manager.
280                            //
281                            // If the dest is 0, should we rewrite the dest as self.net_id? This
282                            // is the opposite as above, but I dunno how that will work with responses
283                            let hdr = frame.hdr.clone();
284                            let hdr: Header = hdr.into();
285                            let res = match frame.body {
286                                Ok(body) => self.stack.send_raw(&hdr, frame.hdr_raw, body),
287                                Err(e) => self.stack.send_err(&hdr, e),
288                            };
289                            match res {
290                                Ok(()) => {}
291                                Err(e) => {
292                                    // TODO: match on error, potentially try to send NAK?
293                                    panic!("recv->send error: {e:?}");
294                                }
295                            }
296                        } else {
297                            warn!(
298                                "Decode error! Ignoring frame on net_id {}",
299                                net_id.unwrap_or(0)
300                            );
301                        }
302
303                        remaining
304                    }
305                };
306            }
307        }
308    }
309}
310
311// Helper functions
312
313pub fn register_interface<R: ScopedRawMutex>(
314    stack: &'static NetStack<R, StdTcpClientIm>,
315    socket: TcpStream,
316) -> Result<StdTcpRecvHdl<R>, ClientError> {
317    let (rx, tx) = socket.into_split();
318    let closer = Arc::new(WaitQueue::new());
319    stack.with_interface_manager(|im| {
320        if im.inner.is_some() {
321            return Err(ClientError::SocketAlreadyActive);
322        }
323
324        let q = bbq2::nicknames::Lechon::new_with_storage(BoxedSlice::new(4096));
325        let ctx = q.stream_producer();
326        let crx = q.stream_consumer();
327
328        im.inner = Some(StdTcpClientImInner {
329            interface: StdTcpTxHdl {
330                skt_tx: cobs_stream::Interface {
331                    mtu: 1024,
332                    prod: ctx,
333                },
334            },
335            net_id: 0,
336            closer: closer.clone(),
337        });
338        // TODO: spawning in a non-async context!
339        tokio::task::spawn(tx_worker(tx, crx, closer.clone()));
340        Ok(())
341    })?;
342    Ok(StdTcpRecvHdl {
343        stack,
344        skt: rx,
345        closer,
346    })
347}
348
349async fn tx_worker(mut tx: OwnedWriteHalf, rx: StreamConsumer<StdQueue>, closer: Arc<WaitQueue>) {
350    info!("Started tx_worker");
351    loop {
352        let rxf = rx.wait_read();
353        let clf = closer.wait();
354
355        let frame = select! {
356            r = rxf => r,
357            _c = clf => {
358                break;
359            }
360        };
361
362        let len = frame.len();
363        info!("sending pkt len:{}", len);
364        let res = tx.write_all(&frame).await;
365        frame.release(len);
366        if let Err(e) = res {
367            error!("Err: {e:?}");
368            break;
369        }
370    }
371    // TODO: GC waker?
372    warn!("Closing interface");
373}