jax_rs/ops/
manipulation.rs

1//! Array manipulation operations.
2//!
3//! Operations for reshaping, concatenating, and manipulating arrays.
4
5use crate::{Array, DType, Shape};
6
7impl Array {
8    /// Concatenate arrays along an existing axis.
9    ///
10    /// # Examples
11    ///
12    /// ```
13    /// # use jax_rs::{Array, Shape};
14    /// let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
15    /// let b = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
16    /// let c = Array::concatenate(&[a, b], 0);
17    /// assert_eq!(c.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
18    /// ```
19    pub fn concatenate(arrays: &[Array], axis: usize) -> Array {
20        assert!(!arrays.is_empty(), "Need at least one array to concatenate");
21        assert_eq!(
22            arrays[0].dtype(),
23            DType::Float32,
24            "Only Float32 supported"
25        );
26
27        // Validate all arrays have compatible shapes
28        let first_shape = arrays[0].shape();
29        let ndim = first_shape.ndim();
30        assert!(axis < ndim, "Axis out of bounds");
31
32        for arr in arrays.iter().skip(1) {
33            assert_eq!(
34                arr.ndim(),
35                ndim,
36                "All arrays must have same number of dimensions"
37            );
38            for (i, (&dim1, &dim2)) in first_shape
39                .as_slice()
40                .iter()
41                .zip(arr.shape().as_slice().iter())
42                .enumerate()
43            {
44                if i != axis {
45                    assert_eq!(
46                        dim1, dim2,
47                        "Dimensions must match except on concatenation axis"
48                    );
49                }
50            }
51        }
52
53        // Compute result shape
54        let mut result_dims = first_shape.as_slice().to_vec();
55        result_dims[axis] =
56            arrays.iter().map(|a| a.shape().as_slice()[axis]).sum();
57        let result_shape = Shape::new(result_dims.clone());
58
59        // Simple implementation for axis 0
60        if axis == 0 {
61            let mut data = Vec::new();
62            for arr in arrays {
63                data.extend(arr.to_vec());
64            }
65            Array::from_vec(data, result_shape)
66        } else {
67            // General implementation for any axis
68            // Compute strides for the result array
69            let total_size: usize = result_dims.iter().product();
70            let mut result = vec![0.0f32; total_size];
71
72            // Compute the size of chunks before and after the concatenation axis
73            let outer_size: usize = result_dims[..axis].iter().product();
74            let inner_size: usize = result_dims[axis + 1..].iter().product();
75
76            let mut offset_along_axis = 0;
77            for arr in arrays {
78                let arr_data = arr.to_vec();
79                let arr_shape = arr.shape().as_slice();
80                let arr_axis_size = arr_shape[axis];
81
82                for outer in 0..outer_size {
83                    for ax in 0..arr_axis_size {
84                        for inner in 0..inner_size {
85                            let src_idx = outer * arr_axis_size * inner_size + ax * inner_size + inner;
86                            let dst_ax = offset_along_axis + ax;
87                            let dst_idx = outer * result_dims[axis] * inner_size + dst_ax * inner_size + inner;
88                            result[dst_idx] = arr_data[src_idx];
89                        }
90                    }
91                }
92                offset_along_axis += arr_axis_size;
93            }
94
95            Array::from_vec(result, result_shape)
96        }
97    }
98
99    /// Stack arrays along a new axis.
100    ///
101    /// # Examples
102    ///
103    /// ```
104    /// # use jax_rs::{Array, Shape};
105    /// let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
106    /// let b = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
107    /// let c = Array::stack(&[a, b], 0);
108    /// assert_eq!(c.shape().as_slice(), &[2, 2]);
109    /// assert_eq!(c.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
110    /// ```
111    pub fn stack(arrays: &[Array], axis: usize) -> Array {
112        assert!(!arrays.is_empty(), "Need at least one array to stack");
113
114        // Validate all arrays have the same shape
115        let first_shape = arrays[0].shape();
116        for arr in arrays.iter().skip(1) {
117            assert_eq!(
118                arr.shape(),
119                first_shape,
120                "All arrays must have the same shape for stacking"
121            );
122        }
123
124        let ndim = first_shape.ndim();
125        assert!(axis <= ndim, "Axis out of bounds for stacking");
126
127        // Compute result shape
128        let mut result_dims = Vec::new();
129        for (i, &dim) in first_shape.as_slice().iter().enumerate() {
130            if i == axis {
131                result_dims.push(arrays.len());
132            }
133            result_dims.push(dim);
134        }
135        if axis == ndim {
136            result_dims.push(arrays.len());
137        }
138
139        // Simple implementation for axis 0
140        if axis == 0 {
141            let mut data = Vec::new();
142            for arr in arrays {
143                data.extend(arr.to_vec());
144            }
145            let mut shape_dims = vec![arrays.len()];
146            shape_dims.extend_from_slice(first_shape.as_slice());
147            Array::from_vec(data, Shape::new(shape_dims))
148        } else {
149            panic!("stack only supports axis=0 for now");
150        }
151    }
152
153    /// Split an array into multiple sub-arrays along a specified axis.
154    ///
155    /// # Arguments
156    ///
157    /// * `array` - The array to split
158    /// * `num_sections` - Number of equal sections to split into
159    /// * `axis` - The axis along which to split
160    ///
161    /// # Examples
162    ///
163    /// ```
164    /// # use jax_rs::{Array, Shape};
165    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Shape::new(vec![6]));
166    /// let parts = Array::split(&a, 3, 0);
167    /// assert_eq!(parts.len(), 3);
168    /// assert_eq!(parts[0].to_vec(), vec![1.0, 2.0]);
169    /// assert_eq!(parts[1].to_vec(), vec![3.0, 4.0]);
170    /// assert_eq!(parts[2].to_vec(), vec![5.0, 6.0]);
171    /// ```
172    pub fn split(array: &Array, num_sections: usize, axis: usize) -> Vec<Array> {
173        assert_eq!(array.dtype(), DType::Float32, "Only Float32 supported");
174        assert!(num_sections > 0, "Number of sections must be positive");
175
176        let shape = array.shape().as_slice();
177        assert!(axis < shape.len(), "Axis out of bounds");
178        let axis_size = shape[axis];
179        assert_eq!(
180            axis_size % num_sections,
181            0,
182            "Array size along axis must be divisible by number of sections"
183        );
184
185        let section_size = axis_size / num_sections;
186        let data = array.to_vec();
187
188        let mut result = Vec::with_capacity(num_sections);
189
190        if axis == 0 {
191            // Split along axis 0 - simple case
192            let elements_per_section = data.len() / num_sections;
193
194            for i in 0..num_sections {
195                let start = i * elements_per_section;
196                let end = start + elements_per_section;
197                let section_data = data[start..end].to_vec();
198
199                let mut section_shape = shape.to_vec();
200                section_shape[axis] = section_size;
201
202                result.push(Array::from_vec(section_data, Shape::new(section_shape)));
203            }
204        } else {
205            // General case for any axis
206            let outer_size: usize = shape[..axis].iter().product();
207            let inner_size: usize = shape[axis + 1..].iter().product();
208
209            for section_idx in 0..num_sections {
210                let mut section_data = Vec::with_capacity(outer_size * section_size * inner_size);
211
212                for outer in 0..outer_size {
213                    for ax in 0..section_size {
214                        let src_ax = section_idx * section_size + ax;
215                        for inner in 0..inner_size {
216                            let src_idx = outer * axis_size * inner_size + src_ax * inner_size + inner;
217                            section_data.push(data[src_idx]);
218                        }
219                    }
220                }
221
222                let mut section_shape = shape.to_vec();
223                section_shape[axis] = section_size;
224
225                result.push(Array::from_vec(section_data, Shape::new(section_shape)));
226            }
227        }
228
229        result
230    }
231
232    /// Select elements from array based on condition with broadcasting support.
233    ///
234    /// Returns elements from `x` where `condition` is true (non-zero), otherwise from `y`.
235    /// All three arrays are broadcast to a common shape.
236    ///
237    /// # Arguments
238    ///
239    /// * `condition` - Boolean array (non-zero = true, zero = false)
240    /// * `x` - Array of values to select when condition is true
241    /// * `y` - Array of values to select when condition is false
242    ///
243    /// # Examples
244    ///
245    /// ```
246    /// # use jax_rs::{Array, Shape};
247    /// let condition = Array::from_vec(vec![1.0, 0.0, 1.0], Shape::new(vec![3]));
248    /// let x = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
249    /// let y = Array::from_vec(vec![5.0, 5.0, 5.0], Shape::new(vec![3]));
250    /// let result = Array::where_cond(&condition, &x, &y);
251    /// assert_eq!(result.to_vec(), vec![10.0, 5.0, 30.0]);
252    /// ```
253    pub fn where_cond(condition: &Array, x: &Array, y: &Array) -> Array {
254        assert_eq!(condition.dtype(), DType::Float32, "Only Float32 supported");
255        assert_eq!(x.dtype(), DType::Float32, "Only Float32 supported");
256        assert_eq!(y.dtype(), DType::Float32, "Only Float32 supported");
257
258        // Compute common broadcast shape for all three arrays
259        let shape1 = condition
260            .shape()
261            .broadcast_with(x.shape())
262            .expect("Condition and x shapes are not broadcast-compatible");
263        let result_shape = shape1
264            .broadcast_with(y.shape())
265            .expect("Cannot broadcast all three arrays to common shape");
266
267        let cond_data = condition.to_vec();
268        let x_data = x.to_vec();
269        let y_data = y.to_vec();
270
271        // Fast path: all arrays have the same shape (no broadcasting needed)
272        if condition.shape() == x.shape()
273            && x.shape() == y.shape()
274            && condition.shape() == &result_shape
275        {
276            let result_data: Vec<f32> = cond_data
277                .iter()
278                .zip(x_data.iter().zip(y_data.iter()))
279                .map(|(&c, (&x_val, &y_val))| if c != 0.0 { x_val } else { y_val })
280                .collect();
281            return Array::from_vec(result_data, result_shape);
282        }
283
284        // Slow path: need broadcasting
285        let size = result_shape.size();
286        let result_data: Vec<f32> = (0..size)
287            .map(|i| {
288                let cond_idx =
289                    crate::ops::binary::broadcast_index(i, &result_shape, condition.shape());
290                let x_idx = crate::ops::binary::broadcast_index(i, &result_shape, x.shape());
291                let y_idx = crate::ops::binary::broadcast_index(i, &result_shape, y.shape());
292
293                if cond_data[cond_idx] != 0.0 {
294                    x_data[x_idx]
295                } else {
296                    y_data[y_idx]
297                }
298            })
299            .collect();
300
301        Array::from_vec(result_data, result_shape)
302    }
303
304    /// Select values from multiple choice arrays based on index array.
305    ///
306    /// For each element in `indices`, selects the corresponding element from
307    /// the choice array at that index. Similar to a multi-way switch statement.
308    ///
309    /// # Arguments
310    ///
311    /// * `indices` - Array of integer indices (as f32) specifying which choice to pick
312    /// * `choices` - Slice of arrays to choose from
313    ///
314    /// # Examples
315    ///
316    /// ```
317    /// # use jax_rs::{Array, Shape};
318    /// let indices = Array::from_vec(vec![0.0, 1.0, 2.0, 1.0], Shape::new(vec![4]));
319    /// let choice0 = Array::from_vec(vec![10.0, 10.0, 10.0, 10.0], Shape::new(vec![4]));
320    /// let choice1 = Array::from_vec(vec![20.0, 20.0, 20.0, 20.0], Shape::new(vec![4]));
321    /// let choice2 = Array::from_vec(vec![30.0, 30.0, 30.0, 30.0], Shape::new(vec![4]));
322    /// let result = Array::select(&indices, &[choice0, choice1, choice2]);
323    /// assert_eq!(result.to_vec(), vec![10.0, 20.0, 30.0, 20.0]);
324    /// ```
325    pub fn select(indices: &Array, choices: &[Array]) -> Array {
326        assert_eq!(indices.dtype(), DType::Float32, "Only Float32 supported");
327        assert!(!choices.is_empty(), "Must provide at least one choice");
328
329        // All choices must have the same shape
330        let choice_shape = choices[0].shape();
331        for choice in choices.iter().skip(1) {
332            assert_eq!(
333                choice.dtype(),
334                DType::Float32,
335                "Only Float32 supported for choices"
336            );
337            assert_eq!(
338                choice.shape(),
339                choice_shape,
340                "All choices must have the same shape"
341            );
342        }
343
344        // Indices and choices must have compatible shapes
345        assert_eq!(
346            indices.shape(),
347            choice_shape,
348            "Indices and choices must have the same shape"
349        );
350
351        let indices_data = indices.to_vec();
352        let choice_data: Vec<Vec<f32>> = choices.iter().map(|c| c.to_vec()).collect();
353
354        let result_data: Vec<f32> = indices_data
355            .iter()
356            .enumerate()
357            .map(|(i, &idx)| {
358                let idx_int = idx as usize;
359                assert!(
360                    idx_int < choices.len(),
361                    "Index {} out of bounds for {} choices",
362                    idx_int,
363                    choices.len()
364                );
365                choice_data[idx_int][i]
366            })
367            .collect();
368
369        Array::from_vec(result_data, choice_shape.clone())
370    }
371
372    /// Clip (limit) values in an array.
373    ///
374    /// # Examples
375    ///
376    /// ```
377    /// # use jax_rs::{Array, Shape};
378    /// let a = Array::from_vec(vec![1.0, 5.0, 10.0, 15.0], Shape::new(vec![4]));
379    /// let clipped = a.clip(3.0, 12.0);
380    /// assert_eq!(clipped.to_vec(), vec![3.0, 5.0, 10.0, 12.0]);
381    /// ```
382    pub fn clip(&self, min: f32, max: f32) -> Array {
383        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
384
385        let data = self.to_vec();
386        let result_data: Vec<f32> =
387            data.iter().map(|&x| x.clamp(min, max)).collect();
388
389        Array::from_vec(result_data, self.shape().clone())
390    }
391
392    /// Flip array along specified axis.
393    ///
394    /// # Examples
395    ///
396    /// ```
397    /// # use jax_rs::{Array, Shape};
398    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
399    /// let flipped = a.flip(0);
400    /// assert_eq!(flipped.to_vec(), vec![3.0, 2.0, 1.0]);
401    /// ```
402    pub fn flip(&self, axis: usize) -> Array {
403        assert!(axis < self.ndim(), "Axis out of bounds");
404        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
405
406        let shape = self.shape();
407        let dims = shape.as_slice();
408        let data = self.to_vec();
409
410        // Simple implementation for 1D arrays
411        if self.ndim() == 1 {
412            let mut result: Vec<f32> = data.clone();
413            result.reverse();
414            return Array::from_vec(result, shape.clone());
415        }
416
417        // For multi-dimensional, only support flipping axis 0 for now
418        if axis == 0 {
419            let slice_size = data.len() / dims[0];
420            let mut result = Vec::with_capacity(data.len());
421
422            for i in (0..dims[0]).rev() {
423                let start = i * slice_size;
424                let end = start + slice_size;
425                result.extend_from_slice(&data[start..end]);
426            }
427
428            Array::from_vec(result, shape.clone())
429        } else {
430            panic!("flip only supports axis=0 for multi-dimensional arrays");
431        }
432    }
433
434    /// Pad array with a constant value.
435    ///
436    /// # Arguments
437    ///
438    /// * `pad_width` - Number of values to pad on each side: [(before, after), ...]
439    /// * `constant_value` - Value to use for padding
440    ///
441    /// # Examples
442    ///
443    /// ```
444    /// # use jax_rs::{Array, Shape};
445    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
446    /// let padded = a.pad(&[(1, 1)], 0.0);
447    /// assert_eq!(padded.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 0.0]);
448    /// ```
449    pub fn pad(&self, pad_width: &[(usize, usize)], constant_value: f32) -> Array {
450        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
451        assert_eq!(
452            pad_width.len(),
453            self.ndim(),
454            "pad_width must match number of dimensions"
455        );
456
457        let shape = self.shape().as_slice();
458        let data = self.to_vec();
459
460        // Compute output shape
461        let mut out_shape = Vec::with_capacity(shape.len());
462        for (i, &dim) in shape.iter().enumerate() {
463            out_shape.push(pad_width[i].0 + dim + pad_width[i].1);
464        }
465
466        // For 1D case
467        if self.ndim() == 1 {
468            let (before, after) = pad_width[0];
469            let mut result = vec![constant_value; before];
470            result.extend_from_slice(&data);
471            result.extend(vec![constant_value; after]);
472            return Array::from_vec(result, Shape::new(out_shape));
473        }
474
475        // For 2D case
476        if self.ndim() == 2 {
477            let (h, w) = (shape[0], shape[1]);
478            let (h_before, _) = pad_width[0];
479            let (w_before, _) = pad_width[1];
480
481            let out_h = out_shape[0];
482            let out_w = out_shape[1];
483            let mut result = vec![constant_value; out_h * out_w];
484
485            for i in 0..h {
486                for j in 0..w {
487                    let out_i = i + h_before;
488                    let out_j = j + w_before;
489                    result[out_i * out_w + out_j] = data[i * w + j];
490                }
491            }
492
493            return Array::from_vec(result, Shape::new(out_shape));
494        }
495
496        panic!("pad only supports 1D and 2D arrays for now");
497    }
498
499    /// Pad array with edge values (repeat border elements).
500    ///
501    /// # Examples
502    ///
503    /// ```
504    /// # use jax_rs::{Array, Shape};
505    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
506    /// let padded = a.pad_edge(&[(1, 1)]);
507    /// assert_eq!(padded.to_vec(), vec![1.0, 1.0, 2.0, 3.0, 3.0]);
508    /// ```
509    pub fn pad_edge(&self, pad_width: &[(usize, usize)]) -> Array {
510        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
511        assert_eq!(
512            pad_width.len(),
513            self.ndim(),
514            "pad_width must match number of dimensions"
515        );
516
517        let shape = self.shape().as_slice();
518        let data = self.to_vec();
519
520        // For 1D case
521        if self.ndim() == 1 {
522            let (before, after) = pad_width[0];
523            let mut result = vec![data[0]; before];
524            result.extend_from_slice(&data);
525            result.extend(vec![data[data.len() - 1]; after]);
526
527            let out_len = before + shape[0] + after;
528            return Array::from_vec(result, Shape::new(vec![out_len]));
529        }
530
531        // For 2D case
532        if self.ndim() == 2 {
533            let (h, w) = (shape[0], shape[1]);
534            let (h_before, h_after) = pad_width[0];
535            let (w_before, w_after) = pad_width[1];
536
537            let out_h = h_before + h + h_after;
538            let out_w = w_before + w + w_after;
539            let mut result = vec![0.0; out_h * out_w];
540
541            for out_i in 0..out_h {
542                for out_j in 0..out_w {
543                    // Map output indices to input indices
544                    let in_i = if out_i < h_before {
545                        0
546                    } else if out_i >= h_before + h {
547                        h - 1
548                    } else {
549                        out_i - h_before
550                    };
551
552                    let in_j = if out_j < w_before {
553                        0
554                    } else if out_j >= w_before + w {
555                        w - 1
556                    } else {
557                        out_j - w_before
558                    };
559
560                    result[out_i * out_w + out_j] = data[in_i * w + in_j];
561                }
562            }
563
564            return Array::from_vec(result, Shape::new(vec![out_h, out_w]));
565        }
566
567        panic!("pad_edge only supports 1D and 2D arrays for now");
568    }
569
570    /// Pad array with reflected values (mirror border elements).
571    ///
572    /// # Examples
573    ///
574    /// ```
575    /// # use jax_rs::{Array, Shape};
576    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
577    /// let padded = a.pad_reflect(&[(1, 1)]);
578    /// assert_eq!(padded.to_vec(), vec![2.0, 1.0, 2.0, 3.0, 2.0]);
579    /// ```
580    pub fn pad_reflect(&self, pad_width: &[(usize, usize)]) -> Array {
581        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
582        assert_eq!(
583            pad_width.len(),
584            self.ndim(),
585            "pad_width must match number of dimensions"
586        );
587
588        let shape = self.shape().as_slice();
589        let data = self.to_vec();
590
591        // For 1D case
592        if self.ndim() == 1 {
593            let len = shape[0];
594            let (before, after) = pad_width[0];
595
596            assert!(
597                before < len && after < len,
598                "Padding width must be less than array size for reflect mode"
599            );
600
601            let mut result = Vec::with_capacity(before + len + after);
602
603            // Left padding (reflect)
604            for i in 0..before {
605                result.push(data[before - i]);
606            }
607
608            // Original data
609            result.extend_from_slice(&data);
610
611            // Right padding (reflect)
612            for i in 0..after {
613                result.push(data[len - 2 - i]);
614            }
615
616            let out_len = before + len + after;
617            return Array::from_vec(result, Shape::new(vec![out_len]));
618        }
619
620        // For 2D case
621        if self.ndim() == 2 {
622            let (h, w) = (shape[0], shape[1]);
623            let (h_before, h_after) = pad_width[0];
624            let (w_before, w_after) = pad_width[1];
625
626            assert!(
627                h_before < h && h_after < h && w_before < w && w_after < w,
628                "Padding width must be less than array size for reflect mode"
629            );
630
631            let out_h = h_before + h + h_after;
632            let out_w = w_before + w + w_after;
633            let mut result = vec![0.0; out_h * out_w];
634
635            for out_i in 0..out_h {
636                for out_j in 0..out_w {
637                    // Map output indices to input indices with reflection
638                    let in_i = if out_i < h_before {
639                        h_before - out_i
640                    } else if out_i >= h_before + h {
641                        h - 2 - (out_i - h_before - h)
642                    } else {
643                        out_i - h_before
644                    };
645
646                    let in_j = if out_j < w_before {
647                        w_before - out_j
648                    } else if out_j >= w_before + w {
649                        w - 2 - (out_j - w_before - w)
650                    } else {
651                        out_j - w_before
652                    };
653
654                    result[out_i * out_w + out_j] = data[in_i * w + in_j];
655                }
656            }
657
658            return Array::from_vec(result, Shape::new(vec![out_h, out_w]));
659        }
660
661        panic!("pad_reflect only supports 1D and 2D arrays for now");
662    }
663
664    /// Replace NaN and infinity values with specified numbers.
665    ///
666    /// # Arguments
667    ///
668    /// * `nan` - Value to replace NaN with (default 0.0)
669    /// * `posinf` - Value to replace positive infinity with (default large positive value)
670    /// * `neginf` - Value to replace negative infinity with (default large negative value)
671    ///
672    /// # Examples
673    ///
674    /// ```
675    /// # use jax_rs::{Array, Shape};
676    /// let a = Array::from_vec(vec![1.0, f32::NAN, f32::INFINITY, -f32::INFINITY], Shape::new(vec![4]));
677    /// let result = a.nan_to_num(0.0, 1e10, -1e10);
678    /// assert_eq!(result.to_vec()[0], 1.0);
679    /// assert_eq!(result.to_vec()[1], 0.0);
680    /// assert_eq!(result.to_vec()[2], 1e10);
681    /// assert_eq!(result.to_vec()[3], -1e10);
682    /// ```
683    pub fn nan_to_num(&self, nan: f32, posinf: f32, neginf: f32) -> Array {
684        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
685        let data = self.to_vec();
686        let result: Vec<f32> = data
687            .iter()
688            .map(|&x| {
689                if x.is_nan() {
690                    nan
691                } else if x.is_infinite() && x > 0.0 {
692                    posinf
693                } else if x.is_infinite() && x < 0.0 {
694                    neginf
695                } else {
696                    x
697                }
698            })
699            .collect();
700        Array::from_vec(result, self.shape().clone())
701    }
702
703    /// Check for NaN values element-wise.
704    ///
705    /// Returns an array with 1.0 where NaN, 0.0 otherwise.
706    ///
707    /// # Examples
708    ///
709    /// ```
710    /// # use jax_rs::{Array, Shape};
711    /// let a = Array::from_vec(vec![1.0, f32::NAN, 3.0, f32::NAN], Shape::new(vec![4]));
712    /// let result = a.isnan();
713    /// assert_eq!(result.to_vec(), vec![0.0, 1.0, 0.0, 1.0]);
714    /// ```
715    pub fn isnan(&self) -> Array {
716        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
717        let data = self.to_vec();
718        let result: Vec<f32> = data
719            .iter()
720            .map(|&x| if x.is_nan() { 1.0 } else { 0.0 })
721            .collect();
722        Array::from_vec(result, self.shape().clone())
723    }
724
725    /// Check for infinity values element-wise.
726    ///
727    /// Returns an array with 1.0 where infinity (positive or negative), 0.0 otherwise.
728    ///
729    /// # Examples
730    ///
731    /// ```
732    /// # use jax_rs::{Array, Shape};
733    /// let a = Array::from_vec(vec![1.0, f32::INFINITY, -f32::INFINITY, 3.0], Shape::new(vec![4]));
734    /// let result = a.isinf();
735    /// assert_eq!(result.to_vec(), vec![0.0, 1.0, 1.0, 0.0]);
736    /// ```
737    pub fn isinf(&self) -> Array {
738        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
739        let data = self.to_vec();
740        let result: Vec<f32> = data
741            .iter()
742            .map(|&x| if x.is_infinite() { 1.0 } else { 0.0 })
743            .collect();
744        Array::from_vec(result, self.shape().clone())
745    }
746
747    /// Check for finite values element-wise.
748    ///
749    /// Returns an array with 1.0 where finite, 0.0 otherwise (NaN or infinity).
750    ///
751    /// # Examples
752    ///
753    /// ```
754    /// # use jax_rs::{Array, Shape};
755    /// let a = Array::from_vec(vec![1.0, f32::NAN, f32::INFINITY, 3.0], Shape::new(vec![4]));
756    /// let result = a.isfinite();
757    /// assert_eq!(result.to_vec(), vec![1.0, 0.0, 0.0, 1.0]);
758    /// ```
759    pub fn isfinite(&self) -> Array {
760        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
761        let data = self.to_vec();
762        let result: Vec<f32> = data
763            .iter()
764            .map(|&x| if x.is_finite() { 1.0 } else { 0.0 })
765            .collect();
766        Array::from_vec(result, self.shape().clone())
767    }
768
769    /// Clip array values by L2 norm.
770    ///
771    /// If the L2 norm exceeds max_norm, scales the array down to have that norm.
772    /// Useful for gradient clipping in neural networks.
773    ///
774    /// # Examples
775    ///
776    /// ```
777    /// # use jax_rs::{Array, Shape};
778    /// let a = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
779    /// let clipped = a.clip_by_norm(2.0);
780    /// // Original norm is 5.0, should be scaled to 2.0
781    /// let result = clipped.to_vec();
782    /// assert!((result[0] - 1.2).abs() < 1e-5);
783    /// assert!((result[1] - 1.6).abs() < 1e-5);
784    /// ```
785    pub fn clip_by_norm(&self, max_norm: f32) -> Array {
786        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
787        let data = self.to_vec();
788
789        // Compute L2 norm
790        let norm: f32 = data.iter().map(|&x| x * x).sum::<f32>().sqrt();
791
792        if norm <= max_norm {
793            return self.clone();
794        }
795
796        // Scale down to max_norm
797        let scale = max_norm / norm;
798        let result: Vec<f32> = data.iter().map(|&x| x * scale).collect();
799        Array::from_vec(result, self.shape().clone())
800    }
801
802    /// Flatten array to 1D.
803    ///
804    /// Returns a 1D array containing all elements in row-major order.
805    ///
806    /// # Examples
807    ///
808    /// ```
809    /// # use jax_rs::{Array, Shape};
810    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
811    /// let flat = a.ravel();
812    /// assert_eq!(flat.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
813    /// assert_eq!(flat.shape().as_slice(), &[4]);
814    /// ```
815    pub fn ravel(&self) -> Array {
816        let size = self.size();
817        Array::from_vec(self.to_vec(), Shape::new(vec![size]))
818    }
819
820    /// Flatten array to 1D (alias for ravel).
821    ///
822    /// Returns a 1D array containing all elements in row-major order.
823    ///
824    /// # Examples
825    ///
826    /// ```
827    /// # use jax_rs::{Array, Shape};
828    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
829    /// let flat = a.flatten();
830    /// assert_eq!(flat.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
831    /// ```
832    pub fn flatten(&self) -> Array {
833        self.ravel()
834    }
835
836    /// View array with at least 1D.
837    ///
838    /// Scalar (0D) arrays are converted to 1D arrays with shape [1].
839    ///
840    /// # Examples
841    ///
842    /// ```
843    /// # use jax_rs::{Array, Shape};
844    /// let a = Array::from_vec(vec![5.0], Shape::new(vec![]));
845    /// let b = a.atleast_1d();
846    /// assert_eq!(b.shape().as_slice(), &[1]);
847    /// ```
848    pub fn atleast_1d(&self) -> Array {
849        if self.shape().ndim() == 0 {
850            Array::from_vec(self.to_vec(), Shape::new(vec![1]))
851        } else {
852            self.clone()
853        }
854    }
855
856    /// View array with at least 2D.
857    ///
858    /// Arrays with fewer than 2 dimensions are expanded.
859    ///
860    /// # Examples
861    ///
862    /// ```
863    /// # use jax_rs::{Array, Shape};
864    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
865    /// let b = a.atleast_2d();
866    /// assert_eq!(b.shape().as_slice(), &[1, 3]);
867    /// ```
868    pub fn atleast_2d(&self) -> Array {
869        match self.shape().ndim() {
870            0 => Array::from_vec(self.to_vec(), Shape::new(vec![1, 1])),
871            1 => {
872                let n = self.shape().as_slice()[0];
873                Array::from_vec(self.to_vec(), Shape::new(vec![1, n]))
874            }
875            _ => self.clone(),
876        }
877    }
878
879    /// View array with at least 3D.
880    ///
881    /// Arrays with fewer than 3 dimensions are expanded.
882    ///
883    /// # Examples
884    ///
885    /// ```
886    /// # use jax_rs::{Array, Shape};
887    /// let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
888    /// let b = a.atleast_3d();
889    /// assert_eq!(b.shape().as_slice(), &[1, 2, 1]);
890    /// ```
891    pub fn atleast_3d(&self) -> Array {
892        match self.shape().ndim() {
893            0 => Array::from_vec(self.to_vec(), Shape::new(vec![1, 1, 1])),
894            1 => {
895                let n = self.shape().as_slice()[0];
896                Array::from_vec(self.to_vec(), Shape::new(vec![1, n, 1]))
897            }
898            2 => {
899                let dims = self.shape().as_slice();
900                Array::from_vec(self.to_vec(), Shape::new(vec![dims[0], dims[1], 1]))
901            }
902            _ => self.clone(),
903        }
904    }
905
906    /// Broadcast array to a new shape.
907    ///
908    /// The new shape must be broadcast-compatible with the current shape.
909    ///
910    /// # Examples
911    ///
912    /// ```
913    /// # use jax_rs::{Array, Shape};
914    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
915    /// let b = a.broadcast_to(Shape::new(vec![2, 3]));
916    /// assert_eq!(b.shape().as_slice(), &[2, 3]);
917    /// assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
918    /// ```
919    pub fn broadcast_to(&self, new_shape: Shape) -> Array {
920        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
921
922        // Check broadcast compatibility with better error message
923        let _result_shape = self
924            .shape()
925            .broadcast_with(&new_shape)
926            .unwrap_or_else(|| {
927                panic!(
928                    "Cannot broadcast array of shape {:?} to shape {:?}. \
929                     Broadcasting requires dimensions to be equal or one of them to be 1.",
930                    self.shape().as_slice(),
931                    new_shape.as_slice()
932                )
933            });
934
935        let data = self.to_vec();
936        let size = new_shape.size();
937        let mut result = Vec::with_capacity(size);
938
939        for i in 0..size {
940            let src_idx =
941                crate::ops::binary::broadcast_index(i, &new_shape, self.shape());
942            result.push(data[src_idx]);
943        }
944
945        Array::from_vec(result, new_shape)
946    }
947
948    /// Broadcast multiple arrays to a common shape.
949    ///
950    /// All arrays are broadcast to a shape that is compatible with all inputs.
951    /// The result shape is determined by the broadcast rules applied successively.
952    ///
953    /// # Arguments
954    ///
955    /// * `arrays` - Slice of arrays to broadcast
956    ///
957    /// # Returns
958    ///
959    /// Vector of arrays, all broadcast to the same shape
960    ///
961    /// # Examples
962    ///
963    /// ```
964    /// # use jax_rs::{Array, Shape};
965    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
966    /// let b = Array::from_vec(vec![10.0, 20.0], Shape::new(vec![2, 1]));
967    /// let broadcasted = Array::broadcast_arrays(&[a, b]);
968    ///
969    /// // Both should have shape [2, 3]
970    /// assert_eq!(broadcasted[0].shape().as_slice(), &[2, 3]);
971    /// assert_eq!(broadcasted[1].shape().as_slice(), &[2, 3]);
972    /// ```
973    pub fn broadcast_arrays(arrays: &[Array]) -> Vec<Array> {
974        if arrays.is_empty() {
975            return vec![];
976        }
977
978        if arrays.len() == 1 {
979            return vec![arrays[0].clone()];
980        }
981
982        // Find the common broadcast shape
983        let mut common_shape = arrays[0].shape().clone();
984
985        for array in &arrays[1..] {
986            common_shape = common_shape
987                .broadcast_with(array.shape())
988                .unwrap_or_else(|| {
989                    panic!(
990                        "Cannot broadcast arrays with shapes {:?} and {:?}. \
991                         Broadcasting requires dimensions to be equal or one of them to be 1.",
992                        common_shape.as_slice(),
993                        array.shape().as_slice()
994                    )
995                });
996        }
997
998        // Broadcast all arrays to the common shape
999        arrays
1000            .iter()
1001            .map(|arr| arr.broadcast_to(common_shape.clone()))
1002            .collect()
1003    }
1004
1005    /// Take elements from array along an axis at specified indices.
1006    ///
1007    /// # Examples
1008    ///
1009    /// ```
1010    /// # use jax_rs::{Array, Shape};
1011    /// let a = Array::from_vec(vec![10.0, 20.0, 30.0, 40.0], Shape::new(vec![4]));
1012    /// let indices = vec![0, 2, 3];
1013    /// let result = a.take(&indices);
1014    /// assert_eq!(result.to_vec(), vec![10.0, 30.0, 40.0]);
1015    /// ```
1016    pub fn take(&self, indices: &[usize]) -> Array {
1017        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1018        let data = self.to_vec();
1019
1020        let result: Vec<f32> = indices
1021            .iter()
1022            .map(|&idx| {
1023                assert!(idx < data.len(), "Index {} out of bounds", idx);
1024                data[idx]
1025            })
1026            .collect();
1027
1028        let len = result.len();
1029        Array::from_vec(result, Shape::new(vec![len]))
1030    }
1031
1032    /// Put values into an array at specified indices.
1033    ///
1034    /// Replaces elements at the given indices with the provided values.
1035    /// Returns a new array with the modifications.
1036    ///
1037    /// # Arguments
1038    ///
1039    /// * `indices` - Flat indices where values should be placed
1040    /// * `values` - Values to place at those indices
1041    ///
1042    /// # Examples
1043    ///
1044    /// ```
1045    /// # use jax_rs::{Array, Shape};
1046    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
1047    /// let result = a.put(&[0, 2, 4], &[10.0, 30.0, 50.0]);
1048    /// assert_eq!(result.to_vec(), vec![10.0, 2.0, 30.0, 4.0, 50.0]);
1049    /// ```
1050    pub fn put(&self, indices: &[usize], values: &[f32]) -> Array {
1051        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1052        assert_eq!(
1053            indices.len(),
1054            values.len(),
1055            "Number of indices must match number of values"
1056        );
1057
1058        let mut data = self.to_vec();
1059
1060        for (i, &idx) in indices.iter().enumerate() {
1061            assert!(idx < data.len(), "Index {} out of bounds", idx);
1062            data[idx] = values[i];
1063        }
1064
1065        Array::from_vec(data, self.shape().clone())
1066    }
1067
1068    /// Scatter update values into an array at specified indices.
1069    ///
1070    /// Returns a new array with values from `updates` placed at positions
1071    /// specified by `indices`. This is equivalent to `put()` but follows
1072    /// the JAX/NumPy scatter naming convention.
1073    ///
1074    /// # Arguments
1075    ///
1076    /// * `indices` - Flattened indices where updates should be placed
1077    /// * `updates` - Values to place at the specified indices
1078    ///
1079    /// # Examples
1080    ///
1081    /// ```
1082    /// # use jax_rs::{Array, Shape};
1083    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
1084    /// let result = a.scatter(&[0, 2, 4], &[10.0, 30.0, 50.0]);
1085    /// assert_eq!(result.to_vec(), vec![10.0, 2.0, 30.0, 4.0, 50.0]);
1086    /// ```
1087    pub fn scatter(&self, indices: &[usize], updates: &[f32]) -> Array {
1088        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1089        assert_eq!(
1090            indices.len(),
1091            updates.len(),
1092            "Number of indices must match number of updates"
1093        );
1094
1095        let mut data = self.to_vec();
1096
1097        for (i, &idx) in indices.iter().enumerate() {
1098            assert!(idx < data.len(), "Index {} out of bounds", idx);
1099            data[idx] = updates[i];
1100        }
1101
1102        Array::from_vec(data, self.shape().clone())
1103    }
1104
1105    /// Scatter-add values into an array at specified indices.
1106    ///
1107    /// Returns a new array with values from `updates` added to the values
1108    /// at positions specified by `indices`. If the same index appears multiple
1109    /// times, updates are accumulated.
1110    ///
1111    /// # Arguments
1112    ///
1113    /// * `indices` - Flattened indices where updates should be added
1114    /// * `updates` - Values to add at the specified indices
1115    ///
1116    /// # Examples
1117    ///
1118    /// ```
1119    /// # use jax_rs::{Array, Shape};
1120    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
1121    /// let result = a.scatter_add(&[0, 2, 4], &[10.0, 30.0, 50.0]);
1122    /// assert_eq!(result.to_vec(), vec![11.0, 2.0, 33.0, 4.0, 55.0]);
1123    /// ```
1124    pub fn scatter_add(&self, indices: &[usize], updates: &[f32]) -> Array {
1125        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1126        assert_eq!(
1127            indices.len(),
1128            updates.len(),
1129            "Number of indices must match number of updates"
1130        );
1131
1132        let mut data = self.to_vec();
1133
1134        for (i, &idx) in indices.iter().enumerate() {
1135            assert!(idx < data.len(), "Index {} out of bounds", idx);
1136            data[idx] += updates[i];
1137        }
1138
1139        Array::from_vec(data, self.shape().clone())
1140    }
1141
1142    /// Scatter-min values into an array at specified indices.
1143    ///
1144    /// Returns a new array where each position specified by `indices` contains
1145    /// the minimum of the original value and the corresponding update value.
1146    ///
1147    /// # Arguments
1148    ///
1149    /// * `indices` - Flattened indices where min operation should be applied
1150    /// * `updates` - Values to compare with current values
1151    ///
1152    /// # Examples
1153    ///
1154    /// ```
1155    /// # use jax_rs::{Array, Shape};
1156    /// let a = Array::from_vec(vec![5.0, 10.0, 15.0, 20.0, 25.0], Shape::new(vec![5]));
1157    /// let result = a.scatter_min(&[1, 2, 3], &[8.0, 20.0, 15.0]);
1158    /// assert_eq!(result.to_vec(), vec![5.0, 8.0, 15.0, 15.0, 25.0]);
1159    /// ```
1160    pub fn scatter_min(&self, indices: &[usize], updates: &[f32]) -> Array {
1161        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1162        assert_eq!(
1163            indices.len(),
1164            updates.len(),
1165            "Number of indices must match number of updates"
1166        );
1167
1168        let mut data = self.to_vec();
1169
1170        for (i, &idx) in indices.iter().enumerate() {
1171            assert!(idx < data.len(), "Index {} out of bounds", idx);
1172            data[idx] = data[idx].min(updates[i]);
1173        }
1174
1175        Array::from_vec(data, self.shape().clone())
1176    }
1177
1178    /// Scatter-max values into an array at specified indices.
1179    ///
1180    /// Returns a new array where each position specified by `indices` contains
1181    /// the maximum of the original value and the corresponding update value.
1182    ///
1183    /// # Arguments
1184    ///
1185    /// * `indices` - Flattened indices where max operation should be applied
1186    /// * `updates` - Values to compare with current values
1187    ///
1188    /// # Examples
1189    ///
1190    /// ```
1191    /// # use jax_rs::{Array, Shape};
1192    /// let a = Array::from_vec(vec![5.0, 10.0, 15.0, 20.0, 25.0], Shape::new(vec![5]));
1193    /// let result = a.scatter_max(&[1, 2, 3], &[12.0, 10.0, 25.0]);
1194    /// assert_eq!(result.to_vec(), vec![5.0, 12.0, 15.0, 25.0, 25.0]);
1195    /// ```
1196    pub fn scatter_max(&self, indices: &[usize], updates: &[f32]) -> Array {
1197        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1198        assert_eq!(
1199            indices.len(),
1200            updates.len(),
1201            "Number of indices must match number of updates"
1202        );
1203
1204        let mut data = self.to_vec();
1205
1206        for (i, &idx) in indices.iter().enumerate() {
1207            assert!(idx < data.len(), "Index {} out of bounds", idx);
1208            data[idx] = data[idx].max(updates[i]);
1209        }
1210
1211        Array::from_vec(data, self.shape().clone())
1212    }
1213
1214    /// Scatter updates to specified indices using multiplication.
1215    ///
1216    /// For each index, multiplies the existing value with the update value.
1217    /// When multiple updates target the same index, they are accumulated.
1218    ///
1219    /// # Arguments
1220    ///
1221    /// * `indices` - Indices where updates should be applied
1222    /// * `updates` - Values to multiply at the corresponding indices
1223    ///
1224    /// # Examples
1225    ///
1226    /// ```
1227    /// # use jax_rs::{Array, Shape};
1228    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
1229    /// let result = a.scatter_mul(&[1, 2, 3], &[2.0, 3.0, 0.5]);
1230    /// assert_eq!(result.to_vec(), vec![1.0, 4.0, 9.0, 2.0, 5.0]);
1231    /// ```
1232    pub fn scatter_mul(&self, indices: &[usize], updates: &[f32]) -> Array {
1233        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1234        assert_eq!(
1235            indices.len(),
1236            updates.len(),
1237            "Number of indices must match number of updates"
1238        );
1239
1240        let mut data = self.to_vec();
1241
1242        for (i, &idx) in indices.iter().enumerate() {
1243            assert!(idx < data.len(), "Index {} out of bounds", idx);
1244            data[idx] *= updates[i];
1245        }
1246
1247        Array::from_vec(data, self.shape().clone())
1248    }
1249
1250    /// Take values from an array along an axis using indices.
1251    ///
1252    /// This is similar to gather operations in other frameworks.
1253    /// For each position, it selects the element specified by the index array.
1254    ///
1255    /// # Arguments
1256    ///
1257    /// * `indices` - Array of indices to take along the axis
1258    /// * `axis` - Axis along which to take values
1259    ///
1260    /// # Examples
1261    ///
1262    /// ```
1263    /// # use jax_rs::{Array, Shape};
1264    /// // For a 2D array, select different columns for each row
1265    /// let a = Array::from_vec(
1266    ///     vec![10.0, 20.0, 30.0,
1267    ///          40.0, 50.0, 60.0],
1268    ///     Shape::new(vec![2, 3])
1269    /// );
1270    /// let indices = Array::from_vec(vec![0.0, 2.0], Shape::new(vec![2]));
1271    /// let result = a.take_along_axis(&indices, 1);
1272    /// // Takes column 0 from row 0 (10.0) and column 2 from row 1 (60.0)
1273    /// assert_eq!(result.to_vec(), vec![10.0, 60.0]);
1274    /// ```
1275    pub fn take_along_axis(&self, indices: &Array, axis: usize) -> Array {
1276        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1277        assert_eq!(indices.dtype(), DType::Float32, "Indices must be Float32");
1278        assert!(axis < self.ndim(), "Axis {} out of bounds", axis);
1279
1280        let data = self.to_vec();
1281        let idx_data = indices.to_vec();
1282        let shape = self.shape().as_slice();
1283
1284        // Handle 1D indices for 2D array (backward compatible simplified API)
1285        if indices.ndim() == 1 && self.ndim() == 2 {
1286            if axis == 1 {
1287                // Take along columns: for each row, pick the column from indices
1288                let rows = shape[0];
1289                let cols = shape[1];
1290                assert_eq!(
1291                    indices.size(),
1292                    rows,
1293                    "Indices size must match number of rows"
1294                );
1295                let result: Vec<f32> = idx_data
1296                    .iter()
1297                    .enumerate()
1298                    .map(|(row, &idx)| {
1299                        let col = idx as usize;
1300                        assert!(col < cols, "Column index {} out of bounds", col);
1301                        data[row * cols + col]
1302                    })
1303                    .collect();
1304                return Array::from_vec(result, Shape::new(vec![rows]));
1305            } else {
1306                // axis == 0: Take along rows: for each column, pick the row from indices
1307                let rows = shape[0];
1308                let cols = shape[1];
1309                assert_eq!(
1310                    indices.size(),
1311                    cols,
1312                    "Indices size must match number of columns"
1313                );
1314                let result: Vec<f32> = idx_data
1315                    .iter()
1316                    .enumerate()
1317                    .map(|(col, &idx)| {
1318                        let row = idx as usize;
1319                        assert!(row < rows, "Row index {} out of bounds", row);
1320                        data[row * cols + col]
1321                    })
1322                    .collect();
1323                return Array::from_vec(result, Shape::new(vec![cols]));
1324            }
1325        }
1326
1327        // Handle 1D case
1328        if self.ndim() == 1 {
1329            let result: Vec<f32> = idx_data
1330                .iter()
1331                .map(|&idx| {
1332                    let i = idx as usize;
1333                    assert!(i < data.len(), "Index {} out of bounds", i);
1334                    data[i]
1335                })
1336                .collect();
1337            return Array::from_vec(result, indices.shape().clone());
1338        }
1339
1340        // General N-dimensional implementation (requires matching dimensions)
1341        let idx_shape = indices.shape().as_slice();
1342        assert_eq!(
1343            self.ndim(),
1344            indices.ndim(),
1345            "For N-dimensional take_along_axis, array and indices must have same number of dimensions"
1346        );
1347        for (i, (&s, &is)) in shape.iter().zip(idx_shape.iter()).enumerate() {
1348            if i != axis {
1349                assert_eq!(
1350                    s, is,
1351                    "Dimension {} must match: array has {}, indices has {}",
1352                    i, s, is
1353                );
1354            }
1355        }
1356
1357        let ndim = self.ndim();
1358        let out_size = indices.size();
1359        let mut result = vec![0.0f32; out_size];
1360
1361        // Compute strides for input array
1362        let mut strides = vec![1usize; ndim];
1363        for i in (0..ndim - 1).rev() {
1364            strides[i] = strides[i + 1] * shape[i + 1];
1365        }
1366
1367        // Compute strides for indices array
1368        let mut idx_strides = vec![1usize; ndim];
1369        for i in (0..ndim - 1).rev() {
1370            idx_strides[i] = idx_strides[i + 1] * idx_shape[i + 1];
1371        }
1372
1373        // Iterate over all output positions
1374        for out_flat in 0..out_size {
1375            // Convert flat index to multi-dimensional index in output/indices space
1376            let mut multi_idx = vec![0usize; ndim];
1377            let mut remaining = out_flat;
1378            for i in 0..ndim {
1379                multi_idx[i] = remaining / idx_strides[i];
1380                remaining %= idx_strides[i];
1381            }
1382
1383            // Get the index value at this position
1384            let idx_val = idx_data[out_flat] as usize;
1385            assert!(
1386                idx_val < shape[axis],
1387                "Index {} out of bounds for axis {} with size {}",
1388                idx_val,
1389                axis,
1390                shape[axis]
1391            );
1392
1393            // Build input multi-dimensional index: same as output except at axis
1394            let mut input_idx = multi_idx.clone();
1395            input_idx[axis] = idx_val;
1396
1397            // Convert to flat index in input array
1398            let in_flat: usize = input_idx
1399                .iter()
1400                .zip(strides.iter())
1401                .map(|(i, s)| i * s)
1402                .sum();
1403
1404            result[out_flat] = data[in_flat];
1405        }
1406
1407        Array::from_vec(result, indices.shape().clone())
1408    }
1409
1410    /// Return indices of non-zero elements.
1411    ///
1412    /// Returns indices where elements are non-zero (not equal to 0.0).
1413    ///
1414    /// # Examples
1415    ///
1416    /// ```
1417    /// # use jax_rs::{Array, Shape};
1418    /// let a = Array::from_vec(vec![0.0, 1.0, 0.0, 3.0, 0.0], Shape::new(vec![5]));
1419    /// let indices = a.nonzero();
1420    /// assert_eq!(indices, vec![1, 3]);
1421    /// ```
1422    pub fn nonzero(&self) -> Vec<usize> {
1423        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1424        let data = self.to_vec();
1425
1426        data.iter()
1427            .enumerate()
1428            .filter(|(_, &val)| val != 0.0)
1429            .map(|(idx, _)| idx)
1430            .collect()
1431    }
1432
1433    /// Return indices where condition is true (non-zero).
1434    ///
1435    /// Similar to nonzero but returns 2D array of indices for multi-dimensional arrays.
1436    /// For 1D arrays, returns a list of indices.
1437    ///
1438    /// # Examples
1439    ///
1440    /// ```
1441    /// # use jax_rs::{Array, Shape};
1442    /// let a = Array::from_vec(vec![0.0, 1.0, 0.0, 1.0], Shape::new(vec![4]));
1443    /// let indices = a.argwhere();
1444    /// assert_eq!(indices, vec![1, 3]);
1445    /// ```
1446    pub fn argwhere(&self) -> Vec<usize> {
1447        self.nonzero()
1448    }
1449
1450    /// Select elements from array based on condition mask.
1451    ///
1452    /// Returns a 1D array of elements where the condition is true (non-zero).
1453    ///
1454    /// # Examples
1455    ///
1456    /// ```
1457    /// # use jax_rs::{Array, Shape};
1458    /// let a = Array::from_vec(vec![10.0, 20.0, 30.0, 40.0], Shape::new(vec![4]));
1459    /// let condition = Array::from_vec(vec![1.0, 0.0, 1.0, 0.0], Shape::new(vec![4]));
1460    /// let result = a.compress(&condition);
1461    /// assert_eq!(result.to_vec(), vec![10.0, 30.0]);
1462    /// ```
1463    pub fn compress(&self, condition: &Array) -> Array {
1464        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1465        assert_eq!(
1466            condition.dtype(),
1467            DType::Float32,
1468            "Only Float32 supported"
1469        );
1470        assert_eq!(
1471            self.size(),
1472            condition.size(),
1473            "Array and condition must have same size"
1474        );
1475
1476        let data = self.to_vec();
1477        let cond_data = condition.to_vec();
1478
1479        let result: Vec<f32> = data
1480            .iter()
1481            .zip(cond_data.iter())
1482            .filter(|(_, &c)| c != 0.0)
1483            .map(|(&val, _)| val)
1484            .collect();
1485
1486        let len = result.len();
1487        Array::from_vec(result, Shape::new(vec![len]))
1488    }
1489
1490    /// Choose elements from arrays based on index array.
1491    ///
1492    /// For each element in the index array, select from the corresponding choice array.
1493    ///
1494    /// # Examples
1495    ///
1496    /// ```
1497    /// # use jax_rs::{Array, Shape};
1498    /// let choices = vec![
1499    ///     Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3])),
1500    ///     Array::from_vec(vec![100.0, 200.0, 300.0], Shape::new(vec![3])),
1501    /// ];
1502    /// let indices = vec![0, 1, 0];
1503    /// let result = Array::choose(&indices, &choices);
1504    /// assert_eq!(result.to_vec(), vec![10.0, 200.0, 30.0]);
1505    /// ```
1506    pub fn choose(indices: &[usize], choices: &[Array]) -> Array {
1507        assert!(!choices.is_empty(), "Must provide at least one choice");
1508        let size = choices[0].size();
1509
1510        for choice in choices.iter() {
1511            assert_eq!(
1512                choice.size(),
1513                size,
1514                "All choices must have the same size"
1515            );
1516        }
1517
1518        assert_eq!(
1519            indices.len(),
1520            size,
1521            "Indices must have same length as choices"
1522        );
1523
1524        let choice_data: Vec<Vec<f32>> =
1525            choices.iter().map(|c| c.to_vec()).collect();
1526
1527        let result: Vec<f32> = (0..size)
1528            .map(|i| {
1529                let choice_idx = indices[i];
1530                assert!(
1531                    choice_idx < choices.len(),
1532                    "Index {} out of bounds",
1533                    choice_idx
1534                );
1535                choice_data[choice_idx][i]
1536            })
1537            .collect();
1538
1539        Array::from_vec(result, choices[0].shape().clone())
1540    }
1541
1542    /// Extract elements from array where condition is true.
1543    ///
1544    /// Similar to compress, but condition can be a boolean-like array.
1545    ///
1546    /// # Examples
1547    ///
1548    /// ```
1549    /// # use jax_rs::{Array, Shape};
1550    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
1551    /// let condition = Array::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0], Shape::new(vec![5]));
1552    /// let result = a.extract(&condition);
1553    /// assert_eq!(result.to_vec(), vec![1.0, 3.0, 5.0]);
1554    /// ```
1555    pub fn extract(&self, condition: &Array) -> Array {
1556        self.compress(condition)
1557    }
1558
1559    /// Roll array elements along a given axis.
1560    ///
1561    /// Elements that roll beyond the last position are re-introduced at the first.
1562    ///
1563    /// # Examples
1564    ///
1565    /// ```
1566    /// # use jax_rs::{Array, Shape};
1567    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
1568    /// let rolled = a.roll(2);
1569    /// assert_eq!(rolled.to_vec(), vec![4.0, 5.0, 1.0, 2.0, 3.0]);
1570    /// ```
1571    pub fn roll(&self, shift: isize) -> Array {
1572        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1573        let data = self.to_vec();
1574        let len = data.len();
1575
1576        if len == 0 {
1577            return self.clone();
1578        }
1579
1580        // Normalize shift to be within [0, len)
1581        let shift = ((shift % len as isize) + len as isize) as usize % len;
1582
1583        let mut result = vec![0.0; len];
1584        for i in 0..len {
1585            result[(i + shift) % len] = data[i];
1586        }
1587
1588        Array::from_vec(result, self.shape().clone())
1589    }
1590
1591    /// Rotate array by 90 degrees in the plane specified by axes.
1592    ///
1593    /// For 2D arrays, rotates counterclockwise.
1594    ///
1595    /// # Examples
1596    ///
1597    /// ```
1598    /// # use jax_rs::{Array, Shape};
1599    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1600    /// let rotated = a.rot90(1);
1601    /// assert_eq!(rotated.to_vec(), vec![2.0, 4.0, 1.0, 3.0]);
1602    /// ```
1603    pub fn rot90(&self, k: isize) -> Array {
1604        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1605        assert_eq!(self.shape().ndim(), 2, "Only 2D arrays supported");
1606
1607        let shape = self.shape().as_slice();
1608        let (h, w) = (shape[0], shape[1]);
1609        let data = self.to_vec();
1610
1611        // Normalize k to be within [0, 4)
1612        let k = k.rem_euclid(4);
1613
1614        match k {
1615            0 => self.clone(),
1616            1 => {
1617                // Rotate 90 degrees counterclockwise
1618                let mut result = vec![0.0; h * w];
1619                for i in 0..h {
1620                    for j in 0..w {
1621                        let new_i = w - 1 - j;
1622                        let new_j = i;
1623                        result[new_i * h + new_j] = data[i * w + j];
1624                    }
1625                }
1626                Array::from_vec(result, Shape::new(vec![w, h]))
1627            }
1628            2 => {
1629                // Rotate 180 degrees
1630                let mut result = vec![0.0; h * w];
1631                for i in 0..h {
1632                    for j in 0..w {
1633                        let new_i = h - 1 - i;
1634                        let new_j = w - 1 - j;
1635                        result[new_i * w + new_j] = data[i * w + j];
1636                    }
1637                }
1638                Array::from_vec(result, Shape::new(vec![h, w]))
1639            }
1640            3 => {
1641                // Rotate 270 degrees counterclockwise (90 clockwise)
1642                let mut result = vec![0.0; h * w];
1643                for i in 0..h {
1644                    for j in 0..w {
1645                        let new_i = j;
1646                        let new_j = h - 1 - i;
1647                        result[new_i * h + new_j] = data[i * w + j];
1648                    }
1649                }
1650                Array::from_vec(result, Shape::new(vec![w, h]))
1651            }
1652            _ => unreachable!(),
1653        }
1654    }
1655
1656    /// Interchange two axes of an array.
1657    ///
1658    /// # Examples
1659    ///
1660    /// ```
1661    /// # use jax_rs::{Array, Shape};
1662    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Shape::new(vec![2, 3]));
1663    /// let swapped = a.swapaxes(0, 1);
1664    /// assert_eq!(swapped.shape().as_slice(), &[3, 2]);
1665    /// ```
1666    pub fn swapaxes(&self, axis1: usize, axis2: usize) -> Array {
1667        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1668        let ndim = self.shape().ndim();
1669        assert!(axis1 < ndim, "axis1 out of bounds");
1670        assert!(axis2 < ndim, "axis2 out of bounds");
1671
1672        if axis1 == axis2 {
1673            return self.clone();
1674        }
1675
1676        // For 2D arrays, this is just transpose
1677        if ndim == 2 && ((axis1 == 0 && axis2 == 1) || (axis1 == 1 && axis2 == 0)) {
1678            return self.transpose();
1679        }
1680
1681        // General case for higher dimensions
1682        let old_shape = self.shape().as_slice();
1683        let mut new_shape = old_shape.to_vec();
1684        new_shape.swap(axis1, axis2);
1685
1686        let data = self.to_vec();
1687        let size = self.size();
1688        let mut result = vec![0.0; size];
1689
1690        // Compute strides for old and new shapes
1691        let old_strides = self.shape().default_strides();
1692        let mut new_strides = vec![1; ndim];
1693        for i in (0..ndim - 1).rev() {
1694            new_strides[i] = new_strides[i + 1] * new_shape[i + 1];
1695        }
1696
1697        // Copy elements with swapped axes
1698        for i in 0..size {
1699            let mut old_indices = vec![0; ndim];
1700            let mut temp = i;
1701            for j in 0..ndim {
1702                old_indices[j] = temp / old_strides[j];
1703                temp %= old_strides[j];
1704            }
1705
1706            // Swap the indices
1707            old_indices.swap(axis1, axis2);
1708
1709            // Compute new flat index
1710            let mut new_idx = 0;
1711            for j in 0..ndim {
1712                new_idx += old_indices[j] * new_strides[j];
1713            }
1714
1715            result[new_idx] = data[i];
1716        }
1717
1718        Array::from_vec(result, Shape::new(new_shape))
1719    }
1720
1721    /// Move axes of an array to new positions.
1722    ///
1723    /// Simplified version that only supports moving a single axis.
1724    ///
1725    /// # Examples
1726    ///
1727    /// ```
1728    /// # use jax_rs::{Array, Shape};
1729    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Shape::new(vec![1, 2, 3]));
1730    /// let moved = a.moveaxis(2, 0);
1731    /// assert_eq!(moved.shape().as_slice(), &[3, 1, 2]);
1732    /// ```
1733    pub fn moveaxis(&self, source: usize, destination: usize) -> Array {
1734        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1735        let ndim = self.shape().ndim();
1736        assert!(source < ndim, "source axis out of bounds");
1737        assert!(destination < ndim, "destination axis out of bounds");
1738
1739        if source == destination {
1740            return self.clone();
1741        }
1742
1743        let old_shape = self.shape().as_slice();
1744        let mut new_shape = Vec::new();
1745
1746        // Build new shape by removing source axis and inserting at destination
1747        for (i, &dim) in old_shape.iter().enumerate() {
1748            if i != source {
1749                new_shape.push(dim);
1750            }
1751        }
1752        new_shape.insert(destination, old_shape[source]);
1753
1754        // For simple cases, use swapaxes
1755        if ndim == 2 {
1756            return self.swapaxes(source, destination);
1757        }
1758
1759        // For 3D, we can implement a specific case
1760        if ndim == 3 {
1761            let data = self.to_vec();
1762            let size = self.size();
1763            let mut result = vec![0.0; size];
1764
1765            let old_strides = self.shape().default_strides();
1766            let mut new_strides = vec![1; ndim];
1767            for i in (0..ndim - 1).rev() {
1768                new_strides[i] = new_strides[i + 1] * new_shape[i + 1];
1769            }
1770
1771            for i in 0..size {
1772                let mut old_indices = vec![0; ndim];
1773                let mut temp = i;
1774                for j in 0..ndim {
1775                    old_indices[j] = temp / old_strides[j];
1776                    temp %= old_strides[j];
1777                }
1778
1779                // Reorder indices
1780                let moved_val = old_indices[source];
1781                old_indices.remove(source);
1782                old_indices.insert(destination, moved_val);
1783
1784                // Compute new flat index
1785                let mut new_idx = 0;
1786                for j in 0..ndim {
1787                    new_idx += old_indices[j] * new_strides[j];
1788                }
1789
1790                result[new_idx] = data[i];
1791            }
1792
1793            return Array::from_vec(result, Shape::new(new_shape));
1794        }
1795
1796        // For higher dimensions, fall back to swapaxes
1797        self.swapaxes(source, destination)
1798    }
1799
1800    /// One-dimensional linear interpolation.
1801    ///
1802    /// Returns interpolated values at specified points using linear interpolation.
1803    ///
1804    /// # Arguments
1805    ///
1806    /// * `x` - x-coordinates at which to evaluate the interpolated values
1807    /// * `xp` - x-coordinates of the data points (must be increasing)
1808    /// * `fp` - y-coordinates of the data points
1809    ///
1810    /// # Examples
1811    ///
1812    /// ```
1813    /// # use jax_rs::{Array, Shape};
1814    /// let xp = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1815    /// let fp = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
1816    /// let x = Array::from_vec(vec![1.5, 2.5], Shape::new(vec![2]));
1817    /// let result = Array::interp(&x, &xp, &fp);
1818    /// assert_eq!(result.to_vec(), vec![15.0, 25.0]);
1819    /// ```
1820    pub fn interp(x: &Array, xp: &Array, fp: &Array) -> Array {
1821        assert_eq!(x.dtype(), DType::Float32, "Only Float32 supported");
1822        assert_eq!(xp.dtype(), DType::Float32, "Only Float32 supported");
1823        assert_eq!(fp.dtype(), DType::Float32, "Only Float32 supported");
1824        assert_eq!(
1825            xp.size(),
1826            fp.size(),
1827            "xp and fp must have the same size"
1828        );
1829
1830        let x_data = x.to_vec();
1831        let xp_data = xp.to_vec();
1832        let fp_data = fp.to_vec();
1833
1834        let result: Vec<f32> = x_data
1835            .iter()
1836            .map(|&xi| {
1837                // Handle edge cases
1838                if xi <= xp_data[0] {
1839                    return fp_data[0];
1840                }
1841                if xi >= xp_data[xp_data.len() - 1] {
1842                    return fp_data[fp_data.len() - 1];
1843                }
1844
1845                // Find the interval containing xi
1846                for i in 0..xp_data.len() - 1 {
1847                    if xi >= xp_data[i] && xi <= xp_data[i + 1] {
1848                        // Linear interpolation
1849                        let t = (xi - xp_data[i]) / (xp_data[i + 1] - xp_data[i]);
1850                        return fp_data[i] + t * (fp_data[i + 1] - fp_data[i]);
1851                    }
1852                }
1853
1854                fp_data[fp_data.len() - 1]
1855            })
1856            .collect();
1857
1858        Array::from_vec(result, x.shape().clone())
1859    }
1860
1861    /// Linear interpolation between two arrays.
1862    ///
1863    /// Returns a + weight * (b - a)
1864    ///
1865    /// # Examples
1866    ///
1867    /// ```
1868    /// # use jax_rs::{Array, Shape};
1869    /// let a = Array::from_vec(vec![0.0, 10.0, 20.0], Shape::new(vec![3]));
1870    /// let b = Array::from_vec(vec![100.0, 110.0, 120.0], Shape::new(vec![3]));
1871    /// let result = a.lerp(&b, 0.5);
1872    /// assert_eq!(result.to_vec(), vec![50.0, 60.0, 70.0]);
1873    /// ```
1874    pub fn lerp(&self, other: &Array, weight: f32) -> Array {
1875        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1876        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
1877        assert_eq!(
1878            self.shape(),
1879            other.shape(),
1880            "Arrays must have the same shape"
1881        );
1882
1883        let self_data = self.to_vec();
1884        let other_data = other.to_vec();
1885
1886        let result: Vec<f32> = self_data
1887            .iter()
1888            .zip(other_data.iter())
1889            .map(|(&a, &b)| a + weight * (b - a))
1890            .collect();
1891
1892        Array::from_vec(result, self.shape().clone())
1893    }
1894
1895    /// Linearly interpolate between two arrays element-wise with array weights.
1896    ///
1897    /// Returns a + weight * (b - a) where weight is an array.
1898    ///
1899    /// # Examples
1900    ///
1901    /// ```
1902    /// # use jax_rs::{Array, Shape};
1903    /// let a = Array::from_vec(vec![0.0, 10.0, 20.0], Shape::new(vec![3]));
1904    /// let b = Array::from_vec(vec![100.0, 110.0, 120.0], Shape::new(vec![3]));
1905    /// let weights = Array::from_vec(vec![0.0, 0.5, 1.0], Shape::new(vec![3]));
1906    /// let result = a.lerp_array(&b, &weights);
1907    /// assert_eq!(result.to_vec(), vec![0.0, 60.0, 120.0]);
1908    /// ```
1909    pub fn lerp_array(&self, other: &Array, weights: &Array) -> Array {
1910        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1911        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
1912        assert_eq!(weights.dtype(), DType::Float32, "Only Float32 supported");
1913        assert_eq!(
1914            self.shape(),
1915            other.shape(),
1916            "Arrays must have the same shape"
1917        );
1918        assert_eq!(
1919            self.shape(),
1920            weights.shape(),
1921            "Arrays and weights must have the same shape"
1922        );
1923
1924        let self_data = self.to_vec();
1925        let other_data = other.to_vec();
1926        let weight_data = weights.to_vec();
1927
1928        let result: Vec<f32> = self_data
1929            .iter()
1930            .zip(other_data.iter())
1931            .zip(weight_data.iter())
1932            .map(|((&a, &b), &w)| a + w * (b - a))
1933            .collect();
1934
1935        Array::from_vec(result, self.shape().clone())
1936    }
1937
1938    /// Compute the discrete 1D convolution of two arrays.
1939    ///
1940    /// Returns the discrete linear convolution of the input array with a kernel.
1941    /// Uses 'valid' mode (only overlapping parts).
1942    ///
1943    /// # Examples
1944    ///
1945    /// ```
1946    /// # use jax_rs::{Array, Shape};
1947    /// let signal = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
1948    /// let kernel = Array::from_vec(vec![1.0, 0.0, -1.0], Shape::new(vec![3]));
1949    /// let conv = signal.convolve(&kernel);
1950    /// // Convolution flips the kernel: [-1, 0, 1]
1951    /// // [1*(-1) + 2*0 + 3*1, 2*(-1) + 3*0 + 4*1, 3*(-1) + 4*0 + 5*1]
1952    /// // = [-1+0+3, -2+0+4, -3+0+5] = [2, 2, 2]
1953    /// assert_eq!(conv.to_vec(), vec![2.0, 2.0, 2.0]);
1954    /// ```
1955    pub fn convolve(&self, kernel: &Array) -> Array {
1956        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1957        assert_eq!(kernel.dtype(), DType::Float32, "Only Float32 supported");
1958        assert_eq!(self.ndim(), 1, "Convolve only supports 1D arrays");
1959        assert_eq!(kernel.ndim(), 1, "Kernel must be 1D");
1960
1961        let signal = self.to_vec();
1962        let mut k = kernel.to_vec();
1963        let n = signal.len();
1964        let m = k.len();
1965
1966        if m > n {
1967            // Kernel longer than signal - return empty array
1968            return Array::zeros(Shape::new(vec![0]), DType::Float32);
1969        }
1970
1971        // Flip the kernel for convolution
1972        k.reverse();
1973
1974        // Valid mode: output size = n - m + 1
1975        let out_size = n - m + 1;
1976        let mut result = Vec::with_capacity(out_size);
1977
1978        for i in 0..out_size {
1979            let mut sum = 0.0;
1980            for j in 0..m {
1981                sum += signal[i + j] * k[j];
1982            }
1983            result.push(sum);
1984        }
1985
1986        Array::from_vec(result, Shape::new(vec![out_size]))
1987    }
1988
1989    /// Compute the cross-correlation of two 1D arrays.
1990    ///
1991    /// Cross-correlation is similar to convolution but without flipping the kernel.
1992    /// Uses 'valid' mode (only overlapping parts).
1993    ///
1994    /// # Examples
1995    ///
1996    /// ```
1997    /// # use jax_rs::{Array, Shape};
1998    /// let signal = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
1999    /// let template = Array::from_vec(vec![1.0, 2.0, 1.0], Shape::new(vec![3]));
2000    /// let corr = signal.correlate(&template);
2001    /// // [1*1 + 2*2 + 3*1, 2*1 + 3*2 + 4*1, 3*1 + 4*2 + 5*1]
2002    /// // = [1+4+3, 2+6+4, 3+8+5] = [8, 12, 16]
2003    /// assert_eq!(corr.to_vec(), vec![8.0, 12.0, 16.0]);
2004    /// ```
2005    pub fn correlate(&self, template: &Array) -> Array {
2006        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2007        assert_eq!(template.dtype(), DType::Float32, "Only Float32 supported");
2008        assert_eq!(self.ndim(), 1, "Correlate only supports 1D arrays");
2009        assert_eq!(template.ndim(), 1, "Template must be 1D");
2010
2011        let signal = self.to_vec();
2012        let t = template.to_vec();
2013        let n = signal.len();
2014        let m = t.len();
2015
2016        if m > n {
2017            // Template longer than signal - return empty array
2018            return Array::zeros(Shape::new(vec![0]), DType::Float32);
2019        }
2020
2021        // Valid mode: output size = n - m + 1
2022        let out_size = n - m + 1;
2023        let mut result = Vec::with_capacity(out_size);
2024
2025        for i in 0..out_size {
2026            let mut sum = 0.0;
2027            for j in 0..m {
2028                // Note: no kernel flip, unlike convolution
2029                sum += signal[i + j] * t[j];
2030            }
2031            result.push(sum);
2032        }
2033
2034        Array::from_vec(result, Shape::new(vec![out_size]))
2035    }
2036
2037    /// Stack arrays vertically (row-wise).
2038    ///
2039    /// Equivalent to concatenation along axis 0 after promoting 1D arrays to 2D.
2040    ///
2041    /// # Examples
2042    ///
2043    /// ```
2044    /// # use jax_rs::{Array, Shape};
2045    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
2046    /// let b = Array::from_vec(vec![4.0, 5.0, 6.0], Shape::new(vec![3]));
2047    /// let stacked = a.vstack(&b);
2048    /// assert_eq!(stacked.shape().as_slice(), &[2, 3]);
2049    /// assert_eq!(stacked.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2050    /// ```
2051    pub fn vstack(&self, other: &Array) -> Array {
2052        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2053        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
2054
2055        let self_shape = self.shape().as_slice();
2056        let other_shape = other.shape().as_slice();
2057
2058        // Promote 1D arrays to 2D (1, N)
2059        let self_2d = if self_shape.len() == 1 {
2060            self.reshape(Shape::new(vec![1, self_shape[0]]))
2061        } else {
2062            self.clone()
2063        };
2064
2065        let other_2d = if other_shape.len() == 1 {
2066            other.reshape(Shape::new(vec![1, other_shape[0]]))
2067        } else {
2068            other.clone()
2069        };
2070
2071        // Concatenate along axis 0
2072        Array::concatenate(&[self_2d, other_2d], 0)
2073    }
2074
2075    /// Stack arrays horizontally (column-wise).
2076    ///
2077    /// Equivalent to concatenation along axis 1.
2078    ///
2079    /// # Examples
2080    ///
2081    /// ```
2082    /// # use jax_rs::{Array, Shape};
2083    /// let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2, 1]));
2084    /// let b = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2, 1]));
2085    /// let stacked = a.hstack(&b);
2086    /// assert_eq!(stacked.shape().as_slice(), &[2, 2]);
2087    /// assert_eq!(stacked.to_vec(), vec![1.0, 3.0, 2.0, 4.0]);
2088    /// ```
2089    pub fn hstack(&self, other: &Array) -> Array {
2090        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2091        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
2092
2093        let self_shape = self.shape().as_slice();
2094        let other_shape = other.shape().as_slice();
2095
2096        // For 1D arrays, concatenate directly
2097        if self_shape.len() == 1 && other_shape.len() == 1 {
2098            return Array::concatenate(&[self.clone(), other.clone()], 0);
2099        }
2100
2101        // For 2D+ arrays, concatenate along axis 1
2102        Array::concatenate(&[self.clone(), other.clone()], 1)
2103    }
2104
2105    /// Split array into multiple sub-arrays vertically (row-wise).
2106    ///
2107    /// # Examples
2108    ///
2109    /// ```
2110    /// # use jax_rs::{Array, Shape};
2111    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Shape::new(vec![6]));
2112    /// let parts = a.vsplit(2);
2113    /// assert_eq!(parts.len(), 2);
2114    /// assert_eq!(parts[0].to_vec(), vec![1.0, 2.0, 3.0]);
2115    /// assert_eq!(parts[1].to_vec(), vec![4.0, 5.0, 6.0]);
2116    /// ```
2117    pub fn vsplit(&self, num_sections: usize) -> Vec<Array> {
2118        let shape = self.shape().as_slice();
2119
2120        if shape.len() == 1 {
2121            // For 1D arrays, split along axis 0
2122            return Array::split(self, num_sections, 0);
2123        }
2124
2125        // For 2D+ arrays, split along axis 0
2126        Array::split(self, num_sections, 0)
2127    }
2128
2129    /// Split array into multiple sub-arrays horizontally (column-wise).
2130    ///
2131    /// # Examples
2132    ///
2133    /// ```
2134    /// # use jax_rs::{Array, Shape};
2135    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
2136    /// let parts = a.hsplit(2);
2137    /// assert_eq!(parts.len(), 2);
2138    /// assert_eq!(parts[0].shape().as_slice(), &[2, 1]);
2139    /// assert_eq!(parts[1].shape().as_slice(), &[2, 1]);
2140    /// ```
2141    pub fn hsplit(&self, num_sections: usize) -> Vec<Array> {
2142        let shape = self.shape().as_slice();
2143        assert!(!shape.is_empty(), "hsplit requires at least 1D array");
2144
2145        if shape.len() == 1 {
2146            // For 1D arrays, split along axis 0
2147            return Array::split(self, num_sections, 0);
2148        }
2149
2150        // For 2D+ arrays, split along axis 1
2151        Array::split(self, num_sections, 1)
2152    }
2153
2154    /// Append values to the end of an array.
2155    ///
2156    /// # Examples
2157    ///
2158    /// ```
2159    /// # use jax_rs::{Array, Shape};
2160    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
2161    /// let b = Array::from_vec(vec![4.0, 5.0], Shape::new(vec![2]));
2162    /// let result = a.append(&b);
2163    /// assert_eq!(result.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
2164    /// ```
2165    pub fn append(&self, values: &Array) -> Array {
2166        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2167        assert_eq!(values.dtype(), DType::Float32, "Only Float32 supported");
2168
2169        let mut data = self.to_vec();
2170        data.extend(values.to_vec());
2171
2172        let new_size = data.len();
2173        Array::from_vec(data, Shape::new(vec![new_size]))
2174    }
2175
2176    /// Insert values at the given index.
2177    ///
2178    /// # Examples
2179    ///
2180    /// ```
2181    /// # use jax_rs::{Array, Shape};
2182    /// let a = Array::from_vec(vec![1.0, 2.0, 5.0, 6.0], Shape::new(vec![4]));
2183    /// let values = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
2184    /// let result = a.insert(2, &values);
2185    /// assert_eq!(result.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2186    /// ```
2187    pub fn insert(&self, index: usize, values: &Array) -> Array {
2188        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2189        assert_eq!(values.dtype(), DType::Float32, "Only Float32 supported");
2190
2191        let mut data = self.to_vec();
2192        let values_data = values.to_vec();
2193
2194        assert!(index <= data.len(), "Index out of bounds");
2195
2196        // Insert values at the specified index
2197        for (i, &val) in values_data.iter().enumerate() {
2198            data.insert(index + i, val);
2199        }
2200
2201        let new_size = data.len();
2202        Array::from_vec(data, Shape::new(vec![new_size]))
2203    }
2204
2205    /// Delete elements at specified indices.
2206    ///
2207    /// # Examples
2208    ///
2209    /// ```
2210    /// # use jax_rs::{Array, Shape};
2211    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
2212    /// let result = a.delete(&[1, 3]);
2213    /// assert_eq!(result.to_vec(), vec![1.0, 3.0, 5.0]);
2214    /// ```
2215    pub fn delete(&self, indices: &[usize]) -> Array {
2216        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2217
2218        let data = self.to_vec();
2219        let mut result = Vec::new();
2220
2221        for (i, &val) in data.iter().enumerate() {
2222            if !indices.contains(&i) {
2223                result.push(val);
2224            }
2225        }
2226
2227        let new_size = result.len();
2228        Array::from_vec(result, Shape::new(vec![new_size]))
2229    }
2230
2231    /// Trim leading and trailing zeros.
2232    ///
2233    /// # Examples
2234    ///
2235    /// ```
2236    /// # use jax_rs::{Array, Shape};
2237    /// let a = Array::from_vec(vec![0.0, 0.0, 1.0, 2.0, 3.0, 0.0], Shape::new(vec![6]));
2238    /// let trimmed = a.trim_zeros();
2239    /// assert_eq!(trimmed.to_vec(), vec![1.0, 2.0, 3.0]);
2240    /// ```
2241    pub fn trim_zeros(&self) -> Array {
2242        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2243
2244        let data = self.to_vec();
2245
2246        // Find first non-zero
2247        let start = data.iter().position(|&x| x.abs() > 1e-10).unwrap_or(data.len());
2248
2249        // Find last non-zero
2250        let end = data.iter().rposition(|&x| x.abs() > 1e-10).map(|i| i + 1).unwrap_or(0);
2251
2252        if start >= end {
2253            return Array::zeros(Shape::new(vec![0]), DType::Float32);
2254        }
2255
2256        let result = data[start..end].to_vec();
2257        let new_size = result.len();
2258        Array::from_vec(result, Shape::new(vec![new_size]))
2259    }
2260
2261    /// Repeat each element along axis.
2262    ///
2263    /// # Examples
2264    ///
2265    /// ```
2266    /// # use jax_rs::{Array, Shape};
2267    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
2268    /// let repeated = a.repeat_elements(2);
2269    /// assert_eq!(repeated.to_vec(), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
2270    /// ```
2271    pub fn repeat_elements(&self, repeats: usize) -> Array {
2272        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2273
2274        let data = self.to_vec();
2275        let mut result = Vec::with_capacity(data.len() * repeats);
2276
2277        for &val in data.iter() {
2278            for _ in 0..repeats {
2279                result.push(val);
2280            }
2281        }
2282
2283        let new_size = result.len();
2284        Array::from_vec(result, Shape::new(vec![new_size]))
2285    }
2286
2287    /// Resize array to new shape, repeating or truncating as needed.
2288    ///
2289    /// # Examples
2290    ///
2291    /// ```
2292    /// # use jax_rs::{Array, Shape};
2293    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
2294    /// let resized = a.resize(5);
2295    /// assert_eq!(resized.to_vec(), vec![1.0, 2.0, 3.0, 1.0, 2.0]);
2296    /// ```
2297    pub fn resize(&self, new_size: usize) -> Array {
2298        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2299
2300        let data = self.to_vec();
2301        let mut result = Vec::with_capacity(new_size);
2302
2303        for i in 0..new_size {
2304            result.push(data[i % data.len()]);
2305        }
2306
2307        Array::from_vec(result, Shape::new(vec![new_size]))
2308    }
2309
2310    /// Compute correlation coefficient between two 1D arrays.
2311    ///
2312    /// # Examples
2313    ///
2314    /// ```
2315    /// # use jax_rs::{Array, Shape};
2316    /// let x = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
2317    /// let y = Array::from_vec(vec![2.0, 4.0, 6.0, 8.0], Shape::new(vec![4]));
2318    /// let corr = x.corrcoef(&y);
2319    /// assert!((corr - 1.0).abs() < 1e-5); // Perfect correlation
2320    /// ```
2321    pub fn corrcoef(&self, other: &Array) -> f32 {
2322        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2323        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
2324        assert_eq!(self.size(), other.size(), "Arrays must have same size");
2325
2326        let x = self.to_vec();
2327        let y = other.to_vec();
2328        let n = x.len() as f32;
2329
2330        // Compute means
2331        let x_mean: f32 = x.iter().sum::<f32>() / n;
2332        let y_mean: f32 = y.iter().sum::<f32>() / n;
2333
2334        // Compute covariance and standard deviations
2335        let mut cov = 0.0;
2336        let mut x_var = 0.0;
2337        let mut y_var = 0.0;
2338
2339        for (x_val, y_val) in x.iter().zip(y.iter()) {
2340            let x_diff = x_val - x_mean;
2341            let y_diff = y_val - y_mean;
2342            cov += x_diff * y_diff;
2343            x_var += x_diff * x_diff;
2344            y_var += y_diff * y_diff;
2345        }
2346
2347        // Correlation coefficient
2348        if x_var.abs() < 1e-10 || y_var.abs() < 1e-10 {
2349            return 0.0;
2350        }
2351
2352        cov / (x_var * y_var).sqrt()
2353    }
2354
2355    /// Return indices of non-zero elements in a flattened array.
2356    ///
2357    /// # Examples
2358    ///
2359    /// ```
2360    /// # use jax_rs::{Array, Shape};
2361    /// let a = Array::from_vec(vec![0.0, 1.0, 0.0, 3.0, 0.0, 5.0], Shape::new(vec![6]));
2362    /// let indices = a.flatnonzero();
2363    /// assert_eq!(indices, vec![1, 3, 5]);
2364    /// ```
2365    pub fn flatnonzero(&self) -> Vec<usize> {
2366        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2367
2368        let data = self.to_vec();
2369        data.iter()
2370            .enumerate()
2371            .filter_map(|(i, &val)| if val.abs() > 1e-10 { Some(i) } else { None })
2372            .collect()
2373    }
2374
2375    /// Tile the array by repeating it along each dimension.
2376    ///
2377    /// # Examples
2378    ///
2379    /// ```
2380    /// # use jax_rs::{Array, Shape};
2381    /// let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
2382    /// let b = a.tile_1d(3);
2383    /// assert_eq!(b.to_vec(), vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0]);
2384    /// ```
2385    pub fn tile_1d(&self, reps: usize) -> Array {
2386        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2387        let data = self.to_vec();
2388
2389        let mut result = Vec::with_capacity(data.len() * reps);
2390        for _ in 0..reps {
2391            result.extend_from_slice(&data);
2392        }
2393
2394        let new_size = result.len();
2395        Array::from_vec(result, Shape::new(vec![new_size]))
2396    }
2397
2398    /// Stack 1-D arrays as columns into a 2-D array.
2399    ///
2400    /// # Examples
2401    ///
2402    /// ```
2403    /// # use jax_rs::{Array, Shape};
2404    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
2405    /// let b = Array::from_vec(vec![4.0, 5.0, 6.0], Shape::new(vec![3]));
2406    /// let c = Array::column_stack(&[a, b]);
2407    /// assert_eq!(c.shape().as_slice(), &[3, 2]);
2408    /// // [[1, 4], [2, 5], [3, 6]]
2409    /// ```
2410    pub fn column_stack(arrays: &[Array]) -> Array {
2411        assert!(!arrays.is_empty(), "Need at least one array");
2412        assert_eq!(arrays[0].dtype(), DType::Float32, "Only Float32 supported");
2413
2414        let n_rows = arrays[0].size();
2415        let n_cols = arrays.len();
2416
2417        let mut result = Vec::with_capacity(n_rows * n_cols);
2418        for row_idx in 0..n_rows {
2419            for arr in arrays {
2420                let data = arr.to_vec();
2421                result.push(data[row_idx]);
2422            }
2423        }
2424
2425        Array::from_vec(result, Shape::new(vec![n_rows, n_cols]))
2426    }
2427
2428    /// Stack arrays in sequence vertically (row wise).
2429    ///
2430    /// Alias for vstack.
2431    ///
2432    /// # Examples
2433    ///
2434    /// ```
2435    /// # use jax_rs::{Array, Shape};
2436    /// let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
2437    /// let b = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
2438    /// let c = Array::row_stack(&[a, b]);
2439    /// assert_eq!(c.shape().as_slice(), &[2, 2]);
2440    /// // [[1, 2], [3, 4]]
2441    /// ```
2442    pub fn row_stack(arrays: &[Array]) -> Array {
2443        assert!(!arrays.is_empty(), "Need at least one array");
2444
2445        // Convert 1D arrays to 2D if needed
2446        let arrays_2d: Vec<Array> = arrays.iter().map(|arr| {
2447            if arr.shape().as_slice().len() == 1 {
2448                let size = arr.size();
2449                let data = arr.to_vec();
2450                Array::from_vec(data, Shape::new(vec![1, size]))
2451            } else {
2452                arr.clone()
2453            }
2454        }).collect();
2455
2456        Array::concatenate(&arrays_2d, 0)
2457    }
2458
2459    /// Stack arrays in sequence depth wise (along third axis).
2460    ///
2461    /// # Examples
2462    ///
2463    /// ```
2464    /// # use jax_rs::{Array, Shape};
2465    /// let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
2466    /// let b = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
2467    /// let c = Array::dstack(&[a, b]);
2468    /// assert_eq!(c.shape().as_slice(), &[1, 2, 2]);
2469    /// ```
2470    pub fn dstack(arrays: &[Array]) -> Array {
2471        assert!(!arrays.is_empty(), "Need at least one array");
2472
2473        // Convert to at least 3D
2474        let arrays_3d: Vec<Array> = arrays.iter().map(|arr| {
2475            let shape = arr.shape().as_slice();
2476            match shape.len() {
2477                1 => {
2478                    let size = arr.size();
2479                    let data = arr.to_vec();
2480                    Array::from_vec(data, Shape::new(vec![1, size, 1]))
2481                }
2482                2 => {
2483                    let data = arr.to_vec();
2484                    Array::from_vec(data, Shape::new(vec![shape[0], shape[1], 1]))
2485                }
2486                _ => arr.clone(),
2487            }
2488        }).collect();
2489
2490        Array::concatenate(&arrays_3d, 2)
2491    }
2492
2493    /// Compute the absolute value and return as a new array (alias for abs).
2494    ///
2495    /// # Examples
2496    ///
2497    /// ```
2498    /// # use jax_rs::{Array, Shape};
2499    /// let a = Array::from_vec(vec![1.0, -2.0, 3.0], Shape::new(vec![3]));
2500    /// let b = a.absolute();
2501    /// assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0]);
2502    /// ```
2503    pub fn absolute(&self) -> Array {
2504        self.abs()
2505    }
2506
2507    /// Clamp values to a specified range (alias for clip).
2508    ///
2509    /// # Examples
2510    ///
2511    /// ```
2512    /// # use jax_rs::{Array, Shape};
2513    /// let a = Array::from_vec(vec![1.0, 5.0, 10.0], Shape::new(vec![3]));
2514    /// let b = a.clamp(2.0, 8.0);
2515    /// assert_eq!(b.to_vec(), vec![2.0, 5.0, 8.0]);
2516    /// ```
2517    pub fn clamp(&self, min: f32, max: f32) -> Array {
2518        self.clip(min, max)
2519    }
2520
2521    /// Fill the diagonal of a 2D array with a scalar value.
2522    ///
2523    /// # Examples
2524    ///
2525    /// ```
2526    /// # use jax_rs::{Array, Shape, DType};
2527    /// let a = Array::zeros(Shape::new(vec![3, 3]), DType::Float32);
2528    /// let filled = a.fill_diagonal(5.0);
2529    /// assert_eq!(filled.to_vec(), vec![5.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 5.0]);
2530    /// ```
2531    pub fn fill_diagonal(&self, value: f32) -> Array {
2532        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2533
2534        let shape = self.shape();
2535        let dims = shape.as_slice();
2536        assert_eq!(dims.len(), 2, "fill_diagonal only supports 2D arrays");
2537
2538        let (rows, cols) = (dims[0], dims[1]);
2539        let data = self.to_vec();
2540        let mut result = data.clone();
2541
2542        let min_dim = rows.min(cols);
2543        for i in 0..min_dim {
2544            result[i * cols + i] = value;
2545        }
2546
2547        Array::from_vec(result, shape.clone())
2548    }
2549
2550    /// Evaluate a polynomial at specific values.
2551    /// Polynomial coefficients are in decreasing order (highest degree first).
2552    ///
2553    /// # Examples
2554    ///
2555    /// ```
2556    /// # use jax_rs::{Array, Shape};
2557    /// // Evaluate p(x) = 2x^2 + 3x + 1 at x = [1, 2, 3]
2558    /// let x = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
2559    /// let coeffs = Array::from_vec(vec![2.0, 3.0, 1.0], Shape::new(vec![3]));
2560    /// let result = x.polyval(&coeffs);
2561    /// // At x=1: 2(1)^2 + 3(1) + 1 = 6
2562    /// // At x=2: 2(4) + 3(2) + 1 = 15
2563    /// // At x=3: 2(9) + 3(3) + 1 = 28
2564    /// assert_eq!(result.to_vec(), vec![6.0, 15.0, 28.0]);
2565    /// ```
2566    pub fn polyval(&self, coeffs: &Array) -> Array {
2567        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2568        assert_eq!(coeffs.dtype(), DType::Float32, "Only Float32 supported");
2569        assert_eq!(coeffs.ndim(), 1, "Coefficients must be 1D");
2570
2571        let x_data = self.to_vec();
2572        let c_data = coeffs.to_vec();
2573
2574        let result_data: Vec<f32> = x_data
2575            .iter()
2576            .map(|&x| {
2577                // Horner's method for polynomial evaluation
2578                let mut result = 0.0;
2579                for &coeff in &c_data {
2580                    result = result * x + coeff;
2581                }
2582                result
2583            })
2584            .collect();
2585
2586        Array::from_vec(result_data, self.shape().clone())
2587    }
2588
2589    /// Add two polynomials.
2590    /// Polynomial coefficients are in decreasing order (highest degree first).
2591    ///
2592    /// # Examples
2593    ///
2594    /// ```
2595    /// # use jax_rs::{Array, Shape};
2596    /// // Add p(x) = 2x^2 + 3x + 1 and q(x) = x^2 + 2x + 3
2597    /// let p = Array::from_vec(vec![2.0, 3.0, 1.0], Shape::new(vec![3]));
2598    /// let q = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
2599    /// let sum = p.polyadd(&q);
2600    /// assert_eq!(sum.to_vec(), vec![3.0, 5.0, 4.0]);
2601    /// ```
2602    pub fn polyadd(&self, other: &Array) -> Array {
2603        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2604        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
2605        assert_eq!(self.ndim(), 1, "Polynomials must be 1D");
2606        assert_eq!(other.ndim(), 1, "Polynomials must be 1D");
2607
2608        let p_data = self.to_vec();
2609        let q_data = other.to_vec();
2610
2611        let max_len = p_data.len().max(q_data.len());
2612        let mut result = vec![0.0; max_len];
2613
2614        // Align from the right (lowest degree)
2615        let p_offset = max_len - p_data.len();
2616        let q_offset = max_len - q_data.len();
2617
2618        for (i, &val) in p_data.iter().enumerate() {
2619            result[p_offset + i] += val;
2620        }
2621
2622        for (i, &val) in q_data.iter().enumerate() {
2623            result[q_offset + i] += val;
2624        }
2625
2626        Array::from_vec(result, Shape::new(vec![max_len]))
2627    }
2628
2629    /// Multiply two polynomials.
2630    /// Polynomial coefficients are in decreasing order (highest degree first).
2631    ///
2632    /// # Examples
2633    ///
2634    /// ```
2635    /// # use jax_rs::{Array, Shape};
2636    /// // Multiply (x + 1) * (x + 2) = x^2 + 3x + 2
2637    /// let p = Array::from_vec(vec![1.0, 1.0], Shape::new(vec![2]));
2638    /// let q = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
2639    /// let prod = p.polymul(&q);
2640    /// assert_eq!(prod.to_vec(), vec![1.0, 3.0, 2.0]);
2641    /// ```
2642    pub fn polymul(&self, other: &Array) -> Array {
2643        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2644        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
2645        assert_eq!(self.ndim(), 1, "Polynomials must be 1D");
2646        assert_eq!(other.ndim(), 1, "Polynomials must be 1D");
2647
2648        let p = self.to_vec();
2649        let q = other.to_vec();
2650        let result_len = p.len() + q.len() - 1;
2651        let mut result = vec![0.0; result_len];
2652
2653        for (i, &pi) in p.iter().enumerate() {
2654            for (j, &qj) in q.iter().enumerate() {
2655                result[i + j] += pi * qj;
2656            }
2657        }
2658
2659        Array::from_vec(result, Shape::new(vec![result_len]))
2660    }
2661
2662    /// Differentiate a polynomial.
2663    /// Returns the polynomial representing the derivative.
2664    ///
2665    /// # Examples
2666    ///
2667    /// ```
2668    /// # use jax_rs::{Array, Shape};
2669    /// // d/dx (2x^2 + 3x + 1) = 4x + 3
2670    /// let p = Array::from_vec(vec![2.0, 3.0, 1.0], Shape::new(vec![3]));
2671    /// let dp = p.polyder();
2672    /// assert_eq!(dp.to_vec(), vec![4.0, 3.0]);
2673    /// ```
2674    pub fn polyder(&self) -> Array {
2675        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2676        assert_eq!(self.ndim(), 1, "Polynomial must be 1D");
2677
2678        let coeffs = self.to_vec();
2679        if coeffs.len() <= 1 {
2680            return Array::from_vec(vec![0.0], Shape::new(vec![1]));
2681        }
2682
2683        let n = coeffs.len() - 1;
2684        let mut result = Vec::with_capacity(n);
2685
2686        for (i, &c) in coeffs.iter().take(n).enumerate() {
2687            let degree = (n - i) as f32;
2688            result.push(c * degree);
2689        }
2690
2691        Array::from_vec(result, Shape::new(vec![n]))
2692    }
2693
2694    /// Subtract two polynomials.
2695    /// Polynomial coefficients are in decreasing order (highest degree first).
2696    ///
2697    /// # Examples
2698    ///
2699    /// ```
2700    /// # use jax_rs::{Array, Shape};
2701    /// let p = Array::from_vec(vec![3.0, 5.0, 4.0], Shape::new(vec![3]));
2702    /// let q = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
2703    /// let diff = p.polysub(&q);
2704    /// assert_eq!(diff.to_vec(), vec![2.0, 3.0, 1.0]);
2705    /// ```
2706    pub fn polysub(&self, other: &Array) -> Array {
2707        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2708        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
2709        assert_eq!(self.ndim(), 1, "Polynomials must be 1D");
2710        assert_eq!(other.ndim(), 1, "Polynomials must be 1D");
2711
2712        let p_data = self.to_vec();
2713        let q_data = other.to_vec();
2714
2715        let max_len = p_data.len().max(q_data.len());
2716        let mut result = vec![0.0; max_len];
2717
2718        let p_offset = max_len - p_data.len();
2719        let q_offset = max_len - q_data.len();
2720
2721        for (i, &val) in p_data.iter().enumerate() {
2722            result[p_offset + i] += val;
2723        }
2724
2725        for (i, &val) in q_data.iter().enumerate() {
2726            result[q_offset + i] -= val;
2727        }
2728
2729        Array::from_vec(result, Shape::new(vec![max_len]))
2730    }
2731
2732    /// Evaluate a piecewise-defined function.
2733    ///
2734    /// Applies different functions based on conditions. For each element,
2735    /// the first true condition determines which function to apply.
2736    ///
2737    /// # Arguments
2738    ///
2739    /// * `conditions` - Vector of condition arrays (booleans as 0.0/1.0)
2740    /// * `functions` - Vector of function output arrays corresponding to conditions
2741    ///
2742    /// # Examples
2743    ///
2744    /// ```
2745    /// # use jax_rs::{Array, Shape};
2746    /// let x = Array::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0], Shape::new(vec![5]));
2747    /// // Condition: x < 0
2748    /// let cond1 = Array::from_vec(vec![1.0, 1.0, 0.0, 0.0, 0.0], Shape::new(vec![5]));
2749    /// // Condition: x >= 0
2750    /// let cond2 = Array::from_vec(vec![0.0, 0.0, 1.0, 1.0, 1.0], Shape::new(vec![5]));
2751    /// // Function outputs (pre-computed)
2752    /// let func1 = Array::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0], Shape::new(vec![5])); // identity
2753    /// let func2 = Array::from_vec(vec![4.0, 1.0, 0.0, 1.0, 4.0], Shape::new(vec![5])); // x^2
2754    /// let result = x.piecewise(&[cond1, cond2], &[func1, func2]);
2755    /// // For x<0: use identity, for x>=0: use x^2
2756    /// assert_eq!(result.to_vec(), vec![-2.0, -1.0, 0.0, 1.0, 4.0]);
2757    /// ```
2758    pub fn piecewise(&self, conditions: &[Array], functions: &[Array]) -> Array {
2759        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2760        assert_eq!(conditions.len(), functions.len(), "Number of conditions must match number of functions");
2761        assert!(!conditions.is_empty(), "At least one condition required");
2762
2763        let n = self.size();
2764        for cond in conditions {
2765            assert_eq!(cond.size(), n, "Condition size must match array size");
2766        }
2767        for func in functions {
2768            assert_eq!(func.size(), n, "Function output size must match array size");
2769        }
2770
2771        let mut result = vec![0.0; n];
2772        let mut assigned = vec![false; n];
2773
2774        for (cond, func) in conditions.iter().zip(functions.iter()) {
2775            let cond_data = cond.to_vec();
2776            let func_data = func.to_vec();
2777
2778            for i in 0..n {
2779                if !assigned[i] && cond_data[i] != 0.0 {
2780                    result[i] = func_data[i];
2781                    assigned[i] = true;
2782                }
2783            }
2784        }
2785
2786        Array::from_vec(result, self.shape().clone())
2787    }
2788
2789    /// Place values into array at specified indices.
2790    ///
2791    /// Returns a new array with values inserted at the specified indices.
2792    ///
2793    /// # Examples
2794    ///
2795    /// ```
2796    /// # use jax_rs::{Array, Shape};
2797    /// let a = Array::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0], Shape::new(vec![5]));
2798    /// let mask = Array::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0], Shape::new(vec![5]));
2799    /// let values = vec![10.0, 20.0];
2800    /// let result = a.place(&mask, &values);
2801    /// assert_eq!(result.to_vec(), vec![0.0, 10.0, 0.0, 20.0, 0.0]);
2802    /// ```
2803    pub fn place(&self, mask: &Array, values: &[f32]) -> Array {
2804        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2805        assert_eq!(mask.dtype(), DType::Float32, "Only Float32 supported");
2806        assert_eq!(self.size(), mask.size(), "Array and mask must have same size");
2807
2808        let mut data = self.to_vec();
2809        let mask_data = mask.to_vec();
2810
2811        let mut value_idx = 0;
2812        for (i, &m) in mask_data.iter().enumerate() {
2813            if m != 0.0 && value_idx < values.len() {
2814                data[i] = values[value_idx];
2815                value_idx += 1;
2816            }
2817        }
2818
2819        Array::from_vec(data, self.shape().clone())
2820    }
2821
2822    /// Copy values from source to destination array.
2823    ///
2824    /// Returns a new array with values from source copied to corresponding positions.
2825    ///
2826    /// # Examples
2827    ///
2828    /// ```
2829    /// # use jax_rs::{Array, Shape};
2830    /// let dst = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
2831    /// let src = Array::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0], Shape::new(vec![5]));
2832    /// let mask = Array::from_vec(vec![0.0, 1.0, 1.0, 0.0, 1.0], Shape::new(vec![5]));
2833    /// let result = dst.copyto(&src, &mask);
2834    /// assert_eq!(result.to_vec(), vec![1.0, 20.0, 30.0, 4.0, 50.0]);
2835    /// ```
2836    pub fn copyto(&self, src: &Array, mask: &Array) -> Array {
2837        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2838        assert_eq!(src.dtype(), DType::Float32, "Only Float32 supported");
2839        assert_eq!(mask.dtype(), DType::Float32, "Only Float32 supported");
2840        assert_eq!(self.size(), src.size(), "Arrays must have same size");
2841        assert_eq!(self.size(), mask.size(), "Array and mask must have same size");
2842
2843        let mut data = self.to_vec();
2844        let src_data = src.to_vec();
2845        let mask_data = mask.to_vec();
2846
2847        for i in 0..data.len() {
2848            if mask_data[i] != 0.0 {
2849                data[i] = src_data[i];
2850            }
2851        }
2852
2853        Array::from_vec(data, self.shape().clone())
2854    }
2855
2856    /// Return the index of the maximum element along an axis and the max value.
2857    ///
2858    /// # Examples
2859    ///
2860    /// ```
2861    /// # use jax_rs::{Array, Shape};
2862    /// let a = Array::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], Shape::new(vec![5]));
2863    /// let (idx, val) = a.argmax_with_value();
2864    /// assert_eq!(idx, 4);
2865    /// assert!((val - 5.0).abs() < 1e-6);
2866    /// ```
2867    pub fn argmax_with_value(&self) -> (usize, f32) {
2868        let data = self.to_vec();
2869        let mut max_idx = 0;
2870        let mut max_val = f32::NEG_INFINITY;
2871
2872        for (i, &x) in data.iter().enumerate() {
2873            if x > max_val {
2874                max_val = x;
2875                max_idx = i;
2876            }
2877        }
2878
2879        (max_idx, max_val)
2880    }
2881
2882    /// Return the index of the minimum element along an axis and the min value.
2883    ///
2884    /// # Examples
2885    ///
2886    /// ```
2887    /// # use jax_rs::{Array, Shape};
2888    /// let a = Array::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], Shape::new(vec![5]));
2889    /// let (idx, val) = a.argmin_with_value();
2890    /// assert_eq!(idx, 1);
2891    /// assert!((val - 1.0).abs() < 1e-6);
2892    /// ```
2893    pub fn argmin_with_value(&self) -> (usize, f32) {
2894        let data = self.to_vec();
2895        let mut min_idx = 0;
2896        let mut min_val = f32::INFINITY;
2897
2898        for (i, &x) in data.iter().enumerate() {
2899            if x < min_val {
2900                min_val = x;
2901                min_idx = i;
2902            }
2903        }
2904
2905        (min_idx, min_val)
2906    }
2907
2908    /// Return an array with axes transposed to the given permutation.
2909    ///
2910    /// # Examples
2911    ///
2912    /// ```
2913    /// # use jax_rs::{Array, Shape};
2914    /// let a = Array::from_vec(
2915    ///     vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
2916    ///     Shape::new(vec![2, 3])
2917    /// );
2918    /// let b = a.permute(&[1, 0]);
2919    /// assert_eq!(b.shape().as_slice(), &[3, 2]);
2920    /// ```
2921    pub fn permute(&self, axes: &[usize]) -> Array {
2922        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
2923        let shape = self.shape().as_slice();
2924        assert_eq!(axes.len(), shape.len(), "axes must match number of dimensions");
2925
2926        // Build new shape
2927        let new_shape: Vec<usize> = axes.iter().map(|&a| shape[a]).collect();
2928
2929        // For 2D transpose
2930        if shape.len() == 2 && axes == [1, 0] {
2931            return self.transpose();
2932        }
2933
2934        // General case - use strides calculation
2935        let data = self.to_vec();
2936        let mut result = vec![0.0; data.len()];
2937
2938        // Calculate old strides
2939        let mut old_strides = vec![1usize; shape.len()];
2940        for i in (0..shape.len() - 1).rev() {
2941            old_strides[i] = old_strides[i + 1] * shape[i + 1];
2942        }
2943
2944        // Calculate new strides
2945        let mut new_strides = vec![1usize; new_shape.len()];
2946        for i in (0..new_shape.len() - 1).rev() {
2947            new_strides[i] = new_strides[i + 1] * new_shape[i + 1];
2948        }
2949
2950        // Permute strides
2951        let permuted_old_strides: Vec<usize> = axes.iter().map(|&a| old_strides[a]).collect();
2952
2953        // Copy data with permutation
2954        for new_idx in 0..data.len() {
2955            let mut old_idx = 0;
2956            let mut remainder = new_idx;
2957            for (d, &new_stride) in new_strides.iter().enumerate() {
2958                let coord = remainder / new_stride;
2959                remainder %= new_stride;
2960                old_idx += coord * permuted_old_strides[d];
2961            }
2962            result[new_idx] = data[old_idx];
2963        }
2964
2965        Array::from_vec(result, Shape::new(new_shape))
2966    }
2967
2968    /// Gather values along an axis using indices.
2969    ///
2970    /// This is a generalized form of indexing that allows selecting arbitrary
2971    /// indices along a specified axis.
2972    ///
2973    /// # Examples
2974    ///
2975    /// ```
2976    /// # use jax_rs::{Array, Shape};
2977    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Shape::new(vec![2, 3]));
2978    /// let indices = Array::from_vec(vec![0.0, 2.0], Shape::new(vec![2]));
2979    /// let result = a.gather(&indices, 1);
2980    /// // Gathers columns 0 and 2 from each row
2981    /// ```
2982    pub fn gather(&self, indices: &Array, axis: usize) -> Array {
2983        let shape = self.shape().as_slice();
2984        assert!(axis < shape.len(), "Axis out of bounds");
2985
2986        let indices_data: Vec<usize> = indices.to_vec().iter().map(|&x| x as usize).collect();
2987        let data = self.to_vec();
2988
2989        if axis == 0 && shape.len() == 1 {
2990            // Simple 1D case
2991            let result: Vec<f32> = indices_data.iter().map(|&i| data[i]).collect();
2992            return Array::from_vec(result, Shape::new(vec![indices_data.len()]));
2993        }
2994
2995        // For multi-dimensional arrays
2996        let mut result = Vec::new();
2997        let mut new_shape = shape.to_vec();
2998        new_shape[axis] = indices_data.len();
2999
3000        // Compute strides
3001        let mut strides: Vec<usize> = Vec::with_capacity(shape.len());
3002        let mut stride = 1;
3003        for &dim in shape.iter().rev() {
3004            strides.push(stride);
3005            stride *= dim;
3006        }
3007        strides.reverse();
3008
3009        let total_size: usize = new_shape.iter().product();
3010        result.reserve(total_size);
3011
3012        // Iterate through all output positions
3013        for out_idx in 0..total_size {
3014            // Compute output coordinates
3015            let mut coords = Vec::with_capacity(shape.len());
3016            let mut remainder = out_idx;
3017            for &dim in &new_shape {
3018                coords.push(remainder % dim);
3019                remainder /= dim;
3020            }
3021            coords.reverse();
3022
3023            // The coordinate at axis is an index into indices array
3024            let idx_in_indices = coords[axis];
3025            coords[axis] = indices_data[idx_in_indices];
3026
3027            // Compute input index
3028            let mut in_idx = 0;
3029            for (d, &coord) in coords.iter().enumerate() {
3030                in_idx += coord * strides[d];
3031            }
3032
3033            result.push(data[in_idx]);
3034        }
3035
3036        Array::from_vec(result, Shape::new(new_shape))
3037    }
3038
3039    /// Gather values using n-dimensional indices.
3040    ///
3041    /// # Examples
3042    ///
3043    /// ```
3044    /// # use jax_rs::{Array, Shape};
3045    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Shape::new(vec![2, 3]));
3046    /// let indices = vec![(0, 1), (1, 2)]; // (row, col) pairs
3047    /// let result = a.gather_nd(&indices);
3048    /// // Returns values at [0,1] and [1,2]
3049    /// ```
3050    pub fn gather_nd(&self, indices: &[(usize, usize)]) -> Array {
3051        let data = self.to_vec();
3052        let shape = self.shape().as_slice();
3053        assert_eq!(shape.len(), 2, "gather_nd only supports 2D arrays for now");
3054
3055        let cols = shape[1];
3056        let result: Vec<f32> = indices
3057            .iter()
3058            .map(|&(r, c)| data[r * cols + c])
3059            .collect();
3060
3061        Array::from_vec(result, Shape::new(vec![indices.len()]))
3062    }
3063
3064    /// Segment sum - sum elements by segment ID.
3065    ///
3066    /// # Examples
3067    ///
3068    /// ```
3069    /// # use jax_rs::{Array, Shape};
3070    /// let data = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
3071    /// let segment_ids = Array::from_vec(vec![0.0, 0.0, 1.0, 1.0, 2.0], Shape::new(vec![5]));
3072    /// let result = data.segment_sum(&segment_ids, 3);
3073    /// // segment 0: 1+2=3, segment 1: 3+4=7, segment 2: 5
3074    /// ```
3075    pub fn segment_sum(&self, segment_ids: &Array, num_segments: usize) -> Array {
3076        assert_eq!(self.size(), segment_ids.size(), "Data and segment_ids must have same size");
3077
3078        let data = self.to_vec();
3079        let ids: Vec<usize> = segment_ids.to_vec().iter().map(|&x| x as usize).collect();
3080
3081        let mut result = vec![0.0; num_segments];
3082        for (val, &seg_id) in data.iter().zip(ids.iter()) {
3083            if seg_id < num_segments {
3084                result[seg_id] += val;
3085            }
3086        }
3087
3088        Array::from_vec(result, Shape::new(vec![num_segments]))
3089    }
3090
3091    /// Segment mean - compute mean of elements by segment ID.
3092    pub fn segment_mean(&self, segment_ids: &Array, num_segments: usize) -> Array {
3093        assert_eq!(self.size(), segment_ids.size(), "Data and segment_ids must have same size");
3094
3095        let data = self.to_vec();
3096        let ids: Vec<usize> = segment_ids.to_vec().iter().map(|&x| x as usize).collect();
3097
3098        let mut sums = vec![0.0; num_segments];
3099        let mut counts = vec![0usize; num_segments];
3100
3101        for (val, &seg_id) in data.iter().zip(ids.iter()) {
3102            if seg_id < num_segments {
3103                sums[seg_id] += val;
3104                counts[seg_id] += 1;
3105            }
3106        }
3107
3108        let result: Vec<f32> = sums
3109            .iter()
3110            .zip(counts.iter())
3111            .map(|(&sum, &count)| if count > 0 { sum / count as f32 } else { 0.0 })
3112            .collect();
3113
3114        Array::from_vec(result, Shape::new(vec![num_segments]))
3115    }
3116
3117    /// Segment max - compute max of elements by segment ID.
3118    pub fn segment_max(&self, segment_ids: &Array, num_segments: usize) -> Array {
3119        assert_eq!(self.size(), segment_ids.size(), "Data and segment_ids must have same size");
3120
3121        let data = self.to_vec();
3122        let ids: Vec<usize> = segment_ids.to_vec().iter().map(|&x| x as usize).collect();
3123
3124        let mut result = vec![f32::NEG_INFINITY; num_segments];
3125
3126        for (val, &seg_id) in data.iter().zip(ids.iter()) {
3127            if seg_id < num_segments && *val > result[seg_id] {
3128                result[seg_id] = *val;
3129            }
3130        }
3131
3132        Array::from_vec(result, Shape::new(vec![num_segments]))
3133    }
3134
3135    /// Segment min - compute min of elements by segment ID.
3136    pub fn segment_min(&self, segment_ids: &Array, num_segments: usize) -> Array {
3137        assert_eq!(self.size(), segment_ids.size(), "Data and segment_ids must have same size");
3138
3139        let data = self.to_vec();
3140        let ids: Vec<usize> = segment_ids.to_vec().iter().map(|&x| x as usize).collect();
3141
3142        let mut result = vec![f32::INFINITY; num_segments];
3143
3144        for (val, &seg_id) in data.iter().zip(ids.iter()) {
3145            if seg_id < num_segments && *val < result[seg_id] {
3146                result[seg_id] = *val;
3147            }
3148        }
3149
3150        Array::from_vec(result, Shape::new(vec![num_segments]))
3151    }
3152
3153    /// Flip array along multiple axes.
3154    pub fn flip_axes(&self, axes: &[usize]) -> Array {
3155        let mut result = self.clone();
3156        for &axis in axes {
3157            result = result.flip(axis);
3158        }
3159        result
3160    }
3161
3162    /// Move multiple axes to new positions.
3163    pub fn moveaxis_multiple(&self, sources: &[usize], destinations: &[usize]) -> Array {
3164        assert_eq!(sources.len(), destinations.len(), "sources and destinations must have same length");
3165
3166        let mut result = self.clone();
3167        for (&src, &dst) in sources.iter().zip(destinations.iter()) {
3168            result = result.moveaxis(src, dst);
3169        }
3170        result
3171    }
3172
3173    /// Expand dimensions at multiple positions.
3174    pub fn expand_dims_multiple(&self, axes: &[usize]) -> Array {
3175        let mut sorted_axes = axes.to_vec();
3176        sorted_axes.sort();
3177
3178        let mut result = self.clone();
3179        for (i, &axis) in sorted_axes.iter().enumerate() {
3180            result = result.expand_dims(axis + i);
3181        }
3182        result
3183    }
3184
3185    /// Squeeze all axes with size 1.
3186    pub fn squeeze_all(&self) -> Array {
3187        let shape = self.shape().as_slice();
3188        let new_shape: Vec<usize> = shape.iter().cloned().filter(|&d| d != 1).collect();
3189
3190        if new_shape.is_empty() {
3191            // Result is scalar
3192            return Array::from_vec(self.to_vec(), Shape::new(vec![1]));
3193        }
3194
3195        self.reshape(Shape::new(new_shape))
3196    }
3197
3198    /// Unflatten array - reshape the first axis into multiple dimensions.
3199    pub fn unflatten(&self, dim: usize, sizes: &[usize]) -> Array {
3200        let shape = self.shape().as_slice();
3201        assert!(dim < shape.len(), "dim out of bounds");
3202        assert_eq!(
3203            sizes.iter().product::<usize>(),
3204            shape[dim],
3205            "sizes must multiply to the dimension size"
3206        );
3207
3208        let mut new_shape = Vec::with_capacity(shape.len() - 1 + sizes.len());
3209        new_shape.extend(&shape[..dim]);
3210        new_shape.extend(sizes);
3211        new_shape.extend(&shape[dim + 1..]);
3212
3213        self.reshape(Shape::new(new_shape))
3214    }
3215
3216    /// Repeat array elements along each axis.
3217    pub fn repeat_axis(&self, repeats: usize, axis: usize) -> Array {
3218        let shape = self.shape().as_slice();
3219        assert!(axis < shape.len(), "axis out of bounds");
3220
3221        if axis == 0 {
3222            // Repeat along first axis
3223            let data = self.to_vec();
3224            let chunk_size = self.size() / shape[0];
3225            let mut result = Vec::with_capacity(self.size() * repeats);
3226
3227            for chunk in data.chunks(chunk_size) {
3228                for _ in 0..repeats {
3229                    result.extend(chunk);
3230                }
3231            }
3232
3233            let mut new_shape = shape.to_vec();
3234            new_shape[axis] *= repeats;
3235
3236            Array::from_vec(result, Shape::new(new_shape))
3237        } else {
3238            // For other axes, transpose, repeat, transpose back
3239            // Simplified implementation
3240            let mut new_shape = shape.to_vec();
3241            new_shape[axis] *= repeats;
3242
3243            let data = self.to_vec();
3244            let mut result = Vec::with_capacity(new_shape.iter().product());
3245
3246            // Compute strides
3247            let inner_size: usize = shape[axis + 1..].iter().product();
3248            let outer_size: usize = shape[..axis].iter().product();
3249            let axis_size = shape[axis];
3250
3251            for outer in 0..outer_size {
3252                for ax in 0..axis_size {
3253                    for _ in 0..repeats {
3254                        let start = outer * axis_size * inner_size + ax * inner_size;
3255                        result.extend(&data[start..start + inner_size]);
3256                    }
3257                }
3258            }
3259
3260            Array::from_vec(result, Shape::new(new_shape))
3261        }
3262    }
3263
3264    /// N-dimensional tile - repeat array along each axis.
3265    pub fn tile_nd(&self, reps: &[usize]) -> Array {
3266        assert_eq!(reps.len(), self.ndim(), "reps must have same length as ndim");
3267
3268        let mut result = self.clone();
3269        for (axis, &rep) in reps.iter().enumerate() {
3270            if rep > 1 {
3271                result = result.repeat_axis(rep, axis);
3272            }
3273        }
3274        result
3275    }
3276}
3277
3278#[cfg(test)]
3279mod tests {
3280    use super::*;
3281
3282    #[test]
3283    fn test_concatenate_1d() {
3284        let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
3285        let b = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
3286        let c = Array::from_vec(vec![5.0, 6.0], Shape::new(vec![2]));
3287
3288        let result = Array::concatenate(&[a, b, c], 0);
3289        assert_eq!(result.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
3290    }
3291
3292    #[test]
3293    fn test_stack() {
3294        let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
3295        let b = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
3296
3297        let result = Array::stack(&[a, b], 0);
3298        assert_eq!(result.shape().as_slice(), &[2, 2]);
3299        assert_eq!(result.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
3300    }
3301
3302    #[test]
3303    fn test_split() {
3304        // Test splitting a 1D array
3305        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Shape::new(vec![6]));
3306        let parts = Array::split(&a, 3, 0);
3307
3308        assert_eq!(parts.len(), 3);
3309        assert_eq!(parts[0].to_vec(), vec![1.0, 2.0]);
3310        assert_eq!(parts[1].to_vec(), vec![3.0, 4.0]);
3311        assert_eq!(parts[2].to_vec(), vec![5.0, 6.0]);
3312    }
3313
3314    #[test]
3315    fn test_split_2d() {
3316        // Test splitting a 2D array
3317        let a = Array::from_vec(
3318            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
3319            Shape::new(vec![4, 2]),
3320        );
3321        let parts = Array::split(&a, 2, 0);
3322
3323        assert_eq!(parts.len(), 2);
3324        assert_eq!(parts[0].shape().as_slice(), &[2, 2]);
3325        assert_eq!(parts[0].to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
3326        assert_eq!(parts[1].to_vec(), vec![5.0, 6.0, 7.0, 8.0]);
3327    }
3328
3329    #[test]
3330    fn test_where_cond() {
3331        let cond =
3332            Array::from_vec(vec![1.0, 0.0, 1.0, 0.0], Shape::new(vec![4]));
3333        let x =
3334            Array::from_vec(vec![10.0, 20.0, 30.0, 40.0], Shape::new(vec![4]));
3335        let y = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
3336
3337        let result = Array::where_cond(&cond, &x, &y);
3338        assert_eq!(result.to_vec(), vec![10.0, 2.0, 30.0, 4.0]);
3339    }
3340
3341    #[test]
3342    fn test_where_cond_broadcast_scalar_condition() {
3343        let condition = Array::from_vec(vec![1.0], Shape::new(vec![1]));
3344        let x = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
3345        let y = Array::from_vec(vec![100.0, 200.0, 300.0], Shape::new(vec![3]));
3346        let result = Array::where_cond(&condition, &x, &y);
3347        assert_eq!(result.to_vec(), vec![10.0, 20.0, 30.0]);
3348    }
3349
3350    #[test]
3351    fn test_where_cond_broadcast_scalar_x() {
3352        let condition = Array::from_vec(vec![1.0, 0.0, 1.0], Shape::new(vec![3]));
3353        let x = Array::from_vec(vec![42.0], Shape::new(vec![1]));
3354        let y = Array::from_vec(vec![100.0, 200.0, 300.0], Shape::new(vec![3]));
3355        let result = Array::where_cond(&condition, &x, &y);
3356        assert_eq!(result.to_vec(), vec![42.0, 200.0, 42.0]);
3357    }
3358
3359    #[test]
3360    fn test_where_cond_broadcast_scalar_y() {
3361        let condition = Array::from_vec(vec![1.0, 0.0, 1.0], Shape::new(vec![3]));
3362        let x = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
3363        let y = Array::from_vec(vec![99.0], Shape::new(vec![1]));
3364        let result = Array::where_cond(&condition, &x, &y);
3365        assert_eq!(result.to_vec(), vec![10.0, 99.0, 30.0]);
3366    }
3367
3368    #[test]
3369    fn test_where_cond_2d() {
3370        let condition = Array::from_vec(
3371            vec![1.0, 0.0, 0.0, 1.0],
3372            Shape::new(vec![2, 2])
3373        );
3374        let x = Array::from_vec(
3375            vec![1.0, 2.0, 3.0, 4.0],
3376            Shape::new(vec![2, 2])
3377        );
3378        let y = Array::from_vec(
3379            vec![10.0, 20.0, 30.0, 40.0],
3380            Shape::new(vec![2, 2])
3381        );
3382        let result = Array::where_cond(&condition, &x, &y);
3383        assert_eq!(result.to_vec(), vec![1.0, 20.0, 30.0, 4.0]);
3384    }
3385
3386    #[test]
3387    fn test_where_cond_broadcast_2d() {
3388        let condition = Array::from_vec(vec![1.0, 0.0], Shape::new(vec![2]));
3389        let x = Array::from_vec(
3390            vec![1.0, 2.0, 3.0, 4.0],
3391            Shape::new(vec![2, 2])
3392        );
3393        let y = Array::from_vec(
3394            vec![10.0, 20.0, 30.0, 40.0],
3395            Shape::new(vec![2, 2])
3396        );
3397        let result = Array::where_cond(&condition, &x, &y);
3398        assert_eq!(result.to_vec(), vec![1.0, 20.0, 3.0, 40.0]);
3399    }
3400
3401    #[test]
3402    fn test_where_cond_negative_values() {
3403        let condition = Array::from_vec(vec![-5.0, 0.0, 3.14], Shape::new(vec![3]));
3404        let x = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
3405        let y = Array::from_vec(vec![100.0, 200.0, 300.0], Shape::new(vec![3]));
3406        let result = Array::where_cond(&condition, &x, &y);
3407        // -5.0 is non-zero (true), 0.0 is zero (false), 3.14 is non-zero (true)
3408        assert_eq!(result.to_vec(), vec![10.0, 200.0, 30.0]);
3409    }
3410
3411    #[test]
3412    fn test_select_basic() {
3413        let indices = Array::from_vec(vec![0.0, 1.0, 2.0, 1.0], Shape::new(vec![4]));
3414        let choice0 = Array::from_vec(vec![10.0, 10.0, 10.0, 10.0], Shape::new(vec![4]));
3415        let choice1 = Array::from_vec(vec![20.0, 20.0, 20.0, 20.0], Shape::new(vec![4]));
3416        let choice2 = Array::from_vec(vec![30.0, 30.0, 30.0, 30.0], Shape::new(vec![4]));
3417        let result = Array::select(&indices, &[choice0, choice1, choice2]);
3418        assert_eq!(result.to_vec(), vec![10.0, 20.0, 30.0, 20.0]);
3419    }
3420
3421    #[test]
3422    fn test_select_varying_values() {
3423        let indices = Array::from_vec(vec![0.0, 1.0, 0.0], Shape::new(vec![3]));
3424        let choice0 = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3425        let choice1 = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
3426        let result = Array::select(&indices, &[choice0, choice1]);
3427        // Index 0 selects from choice0: [1.0, _, 3.0]
3428        // Index 1 selects from choice1: [_, 20.0, _]
3429        assert_eq!(result.to_vec(), vec![1.0, 20.0, 3.0]);
3430    }
3431
3432    #[test]
3433    fn test_select_2d() {
3434        let indices = Array::from_vec(vec![0.0, 1.0, 1.0, 0.0], Shape::new(vec![2, 2]));
3435        let choice0 = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
3436        let choice1 = Array::from_vec(vec![10.0, 20.0, 30.0, 40.0], Shape::new(vec![2, 2]));
3437        let result = Array::select(&indices, &[choice0, choice1]);
3438        assert_eq!(result.to_vec(), vec![1.0, 20.0, 30.0, 4.0]);
3439    }
3440
3441    #[test]
3442    fn test_clip() {
3443        let a = Array::from_vec(
3444            vec![-5.0, 0.0, 5.0, 10.0, 15.0],
3445            Shape::new(vec![5]),
3446        );
3447        let clipped = a.clip(0.0, 10.0);
3448        assert_eq!(clipped.to_vec(), vec![0.0, 0.0, 5.0, 10.0, 10.0]);
3449    }
3450
3451    #[test]
3452    fn test_flip_1d() {
3453        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
3454        let flipped = a.flip(0);
3455        assert_eq!(flipped.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
3456    }
3457
3458    #[test]
3459    fn test_flip_2d() {
3460        let a = Array::from_vec(
3461            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3462            Shape::new(vec![3, 2]),
3463        );
3464        let flipped = a.flip(0);
3465        assert_eq!(flipped.to_vec(), vec![5.0, 6.0, 3.0, 4.0, 1.0, 2.0]);
3466    }
3467
3468    #[test]
3469    fn test_nan_to_num() {
3470        let a = Array::from_vec(
3471            vec![1.0, f32::NAN, f32::INFINITY, -f32::INFINITY, 5.0],
3472            Shape::new(vec![5]),
3473        );
3474        let result = a.nan_to_num(0.0, 1e10, -1e10);
3475        assert_eq!(result.to_vec()[0], 1.0);
3476        assert_eq!(result.to_vec()[1], 0.0);
3477        assert_eq!(result.to_vec()[2], 1e10);
3478        assert_eq!(result.to_vec()[3], -1e10);
3479        assert_eq!(result.to_vec()[4], 5.0);
3480    }
3481
3482    #[test]
3483    fn test_isnan() {
3484        let a = Array::from_vec(
3485            vec![1.0, f32::NAN, 3.0, f32::NAN, 5.0],
3486            Shape::new(vec![5]),
3487        );
3488        let result = a.isnan();
3489        assert_eq!(result.to_vec(), vec![0.0, 1.0, 0.0, 1.0, 0.0]);
3490    }
3491
3492    #[test]
3493    fn test_isinf() {
3494        let a = Array::from_vec(
3495            vec![1.0, f32::INFINITY, -f32::INFINITY, 3.0],
3496            Shape::new(vec![4]),
3497        );
3498        let result = a.isinf();
3499        assert_eq!(result.to_vec(), vec![0.0, 1.0, 1.0, 0.0]);
3500    }
3501
3502    #[test]
3503    fn test_isfinite() {
3504        let a = Array::from_vec(
3505            vec![1.0, f32::NAN, f32::INFINITY, 3.0],
3506            Shape::new(vec![4]),
3507        );
3508        let result = a.isfinite();
3509        assert_eq!(result.to_vec(), vec![1.0, 0.0, 0.0, 1.0]);
3510    }
3511
3512    #[test]
3513    fn test_clip_by_norm() {
3514        // Test case 1: norm exceeds max_norm
3515        let a = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
3516        let clipped = a.clip_by_norm(2.0);
3517        let result = clipped.to_vec();
3518        // Original norm is 5.0, should be scaled to 2.0
3519        // scale = 2.0 / 5.0 = 0.4
3520        assert!((result[0] - 1.2).abs() < 1e-5);
3521        assert!((result[1] - 1.6).abs() < 1e-5);
3522
3523        // Test case 2: norm is already below max_norm
3524        let b = Array::from_vec(vec![1.0, 1.0], Shape::new(vec![2]));
3525        let clipped2 = b.clip_by_norm(5.0);
3526        assert_eq!(clipped2.to_vec(), vec![1.0, 1.0]);
3527    }
3528
3529    #[test]
3530    fn test_ravel() {
3531        let a = Array::from_vec(
3532            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3533            Shape::new(vec![2, 3]),
3534        );
3535        let flat = a.ravel();
3536        assert_eq!(flat.shape().as_slice(), &[6]);
3537        assert_eq!(flat.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
3538    }
3539
3540    #[test]
3541    fn test_flatten() {
3542        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
3543        let flat = a.flatten();
3544        assert_eq!(flat.shape().as_slice(), &[4]);
3545        assert_eq!(flat.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
3546    }
3547
3548    #[test]
3549    fn test_atleast_1d() {
3550        // Scalar to 1D
3551        let a = Array::from_vec(vec![5.0], Shape::new(vec![]));
3552        let b = a.atleast_1d();
3553        assert_eq!(b.shape().as_slice(), &[1]);
3554
3555        // Already 1D
3556        let c = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
3557        let d = c.atleast_1d();
3558        assert_eq!(d.shape().as_slice(), &[2]);
3559    }
3560
3561    #[test]
3562    fn test_atleast_2d() {
3563        // 1D to 2D
3564        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3565        let b = a.atleast_2d();
3566        assert_eq!(b.shape().as_slice(), &[1, 3]);
3567
3568        // Already 2D
3569        let c = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![1, 2]));
3570        let d = c.atleast_2d();
3571        assert_eq!(d.shape().as_slice(), &[1, 2]);
3572    }
3573
3574    #[test]
3575    fn test_atleast_3d() {
3576        // 1D to 3D
3577        let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
3578        let b = a.atleast_3d();
3579        assert_eq!(b.shape().as_slice(), &[1, 2, 1]);
3580
3581        // 2D to 3D
3582        let c = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![1, 2]));
3583        let d = c.atleast_3d();
3584        assert_eq!(d.shape().as_slice(), &[1, 2, 1]);
3585    }
3586
3587    #[test]
3588    fn test_broadcast_to() {
3589        // Broadcast 1D to 2D
3590        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3591        let b = a.broadcast_to(Shape::new(vec![2, 3]));
3592        assert_eq!(b.shape().as_slice(), &[2, 3]);
3593        assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
3594    }
3595
3596    #[test]
3597    fn test_take() {
3598        let a = Array::from_vec(
3599            vec![10.0, 20.0, 30.0, 40.0, 50.0],
3600            Shape::new(vec![5]),
3601        );
3602        let indices = vec![0, 2, 4];
3603        let result = a.take(&indices);
3604        assert_eq!(result.to_vec(), vec![10.0, 30.0, 50.0]);
3605    }
3606
3607    #[test]
3608    fn test_nonzero() {
3609        let a = Array::from_vec(
3610            vec![0.0, 1.0, 0.0, 3.0, 0.0, 5.0],
3611            Shape::new(vec![6]),
3612        );
3613        let indices = a.nonzero();
3614        assert_eq!(indices, vec![1, 3, 5]);
3615    }
3616
3617    #[test]
3618    fn test_argwhere() {
3619        let a = Array::from_vec(vec![0.0, 1.0, 0.0, 1.0], Shape::new(vec![4]));
3620        let indices = a.argwhere();
3621        assert_eq!(indices, vec![1, 3]);
3622    }
3623
3624    #[test]
3625    fn test_compress() {
3626        let a = Array::from_vec(vec![10.0, 20.0, 30.0, 40.0], Shape::new(vec![4]));
3627        let condition =
3628            Array::from_vec(vec![1.0, 0.0, 1.0, 0.0], Shape::new(vec![4]));
3629        let result = a.compress(&condition);
3630        assert_eq!(result.to_vec(), vec![10.0, 30.0]);
3631    }
3632
3633    #[test]
3634    fn test_choose() {
3635        let choices = vec![
3636            Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3])),
3637            Array::from_vec(vec![100.0, 200.0, 300.0], Shape::new(vec![3])),
3638        ];
3639        let indices = vec![0, 1, 0];
3640        let result = Array::choose(&indices, &choices);
3641        assert_eq!(result.to_vec(), vec![10.0, 200.0, 30.0]);
3642    }
3643
3644    #[test]
3645    fn test_extract() {
3646        let a = Array::from_vec(
3647            vec![1.0, 2.0, 3.0, 4.0, 5.0],
3648            Shape::new(vec![5]),
3649        );
3650        let condition = Array::from_vec(
3651            vec![1.0, 0.0, 1.0, 0.0, 1.0],
3652            Shape::new(vec![5]),
3653        );
3654        let result = a.extract(&condition);
3655        assert_eq!(result.to_vec(), vec![1.0, 3.0, 5.0]);
3656    }
3657
3658    #[test]
3659    fn test_roll() {
3660        let a = Array::from_vec(
3661            vec![1.0, 2.0, 3.0, 4.0, 5.0],
3662            Shape::new(vec![5]),
3663        );
3664        let rolled = a.roll(2);
3665        assert_eq!(rolled.to_vec(), vec![4.0, 5.0, 1.0, 2.0, 3.0]);
3666
3667        // Test negative roll
3668        let rolled_neg = a.roll(-1);
3669        assert_eq!(rolled_neg.to_vec(), vec![2.0, 3.0, 4.0, 5.0, 1.0]);
3670    }
3671
3672    #[test]
3673    fn test_rot90() {
3674        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
3675
3676        // Rotate 90 degrees
3677        let rot1 = a.rot90(1);
3678        assert_eq!(rot1.to_vec(), vec![2.0, 4.0, 1.0, 3.0]);
3679
3680        // Rotate 180 degrees
3681        let rot2 = a.rot90(2);
3682        assert_eq!(rot2.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
3683
3684        // Rotate 270 degrees
3685        let rot3 = a.rot90(3);
3686        assert_eq!(rot3.to_vec(), vec![3.0, 1.0, 4.0, 2.0]);
3687    }
3688
3689    #[test]
3690    fn test_swapaxes() {
3691        let a = Array::from_vec(
3692            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3693            Shape::new(vec![2, 3]),
3694        );
3695        let swapped = a.swapaxes(0, 1);
3696        assert_eq!(swapped.shape().as_slice(), &[3, 2]);
3697        assert_eq!(swapped.to_vec(), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
3698    }
3699
3700    #[test]
3701    fn test_moveaxis() {
3702        let a = Array::from_vec(
3703            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3704            Shape::new(vec![1, 2, 3]),
3705        );
3706        let moved = a.moveaxis(2, 0);
3707        assert_eq!(moved.shape().as_slice(), &[3, 1, 2]);
3708    }
3709
3710    #[test]
3711    fn test_interp() {
3712        let xp = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3713        let fp = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
3714        let x = Array::from_vec(vec![1.5, 2.5], Shape::new(vec![2]));
3715        let result = Array::interp(&x, &xp, &fp);
3716        assert_eq!(result.to_vec(), vec![15.0, 25.0]);
3717
3718        // Test edge cases
3719        let x_edge = Array::from_vec(vec![0.5, 3.5], Shape::new(vec![2]));
3720        let result_edge = Array::interp(&x_edge, &xp, &fp);
3721        assert_eq!(result_edge.to_vec(), vec![10.0, 30.0]);
3722    }
3723
3724    #[test]
3725    fn test_lerp() {
3726        let a = Array::from_vec(vec![0.0, 10.0, 20.0], Shape::new(vec![3]));
3727        let b = Array::from_vec(vec![100.0, 110.0, 120.0], Shape::new(vec![3]));
3728        let result = a.lerp(&b, 0.5);
3729        assert_eq!(result.to_vec(), vec![50.0, 60.0, 70.0]);
3730
3731        // Test with weight = 0.0 (should return a)
3732        let result_0 = a.lerp(&b, 0.0);
3733        assert_eq!(result_0.to_vec(), a.to_vec());
3734
3735        // Test with weight = 1.0 (should return b)
3736        let result_1 = a.lerp(&b, 1.0);
3737        assert_eq!(result_1.to_vec(), b.to_vec());
3738    }
3739
3740    #[test]
3741    fn test_lerp_array() {
3742        let a = Array::from_vec(vec![0.0, 10.0, 20.0], Shape::new(vec![3]));
3743        let b = Array::from_vec(vec![100.0, 110.0, 120.0], Shape::new(vec![3]));
3744        let weights = Array::from_vec(vec![0.0, 0.5, 1.0], Shape::new(vec![3]));
3745        let result = a.lerp_array(&b, &weights);
3746        assert_eq!(result.to_vec(), vec![0.0, 60.0, 120.0]);
3747    }
3748
3749    #[test]
3750    fn test_put() {
3751        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
3752        let result = a.put(&[0, 2, 4], &[10.0, 30.0, 50.0]);
3753        assert_eq!(result.to_vec(), vec![10.0, 2.0, 30.0, 4.0, 50.0]);
3754        assert_eq!(result.shape().as_slice(), &[5]);
3755
3756        // Test with 2D array
3757        let a2d = Array::from_vec(
3758            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3759            Shape::new(vec![2, 3]),
3760        );
3761        let result2d = a2d.put(&[0, 5], &[100.0, 600.0]);
3762        assert_eq!(result2d.to_vec(), vec![100.0, 2.0, 3.0, 4.0, 5.0, 600.0]);
3763    }
3764
3765    #[test]
3766    fn test_scatter() {
3767        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
3768        let result = a.scatter(&[0, 2, 4], &[10.0, 30.0, 50.0]);
3769        assert_eq!(result.to_vec(), vec![10.0, 2.0, 30.0, 4.0, 50.0]);
3770        assert_eq!(result.shape().as_slice(), &[5]);
3771
3772        // Test with 2D array (flattened indexing)
3773        let a2d = Array::from_vec(
3774            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3775            Shape::new(vec![2, 3]),
3776        );
3777        let result2d = a2d.scatter(&[0, 5], &[100.0, 600.0]);
3778        assert_eq!(result2d.to_vec(), vec![100.0, 2.0, 3.0, 4.0, 5.0, 600.0]);
3779    }
3780
3781    #[test]
3782    fn test_scatter_add() {
3783        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
3784        let result = a.scatter_add(&[0, 2, 4], &[10.0, 30.0, 50.0]);
3785        assert_eq!(result.to_vec(), vec![11.0, 2.0, 33.0, 4.0, 55.0]);
3786
3787        // Test with duplicate indices (accumulates)
3788        let a2 = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3789        let result2 = a2.scatter_add(&[0, 0, 1], &[5.0, 3.0, 10.0]);
3790        assert_eq!(result2.to_vec(), vec![9.0, 12.0, 3.0]); // 1+5+3=9, 2+10=12, 3
3791    }
3792
3793    #[test]
3794    fn test_scatter_min() {
3795        let a = Array::from_vec(vec![5.0, 10.0, 15.0, 20.0, 25.0], Shape::new(vec![5]));
3796        let result = a.scatter_min(&[1, 2, 3], &[8.0, 20.0, 15.0]);
3797        assert_eq!(result.to_vec(), vec![5.0, 8.0, 15.0, 15.0, 25.0]);
3798
3799        // Test where update is larger (no change)
3800        let a2 = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3801        let result2 = a2.scatter_min(&[0, 1, 2], &[5.0, 10.0, 15.0]);
3802        assert_eq!(result2.to_vec(), vec![1.0, 2.0, 3.0]);
3803    }
3804
3805    #[test]
3806    fn test_scatter_max() {
3807        let a = Array::from_vec(vec![5.0, 10.0, 15.0, 20.0, 25.0], Shape::new(vec![5]));
3808        let result = a.scatter_max(&[1, 2, 3], &[12.0, 10.0, 25.0]);
3809        assert_eq!(result.to_vec(), vec![5.0, 12.0, 15.0, 25.0, 25.0]);
3810
3811        // Test where update is smaller (no change)
3812        let a2 = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
3813        let result2 = a2.scatter_max(&[0, 1, 2], &[5.0, 10.0, 15.0]);
3814        assert_eq!(result2.to_vec(), vec![10.0, 20.0, 30.0]);
3815    }
3816
3817    #[test]
3818    fn test_scatter_mul() {
3819        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
3820        let result = a.scatter_mul(&[1, 2, 3], &[2.0, 3.0, 0.5]);
3821        assert_eq!(result.to_vec(), vec![1.0, 4.0, 9.0, 2.0, 5.0]);
3822
3823        // Test with duplicate indices (accumulates multiplication)
3824        let a2 = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3825        let result2 = a2.scatter_mul(&[0, 0, 1], &[2.0, 3.0, 5.0]);
3826        assert_eq!(result2.to_vec(), vec![6.0, 10.0, 3.0]); // 1*2*3=6, 2*5=10, 3
3827    }
3828
3829    #[test]
3830    fn test_scatter_duplicate_indices() {
3831        // scatter: last update wins
3832        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3833        let result = a.scatter(&[0, 0], &[10.0, 20.0]);
3834        assert_eq!(result.to_vec(), vec![20.0, 2.0, 3.0]); // Last value wins
3835
3836        // scatter_add: accumulates
3837        let result2 = a.scatter_add(&[0, 0], &[10.0, 20.0]);
3838        assert_eq!(result2.to_vec(), vec![31.0, 2.0, 3.0]); // 1 + 10 + 20 = 31
3839    }
3840
3841    #[test]
3842    fn test_take_along_axis_1d() {
3843        let a = Array::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0], Shape::new(vec![5]));
3844        let indices = Array::from_vec(vec![0.0, 2.0, 4.0], Shape::new(vec![3]));
3845        let result = a.take_along_axis(&indices, 0);
3846        assert_eq!(result.to_vec(), vec![10.0, 30.0, 50.0]);
3847    }
3848
3849    #[test]
3850    fn test_take_along_axis_2d_axis1() {
3851        let a = Array::from_vec(
3852            vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0],
3853            Shape::new(vec![2, 3]),
3854        );
3855        // Take column 0 from row 0 and column 2 from row 1
3856        let indices = Array::from_vec(vec![0.0, 2.0], Shape::new(vec![2]));
3857        let result = a.take_along_axis(&indices, 1);
3858        assert_eq!(result.to_vec(), vec![10.0, 60.0]);
3859    }
3860
3861    #[test]
3862    fn test_take_along_axis_2d_axis0() {
3863        let a = Array::from_vec(
3864            vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0],
3865            Shape::new(vec![2, 3]),
3866        );
3867        // Take row 1 from column 0, row 0 from column 1, row 1 from column 2
3868        let indices = Array::from_vec(vec![1.0, 0.0, 1.0], Shape::new(vec![3]));
3869        let result = a.take_along_axis(&indices, 0);
3870        assert_eq!(result.to_vec(), vec![40.0, 20.0, 60.0]);
3871    }
3872
3873    #[test]
3874    fn test_take_along_axis_3d() {
3875        // 3D array: [2, 3, 4] - 2 batches, 3 rows, 4 cols
3876        let data: Vec<f32> = (0..24).map(|x| x as f32).collect();
3877        let a = Array::from_vec(data, Shape::new(vec![2, 3, 4]));
3878
3879        // Create 3D indices with same shape except axis dimension
3880        // For axis=2 (last dimension), we select which column to pick at each position
3881        // Shape of indices: [2, 3, 2] - pick 2 values per row
3882        let indices = Array::from_vec(
3883            vec![
3884                0.0, 3.0, // batch 0, row 0: pick cols 0, 3
3885                1.0, 2.0, // batch 0, row 1: pick cols 1, 2
3886                0.0, 1.0, // batch 0, row 2: pick cols 0, 1
3887                3.0, 0.0, // batch 1, row 0: pick cols 3, 0
3888                2.0, 1.0, // batch 1, row 1: pick cols 2, 1
3889                1.0, 3.0, // batch 1, row 2: pick cols 1, 3
3890            ],
3891            Shape::new(vec![2, 3, 2]),
3892        );
3893
3894        let result = a.take_along_axis(&indices, 2);
3895        assert_eq!(result.shape().as_slice(), &[2, 3, 2]);
3896
3897        // Verify values
3898        // batch 0: [[0,1,2,3], [4,5,6,7], [8,9,10,11]]
3899        // batch 1: [[12,13,14,15], [16,17,18,19], [20,21,22,23]]
3900        // Expected results:
3901        // batch 0, row 0: [0, 3]
3902        // batch 0, row 1: [5, 6]
3903        // batch 0, row 2: [8, 9]
3904        // batch 1, row 0: [15, 12]
3905        // batch 1, row 1: [18, 17]
3906        // batch 1, row 2: [21, 23]
3907        assert_eq!(
3908            result.to_vec(),
3909            vec![0.0, 3.0, 5.0, 6.0, 8.0, 9.0, 15.0, 12.0, 18.0, 17.0, 21.0, 23.0]
3910        );
3911    }
3912
3913    #[test]
3914    fn test_take_along_axis_3d_middle_axis() {
3915        // 3D array: [2, 3, 2]
3916        let data: Vec<f32> = (0..12).map(|x| x as f32).collect();
3917        let a = Array::from_vec(data, Shape::new(vec![2, 3, 2]));
3918
3919        // For axis=1 (middle dimension), select which row at each position
3920        // Shape of indices: [2, 2, 2] - pick 2 rows per batch
3921        let indices = Array::from_vec(
3922            vec![
3923                0.0, 2.0, // batch 0, position 0: pick rows 0, 2
3924                1.0, 0.0, // batch 0, position 1: pick rows 1, 0
3925                2.0, 1.0, // batch 1, position 0: pick rows 2, 1
3926                0.0, 2.0, // batch 1, position 1: pick rows 0, 2
3927            ],
3928            Shape::new(vec![2, 2, 2]),
3929        );
3930
3931        let result = a.take_along_axis(&indices, 1);
3932        assert_eq!(result.shape().as_slice(), &[2, 2, 2]);
3933
3934        // Array layout:
3935        // batch 0: [[0,1], [2,3], [4,5]]
3936        // batch 1: [[6,7], [8,9], [10,11]]
3937        // Expected (for each position, pick the row specified by index):
3938        // [0][0][0]: row=0, col=0 => 0
3939        // [0][0][1]: row=2, col=1 => 5
3940        // [0][1][0]: row=1, col=0 => 2
3941        // [0][1][1]: row=0, col=1 => 1
3942        // [1][0][0]: row=2, col=0 => 10
3943        // [1][0][1]: row=1, col=1 => 9
3944        // [1][1][0]: row=0, col=0 => 6
3945        // [1][1][1]: row=2, col=1 => 11
3946        assert_eq!(
3947            result.to_vec(),
3948            vec![0.0, 5.0, 2.0, 1.0, 10.0, 9.0, 6.0, 11.0]
3949        );
3950    }
3951
3952    #[test]
3953    fn test_broadcast_arrays_compatible() {
3954        // Test broadcasting arrays with compatible shapes
3955        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
3956        let b = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
3957        let c = Array::from_vec(vec![100.0], Shape::new(vec![1]));
3958
3959        let broadcasted = Array::broadcast_arrays(&[a, b, c]);
3960
3961        assert_eq!(broadcasted.len(), 3);
3962        assert_eq!(broadcasted[0].shape().as_slice(), &[3]);
3963        assert_eq!(broadcasted[1].shape().as_slice(), &[3]);
3964        assert_eq!(broadcasted[2].shape().as_slice(), &[3]);
3965
3966        // Verify values
3967        assert_eq!(broadcasted[0].to_vec(), vec![1.0, 2.0, 3.0]);
3968        assert_eq!(broadcasted[1].to_vec(), vec![10.0, 20.0, 30.0]);
3969        assert_eq!(broadcasted[2].to_vec(), vec![100.0, 100.0, 100.0]);
3970    }
3971
3972    #[test]
3973    fn test_broadcast_arrays_2d() {
3974        // Test broadcasting with 2D arrays
3975        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![1, 3]));
3976        let b = Array::from_vec(vec![10.0, 20.0], Shape::new(vec![2, 1]));
3977
3978        let broadcasted = Array::broadcast_arrays(&[a, b]);
3979
3980        assert_eq!(broadcasted.len(), 2);
3981        assert_eq!(broadcasted[0].shape().as_slice(), &[2, 3]);
3982        assert_eq!(broadcasted[1].shape().as_slice(), &[2, 3]);
3983
3984        // Verify broadcasting worked correctly
3985        assert_eq!(
3986            broadcasted[0].to_vec(),
3987            vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
3988        );
3989        assert_eq!(
3990            broadcasted[1].to_vec(),
3991            vec![10.0, 10.0, 10.0, 20.0, 20.0, 20.0]
3992        );
3993    }
3994
3995    #[test]
3996    fn test_broadcast_arrays_empty() {
3997        // Test with empty array list
3998        let broadcasted = Array::broadcast_arrays(&[]);
3999        assert_eq!(broadcasted.len(), 0);
4000    }
4001
4002    #[test]
4003    fn test_broadcast_arrays_single() {
4004        // Test with single array
4005        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
4006        let broadcasted = Array::broadcast_arrays(&[a.clone()]);
4007
4008        assert_eq!(broadcasted.len(), 1);
4009        assert_eq!(broadcasted[0].shape().as_slice(), &[3]);
4010        assert_eq!(broadcasted[0].to_vec(), vec![1.0, 2.0, 3.0]);
4011    }
4012
4013    #[test]
4014    #[should_panic(expected = "Cannot broadcast arrays with shapes")]
4015    fn test_broadcast_arrays_incompatible() {
4016        // Test incompatible shapes - should panic with improved error message
4017        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
4018        let b = Array::from_vec(vec![10.0, 20.0], Shape::new(vec![2]));
4019
4020        Array::broadcast_arrays(&[a, b]);
4021    }
4022
4023    #[test]
4024    #[should_panic(expected = "Cannot broadcast array of shape")]
4025    fn test_broadcast_to_error_message() {
4026        // Test improved error message in broadcast_to
4027        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
4028        a.broadcast_to(Shape::new(vec![2]));
4029    }
4030
4031    #[test]
4032    fn test_convolve() {
4033        let signal = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
4034        let kernel = Array::from_vec(vec![1.0, 0.0, -1.0], Shape::new(vec![3]));
4035        let conv = signal.convolve(&kernel);
4036        // After flipping kernel to [-1, 0, 1]:
4037        // [1*(-1) + 2*0 + 3*1, 2*(-1) + 3*0 + 4*1, 3*(-1) + 4*0 + 5*1]
4038        // = [-1+0+3, -2+0+4, -3+0+5] = [2, 2, 2]
4039        assert_eq!(conv.to_vec(), vec![2.0, 2.0, 2.0]);
4040    }
4041
4042    #[test]
4043    fn test_convolve_averaging() {
4044        let signal = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
4045        let kernel = Array::from_vec(vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0], Shape::new(vec![3]));
4046        let conv = signal.convolve(&kernel);
4047        // Moving average: [(1+2+3)/3, (2+3+4)/3, (3+4+5)/3] = [2, 3, 4]
4048        assert_eq!(conv.to_vec(), vec![2.0, 3.0, 4.0]);
4049    }
4050
4051    #[test]
4052    fn test_convolve_identity() {
4053        let signal = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
4054        let kernel = Array::from_vec(vec![1.0], Shape::new(vec![1]));
4055        let conv = signal.convolve(&kernel);
4056        // Identity kernel should return the signal
4057        assert_eq!(conv.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
4058    }
4059
4060    #[test]
4061    fn test_correlate() {
4062        let signal = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
4063        let template = Array::from_vec(vec![1.0, 2.0, 1.0], Shape::new(vec![3]));
4064        let corr = signal.correlate(&template);
4065        // [1*1 + 2*2 + 3*1, 2*1 + 3*2 + 4*1, 3*1 + 4*2 + 5*1]
4066        assert_eq!(corr.to_vec(), vec![8.0, 12.0, 16.0]);
4067    }
4068
4069    #[test]
4070    fn test_correlate_pattern_detection() {
4071        // Test finding a pattern in a signal
4072        let signal = Array::from_vec(vec![0.0, 0.0, 1.0, 2.0, 1.0, 0.0, 0.0], Shape::new(vec![7]));
4073        let pattern = Array::from_vec(vec![1.0, 2.0, 1.0], Shape::new(vec![3]));
4074        let corr = signal.correlate(&pattern);
4075        // Should have peak at position where pattern matches
4076        let max_idx = corr
4077            .to_vec()
4078            .iter()
4079            .enumerate()
4080            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
4081            .unwrap()
4082            .0;
4083        assert_eq!(max_idx, 2); // Pattern starts at index 2 in signal
4084    }
4085
4086    #[test]
4087    fn test_convolve_correlate_difference() {
4088        // Show the difference between convolution and correlation
4089        let signal = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
4090        let kernel = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
4091
4092        let conv = signal.convolve(&kernel);
4093        let corr = signal.correlate(&kernel);
4094
4095        // They should give different results for asymmetric kernels
4096        assert_ne!(conv.to_vec(), corr.to_vec());
4097    }
4098}