1#![allow(deprecated)] #![allow(clippy::result_large_err)] #![allow(clippy::needless_range_loop)] #![allow(clippy::too_many_arguments)] #![allow(clippy::identity_op)] #![allow(clippy::approx_constant)] #![allow(clippy::excessive_precision)] pub mod algorithms;
68pub mod array;
69pub mod array_ops;
70pub mod array_ops_legacy;
71pub mod arrays;
72#[cfg(feature = "arrow")]
73pub mod arrow;
74pub mod autodiff;
75pub mod axis_ops;
76pub mod bitwise_ops;
77pub mod blas;
78pub mod char;
79pub mod cluster;
80pub mod comparisons;
81pub mod comparisons_broadcast;
82pub mod complex_ops;
83pub mod conversions;
84pub mod derivative;
85pub mod distance;
86pub mod error;
87pub mod error_handling;
88pub mod expr;
89pub mod fft;
90pub mod financial;
91#[cfg(feature = "gpu")]
92pub mod gpu;
93pub mod indexing;
94pub mod integrate;
95pub mod interop;
96pub mod interpolate;
97pub mod io;
98pub mod linalg;
99pub mod linalg_accelerated;
100pub mod linalg_extended;
101pub mod linalg_optimized;
102pub mod linalg_parallel;
103pub mod optimized_ops; pub mod linalg_stable;
106pub mod masked;
107pub mod math;
108pub mod math_extended;
109pub mod matrix;
110pub mod memory_alloc;
111pub mod memory_optimize;
112pub mod mmap;
113pub mod ndimage;
114pub mod ode;
115pub mod optimize;
116pub mod parallel;
117pub mod parallel_optimize;
118pub mod pde;
119pub mod printing;
120#[cfg(feature = "python")]
121pub mod python;
122pub mod random;
123pub mod roots;
124pub mod set_ops;
125pub mod shared_array;
126pub mod signal;
127pub mod simd;
128pub mod simd_optimize;
129pub mod sparse;
130pub mod sparse_enhanced;
131pub mod spatial;
132pub mod special;
133pub mod stats;
134pub mod stride_tricks;
135pub mod testing;
136pub mod traits;
137pub mod types;
138pub mod ufuncs;
139pub mod unique;
140pub mod unique_optimized;
141pub mod util;
142pub mod views;
143
144pub mod new_modules {
147 pub mod eigenvalues;
148 pub mod fft;
149 pub mod fft_enhanced;
150 pub mod frequency_analysis;
151 #[cfg(feature = "matrix_decomp")]
152 pub mod matrix_decomp;
153 pub mod polynomial;
154 pub mod signal_processing;
155 pub mod sparse;
156 pub mod special;
157 pub mod spectral_analysis;
158}
159
160pub use error::{NumRs2Error, Result};
161
162pub use random::random_base;
164
165#[cfg(doctest)]
167pub mod doctests {}
168
169pub mod prelude {
171 pub use crate::array::Array;
172 pub use crate::array_ops::*;
173 pub use crate::array_ops_legacy::rollaxis;
175 pub use crate::axis_ops::*;
177 pub use crate::axis_ops::{apply_along_axis, apply_over_axes, vectorize};
178 pub use crate::bitwise_ops::{
179 bitwise_and, bitwise_not, bitwise_or, bitwise_xor, invert, left_shift, left_shift_scalar,
180 right_shift, right_shift_scalar,
181 };
182 pub use crate::char;
183 pub use crate::char::{array_from_strings, StringArray, StringElement};
184 pub use crate::comparisons::{
185 all, allclose, allclose_with_tol, any, array_equal, count_nonzero, equal, flatnonzero,
186 greater, greater_equal, isclose, isclose_array, less, less_equal, logical_and, logical_not,
187 logical_or, logical_xor, not_equal,
188 };
189 pub use crate::complex_ops::{
190 absolute as complex_abs, angle as complex_angle, conj as complex_conj, from_polar,
191 imag as complex_imag, iscomplex, iscomplexobj, isreal, isrealobj, real as complex_real,
192 to_complex,
193 };
194 pub use crate::conversions::*;
195 pub use crate::error::{NumRs2Error, Result};
196 pub use crate::error_handling::{
197 errstate, geterr, geterrcall, handle_error, seterr, seterrcall, ErrorAction, ErrorState,
198 ErrorStateBuilder, ErrorStateGuard, FloatingPointError,
199 };
200 pub use crate::financial::{
201 accrued_interest,
203 amortization_schedule,
205 binomial_option_price,
207 black_scholes,
208 black_scholes_greeks,
209 bond_convexity,
210 bond_duration,
211 bond_equivalent_yield,
212 bond_price,
213 bond_yield,
214 cumipmt,
216 cumprinc,
217 db,
219 ddb,
220 effect,
222 fv,
224 fv_array,
225 implied_volatility,
226 ipmt,
228 irr,
229 irr_multiple_series,
230 mirr,
231 modified_duration,
232 nominal,
233 nper,
234 nper_array,
235 npv,
236 npv_multiple_series,
237 npv_rates,
238 pmt,
239 pmt_array,
240 ppmt,
241 pv,
242 pv_array,
243 rate,
244 rate_array,
245 sln,
247 syd,
248 AmortizationSchedule,
249 };
250 pub use crate::indexing::{
252 diag_indices, diag_indices_from, extract, indices_grid, ix_, mask_indices,
253 put as indexing_put, put_along_axis, putmask as indexing_putmask, ravel_multi_index, take,
254 take_along_axis, tril_indices, tril_indices_from, triu_indices, triu_indices_from,
255 unravel_index, IndexSpec,
256 };
257 pub use crate::io::{array_to_vec2d, vec2d_to_array, vec_to_array, SerializeFormat};
258 #[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
260 pub use crate::linalg::{
261 cholesky as cholesky_basic, eig, inv, qr as qr_basic, solve, svd as svd_basic,
262 };
263 #[cfg(feature = "lapack")]
264 pub use crate::linalg::{det, matrix_power};
265 pub use crate::linalg::{inner, kron, norm, outer, tensordot, trace, vdot};
266
267 #[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
269 pub use crate::linalg::{matrix_rank, pinv};
270 pub use crate::linalg_extended::eigenvalue;
272 pub use crate::linalg_optimized::{lu_optimized, transpose_optimized, OptimizedBlas};
273 pub use crate::linalg_parallel::ParallelLinAlg;
274 pub use crate::linalg_stable::{
275 CholeskyStableResult, QRPivotedResult, SVDStableResult, StableDecompositions,
276 };
277 pub use crate::masked::MaskedArray;
278 pub use crate::ufuncs::{abs, ceil, exp, floor, log, round, sqrt};
280 pub use crate::math_extended::{erf, erfc, gamma, gammaln};
284 pub use crate::math::{
287 amax, amin, angle, arange, argmax, argmin, argpartition, argsort, around, bartlett,
288 bincount, blackman, clip, conj, copysign, cumprod, cumsum, cumulative_prod, cumulative_sum,
289 diff, diff_extended, digitize, divmod, ediff1d, empty, fmod, frexp, gcd, geomspace,
290 gradient, hamming, hanning, heaviside, i0, imag, interp, isfinite, isinf, isnan, kaiser,
291 lcm, ldexp, linspace, logspace, max, mean, median, min, modf, nan_to_num, nanmax, nanmean,
292 nanmin, nanstd, nansum, nanvar, nextafter, nonzero, ones, partition, prod, real,
293 real_if_close, remainder, resize, searchsorted, sinc, sort, std, sum, trapz, var, zeros,
294 ElementWiseMath,
295 };
296 pub use crate::matrix::{
297 asmatrix, matrix, matrix_from_nested, matrix_from_scalar, BandedMatrix, Matrix,
298 };
299 pub use crate::mmap::MmapArray;
300 pub use crate::random::advanced_distributions;
301 pub use crate::random::distributions;
302 pub use crate::random::generator::{default_rng, BitGenerator, Generator, StdBitGenerator};
303 pub use crate::random::{self, RandomState};
304 pub use crate::set_ops::{
305 in1d, intersect1d, isin, setdiff1d, setxor1d, union1d, unique_axis, unique_with_options,
306 };
307 pub use crate::signal::{convolve, convolve2d, correlate, correlate2d};
308 pub use crate::simd::get_simd_implementation_name;
310 pub use crate::sparse;
311 pub use crate::sparse_enhanced::SparseOpsAdvanced;
312 pub use crate::stats::{
314 average, corrcoef, cov, histogram, histogram_dd, max_along_axis, min_along_axis, mode,
315 percentile, ptp, quantile, HistBins, Statistics,
316 };
317 pub use crate::stride_tricks::{
318 as_strided, broadcast_arrays, broadcast_to, byte_strides, set_strides, sliding_window_view,
319 };
320 pub use crate::testing::{
322 arrays_close, assert_array_all_finite, assert_array_almost_equal, assert_array_equal,
323 assert_array_no_nan, assert_array_same_shape, assert_scalar_almost_equal, is_finite_array,
324 test_summary, tolerances, TestResult, ToleranceConfig,
325 };
326 pub use crate::run_tests;
328 pub use crate::traits::{
330 ArrayIndexing, ArrayMath, ArrayOps, ArrayReduction, ComplexElement, FloatingPoint,
331 IntegerElement, LinearAlgebra, MatrixDecomposition, NumericElement,
332 };
333 pub use crate::ufuncs::{
336 absolute, add, add_scalar, arctan2, cbrt, divide, divide_scalar, dot, exp2, expm1, fma,
337 hypot, log10, log1p, log2, maximum, minimum, multiply, multiply_scalar, negative, norm_l1,
338 norm_l2, power, power_scalar, reciprocal, subtract, subtract_scalar, BinaryUfunc,
339 UnaryUfunc,
340 };
341 pub use crate::unique::{unique, UniqueResult};
342 pub use crate::unique_optimized::unique_optimized;
343 pub use crate::util::{
344 astype, can_operate_inplace, fast_sum, optimize_layout, parallel_map, MemoryLayout,
345 };
346 pub use crate::views::*;
347
348 pub use crate::interop::ndarray_compat::{from_ndarray, to_ndarray};
351 pub use crate::memory_optimize::{
355 align_data, optimize_layout as memory_optimize_layout, optimize_placement,
356 AlignmentStrategy, LayoutStrategy, PlacementStrategy,
357 };
358
359 pub use crate::parallel_optimize::{
361 adaptive_threshold, optimize_parallel_computation, optimize_scheduling, partition_workload,
362 };
363 pub use crate::parallel_optimize::{
364 ParallelConfig, ParallelizationThreshold, SchedulingStrategy, WorkloadPartitioning,
365 };
366
367 pub use crate::printing::{
369 array_str, get_printoptions, reset_printoptions, set_printoptions, PrintOptions,
370 };
371
372 pub use crate::memory_alloc::{
374 get_default_allocator, get_global_allocator_strategy, init_global_allocator,
375 reset_global_allocator,
376 };
377 pub use crate::memory_alloc::{
378 AlignedAllocator, AlignmentConfig, AllocStrategy, ArenaAllocator, ArenaConfig, CacheConfig,
379 CacheLevel, CacheOptimizedAllocator, PoolAllocator, PoolConfig,
380 };
381
382 pub use crate::algorithms::{
384 BandwidthEstimate, BandwidthOptimizer, CacheAwareArrayOps, CacheAwareConvolution,
385 CacheAwareFFT, MemoryOperation,
386 };
387
388 pub use crate::parallel::parallel_algorithms::ParallelConfig as ParallelAlgorithmConfig;
390 pub use crate::parallel::{
391 global_parallel_context, initialize_parallel_context, shutdown_parallel_context, task,
392 BalancingStrategy, LoadBalancer, ParallelAllocator, ParallelAllocatorConfig,
393 ParallelArrayOps, ParallelContext, ParallelFFT, ParallelMatrixOps, ParallelScheduler,
394 SchedulerConfig, Task, TaskPriority, TaskResult, ThreadLocalAllocator, WorkStealingPool,
395 WorkloadMetrics,
396 };
397
398 pub use crate::memory_alloc::{
400 EnhancedAllocatorBridge, IntelligentAllocationStrategy, NumericalArrayAllocator,
401 };
402 pub use crate::traits::{
403 AllocationFrequency, AllocationLifetime, AllocationRequirements, AllocationStats,
404 AllocationStrategy, MemoryAllocator, MemoryAware, MemoryOptimization, MemoryUsage,
405 OptimizationType, SpecializedAllocator, ThreadingRequirements,
406 };
407
408 #[cfg(feature = "lapack")]
410 pub use crate::new_modules::eigenvalues::{eig as eig_general, eigh, eigvals, eigvalsh};
411 pub use crate::new_modules::fft::FFT;
412 #[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
413 pub use crate::new_modules::matrix_decomp::{
414 cholesky, cod, condition_number, lu, pivoted_cholesky, qr, rcond, schur, svd,
415 };
416 pub use crate::new_modules::polynomial::{
417 poly, polyadd, polychebyshev, polycompanion, polycompose, polyder, polydiv, polyextrap,
418 polyfit, polyfit_weighted, polyfromroots, polygcd, polygrid2d, polyhermite, polyint,
419 polyjacobi, polylaguerre, polylegendre, polymul, polymulx, polypower, polyresidual,
420 polyscale, polysub, polytrim, polyval2d, polyvander, polyvander2d, CubicSpline, Polynomial,
421 PolynomialInterpolation,
422 };
423
424 #[cfg(feature = "lapack")]
426 pub use crate::optimized_ops::parallel_matrix_ops;
427 pub use crate::optimized_ops::{
428 adaptive_array_sum, chunked_array_processing, get_optimization_info,
429 parallel_column_statistics, should_use_parallel, simd_elementwise_ops, simd_matmul,
430 simd_vector_ops, ColumnStats, SimdOpsResult, SimdVectorResult,
431 };
432
433 #[cfg(feature = "gpu")]
435 pub use crate::gpu::{
436 add as gpu_add, divide as gpu_divide, matmul, multiply as gpu_multiply,
437 subtract as gpu_subtract, transpose, GpuArray, GpuContext,
438 };
439 pub use crate::new_modules::sparse::{SparseArray, SparseMatrix, SparseMatrixFormat};
440 pub use crate::new_modules::special::{
441 airy_ai, airy_bi, associated_legendre_p, bessel_i, bessel_j, bessel_k, bessel_y, beta,
442 betainc, digamma, ellipe, ellipeinc, ellipf, ellipk, erfcinv, erfinv, exp1, expi, fresnel,
443 gammainc, jacobi_elliptic, lambertw, lambertwm1, legendre_p, polylog, shichi, sici,
444 spherical_harmonic, struve_h, zeta,
445 };
446 pub use crate::arrays::{
450 ArrayView, BooleanCombineOp, BroadcastEngine, BroadcastOp, BroadcastReduction,
451 FancyIndexEngine, FancyIndexResult, ResolvedIndex, Shape, SpecializedIndexing,
452 };
453
454 pub use crate::types::custom::CustomDType;
456 pub use crate::types::datetime::{
457 business_days,
458 datetime64,
460 datetime_array,
461 datetime_as_string,
462 datetime_data,
463 timedelta64,
464 DateTime64,
465 DateTimeUnit,
466 DateUnit,
467 TimeDelta64,
468 Timezone,
469 TimezoneDateTime,
470 };
471 pub use crate::types::structured::{DType, Field, RecordArray, StructuredArray};
472
473 pub use crate::shared_array::{SharedArray, SharedArrayView};
475
476 pub use crate::expr::{
478 ArrayExpr,
479 BinaryExpr,
480 CSEOptimizer,
481 CSESupport,
482 CachedExpr,
484 Expr,
486 ExprBuilder,
488 ExprCache,
489 ExprId,
490 ExprKey,
491 LazyEval,
492 ScalarExpr,
493 SharedArrayExpr,
494 SharedBinaryExpr,
495 SharedExpr,
497 SharedExprBuilder,
498 SharedScalarExpr,
499 SharedUnaryExpr,
500 UnaryExpr,
501 };
502
503 pub use crate::memory_optimize::access_patterns::{
507 cache_aware_binary_op, cache_aware_copy, cache_aware_transform, detect_layout,
508 AccessPattern, AccessStats, Block, BlockedIterator, OptimizationHints, StrideOptimizer,
509 Tile2D, TiledIterator2D,
510 };
511
512 pub use scirs2_core::ndarray::{Axis, Dimension, IxDyn, ShapeBuilder};
514 pub use scirs2_core::{Complex, Complex64};
516}
517
518#[cfg(test)]
519mod tests {
520 use crate::prelude::*;
521 use crate::simd::{simd_add, simd_div, simd_mul, simd_prod, simd_sqrt, simd_sum};
522 use approx::assert_relative_eq;
523
524 #[test]
525 fn basic_array_ops() {
526 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
527 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
528
529 let c = a.add(&b);
531 assert_eq!(c.to_vec(), vec![6.0, 8.0, 10.0, 12.0]);
532
533 let d = a.subtract(&b);
535 assert_eq!(d.to_vec(), vec![-4.0, -4.0, -4.0, -4.0]);
536
537 let e = a.multiply(&b);
539 assert_eq!(e.to_vec(), vec![5.0, 12.0, 21.0, 32.0]);
540
541 let f = a.divide(&b);
543 assert_relative_eq!(f.to_vec()[0], 0.2, epsilon = 1e-10);
544 assert_relative_eq!(f.to_vec()[1], 1.0 / 3.0, epsilon = 1e-10);
545 assert_relative_eq!(f.to_vec()[2], 3.0 / 7.0, epsilon = 1e-10);
546 assert_relative_eq!(f.to_vec()[3], 0.5, epsilon = 1e-10);
547 }
548
549 #[test]
550 fn test_broadcasting() {
551 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
553
554 let b = a.add_scalar(5.0);
556 assert_eq!(b.to_vec(), vec![6.0, 7.0, 8.0]);
557
558 let c = a.multiply_scalar(2.0);
560 assert_eq!(c.to_vec(), vec![2.0, 4.0, 6.0]);
561
562 let row = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]).reshape(&[1, 3]);
564 let col = Array::<f64>::from_vec(vec![4.0, 5.0]).reshape(&[2, 1]);
565
566 let result = row.add_broadcast(&col).unwrap();
568 assert_eq!(result.shape(), vec![2, 3]);
569 assert_eq!(result.to_vec(), vec![5.0, 6.0, 7.0, 6.0, 7.0, 8.0]);
570
571 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
573 let b = Array::<f64>::from_vec(vec![10.0, 20.0]).reshape(&[1, 2]);
574
575 let result = a.multiply_broadcast(&b).unwrap();
577 assert_eq!(result.shape(), vec![2, 2]);
578 assert_eq!(result.to_vec(), vec![10.0, 40.0, 30.0, 80.0]);
579
580 let shape1 = vec![3, 1, 4];
582 let shape2 = vec![2, 1];
583 let broadcast_shape = Array::<f64>::broadcast_shape(&shape1, &shape2).unwrap();
584 assert_eq!(broadcast_shape, vec![3, 2, 4]);
585 }
586
587 #[test]
588 fn test_array_creation() {
589 let zeros = Array::<f64>::zeros(&[2, 3]);
591 assert_eq!(zeros.shape(), vec![2, 3]);
592 assert_eq!(zeros.to_vec(), vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
593
594 let ones = Array::<f64>::ones(&[2, 2]);
596 assert_eq!(ones.shape(), vec![2, 2]);
597 assert_eq!(ones.to_vec(), vec![1.0, 1.0, 1.0, 1.0]);
598
599 let fives = Array::<f64>::full(&[2, 2], 5.0);
601 assert_eq!(fives.shape(), vec![2, 2]);
602 assert_eq!(fives.to_vec(), vec![5.0, 5.0, 5.0, 5.0]);
603
604 let arr = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
606 let reshaped = arr.reshape(&[2, 3]);
607 assert_eq!(reshaped.shape(), vec![2, 3]);
608 assert_eq!(reshaped.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
609 }
610
611 #[test]
612 fn test_array_methods() {
613 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(&[2, 3]);
614
615 assert_eq!(a.shape(), vec![2, 3]);
617 assert_eq!(a.ndim(), 2);
618 assert_eq!(a.size(), 6);
619
620 let at = a.transpose();
622 assert_eq!(at.shape(), vec![3, 2]);
623
624 let at_vec = at.to_vec();
627 assert_eq!(at_vec.len(), 6);
628 assert!(at_vec.contains(&1.0));
629 assert!(at_vec.contains(&2.0));
630 assert!(at_vec.contains(&3.0));
631 assert!(at_vec.contains(&4.0));
632 assert!(at_vec.contains(&5.0));
633 assert!(at_vec.contains(&6.0));
634
635 let slice = a.slice(0, 1).unwrap();
637 assert_eq!(slice.shape(), vec![3]);
638 assert_eq!(slice.to_vec(), vec![4.0, 5.0, 6.0]);
639 }
640
641 #[test]
642 fn test_map_operations() {
643 let a = Array::<f64>::from_vec(vec![1.0, 4.0, 9.0, 16.0]);
644
645 let sqrt_a = a.map(|x| x.sqrt());
647 assert_relative_eq!(sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
648 assert_relative_eq!(sqrt_a.to_vec()[1], 2.0, epsilon = 1e-10);
649 assert_relative_eq!(sqrt_a.to_vec()[2], 3.0, epsilon = 1e-10);
650 assert_relative_eq!(sqrt_a.to_vec()[3], 4.0, epsilon = 1e-10);
651
652 let par_sqrt_a = a.par_map(|x| x.sqrt());
654 assert_relative_eq!(par_sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
655 assert_relative_eq!(par_sqrt_a.to_vec()[1], 2.0, epsilon = 1e-10);
656 assert_relative_eq!(par_sqrt_a.to_vec()[2], 3.0, epsilon = 1e-10);
657 assert_relative_eq!(par_sqrt_a.to_vec()[3], 4.0, epsilon = 1e-10);
658 }
659
660 #[cfg(feature = "lapack")]
661 #[test]
662 fn test_linalg_ops() {
663 let a = Array::<f64>::from_vec(vec![4.0, 7.0, 2.0, 6.0]).reshape(&[2, 2]);
665
666 let det_a = det(&a).unwrap();
668 assert_relative_eq!(det_a, 10.0, epsilon = 1e-10);
669
670 let inv_a = inv(&a).unwrap();
672 let expected_inv = [0.6, -0.7, -0.2, 0.4];
673 for (actual, expected) in inv_a.to_vec().iter().zip(expected_inv.iter()) {
674 assert_relative_eq!(*actual, *expected, epsilon = 1e-10);
675 }
676
677 let identity = a.matmul(&inv_a).unwrap();
679 assert_relative_eq!(identity.to_vec()[0], 1.0, epsilon = 1e-10);
680 assert_relative_eq!(identity.to_vec()[1], 0.0, epsilon = 1e-10);
681 assert_relative_eq!(identity.to_vec()[2], 0.0, epsilon = 1e-10);
682 assert_relative_eq!(identity.to_vec()[3], 1.0, epsilon = 1e-10);
683
684 let b = Array::<f64>::from_vec(vec![1.0, 3.0]);
686 let x = solve(&a, &b).unwrap();
687
688 assert_relative_eq!(x.to_vec()[0], -1.5, epsilon = 1e-10);
690 assert_relative_eq!(x.to_vec()[1], 1.0, epsilon = 1e-10);
691
692 let b_check = match a.matmul(&x.reshape(&[2, 1])) {
694 Ok(result) => result.reshape(&[2]),
695 Err(_) => panic!("Matrix multiplication failed"),
696 };
697 assert_relative_eq!(b_check.to_vec()[0], b.to_vec()[0], epsilon = 1e-10);
698 assert_relative_eq!(b_check.to_vec()[1], b.to_vec()[1], epsilon = 1e-10);
699 }
700
701 #[test]
702 fn test_tensor_operations() {
703 let a = Array::<f64>::from_vec(vec![1.0, 2.0]).reshape(&[1, 2]);
705 let b = Array::<f64>::from_vec(vec![3.0, 4.0]).reshape(&[2, 1]);
706
707 let kron_result = kron(&a, &b).unwrap();
708 assert_eq!(kron_result.shape(), &[2, 2]);
709 assert_eq!(kron_result.to_vec(), vec![3.0, 6.0, 4.0, 8.0]);
710
711 let tensordot_result = tensordot(&a, &b, &[1, 0]).unwrap();
713 assert_eq!(tensordot_result.shape(), &[1, 1]);
714 assert_relative_eq!(tensordot_result.to_vec()[0], 11.0, epsilon = 1e-10);
715 }
716
717 #[test]
718 fn test_matrix_operations() {
719 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
721 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
722
723 let c = a.matmul(&b).unwrap();
725 assert_eq!(c.shape(), vec![2, 2]);
726 assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
727
728 let v = Array::<f64>::from_vec(vec![1.0, 2.0]);
730 let result = a.matmul(&v.reshape(&[2, 1])).unwrap().reshape(&[2]);
731 assert_eq!(result.to_vec(), vec![5.0, 11.0]);
732 }
733
734 #[test]
735 fn test_simd_operations() {
736 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
737 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]);
738
739 let c = simd_add(&a, &b).unwrap();
741 assert_eq!(c.to_vec(), vec![6.0, 8.0, 10.0, 12.0]);
742
743 let d = simd_mul(&a, &b).unwrap();
745 assert_eq!(d.to_vec(), vec![5.0, 12.0, 21.0, 32.0]);
746
747 let e = simd_div(&a, &b).unwrap();
749 assert_relative_eq!(e.to_vec()[0], 0.2, epsilon = 1e-10);
750 assert_relative_eq!(e.to_vec()[1], 1.0 / 3.0, epsilon = 1e-10);
751 assert_relative_eq!(e.to_vec()[2], 3.0 / 7.0, epsilon = 1e-10);
752 assert_relative_eq!(e.to_vec()[3], 0.5, epsilon = 1e-10);
753
754 let sqrt_a = simd_sqrt(&a);
756 assert_relative_eq!(sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
757 assert_relative_eq!(
758 sqrt_a.to_vec()[1],
759 std::f64::consts::SQRT_2,
760 epsilon = 1e-10
761 );
762 assert_relative_eq!(sqrt_a.to_vec()[2], 1.7320508075688772, epsilon = 1e-10);
763 assert_relative_eq!(sqrt_a.to_vec()[3], 2.0, epsilon = 1e-10);
764
765 assert_eq!(simd_sum(&a), 10.0);
767 assert_eq!(simd_prod(&a), 24.0);
768 }
769
770 #[test]
771 fn test_norm_functions() {
772 let v = Array::<f64>::from_vec(vec![3.0, 4.0]);
774
775 let norm_1 = norm(&v, Some(1.0)).unwrap();
777 assert_relative_eq!(norm_1, 7.0, epsilon = 1e-10);
778
779 let norm_2 = norm(&v, Some(2.0)).unwrap();
781 assert_relative_eq!(norm_2, 5.0, epsilon = 1e-10);
782
783 let norm_inf = norm(&v, Some(f64::INFINITY)).unwrap();
785 assert_relative_eq!(norm_inf, 4.0, epsilon = 1e-10);
786
787 let m = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
789
790 let matrix_norm_1 = norm(&m, Some(1.0)).unwrap();
792 assert_relative_eq!(matrix_norm_1, 6.0, epsilon = 1e-10);
793
794 let matrix_norm_inf = norm(&m, Some(f64::INFINITY)).unwrap();
796 assert_relative_eq!(matrix_norm_inf, 7.0, epsilon = 1e-10);
797 }
798
799 #[test]
800 fn test_math_operations() {
801 use crate::math::*;
802
803 let a = Array::<f64>::from_vec(vec![1.0, 4.0, 9.0, 16.0]);
805
806 let neg_a = a.map(|x| -x);
808 let abs_a = neg_a.abs();
809 for (expected, actual) in a.to_vec().iter().zip(abs_a.to_vec().iter()) {
810 assert_relative_eq!(*expected, *actual, epsilon = 1e-10);
811 }
812
813 let exp_a = a.exp();
815 assert_relative_eq!(exp_a.to_vec()[0], 1.0_f64.exp(), epsilon = 1e-10);
816 assert_relative_eq!(exp_a.to_vec()[1], 4.0_f64.exp(), epsilon = 1e-10);
817 assert_relative_eq!(exp_a.to_vec()[2], 9.0_f64.exp(), epsilon = 1e-10);
818 assert_relative_eq!(exp_a.to_vec()[3], 16.0_f64.exp(), epsilon = 1e-10);
819
820 let log_a = a.log();
822 assert_relative_eq!(log_a.to_vec()[0], 1.0_f64.ln(), epsilon = 1e-10);
823 assert_relative_eq!(log_a.to_vec()[1], 4.0_f64.ln(), epsilon = 1e-10);
824 assert_relative_eq!(log_a.to_vec()[2], 9.0_f64.ln(), epsilon = 1e-10);
825 assert_relative_eq!(log_a.to_vec()[3], 16.0_f64.ln(), epsilon = 1e-10);
826
827 let sqrt_a = a.sqrt();
829 assert_relative_eq!(sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
830 assert_relative_eq!(sqrt_a.to_vec()[1], 2.0, epsilon = 1e-10);
831 assert_relative_eq!(sqrt_a.to_vec()[2], 3.0, epsilon = 1e-10);
832 assert_relative_eq!(sqrt_a.to_vec()[3], 4.0, epsilon = 1e-10);
833
834 let pow_a = a.pow(2.0);
836 assert_relative_eq!(pow_a.to_vec()[0], 1.0, epsilon = 1e-10);
837 assert_relative_eq!(pow_a.to_vec()[1], 16.0, epsilon = 1e-10);
838 assert_relative_eq!(pow_a.to_vec()[2], 81.0, epsilon = 1e-10);
839 assert_relative_eq!(pow_a.to_vec()[3], 256.0, epsilon = 1e-10);
840
841 let angles = Array::<f64>::from_vec(vec![
843 0.0,
844 std::f64::consts::PI / 6.0,
845 std::f64::consts::PI / 4.0,
846 std::f64::consts::PI / 3.0,
847 ]);
848
849 let sin_angles = angles.sin();
850 assert_relative_eq!(sin_angles.to_vec()[0], 0.0, epsilon = 1e-10);
851 assert_relative_eq!(sin_angles.to_vec()[1], 0.5, epsilon = 1e-10);
852 assert_relative_eq!(
853 sin_angles.to_vec()[2],
854 1.0 / std::f64::consts::SQRT_2,
855 epsilon = 1e-10
856 );
857 assert_relative_eq!(sin_angles.to_vec()[3], 0.8660254037844386, epsilon = 1e-10);
858
859 let cos_angles = angles.cos();
860 assert_relative_eq!(cos_angles.to_vec()[0], 1.0, epsilon = 1e-10);
861 assert_relative_eq!(cos_angles.to_vec()[1], 0.8660254037844386, epsilon = 1e-10);
862 assert_relative_eq!(
863 cos_angles.to_vec()[2],
864 1.0 / std::f64::consts::SQRT_2,
865 epsilon = 1e-10
866 );
867 assert_relative_eq!(cos_angles.to_vec()[3], 0.5, epsilon = 1e-10);
868
869 let lin = linspace(0.0, 10.0, 6);
871 assert_eq!(lin.size(), 6);
872 assert_relative_eq!(lin.to_vec()[0], 0.0, epsilon = 1e-10);
873 assert_relative_eq!(lin.to_vec()[1], 2.0, epsilon = 1e-10);
874 assert_relative_eq!(lin.to_vec()[2], 4.0, epsilon = 1e-10);
875 assert_relative_eq!(lin.to_vec()[3], 6.0, epsilon = 1e-10);
876 assert_relative_eq!(lin.to_vec()[4], 8.0, epsilon = 1e-10);
877 assert_relative_eq!(lin.to_vec()[5], 10.0, epsilon = 1e-10);
878
879 let range = arange(0.0, 5.0, 1.0);
881 assert_eq!(range.size(), 5);
882 assert_eq!(range.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
883
884 let rev_range = arange(5.0, 0.0, -1.0);
886 assert_eq!(rev_range.size(), 5);
887 assert_eq!(rev_range.to_vec(), vec![5.0, 4.0, 3.0, 2.0, 1.0]);
888
889 let log_space = logspace(0.0, 3.0, 4, None);
891 assert_eq!(log_space.size(), 4);
892 assert_relative_eq!(log_space.to_vec()[0], 1.0, epsilon = 1e-10);
893 assert_relative_eq!(log_space.to_vec()[1], 10.0, epsilon = 1e-10);
894 assert_relative_eq!(log_space.to_vec()[2], 100.0, epsilon = 1e-10);
895 assert_relative_eq!(log_space.to_vec()[3], 1000.0, epsilon = 1e-10);
896
897 let geom_space = geomspace(1.0, 1000.0, 4);
899 assert_eq!(geom_space.size(), 4);
900 assert_relative_eq!(geom_space.to_vec()[0], 1.0, epsilon = 1e-10);
901 assert_relative_eq!(geom_space.to_vec()[1], 10.0, epsilon = 1e-10);
902 assert_relative_eq!(geom_space.to_vec()[2], 100.0, epsilon = 1e-10);
903 assert_relative_eq!(geom_space.to_vec()[3], 1000.0, epsilon = 1e-10);
904 }
905
906 #[test]
907 fn test_array_operations() {
908 use crate::array_ops::*;
909
910 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
912 let tiled = tile(&a, &[2]).unwrap();
913 assert_eq!(tiled.shape(), vec![6]);
914 assert_eq!(tiled.to_vec(), vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
915
916 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
917 let tiled_2d = tile(&a_2d, &[2, 1]).unwrap();
918 assert_eq!(tiled_2d.shape(), vec![4, 2]);
919 assert_eq!(
920 tiled_2d.to_vec(),
921 vec![1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]
922 );
923
924 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
926 let repeated = repeat(&a, 2, None).unwrap();
927 assert_eq!(repeated.shape(), vec![6]);
928 assert_eq!(repeated.to_vec(), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
929
930 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
931 let repeated_axis0 = repeat(&a_2d, 2, Some(0)).unwrap();
932 assert_eq!(repeated_axis0.shape(), vec![4, 2]);
933 assert_eq!(
934 repeated_axis0.to_vec(),
935 vec![1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0]
936 );
937
938 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
940 let b = Array::<f64>::from_vec(vec![4.0, 5.0, 6.0]);
941 let c = concatenate(&[&a, &b], 0).unwrap();
942 assert_eq!(c.shape(), vec![6]);
943 assert_eq!(c.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
944
945 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
946 let b_2d = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
947 let c_axis0 = concatenate(&[&a_2d, &b_2d], 0).unwrap();
948 assert_eq!(c_axis0.shape(), vec![4, 2]);
949 assert_eq!(
950 c_axis0.to_vec(),
951 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
952 );
953
954 let c_axis1 = concatenate(&[&a_2d, &b_2d], 1).unwrap();
955 assert_eq!(c_axis1.shape(), vec![2, 4]);
956 let c_vec = c_axis1.to_vec();
957 assert_eq!(c_vec.len(), 8);
959 assert!(c_vec.contains(&1.0));
960 assert!(c_vec.contains(&2.0));
961 assert!(c_vec.contains(&3.0));
962 assert!(c_vec.contains(&4.0));
963 assert!(c_vec.contains(&5.0));
964 assert!(c_vec.contains(&6.0));
965 assert!(c_vec.contains(&7.0));
966 assert!(c_vec.contains(&8.0));
967
968 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
970 let b = Array::<f64>::from_vec(vec![4.0, 5.0, 6.0]);
971 let c = stack(&[&a, &b], 0).unwrap();
972 assert_eq!(c.shape(), vec![2, 3]);
973 let c_vec = c.to_vec();
974 assert_eq!(c_vec.len(), 6);
976 assert!(c_vec.contains(&1.0));
977 assert!(c_vec.contains(&2.0));
978 assert!(c_vec.contains(&3.0));
979 assert!(c_vec.contains(&4.0));
980 assert!(c_vec.contains(&5.0));
981 assert!(c_vec.contains(&6.0));
982
983 let d = stack(&[&a, &b], 1).unwrap();
984 assert_eq!(d.shape(), vec![3, 2]);
985 let d_vec = d.to_vec();
986 assert_eq!(d_vec.len(), 6);
988 assert!(d_vec.contains(&1.0));
989 assert!(d_vec.contains(&2.0));
990 assert!(d_vec.contains(&3.0));
991 assert!(d_vec.contains(&4.0));
992 assert!(d_vec.contains(&5.0));
993 assert!(d_vec.contains(&6.0));
994
995 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
997 let splits = split(&a, &[2, 4], 0).unwrap();
998 assert_eq!(splits.len(), 3);
999 assert_eq!(splits[0].to_vec(), vec![1.0, 2.0]);
1000 assert_eq!(splits[1].to_vec(), vec![3.0, 4.0]);
1001 assert_eq!(splits[2].to_vec(), vec![5.0, 6.0]);
1002
1003 let _a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(&[2, 3]);
1004 let splits_a = split(&a, &[2, 4], 0).unwrap();
1006 assert_eq!(splits_a.len(), 3);
1007
1008 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
1017 let expanded = expand_dims(&a, 0).unwrap();
1018 assert_eq!(expanded.shape(), vec![1, 3]);
1019 assert_eq!(expanded.to_vec(), vec![1.0, 2.0, 3.0]);
1020
1021 let expanded_end = expand_dims(&a, 1).unwrap();
1022 assert_eq!(expanded_end.shape(), vec![3, 1]);
1023 assert_eq!(expanded_end.to_vec(), vec![1.0, 2.0, 3.0]);
1024
1025 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]).reshape(&[1, 3, 1]);
1027 let squeezed = squeeze(&a, None).unwrap();
1028 assert_eq!(squeezed.shape(), vec![3]);
1029 assert_eq!(squeezed.to_vec(), vec![1.0, 2.0, 3.0]);
1030
1031 let squeezed_axis = squeeze(&a, Some(0)).unwrap();
1032 assert_eq!(squeezed_axis.shape(), vec![3, 1]);
1033 assert_eq!(squeezed_axis.to_vec(), vec![1.0, 2.0, 3.0]);
1034 }
1035
1036 #[test]
1037 fn test_statistics_functions() {
1038 use crate::stats::*;
1039
1040 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1042
1043 assert_relative_eq!(a.mean(), 3.0, epsilon = 1e-10);
1045
1046 assert_relative_eq!(a.var(), 2.0, epsilon = 1e-10);
1048
1049 assert_relative_eq!(a.std(), std::f64::consts::SQRT_2, epsilon = 1e-10);
1051
1052 assert_relative_eq!(a.min(), 1.0, epsilon = 1e-10);
1054 assert_relative_eq!(a.max(), 5.0, epsilon = 1e-10);
1055
1056 assert_relative_eq!(a.percentile(0.0), 1.0, epsilon = 1e-10);
1058 assert_relative_eq!(a.percentile(0.5), 3.0, epsilon = 1e-10);
1059 assert_relative_eq!(a.percentile(1.0), 5.0, epsilon = 1e-10);
1060 assert_relative_eq!(a.percentile(0.25), 2.0, epsilon = 1e-10);
1061 assert_relative_eq!(a.percentile(0.75), 4.0, epsilon = 1e-10);
1062
1063 let b = Array::<f64>::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0]);
1065 let cov_result = cov(&a, Some(&b), None, None, None).unwrap();
1066 assert_relative_eq!(cov_result.get(&[0, 1]).unwrap(), -2.5, epsilon = 1e-10);
1067 let corrcoef_result = corrcoef(&a, Some(&b), None).unwrap();
1068 assert_relative_eq!(corrcoef_result.get(&[0, 1]).unwrap(), -1.0, epsilon = 1e-10);
1069
1070 let data = Array::<f64>::from_vec(vec![1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0]);
1072 let (counts, bins) = histogram(&data, 4, None, None).unwrap();
1073 assert_eq!(counts.to_vec(), vec![2.0, 2.0, 2.0, 3.0]);
1074 assert_eq!(bins.size(), 5);
1075 assert_relative_eq!(bins.to_vec()[0], 1.0, epsilon = 1e-10);
1076 assert_relative_eq!(bins.to_vec()[4], 5.0, epsilon = 1e-10);
1077 }
1078
1079 #[test]
1080 fn test_boolean_indexing() {
1081 use crate::indexing::*;
1082
1083 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1085
1086 let mask = vec![true, false, true, false, true];
1088
1089 let _bool_array = Array::<bool>::from_vec(mask.clone());
1092
1093 let mut filtered = Array::<f64>::zeros(&[5]);
1095 let values = Array::<f64>::from_vec(vec![1.0, 3.0, 5.0]);
1096
1097 let mut value_idx = 0;
1099 for (i, &m) in mask.iter().enumerate() {
1100 if m {
1101 filtered
1102 .set(&[i], values.get(&[value_idx]).unwrap())
1103 .unwrap();
1104 value_idx += 1;
1105 }
1106 }
1107
1108 assert_eq!(filtered.to_vec(), vec![1.0, 0.0, 3.0, 0.0, 5.0]);
1110
1111 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1113 .reshape(&[3, 3]);
1114
1115 let _row_indices = [0]; let _col_indices = [0]; let row_result = a_2d.index(&[IndexSpec::Index(0), IndexSpec::All]).unwrap();
1121 assert_eq!(row_result.shape(), vec![3]); let row_vec = row_result.to_vec();
1125 assert_eq!(row_vec.len(), 3);
1126 assert_eq!(row_vec, vec![1.0, 2.0, 3.0]);
1127
1128 let col_result = a_2d.index(&[IndexSpec::All, IndexSpec::Index(0)]).unwrap();
1129 assert_eq!(col_result.shape(), vec![3]); assert_eq!(col_result.to_vec(), vec![1.0, 4.0, 7.0]);
1131
1132 let mut a_copy = a.clone();
1134 a_copy
1135 .set_mask(
1136 &Array::<bool>::from_vec(vec![true, false, true, false, true]),
1137 &Array::<f64>::from_vec(vec![10.0, 30.0, 50.0]),
1138 )
1139 .unwrap();
1140
1141 assert_eq!(a_copy.to_vec(), vec![10.0, 2.0, 30.0, 4.0, 50.0]);
1142 }
1143
1144 #[test]
1145 fn test_fancy_indexing() {
1146 use crate::indexing::*;
1147
1148 let _a = Array::<f64>::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
1150
1151 let _indices = [0, 1, 2];
1154 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1158 .reshape(&[3, 3]);
1159
1160 let single_element = a_2d
1162 .index(&[IndexSpec::Index(1), IndexSpec::Index(1)])
1163 .unwrap();
1164 assert_eq!(single_element.to_vec(), vec![5.0]);
1165
1166 let slice_result = a_2d
1168 .index(&[IndexSpec::Slice(0, Some(2), None), IndexSpec::All])
1169 .unwrap();
1170 assert_eq!(slice_result.shape(), vec![2, 3]);
1171 assert_eq!(slice_result.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1172 }
1173
1174 #[test]
1175 fn test_axis_operations() {
1176 use crate::axis_ops::*;
1177
1178 let mut array = Array::<f64>::zeros(&[2, 3]);
1180 array.set(&[0, 0], 1.0).unwrap();
1181 array.set(&[0, 1], 2.0).unwrap();
1182 array.set(&[0, 2], 3.0).unwrap();
1183 array.set(&[1, 0], 4.0).unwrap();
1184 array.set(&[1, 1], 5.0).unwrap();
1185 array.set(&[1, 2], 6.0).unwrap();
1186
1187 let sum_axis0 = array.sum_axis(0).unwrap();
1189 assert_eq!(sum_axis0.shape(), vec![3]);
1190 assert_eq!(sum_axis0.to_vec(), vec![5.0, 7.0, 9.0]);
1191
1192 let sum_axis1 = array.sum_axis(1).unwrap();
1194 assert_eq!(sum_axis1.shape(), vec![2]);
1195 assert_eq!(sum_axis1.to_vec(), vec![6.0, 15.0]);
1196
1197 let mean_axis0 = array.mean_axis(Some(0)).unwrap();
1199 assert_eq!(mean_axis0.shape(), vec![3]);
1200 assert_eq!(mean_axis0.to_vec(), vec![2.5, 3.5, 4.5]);
1201
1202 let mean_axis1 = array.mean_axis(Some(1)).unwrap();
1204 assert_eq!(mean_axis1.shape(), vec![2]);
1205 assert_eq!(mean_axis1.to_vec(), vec![2.0, 5.0]);
1206
1207 let min_axis0 = array.min_axis(Some(0)).unwrap();
1210 assert_eq!(min_axis0.shape(), vec![3]);
1211 let min_axis0_vec = min_axis0.to_vec();
1213 assert_eq!(min_axis0_vec, vec![1.0, 2.0, 3.0]);
1214
1215 let min_axis1 = array.min_axis(Some(1)).unwrap();
1217 assert_eq!(min_axis1.shape(), vec![2]);
1218 assert_eq!(min_axis1.to_vec(), vec![1.0, 4.0]);
1220
1221 let max_axis1 = array.max_axis(Some(1)).unwrap();
1223 assert_eq!(max_axis1.shape(), vec![2]);
1224 assert_eq!(max_axis1.to_vec(), vec![3.0, 6.0]);
1226
1227 let mut array2 = Array::<f64>::zeros(&[2, 3]);
1229 array2.set(&[0, 0], 3.0).unwrap();
1230 array2.set(&[0, 1], 2.0).unwrap();
1231 array2.set(&[0, 2], 1.0).unwrap();
1232 array2.set(&[1, 0], 0.0).unwrap();
1233 array2.set(&[1, 1], 5.0).unwrap();
1234 array2.set(&[1, 2], 6.0).unwrap();
1235
1236 let argmin_axis0 = array2.argmin_axis(0).unwrap();
1238 assert_eq!(argmin_axis0.shape(), vec![3]);
1239 assert_eq!(argmin_axis0.to_vec(), vec![1, 0, 0]);
1240
1241 let var_axis0 = array.var_axis(Some(0)).unwrap();
1255 assert_eq!(var_axis0.shape(), vec![3]);
1256 assert_relative_eq!(var_axis0.get(&[0]).unwrap(), 2.25, epsilon = 1e-10);
1257
1258 let std_axis1 = array.std_axis(Some(1)).unwrap();
1260 assert_eq!(std_axis1.shape(), vec![2]);
1261
1262 let std_row1 = std_axis1.get(&[0]).unwrap();
1265 assert!(
1266 std_row1 > 0.8 && std_row1 < 1.1,
1267 "std_row1 ({}) should be approximately 1.0 or 0.82",
1268 std_row1
1269 );
1270
1271 let std_row2 = std_axis1.get(&[1]).unwrap();
1272 assert!(
1273 std_row2 > 0.8 && std_row2 < 1.1,
1274 "std_row2 ({}) should be approximately 1.0 or 0.82",
1275 std_row2
1276 );
1277 }
1278
1279 #[test]
1280 fn test_views_and_strides() {
1281 use crate::views::SliceOrIndex;
1282
1283 let mut a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1285 .reshape(&[3, 3]);
1286
1287 let view = a.view();
1289 assert_eq!(view.shape(), vec![3, 3]);
1290
1291 let mut view_mut = a.view_mut();
1293 view_mut.set(&[0, 0], 10.0).unwrap();
1294 assert_eq!(a.get(&[0, 0]).unwrap(), 10.0);
1295
1296 a.set(&[0, 0], 1.0).unwrap();
1298
1299 let strided = a.strided_view(&[2, 2]).unwrap();
1301 assert_eq!(strided.shape(), vec![2, 2]);
1302 let flat_data = strided.to_vec();
1303 assert!(flat_data.contains(&1.0));
1304 assert!(flat_data.contains(&3.0));
1305 assert!(flat_data.contains(&7.0));
1306 assert!(flat_data.contains(&9.0));
1307
1308 let slices = vec![
1310 SliceOrIndex::Slice(0, Some(2), None),
1311 SliceOrIndex::Slice(0, Some(2), None),
1312 ];
1313 let sliced = a.sliced_view(&slices).unwrap();
1314 assert_eq!(sliced.shape(), vec![2, 2]);
1315 assert_eq!(sliced.to_vec(), vec![1.0, 2.0, 4.0, 5.0]);
1316
1317 let transposed = a.transposed_view();
1319 assert_eq!(transposed.shape(), vec![3, 3]);
1320 let _t_flat = transposed.to_vec();
1321 assert_eq!(transposed.get(&[0, 1]).unwrap(), 4.0);
1323 assert_eq!(transposed.get(&[1, 0]).unwrap(), 2.0);
1324
1325 let broadcast = a.broadcast_view(&[3, 3, 3]).unwrap();
1327 assert_eq!(broadcast.shape(), vec![3, 3, 3]);
1328 assert_eq!(broadcast.get(&[0, 0, 0]).unwrap(), 1.0);
1329 assert_eq!(broadcast.get(&[1, 0, 0]).unwrap(), 1.0);
1330 }
1331
1332 #[test]
1333 fn test_universal_functions() {
1334 use crate::ufuncs::*;
1335
1336 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1338 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]);
1339
1340 let result = add(&a, &b).unwrap();
1342 assert_eq!(result.to_vec(), vec![6.0, 8.0, 10.0, 12.0]);
1343
1344 let result = subtract(&a, &b).unwrap();
1345 assert_eq!(result.to_vec(), vec![-4.0, -4.0, -4.0, -4.0]);
1346
1347 let result = multiply(&a, &b).unwrap();
1348 assert_eq!(result.to_vec(), vec![5.0, 12.0, 21.0, 32.0]);
1349
1350 let result = divide(&a, &b).unwrap();
1351 assert_relative_eq!(result.to_vec()[0], 0.2, epsilon = 1e-10);
1352 assert_relative_eq!(result.to_vec()[1], 1.0 / 3.0, epsilon = 1e-10);
1353 assert_relative_eq!(result.to_vec()[2], 3.0 / 7.0, epsilon = 1e-10);
1354 assert_relative_eq!(result.to_vec()[3], 0.5, epsilon = 1e-10);
1355
1356 let result = power(&a, &b).unwrap();
1357 assert_relative_eq!(result.to_vec()[0], 1.0, epsilon = 1e-10);
1358 assert_relative_eq!(result.to_vec()[1], 64.0, epsilon = 1e-10);
1359 assert_relative_eq!(result.to_vec()[2], 2187.0, epsilon = 1e-10);
1360 assert_relative_eq!(result.to_vec()[3], 65536.0, epsilon = 1e-10);
1361
1362 let result = square(&a);
1364 assert_eq!(result.to_vec(), vec![1.0, 4.0, 9.0, 16.0]);
1365
1366 let result = sqrt(&a);
1367 assert_relative_eq!(result.to_vec()[0], 1.0, epsilon = 1e-10);
1368 assert_relative_eq!(
1369 result.to_vec()[1],
1370 std::f64::consts::SQRT_2,
1371 epsilon = 1e-10
1372 );
1373 assert_relative_eq!(result.to_vec()[2], 1.7320508075688772, epsilon = 1e-10);
1374 assert_relative_eq!(result.to_vec()[3], 2.0, epsilon = 1e-10);
1375
1376 let result = exp(&a);
1377 assert_relative_eq!(result.to_vec()[0], 1.0_f64.exp(), epsilon = 1e-10);
1378 assert_relative_eq!(result.to_vec()[1], 2.0_f64.exp(), epsilon = 1e-10);
1379 assert_relative_eq!(result.to_vec()[2], 3.0_f64.exp(), epsilon = 1e-10);
1380 assert_relative_eq!(result.to_vec()[3], 4.0_f64.exp(), epsilon = 1e-10);
1381
1382 let result = log(&a);
1383 assert_relative_eq!(result.to_vec()[0], 1.0_f64.ln(), epsilon = 1e-10);
1384 assert_relative_eq!(result.to_vec()[1], 2.0_f64.ln(), epsilon = 1e-10);
1385 assert_relative_eq!(result.to_vec()[2], 3.0_f64.ln(), epsilon = 1e-10);
1386 assert_relative_eq!(result.to_vec()[3], 4.0_f64.ln(), epsilon = 1e-10);
1387
1388 let result = multiply_scalar(&a, 2.0);
1390 assert_eq!(result.to_vec(), vec![2.0, 4.0, 6.0, 8.0]);
1391
1392 let row = Array::<f64>::from_vec(vec![10.0, 20.0]).reshape(&[1, 2]);
1394 let col = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]).reshape(&[3, 1]);
1395 let result = add(&row, &col).unwrap();
1396 assert_eq!(result.shape(), vec![3, 2]);
1397 assert_eq!(result.to_vec(), vec![11.0, 21.0, 12.0, 22.0, 13.0, 23.0]);
1398 }
1399}