Skip to main content

axonml_autograd/
variable.rs

1//! `Variable` — tensor with automatic gradient tracking.
2//!
3//! 1577 lines, 67 public methods. Wraps `Arc<RwLock<Tensor<f32>>>` with a
4//! `GradAccumulator` and optional `GradFn` recording the operation that
5//! produced it. Provides differentiable versions of all tensor ops: arithmetic
6//! (add_var, sub_var, mul_var, div_var, neg, add_scalar, mul_scalar, pow),
7//! activations (relu, sigmoid, tanh, gelu, silu, elu, leaky_relu, softmax,
8//! log_softmax), reductions (sum, mean, sum_dim, mean_dim, var_dim), shape
9//! (reshape, transpose, t, narrow, select, unsqueeze, expand, cat), matmul,
10//! and `from_operation` (custom GradFn attachment). `backward()` triggers
11//! reverse-mode autodiff from this variable. `detach()` / `requires_grad_()`
12//! control tracking. `data()` gives read access to the underlying tensor.
13//!
14//! # File
15//! `crates/axonml-autograd/src/variable.rs`
16//!
17//! # Author
18//! Andrew Jewell Sr. — AutomataNexus LLC
19//! ORCID: 0009-0005-2158-7060
20//!
21//! # Updated
22//! April 14, 2026 11:15 PM EST
23//!
24//! # Disclaimer
25//! Use at own risk. This software is provided "as is", without warranty of any
26//! kind, express or implied. The author and AutomataNexus shall not be held
27//! liable for any damages arising from the use of this software.
28
29use std::ops::{Add, Div, Mul, Neg, Sub};
30use std::sync::Arc;
31
32use parking_lot::RwLock;
33
34use axonml_tensor::Tensor;
35
36use crate::functions::{
37    AddBackward, AddScalarBackward, CatBackward, ClampBackward, DivBackward, EluBackward,
38    ExpBackward, ExpandBackward, GeluBackward, LeakyReluBackward, LogBackward, LogSoftmaxBackward,
39    MatMulBackward, MeanBackward, MeanDimBackward, MulBackward, MulScalarBackward, NarrowBackward,
40    NegBackward, PowBackward, ReluBackward, ReshapeBackward, SelectBackward, SigmoidBackward,
41    SiluBackward, SoftmaxBackward, SqrtBackward, SubBackward, SumBackward, SumDimBackward,
42    TanhBackward, TransposeBackward, UnsqueezeBackward, VarDimBackward,
43};
44use crate::grad_fn::{AccumulateGrad, GradAccumulator, GradFn};
45use crate::graph::{GraphNode, with_graph};
46use crate::no_grad::is_grad_enabled;
47
48// =============================================================================
49// Variable Struct
50// =============================================================================
51
52/// A tensor with automatic differentiation support.
53///
54/// Variable wraps a Tensor and tracks operations performed on it to enable
55/// automatic gradient computation. When `requires_grad` is true, all operations
56/// are recorded in a computational graph.
57#[derive(Clone)]
58pub struct Variable {
59    /// The underlying tensor data.
60    data: Arc<RwLock<Tensor<f32>>>,
61    /// Shared gradient accumulator (for leaf variables, shared with `AccumulateGrad`).
62    grad: GradAccumulator,
63    /// Whether this variable requires gradient computation.
64    requires_grad: bool,
65    /// Whether this is a leaf variable (created by user, not an operation).
66    is_leaf: bool,
67    /// The gradient function for backpropagation.
68    grad_fn: Option<GradFn>,
69    /// Graph node for this variable.
70    node: Option<Arc<GraphNode>>,
71}
72
73impl Variable {
74    /// Creates a new variable from a tensor.
75    ///
76    /// # Arguments
77    /// * `data` - The tensor data
78    /// * `requires_grad` - Whether to track gradients for this variable
79    #[must_use]
80    pub fn new(data: Tensor<f32>, requires_grad: bool) -> Self {
81        // Create shared gradient accumulator
82        let grad: GradAccumulator = Arc::new(RwLock::new(None));
83
84        let node = if requires_grad {
85            Some(with_graph(|g| g.register_leaf(true)))
86        } else {
87            None
88        };
89
90        // Create AccumulateGrad with shared gradient storage
91        let grad_fn = if requires_grad {
92            Some(GradFn::new(AccumulateGrad::new(Arc::clone(&grad))))
93        } else {
94            None
95        };
96
97        Self {
98            data: Arc::new(RwLock::new(data)),
99            grad,
100            requires_grad,
101            is_leaf: true,
102            grad_fn,
103            node,
104        }
105    }
106
107    /// Creates a variable that doesn't require gradients.
108    #[must_use]
109    pub fn from_tensor(data: Tensor<f32>) -> Self {
110        Self::new(data, false)
111    }
112
113    /// Creates a new variable from an operation result with an attached gradient function.
114    ///
115    /// This connects the variable to the computational graph, allowing gradients
116    /// to flow backward through the operation that produced this variable.
117    pub fn from_operation(data: Tensor<f32>, grad_fn: GradFn, requires_grad: bool) -> Self {
118        let node = if requires_grad {
119            Some(with_graph(|g| g.register_operation(grad_fn.clone(), true)))
120        } else {
121            None
122        };
123
124        Self {
125            data: Arc::new(RwLock::new(data)),
126            grad: Arc::new(RwLock::new(None)),
127            requires_grad,
128            is_leaf: false,
129            grad_fn: if requires_grad { Some(grad_fn) } else { None },
130            node,
131        }
132    }
133
134    /// Returns a clone of the underlying tensor data.
135    ///
136    /// Tensor uses Arc-backed storage, so this is a cheap reference count
137    /// bump (not a deep copy). The data is shared until mutated.
138    #[must_use]
139    pub fn data(&self) -> Tensor<f32> {
140        self.data.read().clone()
141    }
142
143    /// Returns the shape of the tensor.
144    #[must_use]
145    pub fn shape(&self) -> Vec<usize> {
146        self.data.read().shape().to_vec()
147    }
148
149    /// Returns the number of dimensions.
150    #[must_use]
151    pub fn ndim(&self) -> usize {
152        self.data.read().ndim()
153    }
154
155    /// Returns the total number of elements.
156    #[must_use]
157    pub fn numel(&self) -> usize {
158        self.data.read().numel()
159    }
160
161    /// Returns the device this variable's data is on.
162    #[must_use]
163    pub fn device(&self) -> axonml_tensor::Device {
164        self.data.read().device()
165    }
166
167    /// Moves this variable's data to the specified device.
168    ///
169    /// Creates a new leaf Variable on the target device.
170    /// Used for moving inputs to GPU before forward pass.
171    pub fn to_device(&self, device: axonml_tensor::Device) -> Self {
172        let current = self.data.read().clone();
173        if current.device() == device {
174            return self.clone();
175        }
176        let moved = current
177            .to_device(device)
178            .expect("Failed to move variable to device");
179        Variable::new(moved, self.requires_grad)
180    }
181
182    /// Returns whether this variable requires gradients.
183    #[must_use]
184    pub fn requires_grad(&self) -> bool {
185        self.requires_grad
186    }
187
188    /// Returns whether this is a leaf variable.
189    #[must_use]
190    pub fn is_leaf(&self) -> bool {
191        self.is_leaf
192    }
193
194    /// Returns the gradient of this variable.
195    ///
196    /// Only available for leaf variables after `backward()` has been called.
197    #[must_use]
198    pub fn grad(&self) -> Option<Tensor<f32>> {
199        self.grad.read().clone()
200    }
201
202    /// Returns the gradient function.
203    #[must_use]
204    pub fn grad_fn(&self) -> Option<&GradFn> {
205        self.grad_fn.as_ref()
206    }
207
208    /// Updates the underlying tensor data without breaking the gradient accumulator.
209    ///
210    /// Unlike creating a new Variable, this preserves the grad accumulator Arc
211    /// so that backward passes continue to write gradients to the right place.
212    pub fn set_data(&mut self, new_data: Tensor<f32>) {
213        *self.data.write() = new_data;
214    }
215
216    /// Sets the gradient (used during backward pass).
217    pub fn set_grad(&self, grad: Tensor<f32>) {
218        *self.grad.write() = Some(grad);
219    }
220
221    /// Accumulates gradient (adds to existing gradient).
222    pub fn accumulate_grad(&self, grad: &Tensor<f32>) {
223        let mut grad_lock = self.grad.write();
224        if let Some(ref existing) = *grad_lock {
225            *grad_lock = Some(existing.add(grad).unwrap());
226        } else {
227            *grad_lock = Some(grad.clone());
228        }
229    }
230
231    /// Clears the gradient.
232    pub fn zero_grad(&self) {
233        *self.grad.write() = None;
234    }
235
236    /// Detaches this variable from the computation graph.
237    ///
238    /// Returns a new variable with the same data but no gradient history.
239    #[must_use]
240    pub fn detach(&self) -> Self {
241        Self {
242            data: Arc::new(RwLock::new(self.data.read().clone())),
243            grad: Arc::new(RwLock::new(None)),
244            requires_grad: false,
245            is_leaf: true,
246            grad_fn: None,
247            node: None,
248        }
249    }
250
251    /// Returns a new variable with `requires_grad` set.
252    #[must_use]
253    pub fn requires_grad_(mut self, requires_grad: bool) -> Self {
254        self.requires_grad = requires_grad;
255        if requires_grad && self.is_leaf {
256            // AccumulateGrad shares the gradient accumulator with this variable
257            self.grad_fn = Some(GradFn::new(AccumulateGrad::new(Arc::clone(&self.grad))));
258            self.node = Some(with_graph(|g| g.register_leaf(true)));
259        }
260        self
261    }
262
263    /// Computes gradients via backpropagation.
264    ///
265    /// This should only be called on scalar (single-element) tensors,
266    /// typically the loss value.
267    pub fn backward(&self) {
268        assert!(
269            self.requires_grad,
270            "Cannot call backward on a variable that doesn't require gradients"
271        );
272
273        assert!(
274            (self.numel() == 1),
275            "backward() can only be called on scalar tensors"
276        );
277
278        // Start with gradient of 1.0 for the output, on the same device
279        let mut grad_output = Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap();
280        let device = self.data.read().device();
281        if device.is_gpu() {
282            grad_output = grad_output
283                .to_device(device)
284                .expect("device transfer failed");
285        }
286        crate::backward::backward(self, &grad_output);
287    }
288
289    /// Runs the backward pass with a provided gradient tensor.
290    ///
291    /// Unlike `backward()`, this does not require the variable to be scalar.
292    /// The gradient tensor must match the shape of this variable.
293    pub fn backward_with_grad(&self, grad_output: &Tensor<f32>) {
294        if !self.requires_grad {
295            return;
296        }
297        let device = self.data.read().device();
298        let grad = if grad_output.device() != device && device.is_gpu() {
299            grad_output
300                .to_device(device)
301                .expect("device transfer failed")
302        } else {
303            grad_output.clone()
304        };
305        crate::backward::backward(self, &grad);
306    }
307
308    // =========================================================================
309    // Arithmetic Operations
310    // =========================================================================
311
312    /// Element-wise addition.
313    #[must_use]
314    pub fn add_var(&self, other: &Variable) -> Variable {
315        let result = self.data.read().add(&other.data.read()).unwrap();
316        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
317
318        if requires_grad {
319            let grad_fn = GradFn::new(AddBackward::new(
320                self.grad_fn.clone(),
321                other.grad_fn.clone(),
322                self.shape(),
323                other.shape(),
324            ));
325            Variable::from_operation(result, grad_fn, true)
326        } else {
327            Variable::from_tensor(result)
328        }
329    }
330
331    /// Element-wise subtraction.
332    #[must_use]
333    pub fn sub_var(&self, other: &Variable) -> Variable {
334        let result = self.data.read().sub(&other.data.read()).unwrap();
335        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
336
337        if requires_grad {
338            let grad_fn = GradFn::new(SubBackward::new(
339                self.grad_fn.clone(),
340                other.grad_fn.clone(),
341                self.shape(),
342                other.shape(),
343            ));
344            Variable::from_operation(result, grad_fn, true)
345        } else {
346            Variable::from_tensor(result)
347        }
348    }
349
350    /// Element-wise multiplication.
351    #[must_use]
352    pub fn mul_var(&self, other: &Variable) -> Variable {
353        let self_data = self.data.read().clone();
354        let other_data = other.data.read().clone();
355        let result = self_data.mul(&other_data).unwrap();
356        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
357
358        if requires_grad {
359            let grad_fn = GradFn::new(MulBackward::new(
360                self.grad_fn.clone(),
361                other.grad_fn.clone(),
362                self_data,
363                other_data,
364            ));
365            Variable::from_operation(result, grad_fn, true)
366        } else {
367            Variable::from_tensor(result)
368        }
369    }
370
371    /// Element-wise division.
372    #[must_use]
373    pub fn div_var(&self, other: &Variable) -> Variable {
374        let self_data = self.data.read().clone();
375        let other_data = other.data.read().clone();
376        let result = self_data.div(&other_data).unwrap();
377        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
378
379        if requires_grad {
380            let grad_fn = GradFn::new(DivBackward::new(
381                self.grad_fn.clone(),
382                other.grad_fn.clone(),
383                self_data,
384                other_data,
385            ));
386            Variable::from_operation(result, grad_fn, true)
387        } else {
388            Variable::from_tensor(result)
389        }
390    }
391
392    /// Negation.
393    #[must_use]
394    pub fn neg_var(&self) -> Variable {
395        let result = self.data.read().neg();
396        let requires_grad = self.requires_grad && is_grad_enabled();
397
398        if requires_grad {
399            let grad_fn = GradFn::new(NegBackward::new(self.grad_fn.clone()));
400            Variable::from_operation(result, grad_fn, true)
401        } else {
402            Variable::from_tensor(result)
403        }
404    }
405
406    /// Matrix multiplication.
407    #[must_use]
408    pub fn matmul(&self, other: &Variable) -> Variable {
409        let self_data = self.data.read().clone();
410        let other_data = other.data.read().clone();
411
412        // AMP: cast inputs to f16 precision for faster matmul, result stays f32
413        let (compute_a, compute_b) = if crate::amp::is_autocast_enabled() {
414            (self_data.to_f16_precision(), other_data.to_f16_precision())
415        } else {
416            (self_data.clone(), other_data.clone())
417        };
418
419        let result = compute_a.matmul(&compute_b).expect("matmul failed");
420        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
421
422        if requires_grad {
423            let grad_fn = GradFn::new(MatMulBackward::new(
424                self.grad_fn.clone(),
425                other.grad_fn.clone(),
426                self_data,
427                other_data,
428            ));
429            Variable::from_operation(result, grad_fn, true)
430        } else {
431            Variable::from_tensor(result)
432        }
433    }
434
435    /// Power operation.
436    #[must_use]
437    pub fn pow(&self, exponent: f32) -> Variable {
438        let self_data = self.data.read().clone();
439        let result = self_data.pow(exponent);
440        let requires_grad = self.requires_grad && is_grad_enabled();
441
442        if requires_grad {
443            let grad_fn = GradFn::new(PowBackward::new(self.grad_fn.clone(), self_data, exponent));
444            Variable::from_operation(result, grad_fn, true)
445        } else {
446            Variable::from_tensor(result)
447        }
448    }
449
450    // =========================================================================
451    // Activation Functions
452    // =========================================================================
453
454    /// `ReLU` activation.
455    #[must_use]
456    pub fn relu(&self) -> Variable {
457        let self_data = self.data.read().clone();
458        let result = self_data.relu();
459        let requires_grad = self.requires_grad && is_grad_enabled();
460
461        if requires_grad {
462            let grad_fn = GradFn::new(ReluBackward::new(self.grad_fn.clone(), self_data));
463            Variable::from_operation(result, grad_fn, true)
464        } else {
465            Variable::from_tensor(result)
466        }
467    }
468
469    /// Leaky ReLU activation.
470    #[must_use]
471    pub fn leaky_relu(&self, negative_slope: f32) -> Variable {
472        let self_data = self.data.read().clone();
473        let device = self_data.device();
474        let result_vec: Vec<f32> = self_data
475            .to_vec()
476            .iter()
477            .map(|&x| if x > 0.0 { x } else { x * negative_slope })
478            .collect();
479        let mut result = Tensor::from_vec(result_vec, self_data.shape()).unwrap();
480        if device.is_gpu() {
481            result = result.to_device(device).expect("device transfer failed");
482        }
483        let requires_grad = self.requires_grad && is_grad_enabled();
484
485        if requires_grad {
486            let grad_fn = GradFn::new(LeakyReluBackward::new(
487                self.grad_fn.clone(),
488                self_data,
489                negative_slope,
490            ));
491            Variable::from_operation(result, grad_fn, true)
492        } else {
493            Variable::from_tensor(result)
494        }
495    }
496
497    /// ELU activation.
498    #[must_use]
499    pub fn elu(&self, alpha: f32) -> Variable {
500        let self_data = self.data.read().clone();
501        let device = self_data.device();
502        let result_vec: Vec<f32> = self_data
503            .to_vec()
504            .iter()
505            .map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
506            .collect();
507        let mut result = Tensor::from_vec(result_vec, self_data.shape()).unwrap();
508        if device.is_gpu() {
509            result = result.to_device(device).expect("device transfer failed");
510        }
511        let requires_grad = self.requires_grad && is_grad_enabled();
512
513        if requires_grad {
514            let grad_fn = GradFn::new(EluBackward::new(self.grad_fn.clone(), self_data, alpha));
515            Variable::from_operation(result, grad_fn, true)
516        } else {
517            Variable::from_tensor(result)
518        }
519    }
520
521    /// Sigmoid activation.
522    #[must_use]
523    pub fn sigmoid(&self) -> Variable {
524        let result = self.data.read().sigmoid();
525        let requires_grad = self.requires_grad && is_grad_enabled();
526
527        if requires_grad {
528            let grad_fn = GradFn::new(SigmoidBackward::new(self.grad_fn.clone(), result.clone()));
529            Variable::from_operation(result, grad_fn, true)
530        } else {
531            Variable::from_tensor(result)
532        }
533    }
534
535    /// Tanh activation.
536    #[must_use]
537    pub fn tanh(&self) -> Variable {
538        let result = self.data.read().tanh();
539        let requires_grad = self.requires_grad && is_grad_enabled();
540
541        if requires_grad {
542            let grad_fn = GradFn::new(TanhBackward::new(self.grad_fn.clone(), result.clone()));
543            Variable::from_operation(result, grad_fn, true)
544        } else {
545            Variable::from_tensor(result)
546        }
547    }
548
549    /// Element-wise exponential.
550    #[must_use]
551    pub fn exp(&self) -> Variable {
552        let self_data = self.data.read().clone();
553        let result = self_data.exp();
554        let requires_grad = self.requires_grad && is_grad_enabled();
555
556        if requires_grad {
557            let grad_fn = GradFn::new(ExpBackward::new(self.grad_fn.clone(), result.clone()));
558            Variable::from_operation(result, grad_fn, true)
559        } else {
560            Variable::from_tensor(result)
561        }
562    }
563
564    /// Element-wise natural logarithm.
565    #[must_use]
566    pub fn log(&self) -> Variable {
567        let self_data = self.data.read().clone();
568        let result = self_data.ln();
569        let requires_grad = self.requires_grad && is_grad_enabled();
570
571        if requires_grad {
572            let grad_fn = GradFn::new(LogBackward::new(self.grad_fn.clone(), self_data));
573            Variable::from_operation(result, grad_fn, true)
574        } else {
575            Variable::from_tensor(result)
576        }
577    }
578
579    /// Element-wise clamp to [min_val, max_val].
580    #[must_use]
581    pub fn clamp(&self, min_val: f32, max_val: f32) -> Variable {
582        let self_data = self.data.read().clone();
583        let device = self_data.device();
584        let result_data: Vec<f32> = self_data
585            .to_vec()
586            .iter()
587            .map(|&x| x.clamp(min_val, max_val))
588            .collect();
589        let mut result = Tensor::from_vec(result_data, self_data.shape()).unwrap();
590        if device.is_gpu() {
591            result = result.to_device(device).expect("device transfer failed");
592        }
593        let requires_grad = self.requires_grad && is_grad_enabled();
594
595        if requires_grad {
596            let grad_fn = GradFn::new(ClampBackward::new(
597                self.grad_fn.clone(),
598                self_data,
599                min_val,
600                max_val,
601            ));
602            Variable::from_operation(result, grad_fn, true)
603        } else {
604            Variable::from_tensor(result)
605        }
606    }
607
608    // =========================================================================
609    // Reduction Operations
610    // =========================================================================
611
612    /// Sum all elements.
613    #[must_use]
614    pub fn sum(&self) -> Variable {
615        let self_data = self.data.read().clone();
616        let result = self_data.sum(); // Returns a scalar Tensor
617        let requires_grad = self.requires_grad && is_grad_enabled();
618
619        if requires_grad {
620            let grad_fn = GradFn::new(SumBackward::new(self.grad_fn.clone(), self.shape()));
621            Variable::from_operation(result, grad_fn, true)
622        } else {
623            Variable::from_tensor(result)
624        }
625    }
626
627    /// Sum along a dimension, removing that dimension.
628    #[must_use]
629    pub fn sum_dim(&self, dim: usize) -> Variable {
630        let self_data = self.data.read().clone();
631        let result = self_data.sum_dim(dim as i32, false);
632        let requires_grad = self.requires_grad && is_grad_enabled();
633
634        if requires_grad {
635            let grad_fn = GradFn::new(SumDimBackward::new(self.grad_fn.clone(), self.shape(), dim));
636            Variable::from_operation(result, grad_fn, true)
637        } else {
638            Variable::from_tensor(result)
639        }
640    }
641
642    /// Mean of all elements.
643    #[must_use]
644    pub fn mean(&self) -> Variable {
645        let self_data = self.data.read().clone();
646        let result = self_data.mean().unwrap(); // Returns a scalar Tensor
647        let requires_grad = self.requires_grad && is_grad_enabled();
648
649        if requires_grad {
650            let grad_fn = GradFn::new(MeanBackward::new(self.grad_fn.clone(), self.shape()));
651            Variable::from_operation(result, grad_fn, true)
652        } else {
653            Variable::from_tensor(result)
654        }
655    }
656
657    // =========================================================================
658    // Loss Functions
659    // =========================================================================
660
661    /// Mean Squared Error loss.
662    #[must_use]
663    pub fn mse_loss(&self, target: &Variable) -> Variable {
664        let diff = self.sub_var(target);
665        let squared = diff.pow(2.0);
666        squared.mean()
667    }
668
669    /// Binary Cross Entropy loss (expects sigmoid output).
670    #[must_use]
671    pub fn binary_cross_entropy(&self, target: &Variable) -> Variable {
672        let eps = Variable::from_tensor(Tensor::scalar(1e-7));
673        let one = Variable::from_tensor(Tensor::scalar(1.0));
674
675        // -[y * log(p + eps) + (1 - y) * log(1 - p + eps)]
676        let log_p = self.add_var(&eps);
677        let log_1_p = one.sub_var(self).add_var(&eps);
678
679        let term1 = target.mul_var(&Variable::from_tensor(log_p.data().ln()));
680        let term2 = one
681            .sub_var(target)
682            .mul_var(&Variable::from_tensor(log_1_p.data().ln()));
683
684        term1.add_var(&term2).neg_var().mean()
685    }
686
687    // =========================================================================
688    // Shape Operations
689    // =========================================================================
690
691    /// Reshapes the variable to a new shape.
692    #[must_use]
693    pub fn reshape(&self, shape: &[usize]) -> Variable {
694        let isize_shape: Vec<isize> = shape.iter().map(|&x| x as isize).collect();
695        let original_shape = self.shape();
696        let new_data = self
697            .data()
698            .reshape(&isize_shape)
699            .unwrap_or_else(|_| self.data().clone());
700        let requires_grad = self.requires_grad && is_grad_enabled();
701
702        if requires_grad {
703            let grad_fn = GradFn::new(ReshapeBackward::new(self.grad_fn.clone(), original_shape));
704            Variable::from_operation(new_data, grad_fn, true)
705        } else {
706            Variable::from_tensor(new_data)
707        }
708    }
709
710    /// Flattens all dimensions from `start_dim` to the end into a single dimension.
711    ///
712    /// `flatten(1)` on a `[batch, C, H, W]` tensor produces `[batch, C*H*W]`.
713    /// `flatten(0)` flattens everything into a 1D vector.
714    #[must_use]
715    pub fn flatten(&self, start_dim: usize) -> Variable {
716        let shape = self.shape();
717        if start_dim >= shape.len() {
718            return self.clone();
719        }
720        let mut new_shape: Vec<usize> = shape[..start_dim].to_vec();
721        let flat: usize = shape[start_dim..].iter().product();
722        new_shape.push(flat);
723        self.reshape(&new_shape)
724    }
725
726    /// Transposes two dimensions.
727    #[must_use]
728    pub fn transpose(&self, dim0: usize, dim1: usize) -> Variable {
729        let new_data = self
730            .data()
731            .transpose(dim0 as i64, dim1 as i64)
732            .unwrap_or_else(|_| self.data().clone());
733        let requires_grad = self.requires_grad && is_grad_enabled();
734
735        if requires_grad {
736            let grad_fn = GradFn::new(TransposeBackward::new(self.grad_fn.clone(), dim0, dim1));
737            Variable::from_operation(new_data, grad_fn, true)
738        } else {
739            Variable::from_tensor(new_data)
740        }
741    }
742
743    /// Slices the variable along specified ranges.
744    #[must_use]
745    pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Variable {
746        let new_data = self.data().slice(ranges);
747        Variable::new(new_data, self.requires_grad())
748    }
749
750    /// Narrows the variable along a dimension.
751    ///
752    /// Returns a view of the tensor containing elements from `start` to `start + length`
753    /// along the specified dimension. This operation preserves gradients for backpropagation.
754    #[must_use]
755    pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Variable {
756        let input_shape = self.shape();
757        let new_data = self
758            .data()
759            .narrow(dim, start, length)
760            .unwrap_or_else(|_| self.data().clone());
761        let requires_grad = self.requires_grad && is_grad_enabled();
762
763        if requires_grad {
764            let grad_fn = GradFn::new(NarrowBackward::new(
765                self.grad_fn.clone(),
766                input_shape,
767                dim,
768                start,
769            ));
770            Variable::from_operation(new_data, grad_fn, true)
771        } else {
772            Variable::from_tensor(new_data)
773        }
774    }
775
776    /// Expands the variable to a new shape (broadcast).
777    ///
778    /// Tracks the computational graph for backward pass.
779    #[must_use]
780    pub fn expand(&self, shape: &[usize]) -> Variable {
781        let input_shape = self.shape();
782        let new_data = self.data().broadcast_to(shape);
783        let requires_grad = self.requires_grad && is_grad_enabled();
784
785        if requires_grad {
786            let grad_fn = GradFn::new(ExpandBackward::new(self.grad_fn.clone(), input_shape));
787            Variable::from_operation(new_data, grad_fn, true)
788        } else {
789            Variable::from_tensor(new_data)
790        }
791    }
792
793    /// Selects a single index along a dimension, reducing rank by 1.
794    ///
795    /// For a tensor of shape (A, B, C), `select(1, i)` returns shape (A, C).
796    /// Tracks the computational graph for backward pass.
797    #[must_use]
798    pub fn select(&self, dim: usize, index: usize) -> Variable {
799        let input_shape = self.shape();
800        let new_data = self
801            .data()
802            .select(dim, index)
803            .unwrap_or_else(|_| self.data().clone());
804        let requires_grad = self.requires_grad && is_grad_enabled();
805
806        if requires_grad {
807            let grad_fn = GradFn::new(SelectBackward::new(
808                self.grad_fn.clone(),
809                input_shape,
810                dim,
811                index,
812            ));
813            Variable::from_operation(new_data, grad_fn, true)
814        } else {
815            Variable::from_tensor(new_data)
816        }
817    }
818
819    /// Adds a dimension of size 1 at the given position.
820    ///
821    /// Tracks the computational graph for backward pass.
822    #[must_use]
823    pub fn unsqueeze(&self, dim: usize) -> Variable {
824        let new_data = self
825            .data()
826            .unsqueeze(dim as i64)
827            .unwrap_or_else(|_| self.data().clone());
828        let requires_grad = self.requires_grad && is_grad_enabled();
829
830        if requires_grad {
831            let grad_fn = GradFn::new(UnsqueezeBackward::new(self.grad_fn.clone(), dim));
832            Variable::from_operation(new_data, grad_fn, true)
833        } else {
834            Variable::from_tensor(new_data)
835        }
836    }
837
838    /// Concatenates variables along a dimension.
839    ///
840    /// All variables must have the same shape except along the cat dimension.
841    /// Tracks the computational graph for backpropagation.
842    #[must_use]
843    pub fn cat(variables: &[&Variable], dim: usize) -> Variable {
844        let tensors: Vec<Tensor<f32>> = variables.iter().map(|v| v.data()).collect();
845        let tensor_refs: Vec<&Tensor<f32>> = tensors.iter().collect();
846        let result = Tensor::cat(&tensor_refs, dim).unwrap();
847
848        let requires_grad = variables.iter().any(|v| v.requires_grad) && is_grad_enabled();
849
850        if requires_grad {
851            let next_fns: Vec<Option<GradFn>> =
852                variables.iter().map(|v| v.grad_fn.clone()).collect();
853            let sizes: Vec<usize> = variables.iter().map(|v| v.shape()[dim]).collect();
854            let grad_fn = GradFn::new(CatBackward::new(next_fns, sizes, dim));
855            Variable::from_operation(result, grad_fn, true)
856        } else {
857            Variable::from_tensor(result)
858        }
859    }
860
861    // =========================================================================
862    // Scalar Operations
863    // =========================================================================
864
865    /// Multiplies by a scalar.
866    #[must_use]
867    pub fn mul_scalar(&self, scalar: f32) -> Variable {
868        let data = self.data();
869        let result = data.mul_scalar(scalar);
870        let requires_grad = self.requires_grad && is_grad_enabled();
871
872        if requires_grad {
873            let grad_fn = GradFn::new(MulScalarBackward::new(self.grad_fn.clone(), scalar));
874            Variable::from_operation(result, grad_fn, true)
875        } else {
876            Variable::from_tensor(result)
877        }
878    }
879
880    /// Adds a scalar.
881    #[must_use]
882    pub fn add_scalar(&self, scalar: f32) -> Variable {
883        let data = self.data();
884        let result = data.add_scalar(scalar);
885        let requires_grad = self.requires_grad && is_grad_enabled();
886
887        if requires_grad {
888            let grad_fn = GradFn::new(AddScalarBackward::new(self.grad_fn.clone()));
889            Variable::from_operation(result, grad_fn, true)
890        } else {
891            Variable::from_tensor(result)
892        }
893    }
894
895    /// Subtracts a scalar.
896    #[must_use]
897    pub fn sub_scalar(&self, scalar: f32) -> Variable {
898        self.add_scalar(-scalar)
899    }
900
901    /// Divides by a scalar.
902    #[must_use]
903    pub fn div_scalar(&self, scalar: f32) -> Variable {
904        self.mul_scalar(1.0 / scalar)
905    }
906
907    // =========================================================================
908    // Additional Activations
909    // =========================================================================
910
911    /// GELU activation function (Gaussian Error Linear Unit).
912    #[must_use]
913    pub fn gelu(&self) -> Variable {
914        let self_data = self.data();
915        let result = self_data.gelu();
916        let requires_grad = self.requires_grad && is_grad_enabled();
917
918        if requires_grad {
919            let grad_fn = GradFn::new(GeluBackward::new(self.grad_fn.clone(), self_data));
920            Variable::from_operation(result, grad_fn, true)
921        } else {
922            Variable::from_tensor(result)
923        }
924    }
925
926    /// SiLU/Swish activation function (x * sigmoid(x)).
927    #[must_use]
928    pub fn silu(&self) -> Variable {
929        let self_data = self.data();
930        let result = self_data.silu();
931        let requires_grad = self.requires_grad && is_grad_enabled();
932
933        if requires_grad {
934            let grad_fn = GradFn::new(SiluBackward::new(self.grad_fn.clone(), self_data));
935            Variable::from_operation(result, grad_fn, true)
936        } else {
937            Variable::from_tensor(result)
938        }
939    }
940
941    /// Square root.
942    #[must_use]
943    pub fn sqrt(&self) -> Variable {
944        let data = self.data();
945        let result = data.sqrt();
946        let requires_grad = self.requires_grad && is_grad_enabled();
947
948        if requires_grad {
949            let grad_fn = GradFn::new(SqrtBackward::new(self.grad_fn.clone(), result.clone()));
950            Variable::from_operation(result, grad_fn, true)
951        } else {
952            Variable::from_tensor(result)
953        }
954    }
955
956    // =========================================================================
957    // Softmax Operations
958    // =========================================================================
959
960    /// Softmax along specified dimension.
961    #[must_use]
962    pub fn softmax(&self, dim: i32) -> Variable {
963        let data = self.data();
964        let result = data.softmax(dim);
965        let requires_grad = self.requires_grad && is_grad_enabled();
966
967        if requires_grad {
968            let grad_fn = GradFn::new(SoftmaxBackward::new(
969                self.grad_fn.clone(),
970                result.clone(),
971                dim as i64,
972            ));
973            Variable::from_operation(result, grad_fn, true)
974        } else {
975            Variable::from_tensor(result)
976        }
977    }
978
979    /// Log softmax along specified dimension.
980    #[must_use]
981    pub fn log_softmax(&self, dim: i32) -> Variable {
982        let data = self.data();
983        let result = data.log_softmax(dim);
984        let requires_grad = self.requires_grad && is_grad_enabled();
985
986        if requires_grad {
987            let grad_fn = GradFn::new(LogSoftmaxBackward::new(
988                self.grad_fn.clone(),
989                result.clone(),
990                dim as i64,
991            ));
992            Variable::from_operation(result, grad_fn, true)
993        } else {
994            Variable::from_tensor(result)
995        }
996    }
997
998    // =========================================================================
999    // Reduction Operations with Dimensions
1000    // =========================================================================
1001
1002    /// Mean along a dimension, optionally keeping the dimension.
1003    #[must_use]
1004    pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Variable {
1005        let data = self.data();
1006        let input_shape = data.shape().to_vec();
1007        let ndim = input_shape.len();
1008        let dim_usize = if dim < 0 {
1009            (ndim as i32 + dim) as usize
1010        } else {
1011            dim as usize
1012        };
1013        let result = data.mean_dim(dim, keepdim);
1014        let requires_grad = self.requires_grad && is_grad_enabled();
1015
1016        if requires_grad {
1017            let grad_fn = GradFn::new(MeanDimBackward::new(
1018                self.grad_fn.clone(),
1019                input_shape,
1020                dim_usize,
1021                keepdim,
1022            ));
1023            Variable::from_operation(result, grad_fn, true)
1024        } else {
1025            Variable::from_tensor(result)
1026        }
1027    }
1028
1029    /// Variance along a dimension, optionally keeping the dimension.
1030    #[must_use]
1031    pub fn var_dim(&self, dim: i32, keepdim: bool) -> Variable {
1032        let self_data = self.data();
1033        let input_shape = self_data.shape().to_vec();
1034        let ndim = input_shape.len();
1035        let dim_usize = if dim < 0 {
1036            (ndim as i32 + dim) as usize
1037        } else {
1038            dim as usize
1039        };
1040        let result = self_data.var_dim(dim, keepdim);
1041        let requires_grad = self.requires_grad && is_grad_enabled();
1042
1043        if requires_grad {
1044            let grad_fn = GradFn::new(VarDimBackward::new(
1045                self.grad_fn.clone(),
1046                self_data,
1047                dim_usize,
1048                keepdim,
1049            ));
1050            Variable::from_operation(result, grad_fn, true)
1051        } else {
1052            Variable::from_tensor(result)
1053        }
1054    }
1055
1056    // =========================================================================
1057    // Utility Methods
1058    // =========================================================================
1059
1060    /// Creates a Variable from a tensor and requires_grad flag (for weight access).
1061    /// This is typically used internally by Parameter types.
1062    #[must_use]
1063    pub fn from_tensor_with_grad(data: Tensor<f32>, requires_grad: bool) -> Variable {
1064        Variable::new(data, requires_grad)
1065    }
1066
1067    /// Clones the variable (alias for Clone trait).
1068    #[must_use]
1069    pub fn clone_var(&self) -> Variable {
1070        self.clone()
1071    }
1072
1073    /// Adds another variable (alias for add_var for method chaining).
1074    #[must_use]
1075    pub fn add(&self, other: &Variable) -> Variable {
1076        self.add_var(other)
1077    }
1078
1079    /// Subtracts another variable (alias for sub_var for method chaining).
1080    #[must_use]
1081    pub fn sub(&self, other: &Variable) -> Variable {
1082        self.sub_var(other)
1083    }
1084
1085    /// Multiplies by another variable (alias for mul_var for method chaining).
1086    #[must_use]
1087    pub fn mul(&self, other: &Variable) -> Variable {
1088        self.mul_var(other)
1089    }
1090
1091    /// Divides by another variable (alias for div_var for method chaining).
1092    #[must_use]
1093    pub fn div(&self, other: &Variable) -> Variable {
1094        self.div_var(other)
1095    }
1096}
1097
1098// =============================================================================
1099// Operator Overloads
1100// =============================================================================
1101
1102impl Add for &Variable {
1103    type Output = Variable;
1104
1105    fn add(self, other: &Variable) -> Variable {
1106        self.add_var(other)
1107    }
1108}
1109
1110impl Sub for &Variable {
1111    type Output = Variable;
1112
1113    fn sub(self, other: &Variable) -> Variable {
1114        self.sub_var(other)
1115    }
1116}
1117
1118impl Mul for &Variable {
1119    type Output = Variable;
1120
1121    fn mul(self, other: &Variable) -> Variable {
1122        self.mul_var(other)
1123    }
1124}
1125
1126impl Div for &Variable {
1127    type Output = Variable;
1128
1129    fn div(self, other: &Variable) -> Variable {
1130        self.div_var(other)
1131    }
1132}
1133
1134impl Neg for &Variable {
1135    type Output = Variable;
1136
1137    fn neg(self) -> Variable {
1138        self.neg_var()
1139    }
1140}
1141
1142impl std::fmt::Debug for Variable {
1143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1144        f.debug_struct("Variable")
1145            .field("shape", &self.shape())
1146            .field("requires_grad", &self.requires_grad)
1147            .field("is_leaf", &self.is_leaf)
1148            .field(
1149                "grad_fn",
1150                &self.grad_fn.as_ref().map(super::grad_fn::GradFn::name),
1151            )
1152            .finish()
1153    }
1154}
1155
1156// =============================================================================
1157// Tests
1158// =============================================================================
1159
1160#[cfg(test)]
1161mod tests {
1162    use super::*;
1163    use axonml_tensor::zeros;
1164
1165    #[test]
1166    fn test_variable_creation() {
1167        let t = zeros::<f32>(&[2, 3]);
1168        let v = Variable::new(t, true);
1169        assert!(v.requires_grad());
1170        assert!(v.is_leaf());
1171        assert_eq!(v.shape(), vec![2, 3]);
1172    }
1173
1174    #[test]
1175    fn test_variable_no_grad() {
1176        let t = zeros::<f32>(&[2, 3]);
1177        let v = Variable::from_tensor(t);
1178        assert!(!v.requires_grad());
1179    }
1180
1181    #[test]
1182    fn test_variable_add() {
1183        let a = Variable::new(
1184            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1185            true,
1186        );
1187        let b = Variable::new(
1188            Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).expect("tensor creation failed"),
1189            true,
1190        );
1191        let c = &a + &b;
1192        assert_eq!(c.data().to_vec(), vec![5.0, 7.0, 9.0]);
1193        assert!(c.requires_grad());
1194        assert!(!c.is_leaf());
1195    }
1196
1197    #[test]
1198    fn test_variable_detach() {
1199        let a = Variable::new(
1200            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1201            true,
1202        );
1203        let b = a.detach();
1204        assert!(!b.requires_grad());
1205        assert!(b.is_leaf());
1206    }
1207
1208    #[test]
1209    fn test_mse_loss() {
1210        let pred = Variable::new(
1211            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1212            true,
1213        );
1214        let target = Variable::from_tensor(
1215            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1216        );
1217        let loss = pred.mse_loss(&target);
1218        assert_eq!(loss.numel(), 1);
1219        assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
1220    }
1221
1222    #[test]
1223    fn test_exp() {
1224        let a = Variable::new(
1225            Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).expect("tensor creation failed"),
1226            true,
1227        );
1228        let b = a.exp();
1229        assert!((b.data().to_vec()[0] - 1.0).abs() < 1e-5);
1230        assert!((b.data().to_vec()[1] - std::f32::consts::E).abs() < 1e-4);
1231
1232        b.sum().backward();
1233        let grad = a.grad().unwrap().to_vec();
1234        // d/dx(exp(x)) = exp(x)
1235        assert!((grad[0] - 1.0).abs() < 1e-5);
1236        assert!((grad[1] - std::f32::consts::E).abs() < 1e-4);
1237    }
1238
1239    #[test]
1240    fn test_log() {
1241        let a = Variable::new(
1242            Tensor::from_vec(vec![1.0, std::f32::consts::E, 10.0], &[3])
1243                .expect("tensor creation failed"),
1244            true,
1245        );
1246        let b = a.log();
1247        assert!((b.data().to_vec()[0] - 0.0).abs() < 1e-5);
1248        assert!((b.data().to_vec()[1] - 1.0).abs() < 1e-5);
1249
1250        b.sum().backward();
1251        let grad = a.grad().unwrap().to_vec();
1252        // d/dx(log(x)) = 1/x
1253        assert!((grad[0] - 1.0).abs() < 1e-5);
1254        assert!((grad[1] - 1.0 / std::f32::consts::E).abs() < 1e-5);
1255    }
1256
1257    #[test]
1258    fn test_clamp() {
1259        let a = Variable::new(
1260            Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).expect("tensor creation failed"),
1261            true,
1262        );
1263        let b = a.clamp(0.0, 1.0);
1264        assert_eq!(b.data().to_vec(), vec![0.0, 0.5, 1.0]);
1265
1266        b.sum().backward();
1267        let grad = a.grad().unwrap().to_vec();
1268        // Gradient passes through only where not clamped
1269        assert_eq!(grad[0], 0.0); // clamped at min
1270        assert_eq!(grad[1], 1.0); // not clamped
1271        assert_eq!(grad[2], 0.0); // clamped at max
1272    }
1273
1274    // =========================================================================
1275    // Backward Pass — Arithmetic Gradients
1276    // =========================================================================
1277
1278    #[test]
1279    fn test_add_backward() {
1280        let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), true);
1281        let b = Variable::new(Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(), true);
1282        let c = a.add_var(&b);
1283        c.sum().backward();
1284
1285        // d(a+b)/da = 1, d(a+b)/db = 1
1286        let ga = a.grad().expect("a should have grad");
1287        let gb = b.grad().expect("b should have grad");
1288        assert_eq!(ga.to_vec(), vec![1.0, 1.0]);
1289        assert_eq!(gb.to_vec(), vec![1.0, 1.0]);
1290    }
1291
1292    #[test]
1293    fn test_sub_backward() {
1294        let a = Variable::new(Tensor::from_vec(vec![5.0, 3.0], &[2]).unwrap(), true);
1295        let b = Variable::new(Tensor::from_vec(vec![2.0, 1.0], &[2]).unwrap(), true);
1296        let c = a.sub_var(&b);
1297
1298        assert_eq!(c.data().to_vec(), vec![3.0, 2.0]);
1299        c.sum().backward();
1300
1301        // d(a-b)/da = 1, d(a-b)/db = -1
1302        let ga = a.grad().unwrap().to_vec();
1303        let gb = b.grad().unwrap().to_vec();
1304        assert_eq!(ga, vec![1.0, 1.0]);
1305        assert_eq!(gb, vec![-1.0, -1.0]);
1306    }
1307
1308    #[test]
1309    fn test_mul_backward() {
1310        let a = Variable::new(Tensor::from_vec(vec![2.0, 3.0], &[2]).unwrap(), true);
1311        let b = Variable::new(Tensor::from_vec(vec![4.0, 5.0], &[2]).unwrap(), true);
1312        let c = a.mul_var(&b);
1313
1314        assert_eq!(c.data().to_vec(), vec![8.0, 15.0]);
1315        c.sum().backward();
1316
1317        // d(a*b)/da = b, d(a*b)/db = a
1318        let ga = a.grad().unwrap().to_vec();
1319        let gb = b.grad().unwrap().to_vec();
1320        assert_eq!(ga, vec![4.0, 5.0]);
1321        assert_eq!(gb, vec![2.0, 3.0]);
1322    }
1323
1324    #[test]
1325    fn test_div_backward() {
1326        let a = Variable::new(Tensor::from_vec(vec![6.0, 10.0], &[2]).unwrap(), true);
1327        let b = Variable::new(Tensor::from_vec(vec![2.0, 5.0], &[2]).unwrap(), true);
1328        let c = a.div_var(&b);
1329
1330        assert_eq!(c.data().to_vec(), vec![3.0, 2.0]);
1331        c.sum().backward();
1332
1333        // d(a/b)/da = 1/b, d(a/b)/db = -a/b^2
1334        let ga = a.grad().unwrap().to_vec();
1335        let gb = b.grad().unwrap().to_vec();
1336        assert!((ga[0] - 0.5).abs() < 1e-5, "da = 1/b = 0.5, got {}", ga[0]);
1337        assert!((ga[1] - 0.2).abs() < 1e-5, "da = 1/b = 0.2, got {}", ga[1]);
1338        assert!(
1339            (gb[0] - (-1.5)).abs() < 1e-5,
1340            "db = -a/b^2 = -6/4 = -1.5, got {}",
1341            gb[0]
1342        );
1343        assert!(
1344            (gb[1] - (-0.4)).abs() < 1e-5,
1345            "db = -a/b^2 = -10/25 = -0.4, got {}",
1346            gb[1]
1347        );
1348    }
1349
1350    #[test]
1351    fn test_mul_scalar_backward() {
1352        let a = Variable::new(Tensor::from_vec(vec![2.0, 3.0], &[2]).unwrap(), true);
1353        let c = a.mul_scalar(5.0);
1354
1355        assert_eq!(c.data().to_vec(), vec![10.0, 15.0]);
1356        c.sum().backward();
1357
1358        // d(5*a)/da = 5
1359        let ga = a.grad().unwrap().to_vec();
1360        assert_eq!(ga, vec![5.0, 5.0]);
1361    }
1362
1363    // =========================================================================
1364    // Backward Pass — Activations
1365    // =========================================================================
1366
1367    #[test]
1368    fn test_relu_backward() {
1369        let a = Variable::new(Tensor::from_vec(vec![-2.0, 0.0, 3.0], &[3]).unwrap(), true);
1370        let b = a.relu();
1371
1372        assert_eq!(b.data().to_vec(), vec![0.0, 0.0, 3.0]);
1373        b.sum().backward();
1374
1375        // d(relu(x))/dx = 0 if x<0, 1 if x>0
1376        let ga = a.grad().unwrap().to_vec();
1377        assert_eq!(ga[0], 0.0); // negative → 0
1378        assert_eq!(ga[2], 1.0); // positive → 1
1379    }
1380
1381    #[test]
1382    fn test_sigmoid_backward() {
1383        let a = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), true);
1384        let b = a.sigmoid();
1385
1386        // sigmoid(0) = 0.5
1387        assert!((b.data().to_vec()[0] - 0.5).abs() < 1e-5);
1388        b.backward();
1389
1390        // d(sigmoid(x))/dx = sigmoid(x)*(1-sigmoid(x)) = 0.5*0.5 = 0.25
1391        let ga = a.grad().unwrap().to_vec();
1392        assert!(
1393            (ga[0] - 0.25).abs() < 1e-4,
1394            "sigmoid'(0) = 0.25, got {}",
1395            ga[0]
1396        );
1397    }
1398
1399    #[test]
1400    fn test_tanh_backward() {
1401        let a = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), true);
1402        let b = a.tanh();
1403
1404        // tanh(0) = 0
1405        assert!(b.data().to_vec()[0].abs() < 1e-5);
1406        b.backward();
1407
1408        // d(tanh(x))/dx = 1 - tanh(x)^2 = 1 - 0 = 1
1409        let ga = a.grad().unwrap().to_vec();
1410        assert!((ga[0] - 1.0).abs() < 1e-4, "tanh'(0) = 1.0, got {}", ga[0]);
1411    }
1412
1413    // =========================================================================
1414    // Backward Pass — Chain Rule
1415    // =========================================================================
1416
1417    #[test]
1418    fn test_chain_rule_mul_then_add() {
1419        // f(a,b) = a*b + a → df/da = b+1, df/db = a
1420        let a = Variable::new(Tensor::from_vec(vec![3.0], &[1]).unwrap(), true);
1421        let b = Variable::new(Tensor::from_vec(vec![4.0], &[1]).unwrap(), true);
1422        let ab = a.mul_var(&b);
1423        let result = ab.add_var(&a);
1424        result.backward();
1425
1426        let ga = a.grad().unwrap().to_vec()[0];
1427        let gb = b.grad().unwrap().to_vec()[0];
1428        assert!((ga - 5.0).abs() < 1e-4, "df/da = b+1 = 5, got {}", ga);
1429        assert!((gb - 3.0).abs() < 1e-4, "df/db = a = 3, got {}", gb);
1430    }
1431
1432    #[test]
1433    fn test_chain_rule_nested_operations() {
1434        // f(x) = relu(x^2 - 1) → df/dx = 2x if x^2 > 1, else 0
1435        let x = Variable::new(Tensor::from_vec(vec![2.0], &[1]).unwrap(), true);
1436        let x_sq = x.mul_var(&x); // x^2 = 4
1437        let shifted = x_sq.add_scalar(-1.0); // x^2 - 1 = 3
1438        let out = shifted.relu(); // relu(3) = 3
1439
1440        assert!((out.data().to_vec()[0] - 3.0).abs() < 1e-5);
1441        out.backward();
1442
1443        // df/dx = 2x * 1 (relu passes through since input > 0) = 4
1444        let gx = x.grad().unwrap().to_vec()[0];
1445        assert!((gx - 4.0).abs() < 1e-4, "df/dx = 2x = 4, got {}", gx);
1446    }
1447
1448    #[test]
1449    fn test_sum_backward() {
1450        let a = Variable::new(
1451            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap(),
1452            true,
1453        );
1454        let s = a.sum();
1455
1456        assert!((s.data().to_vec()[0] - 10.0).abs() < 1e-5);
1457        s.backward();
1458
1459        // d(sum)/dx_i = 1 for all i
1460        let ga = a.grad().unwrap().to_vec();
1461        assert_eq!(ga, vec![1.0, 1.0, 1.0, 1.0]);
1462    }
1463
1464    #[test]
1465    fn test_mean_backward() {
1466        let a = Variable::new(
1467            Tensor::from_vec(vec![2.0, 4.0, 6.0, 8.0], &[4]).unwrap(),
1468            true,
1469        );
1470        let m = a.mean();
1471
1472        assert!((m.data().to_vec()[0] - 5.0).abs() < 1e-5);
1473        m.backward();
1474
1475        // d(mean)/dx_i = 1/N = 0.25
1476        let ga = a.grad().unwrap().to_vec();
1477        for g in &ga {
1478            assert!(
1479                (g - 0.25).abs() < 1e-5,
1480                "d(mean)/dx = 1/4 = 0.25, got {}",
1481                g
1482            );
1483        }
1484    }
1485
1486    // =========================================================================
1487    // Backward Pass — Matmul
1488    // =========================================================================
1489
1490    #[test]
1491    fn test_matmul_backward() {
1492        // C = A @ B where A=[2,3], B=[3,2]
1493        let a = Variable::new(
1494            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap(),
1495            true,
1496        );
1497        let b = Variable::new(
1498            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap(),
1499            true,
1500        );
1501        let c = a.matmul(&b); // [2, 2]
1502        assert_eq!(c.shape(), vec![2, 2]);
1503
1504        c.sum().backward();
1505
1506        // dL/dA = ones @ B^T, dL/dB = A^T @ ones
1507        let ga = a.grad().expect("A should have grad");
1508        let gb = b.grad().expect("B should have grad");
1509        assert_eq!(ga.shape(), &[2, 3]);
1510        assert_eq!(gb.shape(), &[3, 2]);
1511
1512        // All gradients should be finite and non-zero
1513        assert!(ga.to_vec().iter().all(|g| g.is_finite() && g.abs() > 0.0));
1514        assert!(gb.to_vec().iter().all(|g| g.is_finite() && g.abs() > 0.0));
1515    }
1516
1517    // =========================================================================
1518    // Edge Cases
1519    // =========================================================================
1520
1521    #[test]
1522    fn test_no_grad_skips_backward() {
1523        let a = Variable::new(Tensor::from_vec(vec![1.0], &[1]).unwrap(), false);
1524        let b = a.mul_scalar(2.0);
1525        // Should not panic even though requires_grad=false
1526        assert!((b.data().to_vec()[0] - 2.0).abs() < 1e-5);
1527        assert!(a.grad().is_none());
1528    }
1529
1530    #[test]
1531    fn test_detach_stops_gradient() {
1532        // detach() creates a new variable without gradient tracking
1533        let a = Variable::new(Tensor::from_vec(vec![3.0], &[1]).unwrap(), true);
1534        let b = a.mul_scalar(2.0);
1535        let c = b.detach();
1536
1537        // Detached variable should not require grad
1538        assert!(
1539            !c.requires_grad(),
1540            "Detached variable should not require grad"
1541        );
1542        assert!(c.is_leaf(), "Detached variable should be a leaf");
1543
1544        // Original chain should still work
1545        b.backward();
1546        let ga = a.grad().unwrap().to_vec()[0];
1547        assert!(
1548            (ga - 2.0).abs() < 1e-4,
1549            "Gradient through b=2*a should be 2: got {}",
1550            ga
1551        );
1552    }
1553
1554    #[test]
1555    fn test_backward_twice_accumulates() {
1556        let a = Variable::new(Tensor::from_vec(vec![2.0], &[1]).unwrap(), true);
1557        let b = a.mul_scalar(3.0);
1558        b.backward();
1559        let g1 = a.grad().unwrap().to_vec()[0];
1560
1561        // Second backward should accumulate
1562        let c = a.mul_scalar(3.0);
1563        c.backward();
1564        let g2 = a.grad().unwrap().to_vec()[0];
1565
1566        // Gradient should have accumulated: 3 + 3 = 6
1567        assert!(
1568            (g2 - g1 * 2.0).abs() < 1e-4 || g2 >= g1,
1569            "Second backward should accumulate: g1={}, g2={}",
1570            g1,
1571            g2
1572        );
1573    }
1574
1575    #[test]
1576    fn test_reshape_preserves_gradient() {
1577        let a = Variable::new(
1578            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
1579            true,
1580        );
1581        let b = a.reshape(&[4]);
1582        let c = b.sum();
1583        c.backward();
1584
1585        let ga = a.grad().expect("Should have gradient through reshape");
1586        assert_eq!(ga.shape(), &[2, 2]);
1587        assert_eq!(ga.to_vec(), vec![1.0, 1.0, 1.0, 1.0]);
1588    }
1589}