ergot_base/interface_manager/
std_tcp_client.rs1use std::sync::Arc;
9
10use crate::{
11 Header, Key, NetStack,
12 interface_manager::{
13 ConstInit, InterfaceManager, InterfaceSendError, cobs_stream,
14 std_utils::{
15 ReceiverError, StdQueue,
16 acc::{CobsAccumulator, FeedResult},
17 },
18 wire_frames::{CommonHeader, de_frame},
19 },
20};
21use bbq2::{prod_cons::stream::StreamConsumer, traits::storage::BoxedSlice};
22use log::{debug, error, info, warn};
23use maitake_sync::WaitQueue;
24use mutex::ScopedRawMutex;
25use tokio::{
26 io::{AsyncReadExt, AsyncWriteExt},
27 net::{
28 TcpStream,
29 tcp::{OwnedReadHalf, OwnedWriteHalf},
30 },
31 select,
32};
33
34#[derive(Default)]
35pub struct StdTcpClientIm {
36 inner: Option<StdTcpClientImInner>,
37 seq_no: u16,
38}
39
40struct StdTcpClientImInner {
41 interface: StdTcpTxHdl,
42 net_id: u16,
43 closer: Arc<WaitQueue>,
44}
45
46#[derive(Debug, PartialEq)]
47pub enum ClientError {
48 SocketAlreadyActive,
49}
50
51pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
52 stack: &'static NetStack<R, StdTcpClientIm>,
53 skt: OwnedReadHalf,
54 closer: Arc<WaitQueue>,
55}
56
57struct StdTcpTxHdl {
58 skt_tx: cobs_stream::Interface<StdQueue>,
59}
60
61impl StdTcpClientIm {
64 pub const fn new() -> Self {
65 Self {
66 inner: None,
67 seq_no: 0,
68 }
69 }
70}
71
72impl ConstInit for StdTcpClientIm {
73 #[allow(clippy::declare_interior_mutable_const)]
74 const INIT: Self = Self::new();
75}
76
77impl StdTcpClientIm {
78 fn common_send<'a, 'b>(
79 &'b mut self,
80 ihdr: &'a Header,
81 ) -> Result<(&'b mut StdTcpClientImInner, CommonHeader, Option<&'a Key>), InterfaceSendError>
82 {
83 let intfc = match self.inner.take() {
84 None => return Err(InterfaceSendError::NoRouteToDest),
85 Some(intfc) if intfc.closer.is_closed() => {
86 drop(intfc);
87 return Err(InterfaceSendError::NoRouteToDest);
88 }
89 Some(intfc) => self.inner.insert(intfc),
90 };
91
92 if intfc.net_id == 0 {
93 return Err(InterfaceSendError::NoRouteToDest);
95 }
96 if ihdr.dst.network_id == intfc.net_id && ihdr.dst.node_id == 2 {
104 return Err(InterfaceSendError::DestinationLocal);
105 }
106
107 let mut hdr = ihdr.clone();
110 hdr.decrement_ttl()?;
111
112 if hdr.src.net_node_any() {
115 hdr.src.network_id = intfc.net_id;
119 hdr.src.node_id = 2;
120 }
121
122 if hdr.dst.port_id == 255 {
125 hdr.dst.network_id = intfc.net_id;
126 hdr.dst.node_id = 1;
127 }
128
129 let seq_no = self.seq_no;
130 self.seq_no = self.seq_no.wrapping_add(1);
131
132 let header = CommonHeader {
133 src: hdr.src.as_u32(),
134 dst: hdr.dst.as_u32(),
135 seq_no,
136 kind: hdr.kind.0,
137 ttl: hdr.ttl,
138 };
139 let key = if [0, 255].contains(&hdr.dst.port_id) {
140 Some(ihdr.key.as_ref().unwrap())
141 } else {
142 None
143 };
144
145 Ok((intfc, header, key))
146 }
147}
148
149impl InterfaceManager for StdTcpClientIm {
150 fn send<T: serde::Serialize>(
151 &mut self,
152 hdr: &Header,
153 data: &T,
154 ) -> Result<(), InterfaceSendError> {
155 let (intfc, header, key) = self.common_send(hdr)?;
156 let res = intfc.interface.skt_tx.send_ty(&header, key, data);
157
158 match res {
159 Ok(()) => Ok(()),
160 Err(()) => Err(InterfaceSendError::InterfaceFull),
161 }
162 }
163
164 fn send_raw(&mut self, hdr: &Header, data: &[u8]) -> Result<(), InterfaceSendError> {
165 let (intfc, header, key) = self.common_send(hdr)?;
166 let res = intfc.interface.skt_tx.send_raw(&header, key, data);
167
168 match res {
169 Ok(()) => Ok(()),
170 Err(()) => Err(InterfaceSendError::InterfaceFull),
171 }
172 }
173
174 fn send_err(
175 &mut self,
176 hdr: &Header,
177 err: crate::ProtocolError,
178 ) -> Result<(), InterfaceSendError> {
179 let (intfc, header, _key) = self.common_send(hdr)?;
180 let res = intfc.interface.skt_tx.send_err(&header, err);
181
182 match res {
183 Ok(()) => Ok(()),
184 Err(()) => Err(InterfaceSendError::InterfaceFull),
185 }
186 }
187}
188
189impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
190 pub async fn run(mut self) -> Result<(), ReceiverError> {
191 let res = self.run_inner().await;
192 self.stack.with_interface_manager(|im| {
194 _ = im.inner.take();
195 });
196 res
197 }
198
199 pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
200 let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
201 let mut raw_buf = [0u8; 4096];
202 let mut net_id = None;
203
204 loop {
205 let rd = self.skt.read(&mut raw_buf);
206 let close = self.closer.wait();
207
208 let ct = select! {
209 r = rd => {
210 match r {
211 Ok(0) | Err(_) => {
212 warn!("recv run closed");
213 return Err(ReceiverError::SocketClosed)
214 },
215 Ok(ct) => ct,
216 }
217 }
218 _c = close => {
219 return Err(ReceiverError::SocketClosed);
220 }
221 };
222
223 let buf = &raw_buf[..ct];
224 let mut window = buf;
225
226 'cobs: while !window.is_empty() {
227 window = match cobs_buf.feed_raw(window) {
228 FeedResult::Consumed => break 'cobs,
229 FeedResult::OverFull(new_wind) => new_wind,
230 FeedResult::DeserError(new_wind) => new_wind,
231 FeedResult::Success { data, remaining } => {
232 if let Some(mut frame) = de_frame(data) {
235 debug!("Got Frame!");
236 let take_net = net_id.is_none()
237 || net_id.is_some_and(|n| {
238 frame.hdr.dst.network_id != 0 && n != frame.hdr.dst.network_id
239 });
240 if take_net {
241 self.stack.with_interface_manager(|im| {
242 if let Some(i) = im.inner.as_mut() {
243 i.net_id = frame.hdr.dst.network_id;
245 }
246 });
248 net_id = Some(frame.hdr.dst.network_id);
249 }
250
251 if let Some(net) = net_id.as_ref() {
257 if frame.hdr.src.network_id == 0 {
258 assert_ne!(
259 frame.hdr.src.node_id, 0,
260 "we got a local packet remotely?"
261 );
262 assert_ne!(
263 frame.hdr.src.node_id, 2,
264 "someone is pretending to be us?"
265 );
266
267 frame.hdr.src.network_id = *net;
268 }
269 }
270
271 let hdr = frame.hdr.clone();
277 let hdr: Header = hdr.into();
278 let res = match frame.body {
279 Ok(body) => self.stack.send_raw(&hdr, body),
280 Err(e) => self.stack.send_err(&hdr, e),
281 };
282 match res {
283 Ok(()) => {}
284 Err(e) => {
285 panic!("recv->send error: {e:?}");
287 }
288 }
289 } else {
290 warn!(
291 "Decode error! Ignoring frame on net_id {}",
292 net_id.unwrap_or(0)
293 );
294 }
295
296 remaining
297 }
298 };
299 }
300 }
301 }
302}
303
304pub fn register_interface<R: ScopedRawMutex>(
307 stack: &'static NetStack<R, StdTcpClientIm>,
308 socket: TcpStream,
309) -> Result<StdTcpRecvHdl<R>, ClientError> {
310 let (rx, tx) = socket.into_split();
311 let closer = Arc::new(WaitQueue::new());
312 stack.with_interface_manager(|im| {
313 if im.inner.is_some() {
314 return Err(ClientError::SocketAlreadyActive);
315 }
316
317 let q = bbq2::nicknames::Lechon::new_with_storage(BoxedSlice::new(4096));
318 let ctx = q.stream_producer();
319 let crx = q.stream_consumer();
320
321 im.inner = Some(StdTcpClientImInner {
322 interface: StdTcpTxHdl {
323 skt_tx: cobs_stream::Interface {
324 mtu: 1024,
325 prod: ctx,
326 },
327 },
328 net_id: 0,
329 closer: closer.clone(),
330 });
331 tokio::task::spawn(tx_worker(tx, crx, closer.clone()));
333 Ok(())
334 })?;
335 Ok(StdTcpRecvHdl {
336 stack,
337 skt: rx,
338 closer,
339 })
340}
341
342async fn tx_worker(mut tx: OwnedWriteHalf, rx: StreamConsumer<StdQueue>, closer: Arc<WaitQueue>) {
343 info!("Started tx_worker");
344 loop {
345 let rxf = rx.wait_read();
346 let clf = closer.wait();
347
348 let frame = select! {
349 r = rxf => r,
350 _c = clf => {
351 break;
352 }
353 };
354
355 let len = frame.len();
356 info!("sending pkt len:{}", len);
357 let res = tx.write_all(&frame).await;
358 frame.release(len);
359 if let Err(e) = res {
360 error!("Err: {e:?}");
361 break;
362 }
363 }
364 warn!("Closing interface");
366}