ergot_base/interface_manager/
std_tcp_client.rs1use std::sync::Arc;
9
10use crate::{
11 Header, HeaderSeq, NetStack,
12 interface_manager::std_utils::{OwnedFrame, ser_frame},
13};
14
15use super::{
16 ConstInit, InterfaceManager, InterfaceSendError,
17 std_utils::{
18 ReceiverError,
19 acc::{CobsAccumulator, FeedResult},
20 de_frame,
21 },
22};
23use maitake_sync::WaitQueue;
24use mutex::ScopedRawMutex;
25use tokio::{
26 io::{AsyncReadExt, AsyncWriteExt},
27 net::{
28 TcpStream,
29 tcp::{OwnedReadHalf, OwnedWriteHalf},
30 },
31 select,
32 sync::mpsc::{Receiver, Sender, channel, error::TrySendError},
33};
34
35#[derive(Default)]
36pub struct StdTcpClientIm {
37 inner: Option<StdTcpClientImInner>,
38 seq_no: u16,
39}
40
41struct StdTcpClientImInner {
42 interface: StdTcpTxHdl,
43 net_id: u16,
44 closer: Arc<WaitQueue>,
45}
46
47#[derive(Debug, PartialEq)]
48pub enum ClientError {
49 SocketAlreadyActive,
50}
51
52pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
53 stack: &'static NetStack<R, StdTcpClientIm>,
54 skt: OwnedReadHalf,
55 closer: Arc<WaitQueue>,
56}
57
58struct StdTcpTxHdl {
59 skt_tx: Sender<OwnedFrame>,
60}
61
62impl StdTcpClientIm {
65 pub const fn new() -> Self {
66 Self {
67 inner: None,
68 seq_no: 0,
69 }
70 }
71}
72
73impl ConstInit for StdTcpClientIm {
74 #[allow(clippy::declare_interior_mutable_const)]
75 const INIT: Self = Self::new();
76}
77
78impl InterfaceManager for StdTcpClientIm {
79 fn send<T: serde::Serialize>(
80 &mut self,
81 mut hdr: Header,
82 data: &T,
83 ) -> Result<(), InterfaceSendError> {
84 let Some(intfc) = self.inner.as_mut() else {
85 return Err(InterfaceSendError::NoRouteToDest);
86 };
87 if intfc.net_id == 0 {
88 return Err(InterfaceSendError::NoRouteToDest);
90 }
91 if hdr.dst.network_id == intfc.net_id && hdr.dst.node_id == 2 {
99 return Err(InterfaceSendError::DestinationLocal);
100 }
101 if hdr.src.net_node_any() {
104 hdr.src.network_id = intfc.net_id;
108 hdr.src.node_id = 2;
109 }
110
111 let seq_no = self.seq_no;
112 self.seq_no = self.seq_no.wrapping_add(1);
113 let res = intfc.interface.skt_tx.try_send(OwnedFrame {
114 hdr: HeaderSeq {
115 src: hdr.src,
116 dst: hdr.dst,
117 seq_no,
118 key: hdr.key,
119 kind: hdr.kind,
120 },
121 body: postcard::to_stdvec(data).unwrap(),
122 });
123 match res {
124 Ok(()) => Ok(()),
125 Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
126 Err(TrySendError::Closed(_)) => {
127 if let Some(i) = self.inner.take() {
128 i.closer.close();
129 }
130 Err(InterfaceSendError::NoRouteToDest)
131 }
132 }
133 }
134
135 fn send_raw(&mut self, mut hdr: Header, data: &[u8]) -> Result<(), InterfaceSendError> {
136 let Some(intfc) = self.inner.as_mut() else {
137 return Err(InterfaceSendError::NoRouteToDest);
138 };
139 if intfc.net_id == 0 {
140 return Err(InterfaceSendError::NoRouteToDest);
142 }
143 if hdr.dst.network_id == intfc.net_id && hdr.dst.node_id == 2 {
151 return Err(InterfaceSendError::DestinationLocal);
152 }
153 if hdr.src.net_node_any() {
156 hdr.src.network_id = intfc.net_id;
160 hdr.src.node_id = 2;
161 }
162
163 let seq_no = self.seq_no;
164 self.seq_no = self.seq_no.wrapping_add(1);
165 let res = intfc.interface.skt_tx.try_send(OwnedFrame {
166 hdr: HeaderSeq {
167 src: hdr.src,
168 dst: hdr.dst,
169 seq_no,
170 key: hdr.key,
171 kind: hdr.kind,
172 },
173 body: data.to_vec(),
174 });
175 match res {
176 Ok(()) => Ok(()),
177 Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
178 Err(TrySendError::Closed(_)) => {
179 self.inner.take();
180 Err(InterfaceSendError::NoRouteToDest)
181 }
182 }
183 }
184}
185
186impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
187 pub async fn run(mut self) -> Result<(), ReceiverError> {
188 let res = self.run_inner().await;
189 self.stack.with_interface_manager(|im| {
191 _ = im.inner.take();
192 });
193 res
194 }
195
196 pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
197 let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
198 let mut raw_buf = [0u8; 4096];
199 let mut net_id = None;
200
201 loop {
202 let rd = self.skt.read(&mut raw_buf);
203 let close = self.closer.wait();
204
205 let ct = select! {
206 r = rd => {
207 match r {
208 Ok(0) | Err(_) => {
209 println!("recv run closed");
210 return Err(ReceiverError::SocketClosed)
211 },
212 Ok(ct) => ct,
213 }
214 }
215 _c = close => {
216 return Err(ReceiverError::SocketClosed);
217 }
218 };
219
220 let buf = &raw_buf[..ct];
221 let mut window = buf;
222
223 'cobs: while !window.is_empty() {
224 window = match cobs_buf.feed_raw(window) {
225 FeedResult::Consumed => break 'cobs,
226 FeedResult::OverFull(new_wind) => new_wind,
227 FeedResult::DeserError(new_wind) => new_wind,
228 FeedResult::Success { data, remaining } => {
229 if let Some(mut frame) = de_frame(data) {
232 println!("Got Frame!");
233 let take_net = net_id.is_none()
234 || net_id.is_some_and(|n| {
235 frame.hdr.dst.network_id != 0 && n != frame.hdr.dst.network_id
236 });
237 if take_net {
238 self.stack.with_interface_manager(|im| {
239 if let Some(i) = im.inner.as_mut() {
240 i.net_id = frame.hdr.dst.network_id;
242 }
243 });
245 net_id = Some(frame.hdr.dst.network_id);
246 }
247
248 if let Some(net) = net_id.as_ref() {
254 if frame.hdr.src.network_id == 0 {
255 assert_ne!(
256 frame.hdr.src.node_id, 0,
257 "we got a local packet remotely?"
258 );
259 assert_ne!(
260 frame.hdr.src.node_id, 2,
261 "someone is pretending to be us?"
262 );
263
264 frame.hdr.src.network_id = *net;
265 }
266 }
267
268 let res = self.stack.send_raw(frame.hdr.into(), &frame.body);
274 match res {
275 Ok(()) => {}
276 Err(e) => {
277 panic!("recv->send error: {e:?}");
279 }
280 }
281 } else {
282 println!(
283 "Decode error! Ignoring frame on net_id {}",
284 net_id.unwrap_or(0)
285 );
286 }
287
288 remaining
289 }
290 };
291 }
292 }
293 }
294}
295
296pub fn register_interface<R: ScopedRawMutex>(
299 stack: &'static NetStack<R, StdTcpClientIm>,
300 socket: TcpStream,
301) -> Result<StdTcpRecvHdl<R>, ClientError> {
302 let (rx, tx) = socket.into_split();
303 let (ctx, crx) = channel(64);
304 let closer = Arc::new(WaitQueue::new());
305 stack.with_interface_manager(|im| {
306 if im.inner.is_some() {
307 return Err(ClientError::SocketAlreadyActive);
308 }
309
310 im.inner = Some(StdTcpClientImInner {
311 interface: StdTcpTxHdl { skt_tx: ctx },
312 net_id: 0,
313 closer: closer.clone(),
314 });
315 tokio::task::spawn(tx_worker(tx, crx, closer.clone()));
317 Ok(())
318 })?;
319 Ok(StdTcpRecvHdl {
320 stack,
321 skt: rx,
322 closer,
323 })
324}
325
326async fn tx_worker(mut tx: OwnedWriteHalf, mut rx: Receiver<OwnedFrame>, closer: Arc<WaitQueue>) {
327 println!("Started tx_worker");
328 loop {
329 let rxf = rx.recv();
330 let clf = closer.wait();
331
332 let frame = select! {
333 r = rxf => {
334 if let Some(frame) = r {
335 frame
336 } else {
337 println!("tx_workerrx closed!");
338 closer.close();
339 break;
340 }
341 }
342 _c = clf => {
343 break;
344 }
345 };
346
347 let msg = ser_frame(frame);
348 println!("sending pkt len:{}", msg.len());
349 let res = tx.write_all(&msg).await;
350 if let Err(e) = res {
351 println!("Err: {e:?}");
352 break;
353 }
354 }
355 println!("Closing interface");
357}