ergot_base/interface_manager/profiles/direct_router/
std_tcp.rs1use bbq2::{prod_cons::stream::StreamConsumer, traits::bbqhdl::BbqHandle};
6use cobs::max_encoding_overhead;
7use log::{debug, error, info, warn};
8use maitake_sync::WaitQueue;
9use std::sync::Arc;
10use tokio::{
11 io::{AsyncReadExt, AsyncWriteExt},
12 net::{
13 TcpStream,
14 tcp::{OwnedReadHalf, OwnedWriteHalf},
15 },
16 select,
17};
18
19use crate::{
20 Header,
21 interface_manager::{
22 InterfaceState, Profile,
23 interface_impls::std_tcp::StdTcpInterface,
24 utils::{
25 cobs_stream::Sink,
26 std::{
27 ReceiverError, StdQueue,
28 acc::{CobsAccumulator, FeedResult},
29 new_std_queue,
30 },
31 },
32 },
33 net_stack::NetStackHandle,
34 wire_frames::de_frame,
35};
36
37use super::DirectRouter;
38
39#[derive(Debug, PartialEq)]
40pub enum Error {
41 OutOfNetIds,
42}
43
44struct TxWorker {
45 net_id: u16,
46 tx: OwnedWriteHalf,
47 rx: StreamConsumer<StdQueue>,
48 closer: Arc<WaitQueue>,
49}
50
51struct RxWorker<N>
52where
53 N: NetStackHandle<Profile = DirectRouter<StdTcpInterface>>,
54 N: Send + 'static,
55{
56 interface_id: u64,
57 net_id: u16,
58 nsh: N,
59 skt: OwnedReadHalf,
60 closer: Arc<WaitQueue>,
61 mtu: u16,
62}
63
64impl TxWorker {
65 async fn run(mut self) {
66 self.run_inner().await;
67 warn!("Closing interface {}", self.net_id);
68 self.closer.close();
69 }
70
71 async fn run_inner(&mut self) {
72 info!("Started tx_worker for net_id {}", self.net_id);
73 loop {
74 let rxf = self.rx.wait_read();
75 let clf = self.closer.wait();
76
77 let frame = select! {
78 r = rxf => r,
79 _c = clf => {
80 break;
81 }
82 };
83
84 let len = frame.len();
85 debug!("sending pkt len:{} on net_id {}", len, self.net_id);
86 let res = self.tx.write_all(&frame).await;
87 frame.release(len);
88 if let Err(e) = res {
89 error!("Err: {e:?}");
90 break;
91 }
92 }
93 }
94}
95
96impl<N> RxWorker<N>
97where
98 N: NetStackHandle<Profile = DirectRouter<StdTcpInterface>>,
99 N: Send + 'static,
100{
101 async fn run(mut self) {
102 let close = self.closer.clone();
103
104 select! {
107 run = self.run_inner() => {
108 self.closer.close();
110 error!("Receive Error: {run:?}");
111 },
112 _clf = close.wait() => {},
113 }
114
115 self.nsh.stack().manage_profile(|im| {
117 _ = im.deregister_interface(self.interface_id);
118 });
119 }
120
121 pub async fn run_inner(&mut self) -> ReceiverError {
122 let overhead = max_encoding_overhead(self.mtu as usize);
123 let mut cobs_buf = CobsAccumulator::new(self.mtu as usize + overhead);
124 let mut raw_buf = vec![0u8; 4096].into_boxed_slice();
125
126 loop {
127 let rd = self.skt.read(&mut raw_buf);
128 let close = self.closer.wait();
129
130 let ct = select! {
131 r = rd => {
132 match r {
133 Ok(0) | Err(_) => {
134 warn!("recv run {} closed", self.net_id);
135 return ReceiverError::SocketClosed
136 },
137 Ok(ct) => ct,
138 }
139 }
140 _c = close => {
141 return ReceiverError::SocketClosed;
142 }
143 };
144
145 let buf = &mut raw_buf[..ct];
146 let mut window = buf;
147
148 'cobs: while !window.is_empty() {
149 window = match cobs_buf.feed_raw(window) {
150 FeedResult::Consumed => break 'cobs,
151 FeedResult::OverFull(new_wind) => new_wind,
152 FeedResult::DecodeError(new_wind) => new_wind,
153 FeedResult::Success { data, remaining }
154 | FeedResult::SuccessInput { data, remaining } => {
155 if let Some(mut frame) = de_frame(data) {
158 if frame.hdr.src.network_id == 0 {
162 assert_ne!(
163 frame.hdr.src.node_id, 0,
164 "we got a local packet remotely?"
165 );
166 assert_ne!(
167 frame.hdr.src.node_id, 1,
168 "someone is pretending to be us?"
169 );
170
171 frame.hdr.src.network_id = self.net_id;
172 }
173 let hdr = frame.hdr.clone();
179 let hdr: Header = hdr.into();
180
181 let res = match frame.body {
182 Ok(body) => self.nsh.stack().send_raw(&hdr, frame.hdr_raw, body),
183 Err(e) => self.nsh.stack().send_err(&hdr, e),
184 };
185 match res {
186 Ok(()) => {}
187 Err(e) => {
188 warn!("recv->send error: {e:?}");
190 }
191 }
192 } else {
193 warn!("Decode error! Ignoring frame on net_id {}", self.net_id);
194 }
195
196 remaining
197 }
198 };
199 }
200 }
201 }
202}
203
204pub async fn register_interface<N>(
205 stack: N,
206 socket: TcpStream,
207 max_ergot_packet_size: u16,
208 outgoing_buffer_size: usize,
209) -> Result<u64, Error>
210where
211 N: NetStackHandle<Profile = DirectRouter<StdTcpInterface>>,
212 N: Send + 'static,
213{
214 let (rx, tx) = socket.into_split();
215 let q: StdQueue = new_std_queue(outgoing_buffer_size);
216 let res = stack.stack().manage_profile(|im| {
217 let ident =
218 im.register_interface(Sink::new_from_handle(q.clone(), max_ergot_packet_size))?;
219 let state = im.interface_state(ident)?;
220 match state {
221 InterfaceState::Active { net_id, node_id: _ } => Some((ident, net_id)),
222 _ => {
223 _ = im.deregister_interface(ident);
224 None
225 }
226 }
227 });
228 let Some((ident, net_id)) = res else {
229 return Err(Error::OutOfNetIds);
230 };
231 let closer = Arc::new(WaitQueue::new());
232 let rx_worker = RxWorker {
233 nsh: stack.clone(),
234 skt: rx,
235 closer: closer.clone(),
236 mtu: max_ergot_packet_size,
237 interface_id: ident,
238 net_id,
239 };
240 let tx_worker = TxWorker {
241 net_id,
242 tx,
243 rx: <StdQueue as BbqHandle>::stream_consumer(&q),
244 closer,
245 };
246
247 tokio::task::spawn(rx_worker.run());
248 tokio::task::spawn(tx_worker.run());
249
250 Ok(ident)
251}