Skip to main content

ferray_core/creation/
mod.rs

1// ferray-core: Array creation functions (REQ-16, REQ-17, REQ-18, REQ-19)
2//
3// Mirrors numpy's array creation routines: zeros, ones, full, empty,
4// arange, linspace, logspace, geomspace, eye, identity, diag, etc.
5
6use std::mem::MaybeUninit;
7
8use crate::array::owned::Array;
9use crate::dimension::{Dimension, Ix1, Ix2, IxDyn};
10use crate::dtype::Element;
11use crate::error::{FerrayError, FerrayResult};
12
13// ============================================================================
14// REQ-16: Basic creation functions
15// ============================================================================
16
17/// Create an array from a flat vector and a shape (C-order).
18///
19/// This is the primary "array constructor" — analogous to `numpy.array()` when
20/// given a flat sequence plus a shape.
21///
22/// # Errors
23/// Returns `FerrayError::ShapeMismatch` if `data.len()` does not equal the
24/// product of the shape dimensions.
25pub fn array<T: Element, D: Dimension>(dim: D, data: Vec<T>) -> FerrayResult<Array<T, D>> {
26    Array::from_vec(dim, data)
27}
28
29/// Interpret existing data as an array without copying (if possible).
30///
31/// This is equivalent to `numpy.asarray()`. Since Rust ownership rules
32/// require moving the data, this creates an owned array from the vector.
33///
34/// # Errors
35/// Returns `FerrayError::ShapeMismatch` if lengths don't match.
36pub fn asarray<T: Element, D: Dimension>(dim: D, data: Vec<T>) -> FerrayResult<Array<T, D>> {
37    Array::from_vec(dim, data)
38}
39
40/// Create an array from a byte buffer, interpreting bytes as elements of type `T`.
41///
42/// Analogous to `numpy.frombuffer()`.
43///
44/// # Errors
45/// Returns `FerrayError::InvalidValue` if the buffer length is not a multiple
46/// of `size_of::<T>()`, or if the resulting length does not match the shape.
47pub fn frombuffer<T: Element, D: Dimension>(dim: D, buf: &[u8]) -> FerrayResult<Array<T, D>> {
48    let elem_size = std::mem::size_of::<T>();
49    if elem_size == 0 {
50        return Err(FerrayError::invalid_value("zero-sized type"));
51    }
52    if buf.len() % elem_size != 0 {
53        return Err(FerrayError::invalid_value(format!(
54            "buffer length {} is not a multiple of element size {}",
55            buf.len(),
56            elem_size,
57        )));
58    }
59    let n_elems = buf.len() / elem_size;
60    let expected = dim.size();
61    if n_elems != expected {
62        return Err(FerrayError::shape_mismatch(format!(
63            "buffer contains {} elements but shape {:?} requires {}",
64            n_elems,
65            dim.as_slice(),
66            expected,
67        )));
68    }
69    // Validate bytes for types where not all bit patterns are valid.
70    // bool only permits 0x00 and 0x01.
71    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<bool>() {
72        for &byte in buf {
73            if byte > 1 {
74                return Err(FerrayError::invalid_value(format!(
75                    "invalid byte {byte:#04x} for bool (must be 0x00 or 0x01)"
76                )));
77            }
78        }
79    }
80
81    // Copy bytes element-by-element via from_ne_bytes equivalent
82    let mut data = Vec::with_capacity(n_elems);
83    for i in 0..n_elems {
84        let start = i * elem_size;
85        let end = start + elem_size;
86        let slice = &buf[start..end];
87        // SAFETY: We're reading elem_size bytes and interpreting as T.
88        // For bool, we validated above that all bytes are 0 or 1.
89        // For numeric types, all bit patterns are valid.
90        let val = unsafe {
91            let mut val = MaybeUninit::<T>::uninit();
92            std::ptr::copy_nonoverlapping(slice.as_ptr(), val.as_mut_ptr() as *mut u8, elem_size);
93            val.assume_init()
94        };
95        data.push(val);
96    }
97    Array::from_vec(dim, data)
98}
99
100/// Create a 1-D array from an iterator.
101///
102/// Analogous to `numpy.fromiter()`.
103///
104/// # Errors
105/// This function always succeeds (returns `Ok`).
106pub fn fromiter<T: Element>(iter: impl IntoIterator<Item = T>) -> FerrayResult<Array<T, Ix1>> {
107    Array::from_iter_1d(iter)
108}
109
110/// Create an array filled with zeros.
111///
112/// Analogous to `numpy.zeros()`.
113pub fn zeros<T: Element, D: Dimension>(dim: D) -> FerrayResult<Array<T, D>> {
114    Array::zeros(dim)
115}
116
117/// Create an array filled with ones.
118///
119/// Analogous to `numpy.ones()`.
120pub fn ones<T: Element, D: Dimension>(dim: D) -> FerrayResult<Array<T, D>> {
121    Array::ones(dim)
122}
123
124/// Create an array filled with a given value.
125///
126/// Analogous to `numpy.full()`.
127pub fn full<T: Element, D: Dimension>(dim: D, fill_value: T) -> FerrayResult<Array<T, D>> {
128    Array::from_elem(dim, fill_value)
129}
130
131/// Create an array with the same shape as `other`, filled with zeros.
132///
133/// Analogous to `numpy.zeros_like()`.
134pub fn zeros_like<T: Element, D: Dimension>(other: &Array<T, D>) -> FerrayResult<Array<T, D>> {
135    Array::zeros(other.dim().clone())
136}
137
138/// Create an array with the same shape as `other`, filled with ones.
139///
140/// Analogous to `numpy.ones_like()`.
141pub fn ones_like<T: Element, D: Dimension>(other: &Array<T, D>) -> FerrayResult<Array<T, D>> {
142    Array::ones(other.dim().clone())
143}
144
145/// Create an array with the same shape as `other`, filled with `fill_value`.
146///
147/// Analogous to `numpy.full_like()`.
148pub fn full_like<T: Element, D: Dimension>(
149    other: &Array<T, D>,
150    fill_value: T,
151) -> FerrayResult<Array<T, D>> {
152    Array::from_elem(other.dim().clone(), fill_value)
153}
154
155// ============================================================================
156// REQ-17: empty() returning MaybeUninit
157// ============================================================================
158
159/// An array whose elements have not been initialized.
160///
161/// The caller must call [`assume_init`](UninitArray::assume_init) after
162/// filling all elements.
163pub struct UninitArray<T: Element, D: Dimension> {
164    data: Vec<MaybeUninit<T>>,
165    dim: D,
166}
167
168impl<T: Element, D: Dimension> UninitArray<T, D> {
169    /// Shape as a slice.
170    #[inline]
171    pub fn shape(&self) -> &[usize] {
172        self.dim.as_slice()
173    }
174
175    /// Total number of elements.
176    #[inline]
177    pub fn size(&self) -> usize {
178        self.data.len()
179    }
180
181    /// Number of dimensions.
182    #[inline]
183    pub fn ndim(&self) -> usize {
184        self.dim.ndim()
185    }
186
187    /// Get a mutable raw pointer to the underlying data.
188    ///
189    /// Use this to fill the array element-by-element before calling
190    /// `assume_init()`.
191    #[inline]
192    pub fn as_mut_ptr(&mut self) -> *mut MaybeUninit<T> {
193        self.data.as_mut_ptr()
194    }
195
196    /// Write a value at a flat index.
197    ///
198    /// # Errors
199    /// Returns `FerrayError::IndexOutOfBounds` if `flat_index >= size()`.
200    pub fn write_at(&mut self, flat_index: usize, value: T) -> FerrayResult<()> {
201        let size = self.size();
202        if flat_index >= size {
203            return Err(FerrayError::IndexOutOfBounds {
204                index: flat_index as isize,
205                axis: 0,
206                size,
207            });
208        }
209        self.data[flat_index] = MaybeUninit::new(value);
210        Ok(())
211    }
212
213    /// Convert to an initialized `Array<T, D>`.
214    ///
215    /// # Safety
216    /// The caller must ensure that **all** elements have been initialized
217    /// (e.g., via `write_at` or raw pointer writes). Reading uninitialized
218    /// memory is undefined behavior.
219    pub unsafe fn assume_init(self) -> Array<T, D> {
220        let nd_dim = self.dim.to_ndarray_dim();
221        let len = self.data.len();
222
223        // Transmute Vec<MaybeUninit<T>> to Vec<T>.
224        // SAFETY: MaybeUninit<T> has the same layout as T, and the caller
225        // guarantees all elements are initialized.
226        let mut raw_vec = std::mem::ManuallyDrop::new(self.data);
227        let data: Vec<T> =
228            unsafe { Vec::from_raw_parts(raw_vec.as_mut_ptr() as *mut T, len, raw_vec.capacity()) };
229
230        let inner = ndarray::Array::from_shape_vec(nd_dim, data)
231            .expect("UninitArray assume_init: shape/data mismatch (this is a bug)");
232        Array::from_ndarray(inner)
233    }
234}
235
236/// Create an uninitialized array.
237///
238/// Analogous to `numpy.empty()`, but returns a [`UninitArray`] that must
239/// be explicitly initialized via [`UninitArray::assume_init`].
240///
241/// This prevents accidentally reading uninitialized memory — a key safety
242/// improvement over NumPy's `empty()`.
243pub fn empty<T: Element, D: Dimension>(dim: D) -> UninitArray<T, D> {
244    let size = dim.size();
245    let mut data = Vec::with_capacity(size);
246    // SAFETY: MaybeUninit does not require initialization.
247    // We set the length to match the capacity; each element is MaybeUninit.
248    unsafe {
249        data.set_len(size);
250    }
251    UninitArray { data, dim }
252}
253
254// ============================================================================
255// REQ-18: Range functions
256// ============================================================================
257
258/// Trait for types usable with `arange` — numeric types that support
259/// stepping and comparison.
260pub trait ArangeNum: Element + PartialOrd {
261    /// Convert from f64 for step calculations.
262    fn from_f64(v: f64) -> Self;
263    /// Convert to f64 for step calculations.
264    fn to_f64(self) -> f64;
265}
266
267macro_rules! impl_arange_int {
268    ($($ty:ty),*) => {
269        $(
270            impl ArangeNum for $ty {
271                #[inline]
272                fn from_f64(v: f64) -> Self { v as Self }
273                #[inline]
274                fn to_f64(self) -> f64 { self as f64 }
275            }
276        )*
277    };
278}
279
280macro_rules! impl_arange_float {
281    ($($ty:ty),*) => {
282        $(
283            impl ArangeNum for $ty {
284                #[inline]
285                fn from_f64(v: f64) -> Self { v as Self }
286                #[inline]
287                fn to_f64(self) -> f64 { self as f64 }
288            }
289        )*
290    };
291}
292
293impl_arange_int!(u8, u16, u32, u64, i8, i16, i32, i64);
294impl_arange_float!(f32, f64);
295
296/// Create a 1-D array with evenly spaced values within a given interval.
297///
298/// Analogous to `numpy.arange(start, stop, step)`.
299///
300/// # Errors
301/// Returns `FerrayError::InvalidValue` if `step` is zero.
302pub fn arange<T: ArangeNum>(start: T, stop: T, step: T) -> FerrayResult<Array<T, Ix1>> {
303    let step_f = step.to_f64();
304    if step_f == 0.0 {
305        return Err(FerrayError::invalid_value("step cannot be zero"));
306    }
307    let start_f = start.to_f64();
308    let stop_f = stop.to_f64();
309    let n = ((stop_f - start_f) / step_f).ceil();
310    let n = if n < 0.0 { 0 } else { n as usize };
311
312    let mut data = Vec::with_capacity(n);
313    for i in 0..n {
314        data.push(T::from_f64(start_f + (i as f64) * step_f));
315    }
316    let dim = Ix1::new([data.len()]);
317    Array::from_vec(dim, data)
318}
319
320/// Trait for float-like types used in linspace/logspace/geomspace.
321pub trait LinspaceNum: Element + PartialOrd {
322    /// Convert from f64.
323    fn from_f64(v: f64) -> Self;
324    /// Convert to f64.
325    fn to_f64(self) -> f64;
326}
327
328impl LinspaceNum for f32 {
329    #[inline]
330    fn from_f64(v: f64) -> Self {
331        v as f32
332    }
333    #[inline]
334    fn to_f64(self) -> f64 {
335        self as f64
336    }
337}
338
339impl LinspaceNum for f64 {
340    #[inline]
341    fn from_f64(v: f64) -> Self {
342        v
343    }
344    #[inline]
345    fn to_f64(self) -> f64 {
346        self
347    }
348}
349
350/// Create a 1-D array with `num` evenly spaced values between `start` and `stop`.
351///
352/// If `endpoint` is true (the default in NumPy), `stop` is the last sample.
353/// Otherwise, it is not included.
354///
355/// Analogous to `numpy.linspace()`.
356///
357/// # Errors
358/// Returns `FerrayError::InvalidValue` if `num` is 0 and `endpoint` is true
359/// (cannot produce an empty array with an endpoint).
360pub fn linspace<T: LinspaceNum>(
361    start: T,
362    stop: T,
363    num: usize,
364    endpoint: bool,
365) -> FerrayResult<Array<T, Ix1>> {
366    if num == 0 {
367        return Array::from_vec(Ix1::new([0]), vec![]);
368    }
369    if num == 1 {
370        return Array::from_vec(Ix1::new([1]), vec![start]);
371    }
372    let start_f = start.to_f64();
373    let stop_f = stop.to_f64();
374    let divisor = if endpoint {
375        (num - 1) as f64
376    } else {
377        num as f64
378    };
379    let step = (stop_f - start_f) / divisor;
380    let mut data = Vec::with_capacity(num);
381    for i in 0..num {
382        data.push(T::from_f64(start_f + (i as f64) * step));
383    }
384    Array::from_vec(Ix1::new([num]), data)
385}
386
387/// Create a 1-D array with values spaced evenly on a log scale.
388///
389/// Returns `base ** linspace(start, stop, num)`.
390///
391/// Analogous to `numpy.logspace()`.
392///
393/// # Errors
394/// Propagates errors from `linspace`.
395pub fn logspace<T: LinspaceNum>(
396    start: T,
397    stop: T,
398    num: usize,
399    endpoint: bool,
400    base: f64,
401) -> FerrayResult<Array<T, Ix1>> {
402    let lin = linspace(start, stop, num, endpoint)?;
403    let data: Vec<T> = lin
404        .iter()
405        .map(|v| T::from_f64(base.powf(v.clone().to_f64())))
406        .collect();
407    Array::from_vec(Ix1::new([num]), data)
408}
409
410/// Create a 1-D array with values spaced evenly on a geometric (log) scale.
411///
412/// The values are `start * (stop/start) ** linspace(0, 1, num)`.
413///
414/// Analogous to `numpy.geomspace()`.
415///
416/// # Errors
417/// Returns `FerrayError::InvalidValue` if `start` or `stop` is zero or
418/// if they have different signs.
419pub fn geomspace<T: LinspaceNum>(
420    start: T,
421    stop: T,
422    num: usize,
423    endpoint: bool,
424) -> FerrayResult<Array<T, Ix1>> {
425    let start_f = start.clone().to_f64();
426    let stop_f = stop.clone().to_f64();
427    if start_f == 0.0 || stop_f == 0.0 {
428        return Err(FerrayError::invalid_value(
429            "geomspace: start and stop must be non-zero",
430        ));
431    }
432    if (start_f < 0.0) != (stop_f < 0.0) {
433        return Err(FerrayError::invalid_value(
434            "geomspace: start and stop must have the same sign",
435        ));
436    }
437    if num == 0 {
438        return Array::from_vec(Ix1::new([0]), vec![]);
439    }
440    if num == 1 {
441        return Array::from_vec(Ix1::new([1]), vec![start]);
442    }
443    let log_start = start_f.abs().ln();
444    let log_stop = stop_f.abs().ln();
445    let sign = if start_f < 0.0 { -1.0 } else { 1.0 };
446    let divisor = if endpoint {
447        (num - 1) as f64
448    } else {
449        num as f64
450    };
451    let step = (log_stop - log_start) / divisor;
452    let mut data = Vec::with_capacity(num);
453    for i in 0..num {
454        let log_val = log_start + (i as f64) * step;
455        data.push(T::from_f64(sign * log_val.exp()));
456    }
457    Array::from_vec(Ix1::new([num]), data)
458}
459
460/// Return coordinate arrays from coordinate vectors.
461///
462/// Analogous to `numpy.meshgrid(*xi, indexing='xy')`.
463///
464/// Given N 1-D arrays, returns N N-D arrays, where each output array
465/// has the shape `(len(x1), len(x2), ..., len(xN))` for 'xy' indexing
466/// or `(len(x1), ..., len(xN))` transposed for 'ij' indexing.
467///
468/// `indexing` should be `"xy"` (default Cartesian) or `"ij"` (matrix).
469///
470/// # Errors
471/// Returns `FerrayError::InvalidValue` if `indexing` is not `"xy"` or `"ij"`,
472/// or if there are fewer than 2 input arrays.
473pub fn meshgrid(
474    arrays: &[Array<f64, Ix1>],
475    indexing: &str,
476) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
477    if indexing != "xy" && indexing != "ij" {
478        return Err(FerrayError::invalid_value(
479            "meshgrid: indexing must be 'xy' or 'ij'",
480        ));
481    }
482    let ndim = arrays.len();
483    if ndim == 0 {
484        return Ok(vec![]);
485    }
486
487    let mut shapes: Vec<usize> = arrays.iter().map(|a| a.shape()[0]).collect();
488    if indexing == "xy" && ndim >= 2 {
489        shapes.swap(0, 1);
490    }
491
492    let total: usize = shapes.iter().product();
493    let mut results = Vec::with_capacity(ndim);
494
495    for (k, arr) in arrays.iter().enumerate() {
496        let src_data: Vec<f64> = arr.iter().copied().collect();
497        let mut data = Vec::with_capacity(total);
498        // For 'xy' indexing, the first two dimensions are swapped
499        let effective_k = if indexing == "xy" && ndim >= 2 {
500            match k {
501                0 => 1,
502                1 => 0,
503                other => other,
504            }
505        } else {
506            k
507        };
508
509        // Build the output by iterating over all indices in the output shape
510        for flat in 0..total {
511            // Compute the index along dimension effective_k
512            let mut rem = flat;
513            let mut idx_k = 0;
514            for (d, &s) in shapes.iter().enumerate().rev() {
515                if d == effective_k {
516                    idx_k = rem % s;
517                }
518                rem /= s;
519            }
520            data.push(src_data[idx_k]);
521        }
522
523        let dim = IxDyn::new(&shapes);
524        results.push(Array::from_vec(dim, data)?);
525    }
526    Ok(results)
527}
528
529/// Create a dense multi-dimensional "meshgrid" with matrix ('ij') indexing.
530///
531/// Analogous to `numpy.mgrid[start:stop:step, ...]`.
532///
533/// Takes a slice of `(start, stop, step)` tuples, one per dimension.
534/// Returns a vector of arrays, one per dimension.
535///
536/// # Errors
537/// Returns `FerrayError::InvalidValue` if any step is zero.
538pub fn mgrid(ranges: &[(f64, f64, f64)]) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
539    let mut arrs: Vec<Array<f64, Ix1>> = Vec::with_capacity(ranges.len());
540    for &(start, stop, step) in ranges {
541        arrs.push(arange(start, stop, step)?);
542    }
543    meshgrid(&arrs, "ij")
544}
545
546/// Create a sparse (open) multi-dimensional "meshgrid" with 'ij' indexing.
547///
548/// Analogous to `numpy.ogrid[start:stop:step, ...]`.
549///
550/// Returns arrays that are broadcastable to the full grid shape.
551/// Each returned array has shape 1 in all dimensions except its own.
552///
553/// # Errors
554/// Returns `FerrayError::InvalidValue` if any step is zero.
555pub fn ogrid(ranges: &[(f64, f64, f64)]) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
556    let ndim = ranges.len();
557    let mut results = Vec::with_capacity(ndim);
558    for (i, &(start, stop, step)) in ranges.iter().enumerate() {
559        let arr1d = arange(start, stop, step)?;
560        let n = arr1d.shape()[0];
561        let data: Vec<f64> = arr1d.iter().copied().collect();
562        // Build shape: all ones except dimension i = n
563        let mut shape = vec![1usize; ndim];
564        shape[i] = n;
565        let dim = IxDyn::new(&shape);
566        results.push(Array::from_vec(dim, data)?);
567    }
568    Ok(results)
569}
570
571// ============================================================================
572// REQ-19: Identity/diagonal functions
573// ============================================================================
574
575/// Create a 2-D identity matrix of size `n x n`.
576///
577/// Analogous to `numpy.identity()`.
578pub fn identity<T: Element>(n: usize) -> FerrayResult<Array<T, Ix2>> {
579    eye(n, n, 0)
580}
581
582/// Create a 2-D array with ones on the diagonal and zeros elsewhere.
583///
584/// `k` is the diagonal offset: 0 = main diagonal, positive = above, negative = below.
585///
586/// Analogous to `numpy.eye(N, M, k)`.
587pub fn eye<T: Element>(n: usize, m: usize, k: isize) -> FerrayResult<Array<T, Ix2>> {
588    let mut data = vec![T::zero(); n * m];
589    for i in 0..n {
590        let j = i as isize + k;
591        if j >= 0 && (j as usize) < m {
592            data[i * m + j as usize] = T::one();
593        }
594    }
595    Array::from_vec(Ix2::new([n, m]), data)
596}
597
598/// Extract a diagonal or construct a diagonal array.
599///
600/// If `a` is 2-D, extract the `k`-th diagonal as a 1-D array.
601/// If `a` is 1-D, construct a 2-D array with `a` on the `k`-th diagonal.
602///
603/// Analogous to `numpy.diag()`.
604///
605/// # Errors
606/// Returns `FerrayError::InvalidValue` if `a` is not 1-D or 2-D.
607pub fn diag<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
608    let shape = a.shape();
609    match shape.len() {
610        1 => {
611            // Construct a 2-D diagonal array
612            let n = shape[0];
613            let size = n + k.unsigned_abs();
614            let mut data = vec![T::zero(); size * size];
615            let src: Vec<T> = a.iter().cloned().collect();
616            for (i, val) in src.into_iter().enumerate() {
617                let row = if k >= 0 { i } else { i + k.unsigned_abs() };
618                let col = if k >= 0 { i + k as usize } else { i };
619                data[row * size + col] = val;
620            }
621            Array::from_vec(IxDyn::new(&[size, size]), data)
622        }
623        2 => {
624            // Extract the k-th diagonal
625            let (n, m) = (shape[0], shape[1]);
626            let src: Vec<T> = a.iter().cloned().collect();
627            let mut diag_vals = Vec::new();
628            for i in 0..n {
629                let j = i as isize + k;
630                if j >= 0 && (j as usize) < m {
631                    diag_vals.push(src[i * m + j as usize].clone());
632                }
633            }
634            let len = diag_vals.len();
635            Array::from_vec(IxDyn::new(&[len]), diag_vals)
636        }
637        _ => Err(FerrayError::invalid_value("diag: input must be 1-D or 2-D")),
638    }
639}
640
641/// Create a 2-D array with the flattened input as a diagonal.
642///
643/// Analogous to `numpy.diagflat()`.
644///
645/// # Errors
646/// Propagates errors from the underlying construction.
647pub fn diagflat<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
648    // Flatten a to 1-D, then call diag
649    let flat: Vec<T> = a.iter().cloned().collect();
650    let n = flat.len();
651    let arr1d = Array::from_vec(IxDyn::new(&[n]), flat)?;
652    diag(&arr1d, k)
653}
654
655/// Create a lower-triangular matrix of ones.
656///
657/// Returns an `n x m` array where `a[i, j] = 1` if `i >= j - k`, else `0`.
658///
659/// Analogous to `numpy.tri(N, M, k)`.
660pub fn tri<T: Element>(n: usize, m: usize, k: isize) -> FerrayResult<Array<T, Ix2>> {
661    let mut data = vec![T::zero(); n * m];
662    for i in 0..n {
663        for j in 0..m {
664            if (i as isize) >= (j as isize) - k {
665                data[i * m + j] = T::one();
666            }
667        }
668    }
669    Array::from_vec(Ix2::new([n, m]), data)
670}
671
672/// Return the lower triangle of a 2-D array.
673///
674/// `k` is the diagonal above which to zero elements. 0 = main diagonal.
675///
676/// Analogous to `numpy.tril()`.
677///
678/// # Errors
679/// Returns `FerrayError::InvalidValue` if input is not 2-D.
680pub fn tril<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
681    let shape = a.shape();
682    if shape.len() != 2 {
683        return Err(FerrayError::invalid_value("tril: input must be 2-D"));
684    }
685    let (n, m) = (shape[0], shape[1]);
686    let src: Vec<T> = a.iter().cloned().collect();
687    let mut data = vec![T::zero(); n * m];
688    for i in 0..n {
689        for j in 0..m {
690            if (i as isize) >= (j as isize) - k {
691                data[i * m + j] = src[i * m + j].clone();
692            }
693        }
694    }
695    Array::from_vec(IxDyn::new(&[n, m]), data)
696}
697
698/// Return the upper triangle of a 2-D array.
699///
700/// `k` is the diagonal below which to zero elements. 0 = main diagonal.
701///
702/// Analogous to `numpy.triu()`.
703///
704/// # Errors
705/// Returns `FerrayError::InvalidValue` if input is not 2-D.
706pub fn triu<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
707    let shape = a.shape();
708    if shape.len() != 2 {
709        return Err(FerrayError::invalid_value("triu: input must be 2-D"));
710    }
711    let (n, m) = (shape[0], shape[1]);
712    let src: Vec<T> = a.iter().cloned().collect();
713    let mut data = vec![T::zero(); n * m];
714    for i in 0..n {
715        for j in 0..m {
716            if (i as isize) <= (j as isize) - k {
717                data[i * m + j] = src[i * m + j].clone();
718            }
719        }
720    }
721    Array::from_vec(IxDyn::new(&[n, m]), data)
722}
723
724// ============================================================================
725// Tests
726// ============================================================================
727
728#[cfg(test)]
729mod tests {
730    use super::*;
731    use crate::dimension::{Ix1, Ix2, IxDyn};
732
733    // -- REQ-16 tests --
734
735    #[test]
736    fn test_array_creation() {
737        let a = array(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
738        assert_eq!(a.shape(), &[2, 3]);
739        assert_eq!(a.size(), 6);
740    }
741
742    #[test]
743    fn test_asarray() {
744        let a = asarray(Ix1::new([3]), vec![1, 2, 3]).unwrap();
745        assert_eq!(a.as_slice().unwrap(), &[1, 2, 3]);
746    }
747
748    #[test]
749    fn test_frombuffer() {
750        let bytes: Vec<u8> = vec![1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0];
751        let a = frombuffer::<i32, Ix1>(Ix1::new([3]), &bytes).unwrap();
752        assert_eq!(a.as_slice().unwrap(), &[1, 2, 3]);
753    }
754
755    #[test]
756    fn test_frombuffer_bad_length() {
757        let bytes: Vec<u8> = vec![1, 0, 0];
758        assert!(frombuffer::<i32, Ix1>(Ix1::new([1]), &bytes).is_err());
759    }
760
761    #[test]
762    fn test_fromiter() {
763        let a = fromiter((0..5).map(|x| x as f64)).unwrap();
764        assert_eq!(a.shape(), &[5]);
765        assert_eq!(a.as_slice().unwrap(), &[0.0, 1.0, 2.0, 3.0, 4.0]);
766    }
767
768    #[test]
769    fn test_zeros() {
770        let a = zeros::<f64, Ix2>(Ix2::new([3, 4])).unwrap();
771        assert_eq!(a.shape(), &[3, 4]);
772        assert!(a.iter().all(|&v| v == 0.0));
773    }
774
775    #[test]
776    fn test_ones() {
777        let a = ones::<f64, Ix1>(Ix1::new([5])).unwrap();
778        assert!(a.iter().all(|&v| v == 1.0));
779    }
780
781    #[test]
782    fn test_full() {
783        let a = full(Ix1::new([4]), 42i32).unwrap();
784        assert!(a.iter().all(|&v| v == 42));
785    }
786
787    #[test]
788    fn test_zeros_like() {
789        let a = ones::<f64, Ix2>(Ix2::new([2, 3])).unwrap();
790        let b = zeros_like(&a).unwrap();
791        assert_eq!(b.shape(), &[2, 3]);
792        assert!(b.iter().all(|&v| v == 0.0));
793    }
794
795    #[test]
796    fn test_ones_like() {
797        let a = zeros::<f64, Ix1>(Ix1::new([4])).unwrap();
798        let b = ones_like(&a).unwrap();
799        assert!(b.iter().all(|&v| v == 1.0));
800    }
801
802    #[test]
803    fn test_full_like() {
804        let a = zeros::<i32, Ix1>(Ix1::new([3])).unwrap();
805        let b = full_like(&a, 7).unwrap();
806        assert!(b.iter().all(|&v| v == 7));
807    }
808
809    // -- REQ-17 tests --
810
811    #[test]
812    fn test_empty_and_init() {
813        let mut u = empty::<f64, Ix1>(Ix1::new([3]));
814        assert_eq!(u.shape(), &[3]);
815        u.write_at(0, 1.0).unwrap();
816        u.write_at(1, 2.0).unwrap();
817        u.write_at(2, 3.0).unwrap();
818        // SAFETY: all elements initialized
819        let a = unsafe { u.assume_init() };
820        assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
821    }
822
823    #[test]
824    fn test_empty_write_oob() {
825        let mut u = empty::<f64, Ix1>(Ix1::new([2]));
826        assert!(u.write_at(5, 1.0).is_err());
827    }
828
829    // -- REQ-18 tests --
830
831    #[test]
832    fn test_arange_int() {
833        let a = arange(0i32, 5, 1).unwrap();
834        assert_eq!(a.as_slice().unwrap(), &[0, 1, 2, 3, 4]);
835    }
836
837    #[test]
838    fn test_arange_float() {
839        let a = arange(0.0_f64, 1.0, 0.25).unwrap();
840        assert_eq!(a.shape(), &[4]);
841        let data = a.as_slice().unwrap();
842        assert!((data[0] - 0.0).abs() < 1e-10);
843        assert!((data[1] - 0.25).abs() < 1e-10);
844        assert!((data[2] - 0.5).abs() < 1e-10);
845        assert!((data[3] - 0.75).abs() < 1e-10);
846    }
847
848    #[test]
849    fn test_arange_negative_step() {
850        let a = arange(5.0_f64, 0.0, -1.0).unwrap();
851        assert_eq!(a.shape(), &[5]);
852    }
853
854    #[test]
855    fn test_arange_zero_step() {
856        assert!(arange(0.0_f64, 1.0, 0.0).is_err());
857    }
858
859    #[test]
860    fn test_arange_empty() {
861        let a = arange(5i32, 0, 1).unwrap();
862        assert_eq!(a.shape(), &[0]);
863    }
864
865    #[test]
866    fn test_linspace() {
867        let a = linspace(0.0_f64, 1.0, 5, true).unwrap();
868        assert_eq!(a.shape(), &[5]);
869        let data = a.as_slice().unwrap();
870        assert!((data[0] - 0.0).abs() < 1e-10);
871        assert!((data[4] - 1.0).abs() < 1e-10);
872        assert!((data[2] - 0.5).abs() < 1e-10);
873    }
874
875    #[test]
876    fn test_linspace_no_endpoint() {
877        let a = linspace(0.0_f64, 1.0, 4, false).unwrap();
878        assert_eq!(a.shape(), &[4]);
879        let data = a.as_slice().unwrap();
880        assert!((data[0] - 0.0).abs() < 1e-10);
881        assert!((data[1] - 0.25).abs() < 1e-10);
882    }
883
884    #[test]
885    fn test_linspace_single() {
886        let a = linspace(5.0_f64, 10.0, 1, true).unwrap();
887        assert_eq!(a.as_slice().unwrap(), &[5.0]);
888    }
889
890    #[test]
891    fn test_linspace_empty() {
892        let a = linspace(0.0_f64, 1.0, 0, true).unwrap();
893        assert_eq!(a.shape(), &[0]);
894    }
895
896    #[test]
897    fn test_logspace() {
898        let a = logspace(0.0_f64, 2.0, 3, true, 10.0).unwrap();
899        let data = a.as_slice().unwrap();
900        assert!((data[0] - 1.0).abs() < 1e-10); // 10^0
901        assert!((data[1] - 10.0).abs() < 1e-10); // 10^1
902        assert!((data[2] - 100.0).abs() < 1e-10); // 10^2
903    }
904
905    #[test]
906    fn test_geomspace() {
907        let a = geomspace(1.0_f64, 1000.0, 4, true).unwrap();
908        let data = a.as_slice().unwrap();
909        assert!((data[0] - 1.0).abs() < 1e-10);
910        assert!((data[1] - 10.0).abs() < 1e-8);
911        assert!((data[2] - 100.0).abs() < 1e-6);
912        assert!((data[3] - 1000.0).abs() < 1e-4);
913    }
914
915    #[test]
916    fn test_geomspace_zero_start() {
917        assert!(geomspace(0.0_f64, 1.0, 5, true).is_err());
918    }
919
920    #[test]
921    fn test_geomspace_different_signs() {
922        assert!(geomspace(-1.0_f64, 1.0, 5, true).is_err());
923    }
924
925    #[test]
926    fn test_meshgrid_xy() {
927        let x = Array::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
928        let y = Array::from_vec(Ix1::new([2]), vec![4.0, 5.0]).unwrap();
929        let grids = meshgrid(&[x, y], "xy").unwrap();
930        assert_eq!(grids.len(), 2);
931        assert_eq!(grids[0].shape(), &[2, 3]);
932        assert_eq!(grids[1].shape(), &[2, 3]);
933        // X grid: rows are [1,2,3], [1,2,3]
934        let xdata: Vec<f64> = grids[0].iter().copied().collect();
935        assert_eq!(xdata, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
936        // Y grid: rows are [4,4,4], [5,5,5]
937        let ydata: Vec<f64> = grids[1].iter().copied().collect();
938        assert_eq!(ydata, vec![4.0, 4.0, 4.0, 5.0, 5.0, 5.0]);
939    }
940
941    #[test]
942    fn test_meshgrid_ij() {
943        let x = Array::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
944        let y = Array::from_vec(Ix1::new([2]), vec![4.0, 5.0]).unwrap();
945        let grids = meshgrid(&[x, y], "ij").unwrap();
946        assert_eq!(grids.len(), 2);
947        assert_eq!(grids[0].shape(), &[3, 2]);
948        assert_eq!(grids[1].shape(), &[3, 2]);
949    }
950
951    #[test]
952    fn test_meshgrid_bad_indexing() {
953        assert!(meshgrid(&[], "zz").is_err());
954    }
955
956    #[test]
957    fn test_mgrid() {
958        let grids = mgrid(&[(0.0, 3.0, 1.0), (0.0, 2.0, 1.0)]).unwrap();
959        assert_eq!(grids.len(), 2);
960        assert_eq!(grids[0].shape(), &[3, 2]);
961    }
962
963    #[test]
964    fn test_ogrid() {
965        let grids = ogrid(&[(0.0, 3.0, 1.0), (0.0, 2.0, 1.0)]).unwrap();
966        assert_eq!(grids.len(), 2);
967        assert_eq!(grids[0].shape(), &[3, 1]);
968        assert_eq!(grids[1].shape(), &[1, 2]);
969    }
970
971    // -- REQ-19 tests --
972
973    #[test]
974    fn test_identity() {
975        let a = identity::<f64>(3).unwrap();
976        assert_eq!(a.shape(), &[3, 3]);
977        let data = a.as_slice().unwrap();
978        assert_eq!(data, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
979    }
980
981    #[test]
982    fn test_eye() {
983        let a = eye::<f64>(3, 4, 0).unwrap();
984        assert_eq!(a.shape(), &[3, 4]);
985        let data = a.as_slice().unwrap();
986        assert_eq!(
987            data,
988            &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
989        );
990    }
991
992    #[test]
993    fn test_eye_positive_k() {
994        let a = eye::<f64>(3, 3, 1).unwrap();
995        let data = a.as_slice().unwrap();
996        assert_eq!(data, &[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]);
997    }
998
999    #[test]
1000    fn test_eye_negative_k() {
1001        let a = eye::<f64>(3, 3, -1).unwrap();
1002        let data = a.as_slice().unwrap();
1003        assert_eq!(data, &[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
1004    }
1005
1006    #[test]
1007    fn test_diag_from_1d() {
1008        let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1009        let d = diag(&a, 0).unwrap();
1010        assert_eq!(d.shape(), &[3, 3]);
1011        let data: Vec<f64> = d.iter().copied().collect();
1012        assert_eq!(data, vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
1013    }
1014
1015    #[test]
1016    fn test_diag_from_2d() {
1017        let a = Array::from_vec(
1018            IxDyn::new(&[3, 3]),
1019            vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0],
1020        )
1021        .unwrap();
1022        let d = diag(&a, 0).unwrap();
1023        assert_eq!(d.shape(), &[3]);
1024        let data: Vec<f64> = d.iter().copied().collect();
1025        assert_eq!(data, vec![1.0, 2.0, 3.0]);
1026    }
1027
1028    #[test]
1029    fn test_diag_k_positive() {
1030        let a = Array::from_vec(IxDyn::new(&[2]), vec![1.0, 2.0]).unwrap();
1031        let d = diag(&a, 1).unwrap();
1032        assert_eq!(d.shape(), &[3, 3]);
1033        let data: Vec<f64> = d.iter().copied().collect();
1034        assert_eq!(data, vec![0.0, 1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0]);
1035    }
1036
1037    #[test]
1038    fn test_diagflat() {
1039        let a = Array::from_vec(IxDyn::new(&[2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1040        let d = diagflat(&a, 0).unwrap();
1041        assert_eq!(d.shape(), &[4, 4]);
1042        // Diagonal should be [1, 2, 3, 4]
1043        let extracted = diag(&d, 0).unwrap();
1044        let data: Vec<f64> = extracted.iter().copied().collect();
1045        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
1046    }
1047
1048    #[test]
1049    fn test_tri() {
1050        let a = tri::<f64>(3, 3, 0).unwrap();
1051        let data = a.as_slice().unwrap();
1052        assert_eq!(data, &[1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0]);
1053    }
1054
1055    #[test]
1056    fn test_tril() {
1057        let a = Array::from_vec(
1058            IxDyn::new(&[3, 3]),
1059            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1060        )
1061        .unwrap();
1062        let t = tril(&a, 0).unwrap();
1063        let data: Vec<f64> = t.iter().copied().collect();
1064        assert_eq!(data, vec![1.0, 0.0, 0.0, 4.0, 5.0, 0.0, 7.0, 8.0, 9.0]);
1065    }
1066
1067    #[test]
1068    fn test_triu() {
1069        let a = Array::from_vec(
1070            IxDyn::new(&[3, 3]),
1071            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1072        )
1073        .unwrap();
1074        let t = triu(&a, 0).unwrap();
1075        let data: Vec<f64> = t.iter().copied().collect();
1076        assert_eq!(data, vec![1.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0, 0.0, 9.0]);
1077    }
1078
1079    #[test]
1080    fn test_tril_not_2d() {
1081        let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1082        assert!(tril(&a, 0).is_err());
1083    }
1084
1085    #[test]
1086    fn test_triu_not_2d() {
1087        let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1088        assert!(triu(&a, 0).is_err());
1089    }
1090}