async_pipes/pipeline/
io.rs1use std::sync::Arc;
2
3use crate::pipeline::sync::Synchronizer;
4
5macro_rules! variable_channels {
6 (<$t:ident> $($var:ident($tx:ty, $rx:ty)),+ $(,)?) => {
7 #[derive(Debug)]
8 pub enum VarSender<$t> {
9 $( $var($tx), )*
10 }
11
12 #[derive(Debug)]
13 pub enum VarReceiver<$t> {
14 $( $var($rx), )*
15 }
16
17 impl<$t> Clone for VarSender<$t> {
18 fn clone(&self) -> Self {
19 match self {
20 $( Self::$var(tx) => Self::$var(tx.clone()), )*
21 }
22 }
23 }
24
25 impl<$t> VarReceiver<$t> {
26 async fn recv(&mut self) -> Option<$t> {
27 match self {
28 $( Self::$var(rx) => rx.recv().await, )*
29 }
30 }
31 }
32
33 $(
34 impl<$t> From<$tx> for VarSender<$t> {
35 fn from(value: $tx) -> Self {
36 Self::$var(value)
37 }
38 }
39
40 impl<$t> From<$rx> for VarReceiver<$t> {
41 fn from(value: $rx) -> Self {
42 Self::$var(value)
43 }
44 }
45 )*
46 };
47}
48
49variable_channels! {
50 <T>
51 MpscBounded(tokio::sync::mpsc::Sender<T>, tokio::sync::mpsc::Receiver<T>),
52 MpscUnbounded(tokio::sync::mpsc::UnboundedSender<T>, tokio::sync::mpsc::UnboundedReceiver<T>),
53}
54
55impl<T> VarSender<T> {
56 async fn send(&self, t: T) -> Result<(), tokio::sync::mpsc::error::SendError<T>> {
58 match self {
59 Self::MpscBounded(tx) => tx.send(t).await,
60 Self::MpscUnbounded(tx) => tx.send(t),
61 }
62 }
63}
64
65pub struct ConsumeOnDrop {
66 id: String,
67 sync: Arc<Synchronizer>,
68}
69
70impl Drop for ConsumeOnDrop {
71 fn drop(&mut self) {
72 self.sync.ended(&self.id)
73 }
74}
75
76#[derive(Debug)]
78pub struct PipeReader<T> {
79 pipe_id: String,
80 synchronizer: Arc<Synchronizer>,
81 rx: VarReceiver<T>,
82}
83
84impl<T> PipeReader<T> {
85 pub fn new(
86 pipe_id: String,
87 synchronizer: Arc<Synchronizer>,
88 rx: impl Into<VarReceiver<T>>,
89 ) -> Self {
90 Self {
91 pipe_id,
92 synchronizer,
93 rx: rx.into(),
94 }
95 }
96
97 #[allow(dead_code)]
98 pub fn get_pipe_id(&self) -> &str {
99 &self.pipe_id
100 }
101
102 pub async fn read(&mut self) -> Option<(T, ConsumeOnDrop)> {
104 self.rx.recv().await.map(|v| {
105 let cod = ConsumeOnDrop {
106 id: self.pipe_id.clone(),
107 sync: self.synchronizer.clone(),
108 };
109
110 (v, cod)
111 })
112 }
113}
114
115#[derive(Debug)]
117pub struct PipeWriter<T> {
118 pipe_id: String,
119 synchronizer: Arc<Synchronizer>,
120 tx: VarSender<T>,
121}
122
123impl<T> Clone for PipeWriter<T> {
126 fn clone(&self) -> Self {
127 Self {
128 pipe_id: self.pipe_id.clone(),
129 synchronizer: self.synchronizer.clone(),
130 tx: self.tx.clone(),
131 }
132 }
133}
134
135impl<T> PipeWriter<T> {
136 pub fn new(
137 pipe_id: String,
138 synchronizer: Arc<Synchronizer>,
139 tx: impl Into<VarSender<T>>,
140 ) -> Self {
141 Self {
142 pipe_id,
143 synchronizer,
144 tx: tx.into(),
145 }
146 }
147
148 #[allow(dead_code)]
149 pub fn get_pipe_id(&self) -> &str {
150 &self.pipe_id
151 }
152
153 pub async fn write(&self, value: T) {
155 self.synchronizer.started(&self.pipe_id);
156 self.tx
157 .send(value)
158 .await
159 .expect("failed to send input over channel");
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use std::sync::Arc;
166
167 use tokio::sync::mpsc::channel;
168
169 use super::*;
170
171 #[tokio::test]
172 async fn test_read_consumed_updates_sync_on_drop() {
173 let id = "pipe-id";
174 let mut sync = Synchronizer::default();
175 sync.register(id);
176 sync.started_many(id, 4);
177
178 let sync = Arc::new(sync);
179 let (tx, rx) = channel::<()>(1);
180
181 let mut input = PipeReader::new(id.to_string(), sync.clone(), rx);
182
183 tx.send(()).await.unwrap();
184
185 {
186 let (_, _c) = input.read().await.unwrap();
187 assert_eq!(sync.get(id), 4);
188 }
189 assert_eq!(sync.get(id), 3);
190 }
191}