Skip to main content

ts_netstack_smoltcp/
pipe.rs

1use alloc::sync::Arc;
2use core::{
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use bytes::{Bytes, BytesMut};
8use futures_util::task::AtomicWaker;
9use netcore::{
10    Pipe, flume, smoltcp,
11    smoltcp::{
12        phy::{ChecksumCapabilities, DeviceCapabilities, Medium},
13        time::Instant,
14    },
15};
16
17/// Bidirectional pipe carrying byte buffer payloads.
18///
19/// This is like [`netcore::Pipe`], except that it also implements
20/// [`AsyncWakeDevice`][netcore::AsyncWakeDevice], which needs a bit of fiddling to adapt.
21pub struct WakingPipe {
22    /// The send side of the pipe.
23    pub rx: WakingPipeReceiver,
24    /// The transmit side of the pipe.
25    pub tx: WakingPipeSender,
26}
27
28/// A [`flume::Receiver`] wrapped to support [`AsyncWakeDevice`][netcore::AsyncWakeDevice].
29///
30/// It wakes the remote [`WakingPipeSender`] when a message is received.
31pub struct WakingPipeReceiver {
32    rx: flume::Receiver<Bytes>,
33    /// [`flume::Receiver`] doesn't expose a way to poll until a value is ready without
34    /// consuming it. This holds the consumed value.
35    buffered_rx: Option<Bytes>,
36
37    /// The waker that this end of the pipe polls on in `poll_rx`.
38    ///
39    /// It is woken by the remote (tx) end of the pipe when a packet is sent, i.e. the
40    /// readiness state of `poll_rx` changes.
41    self_waker: Arc<AtomicWaker>,
42
43    /// The waker for the remote (tx) end of the pipe.
44    ///
45    /// We wake this when we receive a packet (i.e. make room in the pipe). That only
46    /// matters if `rx` is a bounded channel.
47    remote_waker: Arc<AtomicWaker>,
48}
49
50/// A [`flume::Sender`] that wakes a remote [`WakingPipeReceiver`] when a message is sent.
51#[derive(Clone)]
52pub struct WakingPipeSender {
53    tx: flume::Sender<Bytes>,
54
55    /// The waker this end of the pipe polls on in `poll_tx`.
56    ///
57    /// It is woken by the remote (rx) end of the pipe when a packet is received, i.e. the
58    /// readiness state of `poll_tx` changes.
59    ///
60    /// This only matters if `self.tx` is a bounded channel, otherwise in the unbounded case
61    /// we're always ready to send.
62    self_waker: Arc<AtomicWaker>,
63
64    /// The waker for the remote (rx) end of the pipe.
65    ///
66    /// We wake this when we send a packet.
67    remote_waker: Arc<AtomicWaker>,
68}
69
70impl WakingPipe {
71    /// Construct a new pipe with the given optional capacity `limit`.
72    pub fn new(limit: Option<usize>) -> (Self, Self) {
73        if let Some(limit) = limit {
74            Self::bounded(limit)
75        } else {
76            Self::unbounded()
77        }
78    }
79
80    /// Construct a new unbounded pipe.
81    pub fn unbounded() -> (Self, Self) {
82        let (pipe1, pipe2) = Pipe::unbounded();
83
84        Self::_new(pipe1, pipe2)
85    }
86
87    /// Construct a new pipe that can carry at most `limit` packets.
88    pub fn bounded(limit: usize) -> (Self, Self) {
89        let (pipe1, pipe2) = Pipe::bounded(limit);
90
91        Self::_new(pipe1, pipe2)
92    }
93
94    fn _new(pipe1: Pipe, pipe2: Pipe) -> (Self, Self) {
95        let pipe1_rx_waker = Arc::new(AtomicWaker::new());
96        let pipe2_rx_waker = Arc::new(AtomicWaker::new());
97
98        let pipe1_tx_waker = Arc::new(AtomicWaker::new());
99        let pipe2_tx_waker = Arc::new(AtomicWaker::new());
100
101        (
102            Self {
103                rx: WakingPipeReceiver {
104                    rx: pipe1.rx,
105                    buffered_rx: None,
106                    self_waker: pipe1_rx_waker.clone(),
107                    remote_waker: pipe2_tx_waker.clone(),
108                },
109                tx: WakingPipeSender {
110                    tx: pipe1.tx,
111                    remote_waker: pipe2_rx_waker.clone(),
112                    self_waker: pipe1_tx_waker.clone(),
113                },
114            },
115            Self {
116                rx: WakingPipeReceiver {
117                    rx: pipe2.rx,
118                    buffered_rx: None,
119                    self_waker: pipe2_rx_waker,
120                    remote_waker: pipe1_tx_waker,
121                },
122                tx: WakingPipeSender {
123                    tx: pipe2.tx,
124                    remote_waker: pipe1_rx_waker,
125                    self_waker: pipe2_tx_waker,
126                },
127            },
128        )
129    }
130}
131
132impl WakingPipeReceiver {
133    /// Receive a packet.
134    pub fn recv(&mut self) -> Option<Bytes> {
135        if let Some(buf) = self.buffered_rx.take() {
136            return Some(buf);
137        }
138
139        let ret = self.rx.recv().ok();
140        self.remote_waker.wake();
141
142        ret
143    }
144
145    /// Receive a packet asynchronously.
146    pub async fn recv_async(&mut self) -> Option<Bytes> {
147        if let Some(buf) = self.buffered_rx.take() {
148            return Some(buf);
149        }
150
151        let ret = self.rx.recv_async().await.ok();
152        self.remote_waker.wake();
153
154        ret
155    }
156
157    /// Receive a packet if it's possible to do so without blocking.
158    pub fn try_recv(&mut self) -> Option<Bytes> {
159        if let Some(buf) = self.buffered_rx.take() {
160            return Some(buf);
161        }
162
163        let ret = self.rx.recv().ok();
164        self.remote_waker.wake();
165
166        ret
167    }
168
169    /// Report whether there is a packet ready to be received.
170    pub fn rx_ready(&self) -> bool {
171        self.buffered_rx.is_some() || !self.rx.is_empty()
172    }
173}
174
175impl WakingPipeSender {
176    /// Send a packet, blocking until complete.
177    pub fn send(&self, buf: &[u8]) {
178        if let Err(_e) = self.tx.send(Bytes::copy_from_slice(buf)) {
179            tracing::warn!("send dropped: remote end of pipe is gone");
180            return;
181        }
182
183        self.remote_waker.wake();
184    }
185
186    /// Send a packet asynchronously.
187    pub async fn send_async(&self, buf: &[u8]) {
188        if let Err(_e) = self.tx.send_async(Bytes::copy_from_slice(buf)).await {
189            tracing::warn!("send dropped: remote end of pipe is gone");
190            return;
191        }
192
193        self.remote_waker.wake();
194    }
195
196    /// Send a packet if it's possible to do so without blocking.
197    ///
198    /// Returns whether the packet was actually sent.
199    pub fn try_send(&self, buf: &[u8]) -> bool {
200        match self.tx.try_send(Bytes::copy_from_slice(buf)) {
201            Ok(()) => {
202                self.remote_waker.wake();
203                true
204            }
205            Err(flume::TrySendError::Full(..)) => false,
206            Err(flume::TrySendError::Disconnected(..)) => {
207                tracing::warn!("send dropped: remote end of pipe is gone");
208
209                // Semantically, that the remote end was dropped can be thought of as deciding to
210                // ignore all of our messages
211                true
212            }
213        }
214    }
215
216    /// Report whether we can currently transmit.
217    pub fn tx_ready(&self) -> bool {
218        !self.tx.is_full()
219    }
220}
221
222impl netcore::AsyncWakeDevice for WakingPipeDev {
223    #[tracing::instrument(name = "WakingPipeDev::poll_tx", skip_all, level = "trace", ret)]
224    fn poll_tx<'cx>(self: Pin<&mut Self>, cx: &mut Context<'cx>) -> Poll<()> {
225        self.pipe.tx.self_waker.register(cx.waker());
226
227        if self.pipe.tx.tx_ready() {
228            return Poll::Ready(());
229        }
230
231        Poll::Pending
232    }
233
234    #[tracing::instrument(name = "WakingPipeDev::poll_rx", skip_all, level = "trace", ret)]
235    fn poll_rx<'cx>(mut self: Pin<&mut Self>, cx: &mut Context<'cx>) -> Poll<()> {
236        self.pipe.rx.self_waker.register(cx.waker());
237
238        if self.pipe.rx.rx_ready() {
239            // Check tx readiness so that we return Poll::Ready when Device::receive is actually
240            // ready, which only occurs when both TxToken and RxToken can be constructed.
241            core::task::ready!(self.as_mut().poll_tx(cx));
242
243            return Poll::Ready(());
244        }
245
246        Poll::Pending
247    }
248}
249
250impl smoltcp::phy::TxToken for WakingPipeSender {
251    #[tracing::instrument(
252        name = "WakingPipeSender::consume",
253        skip_all,
254        fields(len),
255        level = "trace"
256    )]
257    fn consume<R, F>(self, len: usize, f: F) -> R
258    where
259        F: FnOnce(&mut [u8]) -> R,
260    {
261        let mut b = BytesMut::zeroed(len);
262
263        let ret = f(&mut b);
264        if self.tx.send(b.freeze()).is_err() {
265            tracing::warn!("remote end of dropped on send");
266        }
267
268        self.remote_waker.wake();
269
270        ret
271    }
272}
273
274pub struct RxToken(Bytes);
275
276impl smoltcp::phy::RxToken for RxToken {
277    #[tracing::instrument(name = "WakingPipeRx::consume", skip_all, level = "trace")]
278    fn consume<R, F>(self, f: F) -> R
279    where
280        F: FnOnce(&[u8]) -> R,
281    {
282        f(&self.0)
283    }
284}
285
286/// Wrapper around [`WakingPipe`] to implement [`smoltcp::phy::Device`].
287///
288/// Like [`netcore::PipeDev`] except that it implements
289/// [`AsyncWakeDevice`][netcore::AsyncWakeDevice].
290pub struct WakingPipeDev {
291    /// End of a pipe that will be directly connected to the netstack, receiving packets
292    /// to be sent and supplying packets to be received.
293    pub pipe: WakingPipe,
294
295    /// The type of network frame the pipe will carry.
296    ///
297    /// For our purposes, this will typically be [`Medium::Ip`].
298    pub medium: Medium,
299    /// The maximum packet size to be transmitted through the pipe.
300    ///
301    /// The implementation does not check or limit the actual size of packets flowing
302    /// through it, this field is just informational for
303    /// [`smoltcp::phy::Device::capabilities`].
304    pub mtu: usize,
305}
306
307impl smoltcp::phy::Device for WakingPipeDev {
308    type RxToken<'a>
309        = RxToken
310    where
311        Self: 'a;
312
313    type TxToken<'a>
314        = WakingPipeSender
315    where
316        Self: 'a;
317
318    #[tracing::instrument(skip(self), level = "trace")]
319    fn receive(&mut self, timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
320        let tx = self.transmit(timestamp)?;
321
322        let b = if let Some(buf) = self.pipe.rx.buffered_rx.take() {
323            buf
324        } else {
325            self.pipe.rx.rx.try_recv().ok()?
326        };
327
328        Some((RxToken(b), tx))
329    }
330
331    #[tracing::instrument(skip(self), level = "trace")]
332    fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
333        if self.pipe.tx.tx.is_disconnected() {
334            return None;
335        }
336
337        Some(self.pipe.tx.clone())
338    }
339
340    fn capabilities(&self) -> DeviceCapabilities {
341        let mut caps = DeviceCapabilities::default();
342
343        caps.max_transmission_unit = self.mtu;
344        caps.medium = self.medium;
345        caps.checksum = ChecksumCapabilities::ignored();
346
347        caps
348    }
349}