Skip to main content

ferray_core/creation/
mod.rs

1// ferray-core: Array creation functions (REQ-16, REQ-17, REQ-18, REQ-19)
2//
3// Mirrors numpy's array creation routines: zeros, ones, full, empty,
4// arange, linspace, logspace, geomspace, eye, identity, diag, etc.
5
6use std::mem::MaybeUninit;
7
8use crate::array::owned::Array;
9use crate::dimension::{Dimension, Ix1, Ix2, IxDyn};
10use crate::dtype::Element;
11use crate::error::{FerrayError, FerrayResult};
12
13// ============================================================================
14// REQ-16: Basic creation functions
15// ============================================================================
16
17/// Create an array from a flat vector and a shape (C-order).
18///
19/// This is the primary "array constructor" — analogous to `numpy.array()` when
20/// given a flat sequence plus a shape.
21///
22/// # Errors
23/// Returns `FerrayError::ShapeMismatch` if `data.len()` does not equal the
24/// product of the shape dimensions.
25pub fn array<T: Element, D: Dimension>(dim: D, data: Vec<T>) -> FerrayResult<Array<T, D>> {
26    Array::from_vec(dim, data)
27}
28
29/// Interpret existing data as an array without copying (if possible).
30///
31/// This is equivalent to `numpy.asarray()`. Since Rust ownership rules
32/// require moving the data, this creates an owned array from the vector.
33///
34/// # Errors
35/// Returns `FerrayError::ShapeMismatch` if lengths don't match.
36pub fn asarray<T: Element, D: Dimension>(dim: D, data: Vec<T>) -> FerrayResult<Array<T, D>> {
37    Array::from_vec(dim, data)
38}
39
40/// Create an array from a byte buffer, interpreting bytes as elements of type `T`.
41///
42/// Analogous to `numpy.frombuffer()`.
43///
44/// # Errors
45/// Returns `FerrayError::InvalidValue` if the buffer length is not a multiple
46/// of `size_of::<T>()`, or if the resulting length does not match the shape.
47pub fn frombuffer<T: Element, D: Dimension>(dim: D, buf: &[u8]) -> FerrayResult<Array<T, D>> {
48    let elem_size = std::mem::size_of::<T>();
49    if elem_size == 0 {
50        return Err(FerrayError::invalid_value("zero-sized type"));
51    }
52    if buf.len() % elem_size != 0 {
53        return Err(FerrayError::invalid_value(format!(
54            "buffer length {} is not a multiple of element size {}",
55            buf.len(),
56            elem_size,
57        )));
58    }
59    let n_elems = buf.len() / elem_size;
60    let expected = dim.size();
61    if n_elems != expected {
62        return Err(FerrayError::shape_mismatch(format!(
63            "buffer contains {} elements but shape {:?} requires {}",
64            n_elems,
65            dim.as_slice(),
66            expected,
67        )));
68    }
69    // Copy bytes element-by-element via from_ne_bytes equivalent
70    let mut data = Vec::with_capacity(n_elems);
71    for i in 0..n_elems {
72        let start = i * elem_size;
73        let end = start + elem_size;
74        let slice = &buf[start..end];
75        // SAFETY: We're reading elem_size bytes and interpreting as T.
76        // T: Element implies T: Clone + 'static, and we're copying from
77        // a properly-sized byte buffer.
78        let val = unsafe {
79            let mut val = MaybeUninit::<T>::uninit();
80            std::ptr::copy_nonoverlapping(slice.as_ptr(), val.as_mut_ptr() as *mut u8, elem_size);
81            val.assume_init()
82        };
83        data.push(val);
84    }
85    Array::from_vec(dim, data)
86}
87
88/// Create a 1-D array from an iterator.
89///
90/// Analogous to `numpy.fromiter()`.
91///
92/// # Errors
93/// This function always succeeds (returns `Ok`).
94pub fn fromiter<T: Element>(iter: impl IntoIterator<Item = T>) -> FerrayResult<Array<T, Ix1>> {
95    Array::from_iter_1d(iter)
96}
97
98/// Create an array filled with zeros.
99///
100/// Analogous to `numpy.zeros()`.
101pub fn zeros<T: Element, D: Dimension>(dim: D) -> FerrayResult<Array<T, D>> {
102    Array::zeros(dim)
103}
104
105/// Create an array filled with ones.
106///
107/// Analogous to `numpy.ones()`.
108pub fn ones<T: Element, D: Dimension>(dim: D) -> FerrayResult<Array<T, D>> {
109    Array::ones(dim)
110}
111
112/// Create an array filled with a given value.
113///
114/// Analogous to `numpy.full()`.
115pub fn full<T: Element, D: Dimension>(dim: D, fill_value: T) -> FerrayResult<Array<T, D>> {
116    Array::from_elem(dim, fill_value)
117}
118
119/// Create an array with the same shape as `other`, filled with zeros.
120///
121/// Analogous to `numpy.zeros_like()`.
122pub fn zeros_like<T: Element, D: Dimension>(other: &Array<T, D>) -> FerrayResult<Array<T, D>> {
123    Array::zeros(other.dim().clone())
124}
125
126/// Create an array with the same shape as `other`, filled with ones.
127///
128/// Analogous to `numpy.ones_like()`.
129pub fn ones_like<T: Element, D: Dimension>(other: &Array<T, D>) -> FerrayResult<Array<T, D>> {
130    Array::ones(other.dim().clone())
131}
132
133/// Create an array with the same shape as `other`, filled with `fill_value`.
134///
135/// Analogous to `numpy.full_like()`.
136pub fn full_like<T: Element, D: Dimension>(
137    other: &Array<T, D>,
138    fill_value: T,
139) -> FerrayResult<Array<T, D>> {
140    Array::from_elem(other.dim().clone(), fill_value)
141}
142
143// ============================================================================
144// REQ-17: empty() returning MaybeUninit
145// ============================================================================
146
147/// An array whose elements have not been initialized.
148///
149/// The caller must call [`assume_init`](UninitArray::assume_init) after
150/// filling all elements.
151pub struct UninitArray<T: Element, D: Dimension> {
152    data: Vec<MaybeUninit<T>>,
153    dim: D,
154}
155
156impl<T: Element, D: Dimension> UninitArray<T, D> {
157    /// Shape as a slice.
158    #[inline]
159    pub fn shape(&self) -> &[usize] {
160        self.dim.as_slice()
161    }
162
163    /// Total number of elements.
164    #[inline]
165    pub fn size(&self) -> usize {
166        self.data.len()
167    }
168
169    /// Number of dimensions.
170    #[inline]
171    pub fn ndim(&self) -> usize {
172        self.dim.ndim()
173    }
174
175    /// Get a mutable raw pointer to the underlying data.
176    ///
177    /// Use this to fill the array element-by-element before calling
178    /// `assume_init()`.
179    #[inline]
180    pub fn as_mut_ptr(&mut self) -> *mut MaybeUninit<T> {
181        self.data.as_mut_ptr()
182    }
183
184    /// Write a value at a flat index.
185    ///
186    /// # Errors
187    /// Returns `FerrayError::IndexOutOfBounds` if `flat_index >= size()`.
188    pub fn write_at(&mut self, flat_index: usize, value: T) -> FerrayResult<()> {
189        let size = self.size();
190        if flat_index >= size {
191            return Err(FerrayError::IndexOutOfBounds {
192                index: flat_index as isize,
193                axis: 0,
194                size,
195            });
196        }
197        self.data[flat_index] = MaybeUninit::new(value);
198        Ok(())
199    }
200
201    /// Convert to an initialized `Array<T, D>`.
202    ///
203    /// # Safety
204    /// The caller must ensure that **all** elements have been initialized
205    /// (e.g., via `write_at` or raw pointer writes). Reading uninitialized
206    /// memory is undefined behavior.
207    pub unsafe fn assume_init(self) -> Array<T, D> {
208        let nd_dim = self.dim.to_ndarray_dim();
209        let len = self.data.len();
210
211        // Transmute Vec<MaybeUninit<T>> to Vec<T>.
212        // SAFETY: MaybeUninit<T> has the same layout as T, and the caller
213        // guarantees all elements are initialized.
214        let mut raw_vec = std::mem::ManuallyDrop::new(self.data);
215        let data: Vec<T> =
216            unsafe { Vec::from_raw_parts(raw_vec.as_mut_ptr() as *mut T, len, raw_vec.capacity()) };
217
218        let inner = ndarray::Array::from_shape_vec(nd_dim, data)
219            .expect("UninitArray assume_init: shape/data mismatch (this is a bug)");
220        Array::from_ndarray(inner)
221    }
222}
223
224/// Create an uninitialized array.
225///
226/// Analogous to `numpy.empty()`, but returns a [`UninitArray`] that must
227/// be explicitly initialized via [`UninitArray::assume_init`].
228///
229/// This prevents accidentally reading uninitialized memory — a key safety
230/// improvement over NumPy's `empty()`.
231pub fn empty<T: Element, D: Dimension>(dim: D) -> UninitArray<T, D> {
232    let size = dim.size();
233    let mut data = Vec::with_capacity(size);
234    // SAFETY: MaybeUninit does not require initialization.
235    // We set the length to match the capacity; each element is MaybeUninit.
236    unsafe {
237        data.set_len(size);
238    }
239    UninitArray { data, dim }
240}
241
242// ============================================================================
243// REQ-18: Range functions
244// ============================================================================
245
246/// Trait for types usable with `arange` — numeric types that support
247/// stepping and comparison.
248pub trait ArangeNum: Element + PartialOrd {
249    /// Convert from f64 for step calculations.
250    fn from_f64(v: f64) -> Self;
251    /// Convert to f64 for step calculations.
252    fn to_f64(self) -> f64;
253}
254
255macro_rules! impl_arange_int {
256    ($($ty:ty),*) => {
257        $(
258            impl ArangeNum for $ty {
259                #[inline]
260                fn from_f64(v: f64) -> Self { v as Self }
261                #[inline]
262                fn to_f64(self) -> f64 { self as f64 }
263            }
264        )*
265    };
266}
267
268macro_rules! impl_arange_float {
269    ($($ty:ty),*) => {
270        $(
271            impl ArangeNum for $ty {
272                #[inline]
273                fn from_f64(v: f64) -> Self { v as Self }
274                #[inline]
275                fn to_f64(self) -> f64 { self as f64 }
276            }
277        )*
278    };
279}
280
281impl_arange_int!(u8, u16, u32, u64, i8, i16, i32, i64);
282impl_arange_float!(f32, f64);
283
284/// Create a 1-D array with evenly spaced values within a given interval.
285///
286/// Analogous to `numpy.arange(start, stop, step)`.
287///
288/// # Errors
289/// Returns `FerrayError::InvalidValue` if `step` is zero.
290pub fn arange<T: ArangeNum>(start: T, stop: T, step: T) -> FerrayResult<Array<T, Ix1>> {
291    let step_f = step.to_f64();
292    if step_f == 0.0 {
293        return Err(FerrayError::invalid_value("step cannot be zero"));
294    }
295    let start_f = start.to_f64();
296    let stop_f = stop.to_f64();
297    let n = ((stop_f - start_f) / step_f).ceil();
298    let n = if n < 0.0 { 0 } else { n as usize };
299
300    let mut data = Vec::with_capacity(n);
301    for i in 0..n {
302        data.push(T::from_f64(start_f + (i as f64) * step_f));
303    }
304    let dim = Ix1::new([data.len()]);
305    Array::from_vec(dim, data)
306}
307
308/// Trait for float-like types used in linspace/logspace/geomspace.
309pub trait LinspaceNum: Element + PartialOrd {
310    /// Convert from f64.
311    fn from_f64(v: f64) -> Self;
312    /// Convert to f64.
313    fn to_f64(self) -> f64;
314}
315
316impl LinspaceNum for f32 {
317    #[inline]
318    fn from_f64(v: f64) -> Self {
319        v as f32
320    }
321    #[inline]
322    fn to_f64(self) -> f64 {
323        self as f64
324    }
325}
326
327impl LinspaceNum for f64 {
328    #[inline]
329    fn from_f64(v: f64) -> Self {
330        v
331    }
332    #[inline]
333    fn to_f64(self) -> f64 {
334        self
335    }
336}
337
338/// Create a 1-D array with `num` evenly spaced values between `start` and `stop`.
339///
340/// If `endpoint` is true (the default in NumPy), `stop` is the last sample.
341/// Otherwise, it is not included.
342///
343/// Analogous to `numpy.linspace()`.
344///
345/// # Errors
346/// Returns `FerrayError::InvalidValue` if `num` is 0 and `endpoint` is true
347/// (cannot produce an empty array with an endpoint).
348pub fn linspace<T: LinspaceNum>(
349    start: T,
350    stop: T,
351    num: usize,
352    endpoint: bool,
353) -> FerrayResult<Array<T, Ix1>> {
354    if num == 0 {
355        return Array::from_vec(Ix1::new([0]), vec![]);
356    }
357    if num == 1 {
358        return Array::from_vec(Ix1::new([1]), vec![start]);
359    }
360    let start_f = start.to_f64();
361    let stop_f = stop.to_f64();
362    let divisor = if endpoint {
363        (num - 1) as f64
364    } else {
365        num as f64
366    };
367    let step = (stop_f - start_f) / divisor;
368    let mut data = Vec::with_capacity(num);
369    for i in 0..num {
370        data.push(T::from_f64(start_f + (i as f64) * step));
371    }
372    Array::from_vec(Ix1::new([num]), data)
373}
374
375/// Create a 1-D array with values spaced evenly on a log scale.
376///
377/// Returns `base ** linspace(start, stop, num)`.
378///
379/// Analogous to `numpy.logspace()`.
380///
381/// # Errors
382/// Propagates errors from `linspace`.
383pub fn logspace<T: LinspaceNum>(
384    start: T,
385    stop: T,
386    num: usize,
387    endpoint: bool,
388    base: f64,
389) -> FerrayResult<Array<T, Ix1>> {
390    let lin = linspace(start, stop, num, endpoint)?;
391    let data: Vec<T> = lin
392        .iter()
393        .map(|v| T::from_f64(base.powf(v.clone().to_f64())))
394        .collect();
395    Array::from_vec(Ix1::new([num]), data)
396}
397
398/// Create a 1-D array with values spaced evenly on a geometric (log) scale.
399///
400/// The values are `start * (stop/start) ** linspace(0, 1, num)`.
401///
402/// Analogous to `numpy.geomspace()`.
403///
404/// # Errors
405/// Returns `FerrayError::InvalidValue` if `start` or `stop` is zero or
406/// if they have different signs.
407pub fn geomspace<T: LinspaceNum>(
408    start: T,
409    stop: T,
410    num: usize,
411    endpoint: bool,
412) -> FerrayResult<Array<T, Ix1>> {
413    let start_f = start.clone().to_f64();
414    let stop_f = stop.clone().to_f64();
415    if start_f == 0.0 || stop_f == 0.0 {
416        return Err(FerrayError::invalid_value(
417            "geomspace: start and stop must be non-zero",
418        ));
419    }
420    if (start_f < 0.0) != (stop_f < 0.0) {
421        return Err(FerrayError::invalid_value(
422            "geomspace: start and stop must have the same sign",
423        ));
424    }
425    if num == 0 {
426        return Array::from_vec(Ix1::new([0]), vec![]);
427    }
428    if num == 1 {
429        return Array::from_vec(Ix1::new([1]), vec![start]);
430    }
431    let log_start = start_f.abs().ln();
432    let log_stop = stop_f.abs().ln();
433    let sign = if start_f < 0.0 { -1.0 } else { 1.0 };
434    let divisor = if endpoint {
435        (num - 1) as f64
436    } else {
437        num as f64
438    };
439    let step = (log_stop - log_start) / divisor;
440    let mut data = Vec::with_capacity(num);
441    for i in 0..num {
442        let log_val = log_start + (i as f64) * step;
443        data.push(T::from_f64(sign * log_val.exp()));
444    }
445    Array::from_vec(Ix1::new([num]), data)
446}
447
448/// Return coordinate arrays from coordinate vectors.
449///
450/// Analogous to `numpy.meshgrid(*xi, indexing='xy')`.
451///
452/// Given N 1-D arrays, returns N N-D arrays, where each output array
453/// has the shape `(len(x1), len(x2), ..., len(xN))` for 'xy' indexing
454/// or `(len(x1), ..., len(xN))` transposed for 'ij' indexing.
455///
456/// `indexing` should be `"xy"` (default Cartesian) or `"ij"` (matrix).
457///
458/// # Errors
459/// Returns `FerrayError::InvalidValue` if `indexing` is not `"xy"` or `"ij"`,
460/// or if there are fewer than 2 input arrays.
461pub fn meshgrid(
462    arrays: &[Array<f64, Ix1>],
463    indexing: &str,
464) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
465    if indexing != "xy" && indexing != "ij" {
466        return Err(FerrayError::invalid_value(
467            "meshgrid: indexing must be 'xy' or 'ij'",
468        ));
469    }
470    let ndim = arrays.len();
471    if ndim == 0 {
472        return Ok(vec![]);
473    }
474
475    let mut shapes: Vec<usize> = arrays.iter().map(|a| a.shape()[0]).collect();
476    if indexing == "xy" && ndim >= 2 {
477        shapes.swap(0, 1);
478    }
479
480    let total: usize = shapes.iter().product();
481    let mut results = Vec::with_capacity(ndim);
482
483    for (k, arr) in arrays.iter().enumerate() {
484        let src_data: Vec<f64> = arr.iter().copied().collect();
485        let mut data = Vec::with_capacity(total);
486        // For 'xy' indexing, the first two dimensions are swapped
487        let effective_k = if indexing == "xy" && ndim >= 2 {
488            match k {
489                0 => 1,
490                1 => 0,
491                other => other,
492            }
493        } else {
494            k
495        };
496
497        // Build the output by iterating over all indices in the output shape
498        for flat in 0..total {
499            // Compute the index along dimension effective_k
500            let mut rem = flat;
501            let mut idx_k = 0;
502            for (d, &s) in shapes.iter().enumerate().rev() {
503                if d == effective_k {
504                    idx_k = rem % s;
505                }
506                rem /= s;
507            }
508            data.push(src_data[idx_k]);
509        }
510
511        let dim = IxDyn::new(&shapes);
512        results.push(Array::from_vec(dim, data)?);
513    }
514    Ok(results)
515}
516
517/// Create a dense multi-dimensional "meshgrid" with matrix ('ij') indexing.
518///
519/// Analogous to `numpy.mgrid[start:stop:step, ...]`.
520///
521/// Takes a slice of `(start, stop, step)` tuples, one per dimension.
522/// Returns a vector of arrays, one per dimension.
523///
524/// # Errors
525/// Returns `FerrayError::InvalidValue` if any step is zero.
526pub fn mgrid(ranges: &[(f64, f64, f64)]) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
527    let mut arrs: Vec<Array<f64, Ix1>> = Vec::with_capacity(ranges.len());
528    for &(start, stop, step) in ranges {
529        arrs.push(arange(start, stop, step)?);
530    }
531    meshgrid(&arrs, "ij")
532}
533
534/// Create a sparse (open) multi-dimensional "meshgrid" with 'ij' indexing.
535///
536/// Analogous to `numpy.ogrid[start:stop:step, ...]`.
537///
538/// Returns arrays that are broadcastable to the full grid shape.
539/// Each returned array has shape 1 in all dimensions except its own.
540///
541/// # Errors
542/// Returns `FerrayError::InvalidValue` if any step is zero.
543pub fn ogrid(ranges: &[(f64, f64, f64)]) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
544    let ndim = ranges.len();
545    let mut results = Vec::with_capacity(ndim);
546    for (i, &(start, stop, step)) in ranges.iter().enumerate() {
547        let arr1d = arange(start, stop, step)?;
548        let n = arr1d.shape()[0];
549        let data: Vec<f64> = arr1d.iter().copied().collect();
550        // Build shape: all ones except dimension i = n
551        let mut shape = vec![1usize; ndim];
552        shape[i] = n;
553        let dim = IxDyn::new(&shape);
554        results.push(Array::from_vec(dim, data)?);
555    }
556    Ok(results)
557}
558
559// ============================================================================
560// REQ-19: Identity/diagonal functions
561// ============================================================================
562
563/// Create a 2-D identity matrix of size `n x n`.
564///
565/// Analogous to `numpy.identity()`.
566pub fn identity<T: Element>(n: usize) -> FerrayResult<Array<T, Ix2>> {
567    eye(n, n, 0)
568}
569
570/// Create a 2-D array with ones on the diagonal and zeros elsewhere.
571///
572/// `k` is the diagonal offset: 0 = main diagonal, positive = above, negative = below.
573///
574/// Analogous to `numpy.eye(N, M, k)`.
575pub fn eye<T: Element>(n: usize, m: usize, k: isize) -> FerrayResult<Array<T, Ix2>> {
576    let mut data = vec![T::zero(); n * m];
577    for i in 0..n {
578        let j = i as isize + k;
579        if j >= 0 && (j as usize) < m {
580            data[i * m + j as usize] = T::one();
581        }
582    }
583    Array::from_vec(Ix2::new([n, m]), data)
584}
585
586/// Extract a diagonal or construct a diagonal array.
587///
588/// If `a` is 2-D, extract the `k`-th diagonal as a 1-D array.
589/// If `a` is 1-D, construct a 2-D array with `a` on the `k`-th diagonal.
590///
591/// Analogous to `numpy.diag()`.
592///
593/// # Errors
594/// Returns `FerrayError::InvalidValue` if `a` is not 1-D or 2-D.
595pub fn diag<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
596    let shape = a.shape();
597    match shape.len() {
598        1 => {
599            // Construct a 2-D diagonal array
600            let n = shape[0];
601            let size = n + k.unsigned_abs();
602            let mut data = vec![T::zero(); size * size];
603            let src: Vec<T> = a.iter().cloned().collect();
604            for (i, val) in src.into_iter().enumerate() {
605                let row = if k >= 0 { i } else { i + k.unsigned_abs() };
606                let col = if k >= 0 { i + k as usize } else { i };
607                data[row * size + col] = val;
608            }
609            Array::from_vec(IxDyn::new(&[size, size]), data)
610        }
611        2 => {
612            // Extract the k-th diagonal
613            let (n, m) = (shape[0], shape[1]);
614            let src: Vec<T> = a.iter().cloned().collect();
615            let mut diag_vals = Vec::new();
616            for i in 0..n {
617                let j = i as isize + k;
618                if j >= 0 && (j as usize) < m {
619                    diag_vals.push(src[i * m + j as usize].clone());
620                }
621            }
622            let len = diag_vals.len();
623            Array::from_vec(IxDyn::new(&[len]), diag_vals)
624        }
625        _ => Err(FerrayError::invalid_value("diag: input must be 1-D or 2-D")),
626    }
627}
628
629/// Create a 2-D array with the flattened input as a diagonal.
630///
631/// Analogous to `numpy.diagflat()`.
632///
633/// # Errors
634/// Propagates errors from the underlying construction.
635pub fn diagflat<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
636    // Flatten a to 1-D, then call diag
637    let flat: Vec<T> = a.iter().cloned().collect();
638    let n = flat.len();
639    let arr1d = Array::from_vec(IxDyn::new(&[n]), flat)?;
640    diag(&arr1d, k)
641}
642
643/// Create a lower-triangular matrix of ones.
644///
645/// Returns an `n x m` array where `a[i, j] = 1` if `i >= j - k`, else `0`.
646///
647/// Analogous to `numpy.tri(N, M, k)`.
648pub fn tri<T: Element>(n: usize, m: usize, k: isize) -> FerrayResult<Array<T, Ix2>> {
649    let mut data = vec![T::zero(); n * m];
650    for i in 0..n {
651        for j in 0..m {
652            if (i as isize) >= (j as isize) - k {
653                data[i * m + j] = T::one();
654            }
655        }
656    }
657    Array::from_vec(Ix2::new([n, m]), data)
658}
659
660/// Return the lower triangle of a 2-D array.
661///
662/// `k` is the diagonal above which to zero elements. 0 = main diagonal.
663///
664/// Analogous to `numpy.tril()`.
665///
666/// # Errors
667/// Returns `FerrayError::InvalidValue` if input is not 2-D.
668pub fn tril<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
669    let shape = a.shape();
670    if shape.len() != 2 {
671        return Err(FerrayError::invalid_value("tril: input must be 2-D"));
672    }
673    let (n, m) = (shape[0], shape[1]);
674    let src: Vec<T> = a.iter().cloned().collect();
675    let mut data = vec![T::zero(); n * m];
676    for i in 0..n {
677        for j in 0..m {
678            if (i as isize) >= (j as isize) - k {
679                data[i * m + j] = src[i * m + j].clone();
680            }
681        }
682    }
683    Array::from_vec(IxDyn::new(&[n, m]), data)
684}
685
686/// Return the upper triangle of a 2-D array.
687///
688/// `k` is the diagonal below which to zero elements. 0 = main diagonal.
689///
690/// Analogous to `numpy.triu()`.
691///
692/// # Errors
693/// Returns `FerrayError::InvalidValue` if input is not 2-D.
694pub fn triu<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
695    let shape = a.shape();
696    if shape.len() != 2 {
697        return Err(FerrayError::invalid_value("triu: input must be 2-D"));
698    }
699    let (n, m) = (shape[0], shape[1]);
700    let src: Vec<T> = a.iter().cloned().collect();
701    let mut data = vec![T::zero(); n * m];
702    for i in 0..n {
703        for j in 0..m {
704            if (i as isize) <= (j as isize) - k {
705                data[i * m + j] = src[i * m + j].clone();
706            }
707        }
708    }
709    Array::from_vec(IxDyn::new(&[n, m]), data)
710}
711
712// ============================================================================
713// Tests
714// ============================================================================
715
716#[cfg(test)]
717mod tests {
718    use super::*;
719    use crate::dimension::{Ix1, Ix2, IxDyn};
720
721    // -- REQ-16 tests --
722
723    #[test]
724    fn test_array_creation() {
725        let a = array(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
726        assert_eq!(a.shape(), &[2, 3]);
727        assert_eq!(a.size(), 6);
728    }
729
730    #[test]
731    fn test_asarray() {
732        let a = asarray(Ix1::new([3]), vec![1, 2, 3]).unwrap();
733        assert_eq!(a.as_slice().unwrap(), &[1, 2, 3]);
734    }
735
736    #[test]
737    fn test_frombuffer() {
738        let bytes: Vec<u8> = vec![1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0];
739        let a = frombuffer::<i32, Ix1>(Ix1::new([3]), &bytes).unwrap();
740        assert_eq!(a.as_slice().unwrap(), &[1, 2, 3]);
741    }
742
743    #[test]
744    fn test_frombuffer_bad_length() {
745        let bytes: Vec<u8> = vec![1, 0, 0];
746        assert!(frombuffer::<i32, Ix1>(Ix1::new([1]), &bytes).is_err());
747    }
748
749    #[test]
750    fn test_fromiter() {
751        let a = fromiter((0..5).map(|x| x as f64)).unwrap();
752        assert_eq!(a.shape(), &[5]);
753        assert_eq!(a.as_slice().unwrap(), &[0.0, 1.0, 2.0, 3.0, 4.0]);
754    }
755
756    #[test]
757    fn test_zeros() {
758        let a = zeros::<f64, Ix2>(Ix2::new([3, 4])).unwrap();
759        assert_eq!(a.shape(), &[3, 4]);
760        assert!(a.iter().all(|&v| v == 0.0));
761    }
762
763    #[test]
764    fn test_ones() {
765        let a = ones::<f64, Ix1>(Ix1::new([5])).unwrap();
766        assert!(a.iter().all(|&v| v == 1.0));
767    }
768
769    #[test]
770    fn test_full() {
771        let a = full(Ix1::new([4]), 42i32).unwrap();
772        assert!(a.iter().all(|&v| v == 42));
773    }
774
775    #[test]
776    fn test_zeros_like() {
777        let a = ones::<f64, Ix2>(Ix2::new([2, 3])).unwrap();
778        let b = zeros_like(&a).unwrap();
779        assert_eq!(b.shape(), &[2, 3]);
780        assert!(b.iter().all(|&v| v == 0.0));
781    }
782
783    #[test]
784    fn test_ones_like() {
785        let a = zeros::<f64, Ix1>(Ix1::new([4])).unwrap();
786        let b = ones_like(&a).unwrap();
787        assert!(b.iter().all(|&v| v == 1.0));
788    }
789
790    #[test]
791    fn test_full_like() {
792        let a = zeros::<i32, Ix1>(Ix1::new([3])).unwrap();
793        let b = full_like(&a, 7).unwrap();
794        assert!(b.iter().all(|&v| v == 7));
795    }
796
797    // -- REQ-17 tests --
798
799    #[test]
800    fn test_empty_and_init() {
801        let mut u = empty::<f64, Ix1>(Ix1::new([3]));
802        assert_eq!(u.shape(), &[3]);
803        u.write_at(0, 1.0).unwrap();
804        u.write_at(1, 2.0).unwrap();
805        u.write_at(2, 3.0).unwrap();
806        // SAFETY: all elements initialized
807        let a = unsafe { u.assume_init() };
808        assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
809    }
810
811    #[test]
812    fn test_empty_write_oob() {
813        let mut u = empty::<f64, Ix1>(Ix1::new([2]));
814        assert!(u.write_at(5, 1.0).is_err());
815    }
816
817    // -- REQ-18 tests --
818
819    #[test]
820    fn test_arange_int() {
821        let a = arange(0i32, 5, 1).unwrap();
822        assert_eq!(a.as_slice().unwrap(), &[0, 1, 2, 3, 4]);
823    }
824
825    #[test]
826    fn test_arange_float() {
827        let a = arange(0.0_f64, 1.0, 0.25).unwrap();
828        assert_eq!(a.shape(), &[4]);
829        let data = a.as_slice().unwrap();
830        assert!((data[0] - 0.0).abs() < 1e-10);
831        assert!((data[1] - 0.25).abs() < 1e-10);
832        assert!((data[2] - 0.5).abs() < 1e-10);
833        assert!((data[3] - 0.75).abs() < 1e-10);
834    }
835
836    #[test]
837    fn test_arange_negative_step() {
838        let a = arange(5.0_f64, 0.0, -1.0).unwrap();
839        assert_eq!(a.shape(), &[5]);
840    }
841
842    #[test]
843    fn test_arange_zero_step() {
844        assert!(arange(0.0_f64, 1.0, 0.0).is_err());
845    }
846
847    #[test]
848    fn test_arange_empty() {
849        let a = arange(5i32, 0, 1).unwrap();
850        assert_eq!(a.shape(), &[0]);
851    }
852
853    #[test]
854    fn test_linspace() {
855        let a = linspace(0.0_f64, 1.0, 5, true).unwrap();
856        assert_eq!(a.shape(), &[5]);
857        let data = a.as_slice().unwrap();
858        assert!((data[0] - 0.0).abs() < 1e-10);
859        assert!((data[4] - 1.0).abs() < 1e-10);
860        assert!((data[2] - 0.5).abs() < 1e-10);
861    }
862
863    #[test]
864    fn test_linspace_no_endpoint() {
865        let a = linspace(0.0_f64, 1.0, 4, false).unwrap();
866        assert_eq!(a.shape(), &[4]);
867        let data = a.as_slice().unwrap();
868        assert!((data[0] - 0.0).abs() < 1e-10);
869        assert!((data[1] - 0.25).abs() < 1e-10);
870    }
871
872    #[test]
873    fn test_linspace_single() {
874        let a = linspace(5.0_f64, 10.0, 1, true).unwrap();
875        assert_eq!(a.as_slice().unwrap(), &[5.0]);
876    }
877
878    #[test]
879    fn test_linspace_empty() {
880        let a = linspace(0.0_f64, 1.0, 0, true).unwrap();
881        assert_eq!(a.shape(), &[0]);
882    }
883
884    #[test]
885    fn test_logspace() {
886        let a = logspace(0.0_f64, 2.0, 3, true, 10.0).unwrap();
887        let data = a.as_slice().unwrap();
888        assert!((data[0] - 1.0).abs() < 1e-10); // 10^0
889        assert!((data[1] - 10.0).abs() < 1e-10); // 10^1
890        assert!((data[2] - 100.0).abs() < 1e-10); // 10^2
891    }
892
893    #[test]
894    fn test_geomspace() {
895        let a = geomspace(1.0_f64, 1000.0, 4, true).unwrap();
896        let data = a.as_slice().unwrap();
897        assert!((data[0] - 1.0).abs() < 1e-10);
898        assert!((data[1] - 10.0).abs() < 1e-8);
899        assert!((data[2] - 100.0).abs() < 1e-6);
900        assert!((data[3] - 1000.0).abs() < 1e-4);
901    }
902
903    #[test]
904    fn test_geomspace_zero_start() {
905        assert!(geomspace(0.0_f64, 1.0, 5, true).is_err());
906    }
907
908    #[test]
909    fn test_geomspace_different_signs() {
910        assert!(geomspace(-1.0_f64, 1.0, 5, true).is_err());
911    }
912
913    #[test]
914    fn test_meshgrid_xy() {
915        let x = Array::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
916        let y = Array::from_vec(Ix1::new([2]), vec![4.0, 5.0]).unwrap();
917        let grids = meshgrid(&[x, y], "xy").unwrap();
918        assert_eq!(grids.len(), 2);
919        assert_eq!(grids[0].shape(), &[2, 3]);
920        assert_eq!(grids[1].shape(), &[2, 3]);
921        // X grid: rows are [1,2,3], [1,2,3]
922        let xdata: Vec<f64> = grids[0].iter().copied().collect();
923        assert_eq!(xdata, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
924        // Y grid: rows are [4,4,4], [5,5,5]
925        let ydata: Vec<f64> = grids[1].iter().copied().collect();
926        assert_eq!(ydata, vec![4.0, 4.0, 4.0, 5.0, 5.0, 5.0]);
927    }
928
929    #[test]
930    fn test_meshgrid_ij() {
931        let x = Array::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
932        let y = Array::from_vec(Ix1::new([2]), vec![4.0, 5.0]).unwrap();
933        let grids = meshgrid(&[x, y], "ij").unwrap();
934        assert_eq!(grids.len(), 2);
935        assert_eq!(grids[0].shape(), &[3, 2]);
936        assert_eq!(grids[1].shape(), &[3, 2]);
937    }
938
939    #[test]
940    fn test_meshgrid_bad_indexing() {
941        assert!(meshgrid(&[], "zz").is_err());
942    }
943
944    #[test]
945    fn test_mgrid() {
946        let grids = mgrid(&[(0.0, 3.0, 1.0), (0.0, 2.0, 1.0)]).unwrap();
947        assert_eq!(grids.len(), 2);
948        assert_eq!(grids[0].shape(), &[3, 2]);
949    }
950
951    #[test]
952    fn test_ogrid() {
953        let grids = ogrid(&[(0.0, 3.0, 1.0), (0.0, 2.0, 1.0)]).unwrap();
954        assert_eq!(grids.len(), 2);
955        assert_eq!(grids[0].shape(), &[3, 1]);
956        assert_eq!(grids[1].shape(), &[1, 2]);
957    }
958
959    // -- REQ-19 tests --
960
961    #[test]
962    fn test_identity() {
963        let a = identity::<f64>(3).unwrap();
964        assert_eq!(a.shape(), &[3, 3]);
965        let data = a.as_slice().unwrap();
966        assert_eq!(data, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
967    }
968
969    #[test]
970    fn test_eye() {
971        let a = eye::<f64>(3, 4, 0).unwrap();
972        assert_eq!(a.shape(), &[3, 4]);
973        let data = a.as_slice().unwrap();
974        assert_eq!(
975            data,
976            &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
977        );
978    }
979
980    #[test]
981    fn test_eye_positive_k() {
982        let a = eye::<f64>(3, 3, 1).unwrap();
983        let data = a.as_slice().unwrap();
984        assert_eq!(data, &[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]);
985    }
986
987    #[test]
988    fn test_eye_negative_k() {
989        let a = eye::<f64>(3, 3, -1).unwrap();
990        let data = a.as_slice().unwrap();
991        assert_eq!(data, &[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
992    }
993
994    #[test]
995    fn test_diag_from_1d() {
996        let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
997        let d = diag(&a, 0).unwrap();
998        assert_eq!(d.shape(), &[3, 3]);
999        let data: Vec<f64> = d.iter().copied().collect();
1000        assert_eq!(data, vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
1001    }
1002
1003    #[test]
1004    fn test_diag_from_2d() {
1005        let a = Array::from_vec(
1006            IxDyn::new(&[3, 3]),
1007            vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0],
1008        )
1009        .unwrap();
1010        let d = diag(&a, 0).unwrap();
1011        assert_eq!(d.shape(), &[3]);
1012        let data: Vec<f64> = d.iter().copied().collect();
1013        assert_eq!(data, vec![1.0, 2.0, 3.0]);
1014    }
1015
1016    #[test]
1017    fn test_diag_k_positive() {
1018        let a = Array::from_vec(IxDyn::new(&[2]), vec![1.0, 2.0]).unwrap();
1019        let d = diag(&a, 1).unwrap();
1020        assert_eq!(d.shape(), &[3, 3]);
1021        let data: Vec<f64> = d.iter().copied().collect();
1022        assert_eq!(data, vec![0.0, 1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0]);
1023    }
1024
1025    #[test]
1026    fn test_diagflat() {
1027        let a = Array::from_vec(IxDyn::new(&[2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1028        let d = diagflat(&a, 0).unwrap();
1029        assert_eq!(d.shape(), &[4, 4]);
1030        // Diagonal should be [1, 2, 3, 4]
1031        let extracted = diag(&d, 0).unwrap();
1032        let data: Vec<f64> = extracted.iter().copied().collect();
1033        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
1034    }
1035
1036    #[test]
1037    fn test_tri() {
1038        let a = tri::<f64>(3, 3, 0).unwrap();
1039        let data = a.as_slice().unwrap();
1040        assert_eq!(data, &[1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0]);
1041    }
1042
1043    #[test]
1044    fn test_tril() {
1045        let a = Array::from_vec(
1046            IxDyn::new(&[3, 3]),
1047            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1048        )
1049        .unwrap();
1050        let t = tril(&a, 0).unwrap();
1051        let data: Vec<f64> = t.iter().copied().collect();
1052        assert_eq!(data, vec![1.0, 0.0, 0.0, 4.0, 5.0, 0.0, 7.0, 8.0, 9.0]);
1053    }
1054
1055    #[test]
1056    fn test_triu() {
1057        let a = Array::from_vec(
1058            IxDyn::new(&[3, 3]),
1059            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1060        )
1061        .unwrap();
1062        let t = triu(&a, 0).unwrap();
1063        let data: Vec<f64> = t.iter().copied().collect();
1064        assert_eq!(data, vec![1.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0, 0.0, 9.0]);
1065    }
1066
1067    #[test]
1068    fn test_tril_not_2d() {
1069        let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1070        assert!(tril(&a, 0).is_err());
1071    }
1072
1073    #[test]
1074    fn test_triu_not_2d() {
1075        let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1076        assert!(triu(&a, 0).is_err());
1077    }
1078}