1#![allow(deprecated)] #![allow(clippy::result_large_err)] #![allow(clippy::needless_range_loop)] #![allow(clippy::too_many_arguments)] #![allow(clippy::identity_op)] #![allow(clippy::approx_constant)] #![allow(clippy::excessive_precision)] pub mod algorithms;
71pub mod array;
72pub mod array_ops;
73pub mod array_ops_legacy;
74pub mod arrays;
75#[cfg(feature = "arrow")]
76pub mod arrow;
77pub mod autodiff;
78pub mod axis_ops;
79pub mod bitwise_ops;
80pub mod blas;
81pub mod char;
82pub mod cluster;
83pub mod comparisons;
84pub mod comparisons_broadcast;
85pub mod complex_ops;
86pub mod conversions;
87pub mod derivative;
88pub mod distance;
89#[cfg(feature = "distributed")]
90pub mod distributed;
91pub mod error;
92pub mod error_handling;
93pub mod expr;
94pub mod fft;
95pub mod financial;
96#[cfg(feature = "gpu")]
97pub mod gpu;
98pub mod indexing;
99pub mod integrate;
100pub mod interop;
101pub mod interpolate;
102pub mod io;
103pub mod linalg;
104pub mod linalg_accelerated;
105pub mod linalg_extended;
106pub mod linalg_optimized;
107pub mod linalg_parallel;
108pub mod optimized_ops; pub mod linalg_stable;
111pub mod masked;
112pub mod math;
113pub mod math_extended;
114pub mod matrix;
115pub mod memory_alloc;
116pub mod memory_optimize;
117pub mod mmap;
118pub mod ndimage;
119pub mod nn;
120pub mod ode;
121pub mod optimize;
122pub mod parallel;
123pub mod parallel_optimize;
124pub mod pde;
125pub mod printing;
126#[cfg(feature = "python")]
127pub mod python;
128pub mod random;
129pub mod roots;
130pub mod set_ops;
131pub mod shared_array;
132pub mod signal;
133pub mod simd;
134pub mod simd_optimize;
135pub mod sparse;
136pub mod sparse_enhanced;
137pub mod spatial;
138pub mod special;
139pub mod stats;
140pub mod stride_tricks;
141pub mod symbolic;
142pub mod testing;
143pub mod traits;
144pub mod types;
145pub mod ufuncs;
146pub mod unique;
147pub mod unique_optimized;
148pub mod util;
149pub mod views;
150#[cfg(feature = "visualization")]
151pub mod viz;
152#[cfg(feature = "wasm")]
153pub mod wasm;
154
155pub mod new_modules;
158
159pub use error::{NumRs2Error, Result};
160
161pub use random::random_base;
163
164#[cfg(doctest)]
166pub mod doctests {}
167
168pub mod prelude {
170 pub use crate::array::Array;
171 pub use crate::array_ops::*;
172 pub use crate::array_ops_legacy::rollaxis;
174 pub use crate::axis_ops::*;
176 pub use crate::axis_ops::{apply_along_axis, apply_over_axes, vectorize};
177 pub use crate::bitwise_ops::{
178 bitwise_and, bitwise_not, bitwise_or, bitwise_xor, invert, left_shift, left_shift_scalar,
179 right_shift, right_shift_scalar,
180 };
181 pub use crate::char;
182 pub use crate::char::{array_from_strings, StringArray, StringElement};
183 pub use crate::comparisons::{
184 all, allclose, allclose_with_tol, any, array_equal, count_nonzero, equal, flatnonzero,
185 greater, greater_equal, isclose, isclose_array, less, less_equal, logical_and, logical_not,
186 logical_or, logical_xor, not_equal,
187 };
188 pub use crate::complex_ops::{
189 absolute as complex_abs, angle as complex_angle, conj as complex_conj, from_polar,
190 imag as complex_imag, iscomplex, iscomplexobj, isreal, isrealobj, real as complex_real,
191 to_complex,
192 };
193 pub use crate::conversions::*;
194 pub use crate::error::{NumRs2Error, Result};
195 pub use crate::error_handling::{
196 errstate, geterr, geterrcall, handle_error, seterr, seterrcall, ErrorAction, ErrorState,
197 ErrorStateBuilder, ErrorStateGuard, FloatingPointError,
198 };
199 pub use crate::financial::{
200 accrued_interest,
202 amortization_schedule,
204 binomial_option_price,
206 black_scholes,
207 black_scholes_greeks,
208 bond_convexity,
209 bond_duration,
210 bond_equivalent_yield,
211 bond_price,
212 bond_yield,
213 cumipmt,
215 cumprinc,
216 db,
218 ddb,
219 effect,
221 fv,
223 fv_array,
224 implied_volatility,
225 ipmt,
227 irr,
228 irr_multiple_series,
229 mirr,
230 modified_duration,
231 nominal,
232 nper,
233 nper_array,
234 npv,
235 npv_multiple_series,
236 npv_rates,
237 pmt,
238 pmt_array,
239 ppmt,
240 pv,
241 pv_array,
242 rate,
243 rate_array,
244 sln,
246 syd,
247 AmortizationSchedule,
248 };
249 pub use crate::indexing::{
251 diag_indices, diag_indices_from, extract, indices_grid, ix_, mask_indices,
252 put as indexing_put, put_along_axis, putmask as indexing_putmask, ravel_multi_index, take,
253 take_along_axis, tril_indices, tril_indices_from, triu_indices, triu_indices_from,
254 unravel_index, IndexSpec,
255 };
256 pub use crate::io::{array_to_vec2d, vec2d_to_array, vec_to_array, SerializeFormat};
257 #[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
259 pub use crate::linalg::{
260 cholesky as cholesky_basic, eig, inv, qr as qr_basic, solve, svd as svd_basic,
261 };
262 #[cfg(feature = "lapack")]
263 pub use crate::linalg::{det, matrix_power};
264 pub use crate::linalg::{inner, kron, norm, outer, tensordot, trace, vdot};
265
266 #[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
268 pub use crate::linalg::{matrix_rank, pinv};
269 pub use crate::linalg_extended::eigenvalue;
271 pub use crate::linalg_optimized::{lu_optimized, transpose_optimized, OptimizedBlas};
272 pub use crate::linalg_parallel::ParallelLinAlg;
273 pub use crate::linalg_stable::{
274 CholeskyStableResult, QRPivotedResult, SVDStableResult, StableDecompositions,
275 };
276 pub use crate::masked::MaskedArray;
277 pub use crate::ufuncs::{abs, ceil, exp, floor, log, round, sqrt};
279 pub use crate::math_extended::{erf, erfc, gamma, gammaln};
283 pub use crate::math::{
286 amax, amin, angle, arange, argmax, argmin, argpartition, argsort, around, bartlett,
287 bincount, blackman, clip, conj, copysign, cumprod, cumsum, cumulative_prod, cumulative_sum,
288 diff, diff_extended, digitize, divmod, ediff1d, empty, fmod, frexp, gcd, geomspace,
289 gradient, hamming, hanning, heaviside, i0, imag, interp, isfinite, isinf, isnan, kaiser,
290 kurtosis, lcm, ldexp, linspace, logspace, max, mean, median, min, modf, nan_to_num, nanmax,
291 nanmean, nanmin, nanstd, nansum, nanvar, nextafter, nonzero, ones, partition, prod, real,
292 real_if_close, remainder, resize, searchsorted, sinc, skew, sort, std, sum, trapz, var,
293 zeros, ElementWiseMath,
294 };
295 pub use crate::matrix::{
296 asmatrix, matrix, matrix_from_nested, matrix_from_scalar, BandedMatrix, Matrix,
297 };
298 pub use crate::mmap::MmapArray;
299 pub use crate::random::advanced_distributions;
300 pub use crate::random::distributions;
301 pub use crate::random::generator::{default_rng, BitGenerator, Generator, StdBitGenerator};
302 pub use crate::random::{self, RandomState};
303 pub use crate::set_ops::{
304 in1d, intersect1d, isin, setdiff1d, setxor1d, union1d, unique_axis, unique_with_options,
305 };
306 pub use crate::signal::{convolve, convolve2d, correlate, correlate2d};
307 pub use crate::simd::get_simd_implementation_name;
309 pub use crate::sparse;
310 pub use crate::sparse_enhanced::SparseOpsAdvanced;
311 pub use crate::stats::{
313 average, corrcoef, cov, histogram, histogram_dd, max_along_axis, min_along_axis, mode,
314 percentile, ptp, quantile, HistBins, Statistics,
315 };
316 pub use crate::stride_tricks::{
317 as_strided, broadcast_arrays, broadcast_to, byte_strides, set_strides, sliding_window_view,
318 };
319 pub use crate::testing::{
321 arrays_close, assert_array_all_finite, assert_array_almost_equal, assert_array_equal,
322 assert_array_no_nan, assert_array_same_shape, assert_scalar_almost_equal, is_finite_array,
323 test_summary, tolerances, TestResult, ToleranceConfig,
324 };
325 pub use crate::run_tests;
327 pub use crate::traits::{
329 ArrayIndexing, ArrayMath, ArrayOps, ArrayReduction, ComplexElement, FloatingPoint,
330 IntegerElement, LinearAlgebra, MatrixDecomposition, NumericElement,
331 };
332 pub use crate::ufuncs::{
335 absolute, add, add_scalar, arctan2, cbrt, divide, divide_scalar, dot, exp2, expm1, fma,
336 hypot, log10, log1p, log2, maximum, minimum, multiply, multiply_scalar, negative, norm_l1,
337 norm_l2, power, power_scalar, reciprocal, subtract, subtract_scalar, BinaryUfunc,
338 UnaryUfunc,
339 };
340 pub use crate::unique::{unique, UniqueResult};
341 pub use crate::unique_optimized::unique_optimized;
342 pub use crate::util::{
343 astype, can_operate_inplace, fast_sum, optimize_layout, parallel_map, MemoryLayout,
344 };
345 pub use crate::views::*;
346
347 pub use crate::interop::ndarray_compat::{from_ndarray, to_ndarray};
350 pub use crate::memory_optimize::{
354 align_data, optimize_layout as memory_optimize_layout, optimize_placement,
355 AlignmentStrategy, LayoutStrategy, PlacementStrategy,
356 };
357
358 pub use crate::parallel_optimize::{
360 adaptive_threshold, optimize_parallel_computation, optimize_scheduling, partition_workload,
361 };
362 pub use crate::parallel_optimize::{
363 ParallelConfig, ParallelizationThreshold, SchedulingStrategy, WorkloadPartitioning,
364 };
365
366 pub use crate::printing::{
368 array_str, get_printoptions, reset_printoptions, set_printoptions, PrintOptions,
369 };
370
371 pub use crate::memory_alloc::{
373 get_default_allocator, get_global_allocator_strategy, init_global_allocator,
374 reset_global_allocator,
375 };
376 pub use crate::memory_alloc::{
377 AlignedAllocator, AlignmentConfig, AllocStrategy, ArenaAllocator, ArenaConfig, CacheConfig,
378 CacheLevel, CacheOptimizedAllocator, PoolAllocator, PoolConfig,
379 };
380
381 pub use crate::algorithms::{
383 BandwidthEstimate, BandwidthOptimizer, CacheAwareArrayOps, CacheAwareConvolution,
384 CacheAwareFFT, MemoryOperation,
385 };
386
387 pub use crate::parallel::parallel_algorithms::ParallelConfig as ParallelAlgorithmConfig;
389 pub use crate::parallel::{
390 global_parallel_context, initialize_parallel_context, shutdown_parallel_context, task,
391 BalancingStrategy, LoadBalancer, ParallelAllocator, ParallelAllocatorConfig,
392 ParallelArrayOps, ParallelContext, ParallelFFT, ParallelMatrixOps, ParallelScheduler,
393 SchedulerConfig, Task, TaskPriority, TaskResult, ThreadLocalAllocator, WorkStealingPool,
394 WorkloadMetrics,
395 };
396
397 pub use crate::memory_alloc::{
399 EnhancedAllocatorBridge, IntelligentAllocationStrategy, NumericalArrayAllocator,
400 };
401 pub use crate::traits::{
402 AllocationFrequency, AllocationLifetime, AllocationRequirements, AllocationStats,
403 AllocationStrategy, MemoryAllocator, MemoryAware, MemoryOptimization, MemoryUsage,
404 OptimizationType, SpecializedAllocator, ThreadingRequirements,
405 };
406
407 #[cfg(feature = "lapack")]
409 pub use crate::new_modules::eigenvalues::{eig as eig_general, eigh, eigvals, eigvalsh};
410 pub use crate::new_modules::fft::FFT;
411 #[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
412 pub use crate::new_modules::matrix_decomp::{
413 cholesky, cod, condition_number, lu, pivoted_cholesky, qr, rcond, schur, svd,
414 };
415 pub use crate::new_modules::polynomial::{
416 poly, polyadd, polychebyshev, polycompanion, polycompose, polyder, polydiv, polyextrap,
417 polyfit, polyfit_weighted, polyfromroots, polygcd, polygrid2d, polyhermite, polyint,
418 polyjacobi, polylaguerre, polylegendre, polymul, polymulx, polypower, polyresidual,
419 polyscale, polysub, polytrim, polyval2d, polyvander, polyvander2d, CubicSpline, Polynomial,
420 PolynomialInterpolation,
421 };
422
423 #[cfg(feature = "lapack")]
425 pub use crate::optimized_ops::parallel_matrix_ops;
426 pub use crate::optimized_ops::{
427 adaptive_array_sum, chunked_array_processing, get_optimization_info,
428 parallel_column_statistics, should_use_parallel, simd_elementwise_ops, simd_matmul,
429 simd_vector_ops, ColumnStats, SimdOpsResult, SimdVectorResult,
430 };
431
432 #[cfg(feature = "gpu")]
434 pub use crate::gpu::{
435 add as gpu_add, divide as gpu_divide, matmul, multiply as gpu_multiply,
436 subtract as gpu_subtract, transpose, GpuArray, GpuContext,
437 };
438 pub use crate::new_modules::sparse::{SparseArray, SparseMatrix, SparseMatrixFormat};
439 pub use crate::new_modules::special::{
440 airy_ai, airy_bi, associated_legendre_p, bessel_i, bessel_j, bessel_k, bessel_y, beta,
441 betainc, digamma, ellipe, ellipeinc, ellipf, ellipk, erfcinv, erfinv, exp1, expi, fresnel,
442 gammainc, jacobi_elliptic, lambertw, lambertwm1, legendre_p, polylog, shichi, sici,
443 spherical_harmonic, struve_h, zeta,
444 };
445 pub use crate::arrays::{
449 ArrayView, BooleanCombineOp, BroadcastEngine, BroadcastOp, BroadcastReduction,
450 FancyIndexEngine, FancyIndexResult, ResolvedIndex, Shape, SpecializedIndexing,
451 };
452
453 pub use crate::types::custom::CustomDType;
455 pub use crate::types::datetime::{
456 business_days,
457 datetime64,
459 datetime_array,
460 datetime_as_string,
461 datetime_data,
462 timedelta64,
463 DateTime64,
464 DateTimeUnit,
465 DateUnit,
466 TimeDelta64,
467 Timezone,
468 TimezoneDateTime,
469 };
470 pub use crate::types::structured::{DType, Field, RecordArray, StructuredArray};
471
472 pub use crate::shared_array::{SharedArray, SharedArrayView};
474
475 pub use crate::expr::{
477 ArrayExpr,
478 BinaryExpr,
479 CSEOptimizer,
480 CSESupport,
481 CachedExpr,
483 Expr,
485 ExprBuilder,
487 ExprCache,
488 ExprId,
489 ExprKey,
490 LazyEval,
491 ScalarExpr,
492 SharedArrayExpr,
493 SharedBinaryExpr,
494 SharedExpr,
496 SharedExprBuilder,
497 SharedScalarExpr,
498 SharedUnaryExpr,
499 UnaryExpr,
500 };
501
502 pub use crate::memory_optimize::access_patterns::{
506 cache_aware_binary_op, cache_aware_copy, cache_aware_transform, detect_layout,
507 AccessPattern, AccessStats, Block, BlockedIterator, OptimizationHints, StrideOptimizer,
508 Tile2D, TiledIterator2D,
509 };
510
511 pub use scirs2_core::ndarray::{Axis, Dimension, IxDyn, ShapeBuilder};
513 pub use scirs2_core::{Complex, Complex64};
515}
516
517#[cfg(test)]
518mod tests {
519 use crate::prelude::*;
520 use crate::simd::{simd_add, simd_div, simd_mul, simd_prod, simd_sqrt, simd_sum};
521 use approx::assert_relative_eq;
522
523 #[test]
524 fn basic_array_ops() {
525 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
526 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
527
528 let c = a.add(&b);
530 assert_eq!(c.to_vec(), vec![6.0, 8.0, 10.0, 12.0]);
531
532 let d = a.subtract(&b);
534 assert_eq!(d.to_vec(), vec![-4.0, -4.0, -4.0, -4.0]);
535
536 let e = a.multiply(&b);
538 assert_eq!(e.to_vec(), vec![5.0, 12.0, 21.0, 32.0]);
539
540 let f = a.divide(&b);
542 assert_relative_eq!(f.to_vec()[0], 0.2, epsilon = 1e-10);
543 assert_relative_eq!(f.to_vec()[1], 1.0 / 3.0, epsilon = 1e-10);
544 assert_relative_eq!(f.to_vec()[2], 3.0 / 7.0, epsilon = 1e-10);
545 assert_relative_eq!(f.to_vec()[3], 0.5, epsilon = 1e-10);
546 }
547
548 #[test]
549 fn test_broadcasting() {
550 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
552
553 let b = a.add_scalar(5.0);
555 assert_eq!(b.to_vec(), vec![6.0, 7.0, 8.0]);
556
557 let c = a.multiply_scalar(2.0);
559 assert_eq!(c.to_vec(), vec![2.0, 4.0, 6.0]);
560
561 let row = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]).reshape(&[1, 3]);
563 let col = Array::<f64>::from_vec(vec![4.0, 5.0]).reshape(&[2, 1]);
564
565 let result = row
567 .add_broadcast(&col)
568 .expect("test: broadcast addition should succeed");
569 assert_eq!(result.shape(), vec![2, 3]);
570 assert_eq!(result.to_vec(), vec![5.0, 6.0, 7.0, 6.0, 7.0, 8.0]);
571
572 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
574 let b = Array::<f64>::from_vec(vec![10.0, 20.0]).reshape(&[1, 2]);
575
576 let result = a
578 .multiply_broadcast(&b)
579 .expect("test: broadcast multiplication should succeed");
580 assert_eq!(result.shape(), vec![2, 2]);
581 assert_eq!(result.to_vec(), vec![10.0, 40.0, 30.0, 80.0]);
582
583 let shape1 = vec![3, 1, 4];
585 let shape2 = vec![2, 1];
586 let broadcast_shape = Array::<f64>::broadcast_shape(&shape1, &shape2)
587 .expect("test: broadcast shape computation should succeed");
588 assert_eq!(broadcast_shape, vec![3, 2, 4]);
589 }
590
591 #[test]
592 fn test_array_creation() {
593 let zeros = Array::<f64>::zeros(&[2, 3]);
595 assert_eq!(zeros.shape(), vec![2, 3]);
596 assert_eq!(zeros.to_vec(), vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
597
598 let ones = Array::<f64>::ones(&[2, 2]);
600 assert_eq!(ones.shape(), vec![2, 2]);
601 assert_eq!(ones.to_vec(), vec![1.0, 1.0, 1.0, 1.0]);
602
603 let fives = Array::<f64>::full(&[2, 2], 5.0);
605 assert_eq!(fives.shape(), vec![2, 2]);
606 assert_eq!(fives.to_vec(), vec![5.0, 5.0, 5.0, 5.0]);
607
608 let arr = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
610 let reshaped = arr.reshape(&[2, 3]);
611 assert_eq!(reshaped.shape(), vec![2, 3]);
612 assert_eq!(reshaped.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
613 }
614
615 #[test]
616 fn test_array_methods() {
617 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(&[2, 3]);
618
619 assert_eq!(a.shape(), vec![2, 3]);
621 assert_eq!(a.ndim(), 2);
622 assert_eq!(a.size(), 6);
623
624 let at = a.transpose();
626 assert_eq!(at.shape(), vec![3, 2]);
627
628 let at_vec = at.to_vec();
631 assert_eq!(at_vec.len(), 6);
632 assert!(at_vec.contains(&1.0));
633 assert!(at_vec.contains(&2.0));
634 assert!(at_vec.contains(&3.0));
635 assert!(at_vec.contains(&4.0));
636 assert!(at_vec.contains(&5.0));
637 assert!(at_vec.contains(&6.0));
638
639 let slice = a
641 .slice(0, 1)
642 .expect("test: slice should succeed for valid axis");
643 assert_eq!(slice.shape(), vec![3]);
644 assert_eq!(slice.to_vec(), vec![4.0, 5.0, 6.0]);
645 }
646
647 #[test]
648 fn test_map_operations() {
649 let a = Array::<f64>::from_vec(vec![1.0, 4.0, 9.0, 16.0]);
650
651 let sqrt_a = a.map(|x| x.sqrt());
653 assert_relative_eq!(sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
654 assert_relative_eq!(sqrt_a.to_vec()[1], 2.0, epsilon = 1e-10);
655 assert_relative_eq!(sqrt_a.to_vec()[2], 3.0, epsilon = 1e-10);
656 assert_relative_eq!(sqrt_a.to_vec()[3], 4.0, epsilon = 1e-10);
657
658 let par_sqrt_a = a.par_map(|x| x.sqrt());
660 assert_relative_eq!(par_sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
661 assert_relative_eq!(par_sqrt_a.to_vec()[1], 2.0, epsilon = 1e-10);
662 assert_relative_eq!(par_sqrt_a.to_vec()[2], 3.0, epsilon = 1e-10);
663 assert_relative_eq!(par_sqrt_a.to_vec()[3], 4.0, epsilon = 1e-10);
664 }
665
666 #[cfg(feature = "lapack")]
667 #[test]
668 fn test_linalg_ops() {
669 let a = Array::<f64>::from_vec(vec![4.0, 7.0, 2.0, 6.0]).reshape(&[2, 2]);
671
672 let det_a = det(&a).expect("test: determinant computation should succeed");
674 assert_relative_eq!(det_a, 10.0, epsilon = 1e-10);
675
676 let inv_a = inv(&a).expect("test: matrix inverse should succeed for invertible matrix");
678 let expected_inv = [0.6, -0.7, -0.2, 0.4];
679 for (actual, expected) in inv_a.to_vec().iter().zip(expected_inv.iter()) {
680 assert_relative_eq!(*actual, *expected, epsilon = 1e-10);
681 }
682
683 let identity = a
685 .matmul(&inv_a)
686 .expect("test: matrix multiplication should succeed");
687 assert_relative_eq!(identity.to_vec()[0], 1.0, epsilon = 1e-10);
688 assert_relative_eq!(identity.to_vec()[1], 0.0, epsilon = 1e-10);
689 assert_relative_eq!(identity.to_vec()[2], 0.0, epsilon = 1e-10);
690 assert_relative_eq!(identity.to_vec()[3], 1.0, epsilon = 1e-10);
691
692 let b = Array::<f64>::from_vec(vec![1.0, 3.0]);
694 let x = solve(&a, &b).expect("test: linear system solve should succeed");
695
696 assert_relative_eq!(x.to_vec()[0], -1.5, epsilon = 1e-10);
698 assert_relative_eq!(x.to_vec()[1], 1.0, epsilon = 1e-10);
699
700 let b_check = a
702 .matmul(&x.reshape(&[2, 1]))
703 .expect("test: matrix-vector multiplication should succeed")
704 .reshape(&[2]);
705 assert_relative_eq!(b_check.to_vec()[0], b.to_vec()[0], epsilon = 1e-10);
706 assert_relative_eq!(b_check.to_vec()[1], b.to_vec()[1], epsilon = 1e-10);
707 }
708
709 #[test]
710 fn test_tensor_operations() {
711 let a = Array::<f64>::from_vec(vec![1.0, 2.0]).reshape(&[1, 2]);
713 let b = Array::<f64>::from_vec(vec![3.0, 4.0]).reshape(&[2, 1]);
714
715 let kron_result = kron(&a, &b).expect("test: Kronecker product should succeed");
716 assert_eq!(kron_result.shape(), &[2, 2]);
717 assert_eq!(kron_result.to_vec(), vec![3.0, 6.0, 4.0, 8.0]);
718
719 let tensordot_result = tensordot(&a, &b, &[1, 0]).expect("test: tensordot should succeed");
721 assert_eq!(tensordot_result.shape(), &[1, 1]);
722 assert_relative_eq!(tensordot_result.to_vec()[0], 11.0, epsilon = 1e-10);
723 }
724
725 #[test]
726 fn test_matrix_operations() {
727 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
729 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
730
731 let c = a
733 .matmul(&b)
734 .expect("test: matrix multiplication should succeed");
735 assert_eq!(c.shape(), vec![2, 2]);
736 assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
737
738 let v = Array::<f64>::from_vec(vec![1.0, 2.0]);
740 let result = a
741 .matmul(&v.reshape(&[2, 1]))
742 .expect("test: matrix-vector multiplication should succeed")
743 .reshape(&[2]);
744 assert_eq!(result.to_vec(), vec![5.0, 11.0]);
745 }
746
747 #[test]
748 fn test_simd_operations() {
749 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
750 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]);
751
752 let c = simd_add(&a, &b).expect("test: SIMD addition should succeed");
754 assert_eq!(c.to_vec(), vec![6.0, 8.0, 10.0, 12.0]);
755
756 let d = simd_mul(&a, &b).expect("test: SIMD multiplication should succeed");
758 assert_eq!(d.to_vec(), vec![5.0, 12.0, 21.0, 32.0]);
759
760 let e = simd_div(&a, &b).expect("test: SIMD division should succeed");
762 assert_relative_eq!(e.to_vec()[0], 0.2, epsilon = 1e-10);
763 assert_relative_eq!(e.to_vec()[1], 1.0 / 3.0, epsilon = 1e-10);
764 assert_relative_eq!(e.to_vec()[2], 3.0 / 7.0, epsilon = 1e-10);
765 assert_relative_eq!(e.to_vec()[3], 0.5, epsilon = 1e-10);
766
767 let sqrt_a = simd_sqrt(&a);
769 assert_relative_eq!(sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
770 assert_relative_eq!(
771 sqrt_a.to_vec()[1],
772 std::f64::consts::SQRT_2,
773 epsilon = 1e-10
774 );
775 assert_relative_eq!(sqrt_a.to_vec()[2], 1.7320508075688772, epsilon = 1e-10);
776 assert_relative_eq!(sqrt_a.to_vec()[3], 2.0, epsilon = 1e-10);
777
778 assert_eq!(simd_sum(&a), 10.0);
780 assert_eq!(simd_prod(&a), 24.0);
781 }
782
783 #[test]
784 fn test_norm_functions() {
785 let v = Array::<f64>::from_vec(vec![3.0, 4.0]);
787
788 let norm_1 = norm(&v, Some(1.0)).expect("test: L1 norm computation should succeed");
790 assert_relative_eq!(norm_1, 7.0, epsilon = 1e-10);
791
792 let norm_2 = norm(&v, Some(2.0)).expect("test: L2 norm computation should succeed");
794 assert_relative_eq!(norm_2, 5.0, epsilon = 1e-10);
795
796 let norm_inf =
798 norm(&v, Some(f64::INFINITY)).expect("test: infinity norm computation should succeed");
799 assert_relative_eq!(norm_inf, 4.0, epsilon = 1e-10);
800
801 let m = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
803
804 let matrix_norm_1 =
806 norm(&m, Some(1.0)).expect("test: matrix L1 norm computation should succeed");
807 assert_relative_eq!(matrix_norm_1, 6.0, epsilon = 1e-10);
808
809 let matrix_norm_inf = norm(&m, Some(f64::INFINITY))
811 .expect("test: matrix infinity norm computation should succeed");
812 assert_relative_eq!(matrix_norm_inf, 7.0, epsilon = 1e-10);
813 }
814
815 #[test]
816 fn test_math_operations() {
817 use crate::math::*;
818
819 let a = Array::<f64>::from_vec(vec![1.0, 4.0, 9.0, 16.0]);
821
822 let neg_a = a.map(|x| -x);
824 let abs_a = neg_a.abs();
825 for (expected, actual) in a.to_vec().iter().zip(abs_a.to_vec().iter()) {
826 assert_relative_eq!(*expected, *actual, epsilon = 1e-10);
827 }
828
829 let exp_a = a.exp();
831 assert_relative_eq!(exp_a.to_vec()[0], 1.0_f64.exp(), epsilon = 1e-10);
832 assert_relative_eq!(exp_a.to_vec()[1], 4.0_f64.exp(), epsilon = 1e-10);
833 assert_relative_eq!(exp_a.to_vec()[2], 9.0_f64.exp(), epsilon = 1e-10);
834 assert_relative_eq!(exp_a.to_vec()[3], 16.0_f64.exp(), epsilon = 1e-10);
835
836 let log_a = a.log();
838 assert_relative_eq!(log_a.to_vec()[0], 1.0_f64.ln(), epsilon = 1e-10);
839 assert_relative_eq!(log_a.to_vec()[1], 4.0_f64.ln(), epsilon = 1e-10);
840 assert_relative_eq!(log_a.to_vec()[2], 9.0_f64.ln(), epsilon = 1e-10);
841 assert_relative_eq!(log_a.to_vec()[3], 16.0_f64.ln(), epsilon = 1e-10);
842
843 let sqrt_a = a.sqrt();
845 assert_relative_eq!(sqrt_a.to_vec()[0], 1.0, epsilon = 1e-10);
846 assert_relative_eq!(sqrt_a.to_vec()[1], 2.0, epsilon = 1e-10);
847 assert_relative_eq!(sqrt_a.to_vec()[2], 3.0, epsilon = 1e-10);
848 assert_relative_eq!(sqrt_a.to_vec()[3], 4.0, epsilon = 1e-10);
849
850 let pow_a = a.pow(2.0);
852 assert_relative_eq!(pow_a.to_vec()[0], 1.0, epsilon = 1e-10);
853 assert_relative_eq!(pow_a.to_vec()[1], 16.0, epsilon = 1e-10);
854 assert_relative_eq!(pow_a.to_vec()[2], 81.0, epsilon = 1e-10);
855 assert_relative_eq!(pow_a.to_vec()[3], 256.0, epsilon = 1e-10);
856
857 let angles = Array::<f64>::from_vec(vec![
859 0.0,
860 std::f64::consts::PI / 6.0,
861 std::f64::consts::PI / 4.0,
862 std::f64::consts::PI / 3.0,
863 ]);
864
865 let sin_angles = angles.sin();
866 assert_relative_eq!(sin_angles.to_vec()[0], 0.0, epsilon = 1e-10);
867 assert_relative_eq!(sin_angles.to_vec()[1], 0.5, epsilon = 1e-10);
868 assert_relative_eq!(
869 sin_angles.to_vec()[2],
870 1.0 / std::f64::consts::SQRT_2,
871 epsilon = 1e-10
872 );
873 assert_relative_eq!(sin_angles.to_vec()[3], 0.8660254037844386, epsilon = 1e-10);
874
875 let cos_angles = angles.cos();
876 assert_relative_eq!(cos_angles.to_vec()[0], 1.0, epsilon = 1e-10);
877 assert_relative_eq!(cos_angles.to_vec()[1], 0.8660254037844386, epsilon = 1e-10);
878 assert_relative_eq!(
879 cos_angles.to_vec()[2],
880 1.0 / std::f64::consts::SQRT_2,
881 epsilon = 1e-10
882 );
883 assert_relative_eq!(cos_angles.to_vec()[3], 0.5, epsilon = 1e-10);
884
885 let lin = linspace(0.0, 10.0, 6);
887 assert_eq!(lin.size(), 6);
888 assert_relative_eq!(lin.to_vec()[0], 0.0, epsilon = 1e-10);
889 assert_relative_eq!(lin.to_vec()[1], 2.0, epsilon = 1e-10);
890 assert_relative_eq!(lin.to_vec()[2], 4.0, epsilon = 1e-10);
891 assert_relative_eq!(lin.to_vec()[3], 6.0, epsilon = 1e-10);
892 assert_relative_eq!(lin.to_vec()[4], 8.0, epsilon = 1e-10);
893 assert_relative_eq!(lin.to_vec()[5], 10.0, epsilon = 1e-10);
894
895 let range = arange(0.0, 5.0, 1.0);
897 assert_eq!(range.size(), 5);
898 assert_eq!(range.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
899
900 let rev_range = arange(5.0, 0.0, -1.0);
902 assert_eq!(rev_range.size(), 5);
903 assert_eq!(rev_range.to_vec(), vec![5.0, 4.0, 3.0, 2.0, 1.0]);
904
905 let log_space = logspace(0.0, 3.0, 4, None);
907 assert_eq!(log_space.size(), 4);
908 assert_relative_eq!(log_space.to_vec()[0], 1.0, epsilon = 1e-10);
909 assert_relative_eq!(log_space.to_vec()[1], 10.0, epsilon = 1e-10);
910 assert_relative_eq!(log_space.to_vec()[2], 100.0, epsilon = 1e-10);
911 assert_relative_eq!(log_space.to_vec()[3], 1000.0, epsilon = 1e-10);
912
913 let geom_space = geomspace(1.0, 1000.0, 4);
915 assert_eq!(geom_space.size(), 4);
916 assert_relative_eq!(geom_space.to_vec()[0], 1.0, epsilon = 1e-10);
917 assert_relative_eq!(geom_space.to_vec()[1], 10.0, epsilon = 1e-10);
918 assert_relative_eq!(geom_space.to_vec()[2], 100.0, epsilon = 1e-10);
919 assert_relative_eq!(geom_space.to_vec()[3], 1000.0, epsilon = 1e-10);
920 }
921
922 #[test]
923 fn test_array_operations() {
924 use crate::array_ops::*;
925
926 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
928 let tiled = tile(&a, &[2]).expect("test: tile operation should succeed");
929 assert_eq!(tiled.shape(), vec![6]);
930 assert_eq!(tiled.to_vec(), vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
931
932 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
933 let tiled_2d = tile(&a_2d, &[2, 1]).expect("test: 2D tile operation should succeed");
934 assert_eq!(tiled_2d.shape(), vec![4, 2]);
935 assert_eq!(
936 tiled_2d.to_vec(),
937 vec![1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]
938 );
939
940 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
942 let repeated = repeat(&a, 2, None).expect("test: repeat operation should succeed");
943 assert_eq!(repeated.shape(), vec![6]);
944 assert_eq!(repeated.to_vec(), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
945
946 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
947 let repeated_axis0 =
948 repeat(&a_2d, 2, Some(0)).expect("test: repeat along axis 0 should succeed");
949 assert_eq!(repeated_axis0.shape(), vec![4, 2]);
950 assert_eq!(
951 repeated_axis0.to_vec(),
952 vec![1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0]
953 );
954
955 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
957 let b = Array::<f64>::from_vec(vec![4.0, 5.0, 6.0]);
958 let c = concatenate(&[&a, &b], 0).expect("test: concatenate should succeed");
959 assert_eq!(c.shape(), vec![6]);
960 assert_eq!(c.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
961
962 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
963 let b_2d = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
964 let c_axis0 =
965 concatenate(&[&a_2d, &b_2d], 0).expect("test: concatenate along axis 0 should succeed");
966 assert_eq!(c_axis0.shape(), vec![4, 2]);
967 assert_eq!(
968 c_axis0.to_vec(),
969 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
970 );
971
972 let c_axis1 =
973 concatenate(&[&a_2d, &b_2d], 1).expect("test: concatenate along axis 1 should succeed");
974 assert_eq!(c_axis1.shape(), vec![2, 4]);
975 let c_vec = c_axis1.to_vec();
976 assert_eq!(c_vec.len(), 8);
978 assert!(c_vec.contains(&1.0));
979 assert!(c_vec.contains(&2.0));
980 assert!(c_vec.contains(&3.0));
981 assert!(c_vec.contains(&4.0));
982 assert!(c_vec.contains(&5.0));
983 assert!(c_vec.contains(&6.0));
984 assert!(c_vec.contains(&7.0));
985 assert!(c_vec.contains(&8.0));
986
987 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
989 let b = Array::<f64>::from_vec(vec![4.0, 5.0, 6.0]);
990 let c = stack(&[&a, &b], 0).expect("test: stack along axis 0 should succeed");
991 assert_eq!(c.shape(), vec![2, 3]);
992 let c_vec = c.to_vec();
993 assert_eq!(c_vec.len(), 6);
995 assert!(c_vec.contains(&1.0));
996 assert!(c_vec.contains(&2.0));
997 assert!(c_vec.contains(&3.0));
998 assert!(c_vec.contains(&4.0));
999 assert!(c_vec.contains(&5.0));
1000 assert!(c_vec.contains(&6.0));
1001
1002 let d = stack(&[&a, &b], 1).expect("test: stack along axis 1 should succeed");
1003 assert_eq!(d.shape(), vec![3, 2]);
1004 let d_vec = d.to_vec();
1005 assert_eq!(d_vec.len(), 6);
1007 assert!(d_vec.contains(&1.0));
1008 assert!(d_vec.contains(&2.0));
1009 assert!(d_vec.contains(&3.0));
1010 assert!(d_vec.contains(&4.0));
1011 assert!(d_vec.contains(&5.0));
1012 assert!(d_vec.contains(&6.0));
1013
1014 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1016 let splits = split(&a, &[2, 4], 0).expect("test: split should succeed");
1017 assert_eq!(splits.len(), 3);
1018 assert_eq!(splits[0].to_vec(), vec![1.0, 2.0]);
1019 assert_eq!(splits[1].to_vec(), vec![3.0, 4.0]);
1020 assert_eq!(splits[2].to_vec(), vec![5.0, 6.0]);
1021
1022 let _a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(&[2, 3]);
1023 let splits_a =
1025 split(&a, &[2, 4], 0).expect("test: split with multiple indices should succeed");
1026 assert_eq!(splits_a.len(), 3);
1027
1028 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
1037 let expanded = expand_dims(&a, 0).expect("test: expand_dims should succeed");
1038 assert_eq!(expanded.shape(), vec![1, 3]);
1039 assert_eq!(expanded.to_vec(), vec![1.0, 2.0, 3.0]);
1040
1041 let expanded_end = expand_dims(&a, 1).expect("test: expand_dims at end should succeed");
1042 assert_eq!(expanded_end.shape(), vec![3, 1]);
1043 assert_eq!(expanded_end.to_vec(), vec![1.0, 2.0, 3.0]);
1044
1045 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]).reshape(&[1, 3, 1]);
1047 let squeezed = squeeze(&a, None).expect("test: squeeze should succeed");
1048 assert_eq!(squeezed.shape(), vec![3]);
1049 assert_eq!(squeezed.to_vec(), vec![1.0, 2.0, 3.0]);
1050
1051 let squeezed_axis = squeeze(&a, Some(0)).expect("test: squeeze at axis 0 should succeed");
1052 assert_eq!(squeezed_axis.shape(), vec![3, 1]);
1053 assert_eq!(squeezed_axis.to_vec(), vec![1.0, 2.0, 3.0]);
1054 }
1055
1056 #[test]
1057 fn test_statistics_functions() {
1058 use crate::stats::*;
1059
1060 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1062
1063 assert_relative_eq!(a.mean(), 3.0, epsilon = 1e-10);
1065
1066 assert_relative_eq!(a.var(), 2.0, epsilon = 1e-10);
1068
1069 assert_relative_eq!(a.std(), std::f64::consts::SQRT_2, epsilon = 1e-10);
1071
1072 assert_relative_eq!(a.min(), 1.0, epsilon = 1e-10);
1074 assert_relative_eq!(a.max(), 5.0, epsilon = 1e-10);
1075
1076 assert_relative_eq!(a.percentile(0.0), 1.0, epsilon = 1e-10);
1078 assert_relative_eq!(a.percentile(0.5), 3.0, epsilon = 1e-10);
1079 assert_relative_eq!(a.percentile(1.0), 5.0, epsilon = 1e-10);
1080 assert_relative_eq!(a.percentile(0.25), 2.0, epsilon = 1e-10);
1081 assert_relative_eq!(a.percentile(0.75), 4.0, epsilon = 1e-10);
1082
1083 let b = Array::<f64>::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0]);
1085 let cov_result =
1086 cov(&a, Some(&b), None, None, None).expect("test: covariance should succeed");
1087 assert_relative_eq!(
1088 cov_result
1089 .get(&[0, 1])
1090 .expect("test: cov element access should succeed"),
1091 -2.5,
1092 epsilon = 1e-10
1093 );
1094 let corrcoef_result =
1095 corrcoef(&a, Some(&b), None).expect("test: correlation coefficient should succeed");
1096 assert_relative_eq!(
1097 corrcoef_result
1098 .get(&[0, 1])
1099 .expect("test: corrcoef element access should succeed"),
1100 -1.0,
1101 epsilon = 1e-10
1102 );
1103
1104 let data = Array::<f64>::from_vec(vec![1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0]);
1106 let (counts, bins) =
1107 histogram(&data, 4, None, None).expect("test: histogram should succeed");
1108 assert_eq!(counts.to_vec(), vec![2.0, 2.0, 2.0, 3.0]);
1109 assert_eq!(bins.size(), 5);
1110 assert_relative_eq!(bins.to_vec()[0], 1.0, epsilon = 1e-10);
1111 assert_relative_eq!(bins.to_vec()[4], 5.0, epsilon = 1e-10);
1112 }
1113
1114 #[test]
1115 fn test_boolean_indexing() {
1116 use crate::indexing::*;
1117
1118 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1120
1121 let mask = vec![true, false, true, false, true];
1123
1124 let _bool_array = Array::<bool>::from_vec(mask.clone());
1127
1128 let mut filtered = Array::<f64>::zeros(&[5]);
1130 let values = Array::<f64>::from_vec(vec![1.0, 3.0, 5.0]);
1131
1132 let mut value_idx = 0;
1134 for (i, &m) in mask.iter().enumerate() {
1135 if m {
1136 filtered
1137 .set(
1138 &[i],
1139 values
1140 .get(&[value_idx])
1141 .expect("test: value access should succeed"),
1142 )
1143 .expect("test: set filtered value should succeed");
1144 value_idx += 1;
1145 }
1146 }
1147
1148 assert_eq!(filtered.to_vec(), vec![1.0, 0.0, 3.0, 0.0, 5.0]);
1150
1151 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1153 .reshape(&[3, 3]);
1154
1155 let _row_indices = [0]; let _col_indices = [0]; let row_result = a_2d
1161 .index(&[IndexSpec::Index(0), IndexSpec::All])
1162 .expect("test: row indexing should succeed");
1163 assert_eq!(row_result.shape(), vec![3]); let row_vec = row_result.to_vec();
1167 assert_eq!(row_vec.len(), 3);
1168 assert_eq!(row_vec, vec![1.0, 2.0, 3.0]);
1169
1170 let col_result = a_2d
1171 .index(&[IndexSpec::All, IndexSpec::Index(0)])
1172 .expect("test: column indexing should succeed");
1173 assert_eq!(col_result.shape(), vec![3]); assert_eq!(col_result.to_vec(), vec![1.0, 4.0, 7.0]);
1175
1176 let mut a_copy = a.clone();
1178 a_copy
1179 .set_mask(
1180 &Array::<bool>::from_vec(vec![true, false, true, false, true]),
1181 &Array::<f64>::from_vec(vec![10.0, 30.0, 50.0]),
1182 )
1183 .expect("test: set_mask should succeed");
1184
1185 assert_eq!(a_copy.to_vec(), vec![10.0, 2.0, 30.0, 4.0, 50.0]);
1186 }
1187
1188 #[test]
1189 fn test_fancy_indexing() {
1190 use crate::indexing::*;
1191
1192 let _a = Array::<f64>::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
1194
1195 let _indices = [0, 1, 2];
1198 let a_2d = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1202 .reshape(&[3, 3]);
1203
1204 let single_element = a_2d
1206 .index(&[IndexSpec::Index(1), IndexSpec::Index(1)])
1207 .expect("test: single element indexing should succeed");
1208 assert_eq!(single_element.to_vec(), vec![5.0]);
1209
1210 let slice_result = a_2d
1212 .index(&[IndexSpec::Slice(0, Some(2), None), IndexSpec::All])
1213 .expect("test: slice indexing should succeed");
1214 assert_eq!(slice_result.shape(), vec![2, 3]);
1215 assert_eq!(slice_result.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1216 }
1217
1218 #[test]
1219 fn test_axis_operations() {
1220 use crate::axis_ops::*;
1221
1222 let mut array = Array::<f64>::zeros(&[2, 3]);
1224 array
1225 .set(&[0, 0], 1.0)
1226 .expect("test: set [0,0] should succeed");
1227 array
1228 .set(&[0, 1], 2.0)
1229 .expect("test: set [0,1] should succeed");
1230 array
1231 .set(&[0, 2], 3.0)
1232 .expect("test: set [0,2] should succeed");
1233 array
1234 .set(&[1, 0], 4.0)
1235 .expect("test: set [1,0] should succeed");
1236 array
1237 .set(&[1, 1], 5.0)
1238 .expect("test: set [1,1] should succeed");
1239 array
1240 .set(&[1, 2], 6.0)
1241 .expect("test: set [1,2] should succeed");
1242
1243 let sum_axis0 = array.sum_axis(0).expect("test: sum_axis(0) should succeed");
1245 assert_eq!(sum_axis0.shape(), vec![3]);
1246 assert_eq!(sum_axis0.to_vec(), vec![5.0, 7.0, 9.0]);
1247
1248 let sum_axis1 = array.sum_axis(1).expect("test: sum_axis(1) should succeed");
1250 assert_eq!(sum_axis1.shape(), vec![2]);
1251 assert_eq!(sum_axis1.to_vec(), vec![6.0, 15.0]);
1252
1253 let mean_axis0 = array
1255 .mean_axis(Some(0))
1256 .expect("test: mean_axis(Some(0)) should succeed");
1257 assert_eq!(mean_axis0.shape(), vec![3]);
1258 assert_eq!(mean_axis0.to_vec(), vec![2.5, 3.5, 4.5]);
1259
1260 let mean_axis1 = array
1262 .mean_axis(Some(1))
1263 .expect("test: mean_axis(Some(1)) should succeed");
1264 assert_eq!(mean_axis1.shape(), vec![2]);
1265 assert_eq!(mean_axis1.to_vec(), vec![2.0, 5.0]);
1266
1267 let min_axis0 = array
1270 .min_axis(Some(0))
1271 .expect("test: min_axis(Some(0)) should succeed");
1272 assert_eq!(min_axis0.shape(), vec![3]);
1273 let min_axis0_vec = min_axis0.to_vec();
1275 assert_eq!(min_axis0_vec, vec![1.0, 2.0, 3.0]);
1276
1277 let min_axis1 = array
1279 .min_axis(Some(1))
1280 .expect("test: min_axis(Some(1)) should succeed");
1281 assert_eq!(min_axis1.shape(), vec![2]);
1282 assert_eq!(min_axis1.to_vec(), vec![1.0, 4.0]);
1284
1285 let max_axis1 = array
1287 .max_axis(Some(1))
1288 .expect("test: max_axis(Some(1)) should succeed");
1289 assert_eq!(max_axis1.shape(), vec![2]);
1290 assert_eq!(max_axis1.to_vec(), vec![3.0, 6.0]);
1292
1293 let mut array2 = Array::<f64>::zeros(&[2, 3]);
1295 array2
1296 .set(&[0, 0], 3.0)
1297 .expect("test: set array2[0,0] should succeed");
1298 array2
1299 .set(&[0, 1], 2.0)
1300 .expect("test: set array2[0,1] should succeed");
1301 array2
1302 .set(&[0, 2], 1.0)
1303 .expect("test: set array2[0,2] should succeed");
1304 array2
1305 .set(&[1, 0], 0.0)
1306 .expect("test: set array2[1,0] should succeed");
1307 array2
1308 .set(&[1, 1], 5.0)
1309 .expect("test: set array2[1,1] should succeed");
1310 array2
1311 .set(&[1, 2], 6.0)
1312 .expect("test: set array2[1,2] should succeed");
1313
1314 let argmin_axis0 = array2
1316 .argmin_axis(0)
1317 .expect("test: argmin_axis(0) should succeed");
1318 assert_eq!(argmin_axis0.shape(), vec![3]);
1319 assert_eq!(argmin_axis0.to_vec(), vec![1, 0, 0]);
1320
1321 let var_axis0 = array
1335 .var_axis(Some(0))
1336 .expect("test: var_axis(Some(0)) should succeed");
1337 assert_eq!(var_axis0.shape(), vec![3]);
1338 assert_relative_eq!(
1339 var_axis0
1340 .get(&[0])
1341 .expect("test: var_axis0 element access should succeed"),
1342 2.25,
1343 epsilon = 1e-10
1344 );
1345
1346 let std_axis1 = array
1348 .std_axis(Some(1))
1349 .expect("test: std_axis(Some(1)) should succeed");
1350 assert_eq!(std_axis1.shape(), vec![2]);
1351
1352 let std_row1 = std_axis1
1355 .get(&[0])
1356 .expect("test: std_axis1[0] access should succeed");
1357 assert!(
1358 std_row1 > 0.8 && std_row1 < 1.1,
1359 "std_row1 ({}) should be approximately 1.0 or 0.82",
1360 std_row1
1361 );
1362
1363 let std_row2 = std_axis1
1364 .get(&[1])
1365 .expect("test: std_axis1[1] access should succeed");
1366 assert!(
1367 std_row2 > 0.8 && std_row2 < 1.1,
1368 "std_row2 ({}) should be approximately 1.0 or 0.82",
1369 std_row2
1370 );
1371 }
1372
1373 #[test]
1374 fn test_views_and_strides() {
1375 use crate::views::SliceOrIndex;
1376
1377 let mut a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1379 .reshape(&[3, 3]);
1380
1381 let view = a.view();
1383 assert_eq!(view.shape(), vec![3, 3]);
1384
1385 let mut view_mut = a.view_mut();
1387 view_mut
1388 .set(&[0, 0], 10.0)
1389 .expect("test: view_mut set should succeed");
1390 assert_eq!(
1391 a.get(&[0, 0])
1392 .expect("test: get after view_mut set should succeed"),
1393 10.0
1394 );
1395
1396 a.set(&[0, 0], 1.0)
1398 .expect("test: reset value should succeed");
1399
1400 let strided = a
1402 .strided_view(&[2, 2])
1403 .expect("test: strided_view should succeed");
1404 assert_eq!(strided.shape(), vec![2, 2]);
1405 let flat_data = strided.to_vec();
1406 assert!(flat_data.contains(&1.0));
1407 assert!(flat_data.contains(&3.0));
1408 assert!(flat_data.contains(&7.0));
1409 assert!(flat_data.contains(&9.0));
1410
1411 let slices = vec![
1413 SliceOrIndex::Slice(0, Some(2), None),
1414 SliceOrIndex::Slice(0, Some(2), None),
1415 ];
1416 let sliced = a
1417 .sliced_view(&slices)
1418 .expect("test: sliced_view should succeed");
1419 assert_eq!(sliced.shape(), vec![2, 2]);
1420 assert_eq!(sliced.to_vec(), vec![1.0, 2.0, 4.0, 5.0]);
1421
1422 let transposed = a.transposed_view();
1424 assert_eq!(transposed.shape(), vec![3, 3]);
1425 let _t_flat = transposed.to_vec();
1426 assert_eq!(
1428 transposed
1429 .get(&[0, 1])
1430 .expect("test: transposed get [0,1] should succeed"),
1431 4.0
1432 );
1433 assert_eq!(
1434 transposed
1435 .get(&[1, 0])
1436 .expect("test: transposed get [1,0] should succeed"),
1437 2.0
1438 );
1439
1440 let broadcast = a
1442 .broadcast_view(&[3, 3, 3])
1443 .expect("test: broadcast_view should succeed");
1444 assert_eq!(broadcast.shape(), vec![3, 3, 3]);
1445 assert_eq!(
1446 broadcast
1447 .get(&[0, 0, 0])
1448 .expect("test: broadcast get [0,0,0] should succeed"),
1449 1.0
1450 );
1451 assert_eq!(
1452 broadcast
1453 .get(&[1, 0, 0])
1454 .expect("test: broadcast get [1,0,0] should succeed"),
1455 1.0
1456 );
1457 }
1458
1459 #[test]
1460 fn test_universal_functions() {
1461 use crate::ufuncs::*;
1462
1463 let a = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1465 let b = Array::<f64>::from_vec(vec![5.0, 6.0, 7.0, 8.0]);
1466
1467 let result = add(&a, &b).expect("test: ufunc add should succeed");
1469 assert_eq!(result.to_vec(), vec![6.0, 8.0, 10.0, 12.0]);
1470
1471 let result = subtract(&a, &b).expect("test: ufunc subtract should succeed");
1472 assert_eq!(result.to_vec(), vec![-4.0, -4.0, -4.0, -4.0]);
1473
1474 let result = multiply(&a, &b).expect("test: ufunc multiply should succeed");
1475 assert_eq!(result.to_vec(), vec![5.0, 12.0, 21.0, 32.0]);
1476
1477 let result = divide(&a, &b).expect("test: ufunc divide should succeed");
1478 assert_relative_eq!(result.to_vec()[0], 0.2, epsilon = 1e-10);
1479 assert_relative_eq!(result.to_vec()[1], 1.0 / 3.0, epsilon = 1e-10);
1480 assert_relative_eq!(result.to_vec()[2], 3.0 / 7.0, epsilon = 1e-10);
1481 assert_relative_eq!(result.to_vec()[3], 0.5, epsilon = 1e-10);
1482
1483 let result = power(&a, &b).expect("test: ufunc power should succeed");
1484 assert_relative_eq!(result.to_vec()[0], 1.0, epsilon = 1e-10);
1485 assert_relative_eq!(result.to_vec()[1], 64.0, epsilon = 1e-10);
1486 assert_relative_eq!(result.to_vec()[2], 2187.0, epsilon = 1e-10);
1487 assert_relative_eq!(result.to_vec()[3], 65536.0, epsilon = 1e-10);
1488
1489 let result = square(&a);
1491 assert_eq!(result.to_vec(), vec![1.0, 4.0, 9.0, 16.0]);
1492
1493 let result = sqrt(&a);
1494 assert_relative_eq!(result.to_vec()[0], 1.0, epsilon = 1e-10);
1495 assert_relative_eq!(
1496 result.to_vec()[1],
1497 std::f64::consts::SQRT_2,
1498 epsilon = 1e-10
1499 );
1500 assert_relative_eq!(result.to_vec()[2], 1.7320508075688772, epsilon = 1e-10);
1501 assert_relative_eq!(result.to_vec()[3], 2.0, epsilon = 1e-10);
1502
1503 let result = exp(&a);
1504 assert_relative_eq!(result.to_vec()[0], 1.0_f64.exp(), epsilon = 1e-10);
1505 assert_relative_eq!(result.to_vec()[1], 2.0_f64.exp(), epsilon = 1e-10);
1506 assert_relative_eq!(result.to_vec()[2], 3.0_f64.exp(), epsilon = 1e-10);
1507 assert_relative_eq!(result.to_vec()[3], 4.0_f64.exp(), epsilon = 1e-10);
1508
1509 let result = log(&a);
1510 assert_relative_eq!(result.to_vec()[0], 1.0_f64.ln(), epsilon = 1e-10);
1511 assert_relative_eq!(result.to_vec()[1], 2.0_f64.ln(), epsilon = 1e-10);
1512 assert_relative_eq!(result.to_vec()[2], 3.0_f64.ln(), epsilon = 1e-10);
1513 assert_relative_eq!(result.to_vec()[3], 4.0_f64.ln(), epsilon = 1e-10);
1514
1515 let result = multiply_scalar(&a, 2.0);
1517 assert_eq!(result.to_vec(), vec![2.0, 4.0, 6.0, 8.0]);
1518
1519 let row = Array::<f64>::from_vec(vec![10.0, 20.0]).reshape(&[1, 2]);
1521 let col = Array::<f64>::from_vec(vec![1.0, 2.0, 3.0]).reshape(&[3, 1]);
1522 let result = add(&row, &col).expect("test: ufunc add with broadcasting should succeed");
1523 assert_eq!(result.shape(), vec![3, 2]);
1524 assert_eq!(result.to_vec(), vec![11.0, 21.0, 12.0, 22.0, 13.0, 23.0]);
1525 }
1526}