embedded_mqttc/network/fake/
stream.rs1use core::{cell::RefCell, cmp::min, future::Future, pin::Pin, task::{Context, Poll}};
2
3use embytes_buffer::{new_stack_buffer, Buffer, BufferReader, BufferWriter, ReadWrite};
4use embassy_sync::{blocking_mutex::{raw::CriticalSectionRawMutex, Mutex}, waitqueue::WakerRegistration};
5use embedded_io_async::{ErrorKind, ErrorType, Read, Write};
6use crate::network::{TryRead, TryWrite};
7pub struct BufferedStream<const N: usize> {
8 pub(crate) inner: Mutex<CriticalSectionRawMutex, RefCell<BufferedStreamInner<N>>>
9}
10
11impl <const N: usize> BufferedStream<N> {
12 pub fn new() -> Self {
13 Self {
14 inner: Mutex::new(RefCell::new(BufferedStreamInner::new()))
15 }
16 }
17
18 pub fn read_async<'a, 'b>(&'b self, buf: &'a mut [u8]) -> ReadFuture<'a, 'b, N> {
19 ReadFuture {
20 connection: self,
21 buf
22 }
23 }
24
25 pub fn try_read_sync(&self, buf: &mut [u8]) -> Result<usize, ErrorKind> {
26 self.inner.lock(|inner| {
27 inner.borrow_mut().try_read_sync(buf)
28 })
29 }
30
31 pub fn write_async<'a, 'b>(&'b self, buf: &'a [u8]) -> WriteFuture<'a, 'b, N> {
32 WriteFuture {
33 connection: self,
34 buf
35 }
36 }
37
38 pub fn try_write_sync(&self, buf: &[u8]) -> Result<usize, ErrorKind> {
39 self.inner.lock(|inner| {
40 inner.borrow_mut().try_write_sync(buf)
41 })
42 }
43
44 pub fn with_reader<F, R>(&self, f: F) -> R where F: FnOnce(&dyn BufferReader) -> R {
45 self.inner.lock(|inner|{
46 let mut inner = inner.borrow_mut();
47 let reader = inner.buffer.create_reader();
48 f(&reader)
49 })
50 }
51}
52
53pub(crate) struct BufferedStreamInner<const N: usize> {
54 pub(crate) buffer: Buffer<[u8; N]>,
55 pub(crate) read_waker: WakerRegistration,
56 pub(crate) write_waker: WakerRegistration
57}
58
59impl <const N: usize> BufferedStreamInner<N> {
60 fn new() -> Self {
61 Self {
62 buffer: new_stack_buffer(),
63 read_waker: WakerRegistration::new(),
64 write_waker: WakerRegistration::new()
65 }
66 }
67
68 fn try_read_sync(&mut self, buf: &mut [u8]) -> Result<usize, ErrorKind> {
69
70
71 if buf.len() == 0 {
72 return Ok(0);
73 }
74
75 if ! self.buffer.has_remaining_len() {
76 return Ok(0);
77 }
78
79 let n = min(
80 self.buffer.remaining_len(),
81 buf.len()
82 );
83
84 let reader = self.buffer.create_reader();
85 let buf = &mut buf[0..n];
86 buf.copy_from_slice(&reader[..n]);
87 reader.add_bytes_read(n);
88
89 if n > 0 {
91 self.write_waker.wake();
92 }
93
94 Ok(n)
95 }
96
97 fn try_write_sync(&mut self, buf: &[u8]) -> Result<usize, ErrorKind> {
98
99 if buf.len() == 0 {
100 return Ok(0);
101 }
102
103 let mut writer = self.buffer.create_writer();
104
105 if writer.remaining_capacity() == 0 {
106 return Ok(0);
107 }
108
109 let n = min(
110 buf.len(),
111 writer.remaining_capacity()
112 );
113
114 let target = &mut writer[..n];
115 target.copy_from_slice(&buf[..n]);
116 writer.commit(n).unwrap();
117
118 if n > 0 {
120 self.read_waker.wake();
121 }
122
123 Ok(n)
124 }
125
126 fn poll_read(&mut self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<Result<usize, ErrorKind>> {
127 if buf.len() == 0 {
128 return Poll::Ready(Ok(0))
129 }
130
131 let result = self.try_read_sync(buf);
132
133 match result {
134 Ok(n) => {
135 if n == 0 {
136 self.read_waker.register(cx.waker());
137 Poll::Pending
138 } else {
139 Poll::Ready(Ok(n))
140 }
141 },
142 Err(e) => Poll::Ready(Err(e)),
143 }
144 }
145
146 fn poll_write(&mut self, buf: &[u8], cx: &mut Context<'_>) -> Poll<Result<usize, ErrorKind>> {
147 if buf.len() == 0 {
148 return Poll::Ready(Ok(0));
149 }
150
151 let result = self.try_write_sync(buf);
152
153 match result {
154 Ok(n) => {
155 if n == 0 {
156 self.write_waker.register(cx.waker());
157 Poll::Pending
158 } else {
159 Poll::Ready(Ok(n))
160 }
161 },
162 Err(e) => Poll::Ready(Err(e)),
163 }
164 }
165}
166
167impl <const N: usize> ErrorType for BufferedStream<N> {
168 type Error = ErrorKind;
169}
170
171impl <const N: usize> TryRead for BufferedStream<N> {
172 async fn try_read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
173 self.try_read_sync(buf)
174 }
175}
176
177impl <const N: usize> TryWrite for BufferedStream<N> {
178 async fn try_write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
179 self.try_write_sync(buf)
180 }
181}
182
183pub struct ReadFuture<'a, 'b, const N: usize> {
184 connection: &'b BufferedStream<N>,
185 buf: &'a mut [u8]
186}
187
188impl <'a, 'b, const N: usize> Future for ReadFuture<'a, 'b, N> {
189 type Output = Result<usize, ErrorKind>;
190
191 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
192 self.connection.inner.lock(|inner|{
193 inner.borrow_mut().poll_read(self.buf, cx)
194 })
195 }
196}
197
198impl <'a, 'b, const N: usize> Unpin for ReadFuture<'a, 'b, N>{}
199
200impl <const N: usize> Read for BufferedStream<N> {
201 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
202 let future = ReadFuture{
203 connection: self,
204 buf
205 };
206
207 future.await
208 }
209}
210
211pub struct WriteFuture<'a, 'b, const N: usize> {
212 connection: &'b BufferedStream<N>,
213 buf: &'a [u8]
214}
215
216impl <'a, 'b, const N: usize> Future for WriteFuture<'a, 'b, N> {
217 type Output = Result<usize, ErrorKind>;
218
219 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
220 self.connection.inner.lock(|inner| {
221 inner.borrow_mut().poll_write(self.buf, cx)
222 })
223 }
224}
225
226impl <const N: usize> Write for BufferedStream<N> {
227 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
228 let future = WriteFuture{
229 connection: self,
230 buf
231 };
232
233 future.await
234 }
235}
236
237#[cfg(test)]
238mod stream_tests {
239 use super::BufferedStream;
240
241
242 #[tokio::test]
243 async fn test_stream () {
244 let stream = BufferedStream::<4>::new();
245
246 let write_future = async {
247
248 for i in 0..128 {
249 stream.write_async(&[i]).await.unwrap();
250 }
251
252 };
253
254 let read_future = async {
255 let mut remaining = 128;
256 let mut received = Vec::new();
257
258 while remaining > 0 {
259 let mut buf = [0; 8];
260 let n = stream.read_async(&mut buf).await.unwrap();
261 remaining -= n;
262 received.extend_from_slice(&buf[0..n]);
263 }
264
265 assert_eq!(received[8], 8);
266 assert_eq!(received[0], 0);
267
268 };
269
270 tokio::join!(read_future, write_future);
271 }
272
273}