embedded_mqttc/network/fake/
stream.rs

1use core::{cell::RefCell, cmp::min, future::Future, pin::Pin, task::{Context, Poll}};
2
3use embytes_buffer::{new_stack_buffer, Buffer, BufferReader, BufferWriter, ReadWrite};
4use embassy_sync::{blocking_mutex::{raw::CriticalSectionRawMutex, Mutex}, waitqueue::WakerRegistration};
5use embedded_io_async::{ErrorKind, ErrorType, Read, Write};
6use crate::network::{TryRead, TryWrite};
7pub struct BufferedStream<const N: usize> {
8    pub(crate) inner: Mutex<CriticalSectionRawMutex, RefCell<BufferedStreamInner<N>>>
9}
10
11impl <const N: usize> BufferedStream<N> {
12    pub fn new() -> Self {
13        Self {
14            inner: Mutex::new(RefCell::new(BufferedStreamInner::new()))
15        }
16    }
17
18    pub fn read_async<'a, 'b>(&'b self, buf: &'a mut [u8]) -> ReadFuture<'a, 'b, N> {
19        ReadFuture { 
20            connection: self, 
21            buf 
22        }
23    }
24
25    pub fn try_read_sync(&self, buf: &mut [u8]) -> Result<usize, ErrorKind> {
26        self.inner.lock(|inner| {
27            inner.borrow_mut().try_read_sync(buf)
28        })
29    }
30
31    pub fn write_async<'a, 'b>(&'b self, buf: &'a [u8]) -> WriteFuture<'a, 'b, N> {
32        WriteFuture { 
33            connection: self, 
34            buf
35        }
36    }
37
38    pub fn try_write_sync(&self, buf: &[u8]) -> Result<usize, ErrorKind> {
39        self.inner.lock(|inner| {
40            inner.borrow_mut().try_write_sync(buf)
41        })
42    }
43
44    pub fn with_reader<F, R>(&self, f: F) -> R where F: FnOnce(&dyn BufferReader) -> R {
45        self.inner.lock(|inner|{
46            let mut inner = inner.borrow_mut();
47            let reader = inner.buffer.create_reader();
48            f(&reader)
49        })
50    }
51}
52
53pub(crate) struct BufferedStreamInner<const N: usize> {
54    pub(crate) buffer: Buffer<[u8; N]>,
55    pub(crate) read_waker: WakerRegistration,
56    pub(crate) write_waker: WakerRegistration
57}
58
59impl <const N: usize> BufferedStreamInner<N> {
60    fn new() -> Self {
61        Self {
62            buffer: new_stack_buffer(),
63            read_waker: WakerRegistration::new(),
64            write_waker: WakerRegistration::new()
65        }
66    }
67
68    fn try_read_sync(&mut self, buf: &mut [u8]) -> Result<usize, ErrorKind> {
69
70
71        if buf.len() == 0 {
72            return Ok(0);
73        }
74
75        if ! self.buffer.has_remaining_len() {
76            return Ok(0);
77        }
78
79        let n = min(
80            self.buffer.remaining_len(),
81            buf.len()
82        );
83
84        let reader = self.buffer.create_reader();
85        let buf = &mut buf[0..n];
86        buf.copy_from_slice(&reader[..n]);
87        reader.add_bytes_read(n);
88
89        // Sigals wakers that bytes were read
90        if n > 0 {
91            self.write_waker.wake();
92        }
93
94        Ok(n)
95    }
96
97    fn try_write_sync(&mut self, buf: &[u8]) -> Result<usize, ErrorKind> {
98
99        if buf.len() == 0 {
100            return Ok(0);
101        }
102
103        let mut writer = self.buffer.create_writer();
104
105        if writer.remaining_capacity() == 0 {
106            return Ok(0);
107        }
108
109        let n = min(
110            buf.len(),
111            writer.remaining_capacity()
112        );
113
114        let target = &mut writer[..n];
115        target.copy_from_slice(&buf[..n]);
116        writer.commit(n).unwrap();
117
118        // Sigals wakers that bytes were written
119        if n > 0 {
120            self.read_waker.wake();
121        }
122
123        Ok(n)
124    }
125
126    fn poll_read(&mut self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<Result<usize, ErrorKind>>  {
127        if buf.len() == 0 {
128            return Poll::Ready(Ok(0))
129        }
130
131        let result = self.try_read_sync(buf);
132
133        match result {
134            Ok(n) => {
135                if n == 0 {
136                    self.read_waker.register(cx.waker());
137                    Poll::Pending
138                } else {
139                    Poll::Ready(Ok(n))
140                }
141            },
142            Err(e) => Poll::Ready(Err(e)),
143        }
144    }
145
146    fn poll_write(&mut self, buf: &[u8], cx: &mut Context<'_>) -> Poll<Result<usize, ErrorKind>> {
147        if buf.len() == 0 {
148            return Poll::Ready(Ok(0));
149        }
150        
151        let result = self.try_write_sync(buf);
152
153        match result {
154            Ok(n) => {
155                if n == 0 {
156                        self.write_waker.register(cx.waker());
157                    Poll::Pending
158                } else {
159                    Poll::Ready(Ok(n))
160                }
161            },
162            Err(e) => Poll::Ready(Err(e)),
163        }
164    }
165}
166
167impl <const N: usize> ErrorType for BufferedStream<N> {
168    type Error = ErrorKind;
169}
170
171impl <const N: usize> TryRead for BufferedStream<N> {
172    async fn try_read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
173        self.try_read_sync(buf)
174    }
175}
176
177impl <const N: usize> TryWrite for BufferedStream<N> {
178    async fn try_write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
179        self.try_write_sync(buf)
180    }
181}
182
183pub struct ReadFuture<'a, 'b, const N: usize> {
184    connection: &'b BufferedStream<N>,
185    buf: &'a mut [u8]
186}
187
188impl <'a, 'b, const N: usize> Future for ReadFuture<'a, 'b, N> {
189    type Output = Result<usize, ErrorKind>;
190
191    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
192        self.connection.inner.lock(|inner|{
193            inner.borrow_mut().poll_read(self.buf, cx)
194        })
195    }
196}
197
198impl <'a, 'b, const N: usize> Unpin for ReadFuture<'a, 'b, N>{}
199
200impl <const N: usize> Read for BufferedStream<N> {
201    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
202        let future = ReadFuture{
203            connection: self,
204            buf
205        };
206
207        future.await
208    }
209}
210
211pub struct WriteFuture<'a, 'b, const N: usize> {
212    connection: &'b BufferedStream<N>,
213    buf: &'a [u8]
214}
215
216impl <'a, 'b, const N: usize> Future for WriteFuture<'a, 'b, N>  {
217    type Output = Result<usize, ErrorKind>;
218
219    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
220        self.connection.inner.lock(|inner| {
221            inner.borrow_mut().poll_write(self.buf, cx)
222        })
223    }
224}
225
226impl <const N: usize> Write for BufferedStream<N> {
227    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
228        let future = WriteFuture{
229            connection: self,
230            buf
231        };
232
233        future.await
234    }
235}
236
237#[cfg(test)]
238mod stream_tests {
239    use super::BufferedStream;
240
241
242    #[tokio::test]
243    async fn test_stream () {
244        let stream = BufferedStream::<4>::new();
245
246        let write_future = async {
247
248            for i in 0..128 {
249                stream.write_async(&[i]).await.unwrap();
250            }
251
252        };
253
254        let read_future = async {
255            let mut remaining = 128;
256            let mut received = Vec::new();
257
258            while remaining > 0 {
259                let mut buf = [0; 8];
260                let n = stream.read_async(&mut buf).await.unwrap();
261                remaining -= n;
262                received.extend_from_slice(&buf[0..n]); 
263            }
264
265            assert_eq!(received[8], 8);
266            assert_eq!(received[0], 0);
267
268        };
269
270        tokio::join!(read_future, write_future);
271    }
272
273}