1use std::mem::MaybeUninit;
7
8use crate::array::owned::Array;
9use crate::dimension::{Dimension, Ix1, Ix2, IxDyn};
10use crate::dtype::Element;
11use crate::error::{FerrayError, FerrayResult};
12
13pub fn array<T: Element, D: Dimension>(dim: D, data: Vec<T>) -> FerrayResult<Array<T, D>> {
26 Array::from_vec(dim, data)
27}
28
29pub fn asarray<T: Element, D: Dimension>(dim: D, data: Vec<T>) -> FerrayResult<Array<T, D>> {
37 Array::from_vec(dim, data)
38}
39
40pub fn frombuffer<T: Element, D: Dimension>(dim: D, buf: &[u8]) -> FerrayResult<Array<T, D>> {
48 let elem_size = std::mem::size_of::<T>();
49 if elem_size == 0 {
50 return Err(FerrayError::invalid_value("zero-sized type"));
51 }
52 if buf.len() % elem_size != 0 {
53 return Err(FerrayError::invalid_value(format!(
54 "buffer length {} is not a multiple of element size {}",
55 buf.len(),
56 elem_size,
57 )));
58 }
59 let n_elems = buf.len() / elem_size;
60 let expected = dim.size();
61 if n_elems != expected {
62 return Err(FerrayError::shape_mismatch(format!(
63 "buffer contains {} elements but shape {:?} requires {}",
64 n_elems,
65 dim.as_slice(),
66 expected,
67 )));
68 }
69 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<bool>() {
72 for &byte in buf {
73 if byte > 1 {
74 return Err(FerrayError::invalid_value(format!(
75 "invalid byte {byte:#04x} for bool (must be 0x00 or 0x01)"
76 )));
77 }
78 }
79 }
80
81 let mut data = Vec::with_capacity(n_elems);
83 for i in 0..n_elems {
84 let start = i * elem_size;
85 let end = start + elem_size;
86 let slice = &buf[start..end];
87 let val = unsafe {
91 let mut val = MaybeUninit::<T>::uninit();
92 std::ptr::copy_nonoverlapping(slice.as_ptr(), val.as_mut_ptr() as *mut u8, elem_size);
93 val.assume_init()
94 };
95 data.push(val);
96 }
97 Array::from_vec(dim, data)
98}
99
100pub fn fromiter<T: Element>(iter: impl IntoIterator<Item = T>) -> FerrayResult<Array<T, Ix1>> {
107 Array::from_iter_1d(iter)
108}
109
110pub fn zeros<T: Element, D: Dimension>(dim: D) -> FerrayResult<Array<T, D>> {
114 Array::zeros(dim)
115}
116
117pub fn ones<T: Element, D: Dimension>(dim: D) -> FerrayResult<Array<T, D>> {
121 Array::ones(dim)
122}
123
124pub fn full<T: Element, D: Dimension>(dim: D, fill_value: T) -> FerrayResult<Array<T, D>> {
128 Array::from_elem(dim, fill_value)
129}
130
131pub fn zeros_like<T: Element, D: Dimension>(other: &Array<T, D>) -> FerrayResult<Array<T, D>> {
135 Array::zeros(other.dim().clone())
136}
137
138pub fn ones_like<T: Element, D: Dimension>(other: &Array<T, D>) -> FerrayResult<Array<T, D>> {
142 Array::ones(other.dim().clone())
143}
144
145pub fn full_like<T: Element, D: Dimension>(
149 other: &Array<T, D>,
150 fill_value: T,
151) -> FerrayResult<Array<T, D>> {
152 Array::from_elem(other.dim().clone(), fill_value)
153}
154
155pub struct UninitArray<T: Element, D: Dimension> {
164 data: Vec<MaybeUninit<T>>,
165 dim: D,
166}
167
168impl<T: Element, D: Dimension> UninitArray<T, D> {
169 #[inline]
171 pub fn shape(&self) -> &[usize] {
172 self.dim.as_slice()
173 }
174
175 #[inline]
177 pub fn size(&self) -> usize {
178 self.data.len()
179 }
180
181 #[inline]
183 pub fn ndim(&self) -> usize {
184 self.dim.ndim()
185 }
186
187 #[inline]
192 pub fn as_mut_ptr(&mut self) -> *mut MaybeUninit<T> {
193 self.data.as_mut_ptr()
194 }
195
196 pub fn write_at(&mut self, flat_index: usize, value: T) -> FerrayResult<()> {
201 let size = self.size();
202 if flat_index >= size {
203 return Err(FerrayError::IndexOutOfBounds {
204 index: flat_index as isize,
205 axis: 0,
206 size,
207 });
208 }
209 self.data[flat_index] = MaybeUninit::new(value);
210 Ok(())
211 }
212
213 pub unsafe fn assume_init(self) -> Array<T, D> {
220 let nd_dim = self.dim.to_ndarray_dim();
221 let len = self.data.len();
222
223 let mut raw_vec = std::mem::ManuallyDrop::new(self.data);
227 let data: Vec<T> =
228 unsafe { Vec::from_raw_parts(raw_vec.as_mut_ptr() as *mut T, len, raw_vec.capacity()) };
229
230 let inner = ndarray::Array::from_shape_vec(nd_dim, data)
231 .expect("UninitArray assume_init: shape/data mismatch (this is a bug)");
232 Array::from_ndarray(inner)
233 }
234}
235
236pub fn empty<T: Element, D: Dimension>(dim: D) -> UninitArray<T, D> {
244 let size = dim.size();
245 let mut data = Vec::with_capacity(size);
246 unsafe {
249 data.set_len(size);
250 }
251 UninitArray { data, dim }
252}
253
254pub trait ArangeNum: Element + PartialOrd {
261 fn from_f64(v: f64) -> Self;
263 fn to_f64(self) -> f64;
265}
266
267macro_rules! impl_arange_int {
268 ($($ty:ty),*) => {
269 $(
270 impl ArangeNum for $ty {
271 #[inline]
272 fn from_f64(v: f64) -> Self { v as Self }
273 #[inline]
274 fn to_f64(self) -> f64 { self as f64 }
275 }
276 )*
277 };
278}
279
280macro_rules! impl_arange_float {
281 ($($ty:ty),*) => {
282 $(
283 impl ArangeNum for $ty {
284 #[inline]
285 fn from_f64(v: f64) -> Self { v as Self }
286 #[inline]
287 fn to_f64(self) -> f64 { self as f64 }
288 }
289 )*
290 };
291}
292
293impl_arange_int!(u8, u16, u32, u64, i8, i16, i32, i64);
294impl_arange_float!(f32, f64);
295
296pub fn arange<T: ArangeNum>(start: T, stop: T, step: T) -> FerrayResult<Array<T, Ix1>> {
303 let step_f = step.to_f64();
304 if step_f == 0.0 {
305 return Err(FerrayError::invalid_value("step cannot be zero"));
306 }
307 let start_f = start.to_f64();
308 let stop_f = stop.to_f64();
309 let n = ((stop_f - start_f) / step_f).ceil();
310 let n = if n < 0.0 { 0 } else { n as usize };
311
312 let mut data = Vec::with_capacity(n);
313 for i in 0..n {
314 data.push(T::from_f64(start_f + (i as f64) * step_f));
315 }
316 let dim = Ix1::new([data.len()]);
317 Array::from_vec(dim, data)
318}
319
320pub trait LinspaceNum: Element + PartialOrd {
322 fn from_f64(v: f64) -> Self;
324 fn to_f64(self) -> f64;
326}
327
328impl LinspaceNum for f32 {
329 #[inline]
330 fn from_f64(v: f64) -> Self {
331 v as f32
332 }
333 #[inline]
334 fn to_f64(self) -> f64 {
335 self as f64
336 }
337}
338
339impl LinspaceNum for f64 {
340 #[inline]
341 fn from_f64(v: f64) -> Self {
342 v
343 }
344 #[inline]
345 fn to_f64(self) -> f64 {
346 self
347 }
348}
349
350pub fn linspace<T: LinspaceNum>(
361 start: T,
362 stop: T,
363 num: usize,
364 endpoint: bool,
365) -> FerrayResult<Array<T, Ix1>> {
366 if num == 0 {
367 return Array::from_vec(Ix1::new([0]), vec![]);
368 }
369 if num == 1 {
370 return Array::from_vec(Ix1::new([1]), vec![start]);
371 }
372 let start_f = start.to_f64();
373 let stop_f = stop.to_f64();
374 let divisor = if endpoint {
375 (num - 1) as f64
376 } else {
377 num as f64
378 };
379 let step = (stop_f - start_f) / divisor;
380 let mut data = Vec::with_capacity(num);
381 for i in 0..num {
382 data.push(T::from_f64(start_f + (i as f64) * step));
383 }
384 Array::from_vec(Ix1::new([num]), data)
385}
386
387pub fn logspace<T: LinspaceNum>(
396 start: T,
397 stop: T,
398 num: usize,
399 endpoint: bool,
400 base: f64,
401) -> FerrayResult<Array<T, Ix1>> {
402 let lin = linspace(start, stop, num, endpoint)?;
403 let data: Vec<T> = lin
404 .iter()
405 .map(|v| T::from_f64(base.powf(v.clone().to_f64())))
406 .collect();
407 Array::from_vec(Ix1::new([num]), data)
408}
409
410pub fn geomspace<T: LinspaceNum>(
420 start: T,
421 stop: T,
422 num: usize,
423 endpoint: bool,
424) -> FerrayResult<Array<T, Ix1>> {
425 let start_f = start.clone().to_f64();
426 let stop_f = stop.clone().to_f64();
427 if start_f == 0.0 || stop_f == 0.0 {
428 return Err(FerrayError::invalid_value(
429 "geomspace: start and stop must be non-zero",
430 ));
431 }
432 if (start_f < 0.0) != (stop_f < 0.0) {
433 return Err(FerrayError::invalid_value(
434 "geomspace: start and stop must have the same sign",
435 ));
436 }
437 if num == 0 {
438 return Array::from_vec(Ix1::new([0]), vec![]);
439 }
440 if num == 1 {
441 return Array::from_vec(Ix1::new([1]), vec![start]);
442 }
443 let log_start = start_f.abs().ln();
444 let log_stop = stop_f.abs().ln();
445 let sign = if start_f < 0.0 { -1.0 } else { 1.0 };
446 let divisor = if endpoint {
447 (num - 1) as f64
448 } else {
449 num as f64
450 };
451 let step = (log_stop - log_start) / divisor;
452 let mut data = Vec::with_capacity(num);
453 for i in 0..num {
454 let log_val = log_start + (i as f64) * step;
455 data.push(T::from_f64(sign * log_val.exp()));
456 }
457 Array::from_vec(Ix1::new([num]), data)
458}
459
460pub fn meshgrid(
474 arrays: &[Array<f64, Ix1>],
475 indexing: &str,
476) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
477 if indexing != "xy" && indexing != "ij" {
478 return Err(FerrayError::invalid_value(
479 "meshgrid: indexing must be 'xy' or 'ij'",
480 ));
481 }
482 let ndim = arrays.len();
483 if ndim == 0 {
484 return Ok(vec![]);
485 }
486
487 let mut shapes: Vec<usize> = arrays.iter().map(|a| a.shape()[0]).collect();
488 if indexing == "xy" && ndim >= 2 {
489 shapes.swap(0, 1);
490 }
491
492 let total: usize = shapes.iter().product();
493 let mut results = Vec::with_capacity(ndim);
494
495 for (k, arr) in arrays.iter().enumerate() {
496 let src_data: Vec<f64> = arr.iter().copied().collect();
497 let mut data = Vec::with_capacity(total);
498 let effective_k = if indexing == "xy" && ndim >= 2 {
500 match k {
501 0 => 1,
502 1 => 0,
503 other => other,
504 }
505 } else {
506 k
507 };
508
509 for flat in 0..total {
511 let mut rem = flat;
513 let mut idx_k = 0;
514 for (d, &s) in shapes.iter().enumerate().rev() {
515 if d == effective_k {
516 idx_k = rem % s;
517 }
518 rem /= s;
519 }
520 data.push(src_data[idx_k]);
521 }
522
523 let dim = IxDyn::new(&shapes);
524 results.push(Array::from_vec(dim, data)?);
525 }
526 Ok(results)
527}
528
529pub fn mgrid(ranges: &[(f64, f64, f64)]) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
539 let mut arrs: Vec<Array<f64, Ix1>> = Vec::with_capacity(ranges.len());
540 for &(start, stop, step) in ranges {
541 arrs.push(arange(start, stop, step)?);
542 }
543 meshgrid(&arrs, "ij")
544}
545
546pub fn ogrid(ranges: &[(f64, f64, f64)]) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
556 let ndim = ranges.len();
557 let mut results = Vec::with_capacity(ndim);
558 for (i, &(start, stop, step)) in ranges.iter().enumerate() {
559 let arr1d = arange(start, stop, step)?;
560 let n = arr1d.shape()[0];
561 let data: Vec<f64> = arr1d.iter().copied().collect();
562 let mut shape = vec![1usize; ndim];
564 shape[i] = n;
565 let dim = IxDyn::new(&shape);
566 results.push(Array::from_vec(dim, data)?);
567 }
568 Ok(results)
569}
570
571pub fn identity<T: Element>(n: usize) -> FerrayResult<Array<T, Ix2>> {
579 eye(n, n, 0)
580}
581
582pub fn eye<T: Element>(n: usize, m: usize, k: isize) -> FerrayResult<Array<T, Ix2>> {
588 let mut data = vec![T::zero(); n * m];
589 for i in 0..n {
590 let j = i as isize + k;
591 if j >= 0 && (j as usize) < m {
592 data[i * m + j as usize] = T::one();
593 }
594 }
595 Array::from_vec(Ix2::new([n, m]), data)
596}
597
598pub fn diag<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
608 let shape = a.shape();
609 match shape.len() {
610 1 => {
611 let n = shape[0];
613 let size = n + k.unsigned_abs();
614 let mut data = vec![T::zero(); size * size];
615 let src: Vec<T> = a.iter().cloned().collect();
616 for (i, val) in src.into_iter().enumerate() {
617 let row = if k >= 0 { i } else { i + k.unsigned_abs() };
618 let col = if k >= 0 { i + k as usize } else { i };
619 data[row * size + col] = val;
620 }
621 Array::from_vec(IxDyn::new(&[size, size]), data)
622 }
623 2 => {
624 let (n, m) = (shape[0], shape[1]);
626 let src: Vec<T> = a.iter().cloned().collect();
627 let mut diag_vals = Vec::new();
628 for i in 0..n {
629 let j = i as isize + k;
630 if j >= 0 && (j as usize) < m {
631 diag_vals.push(src[i * m + j as usize].clone());
632 }
633 }
634 let len = diag_vals.len();
635 Array::from_vec(IxDyn::new(&[len]), diag_vals)
636 }
637 _ => Err(FerrayError::invalid_value("diag: input must be 1-D or 2-D")),
638 }
639}
640
641pub fn diagflat<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
648 let flat: Vec<T> = a.iter().cloned().collect();
650 let n = flat.len();
651 let arr1d = Array::from_vec(IxDyn::new(&[n]), flat)?;
652 diag(&arr1d, k)
653}
654
655pub fn tri<T: Element>(n: usize, m: usize, k: isize) -> FerrayResult<Array<T, Ix2>> {
661 let mut data = vec![T::zero(); n * m];
662 for i in 0..n {
663 for j in 0..m {
664 if (i as isize) >= (j as isize) - k {
665 data[i * m + j] = T::one();
666 }
667 }
668 }
669 Array::from_vec(Ix2::new([n, m]), data)
670}
671
672pub fn tril<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
681 let shape = a.shape();
682 if shape.len() != 2 {
683 return Err(FerrayError::invalid_value("tril: input must be 2-D"));
684 }
685 let (n, m) = (shape[0], shape[1]);
686 let src: Vec<T> = a.iter().cloned().collect();
687 let mut data = vec![T::zero(); n * m];
688 for i in 0..n {
689 for j in 0..m {
690 if (i as isize) >= (j as isize) - k {
691 data[i * m + j] = src[i * m + j].clone();
692 }
693 }
694 }
695 Array::from_vec(IxDyn::new(&[n, m]), data)
696}
697
698pub fn triu<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
707 let shape = a.shape();
708 if shape.len() != 2 {
709 return Err(FerrayError::invalid_value("triu: input must be 2-D"));
710 }
711 let (n, m) = (shape[0], shape[1]);
712 let src: Vec<T> = a.iter().cloned().collect();
713 let mut data = vec![T::zero(); n * m];
714 for i in 0..n {
715 for j in 0..m {
716 if (i as isize) <= (j as isize) - k {
717 data[i * m + j] = src[i * m + j].clone();
718 }
719 }
720 }
721 Array::from_vec(IxDyn::new(&[n, m]), data)
722}
723
724#[cfg(test)]
729mod tests {
730 use super::*;
731 use crate::dimension::{Ix1, Ix2, IxDyn};
732
733 #[test]
736 fn test_array_creation() {
737 let a = array(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
738 assert_eq!(a.shape(), &[2, 3]);
739 assert_eq!(a.size(), 6);
740 }
741
742 #[test]
743 fn test_asarray() {
744 let a = asarray(Ix1::new([3]), vec![1, 2, 3]).unwrap();
745 assert_eq!(a.as_slice().unwrap(), &[1, 2, 3]);
746 }
747
748 #[test]
749 fn test_frombuffer() {
750 let bytes: Vec<u8> = vec![1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0];
751 let a = frombuffer::<i32, Ix1>(Ix1::new([3]), &bytes).unwrap();
752 assert_eq!(a.as_slice().unwrap(), &[1, 2, 3]);
753 }
754
755 #[test]
756 fn test_frombuffer_bad_length() {
757 let bytes: Vec<u8> = vec![1, 0, 0];
758 assert!(frombuffer::<i32, Ix1>(Ix1::new([1]), &bytes).is_err());
759 }
760
761 #[test]
762 fn test_fromiter() {
763 let a = fromiter((0..5).map(|x| x as f64)).unwrap();
764 assert_eq!(a.shape(), &[5]);
765 assert_eq!(a.as_slice().unwrap(), &[0.0, 1.0, 2.0, 3.0, 4.0]);
766 }
767
768 #[test]
769 fn test_zeros() {
770 let a = zeros::<f64, Ix2>(Ix2::new([3, 4])).unwrap();
771 assert_eq!(a.shape(), &[3, 4]);
772 assert!(a.iter().all(|&v| v == 0.0));
773 }
774
775 #[test]
776 fn test_ones() {
777 let a = ones::<f64, Ix1>(Ix1::new([5])).unwrap();
778 assert!(a.iter().all(|&v| v == 1.0));
779 }
780
781 #[test]
782 fn test_full() {
783 let a = full(Ix1::new([4]), 42i32).unwrap();
784 assert!(a.iter().all(|&v| v == 42));
785 }
786
787 #[test]
788 fn test_zeros_like() {
789 let a = ones::<f64, Ix2>(Ix2::new([2, 3])).unwrap();
790 let b = zeros_like(&a).unwrap();
791 assert_eq!(b.shape(), &[2, 3]);
792 assert!(b.iter().all(|&v| v == 0.0));
793 }
794
795 #[test]
796 fn test_ones_like() {
797 let a = zeros::<f64, Ix1>(Ix1::new([4])).unwrap();
798 let b = ones_like(&a).unwrap();
799 assert!(b.iter().all(|&v| v == 1.0));
800 }
801
802 #[test]
803 fn test_full_like() {
804 let a = zeros::<i32, Ix1>(Ix1::new([3])).unwrap();
805 let b = full_like(&a, 7).unwrap();
806 assert!(b.iter().all(|&v| v == 7));
807 }
808
809 #[test]
812 fn test_empty_and_init() {
813 let mut u = empty::<f64, Ix1>(Ix1::new([3]));
814 assert_eq!(u.shape(), &[3]);
815 u.write_at(0, 1.0).unwrap();
816 u.write_at(1, 2.0).unwrap();
817 u.write_at(2, 3.0).unwrap();
818 let a = unsafe { u.assume_init() };
820 assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
821 }
822
823 #[test]
824 fn test_empty_write_oob() {
825 let mut u = empty::<f64, Ix1>(Ix1::new([2]));
826 assert!(u.write_at(5, 1.0).is_err());
827 }
828
829 #[test]
832 fn test_arange_int() {
833 let a = arange(0i32, 5, 1).unwrap();
834 assert_eq!(a.as_slice().unwrap(), &[0, 1, 2, 3, 4]);
835 }
836
837 #[test]
838 fn test_arange_float() {
839 let a = arange(0.0_f64, 1.0, 0.25).unwrap();
840 assert_eq!(a.shape(), &[4]);
841 let data = a.as_slice().unwrap();
842 assert!((data[0] - 0.0).abs() < 1e-10);
843 assert!((data[1] - 0.25).abs() < 1e-10);
844 assert!((data[2] - 0.5).abs() < 1e-10);
845 assert!((data[3] - 0.75).abs() < 1e-10);
846 }
847
848 #[test]
849 fn test_arange_negative_step() {
850 let a = arange(5.0_f64, 0.0, -1.0).unwrap();
851 assert_eq!(a.shape(), &[5]);
852 }
853
854 #[test]
855 fn test_arange_zero_step() {
856 assert!(arange(0.0_f64, 1.0, 0.0).is_err());
857 }
858
859 #[test]
860 fn test_arange_empty() {
861 let a = arange(5i32, 0, 1).unwrap();
862 assert_eq!(a.shape(), &[0]);
863 }
864
865 #[test]
866 fn test_linspace() {
867 let a = linspace(0.0_f64, 1.0, 5, true).unwrap();
868 assert_eq!(a.shape(), &[5]);
869 let data = a.as_slice().unwrap();
870 assert!((data[0] - 0.0).abs() < 1e-10);
871 assert!((data[4] - 1.0).abs() < 1e-10);
872 assert!((data[2] - 0.5).abs() < 1e-10);
873 }
874
875 #[test]
876 fn test_linspace_no_endpoint() {
877 let a = linspace(0.0_f64, 1.0, 4, false).unwrap();
878 assert_eq!(a.shape(), &[4]);
879 let data = a.as_slice().unwrap();
880 assert!((data[0] - 0.0).abs() < 1e-10);
881 assert!((data[1] - 0.25).abs() < 1e-10);
882 }
883
884 #[test]
885 fn test_linspace_single() {
886 let a = linspace(5.0_f64, 10.0, 1, true).unwrap();
887 assert_eq!(a.as_slice().unwrap(), &[5.0]);
888 }
889
890 #[test]
891 fn test_linspace_empty() {
892 let a = linspace(0.0_f64, 1.0, 0, true).unwrap();
893 assert_eq!(a.shape(), &[0]);
894 }
895
896 #[test]
897 fn test_logspace() {
898 let a = logspace(0.0_f64, 2.0, 3, true, 10.0).unwrap();
899 let data = a.as_slice().unwrap();
900 assert!((data[0] - 1.0).abs() < 1e-10); assert!((data[1] - 10.0).abs() < 1e-10); assert!((data[2] - 100.0).abs() < 1e-10); }
904
905 #[test]
906 fn test_geomspace() {
907 let a = geomspace(1.0_f64, 1000.0, 4, true).unwrap();
908 let data = a.as_slice().unwrap();
909 assert!((data[0] - 1.0).abs() < 1e-10);
910 assert!((data[1] - 10.0).abs() < 1e-8);
911 assert!((data[2] - 100.0).abs() < 1e-6);
912 assert!((data[3] - 1000.0).abs() < 1e-4);
913 }
914
915 #[test]
916 fn test_geomspace_zero_start() {
917 assert!(geomspace(0.0_f64, 1.0, 5, true).is_err());
918 }
919
920 #[test]
921 fn test_geomspace_different_signs() {
922 assert!(geomspace(-1.0_f64, 1.0, 5, true).is_err());
923 }
924
925 #[test]
926 fn test_meshgrid_xy() {
927 let x = Array::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
928 let y = Array::from_vec(Ix1::new([2]), vec![4.0, 5.0]).unwrap();
929 let grids = meshgrid(&[x, y], "xy").unwrap();
930 assert_eq!(grids.len(), 2);
931 assert_eq!(grids[0].shape(), &[2, 3]);
932 assert_eq!(grids[1].shape(), &[2, 3]);
933 let xdata: Vec<f64> = grids[0].iter().copied().collect();
935 assert_eq!(xdata, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
936 let ydata: Vec<f64> = grids[1].iter().copied().collect();
938 assert_eq!(ydata, vec![4.0, 4.0, 4.0, 5.0, 5.0, 5.0]);
939 }
940
941 #[test]
942 fn test_meshgrid_ij() {
943 let x = Array::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
944 let y = Array::from_vec(Ix1::new([2]), vec![4.0, 5.0]).unwrap();
945 let grids = meshgrid(&[x, y], "ij").unwrap();
946 assert_eq!(grids.len(), 2);
947 assert_eq!(grids[0].shape(), &[3, 2]);
948 assert_eq!(grids[1].shape(), &[3, 2]);
949 }
950
951 #[test]
952 fn test_meshgrid_bad_indexing() {
953 assert!(meshgrid(&[], "zz").is_err());
954 }
955
956 #[test]
957 fn test_mgrid() {
958 let grids = mgrid(&[(0.0, 3.0, 1.0), (0.0, 2.0, 1.0)]).unwrap();
959 assert_eq!(grids.len(), 2);
960 assert_eq!(grids[0].shape(), &[3, 2]);
961 }
962
963 #[test]
964 fn test_ogrid() {
965 let grids = ogrid(&[(0.0, 3.0, 1.0), (0.0, 2.0, 1.0)]).unwrap();
966 assert_eq!(grids.len(), 2);
967 assert_eq!(grids[0].shape(), &[3, 1]);
968 assert_eq!(grids[1].shape(), &[1, 2]);
969 }
970
971 #[test]
974 fn test_identity() {
975 let a = identity::<f64>(3).unwrap();
976 assert_eq!(a.shape(), &[3, 3]);
977 let data = a.as_slice().unwrap();
978 assert_eq!(data, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
979 }
980
981 #[test]
982 fn test_eye() {
983 let a = eye::<f64>(3, 4, 0).unwrap();
984 assert_eq!(a.shape(), &[3, 4]);
985 let data = a.as_slice().unwrap();
986 assert_eq!(
987 data,
988 &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
989 );
990 }
991
992 #[test]
993 fn test_eye_positive_k() {
994 let a = eye::<f64>(3, 3, 1).unwrap();
995 let data = a.as_slice().unwrap();
996 assert_eq!(data, &[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]);
997 }
998
999 #[test]
1000 fn test_eye_negative_k() {
1001 let a = eye::<f64>(3, 3, -1).unwrap();
1002 let data = a.as_slice().unwrap();
1003 assert_eq!(data, &[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
1004 }
1005
1006 #[test]
1007 fn test_diag_from_1d() {
1008 let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1009 let d = diag(&a, 0).unwrap();
1010 assert_eq!(d.shape(), &[3, 3]);
1011 let data: Vec<f64> = d.iter().copied().collect();
1012 assert_eq!(data, vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
1013 }
1014
1015 #[test]
1016 fn test_diag_from_2d() {
1017 let a = Array::from_vec(
1018 IxDyn::new(&[3, 3]),
1019 vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0],
1020 )
1021 .unwrap();
1022 let d = diag(&a, 0).unwrap();
1023 assert_eq!(d.shape(), &[3]);
1024 let data: Vec<f64> = d.iter().copied().collect();
1025 assert_eq!(data, vec![1.0, 2.0, 3.0]);
1026 }
1027
1028 #[test]
1029 fn test_diag_k_positive() {
1030 let a = Array::from_vec(IxDyn::new(&[2]), vec![1.0, 2.0]).unwrap();
1031 let d = diag(&a, 1).unwrap();
1032 assert_eq!(d.shape(), &[3, 3]);
1033 let data: Vec<f64> = d.iter().copied().collect();
1034 assert_eq!(data, vec![0.0, 1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0]);
1035 }
1036
1037 #[test]
1038 fn test_diagflat() {
1039 let a = Array::from_vec(IxDyn::new(&[2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1040 let d = diagflat(&a, 0).unwrap();
1041 assert_eq!(d.shape(), &[4, 4]);
1042 let extracted = diag(&d, 0).unwrap();
1044 let data: Vec<f64> = extracted.iter().copied().collect();
1045 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
1046 }
1047
1048 #[test]
1049 fn test_tri() {
1050 let a = tri::<f64>(3, 3, 0).unwrap();
1051 let data = a.as_slice().unwrap();
1052 assert_eq!(data, &[1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0]);
1053 }
1054
1055 #[test]
1056 fn test_tril() {
1057 let a = Array::from_vec(
1058 IxDyn::new(&[3, 3]),
1059 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1060 )
1061 .unwrap();
1062 let t = tril(&a, 0).unwrap();
1063 let data: Vec<f64> = t.iter().copied().collect();
1064 assert_eq!(data, vec![1.0, 0.0, 0.0, 4.0, 5.0, 0.0, 7.0, 8.0, 9.0]);
1065 }
1066
1067 #[test]
1068 fn test_triu() {
1069 let a = Array::from_vec(
1070 IxDyn::new(&[3, 3]),
1071 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1072 )
1073 .unwrap();
1074 let t = triu(&a, 0).unwrap();
1075 let data: Vec<f64> = t.iter().copied().collect();
1076 assert_eq!(data, vec![1.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0, 0.0, 9.0]);
1077 }
1078
1079 #[test]
1080 fn test_tril_not_2d() {
1081 let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1082 assert!(tril(&a, 0).is_err());
1083 }
1084
1085 #[test]
1086 fn test_triu_not_2d() {
1087 let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1088 assert!(triu(&a, 0).is_err());
1089 }
1090}