Skip to main content

ferray_core/creation/
mod.rs

1// ferray-core: Array creation functions (REQ-16, REQ-17, REQ-18, REQ-19)
2//
3// Mirrors numpy's array creation routines: zeros, ones, full, empty,
4// arange, linspace, logspace, geomspace, eye, identity, diag, etc.
5
6use std::mem::MaybeUninit;
7
8use crate::array::owned::Array;
9use crate::array::view::ArrayView;
10use crate::dimension::{Dimension, Ix1, Ix2, IxDyn};
11use crate::dtype::Element;
12use crate::error::{FerrayError, FerrayResult};
13
14// ============================================================================
15// REQ-16: Basic creation functions
16// ============================================================================
17
18/// Create an array from a flat vector and a shape (C-order).
19///
20/// This is the primary "array constructor" — analogous to `numpy.array()` when
21/// given a flat sequence plus a shape.
22///
23/// # Errors
24/// Returns `FerrayError::ShapeMismatch` if `data.len()` does not equal the
25/// product of the shape dimensions.
26pub fn array<T: Element, D: Dimension>(dim: D, data: Vec<T>) -> FerrayResult<Array<T, D>> {
27    Array::from_vec(dim, data)
28}
29
30/// Interpret existing data as an array without copying (if possible).
31///
32/// This is equivalent to `numpy.asarray()`. Since Rust ownership rules
33/// require moving the data, this creates an owned array from the vector.
34///
35/// # Errors
36/// Returns `FerrayError::ShapeMismatch` if lengths don't match.
37pub fn asarray<T: Element, D: Dimension>(dim: D, data: Vec<T>) -> FerrayResult<Array<T, D>> {
38    Array::from_vec(dim, data)
39}
40
41/// Create an array from a byte buffer, interpreting bytes as elements of type `T`.
42///
43/// Analogous to `numpy.frombuffer()`.
44///
45/// # Errors
46/// Returns `FerrayError::InvalidValue` if the buffer length is not a multiple
47/// of `size_of::<T>()`, or if the resulting length does not match the shape.
48pub fn frombuffer<T: Element, D: Dimension>(dim: D, buf: &[u8]) -> FerrayResult<Array<T, D>> {
49    let elem_size = std::mem::size_of::<T>();
50    if elem_size == 0 {
51        return Err(FerrayError::invalid_value("zero-sized type"));
52    }
53    if buf.len() % elem_size != 0 {
54        return Err(FerrayError::invalid_value(format!(
55            "buffer length {} is not a multiple of element size {}",
56            buf.len(),
57            elem_size,
58        )));
59    }
60    let n_elems = buf.len() / elem_size;
61    let expected = dim.size();
62    if n_elems != expected {
63        return Err(FerrayError::shape_mismatch(format!(
64            "buffer contains {} elements but shape {:?} requires {}",
65            n_elems,
66            dim.as_slice(),
67            expected,
68        )));
69    }
70    // Validate bytes for types where not all bit patterns are valid.
71    // bool only permits 0x00 and 0x01.
72    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<bool>() {
73        for &byte in buf {
74            if byte > 1 {
75                return Err(FerrayError::invalid_value(format!(
76                    "invalid byte {byte:#04x} for bool (must be 0x00 or 0x01)"
77                )));
78            }
79        }
80    }
81
82    // Copy bytes element-by-element via from_ne_bytes equivalent
83    let mut data = Vec::with_capacity(n_elems);
84    for i in 0..n_elems {
85        let start = i * elem_size;
86        let end = start + elem_size;
87        let slice = &buf[start..end];
88        // SAFETY: We're reading elem_size bytes and interpreting as T.
89        // For bool, we validated above that all bytes are 0 or 1.
90        // For numeric types, all bit patterns are valid.
91        let val = unsafe {
92            let mut val = MaybeUninit::<T>::uninit();
93            std::ptr::copy_nonoverlapping(slice.as_ptr(), val.as_mut_ptr() as *mut u8, elem_size);
94            val.assume_init()
95        };
96        data.push(val);
97    }
98    Array::from_vec(dim, data)
99}
100
101/// Create a zero-copy [`ArrayView`] over an existing byte buffer (#364).
102///
103/// Unlike [`frombuffer`], which copies bytes into a freshly allocated
104/// `Array`, this function returns a view whose lifetime is tied to the
105/// input slice. This is the equivalent of NumPy's `np.frombuffer()` with
106/// a memoryview source — the primary building block for zero-copy
107/// interop with mmap, shared memory, network buffers, and FFI.
108///
109/// # Errors
110/// - `InvalidValue` if `T` is a ZST.
111/// - `InvalidValue` if `buf.len()` is not a multiple of `size_of::<T>()`.
112/// - `ShapeMismatch` if the element count doesn't match `dim.size()`.
113/// - `InvalidValue` if `buf.as_ptr()` is not aligned to `align_of::<T>()`
114///   (views require proper alignment — use the copying [`frombuffer`]
115///   instead if alignment cannot be guaranteed).
116/// - `InvalidValue` if `T` is `bool` and any byte is outside `{0x00, 0x01}`.
117pub fn frombuffer_view<'a, T: Element, D: Dimension>(
118    dim: D,
119    buf: &'a [u8],
120) -> FerrayResult<ArrayView<'a, T, D>> {
121    let elem_size = std::mem::size_of::<T>();
122    if elem_size == 0 {
123        return Err(FerrayError::invalid_value("zero-sized type"));
124    }
125    if buf.len() % elem_size != 0 {
126        return Err(FerrayError::invalid_value(format!(
127            "buffer length {} is not a multiple of element size {}",
128            buf.len(),
129            elem_size,
130        )));
131    }
132    let n_elems = buf.len() / elem_size;
133    let expected = dim.size();
134    if n_elems != expected {
135        return Err(FerrayError::shape_mismatch(format!(
136            "buffer contains {} elements but shape {:?} requires {}",
137            n_elems,
138            dim.as_slice(),
139            expected,
140        )));
141    }
142
143    // Alignment: a view interprets the bytes in place, so the buffer must
144    // already be aligned for T. A misaligned read of f32/f64/etc. is UB.
145    let align = std::mem::align_of::<T>();
146    let addr = buf.as_ptr() as usize;
147    if addr % align != 0 {
148        return Err(FerrayError::invalid_value(format!(
149            "buffer address 0x{addr:x} is not aligned to {align} bytes required by the element type; \
150             use `frombuffer` for misaligned input"
151        )));
152    }
153
154    // bool has the same size/alignment as u8 but restricts the valid bit
155    // patterns; validate exhaustively, matching the copy-based frombuffer.
156    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<bool>() {
157        for &byte in buf {
158            if byte > 1 {
159                return Err(FerrayError::invalid_value(format!(
160                    "invalid byte {byte:#04x} for bool (must be 0x00 or 0x01)"
161                )));
162            }
163        }
164    }
165
166    // SAFETY:
167    // - The pointer comes from a valid `&[u8]` slice with length
168    //   `n_elems * elem_size`, so the region is valid for reads of
169    //   `n_elems` `T` values.
170    // - Alignment was checked above.
171    // - For bool, bit patterns were validated above. For all other
172    //   `Element` types, every bit pattern is a valid value.
173    // - The returned view's lifetime is bound to `'a = &'a [u8]`, which
174    //   tracks the borrow back to `buf`, so the memory cannot be freed
175    //   or mutated while the view lives.
176    let ptr = buf.as_ptr() as *const T;
177    let nd_dim = dim.to_ndarray_dim();
178    let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_dim, ptr) };
179    Ok(ArrayView::from_ndarray(nd_view))
180}
181
182/// Create a 1-D array from an iterator.
183///
184/// Analogous to `numpy.fromiter()`.
185///
186/// # Errors
187/// This function always succeeds (returns `Ok`).
188pub fn fromiter<T: Element>(iter: impl IntoIterator<Item = T>) -> FerrayResult<Array<T, Ix1>> {
189    Array::from_iter_1d(iter)
190}
191
192/// Create an array filled with zeros.
193///
194/// Analogous to `numpy.zeros()`.
195pub fn zeros<T: Element, D: Dimension>(dim: D) -> FerrayResult<Array<T, D>> {
196    Array::zeros(dim)
197}
198
199/// Create an array filled with ones.
200///
201/// Analogous to `numpy.ones()`.
202pub fn ones<T: Element, D: Dimension>(dim: D) -> FerrayResult<Array<T, D>> {
203    Array::ones(dim)
204}
205
206/// Create an array filled with a given value.
207///
208/// Analogous to `numpy.full()`.
209pub fn full<T: Element, D: Dimension>(dim: D, fill_value: T) -> FerrayResult<Array<T, D>> {
210    Array::from_elem(dim, fill_value)
211}
212
213/// Create an array with the same shape as `other`, filled with zeros.
214///
215/// Analogous to `numpy.zeros_like()`.
216pub fn zeros_like<T: Element, D: Dimension>(other: &Array<T, D>) -> FerrayResult<Array<T, D>> {
217    Array::zeros(other.dim().clone())
218}
219
220/// Create an array with the same shape as `other`, filled with ones.
221///
222/// Analogous to `numpy.ones_like()`.
223pub fn ones_like<T: Element, D: Dimension>(other: &Array<T, D>) -> FerrayResult<Array<T, D>> {
224    Array::ones(other.dim().clone())
225}
226
227/// Create an array with the same shape as `other`, filled with `fill_value`.
228///
229/// Analogous to `numpy.full_like()`.
230pub fn full_like<T: Element, D: Dimension>(
231    other: &Array<T, D>,
232    fill_value: T,
233) -> FerrayResult<Array<T, D>> {
234    Array::from_elem(other.dim().clone(), fill_value)
235}
236
237// ============================================================================
238// REQ-17: empty() returning MaybeUninit
239// ============================================================================
240
241/// An array whose elements have not been initialized.
242///
243/// The caller must call [`assume_init`](UninitArray::assume_init) after
244/// filling all elements.
245pub struct UninitArray<T: Element, D: Dimension> {
246    data: Vec<MaybeUninit<T>>,
247    dim: D,
248}
249
250impl<T: Element, D: Dimension> UninitArray<T, D> {
251    /// Shape as a slice.
252    #[inline]
253    pub fn shape(&self) -> &[usize] {
254        self.dim.as_slice()
255    }
256
257    /// Total number of elements.
258    #[inline]
259    pub fn size(&self) -> usize {
260        self.data.len()
261    }
262
263    /// Number of dimensions.
264    #[inline]
265    pub fn ndim(&self) -> usize {
266        self.dim.ndim()
267    }
268
269    /// Get a mutable raw pointer to the underlying data.
270    ///
271    /// Use this to fill the array element-by-element before calling
272    /// `assume_init()`.
273    #[inline]
274    pub fn as_mut_ptr(&mut self) -> *mut MaybeUninit<T> {
275        self.data.as_mut_ptr()
276    }
277
278    /// Write a value at a flat index.
279    ///
280    /// # Errors
281    /// Returns `FerrayError::IndexOutOfBounds` if `flat_index >= size()`.
282    pub fn write_at(&mut self, flat_index: usize, value: T) -> FerrayResult<()> {
283        let size = self.size();
284        if flat_index >= size {
285            return Err(FerrayError::IndexOutOfBounds {
286                index: flat_index as isize,
287                axis: 0,
288                size,
289            });
290        }
291        self.data[flat_index] = MaybeUninit::new(value);
292        Ok(())
293    }
294
295    /// Convert to an initialized `Array<T, D>`.
296    ///
297    /// # Safety
298    /// The caller must ensure that **all** elements have been initialized
299    /// (e.g., via `write_at` or raw pointer writes). Reading uninitialized
300    /// memory is undefined behavior.
301    pub unsafe fn assume_init(self) -> Array<T, D> {
302        let nd_dim = self.dim.to_ndarray_dim();
303        let len = self.data.len();
304
305        // Transmute Vec<MaybeUninit<T>> to Vec<T>.
306        // SAFETY: MaybeUninit<T> has the same layout as T, and the caller
307        // guarantees all elements are initialized.
308        let mut raw_vec = std::mem::ManuallyDrop::new(self.data);
309        let data: Vec<T> =
310            unsafe { Vec::from_raw_parts(raw_vec.as_mut_ptr() as *mut T, len, raw_vec.capacity()) };
311
312        let inner = ndarray::Array::from_shape_vec(nd_dim, data)
313            .expect("UninitArray assume_init: shape/data mismatch (this is a bug)");
314        Array::from_ndarray(inner)
315    }
316}
317
318/// Create an uninitialized array.
319///
320/// Analogous to `numpy.empty()`, but returns a [`UninitArray`] that must
321/// be explicitly initialized via [`UninitArray::assume_init`].
322///
323/// This prevents accidentally reading uninitialized memory — a key safety
324/// improvement over NumPy's `empty()`.
325pub fn empty<T: Element, D: Dimension>(dim: D) -> UninitArray<T, D> {
326    let size = dim.size();
327    let mut data = Vec::with_capacity(size);
328    // SAFETY: MaybeUninit does not require initialization.
329    // We set the length to match the capacity; each element is MaybeUninit.
330    unsafe {
331        data.set_len(size);
332    }
333    UninitArray { data, dim }
334}
335
336/// Create an uninitialized array with the same shape (and element type)
337/// as `other`.
338///
339/// Analogous to `numpy.empty_like()`. Returns a [`UninitArray`] that the
340/// caller must fully initialize before calling
341/// [`UninitArray::assume_init`]. Avoids the memset that `zeros_like` /
342/// `full_like` incur when the caller is about to overwrite every element
343/// anyway.
344pub fn empty_like<T: Element, D: Dimension>(other: &Array<T, D>) -> UninitArray<T, D> {
345    empty(other.dim().clone())
346}
347
348// ============================================================================
349// REQ-18: Range functions
350// ============================================================================
351
352/// Trait for types usable with `arange` — numeric types that support
353/// stepping and comparison.
354pub trait ArangeNum: Element + PartialOrd {
355    /// Convert from f64 for step calculations.
356    fn from_f64(v: f64) -> Self;
357    /// Convert to f64 for step calculations.
358    fn to_f64(self) -> f64;
359}
360
361macro_rules! impl_arange_int {
362    ($($ty:ty),*) => {
363        $(
364            impl ArangeNum for $ty {
365                #[inline]
366                fn from_f64(v: f64) -> Self { v as Self }
367                #[inline]
368                fn to_f64(self) -> f64 { self as f64 }
369            }
370        )*
371    };
372}
373
374macro_rules! impl_arange_float {
375    ($($ty:ty),*) => {
376        $(
377            impl ArangeNum for $ty {
378                #[inline]
379                fn from_f64(v: f64) -> Self { v as Self }
380                #[inline]
381                fn to_f64(self) -> f64 { self as f64 }
382            }
383        )*
384    };
385}
386
387impl_arange_int!(u8, u16, u32, u64, i8, i16, i32, i64);
388impl_arange_float!(f32, f64);
389
390/// Create a 1-D array with evenly spaced values within a given interval.
391///
392/// Analogous to `numpy.arange(start, stop, step)`.
393///
394/// # Errors
395/// Returns `FerrayError::InvalidValue` if `step` is zero.
396pub fn arange<T: ArangeNum>(start: T, stop: T, step: T) -> FerrayResult<Array<T, Ix1>> {
397    let step_f = step.to_f64();
398    if step_f == 0.0 {
399        return Err(FerrayError::invalid_value("step cannot be zero"));
400    }
401    let start_f = start.to_f64();
402    let stop_f = stop.to_f64();
403    let n = ((stop_f - start_f) / step_f).ceil();
404    let n = if n < 0.0 { 0 } else { n as usize };
405
406    let mut data = Vec::with_capacity(n);
407    for i in 0..n {
408        data.push(T::from_f64(start_f + (i as f64) * step_f));
409    }
410    let dim = Ix1::new([data.len()]);
411    Array::from_vec(dim, data)
412}
413
414/// Trait for float-like types used in linspace/logspace/geomspace.
415pub trait LinspaceNum: Element + PartialOrd {
416    /// Convert from f64.
417    fn from_f64(v: f64) -> Self;
418    /// Convert to f64.
419    fn to_f64(self) -> f64;
420}
421
422impl LinspaceNum for f32 {
423    #[inline]
424    fn from_f64(v: f64) -> Self {
425        v as f32
426    }
427    #[inline]
428    fn to_f64(self) -> f64 {
429        self as f64
430    }
431}
432
433impl LinspaceNum for f64 {
434    #[inline]
435    fn from_f64(v: f64) -> Self {
436        v
437    }
438    #[inline]
439    fn to_f64(self) -> f64 {
440        self
441    }
442}
443
444/// Create a 1-D array with `num` evenly spaced values between `start` and `stop`.
445///
446/// If `endpoint` is true (the default in NumPy), `stop` is the last sample.
447/// Otherwise, it is not included.
448///
449/// Analogous to `numpy.linspace()`.
450///
451/// # Errors
452/// Returns `FerrayError::InvalidValue` if `num` is 0 and `endpoint` is true
453/// (cannot produce an empty array with an endpoint).
454pub fn linspace<T: LinspaceNum>(
455    start: T,
456    stop: T,
457    num: usize,
458    endpoint: bool,
459) -> FerrayResult<Array<T, Ix1>> {
460    if num == 0 {
461        return Array::from_vec(Ix1::new([0]), vec![]);
462    }
463    if num == 1 {
464        return Array::from_vec(Ix1::new([1]), vec![start]);
465    }
466    let start_f = start.to_f64();
467    let stop_f = stop.to_f64();
468    let divisor = if endpoint {
469        (num - 1) as f64
470    } else {
471        num as f64
472    };
473    let step = (stop_f - start_f) / divisor;
474    let mut data = Vec::with_capacity(num);
475    for i in 0..num {
476        data.push(T::from_f64(start_f + (i as f64) * step));
477    }
478    Array::from_vec(Ix1::new([num]), data)
479}
480
481/// Create a 1-D array with values spaced evenly on a log scale.
482///
483/// Returns `base ** linspace(start, stop, num)`.
484///
485/// Analogous to `numpy.logspace()`.
486///
487/// # Errors
488/// Propagates errors from `linspace`.
489pub fn logspace<T: LinspaceNum>(
490    start: T,
491    stop: T,
492    num: usize,
493    endpoint: bool,
494    base: f64,
495) -> FerrayResult<Array<T, Ix1>> {
496    let lin = linspace(start, stop, num, endpoint)?;
497    let data: Vec<T> = lin
498        .iter()
499        .map(|v| T::from_f64(base.powf(v.clone().to_f64())))
500        .collect();
501    Array::from_vec(Ix1::new([num]), data)
502}
503
504/// Create a 1-D array with values spaced evenly on a geometric (log) scale.
505///
506/// The values are `start * (stop/start) ** linspace(0, 1, num)`.
507///
508/// Analogous to `numpy.geomspace()`.
509///
510/// # Errors
511/// Returns `FerrayError::InvalidValue` if `start` or `stop` is zero or
512/// if they have different signs.
513pub fn geomspace<T: LinspaceNum>(
514    start: T,
515    stop: T,
516    num: usize,
517    endpoint: bool,
518) -> FerrayResult<Array<T, Ix1>> {
519    let start_f = start.clone().to_f64();
520    let stop_f = stop.clone().to_f64();
521    if start_f == 0.0 || stop_f == 0.0 {
522        return Err(FerrayError::invalid_value(
523            "geomspace: start and stop must be non-zero",
524        ));
525    }
526    if (start_f < 0.0) != (stop_f < 0.0) {
527        return Err(FerrayError::invalid_value(
528            "geomspace: start and stop must have the same sign",
529        ));
530    }
531    if num == 0 {
532        return Array::from_vec(Ix1::new([0]), vec![]);
533    }
534    if num == 1 {
535        return Array::from_vec(Ix1::new([1]), vec![start]);
536    }
537    let log_start = start_f.abs().ln();
538    let log_stop = stop_f.abs().ln();
539    let sign = if start_f < 0.0 { -1.0 } else { 1.0 };
540    let divisor = if endpoint {
541        (num - 1) as f64
542    } else {
543        num as f64
544    };
545    let step = (log_stop - log_start) / divisor;
546    let mut data = Vec::with_capacity(num);
547    for i in 0..num {
548        let log_val = log_start + (i as f64) * step;
549        data.push(T::from_f64(sign * log_val.exp()));
550    }
551    Array::from_vec(Ix1::new([num]), data)
552}
553
554/// Return coordinate arrays from coordinate vectors.
555///
556/// Analogous to `numpy.meshgrid(*xi, indexing='xy')`.
557///
558/// Given N 1-D arrays, returns N N-D arrays, where each output array
559/// has the shape `(len(x1), len(x2), ..., len(xN))` for 'xy' indexing
560/// or `(len(x1), ..., len(xN))` transposed for 'ij' indexing.
561///
562/// `indexing` should be `"xy"` (default Cartesian) or `"ij"` (matrix).
563///
564/// # Errors
565/// Returns `FerrayError::InvalidValue` if `indexing` is not `"xy"` or `"ij"`,
566/// or if there are fewer than 2 input arrays.
567pub fn meshgrid(
568    arrays: &[Array<f64, Ix1>],
569    indexing: &str,
570) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
571    if indexing != "xy" && indexing != "ij" {
572        return Err(FerrayError::invalid_value(
573            "meshgrid: indexing must be 'xy' or 'ij'",
574        ));
575    }
576    let ndim = arrays.len();
577    if ndim == 0 {
578        return Ok(vec![]);
579    }
580
581    let mut shapes: Vec<usize> = arrays.iter().map(|a| a.shape()[0]).collect();
582    if indexing == "xy" && ndim >= 2 {
583        shapes.swap(0, 1);
584    }
585
586    let total: usize = shapes.iter().product();
587    let mut results = Vec::with_capacity(ndim);
588
589    for (k, arr) in arrays.iter().enumerate() {
590        let src_data: Vec<f64> = arr.iter().copied().collect();
591        let mut data = Vec::with_capacity(total);
592        // For 'xy' indexing, the first two dimensions are swapped
593        let effective_k = if indexing == "xy" && ndim >= 2 {
594            match k {
595                0 => 1,
596                1 => 0,
597                other => other,
598            }
599        } else {
600            k
601        };
602
603        // Build the output by iterating over all indices in the output shape
604        for flat in 0..total {
605            // Compute the index along dimension effective_k
606            let mut rem = flat;
607            let mut idx_k = 0;
608            for (d, &s) in shapes.iter().enumerate().rev() {
609                if d == effective_k {
610                    idx_k = rem % s;
611                }
612                rem /= s;
613            }
614            data.push(src_data[idx_k]);
615        }
616
617        let dim = IxDyn::new(&shapes);
618        results.push(Array::from_vec(dim, data)?);
619    }
620    Ok(results)
621}
622
623/// Create a dense multi-dimensional "meshgrid" with matrix ('ij') indexing.
624///
625/// Analogous to `numpy.mgrid[start:stop:step, ...]`.
626///
627/// Takes a slice of `(start, stop, step)` tuples, one per dimension.
628/// Returns a vector of arrays, one per dimension.
629///
630/// # Errors
631/// Returns `FerrayError::InvalidValue` if any step is zero.
632pub fn mgrid(ranges: &[(f64, f64, f64)]) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
633    let mut arrs: Vec<Array<f64, Ix1>> = Vec::with_capacity(ranges.len());
634    for &(start, stop, step) in ranges {
635        arrs.push(arange(start, stop, step)?);
636    }
637    meshgrid(&arrs, "ij")
638}
639
640/// Create a sparse (open) multi-dimensional "meshgrid" with 'ij' indexing.
641///
642/// Analogous to `numpy.ogrid[start:stop:step, ...]`.
643///
644/// Returns arrays that are broadcastable to the full grid shape.
645/// Each returned array has shape 1 in all dimensions except its own.
646///
647/// # Errors
648/// Returns `FerrayError::InvalidValue` if any step is zero.
649pub fn ogrid(ranges: &[(f64, f64, f64)]) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
650    let ndim = ranges.len();
651    let mut results = Vec::with_capacity(ndim);
652    for (i, &(start, stop, step)) in ranges.iter().enumerate() {
653        let arr1d = arange(start, stop, step)?;
654        let n = arr1d.shape()[0];
655        let data: Vec<f64> = arr1d.iter().copied().collect();
656        // Build shape: all ones except dimension i = n
657        let mut shape = vec![1usize; ndim];
658        shape[i] = n;
659        let dim = IxDyn::new(&shape);
660        results.push(Array::from_vec(dim, data)?);
661    }
662    Ok(results)
663}
664
665// ============================================================================
666// REQ-19: Identity/diagonal functions
667// ============================================================================
668
669/// Create a 2-D identity matrix of size `n x n`.
670///
671/// Analogous to `numpy.identity()`.
672pub fn identity<T: Element>(n: usize) -> FerrayResult<Array<T, Ix2>> {
673    eye(n, n, 0)
674}
675
676/// Create a 2-D array with ones on the diagonal and zeros elsewhere.
677///
678/// `k` is the diagonal offset: 0 = main diagonal, positive = above, negative = below.
679///
680/// Analogous to `numpy.eye(N, M, k)`.
681pub fn eye<T: Element>(n: usize, m: usize, k: isize) -> FerrayResult<Array<T, Ix2>> {
682    let mut data = vec![T::zero(); n * m];
683    for i in 0..n {
684        let j = i as isize + k;
685        if j >= 0 && (j as usize) < m {
686            data[i * m + j as usize] = T::one();
687        }
688    }
689    Array::from_vec(Ix2::new([n, m]), data)
690}
691
692/// Extract a diagonal or construct a diagonal array.
693///
694/// If `a` is 2-D, extract the `k`-th diagonal as a 1-D array.
695/// If `a` is 1-D, construct a 2-D array with `a` on the `k`-th diagonal.
696///
697/// Analogous to `numpy.diag()`.
698///
699/// # Errors
700/// Returns `FerrayError::InvalidValue` if `a` is not 1-D or 2-D.
701pub fn diag<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
702    let shape = a.shape();
703    match shape.len() {
704        1 => {
705            // Construct a 2-D diagonal array
706            let n = shape[0];
707            let size = n + k.unsigned_abs();
708            let mut data = vec![T::zero(); size * size];
709            let src: Vec<T> = a.iter().cloned().collect();
710            for (i, val) in src.into_iter().enumerate() {
711                let row = if k >= 0 { i } else { i + k.unsigned_abs() };
712                let col = if k >= 0 { i + k as usize } else { i };
713                data[row * size + col] = val;
714            }
715            Array::from_vec(IxDyn::new(&[size, size]), data)
716        }
717        2 => {
718            // Extract the k-th diagonal
719            let (n, m) = (shape[0], shape[1]);
720            let src: Vec<T> = a.iter().cloned().collect();
721            let mut diag_vals = Vec::new();
722            for i in 0..n {
723                let j = i as isize + k;
724                if j >= 0 && (j as usize) < m {
725                    diag_vals.push(src[i * m + j as usize].clone());
726                }
727            }
728            let len = diag_vals.len();
729            Array::from_vec(IxDyn::new(&[len]), diag_vals)
730        }
731        _ => Err(FerrayError::invalid_value("diag: input must be 1-D or 2-D")),
732    }
733}
734
735/// Create a 2-D array with the flattened input as a diagonal.
736///
737/// Analogous to `numpy.diagflat()`.
738///
739/// # Errors
740/// Propagates errors from the underlying construction.
741pub fn diagflat<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
742    // Flatten a to 1-D, then call diag
743    let flat: Vec<T> = a.iter().cloned().collect();
744    let n = flat.len();
745    let arr1d = Array::from_vec(IxDyn::new(&[n]), flat)?;
746    diag(&arr1d, k)
747}
748
749/// Create a lower-triangular matrix of ones.
750///
751/// Returns an `n x m` array where `a[i, j] = 1` if `i >= j - k`, else `0`.
752///
753/// Analogous to `numpy.tri(N, M, k)`.
754pub fn tri<T: Element>(n: usize, m: usize, k: isize) -> FerrayResult<Array<T, Ix2>> {
755    let mut data = vec![T::zero(); n * m];
756    for i in 0..n {
757        for j in 0..m {
758            if (i as isize) >= (j as isize) - k {
759                data[i * m + j] = T::one();
760            }
761        }
762    }
763    Array::from_vec(Ix2::new([n, m]), data)
764}
765
766/// Return the lower triangle of a 2-D array.
767///
768/// `k` is the diagonal above which to zero elements. 0 = main diagonal.
769///
770/// Analogous to `numpy.tril()`.
771///
772/// # Errors
773/// Returns `FerrayError::InvalidValue` if input is not 2-D.
774pub fn tril<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
775    let shape = a.shape();
776    if shape.len() != 2 {
777        return Err(FerrayError::invalid_value("tril: input must be 2-D"));
778    }
779    let (n, m) = (shape[0], shape[1]);
780    let src: Vec<T> = a.iter().cloned().collect();
781    let mut data = vec![T::zero(); n * m];
782    for i in 0..n {
783        for j in 0..m {
784            if (i as isize) >= (j as isize) - k {
785                data[i * m + j] = src[i * m + j].clone();
786            }
787        }
788    }
789    Array::from_vec(IxDyn::new(&[n, m]), data)
790}
791
792/// Return the upper triangle of a 2-D array.
793///
794/// `k` is the diagonal below which to zero elements. 0 = main diagonal.
795///
796/// Analogous to `numpy.triu()`.
797///
798/// # Errors
799/// Returns `FerrayError::InvalidValue` if input is not 2-D.
800pub fn triu<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
801    let shape = a.shape();
802    if shape.len() != 2 {
803        return Err(FerrayError::invalid_value("triu: input must be 2-D"));
804    }
805    let (n, m) = (shape[0], shape[1]);
806    let src: Vec<T> = a.iter().cloned().collect();
807    let mut data = vec![T::zero(); n * m];
808    for i in 0..n {
809        for j in 0..m {
810            if (i as isize) <= (j as isize) - k {
811                data[i * m + j] = src[i * m + j].clone();
812            }
813        }
814    }
815    Array::from_vec(IxDyn::new(&[n, m]), data)
816}
817
818// ============================================================================
819// Tests
820// ============================================================================
821
822#[cfg(test)]
823mod tests {
824    use super::*;
825    use crate::dimension::{Ix1, Ix2, IxDyn};
826
827    // -- REQ-16 tests --
828
829    #[test]
830    fn test_array_creation() {
831        let a = array(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
832        assert_eq!(a.shape(), &[2, 3]);
833        assert_eq!(a.size(), 6);
834    }
835
836    #[test]
837    fn test_asarray() {
838        let a = asarray(Ix1::new([3]), vec![1, 2, 3]).unwrap();
839        assert_eq!(a.as_slice().unwrap(), &[1, 2, 3]);
840    }
841
842    #[test]
843    fn test_frombuffer() {
844        let bytes: Vec<u8> = vec![1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0];
845        let a = frombuffer::<i32, Ix1>(Ix1::new([3]), &bytes).unwrap();
846        assert_eq!(a.as_slice().unwrap(), &[1, 2, 3]);
847    }
848
849    #[test]
850    fn test_frombuffer_bad_length() {
851        let bytes: Vec<u8> = vec![1, 0, 0];
852        assert!(frombuffer::<i32, Ix1>(Ix1::new([1]), &bytes).is_err());
853    }
854
855    #[test]
856    fn test_frombuffer_bool() {
857        // Issue #135: bool round-trips through frombuffer must
858        // preserve the discriminating byte (0 -> false, nonzero
859        // -> true, although our raw-buffer contract is that each
860        // byte is a valid bool per `Element`).
861        let bytes: Vec<u8> = vec![0, 1, 0, 1, 1];
862        let a = frombuffer::<bool, Ix1>(Ix1::new([5]), &bytes).unwrap();
863        assert_eq!(a.as_slice().unwrap(), &[false, true, false, true, true]);
864    }
865
866    #[test]
867    fn test_frombuffer_bool_wrong_length() {
868        // For bool (1 byte each), the buffer length must equal the
869        // requested element count.
870        let bytes: Vec<u8> = vec![0, 1];
871        assert!(frombuffer::<bool, Ix1>(Ix1::new([3]), &bytes).is_err());
872    }
873
874    // #364: frombuffer_view — zero-copy view over an existing byte buffer.
875
876    /// Build an aligned byte buffer of `nbytes` from a typed slice so we
877    /// can exercise the zero-copy path without fighting the allocator.
878    fn aligned_bytes<T: Copy>(src: &[T]) -> Vec<u8> {
879        let n = std::mem::size_of_val(src);
880        let mut out = vec![0u8; n];
881        // SAFETY: src is &[T], out is a byte buffer of exactly n bytes.
882        unsafe {
883            std::ptr::copy_nonoverlapping(src.as_ptr() as *const u8, out.as_mut_ptr(), n);
884        }
885        out
886    }
887
888    #[test]
889    fn test_frombuffer_view_i32_is_zero_copy() {
890        // Build an aligned byte buffer that represents three i32s.
891        let source: Vec<i32> = vec![10, 20, 30];
892        let bytes = aligned_bytes(&source);
893        let view = frombuffer_view::<i32, Ix1>(Ix1::new([3]), &bytes).unwrap();
894        assert_eq!(view.shape(), &[3]);
895        let values: Vec<i32> = view.iter().copied().collect();
896        assert_eq!(values, vec![10, 20, 30]);
897        // Pointer must alias the source buffer — that's the zero-copy
898        // contract distinguishing this from the copying frombuffer.
899        assert_eq!(view.as_ptr() as *const u8, bytes.as_ptr());
900    }
901
902    #[test]
903    fn test_frombuffer_view_f64_2d() {
904        let source: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
905        let bytes = aligned_bytes(&source);
906        let view = frombuffer_view::<f64, Ix2>(Ix2::new([2, 3]), &bytes).unwrap();
907        assert_eq!(view.shape(), &[2, 3]);
908        let values: Vec<f64> = view.iter().copied().collect();
909        assert_eq!(values, source);
910    }
911
912    #[test]
913    fn test_frombuffer_view_bool_valid() {
914        let bytes: Vec<u8> = vec![0, 1, 0, 1];
915        let view = frombuffer_view::<bool, Ix1>(Ix1::new([4]), &bytes).unwrap();
916        let values: Vec<bool> = view.iter().copied().collect();
917        assert_eq!(values, vec![false, true, false, true]);
918    }
919
920    #[test]
921    fn test_frombuffer_view_bool_rejects_invalid_byte() {
922        let bytes: Vec<u8> = vec![0, 1, 42]; // 42 is not a valid bool.
923        assert!(frombuffer_view::<bool, Ix1>(Ix1::new([3]), &bytes).is_err());
924    }
925
926    #[test]
927    fn test_frombuffer_view_rejects_wrong_length() {
928        // 13 bytes is not a multiple of size_of::<i32>() = 4.
929        let bytes = vec![0u8; 13];
930        assert!(frombuffer_view::<i32, Ix1>(Ix1::new([3]), &bytes).is_err());
931        // 8 bytes is 2 i32s, but the caller asked for shape [3].
932        let bytes = vec![0u8; 8];
933        assert!(frombuffer_view::<i32, Ix1>(Ix1::new([3]), &bytes).is_err());
934    }
935
936    #[test]
937    fn test_frombuffer_view_rejects_misalignment() {
938        // Force a misaligned slice by advancing the byte pointer by 1.
939        let mut backing: Vec<u8> = vec![0u8; 1 + 4 * 3];
940        for (i, chunk) in backing[1..].chunks_exact_mut(4).enumerate() {
941            chunk.copy_from_slice(&(i as i32).to_ne_bytes());
942        }
943        let misaligned = &backing[1..];
944        // The slice address is off by one from a 4-byte boundary, so
945        // alignment for i32 cannot be satisfied.
946        assert!((misaligned.as_ptr() as usize) % 4 != 0);
947        assert!(frombuffer_view::<i32, Ix1>(Ix1::new([3]), misaligned).is_err());
948    }
949
950    #[test]
951    fn test_fromiter() {
952        let a = fromiter((0..5).map(|x| x as f64)).unwrap();
953        assert_eq!(a.shape(), &[5]);
954        assert_eq!(a.as_slice().unwrap(), &[0.0, 1.0, 2.0, 3.0, 4.0]);
955    }
956
957    #[test]
958    fn test_zeros() {
959        let a = zeros::<f64, Ix2>(Ix2::new([3, 4])).unwrap();
960        assert_eq!(a.shape(), &[3, 4]);
961        assert!(a.iter().all(|&v| v == 0.0));
962    }
963
964    #[test]
965    fn test_ones() {
966        let a = ones::<f64, Ix1>(Ix1::new([5])).unwrap();
967        assert!(a.iter().all(|&v| v == 1.0));
968    }
969
970    #[test]
971    fn test_full() {
972        let a = full(Ix1::new([4]), 42i32).unwrap();
973        assert!(a.iter().all(|&v| v == 42));
974    }
975
976    #[test]
977    fn test_zeros_like() {
978        let a = ones::<f64, Ix2>(Ix2::new([2, 3])).unwrap();
979        let b = zeros_like(&a).unwrap();
980        assert_eq!(b.shape(), &[2, 3]);
981        assert!(b.iter().all(|&v| v == 0.0));
982    }
983
984    #[test]
985    fn test_ones_like() {
986        let a = zeros::<f64, Ix1>(Ix1::new([4])).unwrap();
987        let b = ones_like(&a).unwrap();
988        assert!(b.iter().all(|&v| v == 1.0));
989    }
990
991    #[test]
992    fn test_full_like() {
993        let a = zeros::<i32, Ix1>(Ix1::new([3])).unwrap();
994        let b = full_like(&a, 7).unwrap();
995        assert!(b.iter().all(|&v| v == 7));
996    }
997
998    // -- REQ-17 tests --
999
1000    #[test]
1001    fn test_empty_and_init() {
1002        let mut u = empty::<f64, Ix1>(Ix1::new([3]));
1003        assert_eq!(u.shape(), &[3]);
1004        u.write_at(0, 1.0).unwrap();
1005        u.write_at(1, 2.0).unwrap();
1006        u.write_at(2, 3.0).unwrap();
1007        // SAFETY: all elements initialized
1008        let a = unsafe { u.assume_init() };
1009        assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1010    }
1011
1012    #[test]
1013    fn test_empty_write_oob() {
1014        let mut u = empty::<f64, Ix1>(Ix1::new([2]));
1015        assert!(u.write_at(5, 1.0).is_err());
1016    }
1017
1018    // #363: empty_like matches source shape, contents independent.
1019    #[test]
1020    fn test_empty_like_matches_shape_2d() {
1021        use crate::dimension::Ix2;
1022        let src = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1023            .unwrap();
1024        let mut u = empty_like(&src);
1025        assert_eq!(u.shape(), &[2, 3]);
1026        assert_eq!(u.size(), 6);
1027        assert_eq!(u.ndim(), 2);
1028
1029        // Fill and init — the resulting array is independent of `src`.
1030        for i in 0..6 {
1031            u.write_at(i, -(i as f64)).unwrap();
1032        }
1033        // SAFETY: every slot just written.
1034        let out = unsafe { u.assume_init() };
1035        assert_eq!(out.shape(), &[2, 3]);
1036        assert_eq!(
1037            out.as_slice().unwrap(),
1038            &[0.0, -1.0, -2.0, -3.0, -4.0, -5.0]
1039        );
1040        // Source is unchanged.
1041        assert_eq!(src.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1042    }
1043
1044    #[test]
1045    fn test_empty_like_zero_sized() {
1046        let src = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
1047        let u = empty_like(&src);
1048        assert_eq!(u.shape(), &[0]);
1049        assert_eq!(u.size(), 0);
1050        // SAFETY: size is zero — nothing to initialize.
1051        let out = unsafe { u.assume_init() };
1052        assert_eq!(out.size(), 0);
1053    }
1054
1055    // -- REQ-18 tests --
1056
1057    #[test]
1058    fn test_arange_int() {
1059        let a = arange(0i32, 5, 1).unwrap();
1060        assert_eq!(a.as_slice().unwrap(), &[0, 1, 2, 3, 4]);
1061    }
1062
1063    #[test]
1064    fn test_arange_float() {
1065        let a = arange(0.0_f64, 1.0, 0.25).unwrap();
1066        assert_eq!(a.shape(), &[4]);
1067        let data = a.as_slice().unwrap();
1068        assert!((data[0] - 0.0).abs() < 1e-10);
1069        assert!((data[1] - 0.25).abs() < 1e-10);
1070        assert!((data[2] - 0.5).abs() < 1e-10);
1071        assert!((data[3] - 0.75).abs() < 1e-10);
1072    }
1073
1074    #[test]
1075    fn test_arange_negative_step() {
1076        let a = arange(5.0_f64, 0.0, -1.0).unwrap();
1077        assert_eq!(a.shape(), &[5]);
1078    }
1079
1080    #[test]
1081    fn test_arange_zero_step() {
1082        assert!(arange(0.0_f64, 1.0, 0.0).is_err());
1083    }
1084
1085    #[test]
1086    fn test_arange_empty() {
1087        let a = arange(5i32, 0, 1).unwrap();
1088        assert_eq!(a.shape(), &[0]);
1089    }
1090
1091    #[test]
1092    fn test_linspace() {
1093        let a = linspace(0.0_f64, 1.0, 5, true).unwrap();
1094        assert_eq!(a.shape(), &[5]);
1095        let data = a.as_slice().unwrap();
1096        assert!((data[0] - 0.0).abs() < 1e-10);
1097        assert!((data[4] - 1.0).abs() < 1e-10);
1098        assert!((data[2] - 0.5).abs() < 1e-10);
1099    }
1100
1101    #[test]
1102    fn test_linspace_no_endpoint() {
1103        let a = linspace(0.0_f64, 1.0, 4, false).unwrap();
1104        assert_eq!(a.shape(), &[4]);
1105        let data = a.as_slice().unwrap();
1106        assert!((data[0] - 0.0).abs() < 1e-10);
1107        assert!((data[1] - 0.25).abs() < 1e-10);
1108    }
1109
1110    #[test]
1111    fn test_linspace_single() {
1112        let a = linspace(5.0_f64, 10.0, 1, true).unwrap();
1113        assert_eq!(a.as_slice().unwrap(), &[5.0]);
1114    }
1115
1116    #[test]
1117    fn test_linspace_empty() {
1118        let a = linspace(0.0_f64, 1.0, 0, true).unwrap();
1119        assert_eq!(a.shape(), &[0]);
1120    }
1121
1122    #[test]
1123    fn test_logspace() {
1124        let a = logspace(0.0_f64, 2.0, 3, true, 10.0).unwrap();
1125        let data = a.as_slice().unwrap();
1126        assert!((data[0] - 1.0).abs() < 1e-10); // 10^0
1127        assert!((data[1] - 10.0).abs() < 1e-10); // 10^1
1128        assert!((data[2] - 100.0).abs() < 1e-10); // 10^2
1129    }
1130
1131    #[test]
1132    fn test_geomspace() {
1133        let a = geomspace(1.0_f64, 1000.0, 4, true).unwrap();
1134        let data = a.as_slice().unwrap();
1135        assert!((data[0] - 1.0).abs() < 1e-10);
1136        assert!((data[1] - 10.0).abs() < 1e-8);
1137        assert!((data[2] - 100.0).abs() < 1e-6);
1138        assert!((data[3] - 1000.0).abs() < 1e-4);
1139    }
1140
1141    #[test]
1142    fn test_geomspace_zero_start() {
1143        assert!(geomspace(0.0_f64, 1.0, 5, true).is_err());
1144    }
1145
1146    #[test]
1147    fn test_geomspace_different_signs() {
1148        assert!(geomspace(-1.0_f64, 1.0, 5, true).is_err());
1149    }
1150
1151    #[test]
1152    fn test_meshgrid_xy() {
1153        let x = Array::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
1154        let y = Array::from_vec(Ix1::new([2]), vec![4.0, 5.0]).unwrap();
1155        let grids = meshgrid(&[x, y], "xy").unwrap();
1156        assert_eq!(grids.len(), 2);
1157        assert_eq!(grids[0].shape(), &[2, 3]);
1158        assert_eq!(grids[1].shape(), &[2, 3]);
1159        // X grid: rows are [1,2,3], [1,2,3]
1160        let xdata: Vec<f64> = grids[0].iter().copied().collect();
1161        assert_eq!(xdata, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
1162        // Y grid: rows are [4,4,4], [5,5,5]
1163        let ydata: Vec<f64> = grids[1].iter().copied().collect();
1164        assert_eq!(ydata, vec![4.0, 4.0, 4.0, 5.0, 5.0, 5.0]);
1165    }
1166
1167    #[test]
1168    fn test_meshgrid_ij() {
1169        let x = Array::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
1170        let y = Array::from_vec(Ix1::new([2]), vec![4.0, 5.0]).unwrap();
1171        let grids = meshgrid(&[x, y], "ij").unwrap();
1172        assert_eq!(grids.len(), 2);
1173        assert_eq!(grids[0].shape(), &[3, 2]);
1174        assert_eq!(grids[1].shape(), &[3, 2]);
1175    }
1176
1177    #[test]
1178    fn test_meshgrid_bad_indexing() {
1179        assert!(meshgrid(&[], "zz").is_err());
1180    }
1181
1182    #[test]
1183    fn test_mgrid() {
1184        let grids = mgrid(&[(0.0, 3.0, 1.0), (0.0, 2.0, 1.0)]).unwrap();
1185        assert_eq!(grids.len(), 2);
1186        assert_eq!(grids[0].shape(), &[3, 2]);
1187    }
1188
1189    #[test]
1190    fn test_ogrid() {
1191        let grids = ogrid(&[(0.0, 3.0, 1.0), (0.0, 2.0, 1.0)]).unwrap();
1192        assert_eq!(grids.len(), 2);
1193        assert_eq!(grids[0].shape(), &[3, 1]);
1194        assert_eq!(grids[1].shape(), &[1, 2]);
1195    }
1196
1197    // -- REQ-19 tests --
1198
1199    #[test]
1200    fn test_identity() {
1201        let a = identity::<f64>(3).unwrap();
1202        assert_eq!(a.shape(), &[3, 3]);
1203        let data = a.as_slice().unwrap();
1204        assert_eq!(data, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
1205    }
1206
1207    #[test]
1208    fn test_eye() {
1209        let a = eye::<f64>(3, 4, 0).unwrap();
1210        assert_eq!(a.shape(), &[3, 4]);
1211        let data = a.as_slice().unwrap();
1212        assert_eq!(
1213            data,
1214            &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
1215        );
1216    }
1217
1218    #[test]
1219    fn test_eye_positive_k() {
1220        let a = eye::<f64>(3, 3, 1).unwrap();
1221        let data = a.as_slice().unwrap();
1222        assert_eq!(data, &[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]);
1223    }
1224
1225    #[test]
1226    fn test_eye_negative_k() {
1227        let a = eye::<f64>(3, 3, -1).unwrap();
1228        let data = a.as_slice().unwrap();
1229        assert_eq!(data, &[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
1230    }
1231
1232    #[test]
1233    fn test_diag_from_1d() {
1234        let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1235        let d = diag(&a, 0).unwrap();
1236        assert_eq!(d.shape(), &[3, 3]);
1237        let data: Vec<f64> = d.iter().copied().collect();
1238        assert_eq!(data, vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
1239    }
1240
1241    #[test]
1242    fn test_diag_from_2d() {
1243        let a = Array::from_vec(
1244            IxDyn::new(&[3, 3]),
1245            vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0],
1246        )
1247        .unwrap();
1248        let d = diag(&a, 0).unwrap();
1249        assert_eq!(d.shape(), &[3]);
1250        let data: Vec<f64> = d.iter().copied().collect();
1251        assert_eq!(data, vec![1.0, 2.0, 3.0]);
1252    }
1253
1254    #[test]
1255    fn test_diag_k_positive() {
1256        let a = Array::from_vec(IxDyn::new(&[2]), vec![1.0, 2.0]).unwrap();
1257        let d = diag(&a, 1).unwrap();
1258        assert_eq!(d.shape(), &[3, 3]);
1259        let data: Vec<f64> = d.iter().copied().collect();
1260        assert_eq!(data, vec![0.0, 1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0]);
1261    }
1262
1263    #[test]
1264    fn test_diagflat() {
1265        let a = Array::from_vec(IxDyn::new(&[2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1266        let d = diagflat(&a, 0).unwrap();
1267        assert_eq!(d.shape(), &[4, 4]);
1268        // Diagonal should be [1, 2, 3, 4]
1269        let extracted = diag(&d, 0).unwrap();
1270        let data: Vec<f64> = extracted.iter().copied().collect();
1271        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
1272    }
1273
1274    #[test]
1275    fn test_tri() {
1276        let a = tri::<f64>(3, 3, 0).unwrap();
1277        let data = a.as_slice().unwrap();
1278        assert_eq!(data, &[1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0]);
1279    }
1280
1281    #[test]
1282    fn test_tril() {
1283        let a = Array::from_vec(
1284            IxDyn::new(&[3, 3]),
1285            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1286        )
1287        .unwrap();
1288        let t = tril(&a, 0).unwrap();
1289        let data: Vec<f64> = t.iter().copied().collect();
1290        assert_eq!(data, vec![1.0, 0.0, 0.0, 4.0, 5.0, 0.0, 7.0, 8.0, 9.0]);
1291    }
1292
1293    #[test]
1294    fn test_triu() {
1295        let a = Array::from_vec(
1296            IxDyn::new(&[3, 3]),
1297            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1298        )
1299        .unwrap();
1300        let t = triu(&a, 0).unwrap();
1301        let data: Vec<f64> = t.iter().copied().collect();
1302        assert_eq!(data, vec![1.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0, 0.0, 9.0]);
1303    }
1304
1305    #[test]
1306    fn test_tril_not_2d() {
1307        let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1308        assert!(tril(&a, 0).is_err());
1309    }
1310
1311    #[test]
1312    fn test_triu_not_2d() {
1313        let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1314        assert!(triu(&a, 0).is_err());
1315    }
1316}