axonml_autograd/
variable.rs

1//! Variable - Tensor with Gradient Tracking
2//!
3//! The Variable struct wraps a Tensor and adds automatic differentiation
4//! capabilities. Variables track their computational history to enable
5//! gradient computation via backpropagation.
6//!
7//! @version 0.1.0
8//! @author `AutomataNexus` Development Team
9
10use std::ops::{Add, Div, Mul, Neg, Sub};
11use std::sync::Arc;
12
13use parking_lot::RwLock;
14
15use axonml_tensor::Tensor;
16
17use crate::functions::{
18    AddBackward, DivBackward, MatMulBackward, MeanBackward, MulBackward, NegBackward, PowBackward,
19    ReluBackward, ReshapeBackward, SigmoidBackward, SubBackward, SumBackward, TanhBackward,
20    TransposeBackward,
21};
22use crate::grad_fn::{AccumulateGrad, GradAccumulator, GradFn};
23use crate::graph::{with_graph, GraphNode};
24use crate::no_grad::is_grad_enabled;
25
26// =============================================================================
27// Variable Struct
28// =============================================================================
29
30/// A tensor with automatic differentiation support.
31///
32/// Variable wraps a Tensor and tracks operations performed on it to enable
33/// automatic gradient computation. When `requires_grad` is true, all operations
34/// are recorded in a computational graph.
35#[derive(Clone)]
36pub struct Variable {
37    /// The underlying tensor data.
38    data: Arc<RwLock<Tensor<f32>>>,
39    /// Shared gradient accumulator (for leaf variables, shared with `AccumulateGrad`).
40    grad: GradAccumulator,
41    /// Whether this variable requires gradient computation.
42    requires_grad: bool,
43    /// Whether this is a leaf variable (created by user, not an operation).
44    is_leaf: bool,
45    /// The gradient function for backpropagation.
46    grad_fn: Option<GradFn>,
47    /// Graph node for this variable.
48    node: Option<Arc<GraphNode>>,
49}
50
51impl Variable {
52    /// Creates a new variable from a tensor.
53    ///
54    /// # Arguments
55    /// * `data` - The tensor data
56    /// * `requires_grad` - Whether to track gradients for this variable
57    #[must_use] pub fn new(data: Tensor<f32>, requires_grad: bool) -> Self {
58        // Create shared gradient accumulator
59        let grad: GradAccumulator = Arc::new(RwLock::new(None));
60
61        let node = if requires_grad {
62            Some(with_graph(|g| g.register_leaf(true)))
63        } else {
64            None
65        };
66
67        // Create AccumulateGrad with shared gradient storage
68        let grad_fn = if requires_grad {
69            Some(GradFn::new(AccumulateGrad::new(Arc::clone(&grad))))
70        } else {
71            None
72        };
73
74        Self {
75            data: Arc::new(RwLock::new(data)),
76            grad,
77            requires_grad,
78            is_leaf: true,
79            grad_fn,
80            node,
81        }
82    }
83
84    /// Creates a variable that doesn't require gradients.
85    #[must_use] pub fn from_tensor(data: Tensor<f32>) -> Self {
86        Self::new(data, false)
87    }
88
89    /// Creates a new variable from an operation result.
90    fn from_operation(data: Tensor<f32>, grad_fn: GradFn, requires_grad: bool) -> Self {
91        let node = if requires_grad {
92            Some(with_graph(|g| g.register_operation(grad_fn.clone(), true)))
93        } else {
94            None
95        };
96
97        Self {
98            data: Arc::new(RwLock::new(data)),
99            grad: Arc::new(RwLock::new(None)),
100            requires_grad,
101            is_leaf: false,
102            grad_fn: if requires_grad { Some(grad_fn) } else { None },
103            node,
104        }
105    }
106
107    /// Returns a reference to the underlying tensor data.
108    #[must_use] pub fn data(&self) -> Tensor<f32> {
109        self.data.read().clone()
110    }
111
112    /// Returns the shape of the tensor.
113    #[must_use] pub fn shape(&self) -> Vec<usize> {
114        self.data.read().shape().to_vec()
115    }
116
117    /// Returns the number of dimensions.
118    #[must_use] pub fn ndim(&self) -> usize {
119        self.data.read().ndim()
120    }
121
122    /// Returns the total number of elements.
123    #[must_use] pub fn numel(&self) -> usize {
124        self.data.read().numel()
125    }
126
127    /// Returns whether this variable requires gradients.
128    #[must_use] pub fn requires_grad(&self) -> bool {
129        self.requires_grad
130    }
131
132    /// Returns whether this is a leaf variable.
133    #[must_use] pub fn is_leaf(&self) -> bool {
134        self.is_leaf
135    }
136
137    /// Returns the gradient of this variable.
138    ///
139    /// Only available for leaf variables after `backward()` has been called.
140    #[must_use] pub fn grad(&self) -> Option<Tensor<f32>> {
141        self.grad.read().clone()
142    }
143
144    /// Returns the gradient function.
145    #[must_use] pub fn grad_fn(&self) -> Option<&GradFn> {
146        self.grad_fn.as_ref()
147    }
148
149    /// Sets the gradient (used during backward pass).
150    pub fn set_grad(&self, grad: Tensor<f32>) {
151        *self.grad.write() = Some(grad);
152    }
153
154    /// Accumulates gradient (adds to existing gradient).
155    pub fn accumulate_grad(&self, grad: &Tensor<f32>) {
156        let mut grad_lock = self.grad.write();
157        if let Some(ref existing) = *grad_lock {
158            *grad_lock = Some(existing.add(grad).unwrap());
159        } else {
160            *grad_lock = Some(grad.clone());
161        }
162    }
163
164    /// Clears the gradient.
165    pub fn zero_grad(&self) {
166        *self.grad.write() = None;
167    }
168
169    /// Detaches this variable from the computation graph.
170    ///
171    /// Returns a new variable with the same data but no gradient history.
172    #[must_use] pub fn detach(&self) -> Self {
173        Self {
174            data: Arc::new(RwLock::new(self.data.read().clone())),
175            grad: Arc::new(RwLock::new(None)),
176            requires_grad: false,
177            is_leaf: true,
178            grad_fn: None,
179            node: None,
180        }
181    }
182
183    /// Returns a new variable with `requires_grad` set.
184    #[must_use] pub fn requires_grad_(mut self, requires_grad: bool) -> Self {
185        self.requires_grad = requires_grad;
186        if requires_grad && self.is_leaf {
187            // AccumulateGrad shares the gradient accumulator with this variable
188            self.grad_fn = Some(GradFn::new(AccumulateGrad::new(Arc::clone(&self.grad))));
189            self.node = Some(with_graph(|g| g.register_leaf(true)));
190        }
191        self
192    }
193
194    /// Computes gradients via backpropagation.
195    ///
196    /// This should only be called on scalar (single-element) tensors,
197    /// typically the loss value.
198    pub fn backward(&self) {
199        assert!(self.requires_grad, "Cannot call backward on a variable that doesn't require gradients");
200
201        assert!((self.numel() == 1), "backward() can only be called on scalar tensors");
202
203        // Start with gradient of 1.0 for the output
204        let grad_output = Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap();
205        crate::backward::backward(self, &grad_output);
206    }
207
208    // =========================================================================
209    // Arithmetic Operations
210    // =========================================================================
211
212    /// Element-wise addition.
213    #[must_use] pub fn add_var(&self, other: &Variable) -> Variable {
214        let result = self.data.read().add(&other.data.read()).unwrap();
215        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
216
217        if requires_grad {
218            let grad_fn = GradFn::new(AddBackward::new(
219                self.grad_fn.clone(),
220                other.grad_fn.clone(),
221                self.shape(),
222                other.shape(),
223            ));
224            Variable::from_operation(result, grad_fn, true)
225        } else {
226            Variable::from_tensor(result)
227        }
228    }
229
230    /// Element-wise subtraction.
231    #[must_use] pub fn sub_var(&self, other: &Variable) -> Variable {
232        let result = self.data.read().sub(&other.data.read()).unwrap();
233        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
234
235        if requires_grad {
236            let grad_fn = GradFn::new(SubBackward::new(
237                self.grad_fn.clone(),
238                other.grad_fn.clone(),
239                self.shape(),
240                other.shape(),
241            ));
242            Variable::from_operation(result, grad_fn, true)
243        } else {
244            Variable::from_tensor(result)
245        }
246    }
247
248    /// Element-wise multiplication.
249    #[must_use] pub fn mul_var(&self, other: &Variable) -> Variable {
250        let self_data = self.data.read().clone();
251        let other_data = other.data.read().clone();
252        let result = self_data.mul(&other_data).unwrap();
253        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
254
255        if requires_grad {
256            let grad_fn = GradFn::new(MulBackward::new(
257                self.grad_fn.clone(),
258                other.grad_fn.clone(),
259                self_data,
260                other_data,
261            ));
262            Variable::from_operation(result, grad_fn, true)
263        } else {
264            Variable::from_tensor(result)
265        }
266    }
267
268    /// Element-wise division.
269    #[must_use] pub fn div_var(&self, other: &Variable) -> Variable {
270        let self_data = self.data.read().clone();
271        let other_data = other.data.read().clone();
272        let result = self_data.div(&other_data).unwrap();
273        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
274
275        if requires_grad {
276            let grad_fn = GradFn::new(DivBackward::new(
277                self.grad_fn.clone(),
278                other.grad_fn.clone(),
279                self_data,
280                other_data,
281            ));
282            Variable::from_operation(result, grad_fn, true)
283        } else {
284            Variable::from_tensor(result)
285        }
286    }
287
288    /// Negation.
289    #[must_use] pub fn neg_var(&self) -> Variable {
290        let result = self.data.read().neg();
291        let requires_grad = self.requires_grad && is_grad_enabled();
292
293        if requires_grad {
294            let grad_fn = GradFn::new(NegBackward::new(self.grad_fn.clone()));
295            Variable::from_operation(result, grad_fn, true)
296        } else {
297            Variable::from_tensor(result)
298        }
299    }
300
301    /// Matrix multiplication.
302    #[must_use] pub fn matmul(&self, other: &Variable) -> Variable {
303        let self_data = self.data.read().clone();
304        let other_data = other.data.read().clone();
305        let result = self_data.matmul(&other_data).unwrap();
306        let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
307
308        if requires_grad {
309            let grad_fn = GradFn::new(MatMulBackward::new(
310                self.grad_fn.clone(),
311                other.grad_fn.clone(),
312                self_data,
313                other_data,
314            ));
315            Variable::from_operation(result, grad_fn, true)
316        } else {
317            Variable::from_tensor(result)
318        }
319    }
320
321    /// Power operation.
322    #[must_use] pub fn pow(&self, exponent: f32) -> Variable {
323        let self_data = self.data.read().clone();
324        let result = self_data.pow(exponent);
325        let requires_grad = self.requires_grad && is_grad_enabled();
326
327        if requires_grad {
328            let grad_fn = GradFn::new(PowBackward::new(self.grad_fn.clone(), self_data, exponent));
329            Variable::from_operation(result, grad_fn, true)
330        } else {
331            Variable::from_tensor(result)
332        }
333    }
334
335    // =========================================================================
336    // Activation Functions
337    // =========================================================================
338
339    /// `ReLU` activation.
340    #[must_use] pub fn relu(&self) -> Variable {
341        let self_data = self.data.read().clone();
342        let result = self_data.relu();
343        let requires_grad = self.requires_grad && is_grad_enabled();
344
345        if requires_grad {
346            let grad_fn = GradFn::new(ReluBackward::new(self.grad_fn.clone(), self_data));
347            Variable::from_operation(result, grad_fn, true)
348        } else {
349            Variable::from_tensor(result)
350        }
351    }
352
353    /// Sigmoid activation.
354    #[must_use] pub fn sigmoid(&self) -> Variable {
355        let result = self.data.read().sigmoid();
356        let requires_grad = self.requires_grad && is_grad_enabled();
357
358        if requires_grad {
359            let grad_fn = GradFn::new(SigmoidBackward::new(self.grad_fn.clone(), result.clone()));
360            Variable::from_operation(result, grad_fn, true)
361        } else {
362            Variable::from_tensor(result)
363        }
364    }
365
366    /// Tanh activation.
367    #[must_use] pub fn tanh(&self) -> Variable {
368        let result = self.data.read().tanh();
369        let requires_grad = self.requires_grad && is_grad_enabled();
370
371        if requires_grad {
372            let grad_fn = GradFn::new(TanhBackward::new(self.grad_fn.clone(), result.clone()));
373            Variable::from_operation(result, grad_fn, true)
374        } else {
375            Variable::from_tensor(result)
376        }
377    }
378
379    // =========================================================================
380    // Reduction Operations
381    // =========================================================================
382
383    /// Sum all elements.
384    #[must_use] pub fn sum(&self) -> Variable {
385        let self_data = self.data.read().clone();
386        let result = self_data.sum(); // Returns a scalar Tensor
387        let requires_grad = self.requires_grad && is_grad_enabled();
388
389        if requires_grad {
390            let grad_fn = GradFn::new(SumBackward::new(self.grad_fn.clone(), self.shape()));
391            Variable::from_operation(result, grad_fn, true)
392        } else {
393            Variable::from_tensor(result)
394        }
395    }
396
397    /// Mean of all elements.
398    #[must_use] pub fn mean(&self) -> Variable {
399        let self_data = self.data.read().clone();
400        let result = self_data.mean().unwrap(); // Returns a scalar Tensor
401        let requires_grad = self.requires_grad && is_grad_enabled();
402
403        if requires_grad {
404            let grad_fn = GradFn::new(MeanBackward::new(self.grad_fn.clone(), self.shape()));
405            Variable::from_operation(result, grad_fn, true)
406        } else {
407            Variable::from_tensor(result)
408        }
409    }
410
411    // =========================================================================
412    // Loss Functions
413    // =========================================================================
414
415    /// Mean Squared Error loss.
416    #[must_use] pub fn mse_loss(&self, target: &Variable) -> Variable {
417        let diff = self.sub_var(target);
418        let squared = diff.pow(2.0);
419        squared.mean()
420    }
421
422    /// Binary Cross Entropy loss (expects sigmoid output).
423    #[must_use] pub fn binary_cross_entropy(&self, target: &Variable) -> Variable {
424        let eps = Variable::from_tensor(Tensor::scalar(1e-7));
425        let one = Variable::from_tensor(Tensor::scalar(1.0));
426
427        // -[y * log(p + eps) + (1 - y) * log(1 - p + eps)]
428        let log_p = self.add_var(&eps);
429        let log_1_p = one.sub_var(self).add_var(&eps);
430
431        let term1 = target.mul_var(&Variable::from_tensor(log_p.data().ln()));
432        let term2 = one
433            .sub_var(target)
434            .mul_var(&Variable::from_tensor(log_1_p.data().ln()));
435
436        term1.add_var(&term2).neg_var().mean()
437    }
438
439    // =========================================================================
440    // Shape Operations
441    // =========================================================================
442
443    /// Reshapes the variable to a new shape.
444    #[must_use] pub fn reshape(&self, shape: &[usize]) -> Variable {
445        let isize_shape: Vec<isize> = shape.iter().map(|&x| x as isize).collect();
446        let original_shape = self.shape();
447        let new_data = self.data().reshape(&isize_shape).unwrap_or_else(|_| self.data().clone());
448        let requires_grad = self.requires_grad && is_grad_enabled();
449
450        if requires_grad {
451            let grad_fn = GradFn::new(ReshapeBackward::new(self.grad_fn.clone(), original_shape));
452            Variable::from_operation(new_data, grad_fn, true)
453        } else {
454            Variable::from_tensor(new_data)
455        }
456    }
457
458    /// Transposes two dimensions.
459    #[must_use] pub fn transpose(&self, dim0: usize, dim1: usize) -> Variable {
460        let new_data = self.data().transpose(dim0 as i64, dim1 as i64).unwrap_or_else(|_| self.data().clone());
461        let requires_grad = self.requires_grad && is_grad_enabled();
462
463        if requires_grad {
464            let grad_fn = GradFn::new(TransposeBackward::new(self.grad_fn.clone(), dim0, dim1));
465            Variable::from_operation(new_data, grad_fn, true)
466        } else {
467            Variable::from_tensor(new_data)
468        }
469    }
470
471    /// Slices the variable along specified ranges.
472    #[must_use] pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Variable {
473        let new_data = self.data().slice(ranges);
474        Variable::new(new_data, self.requires_grad())
475    }
476
477    /// Expands the variable to a new shape (broadcast).
478    #[must_use] pub fn expand(&self, shape: &[usize]) -> Variable {
479        let new_data = self.data().broadcast_to(shape);
480        Variable::new(new_data, self.requires_grad())
481    }
482
483    // =========================================================================
484    // Scalar Operations
485    // =========================================================================
486
487    /// Multiplies by a scalar.
488    #[must_use] pub fn mul_scalar(&self, scalar: f32) -> Variable {
489        let data = self.data();
490        let shape = data.shape();
491        let numel: usize = shape.iter().product();
492        let scalar_tensor = Tensor::from_vec(vec![scalar; numel], shape).unwrap();
493        let scalar_var = Variable::new(scalar_tensor, false);
494        self.mul_var(&scalar_var)
495    }
496
497    /// Adds a scalar.
498    #[must_use] pub fn add_scalar(&self, scalar: f32) -> Variable {
499        let data = self.data();
500        let shape = data.shape();
501        let numel: usize = shape.iter().product();
502        let scalar_tensor = Tensor::from_vec(vec![scalar; numel], shape).unwrap();
503        let scalar_var = Variable::new(scalar_tensor, false);
504        self.add_var(&scalar_var)
505    }
506
507    /// Subtracts a scalar.
508    #[must_use] pub fn sub_scalar(&self, scalar: f32) -> Variable {
509        self.add_scalar(-scalar)
510    }
511
512    /// Divides by a scalar.
513    #[must_use] pub fn div_scalar(&self, scalar: f32) -> Variable {
514        self.mul_scalar(1.0 / scalar)
515    }
516
517    // =========================================================================
518    // Additional Activations
519    // =========================================================================
520
521    /// GELU activation function (Gaussian Error Linear Unit).
522    #[must_use] pub fn gelu(&self) -> Variable {
523        // Approximate GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
524        let data = self.data();
525        let result = data.gelu();
526        Variable::new(result, self.requires_grad())
527    }
528
529    /// SiLU/Swish activation function (x * sigmoid(x)).
530    #[must_use] pub fn silu(&self) -> Variable {
531        let data = self.data();
532        let result = data.silu();
533        Variable::new(result, self.requires_grad())
534    }
535
536    /// Square root.
537    #[must_use] pub fn sqrt(&self) -> Variable {
538        let data = self.data();
539        let result = data.sqrt();
540        Variable::new(result, self.requires_grad())
541    }
542
543    // =========================================================================
544    // Softmax Operations
545    // =========================================================================
546
547    /// Softmax along specified dimension.
548    #[must_use] pub fn softmax(&self, dim: i32) -> Variable {
549        let data = self.data();
550        let result = data.softmax(dim);
551        Variable::new(result, self.requires_grad())
552    }
553
554    /// Log softmax along specified dimension.
555    #[must_use] pub fn log_softmax(&self, dim: i32) -> Variable {
556        let data = self.data();
557        let result = data.log_softmax(dim);
558        Variable::new(result, self.requires_grad())
559    }
560
561    // =========================================================================
562    // Reduction Operations with Dimensions
563    // =========================================================================
564
565    /// Mean along a dimension, optionally keeping the dimension.
566    #[must_use] pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Variable {
567        let data = self.data();
568        let result = data.mean_dim(dim, keepdim);
569        Variable::new(result, self.requires_grad())
570    }
571
572    /// Variance along a dimension, optionally keeping the dimension.
573    #[must_use] pub fn var_dim(&self, dim: i32, keepdim: bool) -> Variable {
574        let data = self.data();
575        let result = data.var_dim(dim, keepdim);
576        Variable::new(result, self.requires_grad())
577    }
578
579    // =========================================================================
580    // Utility Methods
581    // =========================================================================
582
583    /// Creates a Variable from a tensor and requires_grad flag (for weight access).
584    /// This is typically used internally by Parameter types.
585    #[must_use] pub fn from_tensor_with_grad(data: Tensor<f32>, requires_grad: bool) -> Variable {
586        Variable::new(data, requires_grad)
587    }
588
589    /// Clones the variable (alias for Clone trait).
590    #[must_use] pub fn clone_var(&self) -> Variable {
591        self.clone()
592    }
593
594    /// Adds another variable (alias for add_var for method chaining).
595    #[must_use] pub fn add(&self, other: &Variable) -> Variable {
596        self.add_var(other)
597    }
598
599    /// Subtracts another variable (alias for sub_var for method chaining).
600    #[must_use] pub fn sub(&self, other: &Variable) -> Variable {
601        self.sub_var(other)
602    }
603
604    /// Multiplies by another variable (alias for mul_var for method chaining).
605    #[must_use] pub fn mul(&self, other: &Variable) -> Variable {
606        self.mul_var(other)
607    }
608
609    /// Divides by another variable (alias for div_var for method chaining).
610    #[must_use] pub fn div(&self, other: &Variable) -> Variable {
611        self.div_var(other)
612    }
613}
614
615// =============================================================================
616// Operator Overloads
617// =============================================================================
618
619impl Add for &Variable {
620    type Output = Variable;
621
622    fn add(self, other: &Variable) -> Variable {
623        self.add_var(other)
624    }
625}
626
627impl Sub for &Variable {
628    type Output = Variable;
629
630    fn sub(self, other: &Variable) -> Variable {
631        self.sub_var(other)
632    }
633}
634
635impl Mul for &Variable {
636    type Output = Variable;
637
638    fn mul(self, other: &Variable) -> Variable {
639        self.mul_var(other)
640    }
641}
642
643impl Div for &Variable {
644    type Output = Variable;
645
646    fn div(self, other: &Variable) -> Variable {
647        self.div_var(other)
648    }
649}
650
651impl Neg for &Variable {
652    type Output = Variable;
653
654    fn neg(self) -> Variable {
655        self.neg_var()
656    }
657}
658
659impl std::fmt::Debug for Variable {
660    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
661        f.debug_struct("Variable")
662            .field("shape", &self.shape())
663            .field("requires_grad", &self.requires_grad)
664            .field("is_leaf", &self.is_leaf)
665            .field("grad_fn", &self.grad_fn.as_ref().map(super::grad_fn::GradFn::name))
666            .finish()
667    }
668}
669
670// =============================================================================
671// Tests
672// =============================================================================
673
674#[cfg(test)]
675mod tests {
676    use super::*;
677    use axonml_tensor::zeros;
678
679    #[test]
680    fn test_variable_creation() {
681        let t = zeros::<f32>(&[2, 3]);
682        let v = Variable::new(t, true);
683        assert!(v.requires_grad());
684        assert!(v.is_leaf());
685        assert_eq!(v.shape(), vec![2, 3]);
686    }
687
688    #[test]
689    fn test_variable_no_grad() {
690        let t = zeros::<f32>(&[2, 3]);
691        let v = Variable::from_tensor(t);
692        assert!(!v.requires_grad());
693    }
694
695    #[test]
696    fn test_variable_add() {
697        let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
698        let b = Variable::new(Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(), true);
699        let c = &a + &b;
700        assert_eq!(c.data().to_vec(), vec![5.0, 7.0, 9.0]);
701        assert!(c.requires_grad());
702        assert!(!c.is_leaf());
703    }
704
705    #[test]
706    fn test_variable_detach() {
707        let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
708        let b = a.detach();
709        assert!(!b.requires_grad());
710        assert!(b.is_leaf());
711    }
712
713    #[test]
714    fn test_mse_loss() {
715        let pred = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
716        let target = Variable::from_tensor(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
717        let loss = pred.mse_loss(&target);
718        assert_eq!(loss.numel(), 1);
719        assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
720    }
721}