1use std::collections::BTreeMap;
2use std::process::exit;
3
4use easy_parallel::Parallel;
5use ipc_channel::ipc::{
6 self, IpcOneShotServer, IpcReceiver, IpcReceiverSet, IpcSelectionResult, IpcSender,
7};
8use nix::unistd::{fork, ForkResult};
9use rand::RngCore;
10use thiserror::Error;
11pub use uuid::Uuid;
12
13pub mod sources;
14
15#[derive(Debug, Error)]
16pub enum Error {
17 #[error(transparent)]
18 ChannelRecv(#[from] flume::RecvError),
19 #[error("cannot send data on closed channel")]
20 ChannelSend,
21 #[error(transparent)]
22 Decode(#[from] rmp_serde::decode::Error),
23 #[error("fork failed: {0}")]
24 Fork(nix::errno::Errno),
25 #[error(transparent)]
26 Io(#[from] std::io::Error),
27 #[error(transparent)]
28 Ipc(#[from] ipc_channel::Error),
29 #[error(transparent)]
30 TaskSink(anyhow::Error),
31 #[error(transparent)]
32 TaskSource(anyhow::Error),
33 #[error("one or more errors encountered during task processing")]
34 Process(Vec<Self>),
35}
36
37impl Error {
38 pub fn source<E>(e: E) -> Self
39 where
40 E: std::error::Error + Send + Sync + 'static,
41 {
42 Self::TaskSource(anyhow::Error::from(e))
43 }
44
45 pub fn sink<E>(e: E) -> Self
46 where
47 E: std::error::Error + Send + Sync + 'static,
48 {
49 Self::TaskSink(anyhow::Error::from(e))
50 }
51}
52
53pub trait TaskData: serde::Serialize + serde::de::DeserializeOwned + Send {}
54impl<T> TaskData for T where T: serde::Serialize + serde::de::DeserializeOwned + Send {}
55
56#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
57pub struct TaskItem<T> {
58 id: Uuid,
59 payload: T,
60}
61
62#[derive(Debug, serde::Deserialize, serde::Serialize)]
63pub struct TaskResult<T, E> {
64 id: Uuid,
65 result: Result<T, E>,
66}
67
68pub trait TaskSource: Send {
69 type TaskInput: Clone + TaskData;
70 type Error: std::error::Error + Send + Sync + 'static;
71
72 fn next_task(&mut self, id: Uuid) -> Result<Option<Self::TaskInput>, Self::Error>;
73}
74
75pub trait TaskProcessor: Send {
76 type TaskInput: Clone + TaskData;
77 type TaskOutput: TaskData;
78 type TaskError: TaskData;
79
80 fn process_task(
81 &mut self,
82 id: Uuid,
83 input: Self::TaskInput,
84 ) -> Result<Self::TaskOutput, Self::TaskError>;
85}
86
87pub trait TaskSink: Send {
88 type TaskOutput: TaskData;
89 type TaskError: TaskData;
90
91 type Error: std::error::Error + Send + Sync + 'static;
92
93 fn process_task_result(
94 &mut self,
95 id: Uuid,
96 result: Result<Self::TaskOutput, Self::TaskError>,
97 ) -> Result<(), Self::Error>;
98}
99
100pub fn run<T, U, E>(
101 source: &mut impl TaskSource<TaskInput = T>,
102 processor: &mut impl TaskProcessor<TaskInput = T, TaskOutput = U, TaskError = E>,
103 sink: &mut impl TaskSink<TaskOutput = U, TaskError = E>,
104) -> Result<(), Error>
105where
106 T: Clone + TaskData,
107 U: TaskData,
108 E: TaskData,
109{
110 let (result_tx, result_rx) = flume::unbounded::<TaskResult<U, E>>();
111 let (child_tx, child_rx) = flume::unbounded::<IpcSender<Option<TaskItem<T>>>>();
112
113 let mut rxs = IpcReceiverSet::new()?;
114 let mut txs = BTreeMap::new();
115
116 for _ in 0..num_cpus::get() {
117 rand::thread_rng().next_u64();
126
127 let (psrv, pn) = IpcOneShotServer::<(String, IpcReceiver<Option<U>>)>::new()?;
128
129 match unsafe { fork() }.map_err(Error::Fork)? {
130 ForkResult::Child => {
131 let (csrv, cn) = IpcOneShotServer::<IpcReceiver<Option<TaskItem<T>>>>::new()
135 .map_err(Error::Io)?;
136
137 let results = IpcSender::connect(pn)?;
138
139 let (results_tx, results_rx) = ipc::channel::<Option<Vec<u8>>>()?;
140
141 results.send((cn, results_rx))?;
142
143 let (_, task_rx) = csrv.accept()?;
144
145 while let Ok(Some(task)) = task_rx.recv() {
148 let results = TaskResult {
149 id: task.id,
150 result: processor.process_task(task.id, task.payload),
151 };
152 results_tx.send(rmp_serde::to_vec(&results).ok())?;
153 }
154
155 exit(0);
156 }
157 ForkResult::Parent { .. } => {
158 let (_, (cn, rx)) = psrv.accept()?;
159
160 let tx = IpcSender::<IpcReceiver<Option<TaskItem<T>>>>::connect(cn)?;
161
162 let (ctx, crx) = ipc::channel()?;
163
164 tx.send(crx)?;
165
166 if !child_tx.send(ctx.clone()).is_err() {
167 let child = rxs.add(rx)?;
168 txs.insert(child, Some(ctx));
169 }
170 }
171 }
172 }
173
174 let (task_tx, task_rx) = flume::bounded::<TaskItem<T>>(num_cpus::get());
175
176 let (rs, r) = Parallel::new()
177 .add(move || {
178 let mut id = Uuid::now_v7();
179 while let Some(payload) = source.next_task(id).map_err(Error::source)? {
180 task_tx
181 .send(TaskItem { id, payload })
182 .map_err(|_| Error::ChannelSend)?;
183 id = Uuid::now_v7();
184 }
185
186 drop(task_tx);
187
188 Ok::<(), Error>(())
189 })
190 .add(move || {
191 while let Ok(TaskResult { id, result }) = result_rx.recv() {
192 sink.process_task_result(id, result).map_err(Error::sink)?;
193 }
194
195 Ok::<(), Error>(())
196 })
197 .add(move || {
198 while let Ok(task) = task_rx.recv() {
199 loop {
200 let child = child_rx.recv()?;
201
202 if let Err(_) = child.send(Some(task.clone())) {
203 continue;
205 }
206
207 break;
208 }
209 }
210
211 for r in child_rx.drain() {
212 r.send(None).ok();
213 drop(r);
214 }
215
216 drop(child_rx);
217 drop(task_rx);
218
219 Ok(())
220 })
221 .finish(move || {
222 while let Ok(events) = rxs.select() {
223 for event in events {
224 match event {
225 IpcSelectionResult::ChannelClosed(ref child) => {
226 txs.remove(child);
227 }
228 IpcSelectionResult::MessageReceived(ref child, msg) => {
229 let result = msg
230 .to::<Option<Vec<u8>>>()?
231 .map(|v| rmp_serde::from_slice::<TaskResult<U, E>>(&v))
232 .transpose()?;
233
234 if let Some(ref child) = txs[child] {
235 if child_tx.send(child.clone()).is_err() {
236 txs.values_mut().for_each(|ch| {
239 if let Some(ch) = ch.take() {
240 ch.send(None).ok();
241 }
242 });
243 }
244 }
245
246 if let Some(result) = result {
247 result_tx.send(result).map_err(|_| Error::ChannelSend)?;
248 }
249 }
250 }
251 }
252
253 if txs.is_empty() {
254 break;
256 }
257 }
258
259 drop(child_tx);
260 drop(txs);
261 drop(rxs);
262
263 Ok::<(), Error>(())
264 });
265
266 let mut errs = Vec::new();
267
268 if let Err(e) = r {
269 errs.push(e);
270 }
271
272 errs.extend(rs.into_iter().filter_map(Result::err));
273
274 if errs.is_empty() {
275 Ok(())
276 } else {
277 Err(Error::Process(errs))
278 }
279}
280
281#[cfg(test)]
282mod test {
283 use std::convert::Infallible;
284
285 use super::*;
286
287 #[test]
288 fn test_simple() -> Result<(), Error> {
289 struct Source {
290 counter: usize,
291 limit: usize,
292 }
293
294 struct Process {
295 counter: usize,
296 }
297
298 struct Sink {
299 counter: usize,
300 }
301
302 #[derive(Debug, serde::Deserialize, serde::Serialize)]
303 struct Payload {
304 summary: String,
305 count: usize,
306 }
307
308 impl TaskSource for Source {
309 type TaskInput = String;
310
311 type Error = Infallible;
312
313 fn next_task(&mut self, id: Uuid) -> Result<Option<String>, Self::Error> {
314 if self.counter >= self.limit {
315 return Ok(None);
316 }
317
318 self.counter += 1;
319
320 Ok(Some(format!("{}:{id}", self.counter)))
321 }
322 }
323
324 impl TaskSink for Sink {
325 type TaskOutput = Payload;
326 type TaskError = String;
327
328 type Error = Infallible;
329
330 fn process_task_result(
331 &mut self,
332 id: Uuid,
333 result: Result<Payload, String>,
334 ) -> Result<(), Self::Error> {
335 match result {
336 Ok(Payload { summary, count }) => {
337 self.counter += 1;
338 println!("task {id} processed successfully (payload: {summary} / {count})")
339 }
340 Err(v) => {
341 println!("task {id} failed (reason: {v})")
342 }
343 }
344 Ok(())
345 }
346 }
347
348 impl TaskProcessor for Process {
349 type TaskInput = String;
350 type TaskOutput = Payload;
351 type TaskError = String;
352
353 fn process_task(&mut self, _id: Uuid, payload: String) -> Result<Payload, String> {
354 self.counter += 1;
355
356 Ok(Payload {
357 summary: format!("hello, task:{payload}"),
358 count: self.counter,
359 })
360 }
361 }
362
363 let mut source = Source {
364 counter: 0,
365 limit: 1_000,
366 };
367 let mut sink = Sink { counter: 0 };
368 let mut process = Process { counter: 0 };
369
370 run(&mut source, &mut process, &mut sink)?;
371
372 assert_eq!(sink.counter, 1_000);
373
374 Ok(())
375 }
376}