Skip to main content

dag_scheduler/pipeline/
mod.rs

1use crossbeam::channel::{Receiver, Sender};
2use std::{
3    cell::Cell,
4    marker::PhantomData,
5    thread::{self, JoinHandle},
6};
7
8pub struct Source(PhantomData<*const ()>);
9pub struct Sink(PhantomData<*const ()>);
10
11pub trait TSourceWork: Clone + Send + 'static {
12    type SendType;
13
14    fn process(&self) -> Option<Self::SendType>;
15}
16
17pub trait TSinkWork: Clone + Send + 'static {
18    type RecvType;
19    fn process(&self, recv: Self::RecvType);
20}
21
22pub trait TIntermediateWork: Clone + Send + 'static {
23    type RecvType;
24    type SendType;
25    fn process(&self, recv: Self::RecvType) -> Self::SendType;
26}
27
28pub struct Pipeline<T> {
29    receiver: Option<Receiver<T>>, // receiver of last worker
30    join_handlers: Cell<Vec<JoinHandle<()>>>,
31}
32
33impl Pipeline<Source> {
34    pub fn new() -> Self {
35        Self {
36            receiver: None,
37            join_handlers: Cell::new(vec![]),
38        }
39    }
40
41    pub fn add_source_stage<S>(
42        self,
43        name: &str,
44        threads: usize,
45        handler: impl FnOnce(Sender<S>, usize) + Clone + Send + 'static,
46        cap: usize,
47        next_send_recv: Option<(Sender<S>, Receiver<S>)>,
48    ) -> Pipeline<S>
49    where
50        S: Send + 'static,
51    {
52        let mut handlers = self.join_handlers.take();
53
54        let (next_send, next_recv) =
55            next_send_recv.unwrap_or_else(|| crossbeam::channel::bounded(cap));
56        for idx in 0..threads {
57            let handler_ = handler.clone();
58            let send_ = next_send.clone();
59            let tname = format!("{}_{}", name, idx);
60            // println!("{}", tname);
61            let join_handler = thread::Builder::new()
62                .name(tname)
63                .spawn(move || {
64                    handler_(send_, idx);
65                })
66                .unwrap();
67            handlers.push(join_handler);
68        }
69
70        Pipeline {
71            receiver: Some(next_recv),
72            join_handlers: Cell::new(handlers),
73        }
74    }
75    pub fn add_source_work_stage<S, H>(
76        self,
77        name: &str,
78        threads: usize,
79        handler: H,
80        cap: usize,
81        next_send_recv: Option<(Sender<S>, Receiver<S>)>,
82    ) -> Pipeline<S>
83    where
84        H: TSourceWork<SendType = S>,
85        S: Send + 'static,
86    {
87        let mut handlers = self.join_handlers.take();
88
89        let (next_send, next_recv) =
90            next_send_recv.unwrap_or_else(|| crossbeam::channel::bounded(cap));
91        for idx in 0..threads {
92            let send_ = next_send.clone();
93            let tname = format!("{}_{}", name, idx);
94            let handler_ = handler.clone();
95            // println!("{}", tname);
96            let join_handler = thread::Builder::new()
97                .name(tname)
98                .spawn(move || {
99                    while let Some(v) = handler_.process() {
100                        send_.send(v).unwrap();
101                    }
102                })
103                .unwrap();
104            handlers.push(join_handler);
105        }
106
107        Pipeline {
108            receiver: Some(next_recv),
109            join_handlers: Cell::new(handlers),
110        }
111    }
112}
113
114impl<T> Pipeline<T>
115where
116    T: Send + 'static,
117{
118    pub fn add_stage<S>(
119        self,
120        name: &str,
121        threads: usize,
122        handler: impl FnOnce(Receiver<T>, Sender<S>, usize) + Clone + Send + 'static,
123        cap: usize,
124        next_send_recv: Option<(Sender<S>, Receiver<S>)>,
125    ) -> Pipeline<S>
126    where
127        S: Send + 'static,
128    {
129        let mut handlers = self.join_handlers.take();
130
131        let (next_send, next_recv) =
132            next_send_recv.unwrap_or_else(|| crossbeam::channel::bounded(cap));
133
134        for idx in 0..threads {
135            let handler_ = handler.clone();
136            let recv_ = self.receiver.clone();
137            let send_ = next_send.clone();
138            let tname = format!("{}_{}", name, idx);
139            // println!("{}", tname);
140            let join_handler = thread::Builder::new()
141                .name(tname)
142                .spawn(move || {
143                    handler_(recv_.unwrap(), send_, idx);
144                })
145                .unwrap();
146            handlers.push(join_handler);
147        }
148
149        Pipeline {
150            receiver: Some(next_recv),
151            join_handlers: Cell::new(handlers),
152        }
153    }
154
155    pub fn add_work_stage<S, H>(
156        self,
157        name: &str,
158        threads: usize,
159        handler: H,
160        cap: usize,
161        next_send_recv: Option<(Sender<S>, Receiver<S>)>,
162    ) -> Pipeline<S>
163    where
164        H: TIntermediateWork<RecvType = T, SendType = S>,
165        S: Send + 'static,
166    {
167        let mut handlers = self.join_handlers.take();
168
169        let (next_send, next_recv) =
170            next_send_recv.unwrap_or_else(|| crossbeam::channel::bounded(cap));
171
172        for idx in 0..threads {
173            let handler_ = handler.clone();
174            let recv_ = self.receiver.clone();
175            let send_ = next_send.clone();
176            let tname = format!("{}_{}", name, idx);
177            // println!("{}", tname);
178            let join_handler = thread::Builder::new()
179                .name(tname)
180                .spawn(move || {
181                    recv_.unwrap().iter().for_each(|v| {
182                        let send_v = handler_.process(v);
183                        send_.send(send_v).unwrap();
184                    });
185                })
186                .unwrap();
187            handlers.push(join_handler);
188        }
189
190        Pipeline {
191            receiver: Some(next_recv),
192            join_handlers: Cell::new(handlers),
193        }
194    }
195
196    pub fn add_sink_stage(
197        self,
198        name: &str,
199        threads: usize,
200        handler: impl FnOnce(Receiver<T>, usize) + Clone + Send + 'static,
201    ) -> Pipeline<Sink> {
202        let mut handlers = self.join_handlers.take();
203
204        for idx in 0..threads {
205            let handler_ = handler.clone();
206            let recv_ = self.receiver.clone();
207            let tname = format!("{}_{}", name, idx);
208            // println!("{}", tname);
209            let join_handler = thread::Builder::new()
210                .name(tname)
211                .spawn(move || {
212                    handler_(recv_.unwrap(), idx);
213                })
214                .unwrap();
215            handlers.push(join_handler);
216        }
217
218        Pipeline {
219            receiver: None,
220            join_handlers: Cell::new(handlers),
221        }
222    }
223
224    pub fn add_sink_work_stage<H>(self, name: &str, threads: usize, handler: H) -> Pipeline<Sink>
225    where
226        H: TSinkWork<RecvType = T>,
227    {
228        let mut handlers = self.join_handlers.take();
229
230        for idx in 0..threads {
231            let handler_ = handler.clone();
232            let recv_ = self.receiver.clone();
233            let tname = format!("{}_{}", name, idx);
234            // println!("{}", tname);
235            let join_handler = thread::Builder::new()
236                .name(tname)
237                .spawn(move || {
238                    recv_.unwrap().iter().for_each(|v| {
239                        handler_.process(v);
240                    });
241                })
242                .unwrap();
243            handlers.push(join_handler);
244        }
245
246        Pipeline {
247            receiver: None,
248            join_handlers: Cell::new(handlers),
249        }
250    }
251
252    pub fn take_receiver(&mut self) -> Option<Receiver<T>> {
253        self.receiver.take()
254    }
255}
256
257impl<T> Drop for Pipeline<T>
258where
259    T: Sized,
260{
261    fn drop(&mut self) {
262        self.join_handlers.take().into_iter().for_each(|handler| {
263            handler.join().unwrap();
264        });
265    }
266}
267
268#[cfg(test)]
269mod test {
270
271    use std::sync::{Arc, Mutex};
272
273    use crossbeam::channel::{Receiver, Sender};
274
275    use super::{Pipeline, TIntermediateWork, TSinkWork, TSourceWork};
276
277    #[test]
278    fn test_pipeline() {
279        let ppl = Pipeline::new();
280        let ppl = ppl.add_source_stage(
281            "source",
282            4,
283            move |s: Sender<i32>, _thread_idx: usize| {
284                s.send(100).unwrap();
285            },
286            10,
287            None,
288        );
289
290        let ppl = ppl.add_stage(
291            "Multiply",
292            2,
293            move |r: Receiver<i32>, s: Sender<i32>, _thread_idx: usize| {
294                for v in r {
295                    let _ = s.send(v * 10);
296                }
297            },
298            10,
299            None,
300        );
301
302        let _ppl = ppl.add_sink_stage("sink", 1, move |r: Receiver<i32>, _thread_idx: usize| {
303            let mut sum = 0;
304            for v in r {
305                sum += v;
306            }
307            println!("sum_result: {}", sum);
308        });
309    }
310
311    #[derive(Clone)]
312    pub struct SourceWork(pub Arc<Mutex<Vec<i32>>>);
313    impl TSourceWork for SourceWork {
314        type SendType = i32;
315
316        fn process(&self) -> Option<Self::SendType> {
317            if self.0.lock().unwrap().is_empty() {
318                return None;
319            }
320            let v = self.0.lock().unwrap().pop();
321            if v.is_none() {
322                return None;
323            }
324            return Some(v.unwrap());
325        }
326    }
327
328    #[derive(Debug, Clone)]
329    pub struct IntermidiateWork;
330    impl TIntermediateWork for IntermidiateWork {
331        type RecvType = i32;
332        type SendType = i32;
333
334        fn process(&self, r: i32) -> Self::SendType {
335            r * 10
336        }
337    }
338
339    #[derive(Debug, Clone)]
340    pub struct SinkWork;
341    impl TSinkWork for SinkWork {
342        type RecvType = i32;
343
344        fn process(&self, r: i32) {
345            println!("sink: {}", r);
346        }
347    }
348    #[test]
349    fn test_pipeline_with_work() {
350        let ppl = Pipeline::new();
351        let ppl = ppl.add_source_work_stage(
352            "source",
353            4,
354            SourceWork(Arc::new(Mutex::new(vec![1, 2, 3, 4]))),
355            10,
356            None,
357        );
358        let ppl = ppl.add_work_stage("Multiply", 2, IntermidiateWork, 10, None);
359        let _ppl = ppl.add_sink_work_stage("sink", 1, SinkWork);
360    }
361}