ergot_base/interface_manager/
std_tcp_client.rs

1// I need an interface manager that can have 0 or 1 interfaces
2// it needs to be able to be const init'd (empty)
3// at runtime we can attach the client (and maybe re-attach?)
4//
5// In normal setups, we'd probably want some way to "announce" we
6// are here, but in point-to-point
7
8use std::sync::Arc;
9
10use crate::{
11    Header, Key, NetStack,
12    interface_manager::{
13        ConstInit, InterfaceManager, InterfaceSendError, cobs_stream,
14        std_utils::{
15            ReceiverError, StdQueue,
16            acc::{CobsAccumulator, FeedResult},
17        },
18        wire_frames::{CommonHeader, de_frame},
19    },
20};
21use bbq2::{prod_cons::stream::StreamConsumer, traits::storage::BoxedSlice};
22use log::{debug, error, info, warn};
23use maitake_sync::WaitQueue;
24use mutex::ScopedRawMutex;
25use tokio::{
26    io::{AsyncReadExt, AsyncWriteExt},
27    net::{
28        TcpStream,
29        tcp::{OwnedReadHalf, OwnedWriteHalf},
30    },
31    select,
32};
33
34#[derive(Default)]
35pub struct StdTcpClientIm {
36    inner: Option<StdTcpClientImInner>,
37    seq_no: u16,
38}
39
40struct StdTcpClientImInner {
41    interface: StdTcpTxHdl,
42    net_id: u16,
43    closer: Arc<WaitQueue>,
44}
45
46#[derive(Debug, PartialEq)]
47pub enum ClientError {
48    SocketAlreadyActive,
49}
50
51pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
52    stack: &'static NetStack<R, StdTcpClientIm>,
53    skt: OwnedReadHalf,
54    closer: Arc<WaitQueue>,
55}
56
57struct StdTcpTxHdl {
58    skt_tx: cobs_stream::Interface<StdQueue>,
59}
60
61// ---- impls ----
62
63impl StdTcpClientIm {
64    pub const fn new() -> Self {
65        Self {
66            inner: None,
67            seq_no: 0,
68        }
69    }
70}
71
72impl ConstInit for StdTcpClientIm {
73    #[allow(clippy::declare_interior_mutable_const)]
74    const INIT: Self = Self::new();
75}
76
77impl StdTcpClientIm {
78    fn common_send<'a, 'b>(
79        &'b mut self,
80        ihdr: &'a Header,
81    ) -> Result<(&'b mut StdTcpClientImInner, CommonHeader, Option<&'a Key>), InterfaceSendError>
82    {
83        let intfc = match self.inner.take() {
84            None => return Err(InterfaceSendError::NoRouteToDest),
85            Some(intfc) if intfc.closer.is_closed() => {
86                drop(intfc);
87                return Err(InterfaceSendError::NoRouteToDest);
88            }
89            Some(intfc) => self.inner.insert(intfc),
90        };
91
92        if intfc.net_id == 0 {
93            // No net_id yet, don't allow routing (todo: maybe broadcast?)
94            return Err(InterfaceSendError::NoRouteToDest);
95        }
96        // todo: we could probably keep a routing table of some kind, but for
97        // now, we treat this as a "default" route, all packets go
98
99        // TODO: a LOT of this is copy/pasted from the router, can we make this
100        // shared logic, or handled by the stack somehow?
101        //
102        // TODO: Assumption: "we" are always node_id==2
103        if ihdr.dst.network_id == intfc.net_id && ihdr.dst.node_id == 2 {
104            return Err(InterfaceSendError::DestinationLocal);
105        }
106
107        // Now that we've filtered out "dest local" checks, see if there is
108        // any TTL left before we send to the next hop
109        let mut hdr = ihdr.clone();
110        hdr.decrement_ttl()?;
111
112        // If the source is local, rewrite the source using this interface's
113        // information so responses can find their way back here
114        if hdr.src.net_node_any() {
115            // todo: if we know the destination is EXACTLY this network,
116            // we could leave the network_id local to allow for shorter
117            // addresses
118            hdr.src.network_id = intfc.net_id;
119            hdr.src.node_id = 2;
120        }
121
122        // If this is a broadcast message, update the destination, ignoring
123        // whatever was there before
124        if hdr.dst.port_id == 255 {
125            hdr.dst.network_id = intfc.net_id;
126            hdr.dst.node_id = 1;
127        }
128
129        let seq_no = self.seq_no;
130        self.seq_no = self.seq_no.wrapping_add(1);
131
132        let header = CommonHeader {
133            src: hdr.src.as_u32(),
134            dst: hdr.dst.as_u32(),
135            seq_no,
136            kind: hdr.kind.0,
137            ttl: hdr.ttl,
138        };
139        let key = if [0, 255].contains(&hdr.dst.port_id) {
140            Some(ihdr.key.as_ref().unwrap())
141        } else {
142            None
143        };
144
145        Ok((intfc, header, key))
146    }
147}
148
149impl InterfaceManager for StdTcpClientIm {
150    fn send<T: serde::Serialize>(
151        &mut self,
152        hdr: &Header,
153        data: &T,
154    ) -> Result<(), InterfaceSendError> {
155        let (intfc, header, key) = self.common_send(hdr)?;
156        let res = intfc.interface.skt_tx.send_ty(&header, key, data);
157
158        match res {
159            Ok(()) => Ok(()),
160            Err(()) => Err(InterfaceSendError::InterfaceFull),
161        }
162    }
163
164    fn send_raw(&mut self, hdr: &Header, data: &[u8]) -> Result<(), InterfaceSendError> {
165        let (intfc, header, key) = self.common_send(hdr)?;
166        let res = intfc.interface.skt_tx.send_raw(&header, key, data);
167
168        match res {
169            Ok(()) => Ok(()),
170            Err(()) => Err(InterfaceSendError::InterfaceFull),
171        }
172    }
173
174    fn send_err(
175        &mut self,
176        hdr: &Header,
177        err: crate::ProtocolError,
178    ) -> Result<(), InterfaceSendError> {
179        let (intfc, header, _key) = self.common_send(hdr)?;
180        let res = intfc.interface.skt_tx.send_err(&header, err);
181
182        match res {
183            Ok(()) => Ok(()),
184            Err(()) => Err(InterfaceSendError::InterfaceFull),
185        }
186    }
187}
188
189impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
190    pub async fn run(mut self) -> Result<(), ReceiverError> {
191        let res = self.run_inner().await;
192        // todo: this could live somewhere else?
193        self.stack.with_interface_manager(|im| {
194            _ = im.inner.take();
195        });
196        res
197    }
198
199    pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
200        let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
201        let mut raw_buf = [0u8; 4096];
202        let mut net_id = None;
203
204        loop {
205            let rd = self.skt.read(&mut raw_buf);
206            let close = self.closer.wait();
207
208            let ct = select! {
209                r = rd => {
210                    match r {
211                        Ok(0) | Err(_) => {
212                            warn!("recv run closed");
213                            return Err(ReceiverError::SocketClosed)
214                        },
215                        Ok(ct) => ct,
216                    }
217                }
218                _c = close => {
219                    return Err(ReceiverError::SocketClosed);
220                }
221            };
222
223            let buf = &raw_buf[..ct];
224            let mut window = buf;
225
226            'cobs: while !window.is_empty() {
227                window = match cobs_buf.feed_raw(window) {
228                    FeedResult::Consumed => break 'cobs,
229                    FeedResult::OverFull(new_wind) => new_wind,
230                    FeedResult::DeserError(new_wind) => new_wind,
231                    FeedResult::Success { data, remaining } => {
232                        // Successfully de-cobs'd a packet, now we need to
233                        // do something with it.
234                        if let Some(mut frame) = de_frame(data) {
235                            debug!("Got Frame!");
236                            let take_net = net_id.is_none()
237                                || net_id.is_some_and(|n| {
238                                    frame.hdr.dst.network_id != 0 && n != frame.hdr.dst.network_id
239                                });
240                            if take_net {
241                                self.stack.with_interface_manager(|im| {
242                                    if let Some(i) = im.inner.as_mut() {
243                                        // i am, whoever you say i am
244                                        i.net_id = frame.hdr.dst.network_id;
245                                    }
246                                    // else: uhhhhhh
247                                });
248                                net_id = Some(frame.hdr.dst.network_id);
249                            }
250
251                            // If the message comes in and has a src net_id of zero,
252                            // we should rewrite it so it isn't later understood as a
253                            // local packet.
254                            //
255                            // TODO: accept any packet if we don't have a net_id yet?
256                            if let Some(net) = net_id.as_ref() {
257                                if frame.hdr.src.network_id == 0 {
258                                    assert_ne!(
259                                        frame.hdr.src.node_id, 0,
260                                        "we got a local packet remotely?"
261                                    );
262                                    assert_ne!(
263                                        frame.hdr.src.node_id, 2,
264                                        "someone is pretending to be us?"
265                                    );
266
267                                    frame.hdr.src.network_id = *net;
268                                }
269                            }
270
271                            // TODO: if the destination IS self.net_id, we could rewrite the
272                            // dest net_id as zero to avoid a pass through the interface manager.
273                            //
274                            // If the dest is 0, should we rewrite the dest as self.net_id? This
275                            // is the opposite as above, but I dunno how that will work with responses
276                            let hdr = frame.hdr.clone();
277                            let hdr: Header = hdr.into();
278                            let res = match frame.body {
279                                Ok(body) => self.stack.send_raw(&hdr, body),
280                                Err(e) => self.stack.send_err(&hdr, e),
281                            };
282                            match res {
283                                Ok(()) => {}
284                                Err(e) => {
285                                    // TODO: match on error, potentially try to send NAK?
286                                    panic!("recv->send error: {e:?}");
287                                }
288                            }
289                        } else {
290                            warn!(
291                                "Decode error! Ignoring frame on net_id {}",
292                                net_id.unwrap_or(0)
293                            );
294                        }
295
296                        remaining
297                    }
298                };
299            }
300        }
301    }
302}
303
304// Helper functions
305
306pub fn register_interface<R: ScopedRawMutex>(
307    stack: &'static NetStack<R, StdTcpClientIm>,
308    socket: TcpStream,
309) -> Result<StdTcpRecvHdl<R>, ClientError> {
310    let (rx, tx) = socket.into_split();
311    let closer = Arc::new(WaitQueue::new());
312    stack.with_interface_manager(|im| {
313        if im.inner.is_some() {
314            return Err(ClientError::SocketAlreadyActive);
315        }
316
317        let q = bbq2::nicknames::Lechon::new_with_storage(BoxedSlice::new(4096));
318        let ctx = q.stream_producer();
319        let crx = q.stream_consumer();
320
321        im.inner = Some(StdTcpClientImInner {
322            interface: StdTcpTxHdl {
323                skt_tx: cobs_stream::Interface {
324                    mtu: 1024,
325                    prod: ctx,
326                },
327            },
328            net_id: 0,
329            closer: closer.clone(),
330        });
331        // TODO: spawning in a non-async context!
332        tokio::task::spawn(tx_worker(tx, crx, closer.clone()));
333        Ok(())
334    })?;
335    Ok(StdTcpRecvHdl {
336        stack,
337        skt: rx,
338        closer,
339    })
340}
341
342async fn tx_worker(mut tx: OwnedWriteHalf, rx: StreamConsumer<StdQueue>, closer: Arc<WaitQueue>) {
343    info!("Started tx_worker");
344    loop {
345        let rxf = rx.wait_read();
346        let clf = closer.wait();
347
348        let frame = select! {
349            r = rxf => r,
350            _c = clf => {
351                break;
352            }
353        };
354
355        let len = frame.len();
356        info!("sending pkt len:{}", len);
357        let res = tx.write_all(&frame).await;
358        frame.release(len);
359        if let Err(e) = res {
360            error!("Err: {e:?}");
361            break;
362        }
363    }
364    // TODO: GC waker?
365    warn!("Closing interface");
366}