1use super::utils::chain_error_with_input;
2use nu_engine::{ClosureEval, ClosureEvalOnce, command_prelude::*};
3use nu_protocol::{Signals, engine::Closure};
4use rayon::prelude::*;
5
6#[derive(Clone)]
7pub struct ParEach;
8
9impl Command for ParEach {
10 fn name(&self) -> &str {
11 "par-each"
12 }
13
14 fn description(&self) -> &str {
15 "Run a closure on each row of the input list in parallel, creating a new list with the results."
16 }
17
18 fn signature(&self) -> nu_protocol::Signature {
19 Signature::build("par-each")
20 .input_output_types(vec![
21 (
22 Type::List(Box::new(Type::Any)),
23 Type::List(Box::new(Type::Any)),
24 ),
25 (Type::table(), Type::List(Box::new(Type::Any))),
26 (Type::Any, Type::Any),
27 ])
28 .named(
29 "threads",
30 SyntaxShape::Int,
31 "The number of threads to use.",
32 Some('t'),
33 )
34 .switch(
35 "keep-order",
36 "Keep sequence of output same as the order of input.",
37 Some('k'),
38 )
39 .required(
40 "closure",
41 SyntaxShape::Closure(Some(vec![SyntaxShape::Any])),
42 "The closure to run.",
43 )
44 .allow_variants_without_examples(true)
45 .category(Category::Filters)
46 }
47
48 fn examples(&self) -> Vec<Example<'_>> {
49 vec![
50 Example {
51 example: "[1 2 3] | par-each {|e| $e * 2 }",
52 description: "Multiplies each number. Note that the list will become arbitrarily disordered.",
53 result: None,
54 },
55 Example {
56 example: r#"[1 2 3] | par-each --keep-order {|e| $e * 2 }"#,
57 description: "Multiplies each number, keeping an original order.",
58 result: Some(Value::test_list(vec![
59 Value::test_int(2),
60 Value::test_int(4),
61 Value::test_int(6),
62 ])),
63 },
64 Example {
65 example: r#"1..3 | enumerate | par-each {|p| update item ($p.item * 2)} | sort-by item | get item"#,
66 description: "Enumerate and sort-by can be used to reconstruct the original order.",
67 result: Some(Value::test_list(vec![
68 Value::test_int(2),
69 Value::test_int(4),
70 Value::test_int(6),
71 ])),
72 },
73 Example {
74 example: r#"[foo bar baz] | par-each {|e| $e + '!' } | sort"#,
75 description: "Output can still be sorted afterward.",
76 result: Some(Value::test_list(vec![
77 Value::test_string("bar!"),
78 Value::test_string("baz!"),
79 Value::test_string("foo!"),
80 ])),
81 },
82 Example {
83 example: r#"[1 2 3] | enumerate | par-each { |e| if $e.item == 2 { $"found 2 at ($e.index)!"} }"#,
84 description: "Iterate over each element, producing a list showing indexes of any 2s.",
85 result: Some(Value::test_list(vec![Value::test_string("found 2 at 1!")])),
86 },
87 ]
88 }
89
90 fn run(
91 &self,
92 engine_state: &EngineState,
93 stack: &mut Stack,
94 call: &Call,
95 input: PipelineData,
96 ) -> Result<PipelineData, ShellError> {
97 fn create_pool(num_threads: usize, head: Span) -> Result<rayon::ThreadPool, ShellError> {
98 match rayon::ThreadPoolBuilder::new()
99 .num_threads(num_threads)
100 .build()
101 {
102 Err(e) => Err(e).map_err(|e| ShellError::GenericError {
103 error: "Error creating thread pool".into(),
104 msg: e.to_string(),
105 span: Some(head),
106 help: None,
107 inner: vec![],
108 }),
109 Ok(pool) => Ok(pool),
110 }
111 }
112
113 let head = call.head;
114 let closure: Closure = call.req(engine_state, stack, 0)?;
115 let threads: Option<usize> = call.get_flag(engine_state, stack, "threads")?;
116 let max_threads = threads.unwrap_or(0);
117 let keep_order = call.has_flag(engine_state, stack, "keep-order")?;
118
119 let input = match input.try_into_stream(engine_state) {
120 Ok(input) | Err(input) => input,
121 };
122 let metadata = input.metadata();
123
124 let apply_order = |mut vec: Vec<(usize, Value)>| {
126 if keep_order {
127 vec.par_sort_unstable_by_key(|(index, _)| *index);
130 }
131
132 vec.into_iter().map(|(_, val)| val)
133 };
134
135 match input {
136 PipelineData::Empty => Ok(PipelineData::empty()),
137 PipelineData::Value(value, ..) => {
138 let span = value.span();
139 match value {
140 Value::List { vals, .. } => Ok(create_pool(max_threads, head)?.install(|| {
141 let par_iter = vals.into_par_iter().enumerate();
142 let mapped = parallel_closure_map(engine_state, stack, &closure, par_iter);
143 apply_order(mapped.collect())
144 .into_pipeline_data(span, engine_state.signals().clone())
145 })),
146 Value::Range { val, .. } => Ok(create_pool(max_threads, head)?.install(|| {
147 let par_iter = val
148 .into_range_iter(span, Signals::empty())
149 .enumerate()
150 .par_bridge();
151 let mapped = parallel_closure_map(engine_state, stack, &closure, par_iter);
152 apply_order(mapped.collect())
153 .into_pipeline_data(span, engine_state.signals().clone())
154 })),
155 value => {
158 ClosureEvalOnce::new(engine_state, stack, closure).run_with_value(value)
159 }
160 }
161 }
162 PipelineData::ListStream(stream, ..) => {
163 Ok(create_pool(max_threads, head)?.install(|| {
164 let par_iter = stream.into_iter().enumerate().par_bridge();
165 let mapped = parallel_closure_map(engine_state, stack, &closure, par_iter);
166
167 apply_order(mapped.collect())
168 .into_pipeline_data(head, engine_state.signals().clone())
169 }))
170 }
171 PipelineData::ByteStream(stream, ..) => {
172 if let Some(chunks) = stream.chunks() {
173 Ok(create_pool(max_threads, head)?.install(|| {
174 let par_iter = chunks
175 .enumerate()
176 .map(move |(idx, val)| {
177 (idx, val.unwrap_or_else(|err| Value::error(err, head)))
178 })
179 .par_bridge();
180 let mapped = parallel_closure_map(engine_state, stack, &closure, par_iter);
181 apply_order(mapped.collect())
182 .into_pipeline_data(head, engine_state.signals().clone())
183 }))
184 } else {
185 Ok(PipelineData::empty())
186 }
187 }
188 }
189 .and_then(|x| x.filter(|v| !v.is_nothing(), engine_state.signals()))
190 .map(|data| data.set_metadata(metadata))
191 }
192}
193
194fn parallel_closure_map(
195 engine_state: &EngineState,
196 stack: &mut Stack,
197 closure: &Closure,
198 input: impl ParallelIterator<Item = (usize, Value)>,
199) -> impl ParallelIterator<Item = (usize, Value)> {
200 input.map_init(
201 move || ClosureEval::new(engine_state, stack, closure.clone()),
202 |closure_eval, (index, value)| {
203 let span = value.span();
204 let is_error = value.is_error();
205 let value = closure_eval
206 .run_with_value(value)
207 .and_then(|data| data.into_value(span))
208 .unwrap_or_else(|err| {
209 Value::error(chain_error_with_input(err, is_error, span), span)
210 });
211
212 (index, value)
213 },
214 )
215}
216
217#[cfg(test)]
218mod test {
219 use super::*;
220
221 #[test]
222 fn test_examples() {
223 use crate::test_examples;
224
225 test_examples(ParEach {})
226 }
227}