nu_command/filters/
par_each.rs

1use super::utils::chain_error_with_input;
2use nu_engine::{ClosureEvalOnce, command_prelude::*};
3use nu_protocol::{Signals, engine::Closure};
4use rayon::prelude::*;
5
6#[derive(Clone)]
7pub struct ParEach;
8
9impl Command for ParEach {
10    fn name(&self) -> &str {
11        "par-each"
12    }
13
14    fn description(&self) -> &str {
15        "Run a closure on each row of the input list in parallel, creating a new list with the results."
16    }
17
18    fn signature(&self) -> nu_protocol::Signature {
19        Signature::build("par-each")
20            .input_output_types(vec![
21                (
22                    Type::List(Box::new(Type::Any)),
23                    Type::List(Box::new(Type::Any)),
24                ),
25                (Type::table(), Type::List(Box::new(Type::Any))),
26                (Type::Any, Type::Any),
27            ])
28            .named(
29                "threads",
30                SyntaxShape::Int,
31                "the number of threads to use",
32                Some('t'),
33            )
34            .switch(
35                "keep-order",
36                "keep sequence of output same as the order of input",
37                Some('k'),
38            )
39            .required(
40                "closure",
41                SyntaxShape::Closure(Some(vec![SyntaxShape::Any])),
42                "The closure to run.",
43            )
44            .allow_variants_without_examples(true)
45            .category(Category::Filters)
46    }
47
48    fn examples(&self) -> Vec<Example<'_>> {
49        vec![
50            Example {
51                example: "[1 2 3] | par-each {|e| $e * 2 }",
52                description: "Multiplies each number. Note that the list will become arbitrarily disordered.",
53                result: None,
54            },
55            Example {
56                example: r#"[1 2 3] | par-each --keep-order {|e| $e * 2 }"#,
57                description: "Multiplies each number, keeping an original order",
58                result: Some(Value::test_list(vec![
59                    Value::test_int(2),
60                    Value::test_int(4),
61                    Value::test_int(6),
62                ])),
63            },
64            Example {
65                example: r#"1..3 | enumerate | par-each {|p| update item ($p.item * 2)} | sort-by item | get item"#,
66                description: "Enumerate and sort-by can be used to reconstruct the original order",
67                result: Some(Value::test_list(vec![
68                    Value::test_int(2),
69                    Value::test_int(4),
70                    Value::test_int(6),
71                ])),
72            },
73            Example {
74                example: r#"[foo bar baz] | par-each {|e| $e + '!' } | sort"#,
75                description: "Output can still be sorted afterward",
76                result: Some(Value::test_list(vec![
77                    Value::test_string("bar!"),
78                    Value::test_string("baz!"),
79                    Value::test_string("foo!"),
80                ])),
81            },
82            Example {
83                example: r#"[1 2 3] | enumerate | par-each { |e| if $e.item == 2 { $"found 2 at ($e.index)!"} }"#,
84                description: "Iterate over each element, producing a list showing indexes of any 2s",
85                result: Some(Value::test_list(vec![Value::test_string("found 2 at 1!")])),
86            },
87        ]
88    }
89
90    fn run(
91        &self,
92        engine_state: &EngineState,
93        stack: &mut Stack,
94        call: &Call,
95        input: PipelineData,
96    ) -> Result<PipelineData, ShellError> {
97        fn create_pool(num_threads: usize) -> Result<rayon::ThreadPool, ShellError> {
98            match rayon::ThreadPoolBuilder::new()
99                .num_threads(num_threads)
100                .build()
101            {
102                Err(e) => Err(e).map_err(|e| ShellError::GenericError {
103                    error: "Error creating thread pool".into(),
104                    msg: e.to_string(),
105                    span: Some(Span::unknown()),
106                    help: None,
107                    inner: vec![],
108                }),
109                Ok(pool) => Ok(pool),
110            }
111        }
112
113        let head = call.head;
114        let closure: Closure = call.req(engine_state, stack, 0)?;
115        let threads: Option<usize> = call.get_flag(engine_state, stack, "threads")?;
116        let max_threads = threads.unwrap_or(0);
117        let keep_order = call.has_flag(engine_state, stack, "keep-order")?;
118
119        let metadata = input.metadata();
120
121        // A helper function sorts the output if needed
122        let apply_order = |mut vec: Vec<(usize, Value)>| {
123            if keep_order {
124                // It runs inside the rayon's thread pool so parallel sorting can be used.
125                // There are no identical indexes, so unstable sorting can be used.
126                vec.par_sort_unstable_by_key(|(index, _)| *index);
127            }
128
129            vec.into_iter().map(|(_, val)| val)
130        };
131
132        match input {
133            PipelineData::Empty => Ok(PipelineData::empty()),
134            PipelineData::Value(value, ..) => {
135                let span = value.span();
136                match value {
137                    Value::List { vals, .. } => Ok(create_pool(max_threads)?.install(|| {
138                        let vec = vals
139                            .into_par_iter()
140                            .enumerate()
141                            .map(move |(index, value)| {
142                                let span = value.span();
143                                let is_error = value.is_error();
144                                let value =
145                                    ClosureEvalOnce::new(engine_state, stack, closure.clone())
146                                        .run_with_value(value)
147                                        .and_then(|data| data.into_value(span))
148                                        .unwrap_or_else(|err| {
149                                            Value::error(
150                                                chain_error_with_input(err, is_error, span),
151                                                span,
152                                            )
153                                        });
154
155                                (index, value)
156                            })
157                            .collect::<Vec<_>>();
158
159                        apply_order(vec).into_pipeline_data(span, engine_state.signals().clone())
160                    })),
161                    Value::Range { val, .. } => Ok(create_pool(max_threads)?.install(|| {
162                        let vec = val
163                            .into_range_iter(span, Signals::empty())
164                            .enumerate()
165                            .par_bridge()
166                            .map(move |(index, value)| {
167                                let span = value.span();
168                                let is_error = value.is_error();
169                                let value =
170                                    ClosureEvalOnce::new(engine_state, stack, closure.clone())
171                                        .run_with_value(value)
172                                        .and_then(|data| data.into_value(span))
173                                        .unwrap_or_else(|err| {
174                                            Value::error(
175                                                chain_error_with_input(err, is_error, span),
176                                                span,
177                                            )
178                                        });
179
180                                (index, value)
181                            })
182                            .collect::<Vec<_>>();
183
184                        apply_order(vec).into_pipeline_data(span, engine_state.signals().clone())
185                    })),
186                    // This match allows non-iterables to be accepted,
187                    // which is currently considered undesirable (Nov 2022).
188                    value => {
189                        ClosureEvalOnce::new(engine_state, stack, closure).run_with_value(value)
190                    }
191                }
192            }
193            PipelineData::ListStream(stream, ..) => Ok(create_pool(max_threads)?.install(|| {
194                let vec = stream
195                    .into_iter()
196                    .enumerate()
197                    .par_bridge()
198                    .map(move |(index, value)| {
199                        let span = value.span();
200                        let is_error = value.is_error();
201                        let value = ClosureEvalOnce::new(engine_state, stack, closure.clone())
202                            .run_with_value(value)
203                            .and_then(|data| data.into_value(head))
204                            .unwrap_or_else(|err| {
205                                Value::error(chain_error_with_input(err, is_error, span), span)
206                            });
207
208                        (index, value)
209                    })
210                    .collect::<Vec<_>>();
211
212                apply_order(vec).into_pipeline_data(head, engine_state.signals().clone())
213            })),
214            PipelineData::ByteStream(stream, ..) => {
215                if let Some(chunks) = stream.chunks() {
216                    Ok(create_pool(max_threads)?.install(|| {
217                        let vec = chunks
218                            .enumerate()
219                            .par_bridge()
220                            .map(move |(index, value)| {
221                                let value = match value {
222                                    Ok(value) => value,
223                                    Err(err) => return (index, Value::error(err, head)),
224                                };
225
226                                let value =
227                                    ClosureEvalOnce::new(engine_state, stack, closure.clone())
228                                        .run_with_value(value)
229                                        .and_then(|data| data.into_value(head))
230                                        .unwrap_or_else(|err| Value::error(err, head));
231
232                                (index, value)
233                            })
234                            .collect::<Vec<_>>();
235
236                        apply_order(vec).into_pipeline_data(head, engine_state.signals().clone())
237                    }))
238                } else {
239                    Ok(PipelineData::empty())
240                }
241            }
242        }
243        .and_then(|x| x.filter(|v| !v.is_nothing(), engine_state.signals()))
244        .map(|data| data.set_metadata(metadata))
245    }
246}
247
248#[cfg(test)]
249mod test {
250    use super::*;
251
252    #[test]
253    fn test_examples() {
254        use crate::test_examples;
255
256        test_examples(ParEach {})
257    }
258}