1use super::utils::chain_error_with_input;
2use nu_engine::{ClosureEval, ClosureEvalOnce, command_prelude::*};
3use nu_protocol::{Signals, engine::Closure, shell_error::generic::GenericError};
4use rayon::prelude::*;
5use std::{
6 sync::mpsc::{self, RecvTimeoutError},
7 time::Duration,
8};
9
10const STREAM_BUFFER_SIZE: usize = 64;
11const CTRL_C_CHECK_INTERVAL: Duration = Duration::from_millis(100);
12
13#[derive(Clone)]
14pub struct ParEach;
15
16impl Command for ParEach {
17 fn name(&self) -> &str {
18 "par-each"
19 }
20
21 fn description(&self) -> &str {
22 "Run a closure on each row of the input list in parallel, creating a new list with the results."
23 }
24
25 fn signature(&self) -> nu_protocol::Signature {
26 Signature::build("par-each")
27 .input_output_types(vec![
28 (
29 Type::List(Box::new(Type::Any)),
30 Type::List(Box::new(Type::Any)),
31 ),
32 (Type::table(), Type::List(Box::new(Type::Any))),
33 (Type::Any, Type::Any),
34 ])
35 .named(
36 "threads",
37 SyntaxShape::Int,
38 "The number of threads to use.",
39 Some('t'),
40 )
41 .switch(
42 "keep-order",
43 "Keep sequence of output same as the order of input.",
44 Some('k'),
45 )
46 .required(
47 "closure",
48 SyntaxShape::Closure(Some(vec![SyntaxShape::Any])),
49 "The closure to run.",
50 )
51 .allow_variants_without_examples(true)
52 .category(Category::Filters)
53 }
54
55 fn examples(&self) -> Vec<Example<'_>> {
56 vec![
57 Example {
58 example: "[1 2 3] | par-each {|e| $e * 2 }",
59 description: "Multiplies each number. Note that the list will become arbitrarily disordered.",
60 result: None,
61 },
62 Example {
63 example: "[1 2 3] | par-each --keep-order {|e| $e * 2 }",
64 description: "Multiplies each number, keeping an original order.",
65 result: Some(Value::test_list(vec![
66 Value::test_int(2),
67 Value::test_int(4),
68 Value::test_int(6),
69 ])),
70 },
71 Example {
72 example: "1..3 | enumerate | par-each {|p| update item ($p.item * 2)} | sort-by item | get item",
73 description: "Enumerate and sort-by can be used to reconstruct the original order.",
74 result: Some(Value::test_list(vec![
75 Value::test_int(2),
76 Value::test_int(4),
77 Value::test_int(6),
78 ])),
79 },
80 Example {
81 example: "[foo bar baz] | par-each {|e| $e + '!' } | sort",
82 description: "Output can still be sorted afterward.",
83 result: Some(Value::test_list(vec![
84 Value::test_string("bar!"),
85 Value::test_string("baz!"),
86 Value::test_string("foo!"),
87 ])),
88 },
89 Example {
90 example: r#"[1 2 3] | enumerate | par-each { |e| if $e.item == 2 { $"found 2 at ($e.index)!"} }"#,
91 description: "Iterate over each element, producing a list showing indexes of any 2s.",
92 result: Some(Value::test_list(vec![Value::test_string("found 2 at 1!")])),
93 },
94 ]
95 }
96
97 fn run(
98 &self,
99 engine_state: &EngineState,
100 stack: &mut Stack,
101 call: &Call,
102 input: PipelineData,
103 ) -> Result<PipelineData, ShellError> {
104 fn create_pool(num_threads: usize, head: Span) -> Result<rayon::ThreadPool, ShellError> {
105 match rayon::ThreadPoolBuilder::new()
106 .num_threads(num_threads)
107 .build()
108 {
109 Err(e) => Err(e).map_err(|e| {
110 ShellError::Generic(GenericError::new(
111 "Error creating thread pool",
112 e.to_string(),
113 head,
114 ))
115 }),
116 Ok(pool) => Ok(pool),
117 }
118 }
119
120 let head = call.head;
121 let closure: Closure = call.req(engine_state, stack, 0)?;
122 let threads: Option<usize> = call.get_flag(engine_state, stack, "threads")?;
123 let max_threads = threads.unwrap_or(0);
124 let keep_order = call.has_flag(engine_state, stack, "keep-order")?;
125 let signals = engine_state.signals().clone();
126
127 let mut input = input.into_stream_or_original(engine_state);
128 let metadata = input.take_metadata();
129
130 let apply_order = |mut vec: Vec<(usize, Value)>| {
132 if keep_order {
133 vec.par_sort_unstable_by_key(|(index, _)| *index);
136 }
137
138 vec.into_iter().map(|(_, val)| val)
139 };
140
141 match input {
142 PipelineData::Empty => Ok(PipelineData::empty()),
143 PipelineData::Value(value, ..) => {
144 let span = value.span();
145 match value {
146 Value::List { vals, .. } => {
147 let pool = create_pool(max_threads, head)?;
148 if keep_order {
149 Ok(pool.install(|| {
150 let par_iter = vals.into_par_iter().enumerate();
151 let mapped =
152 parallel_closure_map(engine_state, stack, &closure, par_iter);
153 apply_order(mapped.collect())
154 .into_pipeline_data(span, signals.clone())
155 }))
156 } else {
157 let par_iter = vals.into_par_iter();
158 Ok(stream_parallel_values(
159 engine_state,
160 stack,
161 closure.clone(),
162 pool,
163 span,
164 signals.clone(),
165 par_iter,
166 ))
167 }
168 }
169 Value::Range { val, .. } => {
170 let pool = create_pool(max_threads, head)?;
171 if keep_order {
172 Ok(pool.install(|| {
173 let par_iter = val
174 .into_range_iter(span, signals.clone())
175 .enumerate()
176 .par_bridge();
177 let mapped =
178 parallel_closure_map(engine_state, stack, &closure, par_iter);
179 apply_order(mapped.collect())
180 .into_pipeline_data(span, signals.clone())
181 }))
182 } else {
183 let par_iter = val.into_range_iter(span, signals.clone()).par_bridge();
184 Ok(stream_parallel_values(
185 engine_state,
186 stack,
187 closure.clone(),
188 pool,
189 span,
190 signals.clone(),
191 par_iter,
192 ))
193 }
194 }
195 value => {
198 ClosureEvalOnce::new(engine_state, stack, closure).run_with_value(value)
199 }
200 }
201 }
202 PipelineData::ListStream(stream, ..) => {
203 let pool = create_pool(max_threads, head)?;
204 if keep_order {
205 Ok(pool.install(|| {
206 let par_iter = stream.into_iter().enumerate().par_bridge();
207 let mapped = parallel_closure_map(engine_state, stack, &closure, par_iter);
208
209 apply_order(mapped.collect()).into_pipeline_data(head, signals.clone())
210 }))
211 } else {
212 let par_iter = stream.into_iter().par_bridge();
213 Ok(stream_parallel_values(
214 engine_state,
215 stack,
216 closure.clone(),
217 pool,
218 head,
219 signals.clone(),
220 par_iter,
221 ))
222 }
223 }
224 PipelineData::ByteStream(stream, ..) => {
225 if let Some(chunks) = stream.chunks() {
226 let pool = create_pool(max_threads, head)?;
227 if keep_order {
228 Ok(pool.install(|| {
229 let par_iter = chunks
230 .enumerate()
231 .map(move |(idx, val)| {
232 (idx, val.unwrap_or_else(|err| Value::error(err, head)))
233 })
234 .par_bridge();
235 let mapped =
236 parallel_closure_map(engine_state, stack, &closure, par_iter);
237 apply_order(mapped.collect()).into_pipeline_data(head, signals.clone())
238 }))
239 } else {
240 let par_iter = chunks
241 .map(move |val| val.unwrap_or_else(|err| Value::error(err, head)))
242 .par_bridge();
243 Ok(stream_parallel_values(
244 engine_state,
245 stack,
246 closure.clone(),
247 pool,
248 head,
249 signals.clone(),
250 par_iter,
251 ))
252 }
253 } else {
254 Ok(PipelineData::empty())
255 }
256 }
257 }
258 .and_then(|x| x.filter(|v| !v.is_nothing(), engine_state.signals()))
259 .map(|data| data.set_metadata(metadata))
260 }
261}
262
263fn stream_parallel_values(
264 engine_state: &EngineState,
265 stack: &Stack,
266 closure: Closure,
267 pool: rayon::ThreadPool,
268 span: Span,
269 signals: Signals,
270 input: impl ParallelIterator<Item = Value> + 'static,
271) -> PipelineData {
272 let (tx, rx) = mpsc::sync_channel(STREAM_BUFFER_SIZE);
273 let worker_engine_state = engine_state.clone();
274 let worker_stack = stack.clone();
275 let worker_signals = signals.clone();
276
277 pool.install(|| {
278 rayon::spawn(move || {
279 let map_signals = worker_signals.clone();
280 let send_signals = worker_signals.clone();
281
282 let _ = input
283 .map_init(
284 move || ClosureEval::new(&worker_engine_state, &worker_stack, closure.clone()),
285 move |closure_eval, value| {
286 if map_signals.interrupted() {
287 return Err(());
288 }
289
290 let value = run_closure_on_value(closure_eval, value);
291
292 if map_signals.interrupted() {
293 Err(())
294 } else {
295 Ok(value)
296 }
297 },
298 )
299 .try_for_each(move |value| match value {
300 Ok(value) => {
301 if send_signals.interrupted() {
302 Err(())
303 } else {
304 tx.send(value).map_err(|_| ())
305 }
306 }
307 Err(()) => Err(()),
308 });
309 });
310 });
311
312 ReceiverIter::new(rx, signals).into_pipeline_data(span, Signals::empty())
313}
314
315struct ReceiverIter {
317 receiver: mpsc::Receiver<Value>,
318 signals: Signals,
319}
320
321impl ReceiverIter {
322 fn new(receiver: mpsc::Receiver<Value>, signals: Signals) -> Self {
323 Self { receiver, signals }
324 }
325}
326
327impl Iterator for ReceiverIter {
328 type Item = Value;
329
330 fn next(&mut self) -> Option<Self::Item> {
331 loop {
332 if self.signals.interrupted() {
333 return None;
334 }
335
336 match self.receiver.recv_timeout(CTRL_C_CHECK_INTERVAL) {
337 Ok(value) => return Some(value),
338 Err(RecvTimeoutError::Timeout) => {}
339 Err(RecvTimeoutError::Disconnected) => return None,
340 }
341 }
342 }
343}
344
345fn run_closure_on_value(closure_eval: &mut ClosureEval, value: Value) -> Value {
346 let span = value.span();
347 let is_error = value.is_error();
348
349 closure_eval
350 .run_with_value(value)
351 .and_then(|data| data.into_value(span))
352 .unwrap_or_else(|err| Value::error(chain_error_with_input(err, is_error, span), span))
353}
354
355fn parallel_closure_map(
356 engine_state: &EngineState,
357 stack: &mut Stack,
358 closure: &Closure,
359 input: impl ParallelIterator<Item = (usize, Value)>,
360) -> impl ParallelIterator<Item = (usize, Value)> {
361 input.map_init(
362 move || ClosureEval::new(engine_state, stack, closure.clone()),
363 |closure_eval, (index, value)| {
364 let value = run_closure_on_value(closure_eval, value);
365
366 (index, value)
367 },
368 )
369}
370
371#[cfg(test)]
372mod test {
373 use super::*;
374
375 #[test]
376 fn test_examples() -> nu_test_support::Result {
377 nu_test_support::test().examples(ParEach)
378 }
379}