Skip to main content

axonml_autograd/
variable.rs

1//! Variable - Tensor with Gradient Tracking
2//!
3//! The Variable struct wraps a Tensor and adds automatic differentiation
4//! capabilities. Variables track their computational history to enable
5//! gradient computation via backpropagation.
6//!
7//! @version 0.1.0
8//! @author `AutomataNexus` Development Team
9
10use std::ops::{Add, Div, Mul, Neg, Sub};
11use std::sync::Arc;
12
13use parking_lot::RwLock;
14
15use axonml_tensor::Tensor;
16
17use crate::functions::{
18    AddBackward, CatBackward, DivBackward, ExpandBackward, MatMulBackward, MeanBackward,
19    MulBackward, NarrowBackward, NegBackward, PowBackward, ReluBackward, ReshapeBackward,
20    SelectBackward, SigmoidBackward, SubBackward, SumBackward, SumDimBackward, TanhBackward,
21    TransposeBackward, UnsqueezeBackward,
22};
23use crate::grad_fn::{AccumulateGrad, GradAccumulator, GradFn};
24use crate::graph::{with_graph, GraphNode};
25use crate::no_grad::is_grad_enabled;
26
27// =============================================================================
28// Variable Struct
29// =============================================================================
30
31/// A tensor with automatic differentiation support.
32///
33/// Variable wraps a Tensor and tracks operations performed on it to enable
34/// automatic gradient computation. When `requires_grad` is true, all operations
35/// are recorded in a computational graph.
36#[derive(Clone)]
37pub struct Variable {
38    /// The underlying tensor data.
39    data: Arc<RwLock<Tensor<f32>>>,
40    /// Shared gradient accumulator (for leaf variables, shared with `AccumulateGrad`).
41    grad: GradAccumulator,
42    /// Whether this variable requires gradient computation.
43    requires_grad: bool,
44    /// Whether this is a leaf variable (created by user, not an operation).
45    is_leaf: bool,
46    /// The gradient function for backpropagation.
47    grad_fn: Option<GradFn>,
48    /// Graph node for this variable.
49    node: Option<Arc<GraphNode>>,
50}
51
52impl Variable {
53    /// Creates a new variable from a tensor.
54    ///
55    /// # Arguments
56    /// * `data` - The tensor data
57    /// * `requires_grad` - Whether to track gradients for this variable
58    #[must_use]
59    pub fn new(data: Tensor<f32>, requires_grad: bool) -> Self {
60        // Create shared gradient accumulator
61        let grad: GradAccumulator = Arc::new(RwLock::new(None));
62
63        let node = if requires_grad {
64            Some(with_graph(|g| g.register_leaf(true)))
65        } else {
66            None
67        };
68
69        // Create AccumulateGrad with shared gradient storage
70        let grad_fn = if requires_grad {
71            Some(GradFn::new(AccumulateGrad::new(Arc::clone(&grad))))
72        } else {
73            None
74        };
75
76        Self {
77            data: Arc::new(RwLock::new(data)),
78            grad,
79            requires_grad,
80            is_leaf: true,
81            grad_fn,
82            node,
83        }
84    }
85
86    /// Creates a variable that doesn't require gradients.
87    #[must_use]
88    pub fn from_tensor(data: Tensor<f32>) -> Self {
89        Self::new(data, false)
90    }
91
92    /// Creates a new variable from an operation result with an attached gradient function.
93    ///
94    /// This connects the variable to the computational graph, allowing gradients
95    /// to flow backward through the operation that produced this variable.
96    pub fn from_operation(data: Tensor<f32>, grad_fn: GradFn, requires_grad: bool) -> Self {
97        let node = if requires_grad {
98            Some(with_graph(|g| g.register_operation(grad_fn.clone(), true)))
99        } else {
100            None
101        };
102
103        Self {
104            data: Arc::new(RwLock::new(data)),
105            grad: Arc::new(RwLock::new(None)),
106            requires_grad,
107            is_leaf: false,
108            grad_fn: if requires_grad { Some(grad_fn) } else { None },
109            node,
110        }
111    }
112
113    /// Returns a reference to the underlying tensor data.
114    #[must_use]
115    pub fn data(&self) -> Tensor<f32> {
116        self.data.read().clone()
117    }
118
119    /// Returns the shape of the tensor.
120    #[must_use]
121    pub fn shape(&self) -> Vec<usize> {
122        self.data.read().shape().to_vec()
123    }
124
125    /// Returns the number of dimensions.
126    #[must_use]
127    pub fn ndim(&self) -> usize {
128        self.data.read().ndim()
129    }
130
131    /// Returns the total number of elements.
132    #[must_use]
133    pub fn numel(&self) -> usize {
134        self.data.read().numel()
135    }
136
137    /// Returns whether this variable requires gradients.
138    #[must_use]
139    pub fn requires_grad(&self) -> bool {
140        self.requires_grad
141    }
142
143    /// Returns whether this is a leaf variable.
144    #[must_use]
145    pub fn is_leaf(&self) -> bool {
146        self.is_leaf
147    }
148
149    /// Returns the gradient of this variable.
150    ///
151    /// Only available for leaf variables after `backward()` has been called.
152    #[must_use]
153    pub fn grad(&self) -> Option<Tensor<f32>> {
154        self.grad.read().clone()
155    }
156
157    /// Returns the gradient function.
158    #[must_use]
159    pub fn grad_fn(&self) -> Option<&GradFn> {
160        self.grad_fn.as_ref()
161    }
162
163    /// Sets the gradient (used during backward pass).
164    pub fn set_grad(&self, grad: Tensor<f32>) {
165        *self.grad.write() = Some(grad);
166    }
167
168    /// Accumulates gradient (adds to existing gradient).
169    pub fn accumulate_grad(&self, grad: &Tensor<f32>) {
170        let mut grad_lock = self.grad.write();
171        if let Some(ref existing) = *grad_lock {
172            *grad_lock = Some(existing.add(grad).unwrap());
173        } else {
174            *grad_lock = Some(grad.clone());
175        }
176    }
177
178    /// Clears the gradient.
179    pub fn zero_grad(&self) {
180        *self.grad.write() = None;
181    }
182
183    /// Detaches this variable from the computation graph.
184    ///
185    /// Returns a new variable with the same data but no gradient history.
186    #[must_use]
187    pub fn detach(&self) -> Self {
188        Self {
189            data: Arc::new(RwLock::new(self.data.read().clone())),
190            grad: Arc::new(RwLock::new(None)),
191            requires_grad: false,
192            is_leaf: true,
193            grad_fn: None,
194            node: None,
195        }
196    }
197
198    /// Returns a new variable with `requires_grad` set.
199    #[must_use]
200    pub fn requires_grad_(mut self, requires_grad: bool) -> Self {
201        self.requires_grad = requires_grad;
202        if requires_grad && self.is_leaf {
203            // AccumulateGrad shares the gradient accumulator with this variable
204            self.grad_fn = Some(GradFn::new(AccumulateGrad::new(Arc::clone(&self.grad))));
205            self.node = Some(with_graph(|g| g.register_leaf(true)));
206        }
207        self
208    }
209
210    /// Computes gradients via backpropagation.
211    ///
212    /// This should only be called on scalar (single-element) tensors,
213    /// typically the loss value.
214    pub fn backward(&self) {
215        assert!(
216            self.requires_grad,
217            "Cannot call backward on a variable that doesn't require gradients"
218        );
219
220        assert!(
221            (self.numel() == 1),
222            "backward() can only be called on scalar tensors"
223        );
224
225        // Start with gradient of 1.0 for the output
226        let grad_output = Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap();
227        crate::backward::backward(self, &grad_output);
228    }
229
230    // =========================================================================
231    // Arithmetic Operations
232    // =========================================================================
233
234    /// Element-wise addition.
235    #[must_use]
236    pub fn add_var(&self, other: &Variable) -> Variable {
237        let result = self.data.read().add(&other.data.read()).unwrap();
238        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
239
240        if requires_grad {
241            let grad_fn = GradFn::new(AddBackward::new(
242                self.grad_fn.clone(),
243                other.grad_fn.clone(),
244                self.shape(),
245                other.shape(),
246            ));
247            Variable::from_operation(result, grad_fn, true)
248        } else {
249            Variable::from_tensor(result)
250        }
251    }
252
253    /// Element-wise subtraction.
254    #[must_use]
255    pub fn sub_var(&self, other: &Variable) -> Variable {
256        let result = self.data.read().sub(&other.data.read()).unwrap();
257        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
258
259        if requires_grad {
260            let grad_fn = GradFn::new(SubBackward::new(
261                self.grad_fn.clone(),
262                other.grad_fn.clone(),
263                self.shape(),
264                other.shape(),
265            ));
266            Variable::from_operation(result, grad_fn, true)
267        } else {
268            Variable::from_tensor(result)
269        }
270    }
271
272    /// Element-wise multiplication.
273    #[must_use]
274    pub fn mul_var(&self, other: &Variable) -> Variable {
275        let self_data = self.data.read().clone();
276        let other_data = other.data.read().clone();
277        let result = self_data.mul(&other_data).unwrap();
278        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
279
280        if requires_grad {
281            let grad_fn = GradFn::new(MulBackward::new(
282                self.grad_fn.clone(),
283                other.grad_fn.clone(),
284                self_data,
285                other_data,
286            ));
287            Variable::from_operation(result, grad_fn, true)
288        } else {
289            Variable::from_tensor(result)
290        }
291    }
292
293    /// Element-wise division.
294    #[must_use]
295    pub fn div_var(&self, other: &Variable) -> Variable {
296        let self_data = self.data.read().clone();
297        let other_data = other.data.read().clone();
298        let result = self_data.div(&other_data).unwrap();
299        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
300
301        if requires_grad {
302            let grad_fn = GradFn::new(DivBackward::new(
303                self.grad_fn.clone(),
304                other.grad_fn.clone(),
305                self_data,
306                other_data,
307            ));
308            Variable::from_operation(result, grad_fn, true)
309        } else {
310            Variable::from_tensor(result)
311        }
312    }
313
314    /// Negation.
315    #[must_use]
316    pub fn neg_var(&self) -> Variable {
317        let result = self.data.read().neg();
318        let requires_grad = self.requires_grad && is_grad_enabled();
319
320        if requires_grad {
321            let grad_fn = GradFn::new(NegBackward::new(self.grad_fn.clone()));
322            Variable::from_operation(result, grad_fn, true)
323        } else {
324            Variable::from_tensor(result)
325        }
326    }
327
328    /// Matrix multiplication.
329    #[must_use]
330    pub fn matmul(&self, other: &Variable) -> Variable {
331        let self_data = self.data.read().clone();
332        let other_data = other.data.read().clone();
333        let result = self_data.matmul(&other_data).unwrap();
334        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
335
336        if requires_grad {
337            let grad_fn = GradFn::new(MatMulBackward::new(
338                self.grad_fn.clone(),
339                other.grad_fn.clone(),
340                self_data,
341                other_data,
342            ));
343            Variable::from_operation(result, grad_fn, true)
344        } else {
345            Variable::from_tensor(result)
346        }
347    }
348
349    /// Power operation.
350    #[must_use]
351    pub fn pow(&self, exponent: f32) -> Variable {
352        let self_data = self.data.read().clone();
353        let result = self_data.pow(exponent);
354        let requires_grad = self.requires_grad && is_grad_enabled();
355
356        if requires_grad {
357            let grad_fn = GradFn::new(PowBackward::new(self.grad_fn.clone(), self_data, exponent));
358            Variable::from_operation(result, grad_fn, true)
359        } else {
360            Variable::from_tensor(result)
361        }
362    }
363
364    // =========================================================================
365    // Activation Functions
366    // =========================================================================
367
368    /// `ReLU` activation.
369    #[must_use]
370    pub fn relu(&self) -> Variable {
371        let self_data = self.data.read().clone();
372        let result = self_data.relu();
373        let requires_grad = self.requires_grad && is_grad_enabled();
374
375        if requires_grad {
376            let grad_fn = GradFn::new(ReluBackward::new(self.grad_fn.clone(), self_data));
377            Variable::from_operation(result, grad_fn, true)
378        } else {
379            Variable::from_tensor(result)
380        }
381    }
382
383    /// Sigmoid activation.
384    #[must_use]
385    pub fn sigmoid(&self) -> Variable {
386        let result = self.data.read().sigmoid();
387        let requires_grad = self.requires_grad && is_grad_enabled();
388
389        if requires_grad {
390            let grad_fn = GradFn::new(SigmoidBackward::new(self.grad_fn.clone(), result.clone()));
391            Variable::from_operation(result, grad_fn, true)
392        } else {
393            Variable::from_tensor(result)
394        }
395    }
396
397    /// Tanh activation.
398    #[must_use]
399    pub fn tanh(&self) -> Variable {
400        let result = self.data.read().tanh();
401        let requires_grad = self.requires_grad && is_grad_enabled();
402
403        if requires_grad {
404            let grad_fn = GradFn::new(TanhBackward::new(self.grad_fn.clone(), result.clone()));
405            Variable::from_operation(result, grad_fn, true)
406        } else {
407            Variable::from_tensor(result)
408        }
409    }
410
411    // =========================================================================
412    // Reduction Operations
413    // =========================================================================
414
415    /// Sum all elements.
416    #[must_use]
417    pub fn sum(&self) -> Variable {
418        let self_data = self.data.read().clone();
419        let result = self_data.sum(); // Returns a scalar Tensor
420        let requires_grad = self.requires_grad && is_grad_enabled();
421
422        if requires_grad {
423            let grad_fn = GradFn::new(SumBackward::new(self.grad_fn.clone(), self.shape()));
424            Variable::from_operation(result, grad_fn, true)
425        } else {
426            Variable::from_tensor(result)
427        }
428    }
429
430    /// Sum along a dimension, removing that dimension.
431    #[must_use]
432    pub fn sum_dim(&self, dim: usize) -> Variable {
433        let self_data = self.data.read().clone();
434        let result = self_data.sum_dim(dim as i32, false);
435        let requires_grad = self.requires_grad && is_grad_enabled();
436
437        if requires_grad {
438            let grad_fn = GradFn::new(SumDimBackward::new(
439                self.grad_fn.clone(),
440                self.shape(),
441                dim,
442            ));
443            Variable::from_operation(result, grad_fn, true)
444        } else {
445            Variable::from_tensor(result)
446        }
447    }
448
449    /// Mean of all elements.
450    #[must_use]
451    pub fn mean(&self) -> Variable {
452        let self_data = self.data.read().clone();
453        let result = self_data.mean().unwrap(); // Returns a scalar Tensor
454        let requires_grad = self.requires_grad && is_grad_enabled();
455
456        if requires_grad {
457            let grad_fn = GradFn::new(MeanBackward::new(self.grad_fn.clone(), self.shape()));
458            Variable::from_operation(result, grad_fn, true)
459        } else {
460            Variable::from_tensor(result)
461        }
462    }
463
464    // =========================================================================
465    // Loss Functions
466    // =========================================================================
467
468    /// Mean Squared Error loss.
469    #[must_use]
470    pub fn mse_loss(&self, target: &Variable) -> Variable {
471        let diff = self.sub_var(target);
472        let squared = diff.pow(2.0);
473        squared.mean()
474    }
475
476    /// Binary Cross Entropy loss (expects sigmoid output).
477    #[must_use]
478    pub fn binary_cross_entropy(&self, target: &Variable) -> Variable {
479        let eps = Variable::from_tensor(Tensor::scalar(1e-7));
480        let one = Variable::from_tensor(Tensor::scalar(1.0));
481
482        // -[y * log(p + eps) + (1 - y) * log(1 - p + eps)]
483        let log_p = self.add_var(&eps);
484        let log_1_p = one.sub_var(self).add_var(&eps);
485
486        let term1 = target.mul_var(&Variable::from_tensor(log_p.data().ln()));
487        let term2 = one
488            .sub_var(target)
489            .mul_var(&Variable::from_tensor(log_1_p.data().ln()));
490
491        term1.add_var(&term2).neg_var().mean()
492    }
493
494    // =========================================================================
495    // Shape Operations
496    // =========================================================================
497
498    /// Reshapes the variable to a new shape.
499    #[must_use]
500    pub fn reshape(&self, shape: &[usize]) -> Variable {
501        let isize_shape: Vec<isize> = shape.iter().map(|&x| x as isize).collect();
502        let original_shape = self.shape();
503        let new_data = self
504            .data()
505            .reshape(&isize_shape)
506            .unwrap_or_else(|_| self.data().clone());
507        let requires_grad = self.requires_grad && is_grad_enabled();
508
509        if requires_grad {
510            let grad_fn = GradFn::new(ReshapeBackward::new(self.grad_fn.clone(), original_shape));
511            Variable::from_operation(new_data, grad_fn, true)
512        } else {
513            Variable::from_tensor(new_data)
514        }
515    }
516
517    /// Transposes two dimensions.
518    #[must_use]
519    pub fn transpose(&self, dim0: usize, dim1: usize) -> Variable {
520        let new_data = self
521            .data()
522            .transpose(dim0 as i64, dim1 as i64)
523            .unwrap_or_else(|_| self.data().clone());
524        let requires_grad = self.requires_grad && is_grad_enabled();
525
526        if requires_grad {
527            let grad_fn = GradFn::new(TransposeBackward::new(self.grad_fn.clone(), dim0, dim1));
528            Variable::from_operation(new_data, grad_fn, true)
529        } else {
530            Variable::from_tensor(new_data)
531        }
532    }
533
534    /// Slices the variable along specified ranges.
535    #[must_use]
536    pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Variable {
537        let new_data = self.data().slice(ranges);
538        Variable::new(new_data, self.requires_grad())
539    }
540
541    /// Narrows the variable along a dimension.
542    ///
543    /// Returns a view of the tensor containing elements from `start` to `start + length`
544    /// along the specified dimension. This operation preserves gradients for backpropagation.
545    #[must_use]
546    pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Variable {
547        let input_shape = self.shape();
548        let new_data = self
549            .data()
550            .narrow(dim, start, length)
551            .unwrap_or_else(|_| self.data().clone());
552        let requires_grad = self.requires_grad && is_grad_enabled();
553
554        if requires_grad {
555            let grad_fn = GradFn::new(NarrowBackward::new(
556                self.grad_fn.clone(),
557                input_shape,
558                dim,
559                start,
560            ));
561            Variable::from_operation(new_data, grad_fn, true)
562        } else {
563            Variable::from_tensor(new_data)
564        }
565    }
566
567    /// Expands the variable to a new shape (broadcast).
568    ///
569    /// Tracks the computational graph for backward pass.
570    #[must_use]
571    pub fn expand(&self, shape: &[usize]) -> Variable {
572        let input_shape = self.shape();
573        let new_data = self.data().broadcast_to(shape);
574        let requires_grad = self.requires_grad && is_grad_enabled();
575
576        if requires_grad {
577            let grad_fn = GradFn::new(ExpandBackward::new(self.grad_fn.clone(), input_shape));
578            Variable::from_operation(new_data, grad_fn, true)
579        } else {
580            Variable::from_tensor(new_data)
581        }
582    }
583
584    /// Selects a single index along a dimension, reducing rank by 1.
585    ///
586    /// For a tensor of shape (A, B, C), `select(1, i)` returns shape (A, C).
587    /// Tracks the computational graph for backward pass.
588    #[must_use]
589    pub fn select(&self, dim: usize, index: usize) -> Variable {
590        let input_shape = self.shape();
591        let new_data = self.data().select(dim, index)
592            .unwrap_or_else(|_| self.data().clone());
593        let requires_grad = self.requires_grad && is_grad_enabled();
594
595        if requires_grad {
596            let grad_fn = GradFn::new(SelectBackward::new(
597                self.grad_fn.clone(),
598                input_shape,
599                dim,
600                index,
601            ));
602            Variable::from_operation(new_data, grad_fn, true)
603        } else {
604            Variable::from_tensor(new_data)
605        }
606    }
607
608    /// Adds a dimension of size 1 at the given position.
609    ///
610    /// Tracks the computational graph for backward pass.
611    #[must_use]
612    pub fn unsqueeze(&self, dim: usize) -> Variable {
613        let new_data = self.data().unsqueeze(dim as i64)
614            .unwrap_or_else(|_| self.data().clone());
615        let requires_grad = self.requires_grad && is_grad_enabled();
616
617        if requires_grad {
618            let grad_fn = GradFn::new(UnsqueezeBackward::new(self.grad_fn.clone(), dim));
619            Variable::from_operation(new_data, grad_fn, true)
620        } else {
621            Variable::from_tensor(new_data)
622        }
623    }
624
625    /// Concatenates variables along a dimension.
626    ///
627    /// All variables must have the same shape except along the cat dimension.
628    /// Tracks the computational graph for backpropagation.
629    #[must_use]
630    pub fn cat(variables: &[&Variable], dim: usize) -> Variable {
631        let tensors: Vec<Tensor<f32>> = variables.iter().map(|v| v.data()).collect();
632        let tensor_refs: Vec<&Tensor<f32>> = tensors.iter().collect();
633        let result = Tensor::cat(&tensor_refs, dim).unwrap();
634
635        let requires_grad = variables.iter().any(|v| v.requires_grad) && is_grad_enabled();
636
637        if requires_grad {
638            let next_fns: Vec<Option<GradFn>> =
639                variables.iter().map(|v| v.grad_fn.clone()).collect();
640            let sizes: Vec<usize> = variables.iter().map(|v| v.shape()[dim]).collect();
641            let grad_fn = GradFn::new(CatBackward::new(next_fns, sizes, dim));
642            Variable::from_operation(result, grad_fn, true)
643        } else {
644            Variable::from_tensor(result)
645        }
646    }
647
648    // =========================================================================
649    // Scalar Operations
650    // =========================================================================
651
652    /// Multiplies by a scalar.
653    #[must_use]
654    pub fn mul_scalar(&self, scalar: f32) -> Variable {
655        let data = self.data();
656        let shape = data.shape();
657        let numel: usize = shape.iter().product();
658        let scalar_tensor = Tensor::from_vec(vec![scalar; numel], shape).unwrap();
659        let scalar_var = Variable::new(scalar_tensor, false);
660        self.mul_var(&scalar_var)
661    }
662
663    /// Adds a scalar.
664    #[must_use]
665    pub fn add_scalar(&self, scalar: f32) -> Variable {
666        let data = self.data();
667        let shape = data.shape();
668        let numel: usize = shape.iter().product();
669        let scalar_tensor = Tensor::from_vec(vec![scalar; numel], shape).unwrap();
670        let scalar_var = Variable::new(scalar_tensor, false);
671        self.add_var(&scalar_var)
672    }
673
674    /// Subtracts a scalar.
675    #[must_use]
676    pub fn sub_scalar(&self, scalar: f32) -> Variable {
677        self.add_scalar(-scalar)
678    }
679
680    /// Divides by a scalar.
681    #[must_use]
682    pub fn div_scalar(&self, scalar: f32) -> Variable {
683        self.mul_scalar(1.0 / scalar)
684    }
685
686    // =========================================================================
687    // Additional Activations
688    // =========================================================================
689
690    /// GELU activation function (Gaussian Error Linear Unit).
691    #[must_use]
692    pub fn gelu(&self) -> Variable {
693        // Approximate GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
694        let data = self.data();
695        let result = data.gelu();
696        Variable::new(result, self.requires_grad())
697    }
698
699    /// SiLU/Swish activation function (x * sigmoid(x)).
700    #[must_use]
701    pub fn silu(&self) -> Variable {
702        let data = self.data();
703        let result = data.silu();
704        Variable::new(result, self.requires_grad())
705    }
706
707    /// Square root.
708    #[must_use]
709    pub fn sqrt(&self) -> Variable {
710        let data = self.data();
711        let result = data.sqrt();
712        Variable::new(result, self.requires_grad())
713    }
714
715    // =========================================================================
716    // Softmax Operations
717    // =========================================================================
718
719    /// Softmax along specified dimension.
720    #[must_use]
721    pub fn softmax(&self, dim: i32) -> Variable {
722        let data = self.data();
723        let result = data.softmax(dim);
724        Variable::new(result, self.requires_grad())
725    }
726
727    /// Log softmax along specified dimension.
728    #[must_use]
729    pub fn log_softmax(&self, dim: i32) -> Variable {
730        let data = self.data();
731        let result = data.log_softmax(dim);
732        Variable::new(result, self.requires_grad())
733    }
734
735    // =========================================================================
736    // Reduction Operations with Dimensions
737    // =========================================================================
738
739    /// Mean along a dimension, optionally keeping the dimension.
740    #[must_use]
741    pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Variable {
742        let data = self.data();
743        let result = data.mean_dim(dim, keepdim);
744        Variable::new(result, self.requires_grad())
745    }
746
747    /// Variance along a dimension, optionally keeping the dimension.
748    #[must_use]
749    pub fn var_dim(&self, dim: i32, keepdim: bool) -> Variable {
750        let data = self.data();
751        let result = data.var_dim(dim, keepdim);
752        Variable::new(result, self.requires_grad())
753    }
754
755    // =========================================================================
756    // Utility Methods
757    // =========================================================================
758
759    /// Creates a Variable from a tensor and requires_grad flag (for weight access).
760    /// This is typically used internally by Parameter types.
761    #[must_use]
762    pub fn from_tensor_with_grad(data: Tensor<f32>, requires_grad: bool) -> Variable {
763        Variable::new(data, requires_grad)
764    }
765
766    /// Clones the variable (alias for Clone trait).
767    #[must_use]
768    pub fn clone_var(&self) -> Variable {
769        self.clone()
770    }
771
772    /// Adds another variable (alias for add_var for method chaining).
773    #[must_use]
774    pub fn add(&self, other: &Variable) -> Variable {
775        self.add_var(other)
776    }
777
778    /// Subtracts another variable (alias for sub_var for method chaining).
779    #[must_use]
780    pub fn sub(&self, other: &Variable) -> Variable {
781        self.sub_var(other)
782    }
783
784    /// Multiplies by another variable (alias for mul_var for method chaining).
785    #[must_use]
786    pub fn mul(&self, other: &Variable) -> Variable {
787        self.mul_var(other)
788    }
789
790    /// Divides by another variable (alias for div_var for method chaining).
791    #[must_use]
792    pub fn div(&self, other: &Variable) -> Variable {
793        self.div_var(other)
794    }
795}
796
797// =============================================================================
798// Operator Overloads
799// =============================================================================
800
801impl Add for &Variable {
802    type Output = Variable;
803
804    fn add(self, other: &Variable) -> Variable {
805        self.add_var(other)
806    }
807}
808
809impl Sub for &Variable {
810    type Output = Variable;
811
812    fn sub(self, other: &Variable) -> Variable {
813        self.sub_var(other)
814    }
815}
816
817impl Mul for &Variable {
818    type Output = Variable;
819
820    fn mul(self, other: &Variable) -> Variable {
821        self.mul_var(other)
822    }
823}
824
825impl Div for &Variable {
826    type Output = Variable;
827
828    fn div(self, other: &Variable) -> Variable {
829        self.div_var(other)
830    }
831}
832
833impl Neg for &Variable {
834    type Output = Variable;
835
836    fn neg(self) -> Variable {
837        self.neg_var()
838    }
839}
840
841impl std::fmt::Debug for Variable {
842    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
843        f.debug_struct("Variable")
844            .field("shape", &self.shape())
845            .field("requires_grad", &self.requires_grad)
846            .field("is_leaf", &self.is_leaf)
847            .field(
848                "grad_fn",
849                &self.grad_fn.as_ref().map(super::grad_fn::GradFn::name),
850            )
851            .finish()
852    }
853}
854
855// =============================================================================
856// Tests
857// =============================================================================
858
859#[cfg(test)]
860mod tests {
861    use super::*;
862    use axonml_tensor::zeros;
863
864    #[test]
865    fn test_variable_creation() {
866        let t = zeros::<f32>(&[2, 3]);
867        let v = Variable::new(t, true);
868        assert!(v.requires_grad());
869        assert!(v.is_leaf());
870        assert_eq!(v.shape(), vec![2, 3]);
871    }
872
873    #[test]
874    fn test_variable_no_grad() {
875        let t = zeros::<f32>(&[2, 3]);
876        let v = Variable::from_tensor(t);
877        assert!(!v.requires_grad());
878    }
879
880    #[test]
881    fn test_variable_add() {
882        let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
883        let b = Variable::new(Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(), true);
884        let c = &a + &b;
885        assert_eq!(c.data().to_vec(), vec![5.0, 7.0, 9.0]);
886        assert!(c.requires_grad());
887        assert!(!c.is_leaf());
888    }
889
890    #[test]
891    fn test_variable_detach() {
892        let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
893        let b = a.detach();
894        assert!(!b.requires_grad());
895        assert!(b.is_leaf());
896    }
897
898    #[test]
899    fn test_mse_loss() {
900        let pred = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
901        let target = Variable::from_tensor(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
902        let loss = pred.mse_loss(&target);
903        assert_eq!(loss.numel(), 1);
904        assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
905    }
906}