Skip to main content

axonml_autograd/
variable.rs

1//! Variable - Tensor with Gradient Tracking
2//!
3//! # File
4//! `crates/axonml-autograd/src/variable.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::ops::{Add, Div, Mul, Neg, Sub};
18use std::sync::Arc;
19
20use parking_lot::RwLock;
21
22use axonml_tensor::Tensor;
23
24use crate::functions::{
25    AddBackward, AddScalarBackward, CatBackward, ClampBackward, DivBackward, EluBackward,
26    ExpBackward, ExpandBackward, GeluBackward, LeakyReluBackward, LogBackward, LogSoftmaxBackward,
27    MatMulBackward, MeanBackward, MeanDimBackward, MulBackward, MulScalarBackward, NarrowBackward,
28    NegBackward, PowBackward, ReluBackward, ReshapeBackward, SelectBackward, SigmoidBackward,
29    SiluBackward, SoftmaxBackward, SqrtBackward, SubBackward, SumBackward, SumDimBackward,
30    TanhBackward, TransposeBackward, UnsqueezeBackward, VarDimBackward,
31};
32use crate::grad_fn::{AccumulateGrad, GradAccumulator, GradFn};
33use crate::graph::{GraphNode, with_graph};
34use crate::no_grad::is_grad_enabled;
35
36// =============================================================================
37// Variable Struct
38// =============================================================================
39
40/// A tensor with automatic differentiation support.
41///
42/// Variable wraps a Tensor and tracks operations performed on it to enable
43/// automatic gradient computation. When `requires_grad` is true, all operations
44/// are recorded in a computational graph.
45#[derive(Clone)]
46pub struct Variable {
47    /// The underlying tensor data.
48    data: Arc<RwLock<Tensor<f32>>>,
49    /// Shared gradient accumulator (for leaf variables, shared with `AccumulateGrad`).
50    grad: GradAccumulator,
51    /// Whether this variable requires gradient computation.
52    requires_grad: bool,
53    /// Whether this is a leaf variable (created by user, not an operation).
54    is_leaf: bool,
55    /// The gradient function for backpropagation.
56    grad_fn: Option<GradFn>,
57    /// Graph node for this variable.
58    node: Option<Arc<GraphNode>>,
59}
60
61impl Variable {
62    /// Creates a new variable from a tensor.
63    ///
64    /// # Arguments
65    /// * `data` - The tensor data
66    /// * `requires_grad` - Whether to track gradients for this variable
67    #[must_use]
68    pub fn new(data: Tensor<f32>, requires_grad: bool) -> Self {
69        // Create shared gradient accumulator
70        let grad: GradAccumulator = Arc::new(RwLock::new(None));
71
72        let node = if requires_grad {
73            Some(with_graph(|g| g.register_leaf(true)))
74        } else {
75            None
76        };
77
78        // Create AccumulateGrad with shared gradient storage
79        let grad_fn = if requires_grad {
80            Some(GradFn::new(AccumulateGrad::new(Arc::clone(&grad))))
81        } else {
82            None
83        };
84
85        Self {
86            data: Arc::new(RwLock::new(data)),
87            grad,
88            requires_grad,
89            is_leaf: true,
90            grad_fn,
91            node,
92        }
93    }
94
95    /// Creates a variable that doesn't require gradients.
96    #[must_use]
97    pub fn from_tensor(data: Tensor<f32>) -> Self {
98        Self::new(data, false)
99    }
100
101    /// Creates a new variable from an operation result with an attached gradient function.
102    ///
103    /// This connects the variable to the computational graph, allowing gradients
104    /// to flow backward through the operation that produced this variable.
105    pub fn from_operation(data: Tensor<f32>, grad_fn: GradFn, requires_grad: bool) -> Self {
106        let node = if requires_grad {
107            Some(with_graph(|g| g.register_operation(grad_fn.clone(), true)))
108        } else {
109            None
110        };
111
112        Self {
113            data: Arc::new(RwLock::new(data)),
114            grad: Arc::new(RwLock::new(None)),
115            requires_grad,
116            is_leaf: false,
117            grad_fn: if requires_grad { Some(grad_fn) } else { None },
118            node,
119        }
120    }
121
122    /// Returns a reference to the underlying tensor data.
123    #[must_use]
124    pub fn data(&self) -> Tensor<f32> {
125        self.data.read().clone()
126    }
127
128    /// Returns the shape of the tensor.
129    #[must_use]
130    pub fn shape(&self) -> Vec<usize> {
131        self.data.read().shape().to_vec()
132    }
133
134    /// Returns the number of dimensions.
135    #[must_use]
136    pub fn ndim(&self) -> usize {
137        self.data.read().ndim()
138    }
139
140    /// Returns the total number of elements.
141    #[must_use]
142    pub fn numel(&self) -> usize {
143        self.data.read().numel()
144    }
145
146    /// Returns the device this variable's data is on.
147    #[must_use]
148    pub fn device(&self) -> axonml_tensor::Device {
149        self.data.read().device()
150    }
151
152    /// Moves this variable's data to the specified device.
153    ///
154    /// Creates a new leaf Variable on the target device.
155    /// Used for moving inputs to GPU before forward pass.
156    pub fn to_device(&self, device: axonml_tensor::Device) -> Self {
157        let current = self.data.read().clone();
158        if current.device() == device {
159            return self.clone();
160        }
161        let moved = current
162            .to_device(device)
163            .expect("Failed to move variable to device");
164        Variable::new(moved, self.requires_grad)
165    }
166
167    /// Returns whether this variable requires gradients.
168    #[must_use]
169    pub fn requires_grad(&self) -> bool {
170        self.requires_grad
171    }
172
173    /// Returns whether this is a leaf variable.
174    #[must_use]
175    pub fn is_leaf(&self) -> bool {
176        self.is_leaf
177    }
178
179    /// Returns the gradient of this variable.
180    ///
181    /// Only available for leaf variables after `backward()` has been called.
182    #[must_use]
183    pub fn grad(&self) -> Option<Tensor<f32>> {
184        self.grad.read().clone()
185    }
186
187    /// Returns the gradient function.
188    #[must_use]
189    pub fn grad_fn(&self) -> Option<&GradFn> {
190        self.grad_fn.as_ref()
191    }
192
193    /// Sets the gradient (used during backward pass).
194    pub fn set_grad(&self, grad: Tensor<f32>) {
195        *self.grad.write() = Some(grad);
196    }
197
198    /// Accumulates gradient (adds to existing gradient).
199    pub fn accumulate_grad(&self, grad: &Tensor<f32>) {
200        let mut grad_lock = self.grad.write();
201        if let Some(ref existing) = *grad_lock {
202            *grad_lock = Some(existing.add(grad).unwrap());
203        } else {
204            *grad_lock = Some(grad.clone());
205        }
206    }
207
208    /// Clears the gradient.
209    pub fn zero_grad(&self) {
210        *self.grad.write() = None;
211    }
212
213    /// Detaches this variable from the computation graph.
214    ///
215    /// Returns a new variable with the same data but no gradient history.
216    #[must_use]
217    pub fn detach(&self) -> Self {
218        Self {
219            data: Arc::new(RwLock::new(self.data.read().clone())),
220            grad: Arc::new(RwLock::new(None)),
221            requires_grad: false,
222            is_leaf: true,
223            grad_fn: None,
224            node: None,
225        }
226    }
227
228    /// Returns a new variable with `requires_grad` set.
229    #[must_use]
230    pub fn requires_grad_(mut self, requires_grad: bool) -> Self {
231        self.requires_grad = requires_grad;
232        if requires_grad && self.is_leaf {
233            // AccumulateGrad shares the gradient accumulator with this variable
234            self.grad_fn = Some(GradFn::new(AccumulateGrad::new(Arc::clone(&self.grad))));
235            self.node = Some(with_graph(|g| g.register_leaf(true)));
236        }
237        self
238    }
239
240    /// Computes gradients via backpropagation.
241    ///
242    /// This should only be called on scalar (single-element) tensors,
243    /// typically the loss value.
244    pub fn backward(&self) {
245        assert!(
246            self.requires_grad,
247            "Cannot call backward on a variable that doesn't require gradients"
248        );
249
250        assert!(
251            (self.numel() == 1),
252            "backward() can only be called on scalar tensors"
253        );
254
255        // Start with gradient of 1.0 for the output, on the same device
256        let mut grad_output = Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap();
257        let device = self.data.read().device();
258        if device.is_gpu() {
259            grad_output = grad_output.to_device(device).unwrap();
260        }
261        crate::backward::backward(self, &grad_output);
262    }
263
264    /// Runs the backward pass with a provided gradient tensor.
265    ///
266    /// Unlike `backward()`, this does not require the variable to be scalar.
267    /// The gradient tensor must match the shape of this variable.
268    pub fn backward_with_grad(&self, grad_output: &Tensor<f32>) {
269        if !self.requires_grad {
270            return;
271        }
272        let device = self.data.read().device();
273        let grad = if grad_output.device() != device && device.is_gpu() {
274            grad_output.to_device(device).unwrap()
275        } else {
276            grad_output.clone()
277        };
278        crate::backward::backward(self, &grad);
279    }
280
281    // =========================================================================
282    // Arithmetic Operations
283    // =========================================================================
284
285    /// Element-wise addition.
286    #[must_use]
287    pub fn add_var(&self, other: &Variable) -> Variable {
288        let result = self.data.read().add(&other.data.read()).unwrap();
289        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
290
291        if requires_grad {
292            let grad_fn = GradFn::new(AddBackward::new(
293                self.grad_fn.clone(),
294                other.grad_fn.clone(),
295                self.shape(),
296                other.shape(),
297            ));
298            Variable::from_operation(result, grad_fn, true)
299        } else {
300            Variable::from_tensor(result)
301        }
302    }
303
304    /// Element-wise subtraction.
305    #[must_use]
306    pub fn sub_var(&self, other: &Variable) -> Variable {
307        let result = self.data.read().sub(&other.data.read()).unwrap();
308        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
309
310        if requires_grad {
311            let grad_fn = GradFn::new(SubBackward::new(
312                self.grad_fn.clone(),
313                other.grad_fn.clone(),
314                self.shape(),
315                other.shape(),
316            ));
317            Variable::from_operation(result, grad_fn, true)
318        } else {
319            Variable::from_tensor(result)
320        }
321    }
322
323    /// Element-wise multiplication.
324    #[must_use]
325    pub fn mul_var(&self, other: &Variable) -> Variable {
326        let self_data = self.data.read().clone();
327        let other_data = other.data.read().clone();
328        let result = self_data.mul(&other_data).unwrap();
329        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
330
331        if requires_grad {
332            let grad_fn = GradFn::new(MulBackward::new(
333                self.grad_fn.clone(),
334                other.grad_fn.clone(),
335                self_data,
336                other_data,
337            ));
338            Variable::from_operation(result, grad_fn, true)
339        } else {
340            Variable::from_tensor(result)
341        }
342    }
343
344    /// Element-wise division.
345    #[must_use]
346    pub fn div_var(&self, other: &Variable) -> Variable {
347        let self_data = self.data.read().clone();
348        let other_data = other.data.read().clone();
349        let result = self_data.div(&other_data).unwrap();
350        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
351
352        if requires_grad {
353            let grad_fn = GradFn::new(DivBackward::new(
354                self.grad_fn.clone(),
355                other.grad_fn.clone(),
356                self_data,
357                other_data,
358            ));
359            Variable::from_operation(result, grad_fn, true)
360        } else {
361            Variable::from_tensor(result)
362        }
363    }
364
365    /// Negation.
366    #[must_use]
367    pub fn neg_var(&self) -> Variable {
368        let result = self.data.read().neg();
369        let requires_grad = self.requires_grad && is_grad_enabled();
370
371        if requires_grad {
372            let grad_fn = GradFn::new(NegBackward::new(self.grad_fn.clone()));
373            Variable::from_operation(result, grad_fn, true)
374        } else {
375            Variable::from_tensor(result)
376        }
377    }
378
379    /// Matrix multiplication.
380    #[must_use]
381    pub fn matmul(&self, other: &Variable) -> Variable {
382        let self_data = self.data.read().clone();
383        let other_data = other.data.read().clone();
384        let result = self_data.matmul(&other_data).unwrap();
385        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
386
387        if requires_grad {
388            let grad_fn = GradFn::new(MatMulBackward::new(
389                self.grad_fn.clone(),
390                other.grad_fn.clone(),
391                self_data,
392                other_data,
393            ));
394            Variable::from_operation(result, grad_fn, true)
395        } else {
396            Variable::from_tensor(result)
397        }
398    }
399
400    /// Power operation.
401    #[must_use]
402    pub fn pow(&self, exponent: f32) -> Variable {
403        let self_data = self.data.read().clone();
404        let result = self_data.pow(exponent);
405        let requires_grad = self.requires_grad && is_grad_enabled();
406
407        if requires_grad {
408            let grad_fn = GradFn::new(PowBackward::new(self.grad_fn.clone(), self_data, exponent));
409            Variable::from_operation(result, grad_fn, true)
410        } else {
411            Variable::from_tensor(result)
412        }
413    }
414
415    // =========================================================================
416    // Activation Functions
417    // =========================================================================
418
419    /// `ReLU` activation.
420    #[must_use]
421    pub fn relu(&self) -> Variable {
422        let self_data = self.data.read().clone();
423        let result = self_data.relu();
424        let requires_grad = self.requires_grad && is_grad_enabled();
425
426        if requires_grad {
427            let grad_fn = GradFn::new(ReluBackward::new(self.grad_fn.clone(), self_data));
428            Variable::from_operation(result, grad_fn, true)
429        } else {
430            Variable::from_tensor(result)
431        }
432    }
433
434    /// Leaky ReLU activation.
435    #[must_use]
436    pub fn leaky_relu(&self, negative_slope: f32) -> Variable {
437        let self_data = self.data.read().clone();
438        let device = self_data.device();
439        let result_vec: Vec<f32> = self_data
440            .to_vec()
441            .iter()
442            .map(|&x| if x > 0.0 { x } else { x * negative_slope })
443            .collect();
444        let mut result = Tensor::from_vec(result_vec, self_data.shape()).unwrap();
445        if device.is_gpu() {
446            result = result.to_device(device).unwrap();
447        }
448        let requires_grad = self.requires_grad && is_grad_enabled();
449
450        if requires_grad {
451            let grad_fn = GradFn::new(LeakyReluBackward::new(
452                self.grad_fn.clone(),
453                self_data,
454                negative_slope,
455            ));
456            Variable::from_operation(result, grad_fn, true)
457        } else {
458            Variable::from_tensor(result)
459        }
460    }
461
462    /// ELU activation.
463    #[must_use]
464    pub fn elu(&self, alpha: f32) -> Variable {
465        let self_data = self.data.read().clone();
466        let device = self_data.device();
467        let result_vec: Vec<f32> = self_data
468            .to_vec()
469            .iter()
470            .map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
471            .collect();
472        let mut result = Tensor::from_vec(result_vec, self_data.shape()).unwrap();
473        if device.is_gpu() {
474            result = result.to_device(device).unwrap();
475        }
476        let requires_grad = self.requires_grad && is_grad_enabled();
477
478        if requires_grad {
479            let grad_fn = GradFn::new(EluBackward::new(self.grad_fn.clone(), self_data, alpha));
480            Variable::from_operation(result, grad_fn, true)
481        } else {
482            Variable::from_tensor(result)
483        }
484    }
485
486    /// Sigmoid activation.
487    #[must_use]
488    pub fn sigmoid(&self) -> Variable {
489        let result = self.data.read().sigmoid();
490        let requires_grad = self.requires_grad && is_grad_enabled();
491
492        if requires_grad {
493            let grad_fn = GradFn::new(SigmoidBackward::new(self.grad_fn.clone(), result.clone()));
494            Variable::from_operation(result, grad_fn, true)
495        } else {
496            Variable::from_tensor(result)
497        }
498    }
499
500    /// Tanh activation.
501    #[must_use]
502    pub fn tanh(&self) -> Variable {
503        let result = self.data.read().tanh();
504        let requires_grad = self.requires_grad && is_grad_enabled();
505
506        if requires_grad {
507            let grad_fn = GradFn::new(TanhBackward::new(self.grad_fn.clone(), result.clone()));
508            Variable::from_operation(result, grad_fn, true)
509        } else {
510            Variable::from_tensor(result)
511        }
512    }
513
514    /// Element-wise exponential.
515    #[must_use]
516    pub fn exp(&self) -> Variable {
517        let self_data = self.data.read().clone();
518        let result = self_data.exp();
519        let requires_grad = self.requires_grad && is_grad_enabled();
520
521        if requires_grad {
522            let grad_fn = GradFn::new(ExpBackward::new(self.grad_fn.clone(), result.clone()));
523            Variable::from_operation(result, grad_fn, true)
524        } else {
525            Variable::from_tensor(result)
526        }
527    }
528
529    /// Element-wise natural logarithm.
530    #[must_use]
531    pub fn log(&self) -> Variable {
532        let self_data = self.data.read().clone();
533        let result = self_data.ln();
534        let requires_grad = self.requires_grad && is_grad_enabled();
535
536        if requires_grad {
537            let grad_fn = GradFn::new(LogBackward::new(self.grad_fn.clone(), self_data));
538            Variable::from_operation(result, grad_fn, true)
539        } else {
540            Variable::from_tensor(result)
541        }
542    }
543
544    /// Element-wise clamp to [min_val, max_val].
545    #[must_use]
546    pub fn clamp(&self, min_val: f32, max_val: f32) -> Variable {
547        let self_data = self.data.read().clone();
548        let device = self_data.device();
549        let result_data: Vec<f32> = self_data
550            .to_vec()
551            .iter()
552            .map(|&x| x.clamp(min_val, max_val))
553            .collect();
554        let mut result = Tensor::from_vec(result_data, self_data.shape()).unwrap();
555        if device.is_gpu() {
556            result = result.to_device(device).unwrap();
557        }
558        let requires_grad = self.requires_grad && is_grad_enabled();
559
560        if requires_grad {
561            let grad_fn = GradFn::new(ClampBackward::new(
562                self.grad_fn.clone(),
563                self_data,
564                min_val,
565                max_val,
566            ));
567            Variable::from_operation(result, grad_fn, true)
568        } else {
569            Variable::from_tensor(result)
570        }
571    }
572
573    // =========================================================================
574    // Reduction Operations
575    // =========================================================================
576
577    /// Sum all elements.
578    #[must_use]
579    pub fn sum(&self) -> Variable {
580        let self_data = self.data.read().clone();
581        let result = self_data.sum(); // Returns a scalar Tensor
582        let requires_grad = self.requires_grad && is_grad_enabled();
583
584        if requires_grad {
585            let grad_fn = GradFn::new(SumBackward::new(self.grad_fn.clone(), self.shape()));
586            Variable::from_operation(result, grad_fn, true)
587        } else {
588            Variable::from_tensor(result)
589        }
590    }
591
592    /// Sum along a dimension, removing that dimension.
593    #[must_use]
594    pub fn sum_dim(&self, dim: usize) -> Variable {
595        let self_data = self.data.read().clone();
596        let result = self_data.sum_dim(dim as i32, false);
597        let requires_grad = self.requires_grad && is_grad_enabled();
598
599        if requires_grad {
600            let grad_fn = GradFn::new(SumDimBackward::new(self.grad_fn.clone(), self.shape(), dim));
601            Variable::from_operation(result, grad_fn, true)
602        } else {
603            Variable::from_tensor(result)
604        }
605    }
606
607    /// Mean of all elements.
608    #[must_use]
609    pub fn mean(&self) -> Variable {
610        let self_data = self.data.read().clone();
611        let result = self_data.mean().unwrap(); // Returns a scalar Tensor
612        let requires_grad = self.requires_grad && is_grad_enabled();
613
614        if requires_grad {
615            let grad_fn = GradFn::new(MeanBackward::new(self.grad_fn.clone(), self.shape()));
616            Variable::from_operation(result, grad_fn, true)
617        } else {
618            Variable::from_tensor(result)
619        }
620    }
621
622    // =========================================================================
623    // Loss Functions
624    // =========================================================================
625
626    /// Mean Squared Error loss.
627    #[must_use]
628    pub fn mse_loss(&self, target: &Variable) -> Variable {
629        let diff = self.sub_var(target);
630        let squared = diff.pow(2.0);
631        squared.mean()
632    }
633
634    /// Binary Cross Entropy loss (expects sigmoid output).
635    #[must_use]
636    pub fn binary_cross_entropy(&self, target: &Variable) -> Variable {
637        let eps = Variable::from_tensor(Tensor::scalar(1e-7));
638        let one = Variable::from_tensor(Tensor::scalar(1.0));
639
640        // -[y * log(p + eps) + (1 - y) * log(1 - p + eps)]
641        let log_p = self.add_var(&eps);
642        let log_1_p = one.sub_var(self).add_var(&eps);
643
644        let term1 = target.mul_var(&Variable::from_tensor(log_p.data().ln()));
645        let term2 = one
646            .sub_var(target)
647            .mul_var(&Variable::from_tensor(log_1_p.data().ln()));
648
649        term1.add_var(&term2).neg_var().mean()
650    }
651
652    // =========================================================================
653    // Shape Operations
654    // =========================================================================
655
656    /// Reshapes the variable to a new shape.
657    #[must_use]
658    pub fn reshape(&self, shape: &[usize]) -> Variable {
659        let isize_shape: Vec<isize> = shape.iter().map(|&x| x as isize).collect();
660        let original_shape = self.shape();
661        let new_data = self
662            .data()
663            .reshape(&isize_shape)
664            .unwrap_or_else(|_| self.data().clone());
665        let requires_grad = self.requires_grad && is_grad_enabled();
666
667        if requires_grad {
668            let grad_fn = GradFn::new(ReshapeBackward::new(self.grad_fn.clone(), original_shape));
669            Variable::from_operation(new_data, grad_fn, true)
670        } else {
671            Variable::from_tensor(new_data)
672        }
673    }
674
675    /// Flattens all dimensions from `start_dim` to the end into a single dimension.
676    ///
677    /// `flatten(1)` on a `[batch, C, H, W]` tensor produces `[batch, C*H*W]`.
678    /// `flatten(0)` flattens everything into a 1D vector.
679    #[must_use]
680    pub fn flatten(&self, start_dim: usize) -> Variable {
681        let shape = self.shape();
682        if start_dim >= shape.len() {
683            return self.clone();
684        }
685        let mut new_shape: Vec<usize> = shape[..start_dim].to_vec();
686        let flat: usize = shape[start_dim..].iter().product();
687        new_shape.push(flat);
688        self.reshape(&new_shape)
689    }
690
691    /// Transposes two dimensions.
692    #[must_use]
693    pub fn transpose(&self, dim0: usize, dim1: usize) -> Variable {
694        let new_data = self
695            .data()
696            .transpose(dim0 as i64, dim1 as i64)
697            .unwrap_or_else(|_| self.data().clone());
698        let requires_grad = self.requires_grad && is_grad_enabled();
699
700        if requires_grad {
701            let grad_fn = GradFn::new(TransposeBackward::new(self.grad_fn.clone(), dim0, dim1));
702            Variable::from_operation(new_data, grad_fn, true)
703        } else {
704            Variable::from_tensor(new_data)
705        }
706    }
707
708    /// Slices the variable along specified ranges.
709    #[must_use]
710    pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Variable {
711        let new_data = self.data().slice(ranges);
712        Variable::new(new_data, self.requires_grad())
713    }
714
715    /// Narrows the variable along a dimension.
716    ///
717    /// Returns a view of the tensor containing elements from `start` to `start + length`
718    /// along the specified dimension. This operation preserves gradients for backpropagation.
719    #[must_use]
720    pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Variable {
721        let input_shape = self.shape();
722        let new_data = self
723            .data()
724            .narrow(dim, start, length)
725            .unwrap_or_else(|_| self.data().clone());
726        let requires_grad = self.requires_grad && is_grad_enabled();
727
728        if requires_grad {
729            let grad_fn = GradFn::new(NarrowBackward::new(
730                self.grad_fn.clone(),
731                input_shape,
732                dim,
733                start,
734            ));
735            Variable::from_operation(new_data, grad_fn, true)
736        } else {
737            Variable::from_tensor(new_data)
738        }
739    }
740
741    /// Expands the variable to a new shape (broadcast).
742    ///
743    /// Tracks the computational graph for backward pass.
744    #[must_use]
745    pub fn expand(&self, shape: &[usize]) -> Variable {
746        let input_shape = self.shape();
747        let new_data = self.data().broadcast_to(shape);
748        let requires_grad = self.requires_grad && is_grad_enabled();
749
750        if requires_grad {
751            let grad_fn = GradFn::new(ExpandBackward::new(self.grad_fn.clone(), input_shape));
752            Variable::from_operation(new_data, grad_fn, true)
753        } else {
754            Variable::from_tensor(new_data)
755        }
756    }
757
758    /// Selects a single index along a dimension, reducing rank by 1.
759    ///
760    /// For a tensor of shape (A, B, C), `select(1, i)` returns shape (A, C).
761    /// Tracks the computational graph for backward pass.
762    #[must_use]
763    pub fn select(&self, dim: usize, index: usize) -> Variable {
764        let input_shape = self.shape();
765        let new_data = self
766            .data()
767            .select(dim, index)
768            .unwrap_or_else(|_| self.data().clone());
769        let requires_grad = self.requires_grad && is_grad_enabled();
770
771        if requires_grad {
772            let grad_fn = GradFn::new(SelectBackward::new(
773                self.grad_fn.clone(),
774                input_shape,
775                dim,
776                index,
777            ));
778            Variable::from_operation(new_data, grad_fn, true)
779        } else {
780            Variable::from_tensor(new_data)
781        }
782    }
783
784    /// Adds a dimension of size 1 at the given position.
785    ///
786    /// Tracks the computational graph for backward pass.
787    #[must_use]
788    pub fn unsqueeze(&self, dim: usize) -> Variable {
789        let new_data = self
790            .data()
791            .unsqueeze(dim as i64)
792            .unwrap_or_else(|_| self.data().clone());
793        let requires_grad = self.requires_grad && is_grad_enabled();
794
795        if requires_grad {
796            let grad_fn = GradFn::new(UnsqueezeBackward::new(self.grad_fn.clone(), dim));
797            Variable::from_operation(new_data, grad_fn, true)
798        } else {
799            Variable::from_tensor(new_data)
800        }
801    }
802
803    /// Concatenates variables along a dimension.
804    ///
805    /// All variables must have the same shape except along the cat dimension.
806    /// Tracks the computational graph for backpropagation.
807    #[must_use]
808    pub fn cat(variables: &[&Variable], dim: usize) -> Variable {
809        let tensors: Vec<Tensor<f32>> = variables.iter().map(|v| v.data()).collect();
810        let tensor_refs: Vec<&Tensor<f32>> = tensors.iter().collect();
811        let result = Tensor::cat(&tensor_refs, dim).unwrap();
812
813        let requires_grad = variables.iter().any(|v| v.requires_grad) && is_grad_enabled();
814
815        if requires_grad {
816            let next_fns: Vec<Option<GradFn>> =
817                variables.iter().map(|v| v.grad_fn.clone()).collect();
818            let sizes: Vec<usize> = variables.iter().map(|v| v.shape()[dim]).collect();
819            let grad_fn = GradFn::new(CatBackward::new(next_fns, sizes, dim));
820            Variable::from_operation(result, grad_fn, true)
821        } else {
822            Variable::from_tensor(result)
823        }
824    }
825
826    // =========================================================================
827    // Scalar Operations
828    // =========================================================================
829
830    /// Multiplies by a scalar.
831    #[must_use]
832    pub fn mul_scalar(&self, scalar: f32) -> Variable {
833        let data = self.data();
834        let result = data.mul_scalar(scalar);
835        let requires_grad = self.requires_grad && is_grad_enabled();
836
837        if requires_grad {
838            let grad_fn = GradFn::new(MulScalarBackward::new(self.grad_fn.clone(), scalar));
839            Variable::from_operation(result, grad_fn, true)
840        } else {
841            Variable::from_tensor(result)
842        }
843    }
844
845    /// Adds a scalar.
846    #[must_use]
847    pub fn add_scalar(&self, scalar: f32) -> Variable {
848        let data = self.data();
849        let result = data.add_scalar(scalar);
850        let requires_grad = self.requires_grad && is_grad_enabled();
851
852        if requires_grad {
853            let grad_fn = GradFn::new(AddScalarBackward::new(self.grad_fn.clone()));
854            Variable::from_operation(result, grad_fn, true)
855        } else {
856            Variable::from_tensor(result)
857        }
858    }
859
860    /// Subtracts a scalar.
861    #[must_use]
862    pub fn sub_scalar(&self, scalar: f32) -> Variable {
863        self.add_scalar(-scalar)
864    }
865
866    /// Divides by a scalar.
867    #[must_use]
868    pub fn div_scalar(&self, scalar: f32) -> Variable {
869        self.mul_scalar(1.0 / scalar)
870    }
871
872    // =========================================================================
873    // Additional Activations
874    // =========================================================================
875
876    /// GELU activation function (Gaussian Error Linear Unit).
877    #[must_use]
878    pub fn gelu(&self) -> Variable {
879        let self_data = self.data();
880        let result = self_data.gelu();
881        let requires_grad = self.requires_grad && is_grad_enabled();
882
883        if requires_grad {
884            let grad_fn = GradFn::new(GeluBackward::new(self.grad_fn.clone(), self_data));
885            Variable::from_operation(result, grad_fn, true)
886        } else {
887            Variable::from_tensor(result)
888        }
889    }
890
891    /// SiLU/Swish activation function (x * sigmoid(x)).
892    #[must_use]
893    pub fn silu(&self) -> Variable {
894        let self_data = self.data();
895        let result = self_data.silu();
896        let requires_grad = self.requires_grad && is_grad_enabled();
897
898        if requires_grad {
899            let grad_fn = GradFn::new(SiluBackward::new(self.grad_fn.clone(), self_data));
900            Variable::from_operation(result, grad_fn, true)
901        } else {
902            Variable::from_tensor(result)
903        }
904    }
905
906    /// Square root.
907    #[must_use]
908    pub fn sqrt(&self) -> Variable {
909        let data = self.data();
910        let result = data.sqrt();
911        let requires_grad = self.requires_grad && is_grad_enabled();
912
913        if requires_grad {
914            let grad_fn = GradFn::new(SqrtBackward::new(self.grad_fn.clone(), result.clone()));
915            Variable::from_operation(result, grad_fn, true)
916        } else {
917            Variable::from_tensor(result)
918        }
919    }
920
921    // =========================================================================
922    // Softmax Operations
923    // =========================================================================
924
925    /// Softmax along specified dimension.
926    #[must_use]
927    pub fn softmax(&self, dim: i32) -> Variable {
928        let data = self.data();
929        let result = data.softmax(dim);
930        let requires_grad = self.requires_grad && is_grad_enabled();
931
932        if requires_grad {
933            let grad_fn = GradFn::new(SoftmaxBackward::new(
934                self.grad_fn.clone(),
935                result.clone(),
936                dim as i64,
937            ));
938            Variable::from_operation(result, grad_fn, true)
939        } else {
940            Variable::from_tensor(result)
941        }
942    }
943
944    /// Log softmax along specified dimension.
945    #[must_use]
946    pub fn log_softmax(&self, dim: i32) -> Variable {
947        let data = self.data();
948        let result = data.log_softmax(dim);
949        let requires_grad = self.requires_grad && is_grad_enabled();
950
951        if requires_grad {
952            let grad_fn = GradFn::new(LogSoftmaxBackward::new(
953                self.grad_fn.clone(),
954                result.clone(),
955                dim as i64,
956            ));
957            Variable::from_operation(result, grad_fn, true)
958        } else {
959            Variable::from_tensor(result)
960        }
961    }
962
963    // =========================================================================
964    // Reduction Operations with Dimensions
965    // =========================================================================
966
967    /// Mean along a dimension, optionally keeping the dimension.
968    #[must_use]
969    pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Variable {
970        let data = self.data();
971        let input_shape = data.shape().to_vec();
972        let ndim = input_shape.len();
973        let dim_usize = if dim < 0 {
974            (ndim as i32 + dim) as usize
975        } else {
976            dim as usize
977        };
978        let result = data.mean_dim(dim, keepdim);
979        let requires_grad = self.requires_grad && is_grad_enabled();
980
981        if requires_grad {
982            let grad_fn = GradFn::new(MeanDimBackward::new(
983                self.grad_fn.clone(),
984                input_shape,
985                dim_usize,
986                keepdim,
987            ));
988            Variable::from_operation(result, grad_fn, true)
989        } else {
990            Variable::from_tensor(result)
991        }
992    }
993
994    /// Variance along a dimension, optionally keeping the dimension.
995    #[must_use]
996    pub fn var_dim(&self, dim: i32, keepdim: bool) -> Variable {
997        let self_data = self.data();
998        let input_shape = self_data.shape().to_vec();
999        let ndim = input_shape.len();
1000        let dim_usize = if dim < 0 {
1001            (ndim as i32 + dim) as usize
1002        } else {
1003            dim as usize
1004        };
1005        let result = self_data.var_dim(dim, keepdim);
1006        let requires_grad = self.requires_grad && is_grad_enabled();
1007
1008        if requires_grad {
1009            let grad_fn = GradFn::new(VarDimBackward::new(
1010                self.grad_fn.clone(),
1011                self_data,
1012                dim_usize,
1013                keepdim,
1014            ));
1015            Variable::from_operation(result, grad_fn, true)
1016        } else {
1017            Variable::from_tensor(result)
1018        }
1019    }
1020
1021    // =========================================================================
1022    // Utility Methods
1023    // =========================================================================
1024
1025    /// Creates a Variable from a tensor and requires_grad flag (for weight access).
1026    /// This is typically used internally by Parameter types.
1027    #[must_use]
1028    pub fn from_tensor_with_grad(data: Tensor<f32>, requires_grad: bool) -> Variable {
1029        Variable::new(data, requires_grad)
1030    }
1031
1032    /// Clones the variable (alias for Clone trait).
1033    #[must_use]
1034    pub fn clone_var(&self) -> Variable {
1035        self.clone()
1036    }
1037
1038    /// Adds another variable (alias for add_var for method chaining).
1039    #[must_use]
1040    pub fn add(&self, other: &Variable) -> Variable {
1041        self.add_var(other)
1042    }
1043
1044    /// Subtracts another variable (alias for sub_var for method chaining).
1045    #[must_use]
1046    pub fn sub(&self, other: &Variable) -> Variable {
1047        self.sub_var(other)
1048    }
1049
1050    /// Multiplies by another variable (alias for mul_var for method chaining).
1051    #[must_use]
1052    pub fn mul(&self, other: &Variable) -> Variable {
1053        self.mul_var(other)
1054    }
1055
1056    /// Divides by another variable (alias for div_var for method chaining).
1057    #[must_use]
1058    pub fn div(&self, other: &Variable) -> Variable {
1059        self.div_var(other)
1060    }
1061}
1062
1063// =============================================================================
1064// Operator Overloads
1065// =============================================================================
1066
1067impl Add for &Variable {
1068    type Output = Variable;
1069
1070    fn add(self, other: &Variable) -> Variable {
1071        self.add_var(other)
1072    }
1073}
1074
1075impl Sub for &Variable {
1076    type Output = Variable;
1077
1078    fn sub(self, other: &Variable) -> Variable {
1079        self.sub_var(other)
1080    }
1081}
1082
1083impl Mul for &Variable {
1084    type Output = Variable;
1085
1086    fn mul(self, other: &Variable) -> Variable {
1087        self.mul_var(other)
1088    }
1089}
1090
1091impl Div for &Variable {
1092    type Output = Variable;
1093
1094    fn div(self, other: &Variable) -> Variable {
1095        self.div_var(other)
1096    }
1097}
1098
1099impl Neg for &Variable {
1100    type Output = Variable;
1101
1102    fn neg(self) -> Variable {
1103        self.neg_var()
1104    }
1105}
1106
1107impl std::fmt::Debug for Variable {
1108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1109        f.debug_struct("Variable")
1110            .field("shape", &self.shape())
1111            .field("requires_grad", &self.requires_grad)
1112            .field("is_leaf", &self.is_leaf)
1113            .field(
1114                "grad_fn",
1115                &self.grad_fn.as_ref().map(super::grad_fn::GradFn::name),
1116            )
1117            .finish()
1118    }
1119}
1120
1121// =============================================================================
1122// Tests
1123// =============================================================================
1124
1125#[cfg(test)]
1126mod tests {
1127    use super::*;
1128    use axonml_tensor::zeros;
1129
1130    #[test]
1131    fn test_variable_creation() {
1132        let t = zeros::<f32>(&[2, 3]);
1133        let v = Variable::new(t, true);
1134        assert!(v.requires_grad());
1135        assert!(v.is_leaf());
1136        assert_eq!(v.shape(), vec![2, 3]);
1137    }
1138
1139    #[test]
1140    fn test_variable_no_grad() {
1141        let t = zeros::<f32>(&[2, 3]);
1142        let v = Variable::from_tensor(t);
1143        assert!(!v.requires_grad());
1144    }
1145
1146    #[test]
1147    fn test_variable_add() {
1148        let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
1149        let b = Variable::new(Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(), true);
1150        let c = &a + &b;
1151        assert_eq!(c.data().to_vec(), vec![5.0, 7.0, 9.0]);
1152        assert!(c.requires_grad());
1153        assert!(!c.is_leaf());
1154    }
1155
1156    #[test]
1157    fn test_variable_detach() {
1158        let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
1159        let b = a.detach();
1160        assert!(!b.requires_grad());
1161        assert!(b.is_leaf());
1162    }
1163
1164    #[test]
1165    fn test_mse_loss() {
1166        let pred = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
1167        let target = Variable::from_tensor(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
1168        let loss = pred.mse_loss(&target);
1169        assert_eq!(loss.numel(), 1);
1170        assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
1171    }
1172
1173    #[test]
1174    fn test_exp() {
1175        let a = Variable::new(Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap(), true);
1176        let b = a.exp();
1177        assert!((b.data().to_vec()[0] - 1.0).abs() < 1e-5);
1178        assert!((b.data().to_vec()[1] - std::f32::consts::E).abs() < 1e-4);
1179
1180        b.sum().backward();
1181        let grad = a.grad().unwrap().to_vec();
1182        // d/dx(exp(x)) = exp(x)
1183        assert!((grad[0] - 1.0).abs() < 1e-5);
1184        assert!((grad[1] - std::f32::consts::E).abs() < 1e-4);
1185    }
1186
1187    #[test]
1188    fn test_log() {
1189        let a = Variable::new(
1190            Tensor::from_vec(vec![1.0, std::f32::consts::E, 10.0], &[3]).unwrap(),
1191            true,
1192        );
1193        let b = a.log();
1194        assert!((b.data().to_vec()[0] - 0.0).abs() < 1e-5);
1195        assert!((b.data().to_vec()[1] - 1.0).abs() < 1e-5);
1196
1197        b.sum().backward();
1198        let grad = a.grad().unwrap().to_vec();
1199        // d/dx(log(x)) = 1/x
1200        assert!((grad[0] - 1.0).abs() < 1e-5);
1201        assert!((grad[1] - 1.0 / std::f32::consts::E).abs() < 1e-5);
1202    }
1203
1204    #[test]
1205    fn test_clamp() {
1206        let a = Variable::new(Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap(), true);
1207        let b = a.clamp(0.0, 1.0);
1208        assert_eq!(b.data().to_vec(), vec![0.0, 0.5, 1.0]);
1209
1210        b.sum().backward();
1211        let grad = a.grad().unwrap().to_vec();
1212        // Gradient passes through only where not clamped
1213        assert_eq!(grad[0], 0.0); // clamped at min
1214        assert_eq!(grad[1], 1.0); // not clamped
1215        assert_eq!(grad[2], 0.0); // clamped at max
1216    }
1217}