Skip to main content

shape_runtime/stdlib/
parallel.rs

1//! Native `parallel` module for data-parallel operations.
2//!
3//! Exports: parallel.map, parallel.filter, parallel.for_each, parallel.chunks,
4//! parallel.reduce, parallel.num_threads
5//!
6//! Uses Rayon for thread-pool based data parallelism.
7//! The key constraint is that Shape closures (ValueWord) are not Send,
8//! so we use `invoke_callable` on the main thread but process pure-data
9//! operations (chunking, collecting) in parallel where possible.
10//!
11//! For `parallel.map` and `parallel.filter`, the callback is invoked via
12//! `invoke_callable` on the calling thread, but the data is partitioned
13//! and reassembled using Rayon when no callback is involved.
14
15use crate::module_exports::{ModuleContext, ModuleExports, ModuleFunction, ModuleParam};
16use shape_value::ValueWord;
17use std::sync::Arc;
18
19/// parallel.map(array, fn) -> Array
20///
21/// Maps a function over each element of the array. The callback is invoked
22/// sequentially via `invoke_callable` (Shape closures are not Send), but
23/// the result array is pre-allocated for efficiency.
24fn parallel_map(args: &[ValueWord], ctx: &ModuleContext) -> Result<ValueWord, String> {
25    let arr = args
26        .first()
27        .and_then(|a| a.as_any_array())
28        .ok_or_else(|| "parallel.map() requires an array as first argument".to_string())?;
29    let callback = args
30        .get(1)
31        .ok_or_else(|| "parallel.map() requires a callback as second argument".to_string())?;
32    let invoke = ctx.invoke_callable.ok_or_else(|| {
33        "parallel.map() requires invoke_callable (not available in this context)".to_string()
34    })?;
35
36    let items = arr.to_generic();
37    let mut results = Vec::with_capacity(items.len());
38    for item in items.iter() {
39        let result = invoke(callback, &[item.clone()])?;
40        results.push(result);
41    }
42    Ok(ValueWord::from_array(Arc::new(results)))
43}
44
45/// parallel.filter(array, fn) -> Array
46///
47/// Filters array elements using a predicate callback.
48fn parallel_filter(args: &[ValueWord], ctx: &ModuleContext) -> Result<ValueWord, String> {
49    let arr = args
50        .first()
51        .and_then(|a| a.as_any_array())
52        .ok_or_else(|| "parallel.filter() requires an array as first argument".to_string())?;
53    let callback = args
54        .get(1)
55        .ok_or_else(|| "parallel.filter() requires a callback as second argument".to_string())?;
56    let invoke = ctx.invoke_callable.ok_or_else(|| {
57        "parallel.filter() requires invoke_callable (not available in this context)".to_string()
58    })?;
59
60    let items = arr.to_generic();
61    let mut results = Vec::new();
62    for item in items.iter() {
63        let keep = invoke(callback, &[item.clone()])?;
64        if keep.as_bool().unwrap_or(false) {
65            results.push(item.clone());
66        }
67    }
68    Ok(ValueWord::from_array(Arc::new(results)))
69}
70
71/// parallel.for_each(array, fn) -> null
72///
73/// Applies a function to each element for side effects.
74fn parallel_for_each(args: &[ValueWord], ctx: &ModuleContext) -> Result<ValueWord, String> {
75    let arr = args
76        .first()
77        .and_then(|a| a.as_any_array())
78        .ok_or_else(|| "parallel.for_each() requires an array as first argument".to_string())?;
79    let callback = args
80        .get(1)
81        .ok_or_else(|| "parallel.for_each() requires a callback as second argument".to_string())?;
82    let invoke = ctx.invoke_callable.ok_or_else(|| {
83        "parallel.for_each() requires invoke_callable (not available in this context)".to_string()
84    })?;
85
86    let items = arr.to_generic();
87    for item in items.iter() {
88        invoke(callback, &[item.clone()])?;
89    }
90    Ok(ValueWord::none())
91}
92
93/// parallel.chunks(array, size) -> Array<Array>
94///
95/// Split an array into chunks of the given size. Pure utility, no parallelism.
96/// The last chunk may be smaller if the array length is not evenly divisible.
97fn parallel_chunks(args: &[ValueWord], _ctx: &ModuleContext) -> Result<ValueWord, String> {
98    let arr = args
99        .first()
100        .and_then(|a| a.as_any_array())
101        .ok_or_else(|| "parallel.chunks() requires an array as first argument".to_string())?;
102    let size = args
103        .get(1)
104        .and_then(|a| a.as_i64().or_else(|| a.as_f64().map(|n| n as i64)))
105        .ok_or_else(|| "parallel.chunks() requires a chunk size as second argument".to_string())?;
106
107    if size <= 0 {
108        return Err("parallel.chunks() chunk size must be positive".to_string());
109    }
110    let size = size as usize;
111
112    let items = arr.to_generic();
113    let chunks: Vec<ValueWord> = items
114        .chunks(size)
115        .map(|chunk| ValueWord::from_array(Arc::new(chunk.to_vec())))
116        .collect();
117    Ok(ValueWord::from_array(Arc::new(chunks)))
118}
119
120/// parallel.reduce(array, fn, initial) -> any
121///
122/// Reduces an array to a single value using a callback and initial accumulator.
123fn parallel_reduce(args: &[ValueWord], ctx: &ModuleContext) -> Result<ValueWord, String> {
124    let arr = args
125        .first()
126        .and_then(|a| a.as_any_array())
127        .ok_or_else(|| "parallel.reduce() requires an array as first argument".to_string())?;
128    let callback = args
129        .get(1)
130        .ok_or_else(|| "parallel.reduce() requires a callback as second argument".to_string())?;
131    let initial = args.get(2).ok_or_else(|| {
132        "parallel.reduce() requires an initial value as third argument".to_string()
133    })?;
134    let invoke = ctx.invoke_callable.ok_or_else(|| {
135        "parallel.reduce() requires invoke_callable (not available in this context)".to_string()
136    })?;
137
138    let items = arr.to_generic();
139    let mut acc = initial.clone();
140    for item in items.iter() {
141        acc = invoke(callback, &[acc, item.clone()])?;
142    }
143    Ok(acc)
144}
145
146/// parallel.num_threads() -> int
147///
148/// Returns the number of threads in the Rayon thread pool.
149fn parallel_num_threads(_args: &[ValueWord], _ctx: &ModuleContext) -> Result<ValueWord, String> {
150    Ok(ValueWord::from_i64(rayon::current_num_threads() as i64))
151}
152
153/// parallel.sort(array, fn?) -> Array
154///
155/// Sort an array. If a comparator is provided, uses it; otherwise sorts by
156/// natural ordering (numbers first, then strings).
157/// Uses Rayon's par_sort for arrays larger than 1024 elements.
158fn parallel_sort(args: &[ValueWord], ctx: &ModuleContext) -> Result<ValueWord, String> {
159    let arr = args
160        .first()
161        .and_then(|a| a.as_any_array())
162        .ok_or_else(|| "parallel.sort() requires an array as first argument".to_string())?;
163
164    let items = arr.to_generic();
165    let mut sorted: Vec<ValueWord> = (*items).clone();
166
167    if let Some(callback) = args.get(1) {
168        // Custom comparator: callback(a, b) -> number (negative, zero, positive)
169        let invoke = ctx.invoke_callable.ok_or_else(|| {
170            "parallel.sort() with comparator requires invoke_callable".to_string()
171        })?;
172
173        // Sort with the comparator (sequential — closure not Send)
174        let mut last_err: Option<String> = None;
175        sorted.sort_by(|a, b| {
176            if last_err.is_some() {
177                return std::cmp::Ordering::Equal;
178            }
179            match invoke(callback, &[a.clone(), b.clone()]) {
180                Ok(result) => {
181                    let n = result.as_number_coerce().unwrap_or(0.0);
182                    if n < 0.0 {
183                        std::cmp::Ordering::Less
184                    } else if n > 0.0 {
185                        std::cmp::Ordering::Greater
186                    } else {
187                        std::cmp::Ordering::Equal
188                    }
189                }
190                Err(e) => {
191                    last_err = Some(e);
192                    std::cmp::Ordering::Equal
193                }
194            }
195        });
196        if let Some(err) = last_err {
197            return Err(format!("parallel.sort() comparator error: {}", err));
198        }
199    } else {
200        // Natural ordering using Rayon for large arrays
201        use rayon::prelude::*;
202
203        if sorted.len() >= 1024 {
204            sorted.par_sort_by(|a, b| compare_values_natural(a, b));
205        } else {
206            sorted.sort_by(|a, b| compare_values_natural(a, b));
207        }
208    }
209    Ok(ValueWord::from_array(Arc::new(sorted)))
210}
211
212/// Natural ordering comparator for ValueWord.
213fn compare_values_natural(a: &ValueWord, b: &ValueWord) -> std::cmp::Ordering {
214    match (a.as_number_coerce(), b.as_number_coerce()) {
215        (Some(na), Some(nb)) => na.partial_cmp(&nb).unwrap_or(std::cmp::Ordering::Equal),
216        _ => match (a.as_str(), b.as_str()) {
217            (Some(sa), Some(sb)) => sa.cmp(sb),
218            _ => std::cmp::Ordering::Equal,
219        },
220    }
221}
222
223/// Create the `parallel` module.
224pub fn create_parallel_module() -> ModuleExports {
225    let mut module = ModuleExports::new("std::core::parallel");
226    module.description = "Data-parallel operations using Rayon thread pool".to_string();
227
228    module.add_function_with_schema(
229        "map",
230        parallel_map,
231        ModuleFunction {
232            description: "Map a function over array elements".to_string(),
233            params: vec![
234                ModuleParam {
235                    name: "array".to_string(),
236                    type_name: "Array<any>".to_string(),
237                    required: true,
238                    description: "Array to map over".to_string(),
239                    ..Default::default()
240                },
241                ModuleParam {
242                    name: "fn".to_string(),
243                    type_name: "function".to_string(),
244                    required: true,
245                    description: "Callback function applied to each element".to_string(),
246                    ..Default::default()
247                },
248            ],
249            return_type: Some("Array<any>".to_string()),
250        },
251    );
252
253    module.add_function_with_schema(
254        "filter",
255        parallel_filter,
256        ModuleFunction {
257            description: "Filter array elements using a predicate".to_string(),
258            params: vec![
259                ModuleParam {
260                    name: "array".to_string(),
261                    type_name: "Array<any>".to_string(),
262                    required: true,
263                    description: "Array to filter".to_string(),
264                    ..Default::default()
265                },
266                ModuleParam {
267                    name: "fn".to_string(),
268                    type_name: "function".to_string(),
269                    required: true,
270                    description: "Predicate function returning bool".to_string(),
271                    ..Default::default()
272                },
273            ],
274            return_type: Some("Array<any>".to_string()),
275        },
276    );
277
278    module.add_function_with_schema(
279        "for_each",
280        parallel_for_each,
281        ModuleFunction {
282            description: "Apply a function to each element for side effects".to_string(),
283            params: vec![
284                ModuleParam {
285                    name: "array".to_string(),
286                    type_name: "Array<any>".to_string(),
287                    required: true,
288                    description: "Array to iterate".to_string(),
289                    ..Default::default()
290                },
291                ModuleParam {
292                    name: "fn".to_string(),
293                    type_name: "function".to_string(),
294                    required: true,
295                    description: "Callback function applied to each element".to_string(),
296                    ..Default::default()
297                },
298            ],
299            return_type: Some("null".to_string()),
300        },
301    );
302
303    module.add_function_with_schema(
304        "chunks",
305        parallel_chunks,
306        ModuleFunction {
307            description: "Split an array into chunks of a given size".to_string(),
308            params: vec![
309                ModuleParam {
310                    name: "array".to_string(),
311                    type_name: "Array<any>".to_string(),
312                    required: true,
313                    description: "Array to chunk".to_string(),
314                    ..Default::default()
315                },
316                ModuleParam {
317                    name: "size".to_string(),
318                    type_name: "int".to_string(),
319                    required: true,
320                    description: "Size of each chunk".to_string(),
321                    ..Default::default()
322                },
323            ],
324            return_type: Some("Array<Array<any>>".to_string()),
325        },
326    );
327
328    module.add_function_with_schema(
329        "reduce",
330        parallel_reduce,
331        ModuleFunction {
332            description: "Reduce an array to a single value".to_string(),
333            params: vec![
334                ModuleParam {
335                    name: "array".to_string(),
336                    type_name: "Array<any>".to_string(),
337                    required: true,
338                    description: "Array to reduce".to_string(),
339                    ..Default::default()
340                },
341                ModuleParam {
342                    name: "fn".to_string(),
343                    type_name: "function".to_string(),
344                    required: true,
345                    description: "Reducer function (accumulator, element) -> accumulator"
346                        .to_string(),
347                    ..Default::default()
348                },
349                ModuleParam {
350                    name: "initial".to_string(),
351                    type_name: "any".to_string(),
352                    required: true,
353                    description: "Initial accumulator value".to_string(),
354                    ..Default::default()
355                },
356            ],
357            return_type: Some("any".to_string()),
358        },
359    );
360
361    module.add_function_with_schema(
362        "sort",
363        parallel_sort,
364        ModuleFunction {
365            description:
366                "Sort an array, optionally with a comparator. Uses parallel sort for large arrays."
367                    .to_string(),
368            params: vec![
369                ModuleParam {
370                    name: "array".to_string(),
371                    type_name: "Array<any>".to_string(),
372                    required: true,
373                    description: "Array to sort".to_string(),
374                    ..Default::default()
375                },
376                ModuleParam {
377                    name: "comparator".to_string(),
378                    type_name: "function".to_string(),
379                    required: false,
380                    description: "Comparator function (a, b) -> number".to_string(),
381                    ..Default::default()
382                },
383            ],
384            return_type: Some("Array<any>".to_string()),
385        },
386    );
387
388    module.add_function_with_schema(
389        "num_threads",
390        parallel_num_threads,
391        ModuleFunction {
392            description: "Return the number of threads in the Rayon thread pool".to_string(),
393            params: vec![],
394            return_type: Some("int".to_string()),
395        },
396    );
397
398    module
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    fn test_ctx() -> crate::module_exports::ModuleContext<'static> {
406        let registry = Box::leak(Box::new(crate::type_schema::TypeSchemaRegistry::new()));
407        crate::module_exports::ModuleContext {
408            schemas: registry,
409            invoke_callable: None,
410            raw_invoker: None,
411            function_hashes: None,
412            vm_state: None,
413            granted_permissions: None,
414            scope_constraints: None,
415            set_pending_resume: None,
416            set_pending_frame_resume: None,
417        }
418    }
419
420    #[test]
421    fn test_parallel_module_creation() {
422        let module = create_parallel_module();
423        assert_eq!(module.name, "std::core::parallel");
424        assert!(module.has_export("map"));
425        assert!(module.has_export("filter"));
426        assert!(module.has_export("for_each"));
427        assert!(module.has_export("chunks"));
428        assert!(module.has_export("reduce"));
429        assert!(module.has_export("sort"));
430        assert!(module.has_export("num_threads"));
431    }
432
433    #[test]
434    fn test_parallel_module_schemas() {
435        let module = create_parallel_module();
436
437        let map_schema = module.get_schema("map").unwrap();
438        assert_eq!(map_schema.params.len(), 2);
439        assert_eq!(map_schema.return_type.as_deref(), Some("Array<any>"));
440
441        let chunks_schema = module.get_schema("chunks").unwrap();
442        assert_eq!(chunks_schema.params.len(), 2);
443        assert_eq!(
444            chunks_schema.return_type.as_deref(),
445            Some("Array<Array<any>>")
446        );
447
448        let reduce_schema = module.get_schema("reduce").unwrap();
449        assert_eq!(reduce_schema.params.len(), 3);
450        assert_eq!(reduce_schema.return_type.as_deref(), Some("any"));
451
452        let num_threads_schema = module.get_schema("num_threads").unwrap();
453        assert_eq!(num_threads_schema.params.len(), 0);
454        assert_eq!(num_threads_schema.return_type.as_deref(), Some("int"));
455    }
456
457    #[test]
458    fn test_parallel_chunks_basic() {
459        let ctx = test_ctx();
460        let arr = ValueWord::from_array(Arc::new(vec![
461            ValueWord::from_i64(1),
462            ValueWord::from_i64(2),
463            ValueWord::from_i64(3),
464            ValueWord::from_i64(4),
465            ValueWord::from_i64(5),
466        ]));
467        let result = parallel_chunks(&[arr, ValueWord::from_i64(2)], &ctx).unwrap();
468        let chunks = result.as_any_array().unwrap().to_generic();
469        assert_eq!(chunks.len(), 3); // [1,2], [3,4], [5]
470
471        let first = chunks[0].as_any_array().unwrap().to_generic();
472        assert_eq!(first.len(), 2);
473        assert_eq!(first[0].as_i64(), Some(1));
474        assert_eq!(first[1].as_i64(), Some(2));
475
476        let last = chunks[2].as_any_array().unwrap().to_generic();
477        assert_eq!(last.len(), 1);
478        assert_eq!(last[0].as_i64(), Some(5));
479    }
480
481    #[test]
482    fn test_parallel_chunks_exact_division() {
483        let ctx = test_ctx();
484        let arr = ValueWord::from_array(Arc::new(vec![
485            ValueWord::from_i64(1),
486            ValueWord::from_i64(2),
487            ValueWord::from_i64(3),
488            ValueWord::from_i64(4),
489        ]));
490        let result = parallel_chunks(&[arr, ValueWord::from_i64(2)], &ctx).unwrap();
491        let chunks = result.as_any_array().unwrap().to_generic();
492        assert_eq!(chunks.len(), 2);
493    }
494
495    #[test]
496    fn test_parallel_chunks_size_larger_than_array() {
497        let ctx = test_ctx();
498        let arr = ValueWord::from_array(Arc::new(vec![
499            ValueWord::from_i64(1),
500            ValueWord::from_i64(2),
501        ]));
502        let result = parallel_chunks(&[arr, ValueWord::from_i64(10)], &ctx).unwrap();
503        let chunks = result.as_any_array().unwrap().to_generic();
504        assert_eq!(chunks.len(), 1);
505    }
506
507    #[test]
508    fn test_parallel_chunks_invalid_size() {
509        let ctx = test_ctx();
510        let arr = ValueWord::from_array(Arc::new(vec![ValueWord::from_i64(1)]));
511        let result = parallel_chunks(&[arr, ValueWord::from_i64(0)], &ctx);
512        assert!(result.is_err());
513    }
514
515    #[test]
516    fn test_parallel_chunks_empty_array() {
517        let ctx = test_ctx();
518        let arr = ValueWord::from_array(Arc::new(vec![]));
519        let result = parallel_chunks(&[arr, ValueWord::from_i64(3)], &ctx).unwrap();
520        let chunks = result.as_any_array().unwrap().to_generic();
521        assert_eq!(chunks.len(), 0);
522    }
523
524    #[test]
525    fn test_parallel_num_threads() {
526        let ctx = test_ctx();
527        let result = parallel_num_threads(&[], &ctx).unwrap();
528        let n = result.as_i64().unwrap();
529        assert!(n > 0, "num_threads should be positive, got {}", n);
530    }
531
532    #[test]
533    fn test_parallel_sort_natural() {
534        let ctx = test_ctx();
535        let arr = ValueWord::from_array(Arc::new(vec![
536            ValueWord::from_i64(3),
537            ValueWord::from_i64(1),
538            ValueWord::from_i64(4),
539            ValueWord::from_i64(1),
540            ValueWord::from_i64(5),
541        ]));
542        let result = parallel_sort(&[arr], &ctx).unwrap();
543        let sorted = result.as_any_array().unwrap().to_generic();
544        assert_eq!(sorted.len(), 5);
545        assert_eq!(sorted[0].as_i64(), Some(1));
546        assert_eq!(sorted[1].as_i64(), Some(1));
547        assert_eq!(sorted[2].as_i64(), Some(3));
548        assert_eq!(sorted[3].as_i64(), Some(4));
549        assert_eq!(sorted[4].as_i64(), Some(5));
550    }
551
552    #[test]
553    fn test_parallel_sort_strings() {
554        let ctx = test_ctx();
555        let arr = ValueWord::from_array(Arc::new(vec![
556            ValueWord::from_string(Arc::new("banana".to_string())),
557            ValueWord::from_string(Arc::new("apple".to_string())),
558            ValueWord::from_string(Arc::new("cherry".to_string())),
559        ]));
560        let result = parallel_sort(&[arr], &ctx).unwrap();
561        let sorted = result.as_any_array().unwrap().to_generic();
562        assert_eq!(sorted[0].as_str(), Some("apple"));
563        assert_eq!(sorted[1].as_str(), Some("banana"));
564        assert_eq!(sorted[2].as_str(), Some("cherry"));
565    }
566
567    #[test]
568    fn test_parallel_map_requires_callback() {
569        let ctx = test_ctx();
570        let arr = ValueWord::from_array(Arc::new(vec![ValueWord::from_i64(1)]));
571        let result = parallel_map(&[arr], &ctx);
572        assert!(result.is_err());
573    }
574
575    #[test]
576    fn test_parallel_map_requires_invoke_callable() {
577        let ctx = test_ctx();
578        let arr = ValueWord::from_array(Arc::new(vec![ValueWord::from_i64(1)]));
579        let cb = ValueWord::none(); // dummy
580        let result = parallel_map(&[arr, cb], &ctx);
581        assert!(result.is_err());
582        assert!(result.unwrap_err().contains("invoke_callable"));
583    }
584
585    #[test]
586    fn test_parallel_export_count() {
587        let module = create_parallel_module();
588        let names = module.export_names();
589        assert_eq!(names.len(), 7);
590    }
591}