Skip to main content

axonml_autograd/
variable.rs

1//! Variable - Tensor with Gradient Tracking
2//!
3//! # File
4//! `crates/axonml-autograd/src/variable.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::ops::{Add, Div, Mul, Neg, Sub};
18use std::sync::Arc;
19
20use parking_lot::RwLock;
21
22use axonml_tensor::Tensor;
23
24use crate::functions::{
25    AddBackward, AddScalarBackward, CatBackward, ClampBackward, DivBackward, EluBackward,
26    ExpBackward, ExpandBackward, GeluBackward, LeakyReluBackward, LogBackward, LogSoftmaxBackward,
27    MatMulBackward, MeanBackward, MeanDimBackward, MulBackward, MulScalarBackward, NarrowBackward,
28    NegBackward, PowBackward, ReluBackward, ReshapeBackward, SelectBackward, SigmoidBackward,
29    SiluBackward, SoftmaxBackward, SqrtBackward, SubBackward, SumBackward, SumDimBackward,
30    TanhBackward, TransposeBackward, UnsqueezeBackward, VarDimBackward,
31};
32use crate::grad_fn::{AccumulateGrad, GradAccumulator, GradFn};
33use crate::graph::{GraphNode, with_graph};
34use crate::no_grad::is_grad_enabled;
35
36// =============================================================================
37// Variable Struct
38// =============================================================================
39
40/// A tensor with automatic differentiation support.
41///
42/// Variable wraps a Tensor and tracks operations performed on it to enable
43/// automatic gradient computation. When `requires_grad` is true, all operations
44/// are recorded in a computational graph.
45#[derive(Clone)]
46pub struct Variable {
47    /// The underlying tensor data.
48    data: Arc<RwLock<Tensor<f32>>>,
49    /// Shared gradient accumulator (for leaf variables, shared with `AccumulateGrad`).
50    grad: GradAccumulator,
51    /// Whether this variable requires gradient computation.
52    requires_grad: bool,
53    /// Whether this is a leaf variable (created by user, not an operation).
54    is_leaf: bool,
55    /// The gradient function for backpropagation.
56    grad_fn: Option<GradFn>,
57    /// Graph node for this variable.
58    node: Option<Arc<GraphNode>>,
59}
60
61impl Variable {
62    /// Creates a new variable from a tensor.
63    ///
64    /// # Arguments
65    /// * `data` - The tensor data
66    /// * `requires_grad` - Whether to track gradients for this variable
67    #[must_use]
68    pub fn new(data: Tensor<f32>, requires_grad: bool) -> Self {
69        // Create shared gradient accumulator
70        let grad: GradAccumulator = Arc::new(RwLock::new(None));
71
72        let node = if requires_grad {
73            Some(with_graph(|g| g.register_leaf(true)))
74        } else {
75            None
76        };
77
78        // Create AccumulateGrad with shared gradient storage
79        let grad_fn = if requires_grad {
80            Some(GradFn::new(AccumulateGrad::new(Arc::clone(&grad))))
81        } else {
82            None
83        };
84
85        Self {
86            data: Arc::new(RwLock::new(data)),
87            grad,
88            requires_grad,
89            is_leaf: true,
90            grad_fn,
91            node,
92        }
93    }
94
95    /// Creates a variable that doesn't require gradients.
96    #[must_use]
97    pub fn from_tensor(data: Tensor<f32>) -> Self {
98        Self::new(data, false)
99    }
100
101    /// Creates a new variable from an operation result with an attached gradient function.
102    ///
103    /// This connects the variable to the computational graph, allowing gradients
104    /// to flow backward through the operation that produced this variable.
105    pub fn from_operation(data: Tensor<f32>, grad_fn: GradFn, requires_grad: bool) -> Self {
106        let node = if requires_grad {
107            Some(with_graph(|g| g.register_operation(grad_fn.clone(), true)))
108        } else {
109            None
110        };
111
112        Self {
113            data: Arc::new(RwLock::new(data)),
114            grad: Arc::new(RwLock::new(None)),
115            requires_grad,
116            is_leaf: false,
117            grad_fn: if requires_grad { Some(grad_fn) } else { None },
118            node,
119        }
120    }
121
122    /// Returns a clone of the underlying tensor data.
123    ///
124    /// Tensor uses Arc-backed storage, so this is a cheap reference count
125    /// bump (not a deep copy). The data is shared until mutated.
126    #[must_use]
127    pub fn data(&self) -> Tensor<f32> {
128        self.data.read().clone()
129    }
130
131    /// Returns the shape of the tensor.
132    #[must_use]
133    pub fn shape(&self) -> Vec<usize> {
134        self.data.read().shape().to_vec()
135    }
136
137    /// Returns the number of dimensions.
138    #[must_use]
139    pub fn ndim(&self) -> usize {
140        self.data.read().ndim()
141    }
142
143    /// Returns the total number of elements.
144    #[must_use]
145    pub fn numel(&self) -> usize {
146        self.data.read().numel()
147    }
148
149    /// Returns the device this variable's data is on.
150    #[must_use]
151    pub fn device(&self) -> axonml_tensor::Device {
152        self.data.read().device()
153    }
154
155    /// Moves this variable's data to the specified device.
156    ///
157    /// Creates a new leaf Variable on the target device.
158    /// Used for moving inputs to GPU before forward pass.
159    pub fn to_device(&self, device: axonml_tensor::Device) -> Self {
160        let current = self.data.read().clone();
161        if current.device() == device {
162            return self.clone();
163        }
164        let moved = current
165            .to_device(device)
166            .expect("Failed to move variable to device");
167        Variable::new(moved, self.requires_grad)
168    }
169
170    /// Returns whether this variable requires gradients.
171    #[must_use]
172    pub fn requires_grad(&self) -> bool {
173        self.requires_grad
174    }
175
176    /// Returns whether this is a leaf variable.
177    #[must_use]
178    pub fn is_leaf(&self) -> bool {
179        self.is_leaf
180    }
181
182    /// Returns the gradient of this variable.
183    ///
184    /// Only available for leaf variables after `backward()` has been called.
185    #[must_use]
186    pub fn grad(&self) -> Option<Tensor<f32>> {
187        self.grad.read().clone()
188    }
189
190    /// Returns the gradient function.
191    #[must_use]
192    pub fn grad_fn(&self) -> Option<&GradFn> {
193        self.grad_fn.as_ref()
194    }
195
196    /// Updates the underlying tensor data without breaking the gradient accumulator.
197    ///
198    /// Unlike creating a new Variable, this preserves the grad accumulator Arc
199    /// so that backward passes continue to write gradients to the right place.
200    pub fn set_data(&mut self, new_data: Tensor<f32>) {
201        *self.data.write() = new_data;
202    }
203
204    /// Sets the gradient (used during backward pass).
205    pub fn set_grad(&self, grad: Tensor<f32>) {
206        *self.grad.write() = Some(grad);
207    }
208
209    /// Accumulates gradient (adds to existing gradient).
210    pub fn accumulate_grad(&self, grad: &Tensor<f32>) {
211        let mut grad_lock = self.grad.write();
212        if let Some(ref existing) = *grad_lock {
213            *grad_lock = Some(existing.add(grad).unwrap());
214        } else {
215            *grad_lock = Some(grad.clone());
216        }
217    }
218
219    /// Clears the gradient.
220    pub fn zero_grad(&self) {
221        *self.grad.write() = None;
222    }
223
224    /// Detaches this variable from the computation graph.
225    ///
226    /// Returns a new variable with the same data but no gradient history.
227    #[must_use]
228    pub fn detach(&self) -> Self {
229        Self {
230            data: Arc::new(RwLock::new(self.data.read().clone())),
231            grad: Arc::new(RwLock::new(None)),
232            requires_grad: false,
233            is_leaf: true,
234            grad_fn: None,
235            node: None,
236        }
237    }
238
239    /// Returns a new variable with `requires_grad` set.
240    #[must_use]
241    pub fn requires_grad_(mut self, requires_grad: bool) -> Self {
242        self.requires_grad = requires_grad;
243        if requires_grad && self.is_leaf {
244            // AccumulateGrad shares the gradient accumulator with this variable
245            self.grad_fn = Some(GradFn::new(AccumulateGrad::new(Arc::clone(&self.grad))));
246            self.node = Some(with_graph(|g| g.register_leaf(true)));
247        }
248        self
249    }
250
251    /// Computes gradients via backpropagation.
252    ///
253    /// This should only be called on scalar (single-element) tensors,
254    /// typically the loss value.
255    pub fn backward(&self) {
256        assert!(
257            self.requires_grad,
258            "Cannot call backward on a variable that doesn't require gradients"
259        );
260
261        assert!(
262            (self.numel() == 1),
263            "backward() can only be called on scalar tensors"
264        );
265
266        // Start with gradient of 1.0 for the output, on the same device
267        let mut grad_output = Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap();
268        let device = self.data.read().device();
269        if device.is_gpu() {
270            grad_output = grad_output.to_device(device).expect("device transfer failed");
271        }
272        crate::backward::backward(self, &grad_output);
273    }
274
275    /// Runs the backward pass with a provided gradient tensor.
276    ///
277    /// Unlike `backward()`, this does not require the variable to be scalar.
278    /// The gradient tensor must match the shape of this variable.
279    pub fn backward_with_grad(&self, grad_output: &Tensor<f32>) {
280        if !self.requires_grad {
281            return;
282        }
283        let device = self.data.read().device();
284        let grad = if grad_output.device() != device && device.is_gpu() {
285            grad_output.to_device(device).expect("device transfer failed")
286        } else {
287            grad_output.clone()
288        };
289        crate::backward::backward(self, &grad);
290    }
291
292    // =========================================================================
293    // Arithmetic Operations
294    // =========================================================================
295
296    /// Element-wise addition.
297    #[must_use]
298    pub fn add_var(&self, other: &Variable) -> Variable {
299        let result = self.data.read().add(&other.data.read()).unwrap();
300        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
301
302        if requires_grad {
303            let grad_fn = GradFn::new(AddBackward::new(
304                self.grad_fn.clone(),
305                other.grad_fn.clone(),
306                self.shape(),
307                other.shape(),
308            ));
309            Variable::from_operation(result, grad_fn, true)
310        } else {
311            Variable::from_tensor(result)
312        }
313    }
314
315    /// Element-wise subtraction.
316    #[must_use]
317    pub fn sub_var(&self, other: &Variable) -> Variable {
318        let result = self.data.read().sub(&other.data.read()).unwrap();
319        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
320
321        if requires_grad {
322            let grad_fn = GradFn::new(SubBackward::new(
323                self.grad_fn.clone(),
324                other.grad_fn.clone(),
325                self.shape(),
326                other.shape(),
327            ));
328            Variable::from_operation(result, grad_fn, true)
329        } else {
330            Variable::from_tensor(result)
331        }
332    }
333
334    /// Element-wise multiplication.
335    #[must_use]
336    pub fn mul_var(&self, other: &Variable) -> Variable {
337        let self_data = self.data.read().clone();
338        let other_data = other.data.read().clone();
339        let result = self_data.mul(&other_data).unwrap();
340        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
341
342        if requires_grad {
343            let grad_fn = GradFn::new(MulBackward::new(
344                self.grad_fn.clone(),
345                other.grad_fn.clone(),
346                self_data,
347                other_data,
348            ));
349            Variable::from_operation(result, grad_fn, true)
350        } else {
351            Variable::from_tensor(result)
352        }
353    }
354
355    /// Element-wise division.
356    #[must_use]
357    pub fn div_var(&self, other: &Variable) -> Variable {
358        let self_data = self.data.read().clone();
359        let other_data = other.data.read().clone();
360        let result = self_data.div(&other_data).unwrap();
361        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
362
363        if requires_grad {
364            let grad_fn = GradFn::new(DivBackward::new(
365                self.grad_fn.clone(),
366                other.grad_fn.clone(),
367                self_data,
368                other_data,
369            ));
370            Variable::from_operation(result, grad_fn, true)
371        } else {
372            Variable::from_tensor(result)
373        }
374    }
375
376    /// Negation.
377    #[must_use]
378    pub fn neg_var(&self) -> Variable {
379        let result = self.data.read().neg();
380        let requires_grad = self.requires_grad && is_grad_enabled();
381
382        if requires_grad {
383            let grad_fn = GradFn::new(NegBackward::new(self.grad_fn.clone()));
384            Variable::from_operation(result, grad_fn, true)
385        } else {
386            Variable::from_tensor(result)
387        }
388    }
389
390    /// Matrix multiplication.
391    #[must_use]
392    pub fn matmul(&self, other: &Variable) -> Variable {
393        let self_data = self.data.read().clone();
394        let other_data = other.data.read().clone();
395
396        // AMP: cast inputs to f16 precision for faster matmul, result stays f32
397        let (compute_a, compute_b) = if crate::amp::is_autocast_enabled() {
398            (self_data.to_f16_precision(), other_data.to_f16_precision())
399        } else {
400            (self_data.clone(), other_data.clone())
401        };
402
403        let result = compute_a.matmul(&compute_b).expect("matmul failed");
404        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
405
406        if requires_grad {
407            let grad_fn = GradFn::new(MatMulBackward::new(
408                self.grad_fn.clone(),
409                other.grad_fn.clone(),
410                self_data,
411                other_data,
412            ));
413            Variable::from_operation(result, grad_fn, true)
414        } else {
415            Variable::from_tensor(result)
416        }
417    }
418
419    /// Power operation.
420    #[must_use]
421    pub fn pow(&self, exponent: f32) -> Variable {
422        let self_data = self.data.read().clone();
423        let result = self_data.pow(exponent);
424        let requires_grad = self.requires_grad && is_grad_enabled();
425
426        if requires_grad {
427            let grad_fn = GradFn::new(PowBackward::new(self.grad_fn.clone(), self_data, exponent));
428            Variable::from_operation(result, grad_fn, true)
429        } else {
430            Variable::from_tensor(result)
431        }
432    }
433
434    // =========================================================================
435    // Activation Functions
436    // =========================================================================
437
438    /// `ReLU` activation.
439    #[must_use]
440    pub fn relu(&self) -> Variable {
441        let self_data = self.data.read().clone();
442        let result = self_data.relu();
443        let requires_grad = self.requires_grad && is_grad_enabled();
444
445        if requires_grad {
446            let grad_fn = GradFn::new(ReluBackward::new(self.grad_fn.clone(), self_data));
447            Variable::from_operation(result, grad_fn, true)
448        } else {
449            Variable::from_tensor(result)
450        }
451    }
452
453    /// Leaky ReLU activation.
454    #[must_use]
455    pub fn leaky_relu(&self, negative_slope: f32) -> Variable {
456        let self_data = self.data.read().clone();
457        let device = self_data.device();
458        let result_vec: Vec<f32> = self_data
459            .to_vec()
460            .iter()
461            .map(|&x| if x > 0.0 { x } else { x * negative_slope })
462            .collect();
463        let mut result = Tensor::from_vec(result_vec, self_data.shape()).unwrap();
464        if device.is_gpu() {
465            result = result.to_device(device).expect("device transfer failed");
466        }
467        let requires_grad = self.requires_grad && is_grad_enabled();
468
469        if requires_grad {
470            let grad_fn = GradFn::new(LeakyReluBackward::new(
471                self.grad_fn.clone(),
472                self_data,
473                negative_slope,
474            ));
475            Variable::from_operation(result, grad_fn, true)
476        } else {
477            Variable::from_tensor(result)
478        }
479    }
480
481    /// ELU activation.
482    #[must_use]
483    pub fn elu(&self, alpha: f32) -> Variable {
484        let self_data = self.data.read().clone();
485        let device = self_data.device();
486        let result_vec: Vec<f32> = self_data
487            .to_vec()
488            .iter()
489            .map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
490            .collect();
491        let mut result = Tensor::from_vec(result_vec, self_data.shape()).unwrap();
492        if device.is_gpu() {
493            result = result.to_device(device).expect("device transfer failed");
494        }
495        let requires_grad = self.requires_grad && is_grad_enabled();
496
497        if requires_grad {
498            let grad_fn = GradFn::new(EluBackward::new(self.grad_fn.clone(), self_data, alpha));
499            Variable::from_operation(result, grad_fn, true)
500        } else {
501            Variable::from_tensor(result)
502        }
503    }
504
505    /// Sigmoid activation.
506    #[must_use]
507    pub fn sigmoid(&self) -> Variable {
508        let result = self.data.read().sigmoid();
509        let requires_grad = self.requires_grad && is_grad_enabled();
510
511        if requires_grad {
512            let grad_fn = GradFn::new(SigmoidBackward::new(self.grad_fn.clone(), result.clone()));
513            Variable::from_operation(result, grad_fn, true)
514        } else {
515            Variable::from_tensor(result)
516        }
517    }
518
519    /// Tanh activation.
520    #[must_use]
521    pub fn tanh(&self) -> Variable {
522        let result = self.data.read().tanh();
523        let requires_grad = self.requires_grad && is_grad_enabled();
524
525        if requires_grad {
526            let grad_fn = GradFn::new(TanhBackward::new(self.grad_fn.clone(), result.clone()));
527            Variable::from_operation(result, grad_fn, true)
528        } else {
529            Variable::from_tensor(result)
530        }
531    }
532
533    /// Element-wise exponential.
534    #[must_use]
535    pub fn exp(&self) -> Variable {
536        let self_data = self.data.read().clone();
537        let result = self_data.exp();
538        let requires_grad = self.requires_grad && is_grad_enabled();
539
540        if requires_grad {
541            let grad_fn = GradFn::new(ExpBackward::new(self.grad_fn.clone(), result.clone()));
542            Variable::from_operation(result, grad_fn, true)
543        } else {
544            Variable::from_tensor(result)
545        }
546    }
547
548    /// Element-wise natural logarithm.
549    #[must_use]
550    pub fn log(&self) -> Variable {
551        let self_data = self.data.read().clone();
552        let result = self_data.ln();
553        let requires_grad = self.requires_grad && is_grad_enabled();
554
555        if requires_grad {
556            let grad_fn = GradFn::new(LogBackward::new(self.grad_fn.clone(), self_data));
557            Variable::from_operation(result, grad_fn, true)
558        } else {
559            Variable::from_tensor(result)
560        }
561    }
562
563    /// Element-wise clamp to [min_val, max_val].
564    #[must_use]
565    pub fn clamp(&self, min_val: f32, max_val: f32) -> Variable {
566        let self_data = self.data.read().clone();
567        let device = self_data.device();
568        let result_data: Vec<f32> = self_data
569            .to_vec()
570            .iter()
571            .map(|&x| x.clamp(min_val, max_val))
572            .collect();
573        let mut result = Tensor::from_vec(result_data, self_data.shape()).unwrap();
574        if device.is_gpu() {
575            result = result.to_device(device).expect("device transfer failed");
576        }
577        let requires_grad = self.requires_grad && is_grad_enabled();
578
579        if requires_grad {
580            let grad_fn = GradFn::new(ClampBackward::new(
581                self.grad_fn.clone(),
582                self_data,
583                min_val,
584                max_val,
585            ));
586            Variable::from_operation(result, grad_fn, true)
587        } else {
588            Variable::from_tensor(result)
589        }
590    }
591
592    // =========================================================================
593    // Reduction Operations
594    // =========================================================================
595
596    /// Sum all elements.
597    #[must_use]
598    pub fn sum(&self) -> Variable {
599        let self_data = self.data.read().clone();
600        let result = self_data.sum(); // Returns a scalar Tensor
601        let requires_grad = self.requires_grad && is_grad_enabled();
602
603        if requires_grad {
604            let grad_fn = GradFn::new(SumBackward::new(self.grad_fn.clone(), self.shape()));
605            Variable::from_operation(result, grad_fn, true)
606        } else {
607            Variable::from_tensor(result)
608        }
609    }
610
611    /// Sum along a dimension, removing that dimension.
612    #[must_use]
613    pub fn sum_dim(&self, dim: usize) -> Variable {
614        let self_data = self.data.read().clone();
615        let result = self_data.sum_dim(dim as i32, false);
616        let requires_grad = self.requires_grad && is_grad_enabled();
617
618        if requires_grad {
619            let grad_fn = GradFn::new(SumDimBackward::new(self.grad_fn.clone(), self.shape(), dim));
620            Variable::from_operation(result, grad_fn, true)
621        } else {
622            Variable::from_tensor(result)
623        }
624    }
625
626    /// Mean of all elements.
627    #[must_use]
628    pub fn mean(&self) -> Variable {
629        let self_data = self.data.read().clone();
630        let result = self_data.mean().unwrap(); // Returns a scalar Tensor
631        let requires_grad = self.requires_grad && is_grad_enabled();
632
633        if requires_grad {
634            let grad_fn = GradFn::new(MeanBackward::new(self.grad_fn.clone(), self.shape()));
635            Variable::from_operation(result, grad_fn, true)
636        } else {
637            Variable::from_tensor(result)
638        }
639    }
640
641    // =========================================================================
642    // Loss Functions
643    // =========================================================================
644
645    /// Mean Squared Error loss.
646    #[must_use]
647    pub fn mse_loss(&self, target: &Variable) -> Variable {
648        let diff = self.sub_var(target);
649        let squared = diff.pow(2.0);
650        squared.mean()
651    }
652
653    /// Binary Cross Entropy loss (expects sigmoid output).
654    #[must_use]
655    pub fn binary_cross_entropy(&self, target: &Variable) -> Variable {
656        let eps = Variable::from_tensor(Tensor::scalar(1e-7));
657        let one = Variable::from_tensor(Tensor::scalar(1.0));
658
659        // -[y * log(p + eps) + (1 - y) * log(1 - p + eps)]
660        let log_p = self.add_var(&eps);
661        let log_1_p = one.sub_var(self).add_var(&eps);
662
663        let term1 = target.mul_var(&Variable::from_tensor(log_p.data().ln()));
664        let term2 = one
665            .sub_var(target)
666            .mul_var(&Variable::from_tensor(log_1_p.data().ln()));
667
668        term1.add_var(&term2).neg_var().mean()
669    }
670
671    // =========================================================================
672    // Shape Operations
673    // =========================================================================
674
675    /// Reshapes the variable to a new shape.
676    #[must_use]
677    pub fn reshape(&self, shape: &[usize]) -> Variable {
678        let isize_shape: Vec<isize> = shape.iter().map(|&x| x as isize).collect();
679        let original_shape = self.shape();
680        let new_data = self
681            .data()
682            .reshape(&isize_shape)
683            .unwrap_or_else(|_| self.data().clone());
684        let requires_grad = self.requires_grad && is_grad_enabled();
685
686        if requires_grad {
687            let grad_fn = GradFn::new(ReshapeBackward::new(self.grad_fn.clone(), original_shape));
688            Variable::from_operation(new_data, grad_fn, true)
689        } else {
690            Variable::from_tensor(new_data)
691        }
692    }
693
694    /// Flattens all dimensions from `start_dim` to the end into a single dimension.
695    ///
696    /// `flatten(1)` on a `[batch, C, H, W]` tensor produces `[batch, C*H*W]`.
697    /// `flatten(0)` flattens everything into a 1D vector.
698    #[must_use]
699    pub fn flatten(&self, start_dim: usize) -> Variable {
700        let shape = self.shape();
701        if start_dim >= shape.len() {
702            return self.clone();
703        }
704        let mut new_shape: Vec<usize> = shape[..start_dim].to_vec();
705        let flat: usize = shape[start_dim..].iter().product();
706        new_shape.push(flat);
707        self.reshape(&new_shape)
708    }
709
710    /// Transposes two dimensions.
711    #[must_use]
712    pub fn transpose(&self, dim0: usize, dim1: usize) -> Variable {
713        let new_data = self
714            .data()
715            .transpose(dim0 as i64, dim1 as i64)
716            .unwrap_or_else(|_| self.data().clone());
717        let requires_grad = self.requires_grad && is_grad_enabled();
718
719        if requires_grad {
720            let grad_fn = GradFn::new(TransposeBackward::new(self.grad_fn.clone(), dim0, dim1));
721            Variable::from_operation(new_data, grad_fn, true)
722        } else {
723            Variable::from_tensor(new_data)
724        }
725    }
726
727    /// Slices the variable along specified ranges.
728    #[must_use]
729    pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Variable {
730        let new_data = self.data().slice(ranges);
731        Variable::new(new_data, self.requires_grad())
732    }
733
734    /// Narrows the variable along a dimension.
735    ///
736    /// Returns a view of the tensor containing elements from `start` to `start + length`
737    /// along the specified dimension. This operation preserves gradients for backpropagation.
738    #[must_use]
739    pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Variable {
740        let input_shape = self.shape();
741        let new_data = self
742            .data()
743            .narrow(dim, start, length)
744            .unwrap_or_else(|_| self.data().clone());
745        let requires_grad = self.requires_grad && is_grad_enabled();
746
747        if requires_grad {
748            let grad_fn = GradFn::new(NarrowBackward::new(
749                self.grad_fn.clone(),
750                input_shape,
751                dim,
752                start,
753            ));
754            Variable::from_operation(new_data, grad_fn, true)
755        } else {
756            Variable::from_tensor(new_data)
757        }
758    }
759
760    /// Expands the variable to a new shape (broadcast).
761    ///
762    /// Tracks the computational graph for backward pass.
763    #[must_use]
764    pub fn expand(&self, shape: &[usize]) -> Variable {
765        let input_shape = self.shape();
766        let new_data = self.data().broadcast_to(shape);
767        let requires_grad = self.requires_grad && is_grad_enabled();
768
769        if requires_grad {
770            let grad_fn = GradFn::new(ExpandBackward::new(self.grad_fn.clone(), input_shape));
771            Variable::from_operation(new_data, grad_fn, true)
772        } else {
773            Variable::from_tensor(new_data)
774        }
775    }
776
777    /// Selects a single index along a dimension, reducing rank by 1.
778    ///
779    /// For a tensor of shape (A, B, C), `select(1, i)` returns shape (A, C).
780    /// Tracks the computational graph for backward pass.
781    #[must_use]
782    pub fn select(&self, dim: usize, index: usize) -> Variable {
783        let input_shape = self.shape();
784        let new_data = self
785            .data()
786            .select(dim, index)
787            .unwrap_or_else(|_| self.data().clone());
788        let requires_grad = self.requires_grad && is_grad_enabled();
789
790        if requires_grad {
791            let grad_fn = GradFn::new(SelectBackward::new(
792                self.grad_fn.clone(),
793                input_shape,
794                dim,
795                index,
796            ));
797            Variable::from_operation(new_data, grad_fn, true)
798        } else {
799            Variable::from_tensor(new_data)
800        }
801    }
802
803    /// Adds a dimension of size 1 at the given position.
804    ///
805    /// Tracks the computational graph for backward pass.
806    #[must_use]
807    pub fn unsqueeze(&self, dim: usize) -> Variable {
808        let new_data = self
809            .data()
810            .unsqueeze(dim as i64)
811            .unwrap_or_else(|_| self.data().clone());
812        let requires_grad = self.requires_grad && is_grad_enabled();
813
814        if requires_grad {
815            let grad_fn = GradFn::new(UnsqueezeBackward::new(self.grad_fn.clone(), dim));
816            Variable::from_operation(new_data, grad_fn, true)
817        } else {
818            Variable::from_tensor(new_data)
819        }
820    }
821
822    /// Concatenates variables along a dimension.
823    ///
824    /// All variables must have the same shape except along the cat dimension.
825    /// Tracks the computational graph for backpropagation.
826    #[must_use]
827    pub fn cat(variables: &[&Variable], dim: usize) -> Variable {
828        let tensors: Vec<Tensor<f32>> = variables.iter().map(|v| v.data()).collect();
829        let tensor_refs: Vec<&Tensor<f32>> = tensors.iter().collect();
830        let result = Tensor::cat(&tensor_refs, dim).unwrap();
831
832        let requires_grad = variables.iter().any(|v| v.requires_grad) && is_grad_enabled();
833
834        if requires_grad {
835            let next_fns: Vec<Option<GradFn>> =
836                variables.iter().map(|v| v.grad_fn.clone()).collect();
837            let sizes: Vec<usize> = variables.iter().map(|v| v.shape()[dim]).collect();
838            let grad_fn = GradFn::new(CatBackward::new(next_fns, sizes, dim));
839            Variable::from_operation(result, grad_fn, true)
840        } else {
841            Variable::from_tensor(result)
842        }
843    }
844
845    // =========================================================================
846    // Scalar Operations
847    // =========================================================================
848
849    /// Multiplies by a scalar.
850    #[must_use]
851    pub fn mul_scalar(&self, scalar: f32) -> Variable {
852        let data = self.data();
853        let result = data.mul_scalar(scalar);
854        let requires_grad = self.requires_grad && is_grad_enabled();
855
856        if requires_grad {
857            let grad_fn = GradFn::new(MulScalarBackward::new(self.grad_fn.clone(), scalar));
858            Variable::from_operation(result, grad_fn, true)
859        } else {
860            Variable::from_tensor(result)
861        }
862    }
863
864    /// Adds a scalar.
865    #[must_use]
866    pub fn add_scalar(&self, scalar: f32) -> Variable {
867        let data = self.data();
868        let result = data.add_scalar(scalar);
869        let requires_grad = self.requires_grad && is_grad_enabled();
870
871        if requires_grad {
872            let grad_fn = GradFn::new(AddScalarBackward::new(self.grad_fn.clone()));
873            Variable::from_operation(result, grad_fn, true)
874        } else {
875            Variable::from_tensor(result)
876        }
877    }
878
879    /// Subtracts a scalar.
880    #[must_use]
881    pub fn sub_scalar(&self, scalar: f32) -> Variable {
882        self.add_scalar(-scalar)
883    }
884
885    /// Divides by a scalar.
886    #[must_use]
887    pub fn div_scalar(&self, scalar: f32) -> Variable {
888        self.mul_scalar(1.0 / scalar)
889    }
890
891    // =========================================================================
892    // Additional Activations
893    // =========================================================================
894
895    /// GELU activation function (Gaussian Error Linear Unit).
896    #[must_use]
897    pub fn gelu(&self) -> Variable {
898        let self_data = self.data();
899        let result = self_data.gelu();
900        let requires_grad = self.requires_grad && is_grad_enabled();
901
902        if requires_grad {
903            let grad_fn = GradFn::new(GeluBackward::new(self.grad_fn.clone(), self_data));
904            Variable::from_operation(result, grad_fn, true)
905        } else {
906            Variable::from_tensor(result)
907        }
908    }
909
910    /// SiLU/Swish activation function (x * sigmoid(x)).
911    #[must_use]
912    pub fn silu(&self) -> Variable {
913        let self_data = self.data();
914        let result = self_data.silu();
915        let requires_grad = self.requires_grad && is_grad_enabled();
916
917        if requires_grad {
918            let grad_fn = GradFn::new(SiluBackward::new(self.grad_fn.clone(), self_data));
919            Variable::from_operation(result, grad_fn, true)
920        } else {
921            Variable::from_tensor(result)
922        }
923    }
924
925    /// Square root.
926    #[must_use]
927    pub fn sqrt(&self) -> Variable {
928        let data = self.data();
929        let result = data.sqrt();
930        let requires_grad = self.requires_grad && is_grad_enabled();
931
932        if requires_grad {
933            let grad_fn = GradFn::new(SqrtBackward::new(self.grad_fn.clone(), result.clone()));
934            Variable::from_operation(result, grad_fn, true)
935        } else {
936            Variable::from_tensor(result)
937        }
938    }
939
940    // =========================================================================
941    // Softmax Operations
942    // =========================================================================
943
944    /// Softmax along specified dimension.
945    #[must_use]
946    pub fn softmax(&self, dim: i32) -> Variable {
947        let data = self.data();
948        let result = data.softmax(dim);
949        let requires_grad = self.requires_grad && is_grad_enabled();
950
951        if requires_grad {
952            let grad_fn = GradFn::new(SoftmaxBackward::new(
953                self.grad_fn.clone(),
954                result.clone(),
955                dim as i64,
956            ));
957            Variable::from_operation(result, grad_fn, true)
958        } else {
959            Variable::from_tensor(result)
960        }
961    }
962
963    /// Log softmax along specified dimension.
964    #[must_use]
965    pub fn log_softmax(&self, dim: i32) -> Variable {
966        let data = self.data();
967        let result = data.log_softmax(dim);
968        let requires_grad = self.requires_grad && is_grad_enabled();
969
970        if requires_grad {
971            let grad_fn = GradFn::new(LogSoftmaxBackward::new(
972                self.grad_fn.clone(),
973                result.clone(),
974                dim as i64,
975            ));
976            Variable::from_operation(result, grad_fn, true)
977        } else {
978            Variable::from_tensor(result)
979        }
980    }
981
982    // =========================================================================
983    // Reduction Operations with Dimensions
984    // =========================================================================
985
986    /// Mean along a dimension, optionally keeping the dimension.
987    #[must_use]
988    pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Variable {
989        let data = self.data();
990        let input_shape = data.shape().to_vec();
991        let ndim = input_shape.len();
992        let dim_usize = if dim < 0 {
993            (ndim as i32 + dim) as usize
994        } else {
995            dim as usize
996        };
997        let result = data.mean_dim(dim, keepdim);
998        let requires_grad = self.requires_grad && is_grad_enabled();
999
1000        if requires_grad {
1001            let grad_fn = GradFn::new(MeanDimBackward::new(
1002                self.grad_fn.clone(),
1003                input_shape,
1004                dim_usize,
1005                keepdim,
1006            ));
1007            Variable::from_operation(result, grad_fn, true)
1008        } else {
1009            Variable::from_tensor(result)
1010        }
1011    }
1012
1013    /// Variance along a dimension, optionally keeping the dimension.
1014    #[must_use]
1015    pub fn var_dim(&self, dim: i32, keepdim: bool) -> Variable {
1016        let self_data = self.data();
1017        let input_shape = self_data.shape().to_vec();
1018        let ndim = input_shape.len();
1019        let dim_usize = if dim < 0 {
1020            (ndim as i32 + dim) as usize
1021        } else {
1022            dim as usize
1023        };
1024        let result = self_data.var_dim(dim, keepdim);
1025        let requires_grad = self.requires_grad && is_grad_enabled();
1026
1027        if requires_grad {
1028            let grad_fn = GradFn::new(VarDimBackward::new(
1029                self.grad_fn.clone(),
1030                self_data,
1031                dim_usize,
1032                keepdim,
1033            ));
1034            Variable::from_operation(result, grad_fn, true)
1035        } else {
1036            Variable::from_tensor(result)
1037        }
1038    }
1039
1040    // =========================================================================
1041    // Utility Methods
1042    // =========================================================================
1043
1044    /// Creates a Variable from a tensor and requires_grad flag (for weight access).
1045    /// This is typically used internally by Parameter types.
1046    #[must_use]
1047    pub fn from_tensor_with_grad(data: Tensor<f32>, requires_grad: bool) -> Variable {
1048        Variable::new(data, requires_grad)
1049    }
1050
1051    /// Clones the variable (alias for Clone trait).
1052    #[must_use]
1053    pub fn clone_var(&self) -> Variable {
1054        self.clone()
1055    }
1056
1057    /// Adds another variable (alias for add_var for method chaining).
1058    #[must_use]
1059    pub fn add(&self, other: &Variable) -> Variable {
1060        self.add_var(other)
1061    }
1062
1063    /// Subtracts another variable (alias for sub_var for method chaining).
1064    #[must_use]
1065    pub fn sub(&self, other: &Variable) -> Variable {
1066        self.sub_var(other)
1067    }
1068
1069    /// Multiplies by another variable (alias for mul_var for method chaining).
1070    #[must_use]
1071    pub fn mul(&self, other: &Variable) -> Variable {
1072        self.mul_var(other)
1073    }
1074
1075    /// Divides by another variable (alias for div_var for method chaining).
1076    #[must_use]
1077    pub fn div(&self, other: &Variable) -> Variable {
1078        self.div_var(other)
1079    }
1080}
1081
1082// =============================================================================
1083// Operator Overloads
1084// =============================================================================
1085
1086impl Add for &Variable {
1087    type Output = Variable;
1088
1089    fn add(self, other: &Variable) -> Variable {
1090        self.add_var(other)
1091    }
1092}
1093
1094impl Sub for &Variable {
1095    type Output = Variable;
1096
1097    fn sub(self, other: &Variable) -> Variable {
1098        self.sub_var(other)
1099    }
1100}
1101
1102impl Mul for &Variable {
1103    type Output = Variable;
1104
1105    fn mul(self, other: &Variable) -> Variable {
1106        self.mul_var(other)
1107    }
1108}
1109
1110impl Div for &Variable {
1111    type Output = Variable;
1112
1113    fn div(self, other: &Variable) -> Variable {
1114        self.div_var(other)
1115    }
1116}
1117
1118impl Neg for &Variable {
1119    type Output = Variable;
1120
1121    fn neg(self) -> Variable {
1122        self.neg_var()
1123    }
1124}
1125
1126impl std::fmt::Debug for Variable {
1127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1128        f.debug_struct("Variable")
1129            .field("shape", &self.shape())
1130            .field("requires_grad", &self.requires_grad)
1131            .field("is_leaf", &self.is_leaf)
1132            .field(
1133                "grad_fn",
1134                &self.grad_fn.as_ref().map(super::grad_fn::GradFn::name),
1135            )
1136            .finish()
1137    }
1138}
1139
1140// =============================================================================
1141// Tests
1142// =============================================================================
1143
1144#[cfg(test)]
1145mod tests {
1146    use super::*;
1147    use axonml_tensor::zeros;
1148
1149    #[test]
1150    fn test_variable_creation() {
1151        let t = zeros::<f32>(&[2, 3]);
1152        let v = Variable::new(t, true);
1153        assert!(v.requires_grad());
1154        assert!(v.is_leaf());
1155        assert_eq!(v.shape(), vec![2, 3]);
1156    }
1157
1158    #[test]
1159    fn test_variable_no_grad() {
1160        let t = zeros::<f32>(&[2, 3]);
1161        let v = Variable::from_tensor(t);
1162        assert!(!v.requires_grad());
1163    }
1164
1165    #[test]
1166    fn test_variable_add() {
1167        let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), true);
1168        let b = Variable::new(Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).expect("tensor creation failed"), true);
1169        let c = &a + &b;
1170        assert_eq!(c.data().to_vec(), vec![5.0, 7.0, 9.0]);
1171        assert!(c.requires_grad());
1172        assert!(!c.is_leaf());
1173    }
1174
1175    #[test]
1176    fn test_variable_detach() {
1177        let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), true);
1178        let b = a.detach();
1179        assert!(!b.requires_grad());
1180        assert!(b.is_leaf());
1181    }
1182
1183    #[test]
1184    fn test_mse_loss() {
1185        let pred = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), true);
1186        let target = Variable::from_tensor(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"));
1187        let loss = pred.mse_loss(&target);
1188        assert_eq!(loss.numel(), 1);
1189        assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
1190    }
1191
1192    #[test]
1193    fn test_exp() {
1194        let a = Variable::new(Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).expect("tensor creation failed"), true);
1195        let b = a.exp();
1196        assert!((b.data().to_vec()[0] - 1.0).abs() < 1e-5);
1197        assert!((b.data().to_vec()[1] - std::f32::consts::E).abs() < 1e-4);
1198
1199        b.sum().backward();
1200        let grad = a.grad().unwrap().to_vec();
1201        // d/dx(exp(x)) = exp(x)
1202        assert!((grad[0] - 1.0).abs() < 1e-5);
1203        assert!((grad[1] - std::f32::consts::E).abs() < 1e-4);
1204    }
1205
1206    #[test]
1207    fn test_log() {
1208        let a = Variable::new(
1209            Tensor::from_vec(vec![1.0, std::f32::consts::E, 10.0], &[3]).expect("tensor creation failed"),
1210            true,
1211        );
1212        let b = a.log();
1213        assert!((b.data().to_vec()[0] - 0.0).abs() < 1e-5);
1214        assert!((b.data().to_vec()[1] - 1.0).abs() < 1e-5);
1215
1216        b.sum().backward();
1217        let grad = a.grad().unwrap().to_vec();
1218        // d/dx(log(x)) = 1/x
1219        assert!((grad[0] - 1.0).abs() < 1e-5);
1220        assert!((grad[1] - 1.0 / std::f32::consts::E).abs() < 1e-5);
1221    }
1222
1223    #[test]
1224    fn test_clamp() {
1225        let a = Variable::new(Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).expect("tensor creation failed"), true);
1226        let b = a.clamp(0.0, 1.0);
1227        assert_eq!(b.data().to_vec(), vec![0.0, 0.5, 1.0]);
1228
1229        b.sum().backward();
1230        let grad = a.grad().unwrap().to_vec();
1231        // Gradient passes through only where not clamped
1232        assert_eq!(grad[0], 0.0); // clamped at min
1233        assert_eq!(grad[1], 1.0); // not clamped
1234        assert_eq!(grad[2], 0.0); // clamped at max
1235    }
1236}