numrs/autograd/
ops.rs

1//! Operaciones con soporte para autograd
2//! 
3//! Cada operación:
4//! 1. Ejecuta forward pass usando ops de NumRs
5//! 2. Si autograd está habilitado, registra backward function
6//! 3. Crea nuevo Tensor con compute node
7
8use crate::array::Array;
9use crate::autograd::{Tensor, ComputeNode, OpKind, is_grad_enabled};
10use crate::autograd::backward;
11use crate::ops;
12use anyhow::Result;
13
14impl Tensor {
15    /// Add: self + other
16    pub fn add(&self, other: &Tensor) -> Result<Self> {
17        // Forward pass
18        let result = ops::add(&self.data, &other.data)?;
19        
20        // Si autograd deshabilitado o ninguno requiere grad
21        if !is_grad_enabled() || (!self.requires_grad && !other.requires_grad) {
22            return Ok(Tensor::new(result, false));
23        }
24        
25        // Crear compute node con backward function
26        let requires_grad = self.requires_grad || other.requires_grad;
27        let backward_fn = Box::new(backward::add_backward);
28        
29        let node = ComputeNode::new(
30            OpKind::Add,
31            vec![self.clone(), other.clone()],
32            Some(backward_fn),
33        );
34        
35        Ok(Tensor::from_operation(result, node, requires_grad))
36    }
37
38    /// Sub: self - other
39    pub fn sub(&self, other: &Tensor) -> Result<Self> {
40        let result = ops::sub(&self.data, &other.data)?;
41        
42        if !is_grad_enabled() || (!self.requires_grad && !other.requires_grad) {
43            return Ok(Tensor::new(result, false));
44        }
45        
46        let requires_grad = self.requires_grad || other.requires_grad;
47        let backward_fn = Box::new(backward::sub_backward);
48        
49        let node = ComputeNode::new(
50            OpKind::Sub,
51            vec![self.clone(), other.clone()],
52            Some(backward_fn),
53        );
54        
55        Ok(Tensor::from_operation(result, node, requires_grad))
56    }
57    
58    /// Mul (elementwise): self * other
59    pub fn mul(&self, other: &Tensor) -> Result<Self> {
60        // Forward pass (elementwise multiply)
61        // Use ops::mul to handle broadcasting and promotion automatically
62        let result = ops::mul(&self.data, &other.data)?;
63        
64        if !is_grad_enabled() || (!self.requires_grad && !other.requires_grad) {
65            return Ok(Tensor::new(result, false));
66        }
67        
68        let requires_grad = self.requires_grad || other.requires_grad;
69        let backward_fn = Box::new(backward::mul_backward);
70        
71        let node = ComputeNode::new(
72            OpKind::Mul,
73            vec![self.clone(), other.clone()],
74            Some(backward_fn),
75        );
76        
77        Ok(Tensor::from_operation(result, node, requires_grad))
78    }
79
80    /// Div: self / other
81    pub fn div(&self, other: &Tensor) -> Result<Self> {
82        let result = ops::div(&self.data, &other.data)?;
83        
84        if !is_grad_enabled() || (!self.requires_grad && !other.requires_grad) {
85            return Ok(Tensor::new(result, false));
86        }
87        
88        let requires_grad = self.requires_grad || other.requires_grad;
89        let backward_fn = Box::new(backward::div_backward);
90        
91        let node = ComputeNode::new(
92            OpKind::Div,
93            vec![self.clone(), other.clone()],
94            Some(backward_fn),
95        );
96        
97        Ok(Tensor::from_operation(result, node, requires_grad))
98    }
99    
100    /// MatMul: self @ other
101    pub fn matmul(&self, other: &Tensor) -> Result<Self> {
102        // Forward pass
103        let result = ops::matmul(&self.data, &other.data)?;
104        
105        if !is_grad_enabled() || (!self.requires_grad && !other.requires_grad) {
106            return Ok(Tensor::new(result, false));
107        }
108        
109        let requires_grad = self.requires_grad || other.requires_grad;
110        let backward_fn = Box::new(backward::matmul_backward);
111        
112        let node = ComputeNode::new(
113            OpKind::MatMul,
114            vec![self.clone(), other.clone()],
115            Some(backward_fn),
116        );
117        
118        Ok(Tensor::from_operation(result, node, requires_grad))
119    }
120    
121    /// ReLU: max(0, self)
122    pub fn relu(&self) -> Result<Self> {
123        // Forward pass
124        let result = Array::new(
125            self.data.shape.clone(),
126            self.data.data.iter().map(|&x| x.max(0.0)).collect()
127        );
128        
129        if !is_grad_enabled() || !self.requires_grad {
130            return Ok(Tensor::new(result, false));
131        }
132        
133        let backward_fn = Box::new(backward::relu_backward);
134        
135        let node = ComputeNode::new(
136            OpKind::ReLU,
137            vec![self.clone()],
138            Some(backward_fn),
139        );
140        
141        Ok(Tensor::from_operation(result, node, true))
142    }
143    
144    /// Sigmoid: 1 / (1 + exp(-self))
145    pub fn sigmoid(&self) -> Result<Self> {
146        // Forward pass
147        let result = Array::new(
148            self.data.shape.clone(),
149            self.data.data.iter()
150                .map(|&x| 1.0 / (1.0 + (-x).exp()))
151                .collect()
152        );
153        
154        if !is_grad_enabled() || !self.requires_grad {
155            return Ok(Tensor::new(result, false));
156        }
157        
158        let backward_fn = Box::new(backward::sigmoid_backward);
159        
160        let node = ComputeNode::new(
161            OpKind::Sigmoid,
162            vec![self.clone()],
163            Some(backward_fn),
164        );
165        
166        Ok(Tensor::from_operation(result, node, true))
167    }
168    
169    /// Exp: exp(self)
170    pub fn exp(&self) -> Result<Self> {
171        let result = Array::new(
172            self.data.shape.clone(),
173            self.data.data.iter().map(|&x| x.exp()).collect()
174        );
175        
176        if !is_grad_enabled() || !self.requires_grad {
177            return Ok(Tensor::new(result, false));
178        }
179        
180        let backward_fn = Box::new(backward::exp_backward);
181        
182        let node = ComputeNode::new(
183            OpKind::Exp,
184            vec![self.clone()],
185            Some(backward_fn),
186        );
187        
188        Ok(Tensor::from_operation(result, node, true))
189    }
190    
191    /// Log: log(self)
192    pub fn log(&self) -> Result<Self> {
193        let result = Array::new(
194            self.data.shape.clone(),
195            self.data.data.iter().map(|&x| x.ln()).collect()
196        );
197        
198        if !is_grad_enabled() || !self.requires_grad {
199            return Ok(Tensor::new(result, false));
200        }
201        
202        let backward_fn = Box::new(backward::log_backward);
203        
204        let node = ComputeNode::new(
205            OpKind::Log,
206            vec![self.clone()],
207            Some(backward_fn),
208        );
209        
210        Ok(Tensor::from_operation(result, node, true))
211    }
212    
213    /// Sum: suma todos los elementos
214    pub fn sum(&self) -> Result<Self> {
215        // Forward pass
216        let result = ops::sum(&self.data, None)?;
217        
218        if !is_grad_enabled() || !self.requires_grad {
219            return Ok(Tensor::new(result, false));
220        }
221        
222        let backward_fn = Box::new(backward::sum_backward);
223        
224        let node = ComputeNode::new(
225            OpKind::Sum { axis: None },
226            vec![self.clone()],
227            Some(backward_fn),
228        );
229        
230        Ok(Tensor::from_operation(result, node, true))
231    }
232    
233    /// Mean: promedio de todos los elementos
234    pub fn mean(&self) -> Result<Self> {
235        let sum_val: f32 = self.data.data.iter().sum();
236        let n = self.data.data.len() as f32;
237        let result = Array::new(vec![1], vec![sum_val / n]);
238        
239        if !is_grad_enabled() || !self.requires_grad {
240            return Ok(Tensor::new(result, false));
241        }
242        
243        let backward_fn = Box::new(backward::mean_backward);
244        
245        let node = ComputeNode::new(
246            OpKind::Mean { axis: None },
247            vec![self.clone()],
248            Some(backward_fn),
249        );
250        
251        Ok(Tensor::from_operation(result, node, true))
252    }
253    
254    /// MSE Loss: mean((self - target)^2)
255    pub fn mse_loss(&self, target: &Tensor) -> Result<Self> {
256        // Forward pass
257        let diff_squared: f32 = self.data.data.iter()
258            .zip(target.data.data.iter())
259            .map(|(p, t)| (p - t).powi(2))
260            .sum();
261        let n = self.data.data.len() as f32;
262        let result = Array::new(vec![1], vec![diff_squared / n]);
263        
264        if !is_grad_enabled() || (!self.requires_grad && !target.requires_grad) {
265            return Ok(Tensor::new(result, false));
266        }
267        
268        let requires_grad = self.requires_grad || target.requires_grad;
269        let backward_fn = Box::new(backward::mse_backward);
270        
271        let node = ComputeNode::new(
272            OpKind::MSE,
273            vec![self.clone(), target.clone()],
274            Some(backward_fn),
275        );
276        
277        Ok(Tensor::from_operation(result, node, requires_grad))
278    }
279    
280    /// Cross-entropy loss con softmax integrado (Batch-aware)
281    pub fn cross_entropy_loss(&self, target: &Tensor) -> Result<Self> {
282        // Validation
283        if self.data.shape.len() != 2 || target.data.shape.len() != 2 {
284            return Err(anyhow::anyhow!("CrossEntropy expect 2D tensors [batch, classes]"));
285        }
286        
287        let batch_size = self.data.shape[0];
288        let num_classes = self.data.shape[1];
289        
290        let mut total_loss = 0.0;
291        
292        // Iterar row-wise (sample por sample)
293        for i in 0..batch_size {
294            let start = i * num_classes;
295            let end = start + num_classes;
296            let logits = &self.data.data[start..end];
297            let targets = &target.data.data[start..end];
298            
299            // Softmax per sample
300            let max_val = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
301            let exp_sum: f32 = logits.iter().map(|x| (x - max_val).exp()).sum();
302            
303            // Cross-entropy per sample: -sum(target * log(softmax))
304            // log(exp(x-max) / sum) = (x-max) - log(sum)
305            let sample_loss: f32 = logits.iter()
306                .zip(targets.iter())
307                .map(|(&x, &t)| {
308                    let log_softmax = (x - max_val) - exp_sum.ln();
309                    -t * log_softmax
310                })
311                .sum();
312                
313            total_loss += sample_loss;
314        }
315        
316        // Mean Loss over batch
317        let mean_loss = total_loss / batch_size as f32;
318        let result = Array::new(vec![1], vec![mean_loss]);
319        
320        if !is_grad_enabled() || (!self.requires_grad && !target.requires_grad) {
321            return Ok(Tensor::new(result, false));
322        }
323        
324        let requires_grad = self.requires_grad || target.requires_grad;
325        let backward_fn = Box::new(backward::cross_entropy_backward);
326        
327        let node = ComputeNode::new(
328            OpKind::CrossEntropy,
329            vec![self.clone(), target.clone()],
330            Some(backward_fn),
331        );
332        
333        Ok(Tensor::from_operation(result, node, requires_grad))
334    }
335
336    /// Flatten: flatten(start, end)
337    pub fn flatten(&self, start_dim: usize, end_dim: usize) -> Result<Self> {
338        let result = ops::flatten(&self.data, start_dim, end_dim)?;
339        
340        if !is_grad_enabled() || !self.requires_grad {
341            return Ok(Tensor::new(result, false));
342        }
343        
344        let backward_fn = Box::new(backward::flatten_backward);
345        let node = ComputeNode::new(
346            OpKind::Flatten { start_dim, end_dim },
347            vec![self.clone()],
348            Some(backward_fn),
349        );
350        Ok(Tensor::from_operation(result, node, true))
351    }
352
353    /// Reshape: input.reshape([d1, d2, ...])
354    pub fn reshape(&self, shape: Vec<usize>) -> Result<Self> {
355        let shape_isize: Vec<isize> = shape.iter().map(|&x| x as isize).collect();
356        let result = ops::reshape(&self.data, &shape_isize)?;
357        
358        if !is_grad_enabled() || !self.requires_grad {
359            return Ok(Tensor::new(result, false));
360        }
361        
362        let backward_fn = Box::new(backward::reshape_backward);
363        let node = ComputeNode::new(
364            OpKind::Reshape { shape },
365            vec![self.clone()],
366            Some(backward_fn),
367        );
368        Ok(Tensor::from_operation(result, node, true))
369    }
370
371    /// Conv1D
372    pub fn conv1d(&self, weight: &Tensor, bias: Option<&Tensor>, stride: usize, padding: usize) -> Result<Self> {
373        let bias_data = bias.map(|b| &b.data);
374        let result = ops::conv::conv1d(&self.data, &weight.data, bias_data, stride, padding)?;
375        
376        let mut inputs = vec![self.clone(), weight.clone()];
377        let mut requires_grad = self.requires_grad || weight.requires_grad;
378        
379        if let Some(b) = bias {
380            inputs.push(b.clone());
381            requires_grad = requires_grad || b.requires_grad;
382        }
383        
384        if !is_grad_enabled() || !requires_grad {
385            return Ok(Tensor::new(result, false));
386        }
387        
388        let backward_fn = Box::new(backward::conv1d_backward);
389        let node = ComputeNode::new(
390            OpKind::Conv1D { stride, padding },
391            inputs,
392            Some(backward_fn),
393        );
394        Ok(Tensor::from_operation(result, node, true))
395    }
396
397    /// BatchNorm
398    #[allow(clippy::too_many_arguments)]
399    pub fn batch_norm(&self, running_mean: &mut Tensor, running_var: &mut Tensor, weight: &Tensor, bias: &Tensor, training: bool, momentum: f32, eps: f32) -> Result<Self> {
400        let result = ops::batchnorm::batch_norm(
401            &self.data, 
402            &mut running_mean.data, 
403            &mut running_var.data, 
404            &weight.data, 
405            &bias.data, 
406            training, 
407            momentum, 
408            eps
409        )?;
410        
411        let requires_grad = self.requires_grad || weight.requires_grad || bias.requires_grad;
412        
413        if !is_grad_enabled() || !requires_grad {
414             return Ok(Tensor::new(result, false));
415        }
416        
417        // Note: Running stats are not part of gradient computation graph usually, but updated in-place.
418        // However, for ONNX export, they MUST be present as inputs to the node.
419        let inputs = vec![
420            self.clone(), 
421            weight.clone(), 
422            bias.clone(),
423            running_mean.clone(),
424            running_var.clone()
425        ];
426        let backward_fn = Box::new(backward::batchnorm_backward);
427        
428        let node = ComputeNode::new(
429            OpKind::BatchNorm { training, momentum, eps },
430            inputs,
431            Some(backward_fn),
432        );
433        Ok(Tensor::from_operation(result, node, true))
434    }
435
436    /// Dropout
437    pub fn dropout(&self, p: f32, training: bool) -> Result<Self> {
438        let result = ops::dropout::dropout(&self.data, p, training)?;
439        
440        if !is_grad_enabled() || !self.requires_grad {
441             return Ok(Tensor::new(result, false));
442        }
443        
444        let backward_fn = Box::new(backward::dropout_backward);
445        let node = ComputeNode::new(
446            OpKind::Dropout { p, training },
447            vec![self.clone()],
448            Some(backward_fn),
449        );
450        Ok(Tensor::from_operation(result, node, true))
451    }
452
453    /// Pow: self^exponent
454    pub fn pow(&self, exponent: f32) -> Result<Self> {
455        let exponent_arr = Array::new(vec![1], vec![exponent]);
456        let result = ops::pow(&self.data, &exponent_arr)?;
457        
458        if !is_grad_enabled() || !self.requires_grad {
459            return Ok(Tensor::new(result, false));
460        }
461        
462        let backward_fn = Box::new(backward::pow_backward);
463        let node = ComputeNode::new(
464            OpKind::Pow(exponent),
465            vec![self.clone()],
466            Some(backward_fn),
467        );
468        Ok(Tensor::from_operation(result, node, true))
469    }
470
471    /// Sqrt: sqrt(self)
472    pub fn sqrt(&self) -> Result<Self> {
473        let result = ops::sqrt(&self.data)?;
474        
475        if !is_grad_enabled() || !self.requires_grad {
476            return Ok(Tensor::new(result, false));
477        }
478        
479        let backward_fn = Box::new(backward::sqrt_backward);
480        let node = ComputeNode::new(
481            OpKind::Sqrt,
482            vec![self.clone()],
483            Some(backward_fn),
484        );
485        Ok(Tensor::from_operation(result, node, true))
486    }
487
488    /// Sin: sin(self)
489    pub fn sin(&self) -> Result<Self> {
490        let result = ops::sin(&self.data)?;
491        
492        if !is_grad_enabled() || !self.requires_grad {
493            return Ok(Tensor::new(result, false));
494        }
495        
496        let backward_fn = Box::new(backward::sin_backward);
497        let node = ComputeNode::new(
498            OpKind::Sin,
499            vec![self.clone()],
500            Some(backward_fn),
501        );
502        Ok(Tensor::from_operation(result, node, true))
503    }
504
505    /// Cos: cos(self)
506    pub fn cos(&self) -> Result<Self> {
507        let result = ops::cos(&self.data)?;
508        
509        if !is_grad_enabled() || !self.requires_grad {
510            return Ok(Tensor::new(result, false));
511        }
512        
513        let backward_fn = Box::new(backward::cos_backward);
514        let node = ComputeNode::new(
515            OpKind::Cos,
516            vec![self.clone()],
517            Some(backward_fn),
518        );
519        Ok(Tensor::from_operation(result, node, true))
520    }
521
522    /// Tan: tan(self)
523    pub fn tan(&self) -> Result<Self> {
524        let result = ops::tan(&self.data)?;
525        
526        if !is_grad_enabled() || !self.requires_grad {
527            return Ok(Tensor::new(result, false));
528        }
529        
530        let backward_fn = Box::new(backward::tan_backward);
531        let node = ComputeNode::new(
532            OpKind::Tan,
533            vec![self.clone()],
534            Some(backward_fn),
535        );
536        Ok(Tensor::from_operation(result, node, true))
537    }
538
539    /// Tanh: tanh(self)
540    pub fn tanh(&self) -> Result<Self> {
541        let result = ops::tanh(&self.data)?;
542        
543        if !is_grad_enabled() || !self.requires_grad {
544            return Ok(Tensor::new(result, false));
545        }
546        
547        let backward_fn = Box::new(backward::tanh_backward);
548        let node = ComputeNode::new(
549            OpKind::Tanh,
550            vec![self.clone()],
551            Some(backward_fn),
552        );
553        Ok(Tensor::from_operation(result, node, true))
554    }
555}