ergot_base/net_stack.rs
1//! The Ergot NetStack
2//!
3//! The [`NetStack`] is the core of Ergot. It is intended to be placed
4//! in a `static` variable for the duration of your application.
5//!
6//! The Netstack is used directly for a couple of main responsibilities:
7//!
8//! 1. Sending a message, either from user code, or to deliver/forward messages
9//! received from an interface
10//! 2. Attaching a socket, allowing the NetStack to route messages to it
11//! 3. Interacting with the [interface manager], in order to add/remove
12//! interfaces, or obtain other information
13//!
14//! [interface manager]: crate::interface_manager
15//!
16//! In general, interacting with anything contained by the [`NetStack`] requires
17//! locking of the [`BlockingMutex`] which protects the inner contents. This
18//! is used both to allow sharing of the inner contents, but also to allow
19//! `Drop` impls to remove themselves from the stack in a blocking manner.
20
21use core::{any::TypeId, mem::ManuallyDrop, ptr::NonNull};
22
23use cordyceps::List;
24use mutex::{BlockingMutex, ConstInit, ScopedRawMutex};
25use serde::Serialize;
26
27use crate::{
28 Header,
29 interface_manager::{self, InterfaceManager, InterfaceSendError},
30 socket::{SocketHeader, SocketSendError, SocketVTable},
31};
32
33/// The Ergot Netstack
34pub struct NetStack<R: ScopedRawMutex, M: InterfaceManager> {
35 inner: BlockingMutex<R, NetStackInner<M>>,
36}
37
38pub(crate) struct NetStackInner<M: InterfaceManager> {
39 sockets: List<SocketHeader>,
40 manager: M,
41 pcache_bits: u32,
42 pcache_start: u8,
43 seq_no: u16,
44}
45
46/// An error from calling a [`NetStack`] "send" method
47#[derive(Debug, PartialEq, Eq)]
48#[non_exhaustive]
49pub enum NetStackSendError {
50 SocketSend(SocketSendError),
51 InterfaceSend(InterfaceSendError),
52 NoRoute,
53 AnyPortMissingKey,
54 WrongPortKind,
55}
56
57// ---- impl NetStack ----
58
59impl<R, M> NetStack<R, M>
60where
61 R: ScopedRawMutex + ConstInit,
62 M: InterfaceManager + interface_manager::ConstInit,
63{
64 /// Create a new, uninitialized [`NetStack`].
65 ///
66 /// Requires that the [`ScopedRawMutex`] implements the [`mutex::ConstInit`]
67 /// trait, and the [`InterfaceManager`] implements the
68 /// [`interface_manager::ConstInit`] trait.
69 ///
70 /// ## Example
71 ///
72 /// ```rust
73 /// use mutex::raw_impls::cs::CriticalSectionRawMutex as CSRMutex;
74 /// use ergot_base::NetStack;
75 /// use ergot_base::interface_manager::null::NullInterfaceManager as NullIM;
76 ///
77 /// static STACK: NetStack<CSRMutex, NullIM> = NetStack::new();
78 /// ```
79 pub const fn new() -> Self {
80 Self {
81 inner: BlockingMutex::new(NetStackInner::new()),
82 }
83 }
84}
85
86impl<R, M> NetStack<R, M>
87where
88 R: ScopedRawMutex,
89 M: InterfaceManager,
90{
91 /// Manually create a new, uninitialized [`NetStack`].
92 ///
93 /// This method is useful if your [`ScopedRawMutex`] or [`InterfaceManager`]
94 /// do not implement their corresponding `ConstInit` trait.
95 ///
96 /// In general, this is most often only needed for `loom` testing, and
97 /// [`NetStack::new()`] should be used when possible.
98 pub const fn const_new(r: R, m: M) -> Self {
99 Self {
100 inner: BlockingMutex::const_new(
101 r,
102 NetStackInner {
103 sockets: List::new(),
104 manager: m,
105 seq_no: 0,
106 pcache_start: 0,
107 pcache_bits: 0,
108 },
109 ),
110 }
111 }
112
113 /// Access the contained [`InterfaceManager`].
114 ///
115 /// Access to the [`InterfaceManager`] is made via the provided closure.
116 /// The [`BlockingMutex`] is locked for the duration of this access,
117 /// inhibiting all other usage of this [`NetStack`].
118 ///
119 /// This can be used to add new interfaces, obtain metadata, or other
120 /// actions supported by the chosen [`InterfaceManager`].
121 ///
122 /// ## Example
123 ///
124 /// ```rust
125 /// # use mutex::raw_impls::cs::CriticalSectionRawMutex as CSRMutex;
126 /// # use ergot_base::NetStack;
127 /// # use ergot_base::interface_manager::null::NullInterfaceManager as NullIM;
128 /// #
129 /// static STACK: NetStack<CSRMutex, NullIM> = NetStack::new();
130 ///
131 /// let res = STACK.with_interface_manager(|im| {
132 /// // The mutex is locked for the full duration of this closure.
133 /// # _ = im;
134 /// // We can return whatever we want from this context, though not
135 /// // anything borrowed from `im`.
136 /// 42
137 /// });
138 /// assert_eq!(res, 42);
139 /// ```
140 pub fn with_interface_manager<F: FnOnce(&mut M) -> U, U>(&'static self, f: F) -> U {
141 self.inner.with_lock(|inner| f(&mut inner.manager))
142 }
143
144 /// Send a raw (pre-serialized) message.
145 ///
146 /// This interface should almost never be used by end-users, and is instead
147 /// typically used by interfaces to feed received messages into the
148 /// [`NetStack`].
149 pub fn send_raw(&'static self, hdr: Header, body: &[u8]) -> Result<(), NetStackSendError> {
150 if hdr.dst.port_id == 0 && hdr.key.is_none() {
151 return Err(NetStackSendError::AnyPortMissingKey);
152 }
153 let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
154
155 self.inner
156 .with_lock(|inner| inner.send_raw(local_bypass, hdr, body))
157 }
158
159 /// Send a typed message
160 pub fn send_ty<T: 'static + Serialize>(
161 &'static self,
162 hdr: Header,
163 t: T,
164 ) -> Result<(), NetStackSendError> {
165 // Can we assume the destination is local?
166 let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
167
168 self.inner
169 .with_lock(|inner| inner.send_ty(local_bypass, hdr, t))
170 }
171
172 pub(crate) unsafe fn try_attach_socket(
173 &'static self,
174 mut node: NonNull<SocketHeader>,
175 ) -> Option<u8> {
176 self.inner.with_lock(|inner| {
177 let new_port = inner.alloc_port()?;
178 unsafe {
179 node.as_mut().port = new_port;
180 }
181
182 inner.sockets.push_front(node);
183 Some(new_port)
184 })
185 }
186
187 pub(crate) unsafe fn attach_socket(&'static self, node: NonNull<SocketHeader>) -> u8 {
188 let res = unsafe { self.try_attach_socket(node) };
189 let Some(new_port) = res else {
190 panic!("exhausted all addrs");
191 };
192 new_port
193 }
194
195 pub(crate) unsafe fn detach_socket(&'static self, node: NonNull<SocketHeader>) {
196 self.inner.with_lock(|inner| unsafe {
197 let port = node.as_ref().port;
198 inner.free_port(port);
199 inner.sockets.remove(node)
200 });
201 }
202
203 pub(crate) unsafe fn with_lock<U, F: FnOnce() -> U>(&'static self, f: F) -> U {
204 self.inner.with_lock(|_inner| f())
205 }
206}
207
208impl<R, M> Default for NetStack<R, M>
209where
210 R: ScopedRawMutex + ConstInit,
211 M: InterfaceManager + interface_manager::ConstInit,
212{
213 fn default() -> Self {
214 Self::new()
215 }
216}
217
218// ---- impl NetStackInner ----
219
220impl<M> NetStackInner<M>
221where
222 M: InterfaceManager,
223 M: interface_manager::ConstInit,
224{
225 pub const fn new() -> Self {
226 Self {
227 sockets: List::new(),
228 manager: M::INIT,
229 seq_no: 0,
230 pcache_bits: 0,
231 pcache_start: 0,
232 }
233 }
234}
235
236impl<M> NetStackInner<M>
237where
238 M: InterfaceManager,
239{
240 fn send_raw(
241 &mut self,
242 local_bypass: bool,
243 hdr: Header,
244 body: &[u8],
245 ) -> Result<(), NetStackSendError> {
246 let res = if !local_bypass {
247 self.manager.send_raw(hdr.clone(), body)
248 } else {
249 Err(InterfaceSendError::DestinationLocal)
250 };
251
252 match res {
253 Ok(()) => return Ok(()),
254 Err(InterfaceSendError::DestinationLocal) => {}
255 Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
256 }
257 // It was a destination local error, try to honor that
258 for socket in self.sockets.iter_raw() {
259 let skt_ref = unsafe { socket.as_ref() };
260 if hdr.kind != skt_ref.kind {
261 if hdr.dst.port_id != 0 && hdr.dst.port_id == skt_ref.port {
262 // If kind mismatch and not wildcard: report error
263 return Err(NetStackSendError::WrongPortKind);
264 } else {
265 continue;
266 }
267 }
268 // TODO: only allow port_id == 0 if there is only one matching port
269 // with this key.
270 if (skt_ref.port == hdr.dst.port_id)
271 || (hdr.dst.port_id == 0 && hdr.key.is_some_and(|k| k == skt_ref.key))
272 {
273 let res = {
274 let f = skt_ref.vtable.send_raw;
275
276 // SAFETY: skt_ref is now dead to us!
277
278 let this: NonNull<SocketHeader> = socket;
279 let this: NonNull<()> = this.cast();
280 let hdr = hdr.to_headerseq_or_with_seq(|| {
281 let seq = self.seq_no;
282 self.seq_no = self.seq_no.wrapping_add(1);
283 seq
284 });
285
286 (f)(this, body, hdr).map_err(NetStackSendError::SocketSend)
287 };
288 return res;
289 }
290 }
291 Err(NetStackSendError::NoRoute)
292 }
293
294 fn send_ty<T: 'static + Serialize>(
295 &mut self,
296 local_bypass: bool,
297 hdr: Header,
298 t: T,
299 ) -> Result<(), NetStackSendError> {
300 let res = if !local_bypass {
301 // Not local: offer to the interface manager to send
302 self.manager.send(hdr.clone(), &t)
303 } else {
304 // just skip to local sending
305 Err(InterfaceSendError::DestinationLocal)
306 };
307
308 match res {
309 Ok(()) => return Ok(()),
310 Err(InterfaceSendError::DestinationLocal) => {}
311 Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
312 }
313
314 // It was a destination local error, try to honor that
315 //
316 // Sending to a local interface means a potential move. Create a
317 // manuallydrop, if a send succeeds, then we have "moved from" here
318 // into the destination. If no send succeeds (e.g. no socket match
319 // or sending to the socket failed) then we will need to drop the
320 // value ourselves.
321 let mut t = ManuallyDrop::new(t);
322
323 // Check each socket to see if we want to send it there...
324 for socket in self.sockets.iter_raw() {
325 let skt_ref = unsafe { socket.as_ref() };
326
327 if hdr.kind != skt_ref.kind {
328 if hdr.dst.port_id != 0 && hdr.dst.port_id == skt_ref.port {
329 // If kind mismatch and not wildcard: report error
330 return Err(NetStackSendError::WrongPortKind);
331 } else {
332 continue;
333 }
334 }
335
336 // TODO: only allow port_id == 0 if there is only one matching port
337 // with this key.
338 if (skt_ref.port == hdr.dst.port_id || hdr.dst.port_id == 0)
339 && hdr.key.unwrap() == skt_ref.key
340 {
341 let vtable: &'static SocketVTable = skt_ref.vtable;
342 // SAFETY: skt_ref is now dead to us!
343
344 let res = if let Some(f) = vtable.send_owned {
345 let this: NonNull<SocketHeader> = socket;
346 let this: NonNull<()> = this.cast();
347 let that: NonNull<ManuallyDrop<T>> = NonNull::from(&mut t);
348 let that: NonNull<()> = that.cast();
349 let hdr = hdr.to_headerseq_or_with_seq(|| {
350 let seq = self.seq_no;
351 self.seq_no = self.seq_no.wrapping_add(1);
352 seq
353 });
354 (f)(this, that, hdr, &TypeId::of::<T>()).map_err(NetStackSendError::SocketSend)
355 } else if let Some(_f) = vtable.send_bor {
356 // TODO: if we support send borrowed, then we need to
357 // drop the manuallydrop here, success or failure.
358 todo!()
359 } else {
360 // todo: keep going? If we found the "right" destination and
361 // sending fails, then there's not much we can do. Probably: there
362 // is no case where a socket has NEITHER send_owned NOR send_bor,
363 // can we make this state impossible instead?
364 Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
365 };
366
367 // If sending failed, we did NOT move the T, which means it's on us
368 // to drop it.
369 if res.is_err() {
370 unsafe {
371 ManuallyDrop::drop(&mut t);
372 }
373 }
374 return res;
375 }
376 }
377
378 // We reached the end of sockets. We need to drop this item.
379 unsafe {
380 ManuallyDrop::drop(&mut t);
381 }
382 Err(NetStackSendError::NoRoute)
383 }
384}
385
386impl<M> NetStackInner<M>
387where
388 M: InterfaceManager,
389{
390 /// Cache-based allocator inspired by littlefs2 ID allocator
391 ///
392 /// We remember 32 ports at a time, from the current base, which is always
393 /// a multiple of 32. Allocating from this range does not require moving thru
394 /// the socket lists.
395 ///
396 /// If the current 32 ports are all taken, we will start over from a base port
397 /// of 0, and attempt to
398 fn alloc_port(&mut self) -> Option<u8> {
399 // ports 0 is always taken (could be clear on first alloc)
400 self.pcache_bits |= (self.pcache_start == 0) as u32;
401
402 if self.pcache_bits != u32::MAX {
403 // We can allocate from the current slot
404 let ldg = self.pcache_bits.trailing_ones();
405 debug_assert!(ldg < 32);
406 self.pcache_bits |= 1 << ldg;
407 return Some(self.pcache_start + (ldg as u8));
408 }
409
410 // Nope, cache is all taken. try to find a base with available items.
411 // We always start from the bottom to keep ports small, but if we know
412 // we just exhausted a range, don't waste time checking that
413 let old_start = self.pcache_start;
414 for base in 0..8 {
415 let start = base * 32;
416 if start == old_start {
417 continue;
418 }
419 // Clear/reset cache
420 self.pcache_start = start;
421 self.pcache_bits = 0;
422 // port 0 is not allowed
423 self.pcache_bits |= (self.pcache_start == 0) as u32;
424 // port 255 is not allowed
425 self.pcache_bits |= ((self.pcache_start == 0b111_00000) as u32) << 31;
426
427 // TODO: If we trust that sockets are always sorted, we could early-return
428 // when we reach a `pupper > self.pcache_start`. We could also maybe be smart
429 // and iterate forwards for 0..4 and backwards for 4..8 (and switch the early
430 // return check to < instead). NOTE: We currently do NOT guarantee sockets are
431 // sorted!
432 self.sockets.iter().for_each(|s| {
433 // The upper 3 bits of the port
434 let pupper = s.port & !(32 - 1);
435 // The lower 5 bits of the port
436 let plower = s.port & (32 - 1);
437
438 if pupper == self.pcache_start {
439 self.pcache_bits |= 1 << plower;
440 }
441 });
442
443 if self.pcache_bits != u32::MAX {
444 // We can allocate from the current slot
445 let ldg = self.pcache_bits.trailing_ones();
446 debug_assert!(ldg < 32);
447 self.pcache_bits |= 1 << ldg;
448 return Some(self.pcache_start + (ldg as u8));
449 }
450 }
451
452 // Nope, nothing found
453 None
454 }
455
456 fn free_port(&mut self, port: u8) {
457 // The upper 3 bits of the port
458 let pupper = port & !(32 - 1);
459 // The lower 5 bits of the port
460 let plower = port & (32 - 1);
461
462 // TODO: If the freed port is in the 0..32 range, or just less than
463 // the current start range, maybe do an opportunistic re-look?
464 if pupper == self.pcache_start {
465 self.pcache_bits &= !(1 << plower);
466 }
467 }
468}
469
470#[cfg(test)]
471mod test {
472 use core::pin::pin;
473 use mutex::raw_impls::cs::CriticalSectionRawMutex;
474 use std::thread::JoinHandle;
475 use tokio::sync::oneshot;
476
477 use crate::{
478 FrameKind, Key, NetStack, interface_manager::null::NullInterfaceManager,
479 socket::owned::OwnedSocket,
480 };
481
482 #[test]
483 fn port_alloc() {
484 static STACK: NetStack<CriticalSectionRawMutex, NullInterfaceManager> = NetStack::new();
485
486 let mut v = vec![];
487
488 fn spawn_skt(id: u8) -> (u8, JoinHandle<()>, oneshot::Sender<()>) {
489 let (txdone, rxdone) = oneshot::channel();
490 let (txwait, rxwait) = oneshot::channel();
491 let hdl = std::thread::spawn(move || {
492 let skt = OwnedSocket::<u64, _, _>::new(
493 &STACK,
494 Key(*b"TEST1234"),
495 FrameKind::ENDPOINT_REQ,
496 );
497 let skt = pin!(skt);
498 let hdl = skt.attach();
499 assert_eq!(hdl.port(), id);
500 txwait.send(()).unwrap();
501 let _: () = rxdone.blocking_recv().unwrap();
502 });
503 let _ = rxwait.blocking_recv();
504 (id, hdl, txdone)
505 }
506
507 // make sockets 1..32
508 for i in 1..32 {
509 v.push(spawn_skt(i));
510 }
511
512 // make sockets 32..40
513 for i in 32..40 {
514 v.push(spawn_skt(i));
515 }
516
517 // drop socket 35
518 let pos = v.iter().position(|(i, _, _)| *i == 35).unwrap();
519 let (_i, hdl, tx) = v.remove(pos);
520 tx.send(()).unwrap();
521 hdl.join().unwrap();
522
523 // make a new socket, it should be 35
524 v.push(spawn_skt(35));
525
526 // drop socket 4
527 let pos = v.iter().position(|(i, _, _)| *i == 4).unwrap();
528 let (_i, hdl, tx) = v.remove(pos);
529 tx.send(()).unwrap();
530 hdl.join().unwrap();
531
532 // make a new socket, it should be 40
533 v.push(spawn_skt(40));
534
535 // make sockets 41..64
536 for i in 41..64 {
537 v.push(spawn_skt(i));
538 }
539
540 // make a new socket, it should be 4
541 v.push(spawn_skt(4));
542
543 // make sockets 64..255
544 for i in 64..255 {
545 v.push(spawn_skt(i));
546 }
547
548 // drop socket 212
549 let pos = v.iter().position(|(i, _, _)| *i == 212).unwrap();
550 let (_i, hdl, tx) = v.remove(pos);
551 tx.send(()).unwrap();
552 hdl.join().unwrap();
553
554 // make a new socket, it should be 212
555 v.push(spawn_skt(212));
556
557 // Sockets exhausted (we never see 255)
558 let hdl = std::thread::spawn(move || {
559 let skt =
560 OwnedSocket::<u64, _, _>::new(&STACK, Key(*b"TEST1234"), FrameKind::ENDPOINT_REQ);
561 let skt = pin!(skt);
562 let hdl = skt.attach();
563 println!("{}", hdl.port());
564 });
565 assert!(hdl.join().is_err());
566
567 for (_i, hdl, tx) in v.drain(..) {
568 tx.send(()).unwrap();
569 hdl.join().unwrap();
570 }
571 }
572}