ergot_base/interface_manager/
std_tcp_client.rs

1// I need an interface manager that can have 0 or 1 interfaces
2// it needs to be able to be const init'd (empty)
3// at runtime we can attach the client (and maybe re-attach?)
4//
5// In normal setups, we'd probably want some way to "announce" we
6// are here, but in point-to-point
7
8use std::sync::Arc;
9
10use crate::{
11    Header, HeaderSeq, NetStack,
12    interface_manager::std_utils::{OwnedFrame, ser_frame},
13};
14
15use super::{
16    ConstInit, InterfaceManager, InterfaceSendError,
17    std_utils::{
18        ReceiverError,
19        acc::{CobsAccumulator, FeedResult},
20        de_frame,
21    },
22};
23use maitake_sync::WaitQueue;
24use mutex::ScopedRawMutex;
25use tokio::{
26    io::{AsyncReadExt, AsyncWriteExt},
27    net::{
28        TcpStream,
29        tcp::{OwnedReadHalf, OwnedWriteHalf},
30    },
31    select,
32    sync::mpsc::{Receiver, Sender, channel, error::TrySendError},
33};
34
35#[derive(Default)]
36pub struct StdTcpClientIm {
37    inner: Option<StdTcpClientImInner>,
38    seq_no: u16,
39}
40
41struct StdTcpClientImInner {
42    interface: StdTcpTxHdl,
43    net_id: u16,
44    closer: Arc<WaitQueue>,
45}
46
47#[derive(Debug, PartialEq)]
48pub enum ClientError {
49    SocketAlreadyActive,
50}
51
52pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
53    stack: &'static NetStack<R, StdTcpClientIm>,
54    skt: OwnedReadHalf,
55    closer: Arc<WaitQueue>,
56}
57
58struct StdTcpTxHdl {
59    skt_tx: Sender<OwnedFrame>,
60}
61
62// ---- impls ----
63
64impl StdTcpClientIm {
65    pub const fn new() -> Self {
66        Self {
67            inner: None,
68            seq_no: 0,
69        }
70    }
71}
72
73impl ConstInit for StdTcpClientIm {
74    #[allow(clippy::declare_interior_mutable_const)]
75    const INIT: Self = Self::new();
76}
77
78impl InterfaceManager for StdTcpClientIm {
79    fn send<T: serde::Serialize>(
80        &mut self,
81        mut hdr: Header,
82        data: &T,
83    ) -> Result<(), InterfaceSendError> {
84        let Some(intfc) = self.inner.as_mut() else {
85            return Err(InterfaceSendError::NoRouteToDest);
86        };
87        if intfc.net_id == 0 {
88            // No net_id yet, don't allow routing (todo: maybe broadcast?)
89            return Err(InterfaceSendError::NoRouteToDest);
90        }
91        // todo: we could probably keep a routing table of some kind, but for
92        // now, we treat this as a "default" route, all packets go
93
94        // TODO: a LOT of this is copy/pasted from the router, can we make this
95        // shared logic, or handled by the stack somehow?
96        //
97        // TODO: Assumption: "we" are always node_id==2
98        if hdr.dst.network_id == intfc.net_id && hdr.dst.node_id == 2 {
99            return Err(InterfaceSendError::DestinationLocal);
100        }
101        // If the source is local, rewrite the source using this interface's
102        // information so responses can find their way back here
103        if hdr.src.net_node_any() {
104            // todo: if we know the destination is EXACTLY this network,
105            // we could leave the network_id local to allow for shorter
106            // addresses
107            hdr.src.network_id = intfc.net_id;
108            hdr.src.node_id = 2;
109        }
110
111        let seq_no = self.seq_no;
112        self.seq_no = self.seq_no.wrapping_add(1);
113        let res = intfc.interface.skt_tx.try_send(OwnedFrame {
114            hdr: HeaderSeq {
115                src: hdr.src,
116                dst: hdr.dst,
117                seq_no,
118                key: hdr.key,
119                kind: hdr.kind,
120            },
121            body: postcard::to_stdvec(data).unwrap(),
122        });
123        match res {
124            Ok(()) => Ok(()),
125            Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
126            Err(TrySendError::Closed(_)) => {
127                if let Some(i) = self.inner.take() {
128                    i.closer.close();
129                }
130                Err(InterfaceSendError::NoRouteToDest)
131            }
132        }
133    }
134
135    fn send_raw(&mut self, mut hdr: Header, data: &[u8]) -> Result<(), InterfaceSendError> {
136        let Some(intfc) = self.inner.as_mut() else {
137            return Err(InterfaceSendError::NoRouteToDest);
138        };
139        if intfc.net_id == 0 {
140            // No net_id yet, don't allow routing (todo: maybe broadcast?)
141            return Err(InterfaceSendError::NoRouteToDest);
142        }
143        // todo: we could probably keep a routing table of some kind, but for
144        // now, we treat this as a "default" route, all packets go
145
146        // TODO: a LOT of this is copy/pasted from the router, can we make this
147        // shared logic, or handled by the stack somehow?
148        //
149        // TODO: Assumption: "we" are always node_id==2
150        if hdr.dst.network_id == intfc.net_id && hdr.dst.node_id == 2 {
151            return Err(InterfaceSendError::DestinationLocal);
152        }
153        // If the source is local, rewrite the source using this interface's
154        // information so responses can find their way back here
155        if hdr.src.net_node_any() {
156            // todo: if we know the destination is EXACTLY this network,
157            // we could leave the network_id local to allow for shorter
158            // addresses
159            hdr.src.network_id = intfc.net_id;
160            hdr.src.node_id = 2;
161        }
162
163        let seq_no = self.seq_no;
164        self.seq_no = self.seq_no.wrapping_add(1);
165        let res = intfc.interface.skt_tx.try_send(OwnedFrame {
166            hdr: HeaderSeq {
167                src: hdr.src,
168                dst: hdr.dst,
169                seq_no,
170                key: hdr.key,
171                kind: hdr.kind,
172            },
173            body: data.to_vec(),
174        });
175        match res {
176            Ok(()) => Ok(()),
177            Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
178            Err(TrySendError::Closed(_)) => {
179                self.inner.take();
180                Err(InterfaceSendError::NoRouteToDest)
181            }
182        }
183    }
184}
185
186impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
187    pub async fn run(mut self) -> Result<(), ReceiverError> {
188        let res = self.run_inner().await;
189        // todo: this could live somewhere else?
190        self.stack.with_interface_manager(|im| {
191            _ = im.inner.take();
192        });
193        res
194    }
195
196    pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
197        let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
198        let mut raw_buf = [0u8; 4096];
199        let mut net_id = None;
200
201        loop {
202            let rd = self.skt.read(&mut raw_buf);
203            let close = self.closer.wait();
204
205            let ct = select! {
206                r = rd => {
207                    match r {
208                        Ok(0) | Err(_) => {
209                            println!("recv run closed");
210                            return Err(ReceiverError::SocketClosed)
211                        },
212                        Ok(ct) => ct,
213                    }
214                }
215                _c = close => {
216                    return Err(ReceiverError::SocketClosed);
217                }
218            };
219
220            let buf = &raw_buf[..ct];
221            let mut window = buf;
222
223            'cobs: while !window.is_empty() {
224                window = match cobs_buf.feed_raw(window) {
225                    FeedResult::Consumed => break 'cobs,
226                    FeedResult::OverFull(new_wind) => new_wind,
227                    FeedResult::DeserError(new_wind) => new_wind,
228                    FeedResult::Success { data, remaining } => {
229                        // Successfully de-cobs'd a packet, now we need to
230                        // do something with it.
231                        if let Some(mut frame) = de_frame(data) {
232                            println!("Got Frame!");
233                            let take_net = net_id.is_none()
234                                || net_id.is_some_and(|n| {
235                                    frame.hdr.dst.network_id != 0 && n != frame.hdr.dst.network_id
236                                });
237                            if take_net {
238                                self.stack.with_interface_manager(|im| {
239                                    if let Some(i) = im.inner.as_mut() {
240                                        // i am, whoever you say i am
241                                        i.net_id = frame.hdr.dst.network_id;
242                                    }
243                                    // else: uhhhhhh
244                                });
245                                net_id = Some(frame.hdr.dst.network_id);
246                            }
247
248                            // If the message comes in and has a src net_id of zero,
249                            // we should rewrite it so it isn't later understood as a
250                            // local packet.
251                            //
252                            // TODO: accept any packet if we don't have a net_id yet?
253                            if let Some(net) = net_id.as_ref() {
254                                if frame.hdr.src.network_id == 0 {
255                                    assert_ne!(
256                                        frame.hdr.src.node_id, 0,
257                                        "we got a local packet remotely?"
258                                    );
259                                    assert_ne!(
260                                        frame.hdr.src.node_id, 2,
261                                        "someone is pretending to be us?"
262                                    );
263
264                                    frame.hdr.src.network_id = *net;
265                                }
266                            }
267
268                            // TODO: if the destination IS self.net_id, we could rewrite the
269                            // dest net_id as zero to avoid a pass through the interface manager.
270                            //
271                            // If the dest is 0, should we rewrite the dest as self.net_id? This
272                            // is the opposite as above, but I dunno how that will work with responses
273                            let res = self.stack.send_raw(frame.hdr.into(), &frame.body);
274                            match res {
275                                Ok(()) => {}
276                                Err(e) => {
277                                    // TODO: match on error, potentially try to send NAK?
278                                    panic!("recv->send error: {e:?}");
279                                }
280                            }
281                        } else {
282                            println!(
283                                "Decode error! Ignoring frame on net_id {}",
284                                net_id.unwrap_or(0)
285                            );
286                        }
287
288                        remaining
289                    }
290                };
291            }
292        }
293    }
294}
295
296// Helper functions
297
298pub fn register_interface<R: ScopedRawMutex>(
299    stack: &'static NetStack<R, StdTcpClientIm>,
300    socket: TcpStream,
301) -> Result<StdTcpRecvHdl<R>, ClientError> {
302    let (rx, tx) = socket.into_split();
303    let (ctx, crx) = channel(64);
304    let closer = Arc::new(WaitQueue::new());
305    stack.with_interface_manager(|im| {
306        if im.inner.is_some() {
307            return Err(ClientError::SocketAlreadyActive);
308        }
309
310        im.inner = Some(StdTcpClientImInner {
311            interface: StdTcpTxHdl { skt_tx: ctx },
312            net_id: 0,
313            closer: closer.clone(),
314        });
315        // TODO: spawning in a non-async context!
316        tokio::task::spawn(tx_worker(tx, crx, closer.clone()));
317        Ok(())
318    })?;
319    Ok(StdTcpRecvHdl {
320        stack,
321        skt: rx,
322        closer,
323    })
324}
325
326async fn tx_worker(mut tx: OwnedWriteHalf, mut rx: Receiver<OwnedFrame>, closer: Arc<WaitQueue>) {
327    println!("Started tx_worker");
328    loop {
329        let rxf = rx.recv();
330        let clf = closer.wait();
331
332        let frame = select! {
333            r = rxf => {
334                if let Some(frame) = r {
335                    frame
336                } else {
337                    println!("tx_workerrx closed!");
338                    closer.close();
339                    break;
340                }
341            }
342            _c = clf => {
343                break;
344            }
345        };
346
347        let msg = ser_frame(frame);
348        println!("sending pkt len:{}", msg.len());
349        let res = tx.write_all(&msg).await;
350        if let Err(e) = res {
351            println!("Err: {e:?}");
352            break;
353        }
354    }
355    // TODO: GC waker?
356    println!("Closing interface");
357}