atomr_remote/
reader_writer.rs1use std::sync::Arc;
17
18use tokio::sync::mpsc;
19use tokio::task::JoinHandle;
20
21#[async_trait::async_trait]
26pub trait RawTransport: Send + Sync + 'static {
27 type Frame: Send + 'static;
29 type OutFrame: Send + 'static;
31 type Error: Send + 'static + std::fmt::Debug;
33
34 async fn recv(&self) -> Result<Option<Self::Frame>, Self::Error>;
35 async fn send(&self, frame: Self::OutFrame) -> Result<(), Self::Error>;
36}
37
38pub struct ReaderWriterHandle<F, O> {
42 pub outbound: mpsc::UnboundedSender<O>,
43 pub inbound: mpsc::UnboundedReceiver<F>,
44 pub reader: JoinHandle<()>,
45 pub writer: JoinHandle<()>,
46}
47
48pub fn spawn_reader_writer<T>(
56 transport: Arc<T>,
57 outbound_capacity: usize,
58) -> ReaderWriterHandle<T::Frame, T::OutFrame>
59where
60 T: RawTransport,
61{
62 let outbound_capacity = outbound_capacity.max(1);
63 let (out_tx, mut out_rx) = mpsc::unbounded_channel::<T::OutFrame>();
64 let (in_tx, in_rx) = mpsc::unbounded_channel::<T::Frame>();
65
66 let _ = outbound_capacity;
71
72 let r_transport = transport.clone();
73 let r_in_tx = in_tx.clone();
74 let reader = tokio::spawn(async move {
75 loop {
76 match r_transport.recv().await {
77 Ok(Some(frame)) => {
78 if r_in_tx.send(frame).is_err() {
79 return; }
81 }
82 Ok(None) => return, Err(_e) => return, }
85 }
86 });
87
88 let w_transport = transport;
89 let writer = tokio::spawn(async move {
90 while let Some(frame) = out_rx.recv().await {
91 if w_transport.send(frame).await.is_err() {
92 return;
93 }
94 }
95 });
96
97 ReaderWriterHandle { outbound: out_tx, inbound: in_rx, reader, writer }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use std::sync::atomic::{AtomicU32, Ordering};
104 use tokio::sync::Mutex;
105
106 struct TestTransport {
109 recv_q: Mutex<Vec<i32>>,
110 sent: Mutex<Vec<i32>>,
111 recv_calls: AtomicU32,
112 }
113
114 #[async_trait::async_trait]
115 impl RawTransport for TestTransport {
116 type Frame = i32;
117 type OutFrame = i32;
118 type Error = ();
119
120 async fn recv(&self) -> Result<Option<i32>, ()> {
121 self.recv_calls.fetch_add(1, Ordering::SeqCst);
122 let mut q = self.recv_q.lock().await;
123 Ok(q.pop())
124 }
125
126 async fn send(&self, frame: i32) -> Result<(), ()> {
127 self.sent.lock().await.push(frame);
128 Ok(())
129 }
130 }
131
132 #[tokio::test]
133 async fn reader_pumps_until_eof() {
134 let t = Arc::new(TestTransport {
135 recv_q: Mutex::new(vec![3, 2, 1]), sent: Mutex::new(Vec::new()),
137 recv_calls: AtomicU32::new(0),
138 });
139 let mut handle = spawn_reader_writer(t.clone(), 8);
140 let mut got = Vec::new();
141 for _ in 0..3 {
142 got.push(handle.inbound.recv().await.unwrap());
143 }
144 let _ = handle.reader.await;
146 assert_eq!(got, vec![1, 2, 3]);
147 }
148
149 #[tokio::test]
150 async fn writer_drains_outbound_channel() {
151 let t = Arc::new(TestTransport {
152 recv_q: Mutex::new(Vec::new()), sent: Mutex::new(Vec::new()),
154 recv_calls: AtomicU32::new(0),
155 });
156 let handle = spawn_reader_writer(t.clone(), 8);
157 for i in 0..5 {
158 handle.outbound.send(i).unwrap();
159 }
160 drop(handle.outbound);
162 let _ = handle.writer.await;
163 let sent = t.sent.lock().await.clone();
164 assert_eq!(sent, vec![0, 1, 2, 3, 4]);
165 }
166
167 #[tokio::test]
168 async fn reader_and_writer_run_concurrently() {
169 let t = Arc::new(TestTransport {
171 recv_q: Mutex::new(vec![20, 10]),
172 sent: Mutex::new(Vec::new()),
173 recv_calls: AtomicU32::new(0),
174 });
175 let mut handle = spawn_reader_writer(t.clone(), 4);
176
177 let in_a = handle.inbound.recv().await.unwrap();
178 handle.outbound.send(100).unwrap();
179 let in_b = handle.inbound.recv().await.unwrap();
180 handle.outbound.send(200).unwrap();
181
182 drop(handle.outbound);
183 let _ = handle.reader.await;
184 let _ = handle.writer.await;
185
186 assert_eq!(in_a, 10);
187 assert_eq!(in_b, 20);
188 let sent = t.sent.lock().await.clone();
189 assert_eq!(sent, vec![100, 200]);
190 }
191}