ergot_base/interface_manager/profiles/direct_router/
std_tcp.rs

1//! A TCP based DirectRouter
2//!
3//! This implementation can be used to connect to a number of direct edge TCP devices.
4
5use bbq2::{prod_cons::stream::StreamConsumer, traits::bbqhdl::BbqHandle};
6use cobs::max_encoding_overhead;
7use log::{debug, error, info, warn};
8use maitake_sync::WaitQueue;
9use std::sync::Arc;
10use tokio::{
11    io::{AsyncReadExt, AsyncWriteExt},
12    net::{
13        TcpStream,
14        tcp::{OwnedReadHalf, OwnedWriteHalf},
15    },
16    select,
17};
18
19use crate::{
20    Header,
21    interface_manager::{
22        InterfaceState, Profile,
23        interface_impls::std_tcp::StdTcpInterface,
24        utils::{
25            cobs_stream::Sink,
26            std::{
27                ReceiverError, StdQueue,
28                acc::{CobsAccumulator, FeedResult},
29                new_std_queue,
30            },
31        },
32    },
33    net_stack::NetStackHandle,
34    wire_frames::de_frame,
35};
36
37use super::DirectRouter;
38
39#[derive(Debug, PartialEq)]
40pub enum Error {
41    OutOfNetIds,
42}
43
44struct TxWorker {
45    net_id: u16,
46    tx: OwnedWriteHalf,
47    rx: StreamConsumer<StdQueue>,
48    closer: Arc<WaitQueue>,
49}
50
51struct RxWorker<N>
52where
53    N: NetStackHandle<Profile = DirectRouter<StdTcpInterface>>,
54    N: Send + 'static,
55{
56    interface_id: u64,
57    net_id: u16,
58    nsh: N,
59    skt: OwnedReadHalf,
60    closer: Arc<WaitQueue>,
61    mtu: u16,
62}
63
64impl TxWorker {
65    async fn run(mut self) {
66        self.run_inner().await;
67        warn!("Closing interface {}", self.net_id);
68        self.closer.close();
69    }
70
71    async fn run_inner(&mut self) {
72        info!("Started tx_worker for net_id {}", self.net_id);
73        loop {
74            let rxf = self.rx.wait_read();
75            let clf = self.closer.wait();
76
77            let frame = select! {
78                r = rxf => r,
79                _c = clf => {
80                    break;
81                }
82            };
83
84            let len = frame.len();
85            debug!("sending pkt len:{} on net_id {}", len, self.net_id);
86            let res = self.tx.write_all(&frame).await;
87            frame.release(len);
88            if let Err(e) = res {
89                error!("Err: {e:?}");
90                break;
91            }
92        }
93    }
94}
95
96impl<N> RxWorker<N>
97where
98    N: NetStackHandle<Profile = DirectRouter<StdTcpInterface>>,
99    N: Send + 'static,
100{
101    async fn run(mut self) {
102        let close = self.closer.clone();
103
104        // Wait for the receiver to encounter an error, or wait for
105        // the transmitter to signal that it observed an error
106        select! {
107            run = self.run_inner() => {
108                // Halt the TX worker
109                self.closer.close();
110                error!("Receive Error: {run:?}");
111            },
112            _clf = close.wait() => {},
113        }
114
115        // Remove this interface from the list
116        self.nsh.stack().manage_profile(|im| {
117            _ = im.deregister_interface(self.interface_id);
118        });
119    }
120
121    pub async fn run_inner(&mut self) -> ReceiverError {
122        let overhead = max_encoding_overhead(self.mtu as usize);
123        let mut cobs_buf = CobsAccumulator::new(self.mtu as usize + overhead);
124        let mut raw_buf = vec![0u8; 4096].into_boxed_slice();
125
126        loop {
127            let rd = self.skt.read(&mut raw_buf);
128            let close = self.closer.wait();
129
130            let ct = select! {
131                r = rd => {
132                    match r {
133                        Ok(0) | Err(_) => {
134                            warn!("recv run {} closed", self.net_id);
135                            return ReceiverError::SocketClosed
136                        },
137                        Ok(ct) => ct,
138                    }
139                }
140                _c = close => {
141                    return ReceiverError::SocketClosed;
142                }
143            };
144
145            let buf = &mut raw_buf[..ct];
146            let mut window = buf;
147
148            'cobs: while !window.is_empty() {
149                window = match cobs_buf.feed_raw(window) {
150                    FeedResult::Consumed => break 'cobs,
151                    FeedResult::OverFull(new_wind) => new_wind,
152                    FeedResult::DecodeError(new_wind) => new_wind,
153                    FeedResult::Success { data, remaining }
154                    | FeedResult::SuccessInput { data, remaining } => {
155                        // Successfully de-cobs'd a packet, now we need to
156                        // do something with it.
157                        if let Some(mut frame) = de_frame(data) {
158                            // If the message comes in and has a src net_id of zero,
159                            // we should rewrite it so it isn't later understood as a
160                            // local packet.
161                            if frame.hdr.src.network_id == 0 {
162                                assert_ne!(
163                                    frame.hdr.src.node_id, 0,
164                                    "we got a local packet remotely?"
165                                );
166                                assert_ne!(
167                                    frame.hdr.src.node_id, 1,
168                                    "someone is pretending to be us?"
169                                );
170
171                                frame.hdr.src.network_id = self.net_id;
172                            }
173                            // TODO: if the destination IS self.net_id, we could rewrite the
174                            // dest net_id as zero to avoid a pass through the interface manager.
175                            //
176                            // If the dest is 0, should we rewrite the dest as self.net_id? This
177                            // is the opposite as above, but I dunno how that will work with responses
178                            let hdr = frame.hdr.clone();
179                            let hdr: Header = hdr.into();
180
181                            let res = match frame.body {
182                                Ok(body) => self.nsh.stack().send_raw(&hdr, frame.hdr_raw, body),
183                                Err(e) => self.nsh.stack().send_err(&hdr, e),
184                            };
185                            match res {
186                                Ok(()) => {}
187                                Err(e) => {
188                                    // TODO: match on error, potentially try to send NAK?
189                                    warn!("recv->send error: {e:?}");
190                                }
191                            }
192                        } else {
193                            warn!("Decode error! Ignoring frame on net_id {}", self.net_id);
194                        }
195
196                        remaining
197                    }
198                };
199            }
200        }
201    }
202}
203
204pub async fn register_interface<N>(
205    stack: N,
206    socket: TcpStream,
207    max_ergot_packet_size: u16,
208    outgoing_buffer_size: usize,
209) -> Result<u64, Error>
210where
211    N: NetStackHandle<Profile = DirectRouter<StdTcpInterface>>,
212    N: Send + 'static,
213{
214    let (rx, tx) = socket.into_split();
215    let q: StdQueue = new_std_queue(outgoing_buffer_size);
216    let res = stack.stack().manage_profile(|im| {
217        let ident =
218            im.register_interface(Sink::new_from_handle(q.clone(), max_ergot_packet_size))?;
219        let state = im.interface_state(ident)?;
220        match state {
221            InterfaceState::Active { net_id, node_id: _ } => Some((ident, net_id)),
222            _ => {
223                _ = im.deregister_interface(ident);
224                None
225            }
226        }
227    });
228    let Some((ident, net_id)) = res else {
229        return Err(Error::OutOfNetIds);
230    };
231    let closer = Arc::new(WaitQueue::new());
232    let rx_worker = RxWorker {
233        nsh: stack.clone(),
234        skt: rx,
235        closer: closer.clone(),
236        mtu: max_ergot_packet_size,
237        interface_id: ident,
238        net_id,
239    };
240    let tx_worker = TxWorker {
241        net_id,
242        tx,
243        rx: <StdQueue as BbqHandle>::stream_consumer(&q),
244        closer,
245    };
246
247    tokio::task::spawn(rx_worker.run());
248    tokio::task::spawn(tx_worker.run());
249
250    Ok(ident)
251}