embedded_mqttc/network/fake/
connection.rs

1use core::{future::Future, pin::Pin, task::{Context, Poll}};
2
3use embytes_buffer::{BufferReader, ReadWrite};
4use embedded_io_async::{ErrorKind, ErrorType, Read, Write};
5use mqttrs2::Packet;
6use crate::network::{mqtt::MqttPacketError, mqtt::ReadMqttPacket, NetworkConnection, NetworkError, TryRead, TryWrite};
7
8use super::BufferedStream;
9
10pub struct ServerConnection<'a, const N: usize> {
11    pub(crate) out_stream: &'a BufferedStream<N>,
12    pub(crate) in_stream: &'a BufferedStream<N>
13}
14
15impl <'a, const N: usize> ServerConnection <'a, N>{
16    pub fn with_reader<F, R>(&self, f: F) -> R where F: FnOnce(&dyn BufferReader) -> R {
17        self.in_stream.with_reader(f)
18    }
19}
20
21impl <'a, const N: usize> ErrorType for ServerConnection<'a, N>  {
22    type Error = ErrorKind;
23}
24
25impl <'a, const N: usize> Write for ServerConnection<'a, N>  {
26    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
27        self.out_stream.write_async(buf).await
28    }
29}
30
31impl <'a, const N: usize> Read for ServerConnection<'a, N>  {
32    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
33        self.in_stream.read_async(buf).await
34    }
35}
36
37impl <'a, const N: usize> TryRead for ServerConnection<'a, N>  {
38    async fn try_read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
39        self.in_stream.try_read_sync(buf)
40    }
41}
42
43impl <'a, const N: usize> TryWrite for ServerConnection<'a, N>  {
44    async fn try_write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
45        self.out_stream.try_write_sync(buf)
46    }
47}
48
49pub struct ClientConnection<'a, const N: usize>{
50    out_stream: &'a BufferedStream<N>,
51    in_stream: &'a BufferedStream<N>
52}
53
54impl <'a, const N: usize> ErrorType for ClientConnection<'a, N>  {
55    type Error = ErrorKind;
56}
57
58impl <'a, const N: usize> Unpin for ClientConnection<'a, N> {}
59
60impl <'a, const N: usize> Write for ClientConnection<'a, N>  {
61    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
62        self.out_stream.write_async(buf).await
63    }
64}
65
66impl <'a, const N: usize> Read for ClientConnection<'a, N>  {
67    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
68        self.in_stream.read_async(buf).await
69    }
70}
71
72impl <'a, const N: usize> TryRead for ClientConnection<'a, N>  {
73    async fn try_read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
74        self.in_stream.try_read_sync(buf)
75    }
76}
77
78impl <'a, const N: usize> TryWrite for ClientConnection<'a, N>  {
79    async fn try_write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
80        self.out_stream.try_write_sync(buf)
81    }
82}
83
84impl <'a, const N: usize> NetworkConnection for ClientConnection<'a, N> {
85    async fn connect(&mut self) -> Result<(), NetworkError> {
86        Ok(())
87    }
88}
89
90pub struct ConnectionRessources<const N: usize> {
91    client_to_server: BufferedStream<N>,
92    server_to_client: BufferedStream<N>
93}
94
95impl <const N: usize> ConnectionRessources<N> {
96    pub fn new() -> Self {
97        Self {
98            client_to_server: BufferedStream::new(),
99            server_to_client: BufferedStream::new()
100        }
101    }
102}
103
104pub fn new_connection<'a, const N: usize>(resources: &'a ConnectionRessources<N>) 
105    -> (ClientConnection<'a, N>, ServerConnection<'a, N>) {
106    
107    let client = ClientConnection{
108        out_stream: &resources.client_to_server,
109        in_stream: &resources.server_to_client
110    };
111
112    let server = ServerConnection{
113        out_stream: &resources.server_to_client,
114        in_stream: &resources.client_to_server
115    };
116
117    (client, server)
118
119}
120
121
122pub trait ReadAtomic: ErrorType {
123
124    fn read_atomic<T, F>(&self, f: F) -> impl Future<Output = Result<T, MqttPacketError>>
125        where F: Fn(&dyn BufferReader) -> Result<Option<T>, MqttPacketError>;
126
127    fn read_mqtt_packet<O, R>(&self, o: O) -> impl Future<Output = Result<R, MqttPacketError>>
128        where O: Fn(&Packet<'_>) -> R {
129        async move {
130            self.read_atomic(|reader|{
131                let packet = reader.read_packet()?;
132    
133                if let Some(packet) = packet {
134                    let result = o(&packet);
135                    Ok(Some(result))
136    
137                } else {
138                    Ok(None)
139                }
140            })
141            .await
142        }
143    }
144
145}
146
147pub struct ReadAtomicFuture<'b, T, F, const N: usize> where F: Fn(&dyn BufferReader) -> Result<Option<T>, MqttPacketError> {
148    f: F,
149    stream: &'b BufferedStream<N>
150}
151
152impl <'b, T, F, const N: usize> Future for ReadAtomicFuture<'b, T, F, N> 
153    where F: Fn(&dyn BufferReader) -> Result<Option<T>, MqttPacketError> {
154    
155    type Output = Result<T, MqttPacketError>;
156
157    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158        
159        self.stream.inner.lock(|inner| {
160            let mut inner = inner.borrow_mut();
161            let reader = inner.buffer.create_reader();
162
163            let result = (self.f)(&reader);
164            drop(reader);
165
166            match result {
167                Ok(op) => {
168                    if let Some(atomic) = op {
169                        Poll::Ready(Ok(atomic))
170                    } else {
171                        inner.read_waker.register(cx.waker());
172                        Poll::Pending
173                    }
174                },
175                Err(e) => Poll::Ready(Err(e)),
176            }
177        })
178    }
179}
180
181impl <'b, T, F, const N: usize> Unpin for ReadAtomicFuture<'b, T, F, N> where F: Fn(&dyn BufferReader) -> Result<Option<T>, MqttPacketError> {}
182
183impl <const N: usize> ReadAtomic for BufferedStream<N> {
184    fn read_atomic<T, F>(&self, f: F) -> impl Future<Output = Result<T, MqttPacketError>>
185        where F: Fn(&dyn BufferReader) -> Result<Option<T>, MqttPacketError> {
186            ReadAtomicFuture{
187                f,
188                stream: self
189            }
190    }
191}
192
193impl <'a, const N: usize> ReadAtomic for ServerConnection<'a, N> {
194    async fn read_atomic<T, F>(&self, f: F) -> Result<T, MqttPacketError> where F: Fn(&dyn BufferReader) -> Result<Option<T>, MqttPacketError>{
195            self.in_stream.read_atomic(f).await
196    }
197}
198
199#[cfg(all(test, feature = "std"))]
200mod connection_tests {
201    use core::time::Duration;
202
203    use embedded_io_async::{Read, Write};
204
205    use super::{new_connection, ConnectionRessources};
206
207    #[tokio::test]
208    async fn test_connection() {
209
210        let resources = ConnectionRessources::<4>::new();
211
212        let (mut client, mut server) = new_connection(&resources);
213
214        let client_future = async {
215
216            let n = client.write(&[0, 1, 2, 3]).await.unwrap();
217            assert_eq!(n, 4);
218
219            tokio::time::sleep(Duration::from_millis(100)).await;
220
221            for i in 4..8 {
222                client.write(&[i]).await.unwrap();
223                tokio::time::sleep(Duration::from_millis(100)).await;
224            }
225
226        };
227
228        let server_future = async {
229
230            let mut results = Vec::new();
231
232            for _ in 0..8 {
233                let mut buf = [0; 1];
234                server.read(&mut buf).await.unwrap();
235                results.push(buf[0]);
236            }
237
238        };
239
240        tokio::join!(client_future, server_future);
241
242    }
243}
244