numrs/autograd/
mod.rs

1//! Automatic differentiation system for NumRs
2//!
3//! Este módulo implementa:
4//! 1. Compute graph con referencias a operaciones
5//! 2. Backward pass automático
6//! 3. Gradient accumulation
7//!
8//! Arquitectura:
9//! ```text
10//! Tensor (Array + autograd metadata)
11//!   ↓
12//! ComputeNode (operación + inputs)
13//!   ↓
14//! Backward functions (chain rule)
15//! ```
16
17use crate::array::Array;
18use anyhow::Result;
19use std::cell::RefCell;
20use std::rc::Rc;
21
22pub mod backward;
23pub mod nn;
24pub mod ops;
25pub mod optim;
26pub mod tensor;
27pub mod train;
28
29pub use nn::{
30    BatchNorm1d, Conv1d, Dropout, Flatten, Linear, Module, ReLU, Sequential, Sigmoid, Softmax,
31};
32pub use optim::{
33    AdaBound, AdaDelta, AdaGrad, Adam, AdamW, CosineAnnealingLR, ExponentialLR, LinearWarmup,
34    Lookahead, NAdam, Optimizer, RAdam, RMSprop, ReduceLROnPlateau, Rprop, Scheduler, StepLR, LAMB,
35    LBFGS, SGD,
36};
37pub use tensor::Tensor;
38pub use train::{
39    CrossEntropyLoss, Dataset, LossFunction, MSELoss, Metrics, Trainer, TrainerBuilder,
40};
41
42/// Identificador único para cada nodo en el compute graph
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub struct NodeId(usize);
45
46impl NodeId {
47    pub fn new() -> Self {
48        use std::sync::atomic::{AtomicUsize, Ordering};
49        static COUNTER: AtomicUsize = AtomicUsize::new(0);
50        Self(COUNTER.fetch_add(1, Ordering::SeqCst))
51    }
52}
53
54/// Tipo de operación en el compute graph
55#[derive(Debug, Clone)]
56pub enum OpKind {
57    // Operaciones básicas
58    Add,
59    Mul,
60    Sub,
61    Div,
62    Neg, // -x
63    Abs, // |x|
64
65    // Operaciones de matriz
66    MatMul,
67    Transpose,
68
69    // Operaciones elementwise
70    Exp,
71    Log,
72    Pow(f32),
73    Sqrt,
74
75    // Activaciones
76    ReLU,
77    Sigmoid,
78    Tanh,
79    Softmax,
80
81    // Trig
82    Sin,
83    Cos,
84    Tan,
85
86    // Loss functions
87    MSE,
88    CrossEntropy,
89
90    // Reductions
91    Sum {
92        axis: Option<usize>,
93    },
94    Mean {
95        axis: Option<usize>,
96    },
97    Max {
98        axis: Option<usize>,
99    },
100
101    // Neural Network
102    Conv1D {
103        stride: usize,
104        padding: usize,
105    },
106    Flatten {
107        start_dim: usize,
108        end_dim: usize,
109    },
110    Reshape {
111        shape: Vec<usize>,
112    },
113    BatchNorm {
114        training: bool,
115        momentum: f32,
116        eps: f32,
117    },
118    Dropout {
119        p: f32,
120        training: bool,
121    },
122
123    // Placeholder para leaf nodes
124    Leaf,
125}
126
127/// Función de backward para una operación específica
128///
129/// Argumentos:
130/// - `grad_output`: Gradiente que llega desde arriba (dL/dout)
131/// - `inputs`: Tensors de entrada de la operación forward
132/// - `output`: Tensor de salida de la operación forward
133///
134/// Retorna:
135/// - Vec de gradientes para cada input (dL/dinput_i)
136pub type BackwardFn = Box<dyn Fn(&Array, &[Tensor], &Tensor) -> Result<Vec<Array>>>;
137
138/// Nodo en el compute graph
139#[derive(Clone)]
140pub struct ComputeNode {
141    pub id: NodeId,
142    pub op: OpKind,
143    pub inputs: Vec<Tensor>,
144    pub backward_fn: Option<Rc<BackwardFn>>,
145}
146
147impl ComputeNode {
148    pub fn new(op: OpKind, inputs: Vec<Tensor>, backward_fn: Option<BackwardFn>) -> Self {
149        Self {
150            id: NodeId::new(),
151            op,
152            inputs,
153            backward_fn: backward_fn.map(Rc::new),
154        }
155    }
156
157    pub fn leaf() -> Self {
158        Self {
159            id: NodeId::new(),
160            op: OpKind::Leaf,
161            inputs: vec![],
162            backward_fn: None,
163        }
164    }
165}
166
167impl std::fmt::Debug for ComputeNode {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        f.debug_struct("ComputeNode")
170            .field("id", &self.id)
171            .field("op", &self.op)
172            .field("inputs", &self.inputs.len())
173            .field("has_backward", &self.backward_fn.is_some())
174            .finish()
175    }
176}
177
178/// Contexto de autograd - guarda el estado del compute graph
179#[derive(Debug, Clone)]
180pub struct AutogradContext {
181    pub enabled: bool,
182}
183
184impl AutogradContext {
185    pub fn new(enabled: bool) -> Self {
186        Self { enabled }
187    }
188
189    pub fn no_grad() -> Self {
190        Self { enabled: false }
191    }
192}
193
194thread_local! {
195    static AUTOGRAD_ENABLED: RefCell<bool> = RefCell::new(true);
196}
197
198/// Verifica si autograd está habilitado en el thread actual
199pub fn is_grad_enabled() -> bool {
200    AUTOGRAD_ENABLED.with(|enabled| *enabled.borrow())
201}
202
203/// Establece el estado de autograd
204pub fn set_grad_enabled(enabled: bool) {
205    AUTOGRAD_ENABLED.with(|e| *e.borrow_mut() = enabled);
206}
207
208/// Context manager para deshabilitar temporalmente autograd
209pub struct NoGrad {
210    prev_state: bool,
211}
212
213impl NoGrad {
214    pub fn new() -> Self {
215        let prev_state = is_grad_enabled();
216        set_grad_enabled(false);
217        Self { prev_state }
218    }
219}
220
221impl Drop for NoGrad {
222    fn drop(&mut self) {
223        set_grad_enabled(self.prev_state);
224    }
225}
226
227/// Macro para ejecutar código sin gradientes
228///
229/// # Ejemplo
230/// ```
231/// use numrs::no_grad;
232/// # struct MockModel;
233/// # impl MockModel { fn forward(&self, _i: &i32) {} }
234/// # let model = MockModel;
235/// # let input = 0;
236/// no_grad! {
237///     let pred = model.forward(&input);  // No construye compute graph
238/// }
239/// ```
240#[macro_export]
241macro_rules! no_grad {
242    ($($body:tt)*) => {{
243        let _guard = $crate::autograd::NoGrad::new();
244        $($body)*
245    }};
246}