ergot_base/interface_manager/profiles/direct_edge/
std_tcp.rs

1//! A std+tcp edge device profile
2//!
3//! This is useful for std based devices/applications that can directly connect to a DirectRouter
4//! using a tcp connection.
5
6use std::sync::Arc;
7
8use crate::{
9    Header,
10    interface_manager::{
11        InterfaceState, Profile,
12        interface_impls::std_tcp::StdTcpInterface,
13        profiles::direct_edge::{DirectEdge, EDGE_NODE_ID},
14        utils::std::{
15            ReceiverError, StdQueue,
16            acc::{CobsAccumulator, FeedResult},
17        },
18    },
19    net_stack::NetStackHandle,
20    wire_frames::de_frame,
21};
22
23use bbq2::{prod_cons::stream::StreamConsumer, traits::bbqhdl::BbqHandle};
24use log::{debug, error, info, warn};
25use maitake_sync::WaitQueue;
26use tokio::{
27    io::{AsyncReadExt, AsyncWriteExt},
28    net::{
29        TcpStream,
30        tcp::{OwnedReadHalf, OwnedWriteHalf},
31    },
32    select,
33};
34
35pub type StdTcpClientIm = DirectEdge<StdTcpInterface>;
36
37pub struct RxWorker<N: NetStackHandle> {
38    stack: N,
39    skt: OwnedReadHalf,
40    closer: Arc<WaitQueue>,
41}
42
43// ---- impls ----
44
45impl<N> RxWorker<N>
46where
47    N: NetStackHandle<Profile = DirectEdge<StdTcpInterface>>,
48{
49    pub async fn run(mut self) -> Result<(), ReceiverError> {
50        let res = self.run_inner().await;
51        // todo: this could live somewhere else?
52        self.stack.stack().manage_profile(|im| {
53            _ = im.set_interface_state((), InterfaceState::Down);
54        });
55        res
56    }
57
58    pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
59        let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
60        let mut raw_buf = [0u8; 4096];
61        let mut net_id = None;
62
63        loop {
64            let rd = self.skt.read(&mut raw_buf);
65            let close = self.closer.wait();
66
67            let ct = select! {
68                r = rd => {
69                    match r {
70                        Ok(0) | Err(_) => {
71                            warn!("recv run closed");
72                            return Err(ReceiverError::SocketClosed)
73                        },
74                        Ok(ct) => ct,
75                    }
76                }
77                _c = close => {
78                    return Err(ReceiverError::SocketClosed);
79                }
80            };
81
82            let buf = &mut raw_buf[..ct];
83            let mut window = buf;
84
85            'cobs: while !window.is_empty() {
86                window = match cobs_buf.feed_raw(window) {
87                    FeedResult::Consumed => break 'cobs,
88                    FeedResult::OverFull(new_wind) => new_wind,
89                    FeedResult::DecodeError(new_wind) => new_wind,
90                    FeedResult::Success { data, remaining }
91                    | FeedResult::SuccessInput { data, remaining } => {
92                        // Successfully de-cobs'd a packet, now we need to
93                        // do something with it.
94                        if let Some(mut frame) = de_frame(data) {
95                            debug!("Got Frame!");
96                            let take_net = net_id.is_none()
97                                || net_id.is_some_and(|n| {
98                                    frame.hdr.dst.network_id != 0 && n != frame.hdr.dst.network_id
99                                });
100                            if take_net {
101                                self.stack.stack().manage_profile(|im| {
102                                    im.set_interface_state(
103                                        (),
104                                        InterfaceState::Active {
105                                            net_id: frame.hdr.dst.network_id,
106                                            node_id: EDGE_NODE_ID,
107                                        },
108                                    )
109                                    .unwrap();
110                                });
111                                net_id = Some(frame.hdr.dst.network_id);
112                            }
113
114                            // If the message comes in and has a src net_id of zero,
115                            // we should rewrite it so it isn't later understood as a
116                            // local packet.
117                            //
118                            // TODO: accept any packet if we don't have a net_id yet?
119                            if let Some(net) = net_id.as_ref() {
120                                if frame.hdr.src.network_id == 0 {
121                                    assert_ne!(
122                                        frame.hdr.src.node_id, 0,
123                                        "we got a local packet remotely?"
124                                    );
125                                    assert_ne!(
126                                        frame.hdr.src.node_id, 2,
127                                        "someone is pretending to be us?"
128                                    );
129
130                                    frame.hdr.src.network_id = *net;
131                                }
132                            }
133
134                            // TODO: if the destination IS self.net_id, we could rewrite the
135                            // dest net_id as zero to avoid a pass through the interface manager.
136                            //
137                            // If the dest is 0, should we rewrite the dest as self.net_id? This
138                            // is the opposite as above, but I dunno how that will work with responses
139                            let hdr = frame.hdr.clone();
140                            let hdr: Header = hdr.into();
141                            let res = match frame.body {
142                                Ok(body) => self.stack.stack().send_raw(&hdr, frame.hdr_raw, body),
143                                Err(e) => self.stack.stack().send_err(&hdr, e),
144                            };
145                            match res {
146                                Ok(()) => {}
147                                Err(e) => {
148                                    // TODO: match on error, potentially try to send NAK?
149                                    panic!("recv->send error: {e:?}");
150                                }
151                            }
152                        } else {
153                            warn!(
154                                "Decode error! Ignoring frame on net_id {}",
155                                net_id.unwrap_or(0)
156                            );
157                        }
158
159                        remaining
160                    }
161                };
162            }
163        }
164    }
165}
166
167#[derive(Debug, PartialEq)]
168pub struct SocketAlreadyActive;
169
170// Helper functions
171
172pub async fn register_target_interface<N>(
173    stack: N,
174    socket: TcpStream,
175    queue: StdQueue,
176) -> Result<(), SocketAlreadyActive>
177where
178    N: NetStackHandle<Profile = DirectEdge<StdTcpInterface>>,
179    N: Send + 'static,
180{
181    let (rx, tx) = socket.into_split();
182    let closer = Arc::new(WaitQueue::new());
183    stack.stack().manage_profile(|im| {
184        match im.interface_state(()) {
185            Some(InterfaceState::Down) => {}
186            Some(InterfaceState::Inactive) => return Err(SocketAlreadyActive),
187            Some(InterfaceState::ActiveLocal { .. }) => return Err(SocketAlreadyActive),
188            Some(InterfaceState::Active { .. }) => return Err(SocketAlreadyActive),
189            None => {}
190        }
191
192        im.set_interface_state((), InterfaceState::Inactive)
193            .map_err(|_| SocketAlreadyActive)?;
194
195        Ok(())
196    })?;
197    let rx_worker = RxWorker {
198        stack,
199        skt: rx,
200        closer: closer.clone(),
201    };
202    // TODO: spawning in a non-async context!
203    tokio::task::spawn(tx_worker(tx, queue.stream_consumer(), closer.clone()));
204    tokio::task::spawn(rx_worker.run());
205    Ok(())
206}
207
208async fn tx_worker(mut tx: OwnedWriteHalf, rx: StreamConsumer<StdQueue>, closer: Arc<WaitQueue>) {
209    info!("Started tx_worker");
210    loop {
211        let rxf = rx.wait_read();
212        let clf = closer.wait();
213
214        let frame = select! {
215            r = rxf => r,
216            _c = clf => {
217                break;
218            }
219        };
220
221        let len = frame.len();
222        info!("sending pkt len:{}", len);
223        let res = tx.write_all(&frame).await;
224        frame.release(len);
225        if let Err(e) = res {
226            error!("Err: {e:?}");
227            break;
228        }
229    }
230    // TODO: GC waker?
231    warn!("Closing interface");
232}