Skip to main content

nu_command/filters/
par_each.rs

1use super::utils::chain_error_with_input;
2use nu_engine::{ClosureEval, ClosureEvalOnce, command_prelude::*};
3use nu_protocol::{Signals, engine::Closure, shell_error::generic::GenericError};
4use rayon::prelude::*;
5use std::{
6    sync::mpsc::{self, RecvTimeoutError},
7    time::Duration,
8};
9
10const STREAM_BUFFER_SIZE: usize = 64;
11const CTRL_C_CHECK_INTERVAL: Duration = Duration::from_millis(100);
12
13#[derive(Clone)]
14pub struct ParEach;
15
16impl Command for ParEach {
17    fn name(&self) -> &str {
18        "par-each"
19    }
20
21    fn description(&self) -> &str {
22        "Run a closure on each row of the input list in parallel, creating a new list with the results."
23    }
24
25    fn signature(&self) -> nu_protocol::Signature {
26        Signature::build("par-each")
27            .input_output_types(vec![
28                (
29                    Type::List(Box::new(Type::Any)),
30                    Type::List(Box::new(Type::Any)),
31                ),
32                (Type::table(), Type::List(Box::new(Type::Any))),
33                (Type::Any, Type::Any),
34            ])
35            .named(
36                "threads",
37                SyntaxShape::Int,
38                "The number of threads to use.",
39                Some('t'),
40            )
41            .switch(
42                "keep-order",
43                "Keep sequence of output same as the order of input.",
44                Some('k'),
45            )
46            .required(
47                "closure",
48                SyntaxShape::Closure(Some(vec![SyntaxShape::Any])),
49                "The closure to run.",
50            )
51            .allow_variants_without_examples(true)
52            .category(Category::Filters)
53    }
54
55    fn examples(&self) -> Vec<Example<'_>> {
56        vec![
57            Example {
58                example: "[1 2 3] | par-each {|e| $e * 2 }",
59                description: "Multiplies each number. Note that the list will become arbitrarily disordered.",
60                result: None,
61            },
62            Example {
63                example: "[1 2 3] | par-each --keep-order {|e| $e * 2 }",
64                description: "Multiplies each number, keeping an original order.",
65                result: Some(Value::test_list(vec![
66                    Value::test_int(2),
67                    Value::test_int(4),
68                    Value::test_int(6),
69                ])),
70            },
71            Example {
72                example: "1..3 | enumerate | par-each {|p| update item ($p.item * 2)} | sort-by item | get item",
73                description: "Enumerate and sort-by can be used to reconstruct the original order.",
74                result: Some(Value::test_list(vec![
75                    Value::test_int(2),
76                    Value::test_int(4),
77                    Value::test_int(6),
78                ])),
79            },
80            Example {
81                example: "[foo bar baz] | par-each {|e| $e + '!' } | sort",
82                description: "Output can still be sorted afterward.",
83                result: Some(Value::test_list(vec![
84                    Value::test_string("bar!"),
85                    Value::test_string("baz!"),
86                    Value::test_string("foo!"),
87                ])),
88            },
89            Example {
90                example: r#"[1 2 3] | enumerate | par-each { |e| if $e.item == 2 { $"found 2 at ($e.index)!"} }"#,
91                description: "Iterate over each element, producing a list showing indexes of any 2s.",
92                result: Some(Value::test_list(vec![Value::test_string("found 2 at 1!")])),
93            },
94        ]
95    }
96
97    fn run(
98        &self,
99        engine_state: &EngineState,
100        stack: &mut Stack,
101        call: &Call,
102        input: PipelineData,
103    ) -> Result<PipelineData, ShellError> {
104        fn create_pool(num_threads: usize, head: Span) -> Result<rayon::ThreadPool, ShellError> {
105            match rayon::ThreadPoolBuilder::new()
106                .num_threads(num_threads)
107                .build()
108            {
109                Err(e) => Err(e).map_err(|e| {
110                    ShellError::Generic(GenericError::new(
111                        "Error creating thread pool",
112                        e.to_string(),
113                        head,
114                    ))
115                }),
116                Ok(pool) => Ok(pool),
117            }
118        }
119
120        let head = call.head;
121        let closure: Closure = call.req(engine_state, stack, 0)?;
122        let threads: Option<usize> = call.get_flag(engine_state, stack, "threads")?;
123        let max_threads = threads.unwrap_or(0);
124        let keep_order = call.has_flag(engine_state, stack, "keep-order")?;
125        let signals = engine_state.signals().clone();
126
127        let mut input = input.into_stream_or_original(engine_state);
128        let metadata = input.take_metadata();
129
130        // A helper function sorts the output if needed
131        let apply_order = |mut vec: Vec<(usize, Value)>| {
132            if keep_order {
133                // It runs inside the rayon's thread pool so parallel sorting can be used.
134                // There are no identical indexes, so unstable sorting can be used.
135                vec.par_sort_unstable_by_key(|(index, _)| *index);
136            }
137
138            vec.into_iter().map(|(_, val)| val)
139        };
140
141        match input {
142            PipelineData::Empty => Ok(PipelineData::empty()),
143            PipelineData::Value(value, ..) => {
144                let span = value.span();
145                match value {
146                    Value::List { vals, .. } => {
147                        let pool = create_pool(max_threads, head)?;
148                        if keep_order {
149                            Ok(pool.install(|| {
150                                let par_iter = vals.into_par_iter().enumerate();
151                                let mapped =
152                                    parallel_closure_map(engine_state, stack, &closure, par_iter);
153                                apply_order(mapped.collect())
154                                    .into_pipeline_data(span, signals.clone())
155                            }))
156                        } else {
157                            let par_iter = vals.into_par_iter();
158                            Ok(stream_parallel_values(
159                                engine_state,
160                                stack,
161                                closure.clone(),
162                                pool,
163                                span,
164                                signals.clone(),
165                                par_iter,
166                            ))
167                        }
168                    }
169                    Value::Range { val, .. } => {
170                        let pool = create_pool(max_threads, head)?;
171                        if keep_order {
172                            Ok(pool.install(|| {
173                                let par_iter = val
174                                    .into_range_iter(span, signals.clone())
175                                    .enumerate()
176                                    .par_bridge();
177                                let mapped =
178                                    parallel_closure_map(engine_state, stack, &closure, par_iter);
179                                apply_order(mapped.collect())
180                                    .into_pipeline_data(span, signals.clone())
181                            }))
182                        } else {
183                            let par_iter = val.into_range_iter(span, signals.clone()).par_bridge();
184                            Ok(stream_parallel_values(
185                                engine_state,
186                                stack,
187                                closure.clone(),
188                                pool,
189                                span,
190                                signals.clone(),
191                                par_iter,
192                            ))
193                        }
194                    }
195                    // This match allows non-iterables to be accepted,
196                    // which is currently considered undesirable (Nov 2022).
197                    value => {
198                        ClosureEvalOnce::new(engine_state, stack, closure).run_with_value(value)
199                    }
200                }
201            }
202            PipelineData::ListStream(stream, ..) => {
203                let pool = create_pool(max_threads, head)?;
204                if keep_order {
205                    Ok(pool.install(|| {
206                        let par_iter = stream.into_iter().enumerate().par_bridge();
207                        let mapped = parallel_closure_map(engine_state, stack, &closure, par_iter);
208
209                        apply_order(mapped.collect()).into_pipeline_data(head, signals.clone())
210                    }))
211                } else {
212                    let par_iter = stream.into_iter().par_bridge();
213                    Ok(stream_parallel_values(
214                        engine_state,
215                        stack,
216                        closure.clone(),
217                        pool,
218                        head,
219                        signals.clone(),
220                        par_iter,
221                    ))
222                }
223            }
224            PipelineData::ByteStream(stream, ..) => {
225                if let Some(chunks) = stream.chunks() {
226                    let pool = create_pool(max_threads, head)?;
227                    if keep_order {
228                        Ok(pool.install(|| {
229                            let par_iter = chunks
230                                .enumerate()
231                                .map(move |(idx, val)| {
232                                    (idx, val.unwrap_or_else(|err| Value::error(err, head)))
233                                })
234                                .par_bridge();
235                            let mapped =
236                                parallel_closure_map(engine_state, stack, &closure, par_iter);
237                            apply_order(mapped.collect()).into_pipeline_data(head, signals.clone())
238                        }))
239                    } else {
240                        let par_iter = chunks
241                            .map(move |val| val.unwrap_or_else(|err| Value::error(err, head)))
242                            .par_bridge();
243                        Ok(stream_parallel_values(
244                            engine_state,
245                            stack,
246                            closure.clone(),
247                            pool,
248                            head,
249                            signals.clone(),
250                            par_iter,
251                        ))
252                    }
253                } else {
254                    Ok(PipelineData::empty())
255                }
256            }
257        }
258        .and_then(|x| x.filter(|v| !v.is_nothing(), engine_state.signals()))
259        .map(|data| data.set_metadata(metadata))
260    }
261}
262
263fn stream_parallel_values(
264    engine_state: &EngineState,
265    stack: &Stack,
266    closure: Closure,
267    pool: rayon::ThreadPool,
268    span: Span,
269    signals: Signals,
270    input: impl ParallelIterator<Item = Value> + 'static,
271) -> PipelineData {
272    let (tx, rx) = mpsc::sync_channel(STREAM_BUFFER_SIZE);
273    let worker_engine_state = engine_state.clone();
274    let worker_stack = stack.clone();
275    let worker_signals = signals.clone();
276
277    pool.install(|| {
278        rayon::spawn(move || {
279            let map_signals = worker_signals.clone();
280            let send_signals = worker_signals.clone();
281
282            let _ = input
283                .map_init(
284                    move || ClosureEval::new(&worker_engine_state, &worker_stack, closure.clone()),
285                    move |closure_eval, value| {
286                        if map_signals.interrupted() {
287                            return Err(());
288                        }
289
290                        let value = run_closure_on_value(closure_eval, value);
291
292                        if map_signals.interrupted() {
293                            Err(())
294                        } else {
295                            Ok(value)
296                        }
297                    },
298                )
299                .try_for_each(move |value| match value {
300                    Ok(value) => {
301                        if send_signals.interrupted() {
302                            Err(())
303                        } else {
304                            tx.send(value).map_err(|_| ())
305                        }
306                    }
307                    Err(()) => Err(()),
308                });
309        });
310    });
311
312    ReceiverIter::new(rx, signals).into_pipeline_data(span, Signals::empty())
313}
314
315// Polls channel reads so Ctrl+C can stop blocked receives promptly.
316struct ReceiverIter {
317    receiver: mpsc::Receiver<Value>,
318    signals: Signals,
319}
320
321impl ReceiverIter {
322    fn new(receiver: mpsc::Receiver<Value>, signals: Signals) -> Self {
323        Self { receiver, signals }
324    }
325}
326
327impl Iterator for ReceiverIter {
328    type Item = Value;
329
330    fn next(&mut self) -> Option<Self::Item> {
331        loop {
332            if self.signals.interrupted() {
333                return None;
334            }
335
336            match self.receiver.recv_timeout(CTRL_C_CHECK_INTERVAL) {
337                Ok(value) => return Some(value),
338                Err(RecvTimeoutError::Timeout) => {}
339                Err(RecvTimeoutError::Disconnected) => return None,
340            }
341        }
342    }
343}
344
345fn run_closure_on_value(closure_eval: &mut ClosureEval, value: Value) -> Value {
346    let span = value.span();
347    let is_error = value.is_error();
348
349    closure_eval
350        .run_with_value(value)
351        .and_then(|data| data.into_value(span))
352        .unwrap_or_else(|err| Value::error(chain_error_with_input(err, is_error, span), span))
353}
354
355fn parallel_closure_map(
356    engine_state: &EngineState,
357    stack: &mut Stack,
358    closure: &Closure,
359    input: impl ParallelIterator<Item = (usize, Value)>,
360) -> impl ParallelIterator<Item = (usize, Value)> {
361    input.map_init(
362        move || ClosureEval::new(engine_state, stack, closure.clone()),
363        |closure_eval, (index, value)| {
364            let value = run_closure_on_value(closure_eval, value);
365
366            (index, value)
367        },
368    )
369}
370
371#[cfg(test)]
372mod test {
373    use super::*;
374
375    #[test]
376    fn test_examples() -> nu_test_support::Result {
377        nu_test_support::test().examples(ParEach)
378    }
379}