atomr_remote/
reader_writer.rs1use std::sync::Arc;
19
20use tokio::sync::mpsc;
21use tokio::task::JoinHandle;
22
23#[async_trait::async_trait]
28pub trait RawTransport: Send + Sync + 'static {
29 type Frame: Send + 'static;
31 type OutFrame: Send + 'static;
33 type Error: Send + 'static + std::fmt::Debug;
35
36 async fn recv(&self) -> Result<Option<Self::Frame>, Self::Error>;
37 async fn send(&self, frame: Self::OutFrame) -> Result<(), Self::Error>;
38}
39
40pub struct ReaderWriterHandle<F, O> {
44 pub outbound: mpsc::UnboundedSender<O>,
45 pub inbound: mpsc::UnboundedReceiver<F>,
46 pub reader: JoinHandle<()>,
47 pub writer: JoinHandle<()>,
48}
49
50pub fn spawn_reader_writer<T>(
58 transport: Arc<T>,
59 outbound_capacity: usize,
60) -> ReaderWriterHandle<T::Frame, T::OutFrame>
61where
62 T: RawTransport,
63{
64 let outbound_capacity = outbound_capacity.max(1);
65 let (out_tx, mut out_rx) = mpsc::unbounded_channel::<T::OutFrame>();
66 let (in_tx, in_rx) = mpsc::unbounded_channel::<T::Frame>();
67
68 let _ = outbound_capacity;
73
74 let r_transport = transport.clone();
75 let r_in_tx = in_tx.clone();
76 let reader = tokio::spawn(async move {
77 loop {
78 match r_transport.recv().await {
79 Ok(Some(frame)) => {
80 if r_in_tx.send(frame).is_err() {
81 return; }
83 }
84 Ok(None) => return, Err(_e) => return, }
87 }
88 });
89
90 let w_transport = transport;
91 let writer = tokio::spawn(async move {
92 while let Some(frame) = out_rx.recv().await {
93 if w_transport.send(frame).await.is_err() {
94 return;
95 }
96 }
97 });
98
99 ReaderWriterHandle { outbound: out_tx, inbound: in_rx, reader, writer }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use std::sync::atomic::{AtomicU32, Ordering};
106 use tokio::sync::Mutex;
107
108 struct TestTransport {
111 recv_q: Mutex<Vec<i32>>,
112 sent: Mutex<Vec<i32>>,
113 recv_calls: AtomicU32,
114 }
115
116 #[async_trait::async_trait]
117 impl RawTransport for TestTransport {
118 type Frame = i32;
119 type OutFrame = i32;
120 type Error = ();
121
122 async fn recv(&self) -> Result<Option<i32>, ()> {
123 self.recv_calls.fetch_add(1, Ordering::SeqCst);
124 let mut q = self.recv_q.lock().await;
125 Ok(q.pop())
126 }
127
128 async fn send(&self, frame: i32) -> Result<(), ()> {
129 self.sent.lock().await.push(frame);
130 Ok(())
131 }
132 }
133
134 #[tokio::test]
135 async fn reader_pumps_until_eof() {
136 let t = Arc::new(TestTransport {
137 recv_q: Mutex::new(vec![3, 2, 1]), sent: Mutex::new(Vec::new()),
139 recv_calls: AtomicU32::new(0),
140 });
141 let mut handle = spawn_reader_writer(t.clone(), 8);
142 let mut got = Vec::new();
143 for _ in 0..3 {
144 got.push(handle.inbound.recv().await.unwrap());
145 }
146 let _ = handle.reader.await;
148 assert_eq!(got, vec![1, 2, 3]);
149 }
150
151 #[tokio::test]
152 async fn writer_drains_outbound_channel() {
153 let t = Arc::new(TestTransport {
154 recv_q: Mutex::new(Vec::new()), sent: Mutex::new(Vec::new()),
156 recv_calls: AtomicU32::new(0),
157 });
158 let handle = spawn_reader_writer(t.clone(), 8);
159 for i in 0..5 {
160 handle.outbound.send(i).unwrap();
161 }
162 drop(handle.outbound);
164 let _ = handle.writer.await;
165 let sent = t.sent.lock().await.clone();
166 assert_eq!(sent, vec![0, 1, 2, 3, 4]);
167 }
168
169 #[tokio::test]
170 async fn reader_and_writer_run_concurrently() {
171 let t = Arc::new(TestTransport {
173 recv_q: Mutex::new(vec![20, 10]),
174 sent: Mutex::new(Vec::new()),
175 recv_calls: AtomicU32::new(0),
176 });
177 let mut handle = spawn_reader_writer(t.clone(), 4);
178
179 let in_a = handle.inbound.recv().await.unwrap();
180 handle.outbound.send(100).unwrap();
181 let in_b = handle.inbound.recv().await.unwrap();
182 handle.outbound.send(200).unwrap();
183
184 drop(handle.outbound);
185 let _ = handle.reader.await;
186 let _ = handle.writer.await;
187
188 assert_eq!(in_a, 10);
189 assert_eq!(in_b, 20);
190 let sent = t.sent.lock().await.clone();
191 assert_eq!(sent, vec![100, 200]);
192 }
193}