numrs/
lib.rs

1//! NumRs — core library
2//!
3//! This crate provides the core runtime, IR, lowering pipeline and basic
4//! backends to run simple numeric kernels. It's intentionally small and
5//! focused on proving the architecture end-to-end.
6//!
7//! Public surface will be stable: array creation + add/mul/sum
8
9pub mod array;
10pub mod array_view; // Zero-copy view for FFI
11pub mod autograd;
12pub mod backend;
13pub mod codegen;
14pub mod ir;
15pub mod llo;
16pub mod ops;
17pub mod ops_inplace; // Zero-copy operations for FFI bindings
18pub mod startup; // ← Automatic differentiation
19
20pub use array::{cast_array, promoted_dtype, Array, DType, DTypeValue};
21pub use array_view::ArrayView;
22pub use autograd::{is_grad_enabled, set_grad_enabled, NoGrad, Tensor}; // ← Autograd exports
23pub use autograd::{AdaGrad, Adam, Optimizer, RMSprop, SGD}; // ← Optimizer exports
24pub use autograd::{
25    BatchNorm1d, Conv1d, Dropout, Flatten, Linear, Module, ReLU, Sequential, Sigmoid,
26}; // ← Neural network modules
27pub use autograd::{CrossEntropyLoss, Dataset, MSELoss, Trainer, TrainerBuilder};
28pub use backend::dispatch::{get_backend_override, set_backend_override};
29pub use llo::reduction::ReductionKind;
30pub use llo::ElementwiseKind;
31pub use ops::{
32    abs, acos, add, asin, atan, cos, div, exp, log, mul, pow, relu, sigmoid, sin, softmax, sqrt,
33    sub, sum, tan, tanh,
34};
35pub use startup::print_startup_log; // ← Training API
36                                    // Re-export the compile-time HLO macro for convenience
37
38#[cfg(target_arch = "wasm32")]
39pub use backend::webgpu::{init_webgpu_wasm, set_webgpu_available_wasm};
40
41#[cfg(test)]
42mod tests {
43    use crate::array::Array;
44    use crate::ops::{add, div, mul, sub, sum};
45
46    #[test]
47    fn test_add_mul_sum() {
48        let a = Array::new(vec![3], vec![1.0, 2.0, 3.0]);
49        let b = Array::new(vec![3], vec![2.0, 2.0, 2.0]);
50
51        let c = add(&a, &b).expect("add failed");
52        assert_eq!(c.data, vec![3.0, 4.0, 5.0]);
53
54        let d = mul(&a, &b).expect("mul failed");
55        assert_eq!(d.data, vec![2.0, 4.0, 6.0]);
56
57        let s = sum(&a, None).expect("sum failed");
58        // a = [1,2,3] -> sum = 6
59        assert_eq!(s.data, vec![6.0]);
60
61        // Sub / Div tests
62        let a2 = Array::new(vec![2], vec![1.0, 0.0]);
63        let b2 = Array::new(vec![2], vec![2.0, 1.0]);
64        let sub_res = sub(&b2, &a2).expect("sub failed");
65        assert_eq!(sub_res.data, vec![1.0, 1.0]);
66        let div_res = div(&b2, &b2).expect("div failed");
67        assert_eq!(div_res.data, vec![1.0, 1.0]); // TODO: Agregar sqrt, sin, cos cuando tengamos fast-path para ops unarias
68                                                  // let x = Array::new(vec![3], vec![4.0, 9.0, 16.0]);
69                                                  // let sx = sqrt(&x).expect("sqrt failed");
70                                                  // assert_eq!(sx.to_f32().data, vec![2.0, 3.0, 4.0]);
71    }
72}