1#![allow(deprecated)] #![allow(clippy::result_large_err)] #![allow(clippy::needless_range_loop)] #![allow(clippy::too_many_arguments)] #![allow(clippy::identity_op)] #![allow(clippy::approx_constant)] #![allow(clippy::excessive_precision)] pub mod algorithms;
69pub mod array;
70pub mod array_ops;
71pub mod array_ops_legacy;
72pub mod arrays;
73#[cfg(feature = "arrow")]
74pub mod arrow;
75pub mod autodiff;
76pub mod axis_ops;
77pub mod bitwise_ops;
78pub mod blas;
79pub mod char;
80pub mod cluster;
81pub mod comparisons;
82pub mod comparisons_broadcast;
83pub mod complex_ops;
84pub mod conversions;
85pub mod derivative;
86pub mod distance;
87#[cfg(feature = "distributed")]
88pub mod distributed;
89pub mod error;
90pub mod error_handling;
91pub mod expr;
92pub mod fft;
93pub mod financial;
94#[cfg(feature = "gpu")]
95pub mod gpu;
96pub mod indexing;
97pub mod integrate;
98pub mod interop;
99pub mod interpolate;
100pub mod io;
101pub mod linalg;
102pub mod linalg_accelerated;
103pub mod linalg_extended;
104pub mod linalg_optimized;
105pub mod linalg_parallel;
106pub mod optimized_ops; pub mod linalg_stable;
109pub mod masked;
110pub mod math;
111pub mod math_extended;
112pub mod matrix;
113pub mod memory_alloc;
114pub mod memory_optimize;
115pub mod mmap;
116pub mod ndimage;
117pub mod nn;
118pub mod ode;
119pub mod optimize;
120pub mod parallel;
121pub mod parallel_optimize;
122pub mod pde;
123pub mod printing;
124#[cfg(feature = "python")]
125pub mod python;
126pub mod random;
127pub mod roots;
128pub mod set_ops;
129pub mod shared_array;
130pub mod signal;
131pub mod simd;
132pub mod simd_optimize;
133pub mod sparse;
134pub mod sparse_enhanced;
135pub mod spatial;
136pub mod special;
137pub mod stats;
138pub mod stride_tricks;
139pub mod symbolic;
140pub mod testing;
141pub mod traits;
142pub mod types;
143pub mod ufuncs;
144pub mod unique;
145pub mod unique_optimized;
146pub mod util;
147pub mod views;
148#[cfg(feature = "visualization")]
149pub mod viz;
150#[cfg(feature = "wasm")]
151pub mod wasm;
152
153pub mod new_modules;
156
157pub use error::{NumRs2Error, Result};
158
159pub use random::random_base;
161
162#[cfg(doctest)]
164pub mod doctests {}
165
166pub mod prelude {
168 pub use crate::array::Array;
169 pub use crate::array_ops::*;
170 pub use crate::array_ops_legacy::rollaxis;
172 pub use crate::axis_ops::*;
174 pub use crate::axis_ops::{apply_along_axis, apply_over_axes, vectorize};
175 pub use crate::bitwise_ops::{
176 bitwise_and, bitwise_not, bitwise_or, bitwise_xor, invert, left_shift, left_shift_scalar,
177 right_shift, right_shift_scalar,
178 };
179 pub use crate::char;
180 pub use crate::char::{array_from_strings, StringArray, StringElement};
181 pub use crate::comparisons::{
182 all, allclose, allclose_with_tol, any, array_equal, count_nonzero, equal, flatnonzero,
183 greater, greater_equal, isclose, isclose_array, less, less_equal, logical_and, logical_not,
184 logical_or, logical_xor, not_equal,
185 };
186 pub use crate::complex_ops::{
187 absolute as complex_abs, angle as complex_angle, conj as complex_conj, from_polar,
188 imag as complex_imag, iscomplex, iscomplexobj, isreal, isrealobj, real as complex_real,
189 to_complex,
190 };
191 pub use crate::conversions::*;
192 pub use crate::error::{NumRs2Error, Result};
193 pub use crate::error_handling::{
194 errstate, geterr, geterrcall, handle_error, seterr, seterrcall, ErrorAction, ErrorState,
195 ErrorStateBuilder, ErrorStateGuard, FloatingPointError,
196 };
197 pub use crate::financial::{
198 accrued_interest,
200 amortization_schedule,
202 binomial_option_price,
204 black_scholes,
205 black_scholes_greeks,
206 bond_convexity,
207 bond_duration,
208 bond_equivalent_yield,
209 bond_price,
210 bond_yield,
211 cumipmt,
213 cumprinc,
214 db,
216 ddb,
217 effect,
219 fv,
221 fv_array,
222 implied_volatility,
223 ipmt,
225 irr,
226 irr_multiple_series,
227 mirr,
228 modified_duration,
229 nominal,
230 nper,
231 nper_array,
232 npv,
233 npv_multiple_series,
234 npv_rates,
235 pmt,
236 pmt_array,
237 ppmt,
238 pv,
239 pv_array,
240 rate,
241 rate_array,
242 sln,
244 syd,
245 AmortizationSchedule,
246 };
247 pub use crate::indexing::{
249 diag_indices, diag_indices_from, extract, indices_grid, ix_, mask_indices,
250 put as indexing_put, put_along_axis, putmask as indexing_putmask, ravel_multi_index, take,
251 take_along_axis, tril_indices, tril_indices_from, triu_indices, triu_indices_from,
252 unravel_index, IndexSpec,
253 };
254 pub use crate::io::{array_to_vec2d, vec2d_to_array, vec_to_array, SerializeFormat};
255 #[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
257 pub use crate::linalg::{
258 cholesky as cholesky_basic, eig, inv, qr as qr_basic, solve, svd as svd_basic,
259 };
260 #[cfg(feature = "lapack")]
261 pub use crate::linalg::{det, matrix_power};
262 pub use crate::linalg::{inner, kron, norm, outer, tensordot, trace, vdot};
263
264 #[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
266 pub use crate::linalg::{matrix_rank, pinv};
267 pub use crate::linalg_extended::eigenvalue;
269 pub use crate::linalg_optimized::{lu_optimized, transpose_optimized, OptimizedBlas};
270 pub use crate::linalg_parallel::ParallelLinAlg;
271 pub use crate::linalg_stable::{
272 CholeskyStableResult, QRPivotedResult, SVDStableResult, StableDecompositions,
273 };
274 pub use crate::masked::MaskedArray;
275 pub use crate::ufuncs::{abs, ceil, exp, floor, log, round, sqrt};
277 pub use crate::math_extended::{erf, erfc, gamma, gammaln};
281 pub use crate::math::{
284 amax, amin, angle, arange, argmax, argmin, argpartition, argsort, around, bartlett,
285 bincount, blackman, clip, conj, copysign, cumprod, cumsum, cumulative_prod, cumulative_sum,
286 diff, diff_extended, digitize, divmod, ediff1d, empty, fmod, frexp, gcd, geomspace,
287 gradient, hamming, hanning, heaviside, i0, imag, interp, isfinite, isinf, isnan, kaiser,
288 lcm, ldexp, linspace, logspace, max, mean, median, min, modf, nan_to_num, nanmax, nanmean,
289 nanmin, nanstd, nansum, nanvar, nextafter, nonzero, ones, partition, prod, real,
290 real_if_close, remainder, resize, searchsorted, sinc, sort, std, sum, trapz, var, zeros,
291 ElementWiseMath,
292 };
293 pub use crate::matrix::{
294 asmatrix, matrix, matrix_from_nested, matrix_from_scalar, BandedMatrix, Matrix,
295 };
296 pub use crate::mmap::MmapArray;
297 pub use crate::random::advanced_distributions;
298 pub use crate::random::distributions;
299 pub use crate::random::generator::{default_rng, BitGenerator, Generator, StdBitGenerator};
300 pub use crate::random::{self, RandomState};
301 pub use crate::set_ops::{
302 in1d, intersect1d, isin, setdiff1d, setxor1d, union1d, unique_axis, unique_with_options,
303 };
304 pub use crate::signal::{convolve, convolve2d, correlate, correlate2d};
305 pub use crate::simd::get_simd_implementation_name;
307 pub use crate::sparse;
308 pub use crate::sparse_enhanced::SparseOpsAdvanced;
309 pub use crate::stats::{
311 average, corrcoef, cov, histogram, histogram_dd, max_along_axis, min_along_axis, mode,
312 percentile, ptp, quantile, HistBins, Statistics,
313 };
314 pub use crate::stride_tricks::{
315 as_strided, broadcast_arrays, broadcast_to, byte_strides, set_strides, sliding_window_view,
316 };
317 pub use crate::testing::{
319 arrays_close, assert_array_all_finite, assert_array_almost_equal, assert_array_equal,
320 assert_array_no_nan, assert_array_same_shape, assert_scalar_almost_equal, is_finite_array,
321 test_summary, tolerances, TestResult, ToleranceConfig,
322 };
323 pub use crate::run_tests;
325 pub use crate::traits::{
327 ArrayIndexing, ArrayMath, ArrayOps, ArrayReduction, ComplexElement, FloatingPoint,
328 IntegerElement, LinearAlgebra, MatrixDecomposition, NumericElement,
329 };
330 pub use crate::ufuncs::{
333 absolute, add, add_scalar, arctan2, cbrt, divide, divide_scalar, dot, exp2, expm1, fma,
334 hypot, log10, log1p, log2, maximum, minimum, multiply, multiply_scalar, negative, norm_l1,
335 norm_l2, power, power_scalar, reciprocal, subtract, subtract_scalar, BinaryUfunc,
336 UnaryUfunc,
337 };
338 pub use crate::unique::{unique, UniqueResult};
339 pub use crate::unique_optimized::unique_optimized;
340 pub use crate::util::{
341 astype, can_operate_inplace, fast_sum, optimize_layout, parallel_map, MemoryLayout,
342 };
343 pub use crate::views::*;
344
345 pub use crate::interop::ndarray_compat::{from_ndarray, to_ndarray};
348 pub use crate::memory_optimize::{
352 align_data, optimize_layout as memory_optimize_layout, optimize_placement,
353 AlignmentStrategy, LayoutStrategy, PlacementStrategy,
354 };
355
356 pub use crate::parallel_optimize::{
358 adaptive_threshold, optimize_parallel_computation, optimize_scheduling, partition_workload,
359 };
360 pub use crate::parallel_optimize::{
361 ParallelConfig, ParallelizationThreshold, SchedulingStrategy, WorkloadPartitioning,
362 };
363
364 pub use crate::printing::{
366 array_str, get_printoptions, reset_printoptions, set_printoptions, PrintOptions,
367 };
368
369 pub use crate::memory_alloc::{
371 get_default_allocator, get_global_allocator_strategy, init_global_allocator,
372 reset_global_allocator,
373 };
374 pub use crate::memory_alloc::{
375 AlignedAllocator, AlignmentConfig, AllocStrategy, ArenaAllocator, ArenaConfig, CacheConfig,
376 CacheLevel, CacheOptimizedAllocator, PoolAllocator, PoolConfig,
377 };
378
379 pub use crate::algorithms::{
381 BandwidthEstimate, BandwidthOptimizer, CacheAwareArrayOps, CacheAwareConvolution,
382 CacheAwareFFT, MemoryOperation,
383 };
384
385 pub use crate::parallel::parallel_algorithms::ParallelConfig as ParallelAlgorithmConfig;
387 pub use crate::parallel::{
388 global_parallel_context, initialize_parallel_context, shutdown_parallel_context, task,
389 BalancingStrategy, LoadBalancer, ParallelAllocator, ParallelAllocatorConfig,
390 ParallelArrayOps, ParallelContext, ParallelFFT, ParallelMatrixOps, ParallelScheduler,
391 SchedulerConfig, Task, TaskPriority, TaskResult, ThreadLocalAllocator, WorkStealingPool,
392 WorkloadMetrics,
393 };
394
395 pub use crate::memory_alloc::{
397 EnhancedAllocatorBridge, IntelligentAllocationStrategy, NumericalArrayAllocator,
398 };
399 pub use crate::traits::{
400 AllocationFrequency, AllocationLifetime, AllocationRequirements, AllocationStats,
401 AllocationStrategy, MemoryAllocator, MemoryAware, MemoryOptimization, MemoryUsage,
402 OptimizationType, SpecializedAllocator, ThreadingRequirements,
403 };
404
405 #[cfg(feature = "lapack")]
407 pub use crate::new_modules::eigenvalues::{eig as eig_general, eigh, eigvals, eigvalsh};
408 pub use crate::new_modules::fft::FFT;
409 #[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
410 pub use crate::new_modules::matrix_decomp::{
411 cholesky, cod, condition_number, lu, pivoted_cholesky, qr, rcond, schur, svd,
412 };
413 pub use crate::new_modules::polynomial::{
414 poly, polyadd, polychebyshev, polycompanion, polycompose, polyder, polydiv, polyextrap,
415 polyfit, polyfit_weighted, polyfromroots, polygcd, polygrid2d, polyhermite, polyint,
416 polyjacobi, polylaguerre, polylegendre, polymul, polymulx, polypower, polyresidual,
417 polyscale, polysub, polytrim, polyval2d, polyvander, polyvander2d, CubicSpline, Polynomial,
418 PolynomialInterpolation,
419 };
420
421 #[cfg(feature = "lapack")]
423 pub use crate::optimized_ops::parallel_matrix_ops;
424 pub use crate::optimized_ops::{
425 adaptive_array_sum, chunked_array_processing, get_optimization_info,
426 parallel_column_statistics, should_use_parallel, simd_elementwise_ops, simd_matmul,
427 simd_vector_ops, ColumnStats, SimdOpsResult, SimdVectorResult,
428 };
429
430 #[cfg(feature = "gpu")]
432 pub use crate::gpu::{
433 add as gpu_add, divide as gpu_divide, matmul, multiply as gpu_multiply,
434 subtract as gpu_subtract, transpose, GpuArray, GpuContext,
435 };
436 pub use crate::new_modules::sparse::{SparseArray, SparseMatrix, SparseMatrixFormat};
437 pub use crate::new_modules::special::{
438 airy_ai, airy_bi, associated_legendre_p, bessel_i, bessel_j, bessel_k, bessel_y, beta,
439 betainc, digamma, ellipe, ellipeinc, ellipf, ellipk, erfcinv, erfinv, exp1, expi, fresnel,
440 gammainc, jacobi_elliptic, lambertw, lambertwm1, legendre_p, polylog, shichi, sici,
441 spherical_harmonic, struve_h, zeta,
442 };
443 pub use crate::arrays::{
447 ArrayView, BooleanCombineOp, BroadcastEngine, BroadcastOp, BroadcastReduction,
448 FancyIndexEngine, FancyIndexResult, ResolvedIndex, Shape, SpecializedIndexing,
449 };
450
451 pub use crate::types::custom::CustomDType;
453 pub use crate::types::datetime::{
454 business_days,
455 datetime64,
457 datetime_array,
458 datetime_as_string,
459 datetime_data,
460 timedelta64,
461 DateTime64,
462 DateTimeUnit,
463 DateUnit,
464 TimeDelta64,
465 Timezone,
466 TimezoneDateTime,
467 };
468 pub use crate::types::structured::{DType, Field, RecordArray, StructuredArray};
469
470 pub use crate::shared_array::{SharedArray, SharedArrayView};
472
473 pub use crate::expr::{
475 ArrayExpr,
476 BinaryExpr,
477 CSEOptimizer,
478 CSESupport,
479 CachedExpr,
481 Expr,
483 ExprBuilder,
485 ExprCache,
486 ExprId,
487 ExprKey,
488 LazyEval,
489 ScalarExpr,
490 SharedArrayExpr,
491 SharedBinaryExpr,
492 SharedExpr,
494 SharedExprBuilder,
495 SharedScalarExpr,
496 SharedUnaryExpr,
497 UnaryExpr,
498 };
499
500 pub use crate::memory_optimize::access_patterns::{
504 cache_aware_binary_op, cache_aware_copy, cache_aware_transform, detect_layout,
505 AccessPattern, AccessStats, Block, BlockedIterator, OptimizationHints, StrideOptimizer,
506 Tile2D, TiledIterator2D,
507 };
508
509 pub use scirs2_core::ndarray::{Axis, Dimension, IxDyn, ShapeBuilder};
511 pub use scirs2_core::{Complex, Complex64};
513}
514
515#[cfg(test)]
516mod tests {
517 use crate::prelude::*;
518 use crate::simd::{simd_add, simd_div, simd_mul, simd_prod, simd_sqrt, simd_sum};
519 use approx::assert_relative_eq;
520
521 #[test]
522 fn basic_array_ops() {
523 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
524 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
525
526 let c = a.add(&b);
528 assert_eq!(c.to_vec(), vec![6.0, 8.0, 10.0, 12.0]);
529
530 let d = a.subtract(&b);
532 assert_eq!(d.to_vec(), vec![-4.0, -4.0, -4.0, -4.0]);
533
534 let e = a.multiply(&b);
536 assert_eq!(e.to_vec(), vec![5.0, 12.0, 21.0, 32.0]);
537
538 let f = a.divide(&b);
540 assert_relative_eq!(f.to_vec()[0], 0.2, epsilon = 1e-10);
541 assert_relative_eq!(f.to_vec()[1], 1.0 / 3.0, epsilon = 1e-10);
542 assert_relative_eq!(f.to_vec()[2], 3.0 / 7.0, epsilon = 1e-10);
543 assert_relative_eq!(f.to_vec()[3], 0.5, epsilon = 1e-10);
544 }
545
546 #[test]
547 fn test_broadcasting() {
548 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
550
551 let b = a.add_scalar(5.0);
553 assert_eq!(b.to_vec(), vec![6.0, 7.0, 8.0]);
554
555 let c = a.multiply_scalar(2.0);
557 assert_eq!(c.to_vec(), vec![2.0, 4.0, 6.0]);
558
559 let row = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]).reshape(&[1, 3]);
561 let col = Array::<f64>::from_vec(vec![4.0, 5.0]).reshape(&[2, 1]);
562
563 let result = row
565 .add_broadcast(&col)
566 .expect("test: broadcast addition should succeed");
567 assert_eq!(result.shape(), vec![2, 3]);
568 assert_eq!(result.to_vec(), vec![5.0, 6.0, 7.0, 6.0, 7.0, 8.0]);
569
570 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
572 let b = Array::<f64>::from_vec(vec![10.0, 20.0]).reshape(&[1, 2]);
573
574 let result = a
576 .multiply_broadcast(&b)
577 .expect("test: broadcast multiplication should succeed");
578 assert_eq!(result.shape(), vec![2, 2]);
579 assert_eq!(result.to_vec(), vec![10.0, 40.0, 30.0, 80.0]);
580
581 let shape1 = vec![3, 1, 4];
583 let shape2 = vec![2, 1];
584 let broadcast_shape = Array::<f64>::broadcast_shape(&shape1, &shape2)
585 .expect("test: broadcast shape computation should succeed");
586 assert_eq!(broadcast_shape, vec![3, 2, 4]);
587 }
588
589 #[test]
590 fn test_array_creation() {
591 let zeros = Array::<f64>::zeros(&[2, 3]);
593 assert_eq!(zeros.shape(), vec![2, 3]);
594 assert_eq!(zeros.to_vec(), vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
595
596 let ones = Array::<f64>::ones(&[2, 2]);
598 assert_eq!(ones.shape(), vec![2, 2]);
599 assert_eq!(ones.to_vec(), vec![1.0, 1.0, 1.0, 1.0]);
600
601 let fives = Array::<f64>::full(&[2, 2], 5.0);
603 assert_eq!(fives.shape(), vec![2, 2]);
604 assert_eq!(fives.to_vec(), vec![5.0, 5.0, 5.0, 5.0]);
605
606 let arr = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
608 let reshaped = arr.reshape(&[2, 3]);
609 assert_eq!(reshaped.shape(), vec![2, 3]);
610 assert_eq!(reshaped.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
611 }
612
613 #[test]
614 fn test_array_methods() {
615 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(&[2, 3]);
616
617 assert_eq!(a.shape(), vec![2, 3]);
619 assert_eq!(a.ndim(), 2);
620 assert_eq!(a.size(), 6);
621
622 let at = a.transpose();
624 assert_eq!(at.shape(), vec![3, 2]);
625
626 let at_vec = at.to_vec();
629 assert_eq!(at_vec.len(), 6);
630 assert!(at_vec.contains(&1.0));
631 assert!(at_vec.contains(&2.0));
632 assert!(at_vec.contains(&3.0));
633 assert!(at_vec.contains(&4.0));
634 assert!(at_vec.contains(&5.0));
635 assert!(at_vec.contains(&6.0));
636
637 let slice = a
639 .slice(0, 1)
640 .expect("test: slice should succeed for valid axis");
641 assert_eq!(slice.shape(), vec![3]);
642 assert_eq!(slice.to_vec(), vec![4.0, 5.0, 6.0]);
643 }
644
645 #[test]
646 fn test_map_operations() {
647 let a = Array::<f64>::from_vec(vec![1.0, 4.0, 9.0, 16.0]);
648
649 let sqrt_a = a.map(|x| x.sqrt());
651 assert_relative_eq!(sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
652 assert_relative_eq!(sqrt_a.to_vec()[1], 2.0, epsilon = 1e-10);
653 assert_relative_eq!(sqrt_a.to_vec()[2], 3.0, epsilon = 1e-10);
654 assert_relative_eq!(sqrt_a.to_vec()[3], 4.0, epsilon = 1e-10);
655
656 let par_sqrt_a = a.par_map(|x| x.sqrt());
658 assert_relative_eq!(par_sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
659 assert_relative_eq!(par_sqrt_a.to_vec()[1], 2.0, epsilon = 1e-10);
660 assert_relative_eq!(par_sqrt_a.to_vec()[2], 3.0, epsilon = 1e-10);
661 assert_relative_eq!(par_sqrt_a.to_vec()[3], 4.0, epsilon = 1e-10);
662 }
663
664 #[cfg(feature = "lapack")]
665 #[test]
666 fn test_linalg_ops() {
667 let a = Array::<f64>::from_vec(vec![4.0, 7.0, 2.0, 6.0]).reshape(&[2, 2]);
669
670 let det_a = det(&a).expect("test: determinant computation should succeed");
672 assert_relative_eq!(det_a, 10.0, epsilon = 1e-10);
673
674 let inv_a = inv(&a).expect("test: matrix inverse should succeed for invertible matrix");
676 let expected_inv = [0.6, -0.7, -0.2, 0.4];
677 for (actual, expected) in inv_a.to_vec().iter().zip(expected_inv.iter()) {
678 assert_relative_eq!(*actual, *expected, epsilon = 1e-10);
679 }
680
681 let identity = a
683 .matmul(&inv_a)
684 .expect("test: matrix multiplication should succeed");
685 assert_relative_eq!(identity.to_vec()[0], 1.0, epsilon = 1e-10);
686 assert_relative_eq!(identity.to_vec()[1], 0.0, epsilon = 1e-10);
687 assert_relative_eq!(identity.to_vec()[2], 0.0, epsilon = 1e-10);
688 assert_relative_eq!(identity.to_vec()[3], 1.0, epsilon = 1e-10);
689
690 let b = Array::<f64>::from_vec(vec![1.0, 3.0]);
692 let x = solve(&a, &b).expect("test: linear system solve should succeed");
693
694 assert_relative_eq!(x.to_vec()[0], -1.5, epsilon = 1e-10);
696 assert_relative_eq!(x.to_vec()[1], 1.0, epsilon = 1e-10);
697
698 let b_check = a
700 .matmul(&x.reshape(&[2, 1]))
701 .expect("test: matrix-vector multiplication should succeed")
702 .reshape(&[2]);
703 assert_relative_eq!(b_check.to_vec()[0], b.to_vec()[0], epsilon = 1e-10);
704 assert_relative_eq!(b_check.to_vec()[1], b.to_vec()[1], epsilon = 1e-10);
705 }
706
707 #[test]
708 fn test_tensor_operations() {
709 let a = Array::<f64>::from_vec(vec![1.0, 2.0]).reshape(&[1, 2]);
711 let b = Array::<f64>::from_vec(vec![3.0, 4.0]).reshape(&[2, 1]);
712
713 let kron_result = kron(&a, &b).expect("test: Kronecker product should succeed");
714 assert_eq!(kron_result.shape(), &[2, 2]);
715 assert_eq!(kron_result.to_vec(), vec![3.0, 6.0, 4.0, 8.0]);
716
717 let tensordot_result = tensordot(&a, &b, &[1, 0]).expect("test: tensordot should succeed");
719 assert_eq!(tensordot_result.shape(), &[1, 1]);
720 assert_relative_eq!(tensordot_result.to_vec()[0], 11.0, epsilon = 1e-10);
721 }
722
723 #[test]
724 fn test_matrix_operations() {
725 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
727 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
728
729 let c = a
731 .matmul(&b)
732 .expect("test: matrix multiplication should succeed");
733 assert_eq!(c.shape(), vec![2, 2]);
734 assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
735
736 let v = Array::<f64>::from_vec(vec![1.0, 2.0]);
738 let result = a
739 .matmul(&v.reshape(&[2, 1]))
740 .expect("test: matrix-vector multiplication should succeed")
741 .reshape(&[2]);
742 assert_eq!(result.to_vec(), vec![5.0, 11.0]);
743 }
744
745 #[test]
746 fn test_simd_operations() {
747 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
748 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]);
749
750 let c = simd_add(&a, &b).expect("test: SIMD addition should succeed");
752 assert_eq!(c.to_vec(), vec![6.0, 8.0, 10.0, 12.0]);
753
754 let d = simd_mul(&a, &b).expect("test: SIMD multiplication should succeed");
756 assert_eq!(d.to_vec(), vec![5.0, 12.0, 21.0, 32.0]);
757
758 let e = simd_div(&a, &b).expect("test: SIMD division should succeed");
760 assert_relative_eq!(e.to_vec()[0], 0.2, epsilon = 1e-10);
761 assert_relative_eq!(e.to_vec()[1], 1.0 / 3.0, epsilon = 1e-10);
762 assert_relative_eq!(e.to_vec()[2], 3.0 / 7.0, epsilon = 1e-10);
763 assert_relative_eq!(e.to_vec()[3], 0.5, epsilon = 1e-10);
764
765 let sqrt_a = simd_sqrt(&a);
767 assert_relative_eq!(sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
768 assert_relative_eq!(
769 sqrt_a.to_vec()[1],
770 std::f64::consts::SQRT_2,
771 epsilon = 1e-10
772 );
773 assert_relative_eq!(sqrt_a.to_vec()[2], 1.7320508075688772, epsilon = 1e-10);
774 assert_relative_eq!(sqrt_a.to_vec()[3], 2.0, epsilon = 1e-10);
775
776 assert_eq!(simd_sum(&a), 10.0);
778 assert_eq!(simd_prod(&a), 24.0);
779 }
780
781 #[test]
782 fn test_norm_functions() {
783 let v = Array::<f64>::from_vec(vec![3.0, 4.0]);
785
786 let norm_1 = norm(&v, Some(1.0)).expect("test: L1 norm computation should succeed");
788 assert_relative_eq!(norm_1, 7.0, epsilon = 1e-10);
789
790 let norm_2 = norm(&v, Some(2.0)).expect("test: L2 norm computation should succeed");
792 assert_relative_eq!(norm_2, 5.0, epsilon = 1e-10);
793
794 let norm_inf =
796 norm(&v, Some(f64::INFINITY)).expect("test: infinity norm computation should succeed");
797 assert_relative_eq!(norm_inf, 4.0, epsilon = 1e-10);
798
799 let m = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
801
802 let matrix_norm_1 =
804 norm(&m, Some(1.0)).expect("test: matrix L1 norm computation should succeed");
805 assert_relative_eq!(matrix_norm_1, 6.0, epsilon = 1e-10);
806
807 let matrix_norm_inf = norm(&m, Some(f64::INFINITY))
809 .expect("test: matrix infinity norm computation should succeed");
810 assert_relative_eq!(matrix_norm_inf, 7.0, epsilon = 1e-10);
811 }
812
813 #[test]
814 fn test_math_operations() {
815 use crate::math::*;
816
817 let a = Array::<f64>::from_vec(vec![1.0, 4.0, 9.0, 16.0]);
819
820 let neg_a = a.map(|x| -x);
822 let abs_a = neg_a.abs();
823 for (expected, actual) in a.to_vec().iter().zip(abs_a.to_vec().iter()) {
824 assert_relative_eq!(*expected, *actual, epsilon = 1e-10);
825 }
826
827 let exp_a = a.exp();
829 assert_relative_eq!(exp_a.to_vec()[0], 1.0_f64.exp(), epsilon = 1e-10);
830 assert_relative_eq!(exp_a.to_vec()[1], 4.0_f64.exp(), epsilon = 1e-10);
831 assert_relative_eq!(exp_a.to_vec()[2], 9.0_f64.exp(), epsilon = 1e-10);
832 assert_relative_eq!(exp_a.to_vec()[3], 16.0_f64.exp(), epsilon = 1e-10);
833
834 let log_a = a.log();
836 assert_relative_eq!(log_a.to_vec()[0], 1.0_f64.ln(), epsilon = 1e-10);
837 assert_relative_eq!(log_a.to_vec()[1], 4.0_f64.ln(), epsilon = 1e-10);
838 assert_relative_eq!(log_a.to_vec()[2], 9.0_f64.ln(), epsilon = 1e-10);
839 assert_relative_eq!(log_a.to_vec()[3], 16.0_f64.ln(), epsilon = 1e-10);
840
841 let sqrt_a = a.sqrt();
843 assert_relative_eq!(sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
844 assert_relative_eq!(sqrt_a.to_vec()[1], 2.0, epsilon = 1e-10);
845 assert_relative_eq!(sqrt_a.to_vec()[2], 3.0, epsilon = 1e-10);
846 assert_relative_eq!(sqrt_a.to_vec()[3], 4.0, epsilon = 1e-10);
847
848 let pow_a = a.pow(2.0);
850 assert_relative_eq!(pow_a.to_vec()[0], 1.0, epsilon = 1e-10);
851 assert_relative_eq!(pow_a.to_vec()[1], 16.0, epsilon = 1e-10);
852 assert_relative_eq!(pow_a.to_vec()[2], 81.0, epsilon = 1e-10);
853 assert_relative_eq!(pow_a.to_vec()[3], 256.0, epsilon = 1e-10);
854
855 let angles = Array::<f64>::from_vec(vec![
857 0.0,
858 std::f64::consts::PI / 6.0,
859 std::f64::consts::PI / 4.0,
860 std::f64::consts::PI / 3.0,
861 ]);
862
863 let sin_angles = angles.sin();
864 assert_relative_eq!(sin_angles.to_vec()[0], 0.0, epsilon = 1e-10);
865 assert_relative_eq!(sin_angles.to_vec()[1], 0.5, epsilon = 1e-10);
866 assert_relative_eq!(
867 sin_angles.to_vec()[2],
868 1.0 / std::f64::consts::SQRT_2,
869 epsilon = 1e-10
870 );
871 assert_relative_eq!(sin_angles.to_vec()[3], 0.8660254037844386, epsilon = 1e-10);
872
873 let cos_angles = angles.cos();
874 assert_relative_eq!(cos_angles.to_vec()[0], 1.0, epsilon = 1e-10);
875 assert_relative_eq!(cos_angles.to_vec()[1], 0.8660254037844386, epsilon = 1e-10);
876 assert_relative_eq!(
877 cos_angles.to_vec()[2],
878 1.0 / std::f64::consts::SQRT_2,
879 epsilon = 1e-10
880 );
881 assert_relative_eq!(cos_angles.to_vec()[3], 0.5, epsilon = 1e-10);
882
883 let lin = linspace(0.0, 10.0, 6);
885 assert_eq!(lin.size(), 6);
886 assert_relative_eq!(lin.to_vec()[0], 0.0, epsilon = 1e-10);
887 assert_relative_eq!(lin.to_vec()[1], 2.0, epsilon = 1e-10);
888 assert_relative_eq!(lin.to_vec()[2], 4.0, epsilon = 1e-10);
889 assert_relative_eq!(lin.to_vec()[3], 6.0, epsilon = 1e-10);
890 assert_relative_eq!(lin.to_vec()[4], 8.0, epsilon = 1e-10);
891 assert_relative_eq!(lin.to_vec()[5], 10.0, epsilon = 1e-10);
892
893 let range = arange(0.0, 5.0, 1.0);
895 assert_eq!(range.size(), 5);
896 assert_eq!(range.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
897
898 let rev_range = arange(5.0, 0.0, -1.0);
900 assert_eq!(rev_range.size(), 5);
901 assert_eq!(rev_range.to_vec(), vec![5.0, 4.0, 3.0, 2.0, 1.0]);
902
903 let log_space = logspace(0.0, 3.0, 4, None);
905 assert_eq!(log_space.size(), 4);
906 assert_relative_eq!(log_space.to_vec()[0], 1.0, epsilon = 1e-10);
907 assert_relative_eq!(log_space.to_vec()[1], 10.0, epsilon = 1e-10);
908 assert_relative_eq!(log_space.to_vec()[2], 100.0, epsilon = 1e-10);
909 assert_relative_eq!(log_space.to_vec()[3], 1000.0, epsilon = 1e-10);
910
911 let geom_space = geomspace(1.0, 1000.0, 4);
913 assert_eq!(geom_space.size(), 4);
914 assert_relative_eq!(geom_space.to_vec()[0], 1.0, epsilon = 1e-10);
915 assert_relative_eq!(geom_space.to_vec()[1], 10.0, epsilon = 1e-10);
916 assert_relative_eq!(geom_space.to_vec()[2], 100.0, epsilon = 1e-10);
917 assert_relative_eq!(geom_space.to_vec()[3], 1000.0, epsilon = 1e-10);
918 }
919
920 #[test]
921 fn test_array_operations() {
922 use crate::array_ops::*;
923
924 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
926 let tiled = tile(&a, &[2]).expect("test: tile operation should succeed");
927 assert_eq!(tiled.shape(), vec![6]);
928 assert_eq!(tiled.to_vec(), vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
929
930 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
931 let tiled_2d = tile(&a_2d, &[2, 1]).expect("test: 2D tile operation should succeed");
932 assert_eq!(tiled_2d.shape(), vec![4, 2]);
933 assert_eq!(
934 tiled_2d.to_vec(),
935 vec![1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]
936 );
937
938 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
940 let repeated = repeat(&a, 2, None).expect("test: repeat operation should succeed");
941 assert_eq!(repeated.shape(), vec![6]);
942 assert_eq!(repeated.to_vec(), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
943
944 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
945 let repeated_axis0 =
946 repeat(&a_2d, 2, Some(0)).expect("test: repeat along axis 0 should succeed");
947 assert_eq!(repeated_axis0.shape(), vec![4, 2]);
948 assert_eq!(
949 repeated_axis0.to_vec(),
950 vec![1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0]
951 );
952
953 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
955 let b = Array::<f64>::from_vec(vec![4.0, 5.0, 6.0]);
956 let c = concatenate(&[&a, &b], 0).expect("test: concatenate should succeed");
957 assert_eq!(c.shape(), vec![6]);
958 assert_eq!(c.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
959
960 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
961 let b_2d = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
962 let c_axis0 =
963 concatenate(&[&a_2d, &b_2d], 0).expect("test: concatenate along axis 0 should succeed");
964 assert_eq!(c_axis0.shape(), vec![4, 2]);
965 assert_eq!(
966 c_axis0.to_vec(),
967 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
968 );
969
970 let c_axis1 =
971 concatenate(&[&a_2d, &b_2d], 1).expect("test: concatenate along axis 1 should succeed");
972 assert_eq!(c_axis1.shape(), vec![2, 4]);
973 let c_vec = c_axis1.to_vec();
974 assert_eq!(c_vec.len(), 8);
976 assert!(c_vec.contains(&1.0));
977 assert!(c_vec.contains(&2.0));
978 assert!(c_vec.contains(&3.0));
979 assert!(c_vec.contains(&4.0));
980 assert!(c_vec.contains(&5.0));
981 assert!(c_vec.contains(&6.0));
982 assert!(c_vec.contains(&7.0));
983 assert!(c_vec.contains(&8.0));
984
985 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
987 let b = Array::<f64>::from_vec(vec![4.0, 5.0, 6.0]);
988 let c = stack(&[&a, &b], 0).expect("test: stack along axis 0 should succeed");
989 assert_eq!(c.shape(), vec![2, 3]);
990 let c_vec = c.to_vec();
991 assert_eq!(c_vec.len(), 6);
993 assert!(c_vec.contains(&1.0));
994 assert!(c_vec.contains(&2.0));
995 assert!(c_vec.contains(&3.0));
996 assert!(c_vec.contains(&4.0));
997 assert!(c_vec.contains(&5.0));
998 assert!(c_vec.contains(&6.0));
999
1000 let d = stack(&[&a, &b], 1).expect("test: stack along axis 1 should succeed");
1001 assert_eq!(d.shape(), vec![3, 2]);
1002 let d_vec = d.to_vec();
1003 assert_eq!(d_vec.len(), 6);
1005 assert!(d_vec.contains(&1.0));
1006 assert!(d_vec.contains(&2.0));
1007 assert!(d_vec.contains(&3.0));
1008 assert!(d_vec.contains(&4.0));
1009 assert!(d_vec.contains(&5.0));
1010 assert!(d_vec.contains(&6.0));
1011
1012 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1014 let splits = split(&a, &[2, 4], 0).expect("test: split should succeed");
1015 assert_eq!(splits.len(), 3);
1016 assert_eq!(splits[0].to_vec(), vec![1.0, 2.0]);
1017 assert_eq!(splits[1].to_vec(), vec![3.0, 4.0]);
1018 assert_eq!(splits[2].to_vec(), vec![5.0, 6.0]);
1019
1020 let _a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(&[2, 3]);
1021 let splits_a =
1023 split(&a, &[2, 4], 0).expect("test: split with multiple indices should succeed");
1024 assert_eq!(splits_a.len(), 3);
1025
1026 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
1035 let expanded = expand_dims(&a, 0).expect("test: expand_dims should succeed");
1036 assert_eq!(expanded.shape(), vec![1, 3]);
1037 assert_eq!(expanded.to_vec(), vec![1.0, 2.0, 3.0]);
1038
1039 let expanded_end = expand_dims(&a, 1).expect("test: expand_dims at end should succeed");
1040 assert_eq!(expanded_end.shape(), vec![3, 1]);
1041 assert_eq!(expanded_end.to_vec(), vec![1.0, 2.0, 3.0]);
1042
1043 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]).reshape(&[1, 3, 1]);
1045 let squeezed = squeeze(&a, None).expect("test: squeeze should succeed");
1046 assert_eq!(squeezed.shape(), vec![3]);
1047 assert_eq!(squeezed.to_vec(), vec![1.0, 2.0, 3.0]);
1048
1049 let squeezed_axis = squeeze(&a, Some(0)).expect("test: squeeze at axis 0 should succeed");
1050 assert_eq!(squeezed_axis.shape(), vec![3, 1]);
1051 assert_eq!(squeezed_axis.to_vec(), vec![1.0, 2.0, 3.0]);
1052 }
1053
1054 #[test]
1055 fn test_statistics_functions() {
1056 use crate::stats::*;
1057
1058 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1060
1061 assert_relative_eq!(a.mean(), 3.0, epsilon = 1e-10);
1063
1064 assert_relative_eq!(a.var(), 2.0, epsilon = 1e-10);
1066
1067 assert_relative_eq!(a.std(), std::f64::consts::SQRT_2, epsilon = 1e-10);
1069
1070 assert_relative_eq!(a.min(), 1.0, epsilon = 1e-10);
1072 assert_relative_eq!(a.max(), 5.0, epsilon = 1e-10);
1073
1074 assert_relative_eq!(a.percentile(0.0), 1.0, epsilon = 1e-10);
1076 assert_relative_eq!(a.percentile(0.5), 3.0, epsilon = 1e-10);
1077 assert_relative_eq!(a.percentile(1.0), 5.0, epsilon = 1e-10);
1078 assert_relative_eq!(a.percentile(0.25), 2.0, epsilon = 1e-10);
1079 assert_relative_eq!(a.percentile(0.75), 4.0, epsilon = 1e-10);
1080
1081 let b = Array::<f64>::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0]);
1083 let cov_result =
1084 cov(&a, Some(&b), None, None, None).expect("test: covariance should succeed");
1085 assert_relative_eq!(
1086 cov_result
1087 .get(&[0, 1])
1088 .expect("test: cov element access should succeed"),
1089 -2.5,
1090 epsilon = 1e-10
1091 );
1092 let corrcoef_result =
1093 corrcoef(&a, Some(&b), None).expect("test: correlation coefficient should succeed");
1094 assert_relative_eq!(
1095 corrcoef_result
1096 .get(&[0, 1])
1097 .expect("test: corrcoef element access should succeed"),
1098 -1.0,
1099 epsilon = 1e-10
1100 );
1101
1102 let data = Array::<f64>::from_vec(vec![1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0]);
1104 let (counts, bins) =
1105 histogram(&data, 4, None, None).expect("test: histogram should succeed");
1106 assert_eq!(counts.to_vec(), vec![2.0, 2.0, 2.0, 3.0]);
1107 assert_eq!(bins.size(), 5);
1108 assert_relative_eq!(bins.to_vec()[0], 1.0, epsilon = 1e-10);
1109 assert_relative_eq!(bins.to_vec()[4], 5.0, epsilon = 1e-10);
1110 }
1111
1112 #[test]
1113 fn test_boolean_indexing() {
1114 use crate::indexing::*;
1115
1116 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1118
1119 let mask = vec![true, false, true, false, true];
1121
1122 let _bool_array = Array::<bool>::from_vec(mask.clone());
1125
1126 let mut filtered = Array::<f64>::zeros(&[5]);
1128 let values = Array::<f64>::from_vec(vec![1.0, 3.0, 5.0]);
1129
1130 let mut value_idx = 0;
1132 for (i, &m) in mask.iter().enumerate() {
1133 if m {
1134 filtered
1135 .set(
1136 &[i],
1137 values
1138 .get(&[value_idx])
1139 .expect("test: value access should succeed"),
1140 )
1141 .expect("test: set filtered value should succeed");
1142 value_idx += 1;
1143 }
1144 }
1145
1146 assert_eq!(filtered.to_vec(), vec![1.0, 0.0, 3.0, 0.0, 5.0]);
1148
1149 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1151 .reshape(&[3, 3]);
1152
1153 let _row_indices = [0]; let _col_indices = [0]; let row_result = a_2d
1159 .index(&[IndexSpec::Index(0), IndexSpec::All])
1160 .expect("test: row indexing should succeed");
1161 assert_eq!(row_result.shape(), vec![3]); let row_vec = row_result.to_vec();
1165 assert_eq!(row_vec.len(), 3);
1166 assert_eq!(row_vec, vec![1.0, 2.0, 3.0]);
1167
1168 let col_result = a_2d
1169 .index(&[IndexSpec::All, IndexSpec::Index(0)])
1170 .expect("test: column indexing should succeed");
1171 assert_eq!(col_result.shape(), vec![3]); assert_eq!(col_result.to_vec(), vec![1.0, 4.0, 7.0]);
1173
1174 let mut a_copy = a.clone();
1176 a_copy
1177 .set_mask(
1178 &Array::<bool>::from_vec(vec![true, false, true, false, true]),
1179 &Array::<f64>::from_vec(vec![10.0, 30.0, 50.0]),
1180 )
1181 .expect("test: set_mask should succeed");
1182
1183 assert_eq!(a_copy.to_vec(), vec![10.0, 2.0, 30.0, 4.0, 50.0]);
1184 }
1185
1186 #[test]
1187 fn test_fancy_indexing() {
1188 use crate::indexing::*;
1189
1190 let _a = Array::<f64>::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
1192
1193 let _indices = [0, 1, 2];
1196 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1200 .reshape(&[3, 3]);
1201
1202 let single_element = a_2d
1204 .index(&[IndexSpec::Index(1), IndexSpec::Index(1)])
1205 .expect("test: single element indexing should succeed");
1206 assert_eq!(single_element.to_vec(), vec![5.0]);
1207
1208 let slice_result = a_2d
1210 .index(&[IndexSpec::Slice(0, Some(2), None), IndexSpec::All])
1211 .expect("test: slice indexing should succeed");
1212 assert_eq!(slice_result.shape(), vec![2, 3]);
1213 assert_eq!(slice_result.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1214 }
1215
1216 #[test]
1217 fn test_axis_operations() {
1218 use crate::axis_ops::*;
1219
1220 let mut array = Array::<f64>::zeros(&[2, 3]);
1222 array
1223 .set(&[0, 0], 1.0)
1224 .expect("test: set [0,0] should succeed");
1225 array
1226 .set(&[0, 1], 2.0)
1227 .expect("test: set [0,1] should succeed");
1228 array
1229 .set(&[0, 2], 3.0)
1230 .expect("test: set [0,2] should succeed");
1231 array
1232 .set(&[1, 0], 4.0)
1233 .expect("test: set [1,0] should succeed");
1234 array
1235 .set(&[1, 1], 5.0)
1236 .expect("test: set [1,1] should succeed");
1237 array
1238 .set(&[1, 2], 6.0)
1239 .expect("test: set [1,2] should succeed");
1240
1241 let sum_axis0 = array.sum_axis(0).expect("test: sum_axis(0) should succeed");
1243 assert_eq!(sum_axis0.shape(), vec![3]);
1244 assert_eq!(sum_axis0.to_vec(), vec![5.0, 7.0, 9.0]);
1245
1246 let sum_axis1 = array.sum_axis(1).expect("test: sum_axis(1) should succeed");
1248 assert_eq!(sum_axis1.shape(), vec![2]);
1249 assert_eq!(sum_axis1.to_vec(), vec![6.0, 15.0]);
1250
1251 let mean_axis0 = array
1253 .mean_axis(Some(0))
1254 .expect("test: mean_axis(Some(0)) should succeed");
1255 assert_eq!(mean_axis0.shape(), vec![3]);
1256 assert_eq!(mean_axis0.to_vec(), vec![2.5, 3.5, 4.5]);
1257
1258 let mean_axis1 = array
1260 .mean_axis(Some(1))
1261 .expect("test: mean_axis(Some(1)) should succeed");
1262 assert_eq!(mean_axis1.shape(), vec![2]);
1263 assert_eq!(mean_axis1.to_vec(), vec![2.0, 5.0]);
1264
1265 let min_axis0 = array
1268 .min_axis(Some(0))
1269 .expect("test: min_axis(Some(0)) should succeed");
1270 assert_eq!(min_axis0.shape(), vec![3]);
1271 let min_axis0_vec = min_axis0.to_vec();
1273 assert_eq!(min_axis0_vec, vec![1.0, 2.0, 3.0]);
1274
1275 let min_axis1 = array
1277 .min_axis(Some(1))
1278 .expect("test: min_axis(Some(1)) should succeed");
1279 assert_eq!(min_axis1.shape(), vec![2]);
1280 assert_eq!(min_axis1.to_vec(), vec![1.0, 4.0]);
1282
1283 let max_axis1 = array
1285 .max_axis(Some(1))
1286 .expect("test: max_axis(Some(1)) should succeed");
1287 assert_eq!(max_axis1.shape(), vec![2]);
1288 assert_eq!(max_axis1.to_vec(), vec![3.0, 6.0]);
1290
1291 let mut array2 = Array::<f64>::zeros(&[2, 3]);
1293 array2
1294 .set(&[0, 0], 3.0)
1295 .expect("test: set array2[0,0] should succeed");
1296 array2
1297 .set(&[0, 1], 2.0)
1298 .expect("test: set array2[0,1] should succeed");
1299 array2
1300 .set(&[0, 2], 1.0)
1301 .expect("test: set array2[0,2] should succeed");
1302 array2
1303 .set(&[1, 0], 0.0)
1304 .expect("test: set array2[1,0] should succeed");
1305 array2
1306 .set(&[1, 1], 5.0)
1307 .expect("test: set array2[1,1] should succeed");
1308 array2
1309 .set(&[1, 2], 6.0)
1310 .expect("test: set array2[1,2] should succeed");
1311
1312 let argmin_axis0 = array2
1314 .argmin_axis(0)
1315 .expect("test: argmin_axis(0) should succeed");
1316 assert_eq!(argmin_axis0.shape(), vec![3]);
1317 assert_eq!(argmin_axis0.to_vec(), vec![1, 0, 0]);
1318
1319 let var_axis0 = array
1333 .var_axis(Some(0))
1334 .expect("test: var_axis(Some(0)) should succeed");
1335 assert_eq!(var_axis0.shape(), vec![3]);
1336 assert_relative_eq!(
1337 var_axis0
1338 .get(&[0])
1339 .expect("test: var_axis0 element access should succeed"),
1340 2.25,
1341 epsilon = 1e-10
1342 );
1343
1344 let std_axis1 = array
1346 .std_axis(Some(1))
1347 .expect("test: std_axis(Some(1)) should succeed");
1348 assert_eq!(std_axis1.shape(), vec![2]);
1349
1350 let std_row1 = std_axis1
1353 .get(&[0])
1354 .expect("test: std_axis1[0] access should succeed");
1355 assert!(
1356 std_row1 > 0.8 && std_row1 < 1.1,
1357 "std_row1 ({}) should be approximately 1.0 or 0.82",
1358 std_row1
1359 );
1360
1361 let std_row2 = std_axis1
1362 .get(&[1])
1363 .expect("test: std_axis1[1] access should succeed");
1364 assert!(
1365 std_row2 > 0.8 && std_row2 < 1.1,
1366 "std_row2 ({}) should be approximately 1.0 or 0.82",
1367 std_row2
1368 );
1369 }
1370
1371 #[test]
1372 fn test_views_and_strides() {
1373 use crate::views::SliceOrIndex;
1374
1375 let mut a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1377 .reshape(&[3, 3]);
1378
1379 let view = a.view();
1381 assert_eq!(view.shape(), vec![3, 3]);
1382
1383 let mut view_mut = a.view_mut();
1385 view_mut
1386 .set(&[0, 0], 10.0)
1387 .expect("test: view_mut set should succeed");
1388 assert_eq!(
1389 a.get(&[0, 0])
1390 .expect("test: get after view_mut set should succeed"),
1391 10.0
1392 );
1393
1394 a.set(&[0, 0], 1.0)
1396 .expect("test: reset value should succeed");
1397
1398 let strided = a
1400 .strided_view(&[2, 2])
1401 .expect("test: strided_view should succeed");
1402 assert_eq!(strided.shape(), vec![2, 2]);
1403 let flat_data = strided.to_vec();
1404 assert!(flat_data.contains(&1.0));
1405 assert!(flat_data.contains(&3.0));
1406 assert!(flat_data.contains(&7.0));
1407 assert!(flat_data.contains(&9.0));
1408
1409 let slices = vec![
1411 SliceOrIndex::Slice(0, Some(2), None),
1412 SliceOrIndex::Slice(0, Some(2), None),
1413 ];
1414 let sliced = a
1415 .sliced_view(&slices)
1416 .expect("test: sliced_view should succeed");
1417 assert_eq!(sliced.shape(), vec![2, 2]);
1418 assert_eq!(sliced.to_vec(), vec![1.0, 2.0, 4.0, 5.0]);
1419
1420 let transposed = a.transposed_view();
1422 assert_eq!(transposed.shape(), vec![3, 3]);
1423 let _t_flat = transposed.to_vec();
1424 assert_eq!(
1426 transposed
1427 .get(&[0, 1])
1428 .expect("test: transposed get [0,1] should succeed"),
1429 4.0
1430 );
1431 assert_eq!(
1432 transposed
1433 .get(&[1, 0])
1434 .expect("test: transposed get [1,0] should succeed"),
1435 2.0
1436 );
1437
1438 let broadcast = a
1440 .broadcast_view(&[3, 3, 3])
1441 .expect("test: broadcast_view should succeed");
1442 assert_eq!(broadcast.shape(), vec![3, 3, 3]);
1443 assert_eq!(
1444 broadcast
1445 .get(&[0, 0, 0])
1446 .expect("test: broadcast get [0,0,0] should succeed"),
1447 1.0
1448 );
1449 assert_eq!(
1450 broadcast
1451 .get(&[1, 0, 0])
1452 .expect("test: broadcast get [1,0,0] should succeed"),
1453 1.0
1454 );
1455 }
1456
1457 #[test]
1458 fn test_universal_functions() {
1459 use crate::ufuncs::*;
1460
1461 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1463 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]);
1464
1465 let result = add(&a, &b).expect("test: ufunc add should succeed");
1467 assert_eq!(result.to_vec(), vec![6.0, 8.0, 10.0, 12.0]);
1468
1469 let result = subtract(&a, &b).expect("test: ufunc subtract should succeed");
1470 assert_eq!(result.to_vec(), vec![-4.0, -4.0, -4.0, -4.0]);
1471
1472 let result = multiply(&a, &b).expect("test: ufunc multiply should succeed");
1473 assert_eq!(result.to_vec(), vec![5.0, 12.0, 21.0, 32.0]);
1474
1475 let result = divide(&a, &b).expect("test: ufunc divide should succeed");
1476 assert_relative_eq!(result.to_vec()[0], 0.2, epsilon = 1e-10);
1477 assert_relative_eq!(result.to_vec()[1], 1.0 / 3.0, epsilon = 1e-10);
1478 assert_relative_eq!(result.to_vec()[2], 3.0 / 7.0, epsilon = 1e-10);
1479 assert_relative_eq!(result.to_vec()[3], 0.5, epsilon = 1e-10);
1480
1481 let result = power(&a, &b).expect("test: ufunc power should succeed");
1482 assert_relative_eq!(result.to_vec()[0], 1.0, epsilon = 1e-10);
1483 assert_relative_eq!(result.to_vec()[1], 64.0, epsilon = 1e-10);
1484 assert_relative_eq!(result.to_vec()[2], 2187.0, epsilon = 1e-10);
1485 assert_relative_eq!(result.to_vec()[3], 65536.0, epsilon = 1e-10);
1486
1487 let result = square(&a);
1489 assert_eq!(result.to_vec(), vec![1.0, 4.0, 9.0, 16.0]);
1490
1491 let result = sqrt(&a);
1492 assert_relative_eq!(result.to_vec()[0], 1.0, epsilon = 1e-10);
1493 assert_relative_eq!(
1494 result.to_vec()[1],
1495 std::f64::consts::SQRT_2,
1496 epsilon = 1e-10
1497 );
1498 assert_relative_eq!(result.to_vec()[2], 1.7320508075688772, epsilon = 1e-10);
1499 assert_relative_eq!(result.to_vec()[3], 2.0, epsilon = 1e-10);
1500
1501 let result = exp(&a);
1502 assert_relative_eq!(result.to_vec()[0], 1.0_f64.exp(), epsilon = 1e-10);
1503 assert_relative_eq!(result.to_vec()[1], 2.0_f64.exp(), epsilon = 1e-10);
1504 assert_relative_eq!(result.to_vec()[2], 3.0_f64.exp(), epsilon = 1e-10);
1505 assert_relative_eq!(result.to_vec()[3], 4.0_f64.exp(), epsilon = 1e-10);
1506
1507 let result = log(&a);
1508 assert_relative_eq!(result.to_vec()[0], 1.0_f64.ln(), epsilon = 1e-10);
1509 assert_relative_eq!(result.to_vec()[1], 2.0_f64.ln(), epsilon = 1e-10);
1510 assert_relative_eq!(result.to_vec()[2], 3.0_f64.ln(), epsilon = 1e-10);
1511 assert_relative_eq!(result.to_vec()[3], 4.0_f64.ln(), epsilon = 1e-10);
1512
1513 let result = multiply_scalar(&a, 2.0);
1515 assert_eq!(result.to_vec(), vec![2.0, 4.0, 6.0, 8.0]);
1516
1517 let row = Array::<f64>::from_vec(vec![10.0, 20.0]).reshape(&[1, 2]);
1519 let col = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]).reshape(&[3, 1]);
1520 let result = add(&row, &col).expect("test: ufunc add with broadcasting should succeed");
1521 assert_eq!(result.shape(), vec![3, 2]);
1522 assert_eq!(result.to_vec(), vec![11.0, 21.0, 12.0, 22.0, 13.0, 23.0]);
1523 }
1524}