jax_rs/ops/
creation.rs

1//! Array creation functions.
2
3use crate::{buffer::Buffer, Array, DType, Shape};
4
5impl Array {
6    /// Create an array with evenly spaced values within a given interval.
7    ///
8    /// Equivalent to `arange(start, stop, step)` in NumPy/JAX.
9    ///
10    /// # Arguments
11    ///
12    /// * `start` - Start of interval (inclusive)
13    /// * `stop` - End of interval (exclusive)
14    /// * `step` - Spacing between values
15    ///
16    /// # Examples
17    ///
18    /// ```
19    /// # use jax_rs::{Array, DType};
20    /// let a = Array::arange(0.0, 10.0, 2.0, DType::Float32);
21    /// assert_eq!(a.to_vec(), vec![0.0, 2.0, 4.0, 6.0, 8.0]);
22    /// ```
23    pub fn arange(start: f32, stop: f32, step: f32, dtype: DType) -> Self {
24        assert_ne!(step, 0.0, "Step must be non-zero");
25        assert_eq!(dtype, DType::Float32, "Only Float32 supported for now");
26
27        let size = ((stop - start) / step).ceil().max(0.0) as usize;
28        if size == 0 {
29            return Array::zeros(Shape::new(vec![0]), dtype);
30        }
31
32        let data: Vec<f32> =
33            (0..size).map(|i| start + (i as f32) * step).collect();
34
35        let device = crate::default_device();
36        let buffer = Buffer::from_f32(data, device);
37        Array::from_buffer(buffer, Shape::new(vec![size]))
38    }
39
40    /// Return evenly spaced numbers over a specified interval.
41    ///
42    /// Returns `num` evenly spaced samples, calculated over the interval `[start, stop]`.
43    ///
44    /// # Arguments
45    ///
46    /// * `start` - Starting value
47    /// * `stop` - End value
48    /// * `num` - Number of samples to generate
49    /// * `endpoint` - If true, `stop` is the last sample. Otherwise excluded.
50    ///
51    /// # Examples
52    ///
53    /// ```
54    /// # use jax_rs::{Array, DType};
55    /// let a = Array::linspace(0.0, 1.0, 5, true, DType::Float32);
56    /// assert_eq!(a.to_vec(), vec![0.0, 0.25, 0.5, 0.75, 1.0]);
57    /// ```
58    pub fn linspace(
59        start: f32,
60        stop: f32,
61        num: usize,
62        endpoint: bool,
63        dtype: DType,
64    ) -> Self {
65        assert_eq!(dtype, DType::Float32, "Only Float32 supported for now");
66
67        if num == 0 {
68            return Array::zeros(Shape::new(vec![0]), dtype);
69        }
70
71        if num == 1 {
72            return Array::full(start, Shape::new(vec![1]), dtype);
73        }
74
75        if start == stop {
76            return Array::full(start, Shape::new(vec![num]), dtype);
77        }
78
79        let delta = stop - start;
80        let denom = if endpoint { num - 1 } else { num } as f32;
81
82        let data: Vec<f32> =
83            (0..num).map(|i| start + (i as f32) * delta / denom).collect();
84
85        let device = crate::default_device();
86        let buffer = Buffer::from_f32(data, device);
87        Array::from_buffer(buffer, Shape::new(vec![num]))
88    }
89
90    /// Return a 2-D array with ones on the diagonal and zeros elsewhere.
91    ///
92    /// # Arguments
93    ///
94    /// * `n` - Number of rows
95    /// * `m` - Number of columns (defaults to n if None)
96    ///
97    /// # Examples
98    ///
99    /// ```
100    /// # use jax_rs::{Array, DType};
101    /// let i = Array::eye(3, None, DType::Float32);
102    /// // [[1, 0, 0],
103    /// //  [0, 1, 0],
104    /// //  [0, 0, 1]]
105    /// ```
106    pub fn eye(n: usize, m: Option<usize>, dtype: DType) -> Self {
107        assert_eq!(dtype, DType::Float32, "Only Float32 supported for now");
108
109        let m = m.unwrap_or(n);
110        let size = n * m;
111        let mut data = vec![0.0; size];
112
113        // Set diagonal elements to 1
114        for i in 0..n.min(m) {
115            data[i * m + i] = 1.0;
116        }
117
118        let device = crate::default_device();
119        let buffer = Buffer::from_f32(data, device);
120        Array::from_buffer(buffer, Shape::new(vec![n, m]))
121    }
122
123    /// Return the identity matrix (square matrix with ones on diagonal).
124    ///
125    /// # Examples
126    ///
127    /// ```
128    /// # use jax_rs::{Array, DType};
129    /// let i = Array::identity(3, DType::Float32);
130    /// ```
131    pub fn identity(n: usize, dtype: DType) -> Self {
132        Self::eye(n, None, dtype)
133    }
134
135    /// Extract diagonal or construct diagonal array.
136    ///
137    /// If input is 1-D, constructs a 2-D array with the input on the diagonal.
138    /// If input is 2-D, extracts the diagonal.
139    ///
140    /// # Examples
141    ///
142    /// ```
143    /// # use jax_rs::{Array, Shape};
144    /// // Construct diagonal matrix from 1-D array
145    /// let v = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
146    /// let d = Array::diag(&v, 0);
147    /// assert_eq!(d.shape().as_slice(), &[3, 3]);
148    /// assert_eq!(d.to_vec(), vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
149    /// ```
150    pub fn diag(v: &Self, k: i32) -> Self {
151        assert_eq!(v.dtype(), DType::Float32, "Only Float32 supported");
152
153        if v.ndim() == 1 {
154            // Construct diagonal matrix
155            let n = v.size();
156            let offset = k.unsigned_abs() as usize;
157            let matrix_size = n + offset;
158
159            let mut data = vec![0.0; matrix_size * matrix_size];
160            let v_data = v.to_vec();
161
162            for (i, &val) in v_data.iter().enumerate() {
163                let (row, col) =
164                    if k >= 0 { (i, i + offset) } else { (i + offset, i) };
165                data[row * matrix_size + col] = val;
166            }
167
168            Self::from_vec(data, Shape::new(vec![matrix_size, matrix_size]))
169        } else if v.ndim() == 2 {
170            // Extract diagonal
171            let shape = v.shape().as_slice();
172            let (rows, cols) = (shape[0], shape[1]);
173            let data = v.to_vec();
174
175            let diag_len = if k >= 0 {
176                (cols as i32 - k).min(rows as i32).max(0) as usize
177            } else {
178                (rows as i32 + k).min(cols as i32).max(0) as usize
179            };
180
181            let mut diag_data = Vec::with_capacity(diag_len);
182
183            for i in 0..diag_len {
184                let (row, col) = if k >= 0 {
185                    (i, i + k as usize)
186                } else {
187                    (i + (-k) as usize, i)
188                };
189                diag_data.push(data[row * cols + col]);
190            }
191
192            Self::from_vec(diag_data, Shape::new(vec![diag_len]))
193        } else {
194            panic!("diag only supports 1-D and 2-D arrays");
195        }
196    }
197
198    /// Lower triangle of an array.
199    ///
200    /// # Examples
201    ///
202    /// ```
203    /// # use jax_rs::{Array, Shape};
204    /// let m = Array::from_vec(
205    ///     vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
206    ///     Shape::new(vec![3, 3])
207    /// );
208    /// let lower = m.tril(0);
209    /// assert_eq!(lower.to_vec(), vec![1.0, 0.0, 0.0, 4.0, 5.0, 0.0, 7.0, 8.0, 9.0]);
210    /// ```
211    pub fn tril(&self, k: i32) -> Self {
212        assert_eq!(self.ndim(), 2, "tril only supports 2-D arrays");
213        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
214
215        let shape = self.shape().as_slice();
216        let (rows, cols) = (shape[0], shape[1]);
217        let data = self.to_vec();
218
219        let mut result = Vec::with_capacity(data.len());
220
221        for i in 0..rows {
222            for j in 0..cols {
223                let val = if (j as i32) <= (i as i32 + k) {
224                    data[i * cols + j]
225                } else {
226                    0.0
227                };
228                result.push(val);
229            }
230        }
231
232        Self::from_vec(result, self.shape().clone())
233    }
234
235    /// Upper triangle of an array.
236    ///
237    /// # Examples
238    ///
239    /// ```
240    /// # use jax_rs::{Array, Shape};
241    /// let m = Array::from_vec(
242    ///     vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
243    ///     Shape::new(vec![3, 3])
244    /// );
245    /// let upper = m.triu(0);
246    /// assert_eq!(upper.to_vec(), vec![1.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0, 0.0, 9.0]);
247    /// ```
248    pub fn triu(&self, k: i32) -> Self {
249        assert_eq!(self.ndim(), 2, "triu only supports 2-D arrays");
250        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
251
252        let shape = self.shape().as_slice();
253        let (rows, cols) = (shape[0], shape[1]);
254        let data = self.to_vec();
255
256        let mut result = Vec::with_capacity(data.len());
257
258        for i in 0..rows {
259            for j in 0..cols {
260                let val = if (j as i32) >= (i as i32 + k) {
261                    data[i * cols + j]
262                } else {
263                    0.0
264                };
265                result.push(val);
266            }
267        }
268
269        Self::from_vec(result, self.shape().clone())
270    }
271
272    /// Lower triangular matrix with ones on the diagonal and below.
273    ///
274    /// # Examples
275    ///
276    /// ```
277    /// # use jax_rs::{Array, DType, Shape};
278    /// let tri = Array::tri(3, None, 0, DType::Float32);
279    /// assert_eq!(tri.to_vec(), vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0]);
280    /// ```
281    pub fn tri(n: usize, m: Option<usize>, k: i32, dtype: DType) -> Self {
282        assert_eq!(dtype, DType::Float32, "Only Float32 supported");
283
284        let cols = m.unwrap_or(n);
285        let mut data = Vec::with_capacity(n * cols);
286
287        for i in 0..n {
288            for j in 0..cols {
289                let val = if (j as i32) <= (i as i32 + k) { 1.0 } else { 0.0 };
290                data.push(val);
291            }
292        }
293
294        Self::from_vec(data, Shape::new(vec![n, cols]))
295    }
296
297    /// Create array with same shape as another, filled with zeros.
298    ///
299    /// # Examples
300    ///
301    /// ```
302    /// # use jax_rs::{Array, Shape, DType};
303    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
304    /// let b = Array::zeros_like(&a);
305    /// assert_eq!(b.to_vec(), vec![0.0, 0.0, 0.0]);
306    /// ```
307    pub fn zeros_like(other: &Array) -> Array {
308        Array::zeros(other.shape().clone(), other.dtype())
309    }
310
311    /// Create array with same shape as another, filled with ones.
312    ///
313    /// # Examples
314    ///
315    /// ```
316    /// # use jax_rs::{Array, Shape, DType};
317    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
318    /// let b = Array::ones_like(&a);
319    /// assert_eq!(b.to_vec(), vec![1.0, 1.0, 1.0]);
320    /// ```
321    pub fn ones_like(other: &Array) -> Array {
322        Array::ones(other.shape().clone(), other.dtype())
323    }
324
325    /// Create array with same shape as another, filled with a constant value.
326    ///
327    /// # Examples
328    ///
329    /// ```
330    /// # use jax_rs::{Array, Shape, DType};
331    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
332    /// let b = Array::full_like(&a, 42.0);
333    /// assert_eq!(b.to_vec(), vec![42.0, 42.0, 42.0]);
334    /// ```
335    pub fn full_like(other: &Array, value: f32) -> Array {
336        Array::full(value, other.shape().clone(), other.dtype())
337    }
338
339    /// Repeat array along specified axis.
340    ///
341    /// # Examples
342    ///
343    /// ```
344    /// # use jax_rs::{Array, Shape};
345    /// let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
346    /// let b = a.repeat(3, 0);
347    /// assert_eq!(b.to_vec(), vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0]);
348    /// ```
349    pub fn repeat(&self, repeats: usize, axis: usize) -> Array {
350        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
351        assert!(axis < self.ndim(), "Axis out of bounds");
352
353        let shape = self.shape().as_slice();
354        let data = self.to_vec();
355
356        // For 1D case
357        if self.ndim() == 1 {
358            let mut result = Vec::with_capacity(data.len() * repeats);
359            for &val in data.iter() {
360                for _ in 0..repeats {
361                    result.push(val);
362                }
363            }
364            return Array::from_vec(result, Shape::new(vec![shape[0] * repeats]));
365        }
366
367        // For higher dimensions, only support axis 0 for now
368        assert_eq!(axis, 0, "repeat only supports axis=0 for multi-dimensional arrays");
369
370        let slice_size = data.len() / shape[0];
371        let mut result = Vec::with_capacity(data.len() * repeats);
372
373        for i in 0..shape[0] {
374            let start = i * slice_size;
375            let end = start + slice_size;
376            for _ in 0..repeats {
377                result.extend_from_slice(&data[start..end]);
378            }
379        }
380
381        let mut result_shape = shape.to_vec();
382        result_shape[axis] *= repeats;
383        Array::from_vec(result, Shape::new(result_shape))
384    }
385
386    /// Tile array by repeating it multiple times.
387    ///
388    /// # Examples
389    ///
390    /// ```
391    /// # use jax_rs::{Array, Shape};
392    /// let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
393    /// let b = a.tile(3);
394    /// assert_eq!(b.to_vec(), vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0]);
395    /// ```
396    pub fn tile(&self, reps: usize) -> Array {
397        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
398
399        let data = self.to_vec();
400        let mut result = Vec::with_capacity(data.len() * reps);
401
402        for _ in 0..reps {
403            result.extend_from_slice(&data);
404        }
405
406        let shape = self.shape().as_slice();
407        let mut result_shape = shape.to_vec();
408        result_shape[0] *= reps;
409
410        Array::from_vec(result, Shape::new(result_shape))
411    }
412
413    /// Create coordinate matrices from coordinate vectors.
414    ///
415    /// # Examples
416    ///
417    /// ```
418    /// # use jax_rs::{Array, Shape};
419    /// let x = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
420    /// let y = Array::from_vec(vec![3.0, 4.0, 5.0], Shape::new(vec![3]));
421    /// let (xx, yy) = Array::meshgrid(&x, &y);
422    /// assert_eq!(xx.shape().as_slice(), &[3, 2]);
423    /// assert_eq!(yy.shape().as_slice(), &[3, 2]);
424    /// ```
425    pub fn meshgrid(x: &Array, y: &Array) -> (Array, Array) {
426        assert_eq!(x.ndim(), 1, "meshgrid requires 1D arrays");
427        assert_eq!(y.ndim(), 1, "meshgrid requires 1D arrays");
428
429        let x_data = x.to_vec();
430        let y_data = y.to_vec();
431        let nx = x_data.len();
432        let ny = y_data.len();
433
434        // Create XX: repeat x along rows
435        let mut xx_data = Vec::with_capacity(nx * ny);
436        for _ in 0..ny {
437            xx_data.extend_from_slice(&x_data);
438        }
439
440        // Create YY: repeat each y value nx times
441        let mut yy_data = Vec::with_capacity(nx * ny);
442        for &y_val in y_data.iter() {
443            for _ in 0..nx {
444                yy_data.push(y_val);
445            }
446        }
447
448        let xx = Array::from_vec(xx_data, Shape::new(vec![ny, nx]));
449        let yy = Array::from_vec(yy_data, Shape::new(vec![ny, nx]));
450
451        (xx, yy)
452    }
453
454    /// Generate arrays of indices for each dimension.
455    ///
456    /// Returns a vector of arrays, one for each dimension, containing the indices.
457    ///
458    /// # Examples
459    ///
460    /// ```
461    /// # use jax_rs::{Array, Shape};
462    /// let indices = Array::indices(&[2, 3]);
463    /// assert_eq!(indices[0].shape().as_slice(), &[2, 3]);
464    /// assert_eq!(indices[1].shape().as_slice(), &[2, 3]);
465    /// ```
466    pub fn indices(dimensions: &[usize]) -> Vec<Array> {
467        let total_size: usize = dimensions.iter().product();
468        let mut result = Vec::with_capacity(dimensions.len());
469
470        for (dim_idx, &dim_size) in dimensions.iter().enumerate() {
471            let mut data = Vec::with_capacity(total_size);
472
473            // Calculate stride for this dimension
474            let stride: usize = dimensions.iter().skip(dim_idx + 1).product();
475
476            for i in 0..total_size {
477                let idx = (i / stride) % dim_size;
478                data.push(idx as f32);
479            }
480
481            result.push(Array::from_vec(data, Shape::new(dimensions.to_vec())));
482        }
483
484        result
485    }
486
487    /// Convert a flat index to multi-dimensional coordinates.
488    ///
489    /// # Examples
490    ///
491    /// ```
492    /// # use jax_rs::{Array, Shape};
493    /// let shape = Shape::new(vec![3, 4]);
494    /// let coords = Array::unravel_index(5, &shape);
495    /// assert_eq!(coords, vec![1, 1]);
496    /// ```
497    pub fn unravel_index(index: usize, shape: &Shape) -> Vec<usize> {
498        let dims = shape.as_slice();
499        let mut coords = vec![0; dims.len()];
500        let mut idx = index;
501
502        for i in (0..dims.len()).rev() {
503            coords[i] = idx % dims[i];
504            idx /= dims[i];
505        }
506
507        coords
508    }
509
510    /// Convert multi-dimensional coordinates to a flat index.
511    ///
512    /// # Examples
513    ///
514    /// ```
515    /// # use jax_rs::{Array, Shape};
516    /// let shape = Shape::new(vec![3, 4]);
517    /// let index = Array::ravel_multi_index(&[1, 2], &shape);
518    /// assert_eq!(index, 6);
519    /// ```
520    pub fn ravel_multi_index(multi_index: &[usize], shape: &Shape) -> usize {
521        let dims = shape.as_slice();
522        assert_eq!(
523            multi_index.len(),
524            dims.len(),
525            "Index dimensions must match shape"
526        );
527
528        let mut index = 0;
529        let mut stride = 1;
530
531        for i in (0..dims.len()).rev() {
532            assert!(
533                multi_index[i] < dims[i],
534                "Index out of bounds at dimension {}", i
535            );
536            index += multi_index[i] * stride;
537            stride *= dims[i];
538        }
539
540        index
541    }
542
543    /// Return indices for the main diagonal of an n-by-n array.
544    ///
545    /// # Examples
546    ///
547    /// ```
548    /// # use jax_rs::Array;
549    /// let (rows, cols) = Array::diag_indices(3);
550    /// assert_eq!(rows, vec![0, 1, 2]);
551    /// assert_eq!(cols, vec![0, 1, 2]);
552    /// ```
553    pub fn diag_indices(n: usize) -> (Vec<usize>, Vec<usize>) {
554        let indices: Vec<usize> = (0..n).collect();
555        (indices.clone(), indices)
556    }
557
558    /// Return indices for the lower triangle of an n-by-n array.
559    ///
560    /// # Arguments
561    ///
562    /// * `n` - Size of the arrays for which the indices are returned
563    /// * `k` - Diagonal offset (0 for main diagonal, positive for above, negative for below)
564    ///
565    /// # Examples
566    ///
567    /// ```
568    /// # use jax_rs::Array;
569    /// let (rows, cols) = Array::tril_indices(3, 0);
570    /// assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
571    /// assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
572    /// ```
573    pub fn tril_indices(n: usize, k: isize) -> (Vec<usize>, Vec<usize>) {
574        let mut rows = Vec::new();
575        let mut cols = Vec::new();
576
577        for i in 0..n {
578            for j in 0..n {
579                if (j as isize) <= (i as isize + k) {
580                    rows.push(i);
581                    cols.push(j);
582                }
583            }
584        }
585
586        (rows, cols)
587    }
588
589    /// Return indices for the upper triangle of an n-by-n array.
590    ///
591    /// # Arguments
592    ///
593    /// * `n` - Size of the arrays for which the indices are returned
594    /// * `k` - Diagonal offset (0 for main diagonal, positive for above, negative for below)
595    ///
596    /// # Examples
597    ///
598    /// ```
599    /// # use jax_rs::Array;
600    /// let (rows, cols) = Array::triu_indices(3, 0);
601    /// assert_eq!(rows, vec![0, 0, 0, 1, 1, 2]);
602    /// assert_eq!(cols, vec![0, 1, 2, 1, 2, 2]);
603    /// ```
604    pub fn triu_indices(n: usize, k: isize) -> (Vec<usize>, Vec<usize>) {
605        let mut rows = Vec::new();
606        let mut cols = Vec::new();
607
608        for i in 0..n {
609            for j in 0..n {
610                if (j as isize) >= (i as isize + k) {
611                    rows.push(i);
612                    cols.push(j);
613                }
614            }
615        }
616
617        (rows, cols)
618    }
619
620    /// Return numbers spaced evenly on a log scale (geometric progression).
621    ///
622    /// # Examples
623    ///
624    /// ```
625    /// # use jax_rs::{Array, DType};
626    /// let a = Array::geomspace(1.0, 1000.0, 4, DType::Float32);
627    /// // Result: [1.0, 10.0, 100.0, 1000.0]
628    /// ```
629    pub fn geomspace(start: f32, stop: f32, num: usize, dtype: DType) -> Self {
630        assert_eq!(dtype, DType::Float32, "Only Float32 supported for now");
631        assert!(num > 0, "Number of samples must be positive");
632        assert!(start > 0.0 && stop > 0.0, "Start and stop must be positive for geomspace");
633
634        if num == 1 {
635            return Array::from_vec(vec![start], Shape::new(vec![1]));
636        }
637
638        let log_start = start.ln();
639        let log_stop = stop.ln();
640        let step = (log_stop - log_start) / (num - 1) as f32;
641
642        let mut data = Vec::with_capacity(num);
643        for i in 0..num {
644            data.push((log_start + step * i as f32).exp());
645        }
646
647        let device = crate::default_device();
648        let buffer = Buffer::from_f32(data, device);
649        Array::from_buffer(buffer, Shape::new(vec![num]))
650    }
651
652    /// Return numbers spaced evenly on a log scale.
653    ///
654    /// # Examples
655    ///
656    /// ```
657    /// # use jax_rs::{Array, DType};
658    /// let a = Array::logspace(0.0, 3.0, 4, DType::Float32);
659    /// // Result: [1.0, 10.0, 100.0, 1000.0] (10^0 to 10^3)
660    /// ```
661    pub fn logspace(start: f32, stop: f32, num: usize, dtype: DType) -> Self {
662        assert_eq!(dtype, DType::Float32, "Only Float32 supported for now");
663        assert!(num > 0, "Number of samples must be positive");
664
665        if num == 1 {
666            return Array::from_vec(vec![10_f32.powf(start)], Shape::new(vec![1]));
667        }
668
669        let step = (stop - start) / (num - 1) as f32;
670        let mut data = Vec::with_capacity(num);
671        for i in 0..num {
672            data.push(10_f32.powf(start + step * i as f32));
673        }
674
675        let device = crate::default_device();
676        let buffer = Buffer::from_f32(data, device);
677        Array::from_buffer(buffer, Shape::new(vec![num]))
678    }
679
680    /// Create empty array with same shape (uninitialized memory).
681    /// Note: In this implementation, we return zeros.
682    ///
683    /// # Examples
684    ///
685    /// ```
686    /// # use jax_rs::{Array, Shape};
687    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
688    /// let b = a.empty_like();
689    /// assert_eq!(b.shape().as_slice(), &[3]);
690    /// ```
691    pub fn empty_like(&self) -> Array {
692        // In practice, we return zeros for safety
693        Array::zeros(self.shape().clone(), self.dtype())
694    }
695
696    /// Check if array is C-contiguous (row-major order).
697    ///
698    /// # Examples
699    ///
700    /// ```
701    /// # use jax_rs::{Array, Shape};
702    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
703    /// assert!(a.is_contiguous());
704    /// ```
705    pub fn is_contiguous(&self) -> bool {
706        // Our arrays are always contiguous
707        true
708    }
709
710    /// Check if array is Fortran-contiguous (column-major order).
711    /// Note: Our arrays are always C-contiguous.
712    pub fn is_fortran_contiguous(&self) -> bool {
713        // Single-dimension arrays are both C and Fortran contiguous
714        self.ndim() <= 1
715    }
716
717    /// Return a contiguous array in memory (C order).
718    ///
719    /// # Examples
720    ///
721    /// ```
722    /// # use jax_rs::{Array, Shape};
723    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
724    /// let b = a.ascontiguousarray();
725    /// assert!(b.is_contiguous());
726    /// ```
727    pub fn ascontiguousarray(&self) -> Array {
728        // Already contiguous, just clone
729        self.clone()
730    }
731
732    /// Create a Hamming window of given length.
733    ///
734    /// # Examples
735    ///
736    /// ```
737    /// # use jax_rs::Array;
738    /// let w = Array::hamming(5);
739    /// assert_eq!(w.shape().as_slice(), &[5]);
740    /// ```
741    pub fn hamming(n: usize) -> Array {
742        let data: Vec<f32> = (0..n)
743            .map(|i| {
744                0.54 - 0.46 * (2.0 * std::f32::consts::PI * i as f32 / (n - 1) as f32).cos()
745            })
746            .collect();
747        let buffer = Buffer::from_f32(data, crate::default_device());
748        Array::from_buffer(buffer, Shape::new(vec![n]))
749    }
750
751    /// Create a Hanning window of given length.
752    ///
753    /// # Examples
754    ///
755    /// ```
756    /// # use jax_rs::Array;
757    /// let w = Array::hanning(5);
758    /// assert_eq!(w.shape().as_slice(), &[5]);
759    /// ```
760    pub fn hanning(n: usize) -> Array {
761        let data: Vec<f32> = (0..n)
762            .map(|i| {
763                0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / (n - 1) as f32).cos())
764            })
765            .collect();
766        let buffer = Buffer::from_f32(data, crate::default_device());
767        Array::from_buffer(buffer, Shape::new(vec![n]))
768    }
769
770    /// Create a Blackman window of given length.
771    ///
772    /// # Examples
773    ///
774    /// ```
775    /// # use jax_rs::Array;
776    /// let w = Array::blackman(5);
777    /// assert_eq!(w.shape().as_slice(), &[5]);
778    /// ```
779    pub fn blackman(n: usize) -> Array {
780        let data: Vec<f32> = (0..n)
781            .map(|i| {
782                let x = i as f32 / (n - 1) as f32;
783                0.42 - 0.5 * (2.0 * std::f32::consts::PI * x).cos()
784                    + 0.08 * (4.0 * std::f32::consts::PI * x).cos()
785            })
786            .collect();
787        let buffer = Buffer::from_f32(data, crate::default_device());
788        Array::from_buffer(buffer, Shape::new(vec![n]))
789    }
790
791    /// Create a Kaiser window of given length and beta parameter.
792    ///
793    /// # Examples
794    ///
795    /// ```
796    /// # use jax_rs::Array;
797    /// let w = Array::kaiser(5, 5.0);
798    /// assert_eq!(w.shape().as_slice(), &[5]);
799    /// ```
800    pub fn kaiser(n: usize, beta: f32) -> Array {
801        // Approximate Bessel I0 function
802        fn i0(x: f32) -> f32 {
803            let ax = x.abs();
804            if ax < 3.75 {
805                let y = (x / 3.75).powi(2);
806                1.0 + y * (3.5156229 + y * (3.0899424 + y * (1.2067492
807                    + y * (0.2659732 + y * (0.0360768 + y * 0.0045813)))))
808            } else {
809                let y = 3.75 / ax;
810                (ax.exp() / ax.sqrt()) * (0.398_942_3 + y * (0.01328592
811                    + y * (0.00225319 + y * (-0.00157565 + y * (0.00916281
812                    + y * (-0.02057706 + y * (0.02635537 + y * (-0.01647633
813                    + y * 0.00392377))))))))
814            }
815        }
816
817        let data: Vec<f32> = (0..n)
818            .map(|i| {
819                let x = 2.0 * i as f32 / (n - 1) as f32 - 1.0;
820                i0(beta * (1.0 - x * x).sqrt()) / i0(beta)
821            })
822            .collect();
823        let buffer = Buffer::from_f32(data, crate::default_device());
824        Array::from_buffer(buffer, Shape::new(vec![n]))
825    }
826
827    /// Create a Bartlett (triangular) window of given length.
828    ///
829    /// # Examples
830    ///
831    /// ```
832    /// # use jax_rs::Array;
833    /// let w = Array::bartlett(5);
834    /// assert_eq!(w.shape().as_slice(), &[5]);
835    /// ```
836    pub fn bartlett(n: usize) -> Array {
837        let data: Vec<f32> = (0..n)
838            .map(|i| {
839                let x = i as f32;
840                let half = (n - 1) as f32 / 2.0;
841                1.0 - ((x - half) / half).abs()
842            })
843            .collect();
844        let buffer = Buffer::from_f32(data, crate::default_device());
845        Array::from_buffer(buffer, Shape::new(vec![n]))
846    }
847
848    /// Create a flat top window of given length.
849    ///
850    /// # Examples
851    ///
852    /// ```
853    /// # use jax_rs::Array;
854    /// let w = Array::flattop(5);
855    /// assert_eq!(w.shape().as_slice(), &[5]);
856    /// ```
857    pub fn flattop(n: usize) -> Array {
858        let a0 = 0.21557895;
859        let a1 = 0.41663158;
860        let a2 = 0.277_263_16;
861        let a3 = 0.083578947;
862        let a4 = 0.006947368;
863
864        let data: Vec<f32> = (0..n)
865            .map(|i| {
866                let x = 2.0 * std::f32::consts::PI * i as f32 / (n - 1) as f32;
867                a0 - a1 * x.cos() + a2 * (2.0 * x).cos()
868                   - a3 * (3.0 * x).cos() + a4 * (4.0 * x).cos()
869            })
870            .collect();
871        let buffer = Buffer::from_f32(data, crate::default_device());
872        Array::from_buffer(buffer, Shape::new(vec![n]))
873    }
874
875    /// Create a triangular window of given length.
876    ///
877    /// # Examples
878    ///
879    /// ```
880    /// # use jax_rs::Array;
881    /// let w = Array::triang(5);
882    /// assert_eq!(w.shape().as_slice(), &[5]);
883    /// assert!((w.to_vec()[2] - 1.0).abs() < 1e-6); // Peak at center
884    /// ```
885    pub fn triang(n: usize) -> Array {
886        let data: Vec<f32> = (0..n)
887            .map(|i| {
888                let half = (n as f32 + 1.0) / 2.0;
889                if i as f32 + 1.0 <= half {
890                    2.0 * (i as f32 + 1.0) / (n as f32 + 1.0)
891                } else {
892                    2.0 - 2.0 * (i as f32 + 1.0) / (n as f32 + 1.0)
893                }
894            })
895            .collect();
896        let buffer = Buffer::from_f32(data, crate::default_device());
897        Array::from_buffer(buffer, Shape::new(vec![n]))
898    }
899}
900
901#[cfg(test)]
902mod tests {
903    use super::*;
904    use approx::assert_abs_diff_eq;
905
906    #[test]
907    fn test_arange() {
908        let a = Array::arange(0.0, 10.0, 2.0, DType::Float32);
909        assert_eq!(a.to_vec(), vec![0.0, 2.0, 4.0, 6.0, 8.0]);
910
911        let b = Array::arange(0.0, 5.0, 1.0, DType::Float32);
912        assert_eq!(b.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
913
914        let c = Array::arange(1.0, 2.0, 0.25, DType::Float32);
915        assert_eq!(c.to_vec(), vec![1.0, 1.25, 1.5, 1.75]);
916    }
917
918    #[test]
919    fn test_arange_negative_step() {
920        let a = Array::arange(10.0, 0.0, -2.0, DType::Float32);
921        assert_eq!(a.to_vec(), vec![10.0, 8.0, 6.0, 4.0, 2.0]);
922    }
923
924    #[test]
925    fn test_arange_empty() {
926        let a = Array::arange(0.0, 0.0, 1.0, DType::Float32);
927        assert_eq!(a.size(), 0);
928    }
929
930    #[test]
931    #[should_panic(expected = "Step must be non-zero")]
932    fn test_arange_zero_step() {
933        let _a = Array::arange(0.0, 10.0, 0.0, DType::Float32);
934    }
935
936    #[test]
937    fn test_linspace() {
938        let a = Array::linspace(0.0, 1.0, 5, true, DType::Float32);
939        let expected = vec![0.0, 0.25, 0.5, 0.75, 1.0];
940        for (i, &val) in a.to_vec().iter().enumerate() {
941            assert_abs_diff_eq!(val, expected[i], epsilon = 1e-6);
942        }
943    }
944
945    #[test]
946    fn test_linspace_no_endpoint() {
947        let a = Array::linspace(0.0, 1.0, 5, false, DType::Float32);
948        let expected = vec![0.0, 0.2, 0.4, 0.6, 0.8];
949        for (i, &val) in a.to_vec().iter().enumerate() {
950            assert_abs_diff_eq!(val, expected[i], epsilon = 1e-6);
951        }
952    }
953
954    #[test]
955    fn test_linspace_single() {
956        let a = Array::linspace(5.0, 10.0, 1, true, DType::Float32);
957        assert_eq!(a.to_vec(), vec![5.0]);
958    }
959
960    #[test]
961    fn test_linspace_same_start_stop() {
962        let a = Array::linspace(5.0, 5.0, 10, true, DType::Float32);
963        assert!(a.to_vec().iter().all(|&x| x == 5.0));
964    }
965
966    #[test]
967    fn test_eye() {
968        let i = Array::eye(3, None, DType::Float32);
969        assert_eq!(i.shape().as_slice(), &[3, 3]);
970        assert_eq!(
971            i.to_vec(),
972            vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
973        );
974    }
975
976    #[test]
977    fn test_eye_rectangular() {
978        let i = Array::eye(2, Some(4), DType::Float32);
979        assert_eq!(i.shape().as_slice(), &[2, 4]);
980        assert_eq!(i.to_vec(), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
981    }
982
983    #[test]
984    fn test_identity() {
985        let i = Array::identity(4, DType::Float32);
986        assert_eq!(i.shape().as_slice(), &[4, 4]);
987        // Check diagonal
988        let data = i.to_vec();
989        for idx in 0..4 {
990            assert_eq!(data[idx * 4 + idx], 1.0);
991        }
992        // Check off-diagonal (sample a few)
993        assert_eq!(data[1], 0.0);
994        assert_eq!(data[2], 0.0);
995        assert_eq!(data[4], 0.0);
996    }
997
998    #[test]
999    fn test_indices() {
1000        let indices = Array::indices(&[2, 3]);
1001        assert_eq!(indices.len(), 2);
1002        assert_eq!(indices[0].shape().as_slice(), &[2, 3]);
1003        assert_eq!(indices[1].shape().as_slice(), &[2, 3]);
1004        // First dimension varies along rows
1005        assert_eq!(indices[0].to_vec(), vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1006        // Second dimension varies along columns
1007        assert_eq!(indices[1].to_vec(), vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0]);
1008    }
1009
1010    #[test]
1011    fn test_unravel_index() {
1012        let shape = Shape::new(vec![3, 4]);
1013        assert_eq!(Array::unravel_index(0, &shape), vec![0, 0]);
1014        assert_eq!(Array::unravel_index(5, &shape), vec![1, 1]);
1015        assert_eq!(Array::unravel_index(11, &shape), vec![2, 3]);
1016    }
1017
1018    #[test]
1019    fn test_ravel_multi_index() {
1020        let shape = Shape::new(vec![3, 4]);
1021        assert_eq!(Array::ravel_multi_index(&[0, 0], &shape), 0);
1022        assert_eq!(Array::ravel_multi_index(&[1, 2], &shape), 6);
1023        assert_eq!(Array::ravel_multi_index(&[2, 3], &shape), 11);
1024    }
1025
1026    #[test]
1027    fn test_diag_indices() {
1028        let (rows, cols) = Array::diag_indices(3);
1029        assert_eq!(rows, vec![0, 1, 2]);
1030        assert_eq!(cols, vec![0, 1, 2]);
1031    }
1032
1033    #[test]
1034    fn test_tril_indices() {
1035        let (rows, cols) = Array::tril_indices(3, 0);
1036        assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
1037        assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
1038
1039        // Test with offset
1040        let (rows2, cols2) = Array::tril_indices(3, 1);
1041        assert_eq!(rows2, vec![0, 0, 1, 1, 1, 2, 2, 2]);
1042        assert_eq!(cols2, vec![0, 1, 0, 1, 2, 0, 1, 2]);
1043    }
1044
1045    #[test]
1046    fn test_triu_indices() {
1047        let (rows, cols) = Array::triu_indices(3, 0);
1048        assert_eq!(rows, vec![0, 0, 0, 1, 1, 2]);
1049        assert_eq!(cols, vec![0, 1, 2, 1, 2, 2]);
1050
1051        // Test with offset k=-1 (includes first subdiagonal)
1052        let (rows2, cols2) = Array::triu_indices(3, -1);
1053        assert_eq!(rows2, vec![0, 0, 0, 1, 1, 1, 2, 2]);
1054        assert_eq!(cols2, vec![0, 1, 2, 0, 1, 2, 1, 2]);
1055    }
1056}