Skip to main content

nu_command/filters/
par_each.rs

1use super::utils::chain_error_with_input;
2use nu_engine::{ClosureEval, ClosureEvalOnce, command_prelude::*};
3use nu_protocol::{Signals, engine::Closure};
4use rayon::prelude::*;
5
6#[derive(Clone)]
7pub struct ParEach;
8
9impl Command for ParEach {
10    fn name(&self) -> &str {
11        "par-each"
12    }
13
14    fn description(&self) -> &str {
15        "Run a closure on each row of the input list in parallel, creating a new list with the results."
16    }
17
18    fn signature(&self) -> nu_protocol::Signature {
19        Signature::build("par-each")
20            .input_output_types(vec![
21                (
22                    Type::List(Box::new(Type::Any)),
23                    Type::List(Box::new(Type::Any)),
24                ),
25                (Type::table(), Type::List(Box::new(Type::Any))),
26                (Type::Any, Type::Any),
27            ])
28            .named(
29                "threads",
30                SyntaxShape::Int,
31                "The number of threads to use.",
32                Some('t'),
33            )
34            .switch(
35                "keep-order",
36                "Keep sequence of output same as the order of input.",
37                Some('k'),
38            )
39            .required(
40                "closure",
41                SyntaxShape::Closure(Some(vec![SyntaxShape::Any])),
42                "The closure to run.",
43            )
44            .allow_variants_without_examples(true)
45            .category(Category::Filters)
46    }
47
48    fn examples(&self) -> Vec<Example<'_>> {
49        vec![
50            Example {
51                example: "[1 2 3] | par-each {|e| $e * 2 }",
52                description: "Multiplies each number. Note that the list will become arbitrarily disordered.",
53                result: None,
54            },
55            Example {
56                example: r#"[1 2 3] | par-each --keep-order {|e| $e * 2 }"#,
57                description: "Multiplies each number, keeping an original order.",
58                result: Some(Value::test_list(vec![
59                    Value::test_int(2),
60                    Value::test_int(4),
61                    Value::test_int(6),
62                ])),
63            },
64            Example {
65                example: r#"1..3 | enumerate | par-each {|p| update item ($p.item * 2)} | sort-by item | get item"#,
66                description: "Enumerate and sort-by can be used to reconstruct the original order.",
67                result: Some(Value::test_list(vec![
68                    Value::test_int(2),
69                    Value::test_int(4),
70                    Value::test_int(6),
71                ])),
72            },
73            Example {
74                example: r#"[foo bar baz] | par-each {|e| $e + '!' } | sort"#,
75                description: "Output can still be sorted afterward.",
76                result: Some(Value::test_list(vec![
77                    Value::test_string("bar!"),
78                    Value::test_string("baz!"),
79                    Value::test_string("foo!"),
80                ])),
81            },
82            Example {
83                example: r#"[1 2 3] | enumerate | par-each { |e| if $e.item == 2 { $"found 2 at ($e.index)!"} }"#,
84                description: "Iterate over each element, producing a list showing indexes of any 2s.",
85                result: Some(Value::test_list(vec![Value::test_string("found 2 at 1!")])),
86            },
87        ]
88    }
89
90    fn run(
91        &self,
92        engine_state: &EngineState,
93        stack: &mut Stack,
94        call: &Call,
95        input: PipelineData,
96    ) -> Result<PipelineData, ShellError> {
97        fn create_pool(num_threads: usize, head: Span) -> Result<rayon::ThreadPool, ShellError> {
98            match rayon::ThreadPoolBuilder::new()
99                .num_threads(num_threads)
100                .build()
101            {
102                Err(e) => Err(e).map_err(|e| ShellError::GenericError {
103                    error: "Error creating thread pool".into(),
104                    msg: e.to_string(),
105                    span: Some(head),
106                    help: None,
107                    inner: vec![],
108                }),
109                Ok(pool) => Ok(pool),
110            }
111        }
112
113        let head = call.head;
114        let closure: Closure = call.req(engine_state, stack, 0)?;
115        let threads: Option<usize> = call.get_flag(engine_state, stack, "threads")?;
116        let max_threads = threads.unwrap_or(0);
117        let keep_order = call.has_flag(engine_state, stack, "keep-order")?;
118
119        let input = match input.try_into_stream(engine_state) {
120            Ok(input) | Err(input) => input,
121        };
122        let metadata = input.metadata();
123
124        // A helper function sorts the output if needed
125        let apply_order = |mut vec: Vec<(usize, Value)>| {
126            if keep_order {
127                // It runs inside the rayon's thread pool so parallel sorting can be used.
128                // There are no identical indexes, so unstable sorting can be used.
129                vec.par_sort_unstable_by_key(|(index, _)| *index);
130            }
131
132            vec.into_iter().map(|(_, val)| val)
133        };
134
135        match input {
136            PipelineData::Empty => Ok(PipelineData::empty()),
137            PipelineData::Value(value, ..) => {
138                let span = value.span();
139                match value {
140                    Value::List { vals, .. } => Ok(create_pool(max_threads, head)?.install(|| {
141                        let par_iter = vals.into_par_iter().enumerate();
142                        let mapped = parallel_closure_map(engine_state, stack, &closure, par_iter);
143                        apply_order(mapped.collect())
144                            .into_pipeline_data(span, engine_state.signals().clone())
145                    })),
146                    Value::Range { val, .. } => Ok(create_pool(max_threads, head)?.install(|| {
147                        let par_iter = val
148                            .into_range_iter(span, Signals::empty())
149                            .enumerate()
150                            .par_bridge();
151                        let mapped = parallel_closure_map(engine_state, stack, &closure, par_iter);
152                        apply_order(mapped.collect())
153                            .into_pipeline_data(span, engine_state.signals().clone())
154                    })),
155                    // This match allows non-iterables to be accepted,
156                    // which is currently considered undesirable (Nov 2022).
157                    value => {
158                        ClosureEvalOnce::new(engine_state, stack, closure).run_with_value(value)
159                    }
160                }
161            }
162            PipelineData::ListStream(stream, ..) => {
163                Ok(create_pool(max_threads, head)?.install(|| {
164                    let par_iter = stream.into_iter().enumerate().par_bridge();
165                    let mapped = parallel_closure_map(engine_state, stack, &closure, par_iter);
166
167                    apply_order(mapped.collect())
168                        .into_pipeline_data(head, engine_state.signals().clone())
169                }))
170            }
171            PipelineData::ByteStream(stream, ..) => {
172                if let Some(chunks) = stream.chunks() {
173                    Ok(create_pool(max_threads, head)?.install(|| {
174                        let par_iter = chunks
175                            .enumerate()
176                            .map(move |(idx, val)| {
177                                (idx, val.unwrap_or_else(|err| Value::error(err, head)))
178                            })
179                            .par_bridge();
180                        let mapped = parallel_closure_map(engine_state, stack, &closure, par_iter);
181                        apply_order(mapped.collect())
182                            .into_pipeline_data(head, engine_state.signals().clone())
183                    }))
184                } else {
185                    Ok(PipelineData::empty())
186                }
187            }
188        }
189        .and_then(|x| x.filter(|v| !v.is_nothing(), engine_state.signals()))
190        .map(|data| data.set_metadata(metadata))
191    }
192}
193
194fn parallel_closure_map(
195    engine_state: &EngineState,
196    stack: &mut Stack,
197    closure: &Closure,
198    input: impl ParallelIterator<Item = (usize, Value)>,
199) -> impl ParallelIterator<Item = (usize, Value)> {
200    input.map_init(
201        move || ClosureEval::new(engine_state, stack, closure.clone()),
202        |closure_eval, (index, value)| {
203            let span = value.span();
204            let is_error = value.is_error();
205            let value = closure_eval
206                .run_with_value(value)
207                .and_then(|data| data.into_value(span))
208                .unwrap_or_else(|err| {
209                    Value::error(chain_error_with_input(err, is_error, span), span)
210                });
211
212            (index, value)
213        },
214    )
215}
216
217#[cfg(test)]
218mod test {
219    use super::*;
220
221    #[test]
222    fn test_examples() {
223        use crate::test_examples;
224
225        test_examples(ParEach {})
226    }
227}