1use std::mem::MaybeUninit;
7
8use crate::array::owned::Array;
9use crate::dimension::{Dimension, Ix1, Ix2, IxDyn};
10use crate::dtype::Element;
11use crate::error::{FerrayError, FerrayResult};
12
13pub fn array<T: Element, D: Dimension>(dim: D, data: Vec<T>) -> FerrayResult<Array<T, D>> {
26 Array::from_vec(dim, data)
27}
28
29pub fn asarray<T: Element, D: Dimension>(dim: D, data: Vec<T>) -> FerrayResult<Array<T, D>> {
37 Array::from_vec(dim, data)
38}
39
40pub fn frombuffer<T: Element, D: Dimension>(dim: D, buf: &[u8]) -> FerrayResult<Array<T, D>> {
48 let elem_size = std::mem::size_of::<T>();
49 if elem_size == 0 {
50 return Err(FerrayError::invalid_value("zero-sized type"));
51 }
52 if buf.len() % elem_size != 0 {
53 return Err(FerrayError::invalid_value(format!(
54 "buffer length {} is not a multiple of element size {}",
55 buf.len(),
56 elem_size,
57 )));
58 }
59 let n_elems = buf.len() / elem_size;
60 let expected = dim.size();
61 if n_elems != expected {
62 return Err(FerrayError::shape_mismatch(format!(
63 "buffer contains {} elements but shape {:?} requires {}",
64 n_elems,
65 dim.as_slice(),
66 expected,
67 )));
68 }
69 let mut data = Vec::with_capacity(n_elems);
71 for i in 0..n_elems {
72 let start = i * elem_size;
73 let end = start + elem_size;
74 let slice = &buf[start..end];
75 let val = unsafe {
79 let mut val = MaybeUninit::<T>::uninit();
80 std::ptr::copy_nonoverlapping(slice.as_ptr(), val.as_mut_ptr() as *mut u8, elem_size);
81 val.assume_init()
82 };
83 data.push(val);
84 }
85 Array::from_vec(dim, data)
86}
87
88pub fn fromiter<T: Element>(iter: impl IntoIterator<Item = T>) -> FerrayResult<Array<T, Ix1>> {
95 Array::from_iter_1d(iter)
96}
97
98pub fn zeros<T: Element, D: Dimension>(dim: D) -> FerrayResult<Array<T, D>> {
102 Array::zeros(dim)
103}
104
105pub fn ones<T: Element, D: Dimension>(dim: D) -> FerrayResult<Array<T, D>> {
109 Array::ones(dim)
110}
111
112pub fn full<T: Element, D: Dimension>(dim: D, fill_value: T) -> FerrayResult<Array<T, D>> {
116 Array::from_elem(dim, fill_value)
117}
118
119pub fn zeros_like<T: Element, D: Dimension>(other: &Array<T, D>) -> FerrayResult<Array<T, D>> {
123 Array::zeros(other.dim().clone())
124}
125
126pub fn ones_like<T: Element, D: Dimension>(other: &Array<T, D>) -> FerrayResult<Array<T, D>> {
130 Array::ones(other.dim().clone())
131}
132
133pub fn full_like<T: Element, D: Dimension>(
137 other: &Array<T, D>,
138 fill_value: T,
139) -> FerrayResult<Array<T, D>> {
140 Array::from_elem(other.dim().clone(), fill_value)
141}
142
143pub struct UninitArray<T: Element, D: Dimension> {
152 data: Vec<MaybeUninit<T>>,
153 dim: D,
154}
155
156impl<T: Element, D: Dimension> UninitArray<T, D> {
157 #[inline]
159 pub fn shape(&self) -> &[usize] {
160 self.dim.as_slice()
161 }
162
163 #[inline]
165 pub fn size(&self) -> usize {
166 self.data.len()
167 }
168
169 #[inline]
171 pub fn ndim(&self) -> usize {
172 self.dim.ndim()
173 }
174
175 #[inline]
180 pub fn as_mut_ptr(&mut self) -> *mut MaybeUninit<T> {
181 self.data.as_mut_ptr()
182 }
183
184 pub fn write_at(&mut self, flat_index: usize, value: T) -> FerrayResult<()> {
189 let size = self.size();
190 if flat_index >= size {
191 return Err(FerrayError::IndexOutOfBounds {
192 index: flat_index as isize,
193 axis: 0,
194 size,
195 });
196 }
197 self.data[flat_index] = MaybeUninit::new(value);
198 Ok(())
199 }
200
201 pub unsafe fn assume_init(self) -> Array<T, D> {
208 let nd_dim = self.dim.to_ndarray_dim();
209 let len = self.data.len();
210
211 let mut raw_vec = std::mem::ManuallyDrop::new(self.data);
215 let data: Vec<T> =
216 unsafe { Vec::from_raw_parts(raw_vec.as_mut_ptr() as *mut T, len, raw_vec.capacity()) };
217
218 let inner = ndarray::Array::from_shape_vec(nd_dim, data)
219 .expect("UninitArray assume_init: shape/data mismatch (this is a bug)");
220 Array::from_ndarray(inner)
221 }
222}
223
224pub fn empty<T: Element, D: Dimension>(dim: D) -> UninitArray<T, D> {
232 let size = dim.size();
233 let mut data = Vec::with_capacity(size);
234 unsafe {
237 data.set_len(size);
238 }
239 UninitArray { data, dim }
240}
241
242pub trait ArangeNum: Element + PartialOrd {
249 fn from_f64(v: f64) -> Self;
251 fn to_f64(self) -> f64;
253}
254
255macro_rules! impl_arange_int {
256 ($($ty:ty),*) => {
257 $(
258 impl ArangeNum for $ty {
259 #[inline]
260 fn from_f64(v: f64) -> Self { v as Self }
261 #[inline]
262 fn to_f64(self) -> f64 { self as f64 }
263 }
264 )*
265 };
266}
267
268macro_rules! impl_arange_float {
269 ($($ty:ty),*) => {
270 $(
271 impl ArangeNum for $ty {
272 #[inline]
273 fn from_f64(v: f64) -> Self { v as Self }
274 #[inline]
275 fn to_f64(self) -> f64 { self as f64 }
276 }
277 )*
278 };
279}
280
281impl_arange_int!(u8, u16, u32, u64, i8, i16, i32, i64);
282impl_arange_float!(f32, f64);
283
284pub fn arange<T: ArangeNum>(start: T, stop: T, step: T) -> FerrayResult<Array<T, Ix1>> {
291 let step_f = step.to_f64();
292 if step_f == 0.0 {
293 return Err(FerrayError::invalid_value("step cannot be zero"));
294 }
295 let start_f = start.to_f64();
296 let stop_f = stop.to_f64();
297 let n = ((stop_f - start_f) / step_f).ceil();
298 let n = if n < 0.0 { 0 } else { n as usize };
299
300 let mut data = Vec::with_capacity(n);
301 for i in 0..n {
302 data.push(T::from_f64(start_f + (i as f64) * step_f));
303 }
304 let dim = Ix1::new([data.len()]);
305 Array::from_vec(dim, data)
306}
307
308pub trait LinspaceNum: Element + PartialOrd {
310 fn from_f64(v: f64) -> Self;
312 fn to_f64(self) -> f64;
314}
315
316impl LinspaceNum for f32 {
317 #[inline]
318 fn from_f64(v: f64) -> Self {
319 v as f32
320 }
321 #[inline]
322 fn to_f64(self) -> f64 {
323 self as f64
324 }
325}
326
327impl LinspaceNum for f64 {
328 #[inline]
329 fn from_f64(v: f64) -> Self {
330 v
331 }
332 #[inline]
333 fn to_f64(self) -> f64 {
334 self
335 }
336}
337
338pub fn linspace<T: LinspaceNum>(
349 start: T,
350 stop: T,
351 num: usize,
352 endpoint: bool,
353) -> FerrayResult<Array<T, Ix1>> {
354 if num == 0 {
355 return Array::from_vec(Ix1::new([0]), vec![]);
356 }
357 if num == 1 {
358 return Array::from_vec(Ix1::new([1]), vec![start]);
359 }
360 let start_f = start.to_f64();
361 let stop_f = stop.to_f64();
362 let divisor = if endpoint {
363 (num - 1) as f64
364 } else {
365 num as f64
366 };
367 let step = (stop_f - start_f) / divisor;
368 let mut data = Vec::with_capacity(num);
369 for i in 0..num {
370 data.push(T::from_f64(start_f + (i as f64) * step));
371 }
372 Array::from_vec(Ix1::new([num]), data)
373}
374
375pub fn logspace<T: LinspaceNum>(
384 start: T,
385 stop: T,
386 num: usize,
387 endpoint: bool,
388 base: f64,
389) -> FerrayResult<Array<T, Ix1>> {
390 let lin = linspace(start, stop, num, endpoint)?;
391 let data: Vec<T> = lin
392 .iter()
393 .map(|v| T::from_f64(base.powf(v.clone().to_f64())))
394 .collect();
395 Array::from_vec(Ix1::new([num]), data)
396}
397
398pub fn geomspace<T: LinspaceNum>(
408 start: T,
409 stop: T,
410 num: usize,
411 endpoint: bool,
412) -> FerrayResult<Array<T, Ix1>> {
413 let start_f = start.clone().to_f64();
414 let stop_f = stop.clone().to_f64();
415 if start_f == 0.0 || stop_f == 0.0 {
416 return Err(FerrayError::invalid_value(
417 "geomspace: start and stop must be non-zero",
418 ));
419 }
420 if (start_f < 0.0) != (stop_f < 0.0) {
421 return Err(FerrayError::invalid_value(
422 "geomspace: start and stop must have the same sign",
423 ));
424 }
425 if num == 0 {
426 return Array::from_vec(Ix1::new([0]), vec![]);
427 }
428 if num == 1 {
429 return Array::from_vec(Ix1::new([1]), vec![start]);
430 }
431 let log_start = start_f.abs().ln();
432 let log_stop = stop_f.abs().ln();
433 let sign = if start_f < 0.0 { -1.0 } else { 1.0 };
434 let divisor = if endpoint {
435 (num - 1) as f64
436 } else {
437 num as f64
438 };
439 let step = (log_stop - log_start) / divisor;
440 let mut data = Vec::with_capacity(num);
441 for i in 0..num {
442 let log_val = log_start + (i as f64) * step;
443 data.push(T::from_f64(sign * log_val.exp()));
444 }
445 Array::from_vec(Ix1::new([num]), data)
446}
447
448pub fn meshgrid(
462 arrays: &[Array<f64, Ix1>],
463 indexing: &str,
464) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
465 if indexing != "xy" && indexing != "ij" {
466 return Err(FerrayError::invalid_value(
467 "meshgrid: indexing must be 'xy' or 'ij'",
468 ));
469 }
470 let ndim = arrays.len();
471 if ndim == 0 {
472 return Ok(vec![]);
473 }
474
475 let mut shapes: Vec<usize> = arrays.iter().map(|a| a.shape()[0]).collect();
476 if indexing == "xy" && ndim >= 2 {
477 shapes.swap(0, 1);
478 }
479
480 let total: usize = shapes.iter().product();
481 let mut results = Vec::with_capacity(ndim);
482
483 for (k, arr) in arrays.iter().enumerate() {
484 let src_data: Vec<f64> = arr.iter().copied().collect();
485 let mut data = Vec::with_capacity(total);
486 let effective_k = if indexing == "xy" && ndim >= 2 {
488 match k {
489 0 => 1,
490 1 => 0,
491 other => other,
492 }
493 } else {
494 k
495 };
496
497 for flat in 0..total {
499 let mut rem = flat;
501 let mut idx_k = 0;
502 for (d, &s) in shapes.iter().enumerate().rev() {
503 if d == effective_k {
504 idx_k = rem % s;
505 }
506 rem /= s;
507 }
508 data.push(src_data[idx_k]);
509 }
510
511 let dim = IxDyn::new(&shapes);
512 results.push(Array::from_vec(dim, data)?);
513 }
514 Ok(results)
515}
516
517pub fn mgrid(ranges: &[(f64, f64, f64)]) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
527 let mut arrs: Vec<Array<f64, Ix1>> = Vec::with_capacity(ranges.len());
528 for &(start, stop, step) in ranges {
529 arrs.push(arange(start, stop, step)?);
530 }
531 meshgrid(&arrs, "ij")
532}
533
534pub fn ogrid(ranges: &[(f64, f64, f64)]) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
544 let ndim = ranges.len();
545 let mut results = Vec::with_capacity(ndim);
546 for (i, &(start, stop, step)) in ranges.iter().enumerate() {
547 let arr1d = arange(start, stop, step)?;
548 let n = arr1d.shape()[0];
549 let data: Vec<f64> = arr1d.iter().copied().collect();
550 let mut shape = vec![1usize; ndim];
552 shape[i] = n;
553 let dim = IxDyn::new(&shape);
554 results.push(Array::from_vec(dim, data)?);
555 }
556 Ok(results)
557}
558
559pub fn identity<T: Element>(n: usize) -> FerrayResult<Array<T, Ix2>> {
567 eye(n, n, 0)
568}
569
570pub fn eye<T: Element>(n: usize, m: usize, k: isize) -> FerrayResult<Array<T, Ix2>> {
576 let mut data = vec![T::zero(); n * m];
577 for i in 0..n {
578 let j = i as isize + k;
579 if j >= 0 && (j as usize) < m {
580 data[i * m + j as usize] = T::one();
581 }
582 }
583 Array::from_vec(Ix2::new([n, m]), data)
584}
585
586pub fn diag<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
596 let shape = a.shape();
597 match shape.len() {
598 1 => {
599 let n = shape[0];
601 let size = n + k.unsigned_abs();
602 let mut data = vec![T::zero(); size * size];
603 let src: Vec<T> = a.iter().cloned().collect();
604 for (i, val) in src.into_iter().enumerate() {
605 let row = if k >= 0 { i } else { i + k.unsigned_abs() };
606 let col = if k >= 0 { i + k as usize } else { i };
607 data[row * size + col] = val;
608 }
609 Array::from_vec(IxDyn::new(&[size, size]), data)
610 }
611 2 => {
612 let (n, m) = (shape[0], shape[1]);
614 let src: Vec<T> = a.iter().cloned().collect();
615 let mut diag_vals = Vec::new();
616 for i in 0..n {
617 let j = i as isize + k;
618 if j >= 0 && (j as usize) < m {
619 diag_vals.push(src[i * m + j as usize].clone());
620 }
621 }
622 let len = diag_vals.len();
623 Array::from_vec(IxDyn::new(&[len]), diag_vals)
624 }
625 _ => Err(FerrayError::invalid_value("diag: input must be 1-D or 2-D")),
626 }
627}
628
629pub fn diagflat<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
636 let flat: Vec<T> = a.iter().cloned().collect();
638 let n = flat.len();
639 let arr1d = Array::from_vec(IxDyn::new(&[n]), flat)?;
640 diag(&arr1d, k)
641}
642
643pub fn tri<T: Element>(n: usize, m: usize, k: isize) -> FerrayResult<Array<T, Ix2>> {
649 let mut data = vec![T::zero(); n * m];
650 for i in 0..n {
651 for j in 0..m {
652 if (i as isize) >= (j as isize) - k {
653 data[i * m + j] = T::one();
654 }
655 }
656 }
657 Array::from_vec(Ix2::new([n, m]), data)
658}
659
660pub fn tril<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
669 let shape = a.shape();
670 if shape.len() != 2 {
671 return Err(FerrayError::invalid_value("tril: input must be 2-D"));
672 }
673 let (n, m) = (shape[0], shape[1]);
674 let src: Vec<T> = a.iter().cloned().collect();
675 let mut data = vec![T::zero(); n * m];
676 for i in 0..n {
677 for j in 0..m {
678 if (i as isize) >= (j as isize) - k {
679 data[i * m + j] = src[i * m + j].clone();
680 }
681 }
682 }
683 Array::from_vec(IxDyn::new(&[n, m]), data)
684}
685
686pub fn triu<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
695 let shape = a.shape();
696 if shape.len() != 2 {
697 return Err(FerrayError::invalid_value("triu: input must be 2-D"));
698 }
699 let (n, m) = (shape[0], shape[1]);
700 let src: Vec<T> = a.iter().cloned().collect();
701 let mut data = vec![T::zero(); n * m];
702 for i in 0..n {
703 for j in 0..m {
704 if (i as isize) <= (j as isize) - k {
705 data[i * m + j] = src[i * m + j].clone();
706 }
707 }
708 }
709 Array::from_vec(IxDyn::new(&[n, m]), data)
710}
711
712#[cfg(test)]
717mod tests {
718 use super::*;
719 use crate::dimension::{Ix1, Ix2, IxDyn};
720
721 #[test]
724 fn test_array_creation() {
725 let a = array(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
726 assert_eq!(a.shape(), &[2, 3]);
727 assert_eq!(a.size(), 6);
728 }
729
730 #[test]
731 fn test_asarray() {
732 let a = asarray(Ix1::new([3]), vec![1, 2, 3]).unwrap();
733 assert_eq!(a.as_slice().unwrap(), &[1, 2, 3]);
734 }
735
736 #[test]
737 fn test_frombuffer() {
738 let bytes: Vec<u8> = vec![1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0];
739 let a = frombuffer::<i32, Ix1>(Ix1::new([3]), &bytes).unwrap();
740 assert_eq!(a.as_slice().unwrap(), &[1, 2, 3]);
741 }
742
743 #[test]
744 fn test_frombuffer_bad_length() {
745 let bytes: Vec<u8> = vec![1, 0, 0];
746 assert!(frombuffer::<i32, Ix1>(Ix1::new([1]), &bytes).is_err());
747 }
748
749 #[test]
750 fn test_fromiter() {
751 let a = fromiter((0..5).map(|x| x as f64)).unwrap();
752 assert_eq!(a.shape(), &[5]);
753 assert_eq!(a.as_slice().unwrap(), &[0.0, 1.0, 2.0, 3.0, 4.0]);
754 }
755
756 #[test]
757 fn test_zeros() {
758 let a = zeros::<f64, Ix2>(Ix2::new([3, 4])).unwrap();
759 assert_eq!(a.shape(), &[3, 4]);
760 assert!(a.iter().all(|&v| v == 0.0));
761 }
762
763 #[test]
764 fn test_ones() {
765 let a = ones::<f64, Ix1>(Ix1::new([5])).unwrap();
766 assert!(a.iter().all(|&v| v == 1.0));
767 }
768
769 #[test]
770 fn test_full() {
771 let a = full(Ix1::new([4]), 42i32).unwrap();
772 assert!(a.iter().all(|&v| v == 42));
773 }
774
775 #[test]
776 fn test_zeros_like() {
777 let a = ones::<f64, Ix2>(Ix2::new([2, 3])).unwrap();
778 let b = zeros_like(&a).unwrap();
779 assert_eq!(b.shape(), &[2, 3]);
780 assert!(b.iter().all(|&v| v == 0.0));
781 }
782
783 #[test]
784 fn test_ones_like() {
785 let a = zeros::<f64, Ix1>(Ix1::new([4])).unwrap();
786 let b = ones_like(&a).unwrap();
787 assert!(b.iter().all(|&v| v == 1.0));
788 }
789
790 #[test]
791 fn test_full_like() {
792 let a = zeros::<i32, Ix1>(Ix1::new([3])).unwrap();
793 let b = full_like(&a, 7).unwrap();
794 assert!(b.iter().all(|&v| v == 7));
795 }
796
797 #[test]
800 fn test_empty_and_init() {
801 let mut u = empty::<f64, Ix1>(Ix1::new([3]));
802 assert_eq!(u.shape(), &[3]);
803 u.write_at(0, 1.0).unwrap();
804 u.write_at(1, 2.0).unwrap();
805 u.write_at(2, 3.0).unwrap();
806 let a = unsafe { u.assume_init() };
808 assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
809 }
810
811 #[test]
812 fn test_empty_write_oob() {
813 let mut u = empty::<f64, Ix1>(Ix1::new([2]));
814 assert!(u.write_at(5, 1.0).is_err());
815 }
816
817 #[test]
820 fn test_arange_int() {
821 let a = arange(0i32, 5, 1).unwrap();
822 assert_eq!(a.as_slice().unwrap(), &[0, 1, 2, 3, 4]);
823 }
824
825 #[test]
826 fn test_arange_float() {
827 let a = arange(0.0_f64, 1.0, 0.25).unwrap();
828 assert_eq!(a.shape(), &[4]);
829 let data = a.as_slice().unwrap();
830 assert!((data[0] - 0.0).abs() < 1e-10);
831 assert!((data[1] - 0.25).abs() < 1e-10);
832 assert!((data[2] - 0.5).abs() < 1e-10);
833 assert!((data[3] - 0.75).abs() < 1e-10);
834 }
835
836 #[test]
837 fn test_arange_negative_step() {
838 let a = arange(5.0_f64, 0.0, -1.0).unwrap();
839 assert_eq!(a.shape(), &[5]);
840 }
841
842 #[test]
843 fn test_arange_zero_step() {
844 assert!(arange(0.0_f64, 1.0, 0.0).is_err());
845 }
846
847 #[test]
848 fn test_arange_empty() {
849 let a = arange(5i32, 0, 1).unwrap();
850 assert_eq!(a.shape(), &[0]);
851 }
852
853 #[test]
854 fn test_linspace() {
855 let a = linspace(0.0_f64, 1.0, 5, true).unwrap();
856 assert_eq!(a.shape(), &[5]);
857 let data = a.as_slice().unwrap();
858 assert!((data[0] - 0.0).abs() < 1e-10);
859 assert!((data[4] - 1.0).abs() < 1e-10);
860 assert!((data[2] - 0.5).abs() < 1e-10);
861 }
862
863 #[test]
864 fn test_linspace_no_endpoint() {
865 let a = linspace(0.0_f64, 1.0, 4, false).unwrap();
866 assert_eq!(a.shape(), &[4]);
867 let data = a.as_slice().unwrap();
868 assert!((data[0] - 0.0).abs() < 1e-10);
869 assert!((data[1] - 0.25).abs() < 1e-10);
870 }
871
872 #[test]
873 fn test_linspace_single() {
874 let a = linspace(5.0_f64, 10.0, 1, true).unwrap();
875 assert_eq!(a.as_slice().unwrap(), &[5.0]);
876 }
877
878 #[test]
879 fn test_linspace_empty() {
880 let a = linspace(0.0_f64, 1.0, 0, true).unwrap();
881 assert_eq!(a.shape(), &[0]);
882 }
883
884 #[test]
885 fn test_logspace() {
886 let a = logspace(0.0_f64, 2.0, 3, true, 10.0).unwrap();
887 let data = a.as_slice().unwrap();
888 assert!((data[0] - 1.0).abs() < 1e-10); assert!((data[1] - 10.0).abs() < 1e-10); assert!((data[2] - 100.0).abs() < 1e-10); }
892
893 #[test]
894 fn test_geomspace() {
895 let a = geomspace(1.0_f64, 1000.0, 4, true).unwrap();
896 let data = a.as_slice().unwrap();
897 assert!((data[0] - 1.0).abs() < 1e-10);
898 assert!((data[1] - 10.0).abs() < 1e-8);
899 assert!((data[2] - 100.0).abs() < 1e-6);
900 assert!((data[3] - 1000.0).abs() < 1e-4);
901 }
902
903 #[test]
904 fn test_geomspace_zero_start() {
905 assert!(geomspace(0.0_f64, 1.0, 5, true).is_err());
906 }
907
908 #[test]
909 fn test_geomspace_different_signs() {
910 assert!(geomspace(-1.0_f64, 1.0, 5, true).is_err());
911 }
912
913 #[test]
914 fn test_meshgrid_xy() {
915 let x = Array::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
916 let y = Array::from_vec(Ix1::new([2]), vec![4.0, 5.0]).unwrap();
917 let grids = meshgrid(&[x, y], "xy").unwrap();
918 assert_eq!(grids.len(), 2);
919 assert_eq!(grids[0].shape(), &[2, 3]);
920 assert_eq!(grids[1].shape(), &[2, 3]);
921 let xdata: Vec<f64> = grids[0].iter().copied().collect();
923 assert_eq!(xdata, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
924 let ydata: Vec<f64> = grids[1].iter().copied().collect();
926 assert_eq!(ydata, vec![4.0, 4.0, 4.0, 5.0, 5.0, 5.0]);
927 }
928
929 #[test]
930 fn test_meshgrid_ij() {
931 let x = Array::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
932 let y = Array::from_vec(Ix1::new([2]), vec![4.0, 5.0]).unwrap();
933 let grids = meshgrid(&[x, y], "ij").unwrap();
934 assert_eq!(grids.len(), 2);
935 assert_eq!(grids[0].shape(), &[3, 2]);
936 assert_eq!(grids[1].shape(), &[3, 2]);
937 }
938
939 #[test]
940 fn test_meshgrid_bad_indexing() {
941 assert!(meshgrid(&[], "zz").is_err());
942 }
943
944 #[test]
945 fn test_mgrid() {
946 let grids = mgrid(&[(0.0, 3.0, 1.0), (0.0, 2.0, 1.0)]).unwrap();
947 assert_eq!(grids.len(), 2);
948 assert_eq!(grids[0].shape(), &[3, 2]);
949 }
950
951 #[test]
952 fn test_ogrid() {
953 let grids = ogrid(&[(0.0, 3.0, 1.0), (0.0, 2.0, 1.0)]).unwrap();
954 assert_eq!(grids.len(), 2);
955 assert_eq!(grids[0].shape(), &[3, 1]);
956 assert_eq!(grids[1].shape(), &[1, 2]);
957 }
958
959 #[test]
962 fn test_identity() {
963 let a = identity::<f64>(3).unwrap();
964 assert_eq!(a.shape(), &[3, 3]);
965 let data = a.as_slice().unwrap();
966 assert_eq!(data, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
967 }
968
969 #[test]
970 fn test_eye() {
971 let a = eye::<f64>(3, 4, 0).unwrap();
972 assert_eq!(a.shape(), &[3, 4]);
973 let data = a.as_slice().unwrap();
974 assert_eq!(
975 data,
976 &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
977 );
978 }
979
980 #[test]
981 fn test_eye_positive_k() {
982 let a = eye::<f64>(3, 3, 1).unwrap();
983 let data = a.as_slice().unwrap();
984 assert_eq!(data, &[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]);
985 }
986
987 #[test]
988 fn test_eye_negative_k() {
989 let a = eye::<f64>(3, 3, -1).unwrap();
990 let data = a.as_slice().unwrap();
991 assert_eq!(data, &[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
992 }
993
994 #[test]
995 fn test_diag_from_1d() {
996 let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
997 let d = diag(&a, 0).unwrap();
998 assert_eq!(d.shape(), &[3, 3]);
999 let data: Vec<f64> = d.iter().copied().collect();
1000 assert_eq!(data, vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
1001 }
1002
1003 #[test]
1004 fn test_diag_from_2d() {
1005 let a = Array::from_vec(
1006 IxDyn::new(&[3, 3]),
1007 vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0],
1008 )
1009 .unwrap();
1010 let d = diag(&a, 0).unwrap();
1011 assert_eq!(d.shape(), &[3]);
1012 let data: Vec<f64> = d.iter().copied().collect();
1013 assert_eq!(data, vec![1.0, 2.0, 3.0]);
1014 }
1015
1016 #[test]
1017 fn test_diag_k_positive() {
1018 let a = Array::from_vec(IxDyn::new(&[2]), vec![1.0, 2.0]).unwrap();
1019 let d = diag(&a, 1).unwrap();
1020 assert_eq!(d.shape(), &[3, 3]);
1021 let data: Vec<f64> = d.iter().copied().collect();
1022 assert_eq!(data, vec![0.0, 1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0]);
1023 }
1024
1025 #[test]
1026 fn test_diagflat() {
1027 let a = Array::from_vec(IxDyn::new(&[2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1028 let d = diagflat(&a, 0).unwrap();
1029 assert_eq!(d.shape(), &[4, 4]);
1030 let extracted = diag(&d, 0).unwrap();
1032 let data: Vec<f64> = extracted.iter().copied().collect();
1033 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
1034 }
1035
1036 #[test]
1037 fn test_tri() {
1038 let a = tri::<f64>(3, 3, 0).unwrap();
1039 let data = a.as_slice().unwrap();
1040 assert_eq!(data, &[1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0]);
1041 }
1042
1043 #[test]
1044 fn test_tril() {
1045 let a = Array::from_vec(
1046 IxDyn::new(&[3, 3]),
1047 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1048 )
1049 .unwrap();
1050 let t = tril(&a, 0).unwrap();
1051 let data: Vec<f64> = t.iter().copied().collect();
1052 assert_eq!(data, vec![1.0, 0.0, 0.0, 4.0, 5.0, 0.0, 7.0, 8.0, 9.0]);
1053 }
1054
1055 #[test]
1056 fn test_triu() {
1057 let a = Array::from_vec(
1058 IxDyn::new(&[3, 3]),
1059 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1060 )
1061 .unwrap();
1062 let t = triu(&a, 0).unwrap();
1063 let data: Vec<f64> = t.iter().copied().collect();
1064 assert_eq!(data, vec![1.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0, 0.0, 9.0]);
1065 }
1066
1067 #[test]
1068 fn test_tril_not_2d() {
1069 let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1070 assert!(tril(&a, 0).is_err());
1071 }
1072
1073 #[test]
1074 fn test_triu_not_2d() {
1075 let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1076 assert!(triu(&a, 0).is_err());
1077 }
1078}