fugue_mptp/
lib.rs

1use std::collections::BTreeMap;
2use std::process::exit;
3
4use easy_parallel::Parallel;
5use ipc_channel::ipc::{
6    self, IpcOneShotServer, IpcReceiver, IpcReceiverSet, IpcSelectionResult, IpcSender,
7};
8use nix::unistd::{fork, ForkResult};
9use rand::RngCore;
10use thiserror::Error;
11pub use uuid::Uuid;
12
13pub mod sources;
14
15#[derive(Debug, Error)]
16pub enum Error {
17    #[error(transparent)]
18    ChannelRecv(#[from] flume::RecvError),
19    #[error("cannot send data on closed channel")]
20    ChannelSend,
21    #[error(transparent)]
22    Decode(#[from] rmp_serde::decode::Error),
23    #[error("fork failed: {0}")]
24    Fork(nix::errno::Errno),
25    #[error(transparent)]
26    Io(#[from] std::io::Error),
27    #[error(transparent)]
28    Ipc(#[from] ipc_channel::Error),
29    #[error(transparent)]
30    TaskSink(anyhow::Error),
31    #[error(transparent)]
32    TaskSource(anyhow::Error),
33    #[error("one or more errors encountered during task processing")]
34    Process(Vec<Self>),
35}
36
37impl Error {
38    pub fn source<E>(e: E) -> Self
39    where
40        E: std::error::Error + Send + Sync + 'static,
41    {
42        Self::TaskSource(anyhow::Error::from(e))
43    }
44
45    pub fn sink<E>(e: E) -> Self
46    where
47        E: std::error::Error + Send + Sync + 'static,
48    {
49        Self::TaskSink(anyhow::Error::from(e))
50    }
51}
52
53pub trait TaskData: serde::Serialize + serde::de::DeserializeOwned + Send {}
54impl<T> TaskData for T where T: serde::Serialize + serde::de::DeserializeOwned + Send {}
55
56#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
57pub struct TaskItem<T> {
58    id: Uuid,
59    payload: T,
60}
61
62#[derive(Debug, serde::Deserialize, serde::Serialize)]
63pub struct TaskResult<T, E> {
64    id: Uuid,
65    result: Result<T, E>,
66}
67
68pub trait TaskSource: Send {
69    type TaskInput: Clone + TaskData;
70    type Error: std::error::Error + Send + Sync + 'static;
71
72    fn next_task(&mut self, id: Uuid) -> Result<Option<Self::TaskInput>, Self::Error>;
73}
74
75pub trait TaskProcessor: Send {
76    type TaskInput: Clone + TaskData;
77    type TaskOutput: TaskData;
78    type TaskError: TaskData;
79
80    fn process_task(
81        &mut self,
82        id: Uuid,
83        input: Self::TaskInput,
84    ) -> Result<Self::TaskOutput, Self::TaskError>;
85}
86
87pub trait TaskSink: Send {
88    type TaskOutput: TaskData;
89    type TaskError: TaskData;
90
91    type Error: std::error::Error + Send + Sync + 'static;
92
93    fn process_task_result(
94        &mut self,
95        id: Uuid,
96        result: Result<Self::TaskOutput, Self::TaskError>,
97    ) -> Result<(), Self::Error>;
98}
99
100pub fn run<T, U, E>(
101    source: &mut impl TaskSource<TaskInput = T>,
102    processor: &mut impl TaskProcessor<TaskInput = T, TaskOutput = U, TaskError = E>,
103    sink: &mut impl TaskSink<TaskOutput = U, TaskError = E>,
104) -> Result<(), Error>
105where
106    T: Clone + TaskData,
107    U: TaskData,
108    E: TaskData,
109{
110    let (result_tx, result_rx) = flume::unbounded::<TaskResult<U, E>>();
111    let (child_tx, child_rx) = flume::unbounded::<IpcSender<Option<TaskItem<T>>>>();
112
113    let mut rxs = IpcReceiverSet::new()?;
114    let mut txs = BTreeMap::new();
115
116    for _ in 0..num_cpus::get() {
117        // NOTE: this is a workaround for a bug in ipc-channel on MacOS. As far as I can tell, if
118        // we do not do this, then when a child process constructs a `IpcOneShotServer`, we will
119        // have a name clash, since the name is based on the thread_rng which is duplicated and not
120        // reseeded across the fork. In theory, this should be handled by retrying with a new name,
121        // but for some reason the error returned in the child is not `BOOTSTRAP_NAME_IN_USE`, and
122        // rather some other error, which means that the child process bails due to failing to
123        // create a server.
124        //
125        rand::thread_rng().next_u64();
126
127        let (psrv, pn) = IpcOneShotServer::<(String, IpcReceiver<Option<U>>)>::new()?;
128
129        match unsafe { fork() }.map_err(Error::Fork)? {
130            ForkResult::Child => {
131                // NOTE: we first set-up bi-directional IPC to obtain work and return
132                // results.
133                //
134                let (csrv, cn) = IpcOneShotServer::<IpcReceiver<Option<TaskItem<T>>>>::new()
135                    .map_err(Error::Io)?;
136
137                let results = IpcSender::connect(pn)?;
138
139                let (results_tx, results_rx) = ipc::channel::<Option<Vec<u8>>>()?;
140
141                results.send((cn, results_rx))?;
142
143                let (_, task_rx) = csrv.accept()?;
144
145                // NOTE: now we can begin to handle tasks.
146                //
147                while let Ok(Some(task)) = task_rx.recv() {
148                    let results = TaskResult {
149                        id: task.id,
150                        result: processor.process_task(task.id, task.payload),
151                    };
152                    results_tx.send(rmp_serde::to_vec(&results).ok())?;
153                }
154
155                exit(0);
156            }
157            ForkResult::Parent { .. } => {
158                let (_, (cn, rx)) = psrv.accept()?;
159
160                let tx = IpcSender::<IpcReceiver<Option<TaskItem<T>>>>::connect(cn)?;
161
162                let (ctx, crx) = ipc::channel()?;
163
164                tx.send(crx)?;
165
166                if !child_tx.send(ctx.clone()).is_err() {
167                    let child = rxs.add(rx)?;
168                    txs.insert(child, Some(ctx));
169                }
170            }
171        }
172    }
173
174    let (task_tx, task_rx) = flume::bounded::<TaskItem<T>>(num_cpus::get());
175
176    let (rs, r) = Parallel::new()
177        .add(move || {
178            let mut id = Uuid::now_v7();
179            while let Some(payload) = source.next_task(id).map_err(Error::source)? {
180                task_tx
181                    .send(TaskItem { id, payload })
182                    .map_err(|_| Error::ChannelSend)?;
183                id = Uuid::now_v7();
184            }
185
186            drop(task_tx);
187
188            Ok::<(), Error>(())
189        })
190        .add(move || {
191            while let Ok(TaskResult { id, result }) = result_rx.recv() {
192                sink.process_task_result(id, result).map_err(Error::sink)?;
193            }
194
195            Ok::<(), Error>(())
196        })
197        .add(move || {
198            while let Ok(task) = task_rx.recv() {
199                loop {
200                    let child = child_rx.recv()?;
201
202                    if let Err(_) = child.send(Some(task.clone())) {
203                        // cannot send task to child
204                        continue;
205                    }
206
207                    break;
208                }
209            }
210
211            for r in child_rx.drain() {
212                r.send(None).ok();
213                drop(r);
214            }
215
216            drop(child_rx);
217            drop(task_rx);
218
219            Ok(())
220        })
221        .finish(move || {
222            while let Ok(events) = rxs.select() {
223                for event in events {
224                    match event {
225                        IpcSelectionResult::ChannelClosed(ref child) => {
226                            txs.remove(child);
227                        }
228                        IpcSelectionResult::MessageReceived(ref child, msg) => {
229                            let result = msg
230                                .to::<Option<Vec<u8>>>()?
231                                .map(|v| rmp_serde::from_slice::<TaskResult<U, E>>(&v))
232                                .transpose()?;
233
234                            if let Some(ref child) = txs[child] {
235                                if child_tx.send(child.clone()).is_err() {
236                                    // NOTE: after the first failure, we will not process more
237                                    // work, so we clear txs and wait for results.
238                                    txs.values_mut().for_each(|ch| {
239                                        if let Some(ch) = ch.take() {
240                                            ch.send(None).ok();
241                                        }
242                                    });
243                                }
244                            }
245
246                            if let Some(result) = result {
247                                result_tx.send(result).map_err(|_| Error::ChannelSend)?;
248                            }
249                        }
250                    }
251                }
252
253                if txs.is_empty() {
254                    // NOTE: all children are complete; we're done!
255                    break;
256                }
257            }
258
259            drop(child_tx);
260            drop(txs);
261            drop(rxs);
262
263            Ok::<(), Error>(())
264        });
265
266    let mut errs = Vec::new();
267
268    if let Err(e) = r {
269        errs.push(e);
270    }
271
272    errs.extend(rs.into_iter().filter_map(Result::err));
273
274    if errs.is_empty() {
275        Ok(())
276    } else {
277        Err(Error::Process(errs))
278    }
279}
280
281#[cfg(test)]
282mod test {
283    use std::convert::Infallible;
284
285    use super::*;
286
287    #[test]
288    fn test_simple() -> Result<(), Error> {
289        struct Source {
290            counter: usize,
291            limit: usize,
292        }
293
294        struct Process {
295            counter: usize,
296        }
297
298        struct Sink {
299            counter: usize,
300        }
301
302        #[derive(Debug, serde::Deserialize, serde::Serialize)]
303        struct Payload {
304            summary: String,
305            count: usize,
306        }
307
308        impl TaskSource for Source {
309            type TaskInput = String;
310
311            type Error = Infallible;
312
313            fn next_task(&mut self, id: Uuid) -> Result<Option<String>, Self::Error> {
314                if self.counter >= self.limit {
315                    return Ok(None);
316                }
317
318                self.counter += 1;
319
320                Ok(Some(format!("{}:{id}", self.counter)))
321            }
322        }
323
324        impl TaskSink for Sink {
325            type TaskOutput = Payload;
326            type TaskError = String;
327
328            type Error = Infallible;
329
330            fn process_task_result(
331                &mut self,
332                id: Uuid,
333                result: Result<Payload, String>,
334            ) -> Result<(), Self::Error> {
335                match result {
336                    Ok(Payload { summary, count }) => {
337                        self.counter += 1;
338                        println!("task {id} processed successfully (payload: {summary} / {count})")
339                    }
340                    Err(v) => {
341                        println!("task {id} failed (reason: {v})")
342                    }
343                }
344                Ok(())
345            }
346        }
347
348        impl TaskProcessor for Process {
349            type TaskInput = String;
350            type TaskOutput = Payload;
351            type TaskError = String;
352
353            fn process_task(&mut self, _id: Uuid, payload: String) -> Result<Payload, String> {
354                self.counter += 1;
355
356                Ok(Payload {
357                    summary: format!("hello, task:{payload}"),
358                    count: self.counter,
359                })
360            }
361        }
362
363        let mut source = Source {
364            counter: 0,
365            limit: 1_000,
366        };
367        let mut sink = Sink { counter: 0 };
368        let mut process = Process { counter: 0 };
369
370        run(&mut source, &mut process, &mut sink)?;
371
372        assert_eq!(sink.counter, 1_000);
373
374        Ok(())
375    }
376}