async_pipes/pipeline/
io.rs

1use std::sync::Arc;
2
3use crate::pipeline::sync::Synchronizer;
4
5macro_rules! variable_channels {
6    (<$t:ident> $($var:ident($tx:ty, $rx:ty)),+ $(,)?) => {
7        #[derive(Debug)]
8        pub enum VarSender<$t> {
9            $( $var($tx), )*
10        }
11
12        #[derive(Debug)]
13        pub enum VarReceiver<$t> {
14            $( $var($rx), )*
15        }
16
17        impl<$t> Clone for VarSender<$t> {
18            fn clone(&self) -> Self {
19                match self {
20                    $( Self::$var(tx) => Self::$var(tx.clone()), )*
21                }
22            }
23        }
24
25        impl<$t> VarReceiver<$t> {
26            async fn recv(&mut self) -> Option<$t> {
27                match self {
28                    $( Self::$var(rx) => rx.recv().await, )*
29                }
30            }
31        }
32
33        $(
34            impl<$t> From<$tx> for VarSender<$t> {
35                fn from(value: $tx) -> Self {
36                    Self::$var(value)
37                }
38            }
39
40            impl<$t> From<$rx> for VarReceiver<$t> {
41                fn from(value: $rx) -> Self {
42                    Self::$var(value)
43                }
44            }
45        )*
46    };
47}
48
49variable_channels! {
50    <T>
51    MpscBounded(tokio::sync::mpsc::Sender<T>, tokio::sync::mpsc::Receiver<T>),
52    MpscUnbounded(tokio::sync::mpsc::UnboundedSender<T>, tokio::sync::mpsc::UnboundedReceiver<T>),
53}
54
55impl<T> VarSender<T> {
56    /// Implement send outside of macro because of variations in sender interfaces.
57    async fn send(&self, t: T) -> Result<(), tokio::sync::mpsc::error::SendError<T>> {
58        match self {
59            Self::MpscBounded(tx) => tx.send(t).await,
60            Self::MpscUnbounded(tx) => tx.send(t),
61        }
62    }
63}
64
65pub struct ConsumeOnDrop {
66    id: String,
67    sync: Arc<Synchronizer>,
68}
69
70impl Drop for ConsumeOnDrop {
71    fn drop(&mut self) {
72        self.sync.ended(&self.id)
73    }
74}
75
76/// Defines an end to a pipe that allows data to be received from.
77#[derive(Debug)]
78pub struct PipeReader<T> {
79    pipe_id: String,
80    synchronizer: Arc<Synchronizer>,
81    rx: VarReceiver<T>,
82}
83
84impl<T> PipeReader<T> {
85    pub fn new(
86        pipe_id: String,
87        synchronizer: Arc<Synchronizer>,
88        rx: impl Into<VarReceiver<T>>,
89    ) -> Self {
90        Self {
91            pipe_id,
92            synchronizer,
93            rx: rx.into(),
94        }
95    }
96
97    #[allow(dead_code)]
98    pub fn get_pipe_id(&self) -> &str {
99        &self.pipe_id
100    }
101
102    /// Receive the next value from the inner receiver.
103    pub async fn read(&mut self) -> Option<(T, ConsumeOnDrop)> {
104        self.rx.recv().await.map(|v| {
105            let cod = ConsumeOnDrop {
106                id: self.pipe_id.clone(),
107                sync: self.synchronizer.clone(),
108            };
109
110            (v, cod)
111        })
112    }
113}
114
115/// Defines an end to a pipe that allows data to be sent through.
116#[derive(Debug)]
117pub struct PipeWriter<T> {
118    pipe_id: String,
119    synchronizer: Arc<Synchronizer>,
120    tx: VarSender<T>,
121}
122
123/// Manually implement [Clone] for [PipeWriter] over `T`, as deriving [Clone]
124/// does not implement it over generic parameter.
125impl<T> Clone for PipeWriter<T> {
126    fn clone(&self) -> Self {
127        Self {
128            pipe_id: self.pipe_id.clone(),
129            synchronizer: self.synchronizer.clone(),
130            tx: self.tx.clone(),
131        }
132    }
133}
134
135impl<T> PipeWriter<T> {
136    pub fn new(
137        pipe_id: String,
138        synchronizer: Arc<Synchronizer>,
139        tx: impl Into<VarSender<T>>,
140    ) -> Self {
141        Self {
142            pipe_id,
143            synchronizer,
144            tx: tx.into(),
145        }
146    }
147
148    #[allow(dead_code)]
149    pub fn get_pipe_id(&self) -> &str {
150        &self.pipe_id
151    }
152
153    /// Increment the task count for this pipe and then send the value through the channel.
154    pub async fn write(&self, value: T) {
155        self.synchronizer.started(&self.pipe_id);
156        self.tx
157            .send(value)
158            .await
159            .expect("failed to send input over channel");
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use std::sync::Arc;
166
167    use tokio::sync::mpsc::channel;
168
169    use super::*;
170
171    #[tokio::test]
172    async fn test_read_consumed_updates_sync_on_drop() {
173        let id = "pipe-id";
174        let mut sync = Synchronizer::default();
175        sync.register(id);
176        sync.started_many(id, 4);
177
178        let sync = Arc::new(sync);
179        let (tx, rx) = channel::<()>(1);
180
181        let mut input = PipeReader::new(id.to_string(), sync.clone(), rx);
182
183        tx.send(()).await.unwrap();
184
185        {
186            let (_, _c) = input.read().await.unwrap();
187            assert_eq!(sync.get(id), 4);
188        }
189        assert_eq!(sync.get(id), 3);
190    }
191}