1use super::utils::chain_error_with_input;
2use nu_engine::{ClosureEvalOnce, command_prelude::*};
3use nu_protocol::{Signals, engine::Closure};
4use rayon::prelude::*;
5
6#[derive(Clone)]
7pub struct ParEach;
8
9impl Command for ParEach {
10 fn name(&self) -> &str {
11 "par-each"
12 }
13
14 fn description(&self) -> &str {
15 "Run a closure on each row of the input list in parallel, creating a new list with the results."
16 }
17
18 fn signature(&self) -> nu_protocol::Signature {
19 Signature::build("par-each")
20 .input_output_types(vec![
21 (
22 Type::List(Box::new(Type::Any)),
23 Type::List(Box::new(Type::Any)),
24 ),
25 (Type::table(), Type::List(Box::new(Type::Any))),
26 (Type::Any, Type::Any),
27 ])
28 .named(
29 "threads",
30 SyntaxShape::Int,
31 "the number of threads to use",
32 Some('t'),
33 )
34 .switch(
35 "keep-order",
36 "keep sequence of output same as the order of input",
37 Some('k'),
38 )
39 .required(
40 "closure",
41 SyntaxShape::Closure(Some(vec![SyntaxShape::Any])),
42 "The closure to run.",
43 )
44 .allow_variants_without_examples(true)
45 .category(Category::Filters)
46 }
47
48 fn examples(&self) -> Vec<Example<'_>> {
49 vec![
50 Example {
51 example: "[1 2 3] | par-each {|e| $e * 2 }",
52 description: "Multiplies each number. Note that the list will become arbitrarily disordered.",
53 result: None,
54 },
55 Example {
56 example: r#"[1 2 3] | par-each --keep-order {|e| $e * 2 }"#,
57 description: "Multiplies each number, keeping an original order",
58 result: Some(Value::test_list(vec![
59 Value::test_int(2),
60 Value::test_int(4),
61 Value::test_int(6),
62 ])),
63 },
64 Example {
65 example: r#"1..3 | enumerate | par-each {|p| update item ($p.item * 2)} | sort-by item | get item"#,
66 description: "Enumerate and sort-by can be used to reconstruct the original order",
67 result: Some(Value::test_list(vec![
68 Value::test_int(2),
69 Value::test_int(4),
70 Value::test_int(6),
71 ])),
72 },
73 Example {
74 example: r#"[foo bar baz] | par-each {|e| $e + '!' } | sort"#,
75 description: "Output can still be sorted afterward",
76 result: Some(Value::test_list(vec![
77 Value::test_string("bar!"),
78 Value::test_string("baz!"),
79 Value::test_string("foo!"),
80 ])),
81 },
82 Example {
83 example: r#"[1 2 3] | enumerate | par-each { |e| if $e.item == 2 { $"found 2 at ($e.index)!"} }"#,
84 description: "Iterate over each element, producing a list showing indexes of any 2s",
85 result: Some(Value::test_list(vec![Value::test_string("found 2 at 1!")])),
86 },
87 ]
88 }
89
90 fn run(
91 &self,
92 engine_state: &EngineState,
93 stack: &mut Stack,
94 call: &Call,
95 input: PipelineData,
96 ) -> Result<PipelineData, ShellError> {
97 fn create_pool(num_threads: usize) -> Result<rayon::ThreadPool, ShellError> {
98 match rayon::ThreadPoolBuilder::new()
99 .num_threads(num_threads)
100 .build()
101 {
102 Err(e) => Err(e).map_err(|e| ShellError::GenericError {
103 error: "Error creating thread pool".into(),
104 msg: e.to_string(),
105 span: Some(Span::unknown()),
106 help: None,
107 inner: vec![],
108 }),
109 Ok(pool) => Ok(pool),
110 }
111 }
112
113 let head = call.head;
114 let closure: Closure = call.req(engine_state, stack, 0)?;
115 let threads: Option<usize> = call.get_flag(engine_state, stack, "threads")?;
116 let max_threads = threads.unwrap_or(0);
117 let keep_order = call.has_flag(engine_state, stack, "keep-order")?;
118
119 let metadata = input.metadata();
120
121 let apply_order = |mut vec: Vec<(usize, Value)>| {
123 if keep_order {
124 vec.par_sort_unstable_by_key(|(index, _)| *index);
127 }
128
129 vec.into_iter().map(|(_, val)| val)
130 };
131
132 match input {
133 PipelineData::Empty => Ok(PipelineData::empty()),
134 PipelineData::Value(value, ..) => {
135 let span = value.span();
136 match value {
137 Value::List { vals, .. } => Ok(create_pool(max_threads)?.install(|| {
138 let vec = vals
139 .into_par_iter()
140 .enumerate()
141 .map(move |(index, value)| {
142 let span = value.span();
143 let is_error = value.is_error();
144 let value =
145 ClosureEvalOnce::new(engine_state, stack, closure.clone())
146 .run_with_value(value)
147 .and_then(|data| data.into_value(span))
148 .unwrap_or_else(|err| {
149 Value::error(
150 chain_error_with_input(err, is_error, span),
151 span,
152 )
153 });
154
155 (index, value)
156 })
157 .collect::<Vec<_>>();
158
159 apply_order(vec).into_pipeline_data(span, engine_state.signals().clone())
160 })),
161 Value::Range { val, .. } => Ok(create_pool(max_threads)?.install(|| {
162 let vec = val
163 .into_range_iter(span, Signals::empty())
164 .enumerate()
165 .par_bridge()
166 .map(move |(index, value)| {
167 let span = value.span();
168 let is_error = value.is_error();
169 let value =
170 ClosureEvalOnce::new(engine_state, stack, closure.clone())
171 .run_with_value(value)
172 .and_then(|data| data.into_value(span))
173 .unwrap_or_else(|err| {
174 Value::error(
175 chain_error_with_input(err, is_error, span),
176 span,
177 )
178 });
179
180 (index, value)
181 })
182 .collect::<Vec<_>>();
183
184 apply_order(vec).into_pipeline_data(span, engine_state.signals().clone())
185 })),
186 value => {
189 ClosureEvalOnce::new(engine_state, stack, closure).run_with_value(value)
190 }
191 }
192 }
193 PipelineData::ListStream(stream, ..) => Ok(create_pool(max_threads)?.install(|| {
194 let vec = stream
195 .into_iter()
196 .enumerate()
197 .par_bridge()
198 .map(move |(index, value)| {
199 let span = value.span();
200 let is_error = value.is_error();
201 let value = ClosureEvalOnce::new(engine_state, stack, closure.clone())
202 .run_with_value(value)
203 .and_then(|data| data.into_value(head))
204 .unwrap_or_else(|err| {
205 Value::error(chain_error_with_input(err, is_error, span), span)
206 });
207
208 (index, value)
209 })
210 .collect::<Vec<_>>();
211
212 apply_order(vec).into_pipeline_data(head, engine_state.signals().clone())
213 })),
214 PipelineData::ByteStream(stream, ..) => {
215 if let Some(chunks) = stream.chunks() {
216 Ok(create_pool(max_threads)?.install(|| {
217 let vec = chunks
218 .enumerate()
219 .par_bridge()
220 .map(move |(index, value)| {
221 let value = match value {
222 Ok(value) => value,
223 Err(err) => return (index, Value::error(err, head)),
224 };
225
226 let value =
227 ClosureEvalOnce::new(engine_state, stack, closure.clone())
228 .run_with_value(value)
229 .and_then(|data| data.into_value(head))
230 .unwrap_or_else(|err| Value::error(err, head));
231
232 (index, value)
233 })
234 .collect::<Vec<_>>();
235
236 apply_order(vec).into_pipeline_data(head, engine_state.signals().clone())
237 }))
238 } else {
239 Ok(PipelineData::empty())
240 }
241 }
242 }
243 .and_then(|x| x.filter(|v| !v.is_nothing(), engine_state.signals()))
244 .map(|data| data.set_metadata(metadata))
245 }
246}
247
248#[cfg(test)]
249mod test {
250 use super::*;
251
252 #[test]
253 fn test_examples() {
254 use crate::test_examples;
255
256 test_examples(ParEach {})
257 }
258}