1use std::mem::MaybeUninit;
7
8use crate::array::owned::Array;
9use crate::array::view::ArrayView;
10use crate::dimension::{Dimension, Ix1, Ix2, IxDyn};
11use crate::dtype::Element;
12use crate::error::{FerrayError, FerrayResult};
13
14pub fn array<T: Element, D: Dimension>(dim: D, data: Vec<T>) -> FerrayResult<Array<T, D>> {
27 Array::from_vec(dim, data)
28}
29
30pub fn asarray<T: Element, D: Dimension>(dim: D, data: Vec<T>) -> FerrayResult<Array<T, D>> {
38 Array::from_vec(dim, data)
39}
40
41pub fn frombuffer<T: Element, D: Dimension>(dim: D, buf: &[u8]) -> FerrayResult<Array<T, D>> {
49 let elem_size = std::mem::size_of::<T>();
50 if elem_size == 0 {
51 return Err(FerrayError::invalid_value("zero-sized type"));
52 }
53 if buf.len() % elem_size != 0 {
54 return Err(FerrayError::invalid_value(format!(
55 "buffer length {} is not a multiple of element size {}",
56 buf.len(),
57 elem_size,
58 )));
59 }
60 let n_elems = buf.len() / elem_size;
61 let expected = dim.size();
62 if n_elems != expected {
63 return Err(FerrayError::shape_mismatch(format!(
64 "buffer contains {} elements but shape {:?} requires {}",
65 n_elems,
66 dim.as_slice(),
67 expected,
68 )));
69 }
70 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<bool>() {
73 for &byte in buf {
74 if byte > 1 {
75 return Err(FerrayError::invalid_value(format!(
76 "invalid byte {byte:#04x} for bool (must be 0x00 or 0x01)"
77 )));
78 }
79 }
80 }
81
82 let mut data = Vec::with_capacity(n_elems);
84 for i in 0..n_elems {
85 let start = i * elem_size;
86 let end = start + elem_size;
87 let slice = &buf[start..end];
88 let val = unsafe {
92 let mut val = MaybeUninit::<T>::uninit();
93 std::ptr::copy_nonoverlapping(slice.as_ptr(), val.as_mut_ptr() as *mut u8, elem_size);
94 val.assume_init()
95 };
96 data.push(val);
97 }
98 Array::from_vec(dim, data)
99}
100
101pub fn frombuffer_view<'a, T: Element, D: Dimension>(
118 dim: D,
119 buf: &'a [u8],
120) -> FerrayResult<ArrayView<'a, T, D>> {
121 let elem_size = std::mem::size_of::<T>();
122 if elem_size == 0 {
123 return Err(FerrayError::invalid_value("zero-sized type"));
124 }
125 if buf.len() % elem_size != 0 {
126 return Err(FerrayError::invalid_value(format!(
127 "buffer length {} is not a multiple of element size {}",
128 buf.len(),
129 elem_size,
130 )));
131 }
132 let n_elems = buf.len() / elem_size;
133 let expected = dim.size();
134 if n_elems != expected {
135 return Err(FerrayError::shape_mismatch(format!(
136 "buffer contains {} elements but shape {:?} requires {}",
137 n_elems,
138 dim.as_slice(),
139 expected,
140 )));
141 }
142
143 let align = std::mem::align_of::<T>();
146 let addr = buf.as_ptr() as usize;
147 if addr % align != 0 {
148 return Err(FerrayError::invalid_value(format!(
149 "buffer address 0x{addr:x} is not aligned to {align} bytes required by the element type; \
150 use `frombuffer` for misaligned input"
151 )));
152 }
153
154 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<bool>() {
157 for &byte in buf {
158 if byte > 1 {
159 return Err(FerrayError::invalid_value(format!(
160 "invalid byte {byte:#04x} for bool (must be 0x00 or 0x01)"
161 )));
162 }
163 }
164 }
165
166 let ptr = buf.as_ptr() as *const T;
177 let nd_dim = dim.to_ndarray_dim();
178 let nd_view = unsafe { ndarray::ArrayView::from_shape_ptr(nd_dim, ptr) };
179 Ok(ArrayView::from_ndarray(nd_view))
180}
181
182pub fn fromiter<T: Element>(iter: impl IntoIterator<Item = T>) -> FerrayResult<Array<T, Ix1>> {
189 Array::from_iter_1d(iter)
190}
191
192pub fn zeros<T: Element, D: Dimension>(dim: D) -> FerrayResult<Array<T, D>> {
196 Array::zeros(dim)
197}
198
199pub fn ones<T: Element, D: Dimension>(dim: D) -> FerrayResult<Array<T, D>> {
203 Array::ones(dim)
204}
205
206pub fn full<T: Element, D: Dimension>(dim: D, fill_value: T) -> FerrayResult<Array<T, D>> {
210 Array::from_elem(dim, fill_value)
211}
212
213pub fn zeros_like<T: Element, D: Dimension>(other: &Array<T, D>) -> FerrayResult<Array<T, D>> {
217 Array::zeros(other.dim().clone())
218}
219
220pub fn ones_like<T: Element, D: Dimension>(other: &Array<T, D>) -> FerrayResult<Array<T, D>> {
224 Array::ones(other.dim().clone())
225}
226
227pub fn full_like<T: Element, D: Dimension>(
231 other: &Array<T, D>,
232 fill_value: T,
233) -> FerrayResult<Array<T, D>> {
234 Array::from_elem(other.dim().clone(), fill_value)
235}
236
237pub struct UninitArray<T: Element, D: Dimension> {
246 data: Vec<MaybeUninit<T>>,
247 dim: D,
248}
249
250impl<T: Element, D: Dimension> UninitArray<T, D> {
251 #[inline]
253 pub fn shape(&self) -> &[usize] {
254 self.dim.as_slice()
255 }
256
257 #[inline]
259 pub fn size(&self) -> usize {
260 self.data.len()
261 }
262
263 #[inline]
265 pub fn ndim(&self) -> usize {
266 self.dim.ndim()
267 }
268
269 #[inline]
274 pub fn as_mut_ptr(&mut self) -> *mut MaybeUninit<T> {
275 self.data.as_mut_ptr()
276 }
277
278 pub fn write_at(&mut self, flat_index: usize, value: T) -> FerrayResult<()> {
283 let size = self.size();
284 if flat_index >= size {
285 return Err(FerrayError::IndexOutOfBounds {
286 index: flat_index as isize,
287 axis: 0,
288 size,
289 });
290 }
291 self.data[flat_index] = MaybeUninit::new(value);
292 Ok(())
293 }
294
295 pub unsafe fn assume_init(self) -> Array<T, D> {
302 let nd_dim = self.dim.to_ndarray_dim();
303 let len = self.data.len();
304
305 let mut raw_vec = std::mem::ManuallyDrop::new(self.data);
309 let data: Vec<T> =
310 unsafe { Vec::from_raw_parts(raw_vec.as_mut_ptr() as *mut T, len, raw_vec.capacity()) };
311
312 let inner = ndarray::Array::from_shape_vec(nd_dim, data)
313 .expect("UninitArray assume_init: shape/data mismatch (this is a bug)");
314 Array::from_ndarray(inner)
315 }
316}
317
318pub fn empty<T: Element, D: Dimension>(dim: D) -> UninitArray<T, D> {
326 let size = dim.size();
327 let mut data = Vec::with_capacity(size);
328 unsafe {
331 data.set_len(size);
332 }
333 UninitArray { data, dim }
334}
335
336pub fn empty_like<T: Element, D: Dimension>(other: &Array<T, D>) -> UninitArray<T, D> {
345 empty(other.dim().clone())
346}
347
348pub trait ArangeNum: Element + PartialOrd {
355 fn from_f64(v: f64) -> Self;
357 fn to_f64(self) -> f64;
359}
360
361macro_rules! impl_arange_int {
362 ($($ty:ty),*) => {
363 $(
364 impl ArangeNum for $ty {
365 #[inline]
366 fn from_f64(v: f64) -> Self { v as Self }
367 #[inline]
368 fn to_f64(self) -> f64 { self as f64 }
369 }
370 )*
371 };
372}
373
374macro_rules! impl_arange_float {
375 ($($ty:ty),*) => {
376 $(
377 impl ArangeNum for $ty {
378 #[inline]
379 fn from_f64(v: f64) -> Self { v as Self }
380 #[inline]
381 fn to_f64(self) -> f64 { self as f64 }
382 }
383 )*
384 };
385}
386
387impl_arange_int!(u8, u16, u32, u64, i8, i16, i32, i64);
388impl_arange_float!(f32, f64);
389
390pub fn arange<T: ArangeNum>(start: T, stop: T, step: T) -> FerrayResult<Array<T, Ix1>> {
397 let step_f = step.to_f64();
398 if step_f == 0.0 {
399 return Err(FerrayError::invalid_value("step cannot be zero"));
400 }
401 let start_f = start.to_f64();
402 let stop_f = stop.to_f64();
403 let n = ((stop_f - start_f) / step_f).ceil();
404 let n = if n < 0.0 { 0 } else { n as usize };
405
406 let mut data = Vec::with_capacity(n);
407 for i in 0..n {
408 data.push(T::from_f64(start_f + (i as f64) * step_f));
409 }
410 let dim = Ix1::new([data.len()]);
411 Array::from_vec(dim, data)
412}
413
414pub trait LinspaceNum: Element + PartialOrd {
416 fn from_f64(v: f64) -> Self;
418 fn to_f64(self) -> f64;
420}
421
422impl LinspaceNum for f32 {
423 #[inline]
424 fn from_f64(v: f64) -> Self {
425 v as f32
426 }
427 #[inline]
428 fn to_f64(self) -> f64 {
429 self as f64
430 }
431}
432
433impl LinspaceNum for f64 {
434 #[inline]
435 fn from_f64(v: f64) -> Self {
436 v
437 }
438 #[inline]
439 fn to_f64(self) -> f64 {
440 self
441 }
442}
443
444pub fn linspace<T: LinspaceNum>(
455 start: T,
456 stop: T,
457 num: usize,
458 endpoint: bool,
459) -> FerrayResult<Array<T, Ix1>> {
460 if num == 0 {
461 return Array::from_vec(Ix1::new([0]), vec![]);
462 }
463 if num == 1 {
464 return Array::from_vec(Ix1::new([1]), vec![start]);
465 }
466 let start_f = start.to_f64();
467 let stop_f = stop.to_f64();
468 let divisor = if endpoint {
469 (num - 1) as f64
470 } else {
471 num as f64
472 };
473 let step = (stop_f - start_f) / divisor;
474 let mut data = Vec::with_capacity(num);
475 for i in 0..num {
476 data.push(T::from_f64(start_f + (i as f64) * step));
477 }
478 Array::from_vec(Ix1::new([num]), data)
479}
480
481pub fn logspace<T: LinspaceNum>(
490 start: T,
491 stop: T,
492 num: usize,
493 endpoint: bool,
494 base: f64,
495) -> FerrayResult<Array<T, Ix1>> {
496 let lin = linspace(start, stop, num, endpoint)?;
497 let data: Vec<T> = lin
498 .iter()
499 .map(|v| T::from_f64(base.powf(v.clone().to_f64())))
500 .collect();
501 Array::from_vec(Ix1::new([num]), data)
502}
503
504pub fn geomspace<T: LinspaceNum>(
514 start: T,
515 stop: T,
516 num: usize,
517 endpoint: bool,
518) -> FerrayResult<Array<T, Ix1>> {
519 let start_f = start.clone().to_f64();
520 let stop_f = stop.clone().to_f64();
521 if start_f == 0.0 || stop_f == 0.0 {
522 return Err(FerrayError::invalid_value(
523 "geomspace: start and stop must be non-zero",
524 ));
525 }
526 if (start_f < 0.0) != (stop_f < 0.0) {
527 return Err(FerrayError::invalid_value(
528 "geomspace: start and stop must have the same sign",
529 ));
530 }
531 if num == 0 {
532 return Array::from_vec(Ix1::new([0]), vec![]);
533 }
534 if num == 1 {
535 return Array::from_vec(Ix1::new([1]), vec![start]);
536 }
537 let log_start = start_f.abs().ln();
538 let log_stop = stop_f.abs().ln();
539 let sign = if start_f < 0.0 { -1.0 } else { 1.0 };
540 let divisor = if endpoint {
541 (num - 1) as f64
542 } else {
543 num as f64
544 };
545 let step = (log_stop - log_start) / divisor;
546 let mut data = Vec::with_capacity(num);
547 for i in 0..num {
548 let log_val = log_start + (i as f64) * step;
549 data.push(T::from_f64(sign * log_val.exp()));
550 }
551 Array::from_vec(Ix1::new([num]), data)
552}
553
554pub fn meshgrid(
568 arrays: &[Array<f64, Ix1>],
569 indexing: &str,
570) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
571 if indexing != "xy" && indexing != "ij" {
572 return Err(FerrayError::invalid_value(
573 "meshgrid: indexing must be 'xy' or 'ij'",
574 ));
575 }
576 let ndim = arrays.len();
577 if ndim == 0 {
578 return Ok(vec![]);
579 }
580
581 let mut shapes: Vec<usize> = arrays.iter().map(|a| a.shape()[0]).collect();
582 if indexing == "xy" && ndim >= 2 {
583 shapes.swap(0, 1);
584 }
585
586 let total: usize = shapes.iter().product();
587 let mut results = Vec::with_capacity(ndim);
588
589 for (k, arr) in arrays.iter().enumerate() {
590 let src_data: Vec<f64> = arr.iter().copied().collect();
591 let mut data = Vec::with_capacity(total);
592 let effective_k = if indexing == "xy" && ndim >= 2 {
594 match k {
595 0 => 1,
596 1 => 0,
597 other => other,
598 }
599 } else {
600 k
601 };
602
603 for flat in 0..total {
605 let mut rem = flat;
607 let mut idx_k = 0;
608 for (d, &s) in shapes.iter().enumerate().rev() {
609 if d == effective_k {
610 idx_k = rem % s;
611 }
612 rem /= s;
613 }
614 data.push(src_data[idx_k]);
615 }
616
617 let dim = IxDyn::new(&shapes);
618 results.push(Array::from_vec(dim, data)?);
619 }
620 Ok(results)
621}
622
623pub fn mgrid(ranges: &[(f64, f64, f64)]) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
633 let mut arrs: Vec<Array<f64, Ix1>> = Vec::with_capacity(ranges.len());
634 for &(start, stop, step) in ranges {
635 arrs.push(arange(start, stop, step)?);
636 }
637 meshgrid(&arrs, "ij")
638}
639
640pub fn ogrid(ranges: &[(f64, f64, f64)]) -> FerrayResult<Vec<Array<f64, IxDyn>>> {
650 let ndim = ranges.len();
651 let mut results = Vec::with_capacity(ndim);
652 for (i, &(start, stop, step)) in ranges.iter().enumerate() {
653 let arr1d = arange(start, stop, step)?;
654 let n = arr1d.shape()[0];
655 let data: Vec<f64> = arr1d.iter().copied().collect();
656 let mut shape = vec![1usize; ndim];
658 shape[i] = n;
659 let dim = IxDyn::new(&shape);
660 results.push(Array::from_vec(dim, data)?);
661 }
662 Ok(results)
663}
664
665pub fn identity<T: Element>(n: usize) -> FerrayResult<Array<T, Ix2>> {
673 eye(n, n, 0)
674}
675
676pub fn eye<T: Element>(n: usize, m: usize, k: isize) -> FerrayResult<Array<T, Ix2>> {
682 let mut data = vec![T::zero(); n * m];
683 for i in 0..n {
684 let j = i as isize + k;
685 if j >= 0 && (j as usize) < m {
686 data[i * m + j as usize] = T::one();
687 }
688 }
689 Array::from_vec(Ix2::new([n, m]), data)
690}
691
692pub fn diag<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
702 let shape = a.shape();
703 match shape.len() {
704 1 => {
705 let n = shape[0];
707 let size = n + k.unsigned_abs();
708 let mut data = vec![T::zero(); size * size];
709 let src: Vec<T> = a.iter().cloned().collect();
710 for (i, val) in src.into_iter().enumerate() {
711 let row = if k >= 0 { i } else { i + k.unsigned_abs() };
712 let col = if k >= 0 { i + k as usize } else { i };
713 data[row * size + col] = val;
714 }
715 Array::from_vec(IxDyn::new(&[size, size]), data)
716 }
717 2 => {
718 let (n, m) = (shape[0], shape[1]);
720 let src: Vec<T> = a.iter().cloned().collect();
721 let mut diag_vals = Vec::new();
722 for i in 0..n {
723 let j = i as isize + k;
724 if j >= 0 && (j as usize) < m {
725 diag_vals.push(src[i * m + j as usize].clone());
726 }
727 }
728 let len = diag_vals.len();
729 Array::from_vec(IxDyn::new(&[len]), diag_vals)
730 }
731 _ => Err(FerrayError::invalid_value("diag: input must be 1-D or 2-D")),
732 }
733}
734
735pub fn diagflat<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
742 let flat: Vec<T> = a.iter().cloned().collect();
744 let n = flat.len();
745 let arr1d = Array::from_vec(IxDyn::new(&[n]), flat)?;
746 diag(&arr1d, k)
747}
748
749pub fn tri<T: Element>(n: usize, m: usize, k: isize) -> FerrayResult<Array<T, Ix2>> {
755 let mut data = vec![T::zero(); n * m];
756 for i in 0..n {
757 for j in 0..m {
758 if (i as isize) >= (j as isize) - k {
759 data[i * m + j] = T::one();
760 }
761 }
762 }
763 Array::from_vec(Ix2::new([n, m]), data)
764}
765
766pub fn tril<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
775 let shape = a.shape();
776 if shape.len() != 2 {
777 return Err(FerrayError::invalid_value("tril: input must be 2-D"));
778 }
779 let (n, m) = (shape[0], shape[1]);
780 let src: Vec<T> = a.iter().cloned().collect();
781 let mut data = vec![T::zero(); n * m];
782 for i in 0..n {
783 for j in 0..m {
784 if (i as isize) >= (j as isize) - k {
785 data[i * m + j] = src[i * m + j].clone();
786 }
787 }
788 }
789 Array::from_vec(IxDyn::new(&[n, m]), data)
790}
791
792pub fn triu<T: Element>(a: &Array<T, IxDyn>, k: isize) -> FerrayResult<Array<T, IxDyn>> {
801 let shape = a.shape();
802 if shape.len() != 2 {
803 return Err(FerrayError::invalid_value("triu: input must be 2-D"));
804 }
805 let (n, m) = (shape[0], shape[1]);
806 let src: Vec<T> = a.iter().cloned().collect();
807 let mut data = vec![T::zero(); n * m];
808 for i in 0..n {
809 for j in 0..m {
810 if (i as isize) <= (j as isize) - k {
811 data[i * m + j] = src[i * m + j].clone();
812 }
813 }
814 }
815 Array::from_vec(IxDyn::new(&[n, m]), data)
816}
817
818#[cfg(test)]
823mod tests {
824 use super::*;
825 use crate::dimension::{Ix1, Ix2, IxDyn};
826
827 #[test]
830 fn test_array_creation() {
831 let a = array(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
832 assert_eq!(a.shape(), &[2, 3]);
833 assert_eq!(a.size(), 6);
834 }
835
836 #[test]
837 fn test_asarray() {
838 let a = asarray(Ix1::new([3]), vec![1, 2, 3]).unwrap();
839 assert_eq!(a.as_slice().unwrap(), &[1, 2, 3]);
840 }
841
842 #[test]
843 fn test_frombuffer() {
844 let bytes: Vec<u8> = vec![1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0];
845 let a = frombuffer::<i32, Ix1>(Ix1::new([3]), &bytes).unwrap();
846 assert_eq!(a.as_slice().unwrap(), &[1, 2, 3]);
847 }
848
849 #[test]
850 fn test_frombuffer_bad_length() {
851 let bytes: Vec<u8> = vec![1, 0, 0];
852 assert!(frombuffer::<i32, Ix1>(Ix1::new([1]), &bytes).is_err());
853 }
854
855 #[test]
856 fn test_frombuffer_bool() {
857 let bytes: Vec<u8> = vec![0, 1, 0, 1, 1];
862 let a = frombuffer::<bool, Ix1>(Ix1::new([5]), &bytes).unwrap();
863 assert_eq!(a.as_slice().unwrap(), &[false, true, false, true, true]);
864 }
865
866 #[test]
867 fn test_frombuffer_bool_wrong_length() {
868 let bytes: Vec<u8> = vec![0, 1];
871 assert!(frombuffer::<bool, Ix1>(Ix1::new([3]), &bytes).is_err());
872 }
873
874 fn aligned_bytes<T: Copy>(src: &[T]) -> Vec<u8> {
879 let n = std::mem::size_of_val(src);
880 let mut out = vec![0u8; n];
881 unsafe {
883 std::ptr::copy_nonoverlapping(src.as_ptr() as *const u8, out.as_mut_ptr(), n);
884 }
885 out
886 }
887
888 #[test]
889 fn test_frombuffer_view_i32_is_zero_copy() {
890 let source: Vec<i32> = vec![10, 20, 30];
892 let bytes = aligned_bytes(&source);
893 let view = frombuffer_view::<i32, Ix1>(Ix1::new([3]), &bytes).unwrap();
894 assert_eq!(view.shape(), &[3]);
895 let values: Vec<i32> = view.iter().copied().collect();
896 assert_eq!(values, vec![10, 20, 30]);
897 assert_eq!(view.as_ptr() as *const u8, bytes.as_ptr());
900 }
901
902 #[test]
903 fn test_frombuffer_view_f64_2d() {
904 let source: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
905 let bytes = aligned_bytes(&source);
906 let view = frombuffer_view::<f64, Ix2>(Ix2::new([2, 3]), &bytes).unwrap();
907 assert_eq!(view.shape(), &[2, 3]);
908 let values: Vec<f64> = view.iter().copied().collect();
909 assert_eq!(values, source);
910 }
911
912 #[test]
913 fn test_frombuffer_view_bool_valid() {
914 let bytes: Vec<u8> = vec![0, 1, 0, 1];
915 let view = frombuffer_view::<bool, Ix1>(Ix1::new([4]), &bytes).unwrap();
916 let values: Vec<bool> = view.iter().copied().collect();
917 assert_eq!(values, vec![false, true, false, true]);
918 }
919
920 #[test]
921 fn test_frombuffer_view_bool_rejects_invalid_byte() {
922 let bytes: Vec<u8> = vec![0, 1, 42]; assert!(frombuffer_view::<bool, Ix1>(Ix1::new([3]), &bytes).is_err());
924 }
925
926 #[test]
927 fn test_frombuffer_view_rejects_wrong_length() {
928 let bytes = vec![0u8; 13];
930 assert!(frombuffer_view::<i32, Ix1>(Ix1::new([3]), &bytes).is_err());
931 let bytes = vec![0u8; 8];
933 assert!(frombuffer_view::<i32, Ix1>(Ix1::new([3]), &bytes).is_err());
934 }
935
936 #[test]
937 fn test_frombuffer_view_rejects_misalignment() {
938 let mut backing: Vec<u8> = vec![0u8; 1 + 4 * 3];
940 for (i, chunk) in backing[1..].chunks_exact_mut(4).enumerate() {
941 chunk.copy_from_slice(&(i as i32).to_ne_bytes());
942 }
943 let misaligned = &backing[1..];
944 assert!((misaligned.as_ptr() as usize) % 4 != 0);
947 assert!(frombuffer_view::<i32, Ix1>(Ix1::new([3]), misaligned).is_err());
948 }
949
950 #[test]
951 fn test_fromiter() {
952 let a = fromiter((0..5).map(|x| x as f64)).unwrap();
953 assert_eq!(a.shape(), &[5]);
954 assert_eq!(a.as_slice().unwrap(), &[0.0, 1.0, 2.0, 3.0, 4.0]);
955 }
956
957 #[test]
958 fn test_zeros() {
959 let a = zeros::<f64, Ix2>(Ix2::new([3, 4])).unwrap();
960 assert_eq!(a.shape(), &[3, 4]);
961 assert!(a.iter().all(|&v| v == 0.0));
962 }
963
964 #[test]
965 fn test_ones() {
966 let a = ones::<f64, Ix1>(Ix1::new([5])).unwrap();
967 assert!(a.iter().all(|&v| v == 1.0));
968 }
969
970 #[test]
971 fn test_full() {
972 let a = full(Ix1::new([4]), 42i32).unwrap();
973 assert!(a.iter().all(|&v| v == 42));
974 }
975
976 #[test]
977 fn test_zeros_like() {
978 let a = ones::<f64, Ix2>(Ix2::new([2, 3])).unwrap();
979 let b = zeros_like(&a).unwrap();
980 assert_eq!(b.shape(), &[2, 3]);
981 assert!(b.iter().all(|&v| v == 0.0));
982 }
983
984 #[test]
985 fn test_ones_like() {
986 let a = zeros::<f64, Ix1>(Ix1::new([4])).unwrap();
987 let b = ones_like(&a).unwrap();
988 assert!(b.iter().all(|&v| v == 1.0));
989 }
990
991 #[test]
992 fn test_full_like() {
993 let a = zeros::<i32, Ix1>(Ix1::new([3])).unwrap();
994 let b = full_like(&a, 7).unwrap();
995 assert!(b.iter().all(|&v| v == 7));
996 }
997
998 #[test]
1001 fn test_empty_and_init() {
1002 let mut u = empty::<f64, Ix1>(Ix1::new([3]));
1003 assert_eq!(u.shape(), &[3]);
1004 u.write_at(0, 1.0).unwrap();
1005 u.write_at(1, 2.0).unwrap();
1006 u.write_at(2, 3.0).unwrap();
1007 let a = unsafe { u.assume_init() };
1009 assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
1010 }
1011
1012 #[test]
1013 fn test_empty_write_oob() {
1014 let mut u = empty::<f64, Ix1>(Ix1::new([2]));
1015 assert!(u.write_at(5, 1.0).is_err());
1016 }
1017
1018 #[test]
1020 fn test_empty_like_matches_shape_2d() {
1021 use crate::dimension::Ix2;
1022 let src = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1023 .unwrap();
1024 let mut u = empty_like(&src);
1025 assert_eq!(u.shape(), &[2, 3]);
1026 assert_eq!(u.size(), 6);
1027 assert_eq!(u.ndim(), 2);
1028
1029 for i in 0..6 {
1031 u.write_at(i, -(i as f64)).unwrap();
1032 }
1033 let out = unsafe { u.assume_init() };
1035 assert_eq!(out.shape(), &[2, 3]);
1036 assert_eq!(
1037 out.as_slice().unwrap(),
1038 &[0.0, -1.0, -2.0, -3.0, -4.0, -5.0]
1039 );
1040 assert_eq!(src.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1042 }
1043
1044 #[test]
1045 fn test_empty_like_zero_sized() {
1046 let src = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
1047 let u = empty_like(&src);
1048 assert_eq!(u.shape(), &[0]);
1049 assert_eq!(u.size(), 0);
1050 let out = unsafe { u.assume_init() };
1052 assert_eq!(out.size(), 0);
1053 }
1054
1055 #[test]
1058 fn test_arange_int() {
1059 let a = arange(0i32, 5, 1).unwrap();
1060 assert_eq!(a.as_slice().unwrap(), &[0, 1, 2, 3, 4]);
1061 }
1062
1063 #[test]
1064 fn test_arange_float() {
1065 let a = arange(0.0_f64, 1.0, 0.25).unwrap();
1066 assert_eq!(a.shape(), &[4]);
1067 let data = a.as_slice().unwrap();
1068 assert!((data[0] - 0.0).abs() < 1e-10);
1069 assert!((data[1] - 0.25).abs() < 1e-10);
1070 assert!((data[2] - 0.5).abs() < 1e-10);
1071 assert!((data[3] - 0.75).abs() < 1e-10);
1072 }
1073
1074 #[test]
1075 fn test_arange_negative_step() {
1076 let a = arange(5.0_f64, 0.0, -1.0).unwrap();
1077 assert_eq!(a.shape(), &[5]);
1078 }
1079
1080 #[test]
1081 fn test_arange_zero_step() {
1082 assert!(arange(0.0_f64, 1.0, 0.0).is_err());
1083 }
1084
1085 #[test]
1086 fn test_arange_empty() {
1087 let a = arange(5i32, 0, 1).unwrap();
1088 assert_eq!(a.shape(), &[0]);
1089 }
1090
1091 #[test]
1092 fn test_linspace() {
1093 let a = linspace(0.0_f64, 1.0, 5, true).unwrap();
1094 assert_eq!(a.shape(), &[5]);
1095 let data = a.as_slice().unwrap();
1096 assert!((data[0] - 0.0).abs() < 1e-10);
1097 assert!((data[4] - 1.0).abs() < 1e-10);
1098 assert!((data[2] - 0.5).abs() < 1e-10);
1099 }
1100
1101 #[test]
1102 fn test_linspace_no_endpoint() {
1103 let a = linspace(0.0_f64, 1.0, 4, false).unwrap();
1104 assert_eq!(a.shape(), &[4]);
1105 let data = a.as_slice().unwrap();
1106 assert!((data[0] - 0.0).abs() < 1e-10);
1107 assert!((data[1] - 0.25).abs() < 1e-10);
1108 }
1109
1110 #[test]
1111 fn test_linspace_single() {
1112 let a = linspace(5.0_f64, 10.0, 1, true).unwrap();
1113 assert_eq!(a.as_slice().unwrap(), &[5.0]);
1114 }
1115
1116 #[test]
1117 fn test_linspace_empty() {
1118 let a = linspace(0.0_f64, 1.0, 0, true).unwrap();
1119 assert_eq!(a.shape(), &[0]);
1120 }
1121
1122 #[test]
1123 fn test_logspace() {
1124 let a = logspace(0.0_f64, 2.0, 3, true, 10.0).unwrap();
1125 let data = a.as_slice().unwrap();
1126 assert!((data[0] - 1.0).abs() < 1e-10); assert!((data[1] - 10.0).abs() < 1e-10); assert!((data[2] - 100.0).abs() < 1e-10); }
1130
1131 #[test]
1132 fn test_geomspace() {
1133 let a = geomspace(1.0_f64, 1000.0, 4, true).unwrap();
1134 let data = a.as_slice().unwrap();
1135 assert!((data[0] - 1.0).abs() < 1e-10);
1136 assert!((data[1] - 10.0).abs() < 1e-8);
1137 assert!((data[2] - 100.0).abs() < 1e-6);
1138 assert!((data[3] - 1000.0).abs() < 1e-4);
1139 }
1140
1141 #[test]
1142 fn test_geomspace_zero_start() {
1143 assert!(geomspace(0.0_f64, 1.0, 5, true).is_err());
1144 }
1145
1146 #[test]
1147 fn test_geomspace_different_signs() {
1148 assert!(geomspace(-1.0_f64, 1.0, 5, true).is_err());
1149 }
1150
1151 #[test]
1152 fn test_meshgrid_xy() {
1153 let x = Array::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
1154 let y = Array::from_vec(Ix1::new([2]), vec![4.0, 5.0]).unwrap();
1155 let grids = meshgrid(&[x, y], "xy").unwrap();
1156 assert_eq!(grids.len(), 2);
1157 assert_eq!(grids[0].shape(), &[2, 3]);
1158 assert_eq!(grids[1].shape(), &[2, 3]);
1159 let xdata: Vec<f64> = grids[0].iter().copied().collect();
1161 assert_eq!(xdata, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
1162 let ydata: Vec<f64> = grids[1].iter().copied().collect();
1164 assert_eq!(ydata, vec![4.0, 4.0, 4.0, 5.0, 5.0, 5.0]);
1165 }
1166
1167 #[test]
1168 fn test_meshgrid_ij() {
1169 let x = Array::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
1170 let y = Array::from_vec(Ix1::new([2]), vec![4.0, 5.0]).unwrap();
1171 let grids = meshgrid(&[x, y], "ij").unwrap();
1172 assert_eq!(grids.len(), 2);
1173 assert_eq!(grids[0].shape(), &[3, 2]);
1174 assert_eq!(grids[1].shape(), &[3, 2]);
1175 }
1176
1177 #[test]
1178 fn test_meshgrid_bad_indexing() {
1179 assert!(meshgrid(&[], "zz").is_err());
1180 }
1181
1182 #[test]
1183 fn test_mgrid() {
1184 let grids = mgrid(&[(0.0, 3.0, 1.0), (0.0, 2.0, 1.0)]).unwrap();
1185 assert_eq!(grids.len(), 2);
1186 assert_eq!(grids[0].shape(), &[3, 2]);
1187 }
1188
1189 #[test]
1190 fn test_ogrid() {
1191 let grids = ogrid(&[(0.0, 3.0, 1.0), (0.0, 2.0, 1.0)]).unwrap();
1192 assert_eq!(grids.len(), 2);
1193 assert_eq!(grids[0].shape(), &[3, 1]);
1194 assert_eq!(grids[1].shape(), &[1, 2]);
1195 }
1196
1197 #[test]
1200 fn test_identity() {
1201 let a = identity::<f64>(3).unwrap();
1202 assert_eq!(a.shape(), &[3, 3]);
1203 let data = a.as_slice().unwrap();
1204 assert_eq!(data, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
1205 }
1206
1207 #[test]
1208 fn test_eye() {
1209 let a = eye::<f64>(3, 4, 0).unwrap();
1210 assert_eq!(a.shape(), &[3, 4]);
1211 let data = a.as_slice().unwrap();
1212 assert_eq!(
1213 data,
1214 &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
1215 );
1216 }
1217
1218 #[test]
1219 fn test_eye_positive_k() {
1220 let a = eye::<f64>(3, 3, 1).unwrap();
1221 let data = a.as_slice().unwrap();
1222 assert_eq!(data, &[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]);
1223 }
1224
1225 #[test]
1226 fn test_eye_negative_k() {
1227 let a = eye::<f64>(3, 3, -1).unwrap();
1228 let data = a.as_slice().unwrap();
1229 assert_eq!(data, &[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
1230 }
1231
1232 #[test]
1233 fn test_diag_from_1d() {
1234 let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1235 let d = diag(&a, 0).unwrap();
1236 assert_eq!(d.shape(), &[3, 3]);
1237 let data: Vec<f64> = d.iter().copied().collect();
1238 assert_eq!(data, vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
1239 }
1240
1241 #[test]
1242 fn test_diag_from_2d() {
1243 let a = Array::from_vec(
1244 IxDyn::new(&[3, 3]),
1245 vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0],
1246 )
1247 .unwrap();
1248 let d = diag(&a, 0).unwrap();
1249 assert_eq!(d.shape(), &[3]);
1250 let data: Vec<f64> = d.iter().copied().collect();
1251 assert_eq!(data, vec![1.0, 2.0, 3.0]);
1252 }
1253
1254 #[test]
1255 fn test_diag_k_positive() {
1256 let a = Array::from_vec(IxDyn::new(&[2]), vec![1.0, 2.0]).unwrap();
1257 let d = diag(&a, 1).unwrap();
1258 assert_eq!(d.shape(), &[3, 3]);
1259 let data: Vec<f64> = d.iter().copied().collect();
1260 assert_eq!(data, vec![0.0, 1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0]);
1261 }
1262
1263 #[test]
1264 fn test_diagflat() {
1265 let a = Array::from_vec(IxDyn::new(&[2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1266 let d = diagflat(&a, 0).unwrap();
1267 assert_eq!(d.shape(), &[4, 4]);
1268 let extracted = diag(&d, 0).unwrap();
1270 let data: Vec<f64> = extracted.iter().copied().collect();
1271 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
1272 }
1273
1274 #[test]
1275 fn test_tri() {
1276 let a = tri::<f64>(3, 3, 0).unwrap();
1277 let data = a.as_slice().unwrap();
1278 assert_eq!(data, &[1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0]);
1279 }
1280
1281 #[test]
1282 fn test_tril() {
1283 let a = Array::from_vec(
1284 IxDyn::new(&[3, 3]),
1285 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1286 )
1287 .unwrap();
1288 let t = tril(&a, 0).unwrap();
1289 let data: Vec<f64> = t.iter().copied().collect();
1290 assert_eq!(data, vec![1.0, 0.0, 0.0, 4.0, 5.0, 0.0, 7.0, 8.0, 9.0]);
1291 }
1292
1293 #[test]
1294 fn test_triu() {
1295 let a = Array::from_vec(
1296 IxDyn::new(&[3, 3]),
1297 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1298 )
1299 .unwrap();
1300 let t = triu(&a, 0).unwrap();
1301 let data: Vec<f64> = t.iter().copied().collect();
1302 assert_eq!(data, vec![1.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0, 0.0, 9.0]);
1303 }
1304
1305 #[test]
1306 fn test_tril_not_2d() {
1307 let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1308 assert!(tril(&a, 0).is_err());
1309 }
1310
1311 #[test]
1312 fn test_triu_not_2d() {
1313 let a = Array::from_vec(IxDyn::new(&[3]), vec![1.0, 2.0, 3.0]).unwrap();
1314 assert!(triu(&a, 0).is_err());
1315 }
1316}