embedded_mqttc/network/fake/
connection.rs1use core::{future::Future, pin::Pin, task::{Context, Poll}};
2
3use embytes_buffer::{BufferReader, ReadWrite};
4use embedded_io_async::{ErrorKind, ErrorType, Read, Write};
5use mqttrs2::Packet;
6use crate::network::{mqtt::MqttPacketError, mqtt::ReadMqttPacket, NetworkConnection, NetworkError, TryRead, TryWrite};
7
8use super::BufferedStream;
9
10pub struct ServerConnection<'a, const N: usize> {
11 pub(crate) out_stream: &'a BufferedStream<N>,
12 pub(crate) in_stream: &'a BufferedStream<N>
13}
14
15impl <'a, const N: usize> ServerConnection <'a, N>{
16 pub fn with_reader<F, R>(&self, f: F) -> R where F: FnOnce(&dyn BufferReader) -> R {
17 self.in_stream.with_reader(f)
18 }
19}
20
21impl <'a, const N: usize> ErrorType for ServerConnection<'a, N> {
22 type Error = ErrorKind;
23}
24
25impl <'a, const N: usize> Write for ServerConnection<'a, N> {
26 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
27 self.out_stream.write_async(buf).await
28 }
29}
30
31impl <'a, const N: usize> Read for ServerConnection<'a, N> {
32 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
33 self.in_stream.read_async(buf).await
34 }
35}
36
37impl <'a, const N: usize> TryRead for ServerConnection<'a, N> {
38 async fn try_read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
39 self.in_stream.try_read_sync(buf)
40 }
41}
42
43impl <'a, const N: usize> TryWrite for ServerConnection<'a, N> {
44 async fn try_write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
45 self.out_stream.try_write_sync(buf)
46 }
47}
48
49pub struct ClientConnection<'a, const N: usize>{
50 out_stream: &'a BufferedStream<N>,
51 in_stream: &'a BufferedStream<N>
52}
53
54impl <'a, const N: usize> ErrorType for ClientConnection<'a, N> {
55 type Error = ErrorKind;
56}
57
58impl <'a, const N: usize> Unpin for ClientConnection<'a, N> {}
59
60impl <'a, const N: usize> Write for ClientConnection<'a, N> {
61 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
62 self.out_stream.write_async(buf).await
63 }
64}
65
66impl <'a, const N: usize> Read for ClientConnection<'a, N> {
67 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
68 self.in_stream.read_async(buf).await
69 }
70}
71
72impl <'a, const N: usize> TryRead for ClientConnection<'a, N> {
73 async fn try_read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
74 self.in_stream.try_read_sync(buf)
75 }
76}
77
78impl <'a, const N: usize> TryWrite for ClientConnection<'a, N> {
79 async fn try_write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
80 self.out_stream.try_write_sync(buf)
81 }
82}
83
84impl <'a, const N: usize> NetworkConnection for ClientConnection<'a, N> {
85 async fn connect(&mut self) -> Result<(), NetworkError> {
86 Ok(())
87 }
88}
89
90pub struct ConnectionRessources<const N: usize> {
91 client_to_server: BufferedStream<N>,
92 server_to_client: BufferedStream<N>
93}
94
95impl <const N: usize> ConnectionRessources<N> {
96 pub fn new() -> Self {
97 Self {
98 client_to_server: BufferedStream::new(),
99 server_to_client: BufferedStream::new()
100 }
101 }
102}
103
104pub fn new_connection<'a, const N: usize>(resources: &'a ConnectionRessources<N>)
105 -> (ClientConnection<'a, N>, ServerConnection<'a, N>) {
106
107 let client = ClientConnection{
108 out_stream: &resources.client_to_server,
109 in_stream: &resources.server_to_client
110 };
111
112 let server = ServerConnection{
113 out_stream: &resources.server_to_client,
114 in_stream: &resources.client_to_server
115 };
116
117 (client, server)
118
119}
120
121
122pub trait ReadAtomic: ErrorType {
123
124 fn read_atomic<T, F>(&self, f: F) -> impl Future<Output = Result<T, MqttPacketError>>
125 where F: Fn(&dyn BufferReader) -> Result<Option<T>, MqttPacketError>;
126
127 fn read_mqtt_packet<O, R>(&self, o: O) -> impl Future<Output = Result<R, MqttPacketError>>
128 where O: Fn(&Packet<'_>) -> R {
129 async move {
130 self.read_atomic(|reader|{
131 let packet = reader.read_packet()?;
132
133 if let Some(packet) = packet {
134 let result = o(&packet);
135 Ok(Some(result))
136
137 } else {
138 Ok(None)
139 }
140 })
141 .await
142 }
143 }
144
145}
146
147pub struct ReadAtomicFuture<'b, T, F, const N: usize> where F: Fn(&dyn BufferReader) -> Result<Option<T>, MqttPacketError> {
148 f: F,
149 stream: &'b BufferedStream<N>
150}
151
152impl <'b, T, F, const N: usize> Future for ReadAtomicFuture<'b, T, F, N>
153 where F: Fn(&dyn BufferReader) -> Result<Option<T>, MqttPacketError> {
154
155 type Output = Result<T, MqttPacketError>;
156
157 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158
159 self.stream.inner.lock(|inner| {
160 let mut inner = inner.borrow_mut();
161 let reader = inner.buffer.create_reader();
162
163 let result = (self.f)(&reader);
164 drop(reader);
165
166 match result {
167 Ok(op) => {
168 if let Some(atomic) = op {
169 Poll::Ready(Ok(atomic))
170 } else {
171 inner.read_waker.register(cx.waker());
172 Poll::Pending
173 }
174 },
175 Err(e) => Poll::Ready(Err(e)),
176 }
177 })
178 }
179}
180
181impl <'b, T, F, const N: usize> Unpin for ReadAtomicFuture<'b, T, F, N> where F: Fn(&dyn BufferReader) -> Result<Option<T>, MqttPacketError> {}
182
183impl <const N: usize> ReadAtomic for BufferedStream<N> {
184 fn read_atomic<T, F>(&self, f: F) -> impl Future<Output = Result<T, MqttPacketError>>
185 where F: Fn(&dyn BufferReader) -> Result<Option<T>, MqttPacketError> {
186 ReadAtomicFuture{
187 f,
188 stream: self
189 }
190 }
191}
192
193impl <'a, const N: usize> ReadAtomic for ServerConnection<'a, N> {
194 async fn read_atomic<T, F>(&self, f: F) -> Result<T, MqttPacketError> where F: Fn(&dyn BufferReader) -> Result<Option<T>, MqttPacketError>{
195 self.in_stream.read_atomic(f).await
196 }
197}
198
199#[cfg(test)]
200mod connection_tests {
201 use core::time::Duration;
202
203 use embedded_io_async::{Read, Write};
204
205 use super::{new_connection, ConnectionRessources};
206
207 #[tokio::test]
208 async fn test_connection() {
209
210 let resources = ConnectionRessources::<4>::new();
211
212 let (mut client, mut server) = new_connection(&resources);
213
214 let client_future = async {
215
216 let n = client.write(&[0, 1, 2, 3]).await.unwrap();
217 assert_eq!(n, 4);
218
219 tokio::time::sleep(Duration::from_millis(100)).await;
220
221 for i in 4..8 {
222 client.write(&[i]).await.unwrap();
223 tokio::time::sleep(Duration::from_millis(100)).await;
224 }
225
226 };
227
228 let server_future = async {
229
230 let mut results = Vec::new();
231
232 for _ in 0..8 {
233 let mut buf = [0; 1];
234 server.read(&mut buf).await.unwrap();
235 results.push(buf[0]);
236 }
237
238 };
239
240 tokio::join!(client_future, server_future);
241
242 }
243}
244