1#![allow(deprecated)] #![allow(clippy::result_large_err)] #![allow(clippy::needless_range_loop)] #![allow(clippy::too_many_arguments)] #![allow(clippy::identity_op)] #![allow(clippy::approx_constant)] #![allow(clippy::excessive_precision)] pub mod algorithms;
68pub mod array;
69pub mod array_ops;
70pub mod array_ops_legacy;
71pub mod arrays;
72#[cfg(feature = "arrow")]
73pub mod arrow;
74pub mod autodiff;
75pub mod axis_ops;
76pub mod bitwise_ops;
77pub mod blas;
78pub mod char;
79pub mod cluster;
80pub mod comparisons;
81pub mod comparisons_broadcast;
82pub mod complex_ops;
83pub mod conversions;
84pub mod derivative;
85pub mod distance;
86pub mod error;
87pub mod error_handling;
88pub mod expr;
89pub mod fft;
90pub mod financial;
91#[cfg(feature = "gpu")]
92pub mod gpu;
93pub mod indexing;
94pub mod integrate;
95pub mod interop;
96pub mod interpolate;
97pub mod io;
98pub mod linalg;
99pub mod linalg_extended;
100pub mod linalg_optimized;
101pub mod linalg_parallel;
102pub mod optimized_ops; pub mod linalg_stable;
105pub mod masked;
106pub mod math;
107pub mod math_extended;
108pub mod matrix;
109pub mod memory_alloc;
110pub mod memory_optimize;
111pub mod mmap;
112pub mod ndimage;
113pub mod ode;
114pub mod optimize;
115pub mod parallel;
116pub mod parallel_optimize;
117pub mod pde;
118pub mod printing;
119#[cfg(feature = "python")]
120pub mod python;
121pub mod random;
122pub mod roots;
123pub mod set_ops;
124pub mod signal;
125pub mod simd;
126pub mod simd_optimize;
127pub mod sparse;
128pub mod sparse_enhanced;
129pub mod spatial;
130pub mod special;
131pub mod stats;
132pub mod stride_tricks;
133pub mod testing;
134pub mod traits;
135pub mod types;
136pub mod ufuncs;
137pub mod unique;
138pub mod unique_optimized;
139pub mod util;
140pub mod views;
141
142pub mod new_modules {
145 pub mod eigenvalues;
146 pub mod fft;
147 pub mod fft_enhanced;
148 pub mod frequency_analysis;
149 #[cfg(feature = "matrix_decomp")]
150 pub mod matrix_decomp;
151 pub mod polynomial;
152 pub mod signal_processing;
153 pub mod sparse;
154 pub mod special;
155 pub mod spectral_analysis;
156}
157
158pub use error::{NumRs2Error, Result};
159
160pub use random::random_base;
162
163#[cfg(doctest)]
165pub mod doctests {}
166
167pub mod prelude {
169 pub use crate::array::Array;
170 pub use crate::array_ops::*;
171 pub use crate::array_ops_legacy::rollaxis;
173 pub use crate::axis_ops::*;
175 pub use crate::axis_ops::{apply_along_axis, apply_over_axes, vectorize};
176 pub use crate::bitwise_ops::{
177 bitwise_and, bitwise_not, bitwise_or, bitwise_xor, invert, left_shift, left_shift_scalar,
178 right_shift, right_shift_scalar,
179 };
180 pub use crate::char;
181 pub use crate::char::{array_from_strings, StringArray, StringElement};
182 pub use crate::comparisons::{
183 all, allclose, allclose_with_tol, any, array_equal, count_nonzero, equal, flatnonzero,
184 greater, greater_equal, isclose, isclose_array, less, less_equal, logical_and, logical_not,
185 logical_or, logical_xor, not_equal,
186 };
187 pub use crate::complex_ops::{
188 absolute as complex_abs, angle as complex_angle, conj as complex_conj, from_polar,
189 imag as complex_imag, iscomplex, iscomplexobj, isreal, isrealobj, real as complex_real,
190 to_complex,
191 };
192 pub use crate::conversions::*;
193 pub use crate::error::{NumRs2Error, Result};
194 pub use crate::error_handling::{
195 errstate, geterr, geterrcall, handle_error, seterr, seterrcall, ErrorAction, ErrorState,
196 ErrorStateBuilder, ErrorStateGuard, FloatingPointError,
197 };
198 pub use crate::financial::{
199 accrued_interest,
201 amortization_schedule,
203 binomial_option_price,
205 black_scholes,
206 black_scholes_greeks,
207 bond_convexity,
208 bond_duration,
209 bond_equivalent_yield,
210 bond_price,
211 bond_yield,
212 cumipmt,
214 cumprinc,
215 db,
217 ddb,
218 effect,
220 fv,
222 fv_array,
223 implied_volatility,
224 ipmt,
226 irr,
227 irr_multiple_series,
228 mirr,
229 modified_duration,
230 nominal,
231 nper,
232 nper_array,
233 npv,
234 npv_multiple_series,
235 npv_rates,
236 pmt,
237 pmt_array,
238 ppmt,
239 pv,
240 pv_array,
241 rate,
242 rate_array,
243 sln,
245 syd,
246 AmortizationSchedule,
247 };
248 pub use crate::indexing::{
250 diag_indices, diag_indices_from, extract, indices_grid, ix_, mask_indices,
251 put as indexing_put, put_along_axis, putmask as indexing_putmask, ravel_multi_index, take,
252 take_along_axis, tril_indices, tril_indices_from, triu_indices, triu_indices_from,
253 unravel_index, IndexSpec,
254 };
255 pub use crate::io::{array_to_vec2d, vec2d_to_array, vec_to_array, SerializeFormat};
256 #[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
258 pub use crate::linalg::{
259 cholesky as cholesky_basic, eig, inv, qr as qr_basic, solve, svd as svd_basic,
260 };
261 #[cfg(feature = "lapack")]
262 pub use crate::linalg::{det, matrix_power};
263 pub use crate::linalg::{inner, kron, norm, outer, tensordot, trace, vdot};
264
265 #[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
267 pub use crate::linalg::{matrix_rank, pinv};
268 pub use crate::linalg_extended::eigenvalue;
270 pub use crate::linalg_optimized::{lu_optimized, transpose_optimized, OptimizedBlas};
271 pub use crate::linalg_parallel::ParallelLinAlg;
272 pub use crate::linalg_stable::{
273 CholeskyStableResult, QRPivotedResult, SVDStableResult, StableDecompositions,
274 };
275 pub use crate::masked::MaskedArray;
276 pub use crate::ufuncs::{abs, ceil, exp, floor, log, round, sqrt};
278 pub use crate::math_extended::{erf, erfc, gamma, gammaln};
282 pub use crate::math::{
285 amax, amin, angle, arange, argmax, argmin, argpartition, argsort, around, bartlett,
286 bincount, blackman, clip, conj, copysign, cumprod, cumsum, cumulative_prod, cumulative_sum,
287 diff, diff_extended, digitize, divmod, ediff1d, empty, fmod, frexp, gcd, geomspace,
288 gradient, hamming, hanning, heaviside, i0, imag, interp, isfinite, isinf, isnan, kaiser,
289 lcm, ldexp, linspace, logspace, max, mean, median, min, modf, nan_to_num, nanmax, nanmean,
290 nanmin, nanstd, nansum, nanvar, nextafter, nonzero, ones, partition, prod, real,
291 real_if_close, remainder, resize, searchsorted, sinc, sort, std, sum, trapz, var, zeros,
292 ElementWiseMath,
293 };
294 pub use crate::matrix::{
295 asmatrix, matrix, matrix_from_nested, matrix_from_scalar, BandedMatrix, Matrix,
296 };
297 pub use crate::mmap::MmapArray;
298 pub use crate::random::advanced_distributions;
299 pub use crate::random::distributions;
300 pub use crate::random::generator::{default_rng, BitGenerator, Generator, StdBitGenerator};
301 pub use crate::random::{self, RandomState};
302 pub use crate::set_ops::{
303 in1d, intersect1d, isin, setdiff1d, setxor1d, union1d, unique_axis, unique_with_options,
304 };
305 pub use crate::signal::{convolve, convolve2d, correlate, correlate2d};
306 pub use crate::simd::{get_simd_implementation, get_simd_implementation_name};
308 pub use crate::simd_optimize::{detect_cpu_features, CpuFeatures, SimdImplementation};
309 pub use crate::sparse;
310 pub use crate::sparse_enhanced::SparseOpsAdvanced;
311 pub use crate::stats::{
313 average, corrcoef, cov, histogram, histogram_dd, max_along_axis, min_along_axis, mode,
314 percentile, ptp, quantile, HistBins, Statistics,
315 };
316 pub use crate::stride_tricks::{
317 as_strided, broadcast_arrays, broadcast_to, byte_strides, set_strides, sliding_window_view,
318 };
319 pub use crate::testing::{
321 arrays_close, assert_array_all_finite, assert_array_almost_equal, assert_array_equal,
322 assert_array_no_nan, assert_array_same_shape, assert_scalar_almost_equal, is_finite_array,
323 test_summary, tolerances, TestResult, ToleranceConfig,
324 };
325 pub use crate::run_tests;
327 pub use crate::traits::{
329 ArrayIndexing, ArrayMath, ArrayOps, ArrayReduction, ComplexElement, FloatingPoint,
330 IntegerElement, LinearAlgebra, MatrixDecomposition, NumericElement,
331 };
332 pub use crate::ufuncs::{
335 absolute, add, add_scalar, arctan2, cbrt, divide, divide_scalar, dot, exp2, expm1, fma,
336 hypot, log10, log1p, log2, maximum, minimum, multiply, multiply_scalar, negative, norm_l1,
337 norm_l2, power, power_scalar, reciprocal, subtract, subtract_scalar, BinaryUfunc,
338 UnaryUfunc,
339 };
340 pub use crate::unique::{unique, UniqueResult};
341 pub use crate::unique_optimized::unique_optimized;
342 pub use crate::util::{
343 astype, can_operate_inplace, fast_sum, optimize_layout, parallel_map, MemoryLayout,
344 };
345 pub use crate::views::*;
346
347 pub use crate::interop::ndarray_compat::{from_ndarray, to_ndarray};
350 pub use crate::memory_optimize::{
354 align_data, optimize_layout as memory_optimize_layout, optimize_placement,
355 AlignmentStrategy, LayoutStrategy, PlacementStrategy,
356 };
357
358 pub use crate::parallel_optimize::{
360 adaptive_threshold, optimize_parallel_computation, optimize_scheduling, partition_workload,
361 };
362 pub use crate::parallel_optimize::{
363 ParallelConfig, ParallelizationThreshold, SchedulingStrategy, WorkloadPartitioning,
364 };
365
366 pub use crate::printing::{
368 array_str, get_printoptions, reset_printoptions, set_printoptions, PrintOptions,
369 };
370
371 pub use crate::memory_alloc::{
373 get_default_allocator, get_global_allocator_strategy, init_global_allocator,
374 reset_global_allocator,
375 };
376 pub use crate::memory_alloc::{
377 AlignedAllocator, AlignmentConfig, AllocStrategy, ArenaAllocator, ArenaConfig, CacheConfig,
378 CacheLevel, CacheOptimizedAllocator, PoolAllocator, PoolConfig,
379 };
380
381 pub use crate::algorithms::{
383 BandwidthEstimate, BandwidthOptimizer, CacheAwareArrayOps, CacheAwareConvolution,
384 CacheAwareFFT, MemoryOperation,
385 };
386
387 pub use crate::parallel::parallel_algorithms::ParallelConfig as ParallelAlgorithmConfig;
389 pub use crate::parallel::{
390 global_parallel_context, initialize_parallel_context, shutdown_parallel_context, task,
391 BalancingStrategy, LoadBalancer, ParallelAllocator, ParallelAllocatorConfig,
392 ParallelArrayOps, ParallelContext, ParallelFFT, ParallelMatrixOps, ParallelScheduler,
393 SchedulerConfig, Task, TaskPriority, TaskResult, ThreadLocalAllocator, WorkStealingPool,
394 WorkloadMetrics,
395 };
396
397 pub use crate::memory_alloc::{
399 EnhancedAllocatorBridge, IntelligentAllocationStrategy, NumericalArrayAllocator,
400 };
401 pub use crate::traits::{
402 AllocationFrequency, AllocationLifetime, AllocationRequirements, AllocationStats,
403 AllocationStrategy, MemoryAllocator, MemoryAware, MemoryOptimization, MemoryUsage,
404 OptimizationType, SpecializedAllocator, ThreadingRequirements,
405 };
406
407 #[cfg(feature = "lapack")]
409 pub use crate::new_modules::eigenvalues::{eig as eig_general, eigh, eigvals, eigvalsh};
410 pub use crate::new_modules::fft::FFT;
411 #[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
412 pub use crate::new_modules::matrix_decomp::{
413 cholesky, cod, condition_number, lu, pivoted_cholesky, qr, rcond, schur, svd,
414 };
415 pub use crate::new_modules::polynomial::{
416 poly, polyadd, polychebyshev, polycompanion, polycompose, polyder, polydiv, polyextrap,
417 polyfit, polyfit_weighted, polyfromroots, polygcd, polygrid2d, polyhermite, polyint,
418 polyjacobi, polylaguerre, polylegendre, polymul, polymulx, polypower, polyresidual,
419 polyscale, polysub, polytrim, polyval2d, polyvander, polyvander2d, CubicSpline, Polynomial,
420 PolynomialInterpolation,
421 };
422
423 #[cfg(feature = "lapack")]
425 pub use crate::optimized_ops::parallel_matrix_ops;
426 pub use crate::optimized_ops::{
427 adaptive_array_sum, chunked_array_processing, get_optimization_info,
428 parallel_column_statistics, should_use_parallel, simd_elementwise_ops, simd_matmul,
429 simd_vector_ops, ColumnStats, SimdOpsResult, SimdVectorResult,
430 };
431
432 #[cfg(feature = "gpu")]
434 pub use crate::gpu::{
435 add as gpu_add, divide as gpu_divide, matmul, multiply as gpu_multiply,
436 subtract as gpu_subtract, transpose, GpuArray, GpuContext,
437 };
438 pub use crate::new_modules::sparse::{SparseArray, SparseMatrix, SparseMatrixFormat};
439 pub use crate::new_modules::special::{
440 airy_ai, airy_bi, associated_legendre_p, bessel_i, bessel_j, bessel_k, bessel_y, beta,
441 betainc, digamma, ellipe, ellipeinc, ellipf, ellipk, erfcinv, erfinv, exp1, expi, fresnel,
442 gammainc, jacobi_elliptic, lambertw, lambertwm1, legendre_p, polylog, shichi, sici,
443 spherical_harmonic, struve_h, zeta,
444 };
445 pub use crate::arrays::{
449 ArrayView, BooleanCombineOp, BroadcastEngine, BroadcastOp, BroadcastReduction,
450 FancyIndexEngine, FancyIndexResult, ResolvedIndex, Shape, SpecializedIndexing,
451 };
452
453 pub use crate::types::custom::CustomDType;
455 pub use crate::types::datetime::{
456 business_days,
457 datetime64,
459 datetime_array,
460 datetime_as_string,
461 datetime_data,
462 timedelta64,
463 DateTime64,
464 DateTimeUnit,
465 DateUnit,
466 TimeDelta64,
467 Timezone,
468 TimezoneDateTime,
469 };
470 pub use crate::types::structured::{DType, Field, RecordArray, StructuredArray};
471
472 pub use scirs2_core::ndarray::{Axis, Dimension, IxDyn, ShapeBuilder};
474 pub use scirs2_core::{Complex, Complex64};
476}
477
478#[cfg(test)]
479mod tests {
480 use crate::prelude::*;
481 use crate::simd::{simd_add, simd_div, simd_mul, simd_prod, simd_sqrt, simd_sum};
482 use approx::assert_relative_eq;
483
484 #[test]
485 fn basic_array_ops() {
486 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
487 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
488
489 let c = a.add(&b);
491 assert_eq!(c.to_vec(), vec![6.0, 8.0, 10.0, 12.0]);
492
493 let d = a.subtract(&b);
495 assert_eq!(d.to_vec(), vec![-4.0, -4.0, -4.0, -4.0]);
496
497 let e = a.multiply(&b);
499 assert_eq!(e.to_vec(), vec![5.0, 12.0, 21.0, 32.0]);
500
501 let f = a.divide(&b);
503 assert_relative_eq!(f.to_vec()[0], 0.2, epsilon = 1e-10);
504 assert_relative_eq!(f.to_vec()[1], 1.0 / 3.0, epsilon = 1e-10);
505 assert_relative_eq!(f.to_vec()[2], 3.0 / 7.0, epsilon = 1e-10);
506 assert_relative_eq!(f.to_vec()[3], 0.5, epsilon = 1e-10);
507 }
508
509 #[test]
510 fn test_broadcasting() {
511 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
513
514 let b = a.add_scalar(5.0);
516 assert_eq!(b.to_vec(), vec![6.0, 7.0, 8.0]);
517
518 let c = a.multiply_scalar(2.0);
520 assert_eq!(c.to_vec(), vec![2.0, 4.0, 6.0]);
521
522 let row = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]).reshape(&[1, 3]);
524 let col = Array::<f64>::from_vec(vec![4.0, 5.0]).reshape(&[2, 1]);
525
526 let result = row.add_broadcast(&col).unwrap();
528 assert_eq!(result.shape(), vec![2, 3]);
529 assert_eq!(result.to_vec(), vec![5.0, 6.0, 7.0, 6.0, 7.0, 8.0]);
530
531 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
533 let b = Array::<f64>::from_vec(vec![10.0, 20.0]).reshape(&[1, 2]);
534
535 let result = a.multiply_broadcast(&b).unwrap();
537 assert_eq!(result.shape(), vec![2, 2]);
538 assert_eq!(result.to_vec(), vec![10.0, 40.0, 30.0, 80.0]);
539
540 let shape1 = vec![3, 1, 4];
542 let shape2 = vec![2, 1];
543 let broadcast_shape = Array::<f64>::broadcast_shape(&shape1, &shape2).unwrap();
544 assert_eq!(broadcast_shape, vec![3, 2, 4]);
545 }
546
547 #[test]
548 fn test_array_creation() {
549 let zeros = Array::<f64>::zeros(&[2, 3]);
551 assert_eq!(zeros.shape(), vec![2, 3]);
552 assert_eq!(zeros.to_vec(), vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
553
554 let ones = Array::<f64>::ones(&[2, 2]);
556 assert_eq!(ones.shape(), vec![2, 2]);
557 assert_eq!(ones.to_vec(), vec![1.0, 1.0, 1.0, 1.0]);
558
559 let fives = Array::<f64>::full(&[2, 2], 5.0);
561 assert_eq!(fives.shape(), vec![2, 2]);
562 assert_eq!(fives.to_vec(), vec![5.0, 5.0, 5.0, 5.0]);
563
564 let arr = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
566 let reshaped = arr.reshape(&[2, 3]);
567 assert_eq!(reshaped.shape(), vec![2, 3]);
568 assert_eq!(reshaped.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
569 }
570
571 #[test]
572 fn test_array_methods() {
573 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(&[2, 3]);
574
575 assert_eq!(a.shape(), vec![2, 3]);
577 assert_eq!(a.ndim(), 2);
578 assert_eq!(a.size(), 6);
579
580 let at = a.transpose();
582 assert_eq!(at.shape(), vec![3, 2]);
583
584 let at_vec = at.to_vec();
587 assert_eq!(at_vec.len(), 6);
588 assert!(at_vec.contains(&1.0));
589 assert!(at_vec.contains(&2.0));
590 assert!(at_vec.contains(&3.0));
591 assert!(at_vec.contains(&4.0));
592 assert!(at_vec.contains(&5.0));
593 assert!(at_vec.contains(&6.0));
594
595 let slice = a.slice(0, 1).unwrap();
597 assert_eq!(slice.shape(), vec![3]);
598 assert_eq!(slice.to_vec(), vec![4.0, 5.0, 6.0]);
599 }
600
601 #[test]
602 fn test_map_operations() {
603 let a = Array::<f64>::from_vec(vec![1.0, 4.0, 9.0, 16.0]);
604
605 let sqrt_a = a.map(|x| x.sqrt());
607 assert_relative_eq!(sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
608 assert_relative_eq!(sqrt_a.to_vec()[1], 2.0, epsilon = 1e-10);
609 assert_relative_eq!(sqrt_a.to_vec()[2], 3.0, epsilon = 1e-10);
610 assert_relative_eq!(sqrt_a.to_vec()[3], 4.0, epsilon = 1e-10);
611
612 let par_sqrt_a = a.par_map(|x| x.sqrt());
614 assert_relative_eq!(par_sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
615 assert_relative_eq!(par_sqrt_a.to_vec()[1], 2.0, epsilon = 1e-10);
616 assert_relative_eq!(par_sqrt_a.to_vec()[2], 3.0, epsilon = 1e-10);
617 assert_relative_eq!(par_sqrt_a.to_vec()[3], 4.0, epsilon = 1e-10);
618 }
619
620 #[cfg(feature = "lapack")]
621 #[test]
622 fn test_linalg_ops() {
623 let a = Array::<f64>::from_vec(vec![4.0, 7.0, 2.0, 6.0]).reshape(&[2, 2]);
625
626 let det_a = det(&a).unwrap();
628 assert_relative_eq!(det_a, 10.0, epsilon = 1e-10);
629
630 let inv_a = inv(&a).unwrap();
632 let expected_inv = [0.6, -0.7, -0.2, 0.4];
633 for (actual, expected) in inv_a.to_vec().iter().zip(expected_inv.iter()) {
634 assert_relative_eq!(*actual, *expected, epsilon = 1e-10);
635 }
636
637 let identity = a.matmul(&inv_a).unwrap();
639 assert_relative_eq!(identity.to_vec()[0], 1.0, epsilon = 1e-10);
640 assert_relative_eq!(identity.to_vec()[1], 0.0, epsilon = 1e-10);
641 assert_relative_eq!(identity.to_vec()[2], 0.0, epsilon = 1e-10);
642 assert_relative_eq!(identity.to_vec()[3], 1.0, epsilon = 1e-10);
643
644 let b = Array::<f64>::from_vec(vec![1.0, 3.0]);
646 let x = solve(&a, &b).unwrap();
647
648 assert_relative_eq!(x.to_vec()[0], -1.5, epsilon = 1e-10);
650 assert_relative_eq!(x.to_vec()[1], 1.0, epsilon = 1e-10);
651
652 let b_check = match a.matmul(&x.reshape(&[2, 1])) {
654 Ok(result) => result.reshape(&[2]),
655 Err(_) => panic!("Matrix multiplication failed"),
656 };
657 assert_relative_eq!(b_check.to_vec()[0], b.to_vec()[0], epsilon = 1e-10);
658 assert_relative_eq!(b_check.to_vec()[1], b.to_vec()[1], epsilon = 1e-10);
659 }
660
661 #[test]
662 fn test_tensor_operations() {
663 let a = Array::<f64>::from_vec(vec![1.0, 2.0]).reshape(&[1, 2]);
665 let b = Array::<f64>::from_vec(vec![3.0, 4.0]).reshape(&[2, 1]);
666
667 let kron_result = kron(&a, &b).unwrap();
668 assert_eq!(kron_result.shape(), &[2, 2]);
669 assert_eq!(kron_result.to_vec(), vec![3.0, 6.0, 4.0, 8.0]);
670
671 let tensordot_result = tensordot(&a, &b, &[1, 0]).unwrap();
673 assert_eq!(tensordot_result.shape(), &[1, 1]);
674 assert_relative_eq!(tensordot_result.to_vec()[0], 11.0, epsilon = 1e-10);
675 }
676
677 #[test]
678 fn test_matrix_operations() {
679 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
681 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
682
683 let c = a.matmul(&b).unwrap();
685 assert_eq!(c.shape(), vec![2, 2]);
686 assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
687
688 let v = Array::<f64>::from_vec(vec![1.0, 2.0]);
690 let result = a.matmul(&v.reshape(&[2, 1])).unwrap().reshape(&[2]);
691 assert_eq!(result.to_vec(), vec![5.0, 11.0]);
692 }
693
694 #[test]
695 fn test_simd_operations() {
696 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
697 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]);
698
699 let c = simd_add(&a, &b).unwrap();
701 assert_eq!(c.to_vec(), vec![6.0, 8.0, 10.0, 12.0]);
702
703 let d = simd_mul(&a, &b).unwrap();
705 assert_eq!(d.to_vec(), vec![5.0, 12.0, 21.0, 32.0]);
706
707 let e = simd_div(&a, &b).unwrap();
709 assert_relative_eq!(e.to_vec()[0], 0.2, epsilon = 1e-10);
710 assert_relative_eq!(e.to_vec()[1], 1.0 / 3.0, epsilon = 1e-10);
711 assert_relative_eq!(e.to_vec()[2], 3.0 / 7.0, epsilon = 1e-10);
712 assert_relative_eq!(e.to_vec()[3], 0.5, epsilon = 1e-10);
713
714 let sqrt_a = simd_sqrt(&a);
716 assert_relative_eq!(sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
717 assert_relative_eq!(
718 sqrt_a.to_vec()[1],
719 std::f64::consts::SQRT_2,
720 epsilon = 1e-10
721 );
722 assert_relative_eq!(sqrt_a.to_vec()[2], 1.7320508075688772, epsilon = 1e-10);
723 assert_relative_eq!(sqrt_a.to_vec()[3], 2.0, epsilon = 1e-10);
724
725 assert_eq!(simd_sum(&a), 10.0);
727 assert_eq!(simd_prod(&a), 24.0);
728 }
729
730 #[test]
731 fn test_norm_functions() {
732 let v = Array::<f64>::from_vec(vec![3.0, 4.0]);
734
735 let norm_1 = norm(&v, Some(1.0)).unwrap();
737 assert_relative_eq!(norm_1, 7.0, epsilon = 1e-10);
738
739 let norm_2 = norm(&v, Some(2.0)).unwrap();
741 assert_relative_eq!(norm_2, 5.0, epsilon = 1e-10);
742
743 let norm_inf = norm(&v, Some(f64::INFINITY)).unwrap();
745 assert_relative_eq!(norm_inf, 4.0, epsilon = 1e-10);
746
747 let m = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
749
750 let matrix_norm_1 = norm(&m, Some(1.0)).unwrap();
752 assert_relative_eq!(matrix_norm_1, 6.0, epsilon = 1e-10);
753
754 let matrix_norm_inf = norm(&m, Some(f64::INFINITY)).unwrap();
756 assert_relative_eq!(matrix_norm_inf, 7.0, epsilon = 1e-10);
757 }
758
759 #[test]
760 fn test_math_operations() {
761 use crate::math::*;
762
763 let a = Array::<f64>::from_vec(vec![1.0, 4.0, 9.0, 16.0]);
765
766 let neg_a = a.map(|x| -x);
768 let abs_a = neg_a.abs();
769 for (expected, actual) in a.to_vec().iter().zip(abs_a.to_vec().iter()) {
770 assert_relative_eq!(*expected, *actual, epsilon = 1e-10);
771 }
772
773 let exp_a = a.exp();
775 assert_relative_eq!(exp_a.to_vec()[0], 1.0_f64.exp(), epsilon = 1e-10);
776 assert_relative_eq!(exp_a.to_vec()[1], 4.0_f64.exp(), epsilon = 1e-10);
777 assert_relative_eq!(exp_a.to_vec()[2], 9.0_f64.exp(), epsilon = 1e-10);
778 assert_relative_eq!(exp_a.to_vec()[3], 16.0_f64.exp(), epsilon = 1e-10);
779
780 let log_a = a.log();
782 assert_relative_eq!(log_a.to_vec()[0], 1.0_f64.ln(), epsilon = 1e-10);
783 assert_relative_eq!(log_a.to_vec()[1], 4.0_f64.ln(), epsilon = 1e-10);
784 assert_relative_eq!(log_a.to_vec()[2], 9.0_f64.ln(), epsilon = 1e-10);
785 assert_relative_eq!(log_a.to_vec()[3], 16.0_f64.ln(), epsilon = 1e-10);
786
787 let sqrt_a = a.sqrt();
789 assert_relative_eq!(sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
790 assert_relative_eq!(sqrt_a.to_vec()[1], 2.0, epsilon = 1e-10);
791 assert_relative_eq!(sqrt_a.to_vec()[2], 3.0, epsilon = 1e-10);
792 assert_relative_eq!(sqrt_a.to_vec()[3], 4.0, epsilon = 1e-10);
793
794 let pow_a = a.pow(2.0);
796 assert_relative_eq!(pow_a.to_vec()[0], 1.0, epsilon = 1e-10);
797 assert_relative_eq!(pow_a.to_vec()[1], 16.0, epsilon = 1e-10);
798 assert_relative_eq!(pow_a.to_vec()[2], 81.0, epsilon = 1e-10);
799 assert_relative_eq!(pow_a.to_vec()[3], 256.0, epsilon = 1e-10);
800
801 let angles = Array::<f64>::from_vec(vec![
803 0.0,
804 std::f64::consts::PI / 6.0,
805 std::f64::consts::PI / 4.0,
806 std::f64::consts::PI / 3.0,
807 ]);
808
809 let sin_angles = angles.sin();
810 assert_relative_eq!(sin_angles.to_vec()[0], 0.0, epsilon = 1e-10);
811 assert_relative_eq!(sin_angles.to_vec()[1], 0.5, epsilon = 1e-10);
812 assert_relative_eq!(
813 sin_angles.to_vec()[2],
814 1.0 / std::f64::consts::SQRT_2,
815 epsilon = 1e-10
816 );
817 assert_relative_eq!(sin_angles.to_vec()[3], 0.8660254037844386, epsilon = 1e-10);
818
819 let cos_angles = angles.cos();
820 assert_relative_eq!(cos_angles.to_vec()[0], 1.0, epsilon = 1e-10);
821 assert_relative_eq!(cos_angles.to_vec()[1], 0.8660254037844386, epsilon = 1e-10);
822 assert_relative_eq!(
823 cos_angles.to_vec()[2],
824 1.0 / std::f64::consts::SQRT_2,
825 epsilon = 1e-10
826 );
827 assert_relative_eq!(cos_angles.to_vec()[3], 0.5, epsilon = 1e-10);
828
829 let lin = linspace(0.0, 10.0, 6);
831 assert_eq!(lin.size(), 6);
832 assert_relative_eq!(lin.to_vec()[0], 0.0, epsilon = 1e-10);
833 assert_relative_eq!(lin.to_vec()[1], 2.0, epsilon = 1e-10);
834 assert_relative_eq!(lin.to_vec()[2], 4.0, epsilon = 1e-10);
835 assert_relative_eq!(lin.to_vec()[3], 6.0, epsilon = 1e-10);
836 assert_relative_eq!(lin.to_vec()[4], 8.0, epsilon = 1e-10);
837 assert_relative_eq!(lin.to_vec()[5], 10.0, epsilon = 1e-10);
838
839 let range = arange(0.0, 5.0, 1.0);
841 assert_eq!(range.size(), 5);
842 assert_eq!(range.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
843
844 let rev_range = arange(5.0, 0.0, -1.0);
846 assert_eq!(rev_range.size(), 5);
847 assert_eq!(rev_range.to_vec(), vec![5.0, 4.0, 3.0, 2.0, 1.0]);
848
849 let log_space = logspace(0.0, 3.0, 4, None);
851 assert_eq!(log_space.size(), 4);
852 assert_relative_eq!(log_space.to_vec()[0], 1.0, epsilon = 1e-10);
853 assert_relative_eq!(log_space.to_vec()[1], 10.0, epsilon = 1e-10);
854 assert_relative_eq!(log_space.to_vec()[2], 100.0, epsilon = 1e-10);
855 assert_relative_eq!(log_space.to_vec()[3], 1000.0, epsilon = 1e-10);
856
857 let geom_space = geomspace(1.0, 1000.0, 4);
859 assert_eq!(geom_space.size(), 4);
860 assert_relative_eq!(geom_space.to_vec()[0], 1.0, epsilon = 1e-10);
861 assert_relative_eq!(geom_space.to_vec()[1], 10.0, epsilon = 1e-10);
862 assert_relative_eq!(geom_space.to_vec()[2], 100.0, epsilon = 1e-10);
863 assert_relative_eq!(geom_space.to_vec()[3], 1000.0, epsilon = 1e-10);
864 }
865
866 #[test]
867 fn test_array_operations() {
868 use crate::array_ops::*;
869
870 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
872 let tiled = tile(&a, &[2]).unwrap();
873 assert_eq!(tiled.shape(), vec![6]);
874 assert_eq!(tiled.to_vec(), vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
875
876 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
877 let tiled_2d = tile(&a_2d, &[2, 1]).unwrap();
878 assert_eq!(tiled_2d.shape(), vec![4, 2]);
879 assert_eq!(
880 tiled_2d.to_vec(),
881 vec![1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]
882 );
883
884 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
886 let repeated = repeat(&a, 2, None).unwrap();
887 assert_eq!(repeated.shape(), vec![6]);
888 assert_eq!(repeated.to_vec(), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
889
890 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
891 let repeated_axis0 = repeat(&a_2d, 2, Some(0)).unwrap();
892 assert_eq!(repeated_axis0.shape(), vec![4, 2]);
893 assert_eq!(
894 repeated_axis0.to_vec(),
895 vec![1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0]
896 );
897
898 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
900 let b = Array::<f64>::from_vec(vec![4.0, 5.0, 6.0]);
901 let c = concatenate(&[&a, &b], 0).unwrap();
902 assert_eq!(c.shape(), vec![6]);
903 assert_eq!(c.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
904
905 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
906 let b_2d = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
907 let c_axis0 = concatenate(&[&a_2d, &b_2d], 0).unwrap();
908 assert_eq!(c_axis0.shape(), vec![4, 2]);
909 assert_eq!(
910 c_axis0.to_vec(),
911 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
912 );
913
914 let c_axis1 = concatenate(&[&a_2d, &b_2d], 1).unwrap();
915 assert_eq!(c_axis1.shape(), vec![2, 4]);
916 let c_vec = c_axis1.to_vec();
917 assert_eq!(c_vec.len(), 8);
919 assert!(c_vec.contains(&1.0));
920 assert!(c_vec.contains(&2.0));
921 assert!(c_vec.contains(&3.0));
922 assert!(c_vec.contains(&4.0));
923 assert!(c_vec.contains(&5.0));
924 assert!(c_vec.contains(&6.0));
925 assert!(c_vec.contains(&7.0));
926 assert!(c_vec.contains(&8.0));
927
928 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
930 let b = Array::<f64>::from_vec(vec![4.0, 5.0, 6.0]);
931 let c = stack(&[&a, &b], 0).unwrap();
932 assert_eq!(c.shape(), vec![2, 3]);
933 let c_vec = c.to_vec();
934 assert_eq!(c_vec.len(), 6);
936 assert!(c_vec.contains(&1.0));
937 assert!(c_vec.contains(&2.0));
938 assert!(c_vec.contains(&3.0));
939 assert!(c_vec.contains(&4.0));
940 assert!(c_vec.contains(&5.0));
941 assert!(c_vec.contains(&6.0));
942
943 let d = stack(&[&a, &b], 1).unwrap();
944 assert_eq!(d.shape(), vec![3, 2]);
945 let d_vec = d.to_vec();
946 assert_eq!(d_vec.len(), 6);
948 assert!(d_vec.contains(&1.0));
949 assert!(d_vec.contains(&2.0));
950 assert!(d_vec.contains(&3.0));
951 assert!(d_vec.contains(&4.0));
952 assert!(d_vec.contains(&5.0));
953 assert!(d_vec.contains(&6.0));
954
955 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
957 let splits = split(&a, &[2, 4], 0).unwrap();
958 assert_eq!(splits.len(), 3);
959 assert_eq!(splits[0].to_vec(), vec![1.0, 2.0]);
960 assert_eq!(splits[1].to_vec(), vec![3.0, 4.0]);
961 assert_eq!(splits[2].to_vec(), vec![5.0, 6.0]);
962
963 let _a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(&[2, 3]);
964 let splits_a = split(&a, &[2, 4], 0).unwrap();
966 assert_eq!(splits_a.len(), 3);
967
968 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
977 let expanded = expand_dims(&a, 0).unwrap();
978 assert_eq!(expanded.shape(), vec![1, 3]);
979 assert_eq!(expanded.to_vec(), vec![1.0, 2.0, 3.0]);
980
981 let expanded_end = expand_dims(&a, 1).unwrap();
982 assert_eq!(expanded_end.shape(), vec![3, 1]);
983 assert_eq!(expanded_end.to_vec(), vec![1.0, 2.0, 3.0]);
984
985 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]).reshape(&[1, 3, 1]);
987 let squeezed = squeeze(&a, None).unwrap();
988 assert_eq!(squeezed.shape(), vec![3]);
989 assert_eq!(squeezed.to_vec(), vec![1.0, 2.0, 3.0]);
990
991 let squeezed_axis = squeeze(&a, Some(0)).unwrap();
992 assert_eq!(squeezed_axis.shape(), vec![3, 1]);
993 assert_eq!(squeezed_axis.to_vec(), vec![1.0, 2.0, 3.0]);
994 }
995
996 #[test]
997 fn test_statistics_functions() {
998 use crate::stats::*;
999
1000 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1002
1003 assert_relative_eq!(a.mean(), 3.0, epsilon = 1e-10);
1005
1006 assert_relative_eq!(a.var(), 2.0, epsilon = 1e-10);
1008
1009 assert_relative_eq!(a.std(), std::f64::consts::SQRT_2, epsilon = 1e-10);
1011
1012 assert_relative_eq!(a.min(), 1.0, epsilon = 1e-10);
1014 assert_relative_eq!(a.max(), 5.0, epsilon = 1e-10);
1015
1016 assert_relative_eq!(a.percentile(0.0), 1.0, epsilon = 1e-10);
1018 assert_relative_eq!(a.percentile(0.5), 3.0, epsilon = 1e-10);
1019 assert_relative_eq!(a.percentile(1.0), 5.0, epsilon = 1e-10);
1020 assert_relative_eq!(a.percentile(0.25), 2.0, epsilon = 1e-10);
1021 assert_relative_eq!(a.percentile(0.75), 4.0, epsilon = 1e-10);
1022
1023 let b = Array::<f64>::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0]);
1025 let cov_result = cov(&a, Some(&b), None, None, None).unwrap();
1026 assert_relative_eq!(cov_result.get(&[0, 1]).unwrap(), -2.5, epsilon = 1e-10);
1027 let corrcoef_result = corrcoef(&a, Some(&b), None).unwrap();
1028 assert_relative_eq!(corrcoef_result.get(&[0, 1]).unwrap(), -1.0, epsilon = 1e-10);
1029
1030 let data = Array::<f64>::from_vec(vec![1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0]);
1032 let (counts, bins) = histogram(&data, 4, None, None).unwrap();
1033 assert_eq!(counts.to_vec(), vec![2.0, 2.0, 2.0, 3.0]);
1034 assert_eq!(bins.size(), 5);
1035 assert_relative_eq!(bins.to_vec()[0], 1.0, epsilon = 1e-10);
1036 assert_relative_eq!(bins.to_vec()[4], 5.0, epsilon = 1e-10);
1037 }
1038
1039 #[test]
1040 fn test_boolean_indexing() {
1041 use crate::indexing::*;
1042
1043 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1045
1046 let mask = vec![true, false, true, false, true];
1048
1049 let _bool_array = Array::<bool>::from_vec(mask.clone());
1052
1053 let mut filtered = Array::<f64>::zeros(&[5]);
1055 let values = Array::<f64>::from_vec(vec![1.0, 3.0, 5.0]);
1056
1057 let mut value_idx = 0;
1059 for (i, &m) in mask.iter().enumerate() {
1060 if m {
1061 filtered
1062 .set(&[i], values.get(&[value_idx]).unwrap())
1063 .unwrap();
1064 value_idx += 1;
1065 }
1066 }
1067
1068 assert_eq!(filtered.to_vec(), vec![1.0, 0.0, 3.0, 0.0, 5.0]);
1070
1071 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1073 .reshape(&[3, 3]);
1074
1075 let _row_indices = [0]; let _col_indices = [0]; let row_result = a_2d.index(&[IndexSpec::Index(0), IndexSpec::All]).unwrap();
1081 assert_eq!(row_result.shape(), vec![3]); let row_vec = row_result.to_vec();
1085 assert_eq!(row_vec.len(), 3);
1086 assert_eq!(row_vec, vec![1.0, 2.0, 3.0]);
1087
1088 let col_result = a_2d.index(&[IndexSpec::All, IndexSpec::Index(0)]).unwrap();
1089 assert_eq!(col_result.shape(), vec![3]); assert_eq!(col_result.to_vec(), vec![1.0, 4.0, 7.0]);
1091
1092 let mut a_copy = a.clone();
1094 a_copy
1095 .set_mask(
1096 &Array::<bool>::from_vec(vec![true, false, true, false, true]),
1097 &Array::<f64>::from_vec(vec![10.0, 30.0, 50.0]),
1098 )
1099 .unwrap();
1100
1101 assert_eq!(a_copy.to_vec(), vec![10.0, 2.0, 30.0, 4.0, 50.0]);
1102 }
1103
1104 #[test]
1105 fn test_fancy_indexing() {
1106 use crate::indexing::*;
1107
1108 let _a = Array::<f64>::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
1110
1111 let _indices = [0, 1, 2];
1114 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1118 .reshape(&[3, 3]);
1119
1120 let single_element = a_2d
1122 .index(&[IndexSpec::Index(1), IndexSpec::Index(1)])
1123 .unwrap();
1124 assert_eq!(single_element.to_vec(), vec![5.0]);
1125
1126 let slice_result = a_2d
1128 .index(&[IndexSpec::Slice(0, Some(2), None), IndexSpec::All])
1129 .unwrap();
1130 assert_eq!(slice_result.shape(), vec![2, 3]);
1131 assert_eq!(slice_result.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1132 }
1133
1134 #[test]
1135 fn test_axis_operations() {
1136 use crate::axis_ops::*;
1137
1138 let mut array = Array::<f64>::zeros(&[2, 3]);
1140 array.set(&[0, 0], 1.0).unwrap();
1141 array.set(&[0, 1], 2.0).unwrap();
1142 array.set(&[0, 2], 3.0).unwrap();
1143 array.set(&[1, 0], 4.0).unwrap();
1144 array.set(&[1, 1], 5.0).unwrap();
1145 array.set(&[1, 2], 6.0).unwrap();
1146
1147 let sum_axis0 = array.sum_axis(0).unwrap();
1149 assert_eq!(sum_axis0.shape(), vec![3]);
1150 assert_eq!(sum_axis0.to_vec(), vec![5.0, 7.0, 9.0]);
1151
1152 let sum_axis1 = array.sum_axis(1).unwrap();
1154 assert_eq!(sum_axis1.shape(), vec![2]);
1155 assert_eq!(sum_axis1.to_vec(), vec![6.0, 15.0]);
1156
1157 let mean_axis0 = array.mean_axis(Some(0)).unwrap();
1159 assert_eq!(mean_axis0.shape(), vec![3]);
1160 assert_eq!(mean_axis0.to_vec(), vec![2.5, 3.5, 4.5]);
1161
1162 let mean_axis1 = array.mean_axis(Some(1)).unwrap();
1164 assert_eq!(mean_axis1.shape(), vec![2]);
1165 assert_eq!(mean_axis1.to_vec(), vec![2.0, 5.0]);
1166
1167 let min_axis0 = array.min_axis(Some(0)).unwrap();
1170 assert_eq!(min_axis0.shape(), vec![3]);
1171 let min_axis0_vec = min_axis0.to_vec();
1173 assert_eq!(min_axis0_vec, vec![1.0, 2.0, 3.0]);
1174
1175 let min_axis1 = array.min_axis(Some(1)).unwrap();
1177 assert_eq!(min_axis1.shape(), vec![2]);
1178 assert_eq!(min_axis1.to_vec(), vec![1.0, 4.0]);
1180
1181 let max_axis1 = array.max_axis(Some(1)).unwrap();
1183 assert_eq!(max_axis1.shape(), vec![2]);
1184 assert_eq!(max_axis1.to_vec(), vec![3.0, 6.0]);
1186
1187 let mut array2 = Array::<f64>::zeros(&[2, 3]);
1189 array2.set(&[0, 0], 3.0).unwrap();
1190 array2.set(&[0, 1], 2.0).unwrap();
1191 array2.set(&[0, 2], 1.0).unwrap();
1192 array2.set(&[1, 0], 0.0).unwrap();
1193 array2.set(&[1, 1], 5.0).unwrap();
1194 array2.set(&[1, 2], 6.0).unwrap();
1195
1196 let argmin_axis0 = array2.argmin_axis(0).unwrap();
1198 assert_eq!(argmin_axis0.shape(), vec![3]);
1199 assert_eq!(argmin_axis0.to_vec(), vec![1, 0, 0]);
1200
1201 let var_axis0 = array.var_axis(Some(0)).unwrap();
1215 assert_eq!(var_axis0.shape(), vec![3]);
1216 assert_relative_eq!(var_axis0.get(&[0]).unwrap(), 2.25, epsilon = 1e-10);
1217
1218 let std_axis1 = array.std_axis(Some(1)).unwrap();
1220 assert_eq!(std_axis1.shape(), vec![2]);
1221
1222 let std_row1 = std_axis1.get(&[0]).unwrap();
1225 assert!(
1226 std_row1 > 0.8 && std_row1 < 1.1,
1227 "std_row1 ({}) should be approximately 1.0 or 0.82",
1228 std_row1
1229 );
1230
1231 let std_row2 = std_axis1.get(&[1]).unwrap();
1232 assert!(
1233 std_row2 > 0.8 && std_row2 < 1.1,
1234 "std_row2 ({}) should be approximately 1.0 or 0.82",
1235 std_row2
1236 );
1237 }
1238
1239 #[test]
1240 fn test_views_and_strides() {
1241 use crate::views::SliceOrIndex;
1242
1243 let mut a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1245 .reshape(&[3, 3]);
1246
1247 let view = a.view();
1249 assert_eq!(view.shape(), vec![3, 3]);
1250
1251 let mut view_mut = a.view_mut();
1253 view_mut.set(&[0, 0], 10.0).unwrap();
1254 assert_eq!(a.get(&[0, 0]).unwrap(), 10.0);
1255
1256 a.set(&[0, 0], 1.0).unwrap();
1258
1259 let strided = a.strided_view(&[2, 2]).unwrap();
1261 assert_eq!(strided.shape(), vec![2, 2]);
1262 let flat_data = strided.to_vec();
1263 assert!(flat_data.contains(&1.0));
1264 assert!(flat_data.contains(&3.0));
1265 assert!(flat_data.contains(&7.0));
1266 assert!(flat_data.contains(&9.0));
1267
1268 let slices = vec![
1270 SliceOrIndex::Slice(0, Some(2), None),
1271 SliceOrIndex::Slice(0, Some(2), None),
1272 ];
1273 let sliced = a.sliced_view(&slices).unwrap();
1274 assert_eq!(sliced.shape(), vec![2, 2]);
1275 assert_eq!(sliced.to_vec(), vec![1.0, 2.0, 4.0, 5.0]);
1276
1277 let transposed = a.transposed_view();
1279 assert_eq!(transposed.shape(), vec![3, 3]);
1280 let _t_flat = transposed.to_vec();
1281 assert_eq!(transposed.get(&[0, 1]).unwrap(), 4.0);
1283 assert_eq!(transposed.get(&[1, 0]).unwrap(), 2.0);
1284
1285 let broadcast = a.broadcast_view(&[3, 3, 3]).unwrap();
1287 assert_eq!(broadcast.shape(), vec![3, 3, 3]);
1288 assert_eq!(broadcast.get(&[0, 0, 0]).unwrap(), 1.0);
1289 assert_eq!(broadcast.get(&[1, 0, 0]).unwrap(), 1.0);
1290 }
1291
1292 #[test]
1293 fn test_universal_functions() {
1294 use crate::ufuncs::*;
1295
1296 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1298 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]);
1299
1300 let result = add(&a, &b).unwrap();
1302 assert_eq!(result.to_vec(), vec![6.0, 8.0, 10.0, 12.0]);
1303
1304 let result = subtract(&a, &b).unwrap();
1305 assert_eq!(result.to_vec(), vec![-4.0, -4.0, -4.0, -4.0]);
1306
1307 let result = multiply(&a, &b).unwrap();
1308 assert_eq!(result.to_vec(), vec![5.0, 12.0, 21.0, 32.0]);
1309
1310 let result = divide(&a, &b).unwrap();
1311 assert_relative_eq!(result.to_vec()[0], 0.2, epsilon = 1e-10);
1312 assert_relative_eq!(result.to_vec()[1], 1.0 / 3.0, epsilon = 1e-10);
1313 assert_relative_eq!(result.to_vec()[2], 3.0 / 7.0, epsilon = 1e-10);
1314 assert_relative_eq!(result.to_vec()[3], 0.5, epsilon = 1e-10);
1315
1316 let result = power(&a, &b).unwrap();
1317 assert_relative_eq!(result.to_vec()[0], 1.0, epsilon = 1e-10);
1318 assert_relative_eq!(result.to_vec()[1], 64.0, epsilon = 1e-10);
1319 assert_relative_eq!(result.to_vec()[2], 2187.0, epsilon = 1e-10);
1320 assert_relative_eq!(result.to_vec()[3], 65536.0, epsilon = 1e-10);
1321
1322 let result = square(&a);
1324 assert_eq!(result.to_vec(), vec![1.0, 4.0, 9.0, 16.0]);
1325
1326 let result = sqrt(&a);
1327 assert_relative_eq!(result.to_vec()[0], 1.0, epsilon = 1e-10);
1328 assert_relative_eq!(
1329 result.to_vec()[1],
1330 std::f64::consts::SQRT_2,
1331 epsilon = 1e-10
1332 );
1333 assert_relative_eq!(result.to_vec()[2], 1.7320508075688772, epsilon = 1e-10);
1334 assert_relative_eq!(result.to_vec()[3], 2.0, epsilon = 1e-10);
1335
1336 let result = exp(&a);
1337 assert_relative_eq!(result.to_vec()[0], 1.0_f64.exp(), epsilon = 1e-10);
1338 assert_relative_eq!(result.to_vec()[1], 2.0_f64.exp(), epsilon = 1e-10);
1339 assert_relative_eq!(result.to_vec()[2], 3.0_f64.exp(), epsilon = 1e-10);
1340 assert_relative_eq!(result.to_vec()[3], 4.0_f64.exp(), epsilon = 1e-10);
1341
1342 let result = log(&a);
1343 assert_relative_eq!(result.to_vec()[0], 1.0_f64.ln(), epsilon = 1e-10);
1344 assert_relative_eq!(result.to_vec()[1], 2.0_f64.ln(), epsilon = 1e-10);
1345 assert_relative_eq!(result.to_vec()[2], 3.0_f64.ln(), epsilon = 1e-10);
1346 assert_relative_eq!(result.to_vec()[3], 4.0_f64.ln(), epsilon = 1e-10);
1347
1348 let result = multiply_scalar(&a, 2.0);
1350 assert_eq!(result.to_vec(), vec![2.0, 4.0, 6.0, 8.0]);
1351
1352 let row = Array::<f64>::from_vec(vec![10.0, 20.0]).reshape(&[1, 2]);
1354 let col = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]).reshape(&[3, 1]);
1355 let result = add(&row, &col).unwrap();
1356 assert_eq!(result.shape(), vec![3, 2]);
1357 assert_eq!(result.to_vec(), vec![11.0, 21.0, 12.0, 22.0, 13.0, 23.0]);
1358 }
1359}