1use std::ops::{Add, Div, Mul, Neg, Sub};
30use std::sync::Arc;
31
32use parking_lot::RwLock;
33
34use axonml_tensor::Tensor;
35
36use crate::functions::{
37 AddBackward, AddScalarBackward, CatBackward, ClampBackward, DivBackward, EluBackward,
38 ExpBackward, ExpandBackward, GeluBackward, LeakyReluBackward, LogBackward, LogSoftmaxBackward,
39 MatMulBackward, MeanBackward, MeanDimBackward, MulBackward, MulScalarBackward, NarrowBackward,
40 NegBackward, PowBackward, ReluBackward, ReshapeBackward, SelectBackward, SigmoidBackward,
41 SiluBackward, SoftmaxBackward, SqrtBackward, SubBackward, SumBackward, SumDimBackward,
42 TanhBackward, TransposeBackward, UnsqueezeBackward, VarDimBackward,
43};
44use crate::grad_fn::{AccumulateGrad, GradAccumulator, GradFn};
45use crate::graph::{GraphNode, with_graph};
46use crate::no_grad::is_grad_enabled;
47
48#[derive(Clone)]
58pub struct Variable {
59 data: Arc<RwLock<Tensor<f32>>>,
61 grad: GradAccumulator,
63 requires_grad: bool,
65 is_leaf: bool,
67 grad_fn: Option<GradFn>,
69 node: Option<Arc<GraphNode>>,
71}
72
73impl Variable {
74 #[must_use]
80 pub fn new(data: Tensor<f32>, requires_grad: bool) -> Self {
81 let grad: GradAccumulator = Arc::new(RwLock::new(None));
83
84 let node = if requires_grad {
85 Some(with_graph(|g| g.register_leaf(true)))
86 } else {
87 None
88 };
89
90 let grad_fn = if requires_grad {
92 Some(GradFn::new(AccumulateGrad::new(Arc::clone(&grad))))
93 } else {
94 None
95 };
96
97 Self {
98 data: Arc::new(RwLock::new(data)),
99 grad,
100 requires_grad,
101 is_leaf: true,
102 grad_fn,
103 node,
104 }
105 }
106
107 #[must_use]
109 pub fn from_tensor(data: Tensor<f32>) -> Self {
110 Self::new(data, false)
111 }
112
113 pub fn from_operation(data: Tensor<f32>, grad_fn: GradFn, requires_grad: bool) -> Self {
118 let node = if requires_grad {
119 Some(with_graph(|g| g.register_operation(grad_fn.clone(), true)))
120 } else {
121 None
122 };
123
124 Self {
125 data: Arc::new(RwLock::new(data)),
126 grad: Arc::new(RwLock::new(None)),
127 requires_grad,
128 is_leaf: false,
129 grad_fn: if requires_grad { Some(grad_fn) } else { None },
130 node,
131 }
132 }
133
134 #[must_use]
139 pub fn data(&self) -> Tensor<f32> {
140 self.data.read().clone()
141 }
142
143 #[must_use]
145 pub fn shape(&self) -> Vec<usize> {
146 self.data.read().shape().to_vec()
147 }
148
149 #[must_use]
151 pub fn ndim(&self) -> usize {
152 self.data.read().ndim()
153 }
154
155 #[must_use]
157 pub fn numel(&self) -> usize {
158 self.data.read().numel()
159 }
160
161 #[must_use]
163 pub fn device(&self) -> axonml_tensor::Device {
164 self.data.read().device()
165 }
166
167 pub fn to_device(&self, device: axonml_tensor::Device) -> Self {
172 let current = self.data.read().clone();
173 if current.device() == device {
174 return self.clone();
175 }
176 let moved = current
177 .to_device(device)
178 .expect("Failed to move variable to device");
179 Variable::new(moved, self.requires_grad)
180 }
181
182 #[must_use]
184 pub fn requires_grad(&self) -> bool {
185 self.requires_grad
186 }
187
188 #[must_use]
190 pub fn is_leaf(&self) -> bool {
191 self.is_leaf
192 }
193
194 #[must_use]
198 pub fn grad(&self) -> Option<Tensor<f32>> {
199 self.grad.read().clone()
200 }
201
202 #[must_use]
204 pub fn grad_fn(&self) -> Option<&GradFn> {
205 self.grad_fn.as_ref()
206 }
207
208 pub fn set_data(&mut self, new_data: Tensor<f32>) {
213 *self.data.write() = new_data;
214 }
215
216 pub fn set_grad(&self, grad: Tensor<f32>) {
218 *self.grad.write() = Some(grad);
219 }
220
221 pub fn accumulate_grad(&self, grad: &Tensor<f32>) {
223 let mut grad_lock = self.grad.write();
224 if let Some(ref existing) = *grad_lock {
225 *grad_lock = Some(existing.add(grad).unwrap());
226 } else {
227 *grad_lock = Some(grad.clone());
228 }
229 }
230
231 pub fn zero_grad(&self) {
233 *self.grad.write() = None;
234 }
235
236 #[must_use]
240 pub fn detach(&self) -> Self {
241 Self {
242 data: Arc::new(RwLock::new(self.data.read().clone())),
243 grad: Arc::new(RwLock::new(None)),
244 requires_grad: false,
245 is_leaf: true,
246 grad_fn: None,
247 node: None,
248 }
249 }
250
251 #[must_use]
253 pub fn requires_grad_(mut self, requires_grad: bool) -> Self {
254 self.requires_grad = requires_grad;
255 if requires_grad && self.is_leaf {
256 self.grad_fn = Some(GradFn::new(AccumulateGrad::new(Arc::clone(&self.grad))));
258 self.node = Some(with_graph(|g| g.register_leaf(true)));
259 }
260 self
261 }
262
263 pub fn backward(&self) {
268 assert!(
269 self.requires_grad,
270 "Cannot call backward on a variable that doesn't require gradients"
271 );
272
273 assert!(
274 (self.numel() == 1),
275 "backward() can only be called on scalar tensors"
276 );
277
278 let mut grad_output = Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap();
280 let device = self.data.read().device();
281 if device.is_gpu() {
282 grad_output = grad_output
283 .to_device(device)
284 .expect("device transfer failed");
285 }
286 crate::backward::backward(self, &grad_output);
287 }
288
289 pub fn backward_with_grad(&self, grad_output: &Tensor<f32>) {
294 if !self.requires_grad {
295 return;
296 }
297 let device = self.data.read().device();
298 let grad = if grad_output.device() != device && device.is_gpu() {
299 grad_output
300 .to_device(device)
301 .expect("device transfer failed")
302 } else {
303 grad_output.clone()
304 };
305 crate::backward::backward(self, &grad);
306 }
307
308 #[must_use]
314 pub fn add_var(&self, other: &Variable) -> Variable {
315 let result = self.data.read().add(&other.data.read()).unwrap();
316 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
317
318 if requires_grad {
319 let grad_fn = GradFn::new(AddBackward::new(
320 self.grad_fn.clone(),
321 other.grad_fn.clone(),
322 self.shape(),
323 other.shape(),
324 ));
325 Variable::from_operation(result, grad_fn, true)
326 } else {
327 Variable::from_tensor(result)
328 }
329 }
330
331 #[must_use]
333 pub fn sub_var(&self, other: &Variable) -> Variable {
334 let result = self.data.read().sub(&other.data.read()).unwrap();
335 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
336
337 if requires_grad {
338 let grad_fn = GradFn::new(SubBackward::new(
339 self.grad_fn.clone(),
340 other.grad_fn.clone(),
341 self.shape(),
342 other.shape(),
343 ));
344 Variable::from_operation(result, grad_fn, true)
345 } else {
346 Variable::from_tensor(result)
347 }
348 }
349
350 #[must_use]
352 pub fn mul_var(&self, other: &Variable) -> Variable {
353 let self_data = self.data.read().clone();
354 let other_data = other.data.read().clone();
355 let result = self_data.mul(&other_data).unwrap();
356 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
357
358 if requires_grad {
359 let grad_fn = GradFn::new(MulBackward::new(
360 self.grad_fn.clone(),
361 other.grad_fn.clone(),
362 self_data,
363 other_data,
364 ));
365 Variable::from_operation(result, grad_fn, true)
366 } else {
367 Variable::from_tensor(result)
368 }
369 }
370
371 #[must_use]
373 pub fn div_var(&self, other: &Variable) -> Variable {
374 let self_data = self.data.read().clone();
375 let other_data = other.data.read().clone();
376 let result = self_data.div(&other_data).unwrap();
377 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
378
379 if requires_grad {
380 let grad_fn = GradFn::new(DivBackward::new(
381 self.grad_fn.clone(),
382 other.grad_fn.clone(),
383 self_data,
384 other_data,
385 ));
386 Variable::from_operation(result, grad_fn, true)
387 } else {
388 Variable::from_tensor(result)
389 }
390 }
391
392 #[must_use]
394 pub fn neg_var(&self) -> Variable {
395 let result = self.data.read().neg();
396 let requires_grad = self.requires_grad && is_grad_enabled();
397
398 if requires_grad {
399 let grad_fn = GradFn::new(NegBackward::new(self.grad_fn.clone()));
400 Variable::from_operation(result, grad_fn, true)
401 } else {
402 Variable::from_tensor(result)
403 }
404 }
405
406 #[must_use]
408 pub fn matmul(&self, other: &Variable) -> Variable {
409 let self_data = self.data.read().clone();
410 let other_data = other.data.read().clone();
411
412 let (compute_a, compute_b) = if crate::amp::is_autocast_enabled() {
414 (self_data.to_f16_precision(), other_data.to_f16_precision())
415 } else {
416 (self_data.clone(), other_data.clone())
417 };
418
419 let result = compute_a.matmul(&compute_b).expect("matmul failed");
420 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
421
422 if requires_grad {
423 let grad_fn = GradFn::new(MatMulBackward::new(
424 self.grad_fn.clone(),
425 other.grad_fn.clone(),
426 self_data,
427 other_data,
428 ));
429 Variable::from_operation(result, grad_fn, true)
430 } else {
431 Variable::from_tensor(result)
432 }
433 }
434
435 #[must_use]
437 pub fn pow(&self, exponent: f32) -> Variable {
438 let self_data = self.data.read().clone();
439 let result = self_data.pow(exponent);
440 let requires_grad = self.requires_grad && is_grad_enabled();
441
442 if requires_grad {
443 let grad_fn = GradFn::new(PowBackward::new(self.grad_fn.clone(), self_data, exponent));
444 Variable::from_operation(result, grad_fn, true)
445 } else {
446 Variable::from_tensor(result)
447 }
448 }
449
450 #[must_use]
456 pub fn relu(&self) -> Variable {
457 let self_data = self.data.read().clone();
458 let result = self_data.relu();
459 let requires_grad = self.requires_grad && is_grad_enabled();
460
461 if requires_grad {
462 let grad_fn = GradFn::new(ReluBackward::new(self.grad_fn.clone(), self_data));
463 Variable::from_operation(result, grad_fn, true)
464 } else {
465 Variable::from_tensor(result)
466 }
467 }
468
469 #[must_use]
471 pub fn leaky_relu(&self, negative_slope: f32) -> Variable {
472 let self_data = self.data.read().clone();
473 let device = self_data.device();
474 let result_vec: Vec<f32> = self_data
475 .to_vec()
476 .iter()
477 .map(|&x| if x > 0.0 { x } else { x * negative_slope })
478 .collect();
479 let mut result = Tensor::from_vec(result_vec, self_data.shape()).unwrap();
480 if device.is_gpu() {
481 result = result.to_device(device).expect("device transfer failed");
482 }
483 let requires_grad = self.requires_grad && is_grad_enabled();
484
485 if requires_grad {
486 let grad_fn = GradFn::new(LeakyReluBackward::new(
487 self.grad_fn.clone(),
488 self_data,
489 negative_slope,
490 ));
491 Variable::from_operation(result, grad_fn, true)
492 } else {
493 Variable::from_tensor(result)
494 }
495 }
496
497 #[must_use]
499 pub fn elu(&self, alpha: f32) -> Variable {
500 let self_data = self.data.read().clone();
501 let device = self_data.device();
502 let result_vec: Vec<f32> = self_data
503 .to_vec()
504 .iter()
505 .map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
506 .collect();
507 let mut result = Tensor::from_vec(result_vec, self_data.shape()).unwrap();
508 if device.is_gpu() {
509 result = result.to_device(device).expect("device transfer failed");
510 }
511 let requires_grad = self.requires_grad && is_grad_enabled();
512
513 if requires_grad {
514 let grad_fn = GradFn::new(EluBackward::new(self.grad_fn.clone(), self_data, alpha));
515 Variable::from_operation(result, grad_fn, true)
516 } else {
517 Variable::from_tensor(result)
518 }
519 }
520
521 #[must_use]
523 pub fn sigmoid(&self) -> Variable {
524 let result = self.data.read().sigmoid();
525 let requires_grad = self.requires_grad && is_grad_enabled();
526
527 if requires_grad {
528 let grad_fn = GradFn::new(SigmoidBackward::new(self.grad_fn.clone(), result.clone()));
529 Variable::from_operation(result, grad_fn, true)
530 } else {
531 Variable::from_tensor(result)
532 }
533 }
534
535 #[must_use]
537 pub fn tanh(&self) -> Variable {
538 let result = self.data.read().tanh();
539 let requires_grad = self.requires_grad && is_grad_enabled();
540
541 if requires_grad {
542 let grad_fn = GradFn::new(TanhBackward::new(self.grad_fn.clone(), result.clone()));
543 Variable::from_operation(result, grad_fn, true)
544 } else {
545 Variable::from_tensor(result)
546 }
547 }
548
549 #[must_use]
551 pub fn exp(&self) -> Variable {
552 let self_data = self.data.read().clone();
553 let result = self_data.exp();
554 let requires_grad = self.requires_grad && is_grad_enabled();
555
556 if requires_grad {
557 let grad_fn = GradFn::new(ExpBackward::new(self.grad_fn.clone(), result.clone()));
558 Variable::from_operation(result, grad_fn, true)
559 } else {
560 Variable::from_tensor(result)
561 }
562 }
563
564 #[must_use]
566 pub fn log(&self) -> Variable {
567 let self_data = self.data.read().clone();
568 let result = self_data.ln();
569 let requires_grad = self.requires_grad && is_grad_enabled();
570
571 if requires_grad {
572 let grad_fn = GradFn::new(LogBackward::new(self.grad_fn.clone(), self_data));
573 Variable::from_operation(result, grad_fn, true)
574 } else {
575 Variable::from_tensor(result)
576 }
577 }
578
579 #[must_use]
581 pub fn clamp(&self, min_val: f32, max_val: f32) -> Variable {
582 let self_data = self.data.read().clone();
583 let device = self_data.device();
584 let result_data: Vec<f32> = self_data
585 .to_vec()
586 .iter()
587 .map(|&x| x.clamp(min_val, max_val))
588 .collect();
589 let mut result = Tensor::from_vec(result_data, self_data.shape()).unwrap();
590 if device.is_gpu() {
591 result = result.to_device(device).expect("device transfer failed");
592 }
593 let requires_grad = self.requires_grad && is_grad_enabled();
594
595 if requires_grad {
596 let grad_fn = GradFn::new(ClampBackward::new(
597 self.grad_fn.clone(),
598 self_data,
599 min_val,
600 max_val,
601 ));
602 Variable::from_operation(result, grad_fn, true)
603 } else {
604 Variable::from_tensor(result)
605 }
606 }
607
608 #[must_use]
614 pub fn sum(&self) -> Variable {
615 let self_data = self.data.read().clone();
616 let result = self_data.sum(); let requires_grad = self.requires_grad && is_grad_enabled();
618
619 if requires_grad {
620 let grad_fn = GradFn::new(SumBackward::new(self.grad_fn.clone(), self.shape()));
621 Variable::from_operation(result, grad_fn, true)
622 } else {
623 Variable::from_tensor(result)
624 }
625 }
626
627 #[must_use]
629 pub fn sum_dim(&self, dim: usize) -> Variable {
630 let self_data = self.data.read().clone();
631 let result = self_data.sum_dim(dim as i32, false);
632 let requires_grad = self.requires_grad && is_grad_enabled();
633
634 if requires_grad {
635 let grad_fn = GradFn::new(SumDimBackward::new(self.grad_fn.clone(), self.shape(), dim));
636 Variable::from_operation(result, grad_fn, true)
637 } else {
638 Variable::from_tensor(result)
639 }
640 }
641
642 #[must_use]
644 pub fn mean(&self) -> Variable {
645 let self_data = self.data.read().clone();
646 let result = self_data.mean().unwrap(); let requires_grad = self.requires_grad && is_grad_enabled();
648
649 if requires_grad {
650 let grad_fn = GradFn::new(MeanBackward::new(self.grad_fn.clone(), self.shape()));
651 Variable::from_operation(result, grad_fn, true)
652 } else {
653 Variable::from_tensor(result)
654 }
655 }
656
657 #[must_use]
663 pub fn mse_loss(&self, target: &Variable) -> Variable {
664 let diff = self.sub_var(target);
665 let squared = diff.pow(2.0);
666 squared.mean()
667 }
668
669 #[must_use]
671 pub fn binary_cross_entropy(&self, target: &Variable) -> Variable {
672 let eps = Variable::from_tensor(Tensor::scalar(1e-7));
673 let one = Variable::from_tensor(Tensor::scalar(1.0));
674
675 let log_p = self.add_var(&eps);
677 let log_1_p = one.sub_var(self).add_var(&eps);
678
679 let term1 = target.mul_var(&Variable::from_tensor(log_p.data().ln()));
680 let term2 = one
681 .sub_var(target)
682 .mul_var(&Variable::from_tensor(log_1_p.data().ln()));
683
684 term1.add_var(&term2).neg_var().mean()
685 }
686
687 #[must_use]
693 pub fn reshape(&self, shape: &[usize]) -> Variable {
694 let isize_shape: Vec<isize> = shape.iter().map(|&x| x as isize).collect();
695 let original_shape = self.shape();
696 let new_data = self
697 .data()
698 .reshape(&isize_shape)
699 .unwrap_or_else(|_| self.data().clone());
700 let requires_grad = self.requires_grad && is_grad_enabled();
701
702 if requires_grad {
703 let grad_fn = GradFn::new(ReshapeBackward::new(self.grad_fn.clone(), original_shape));
704 Variable::from_operation(new_data, grad_fn, true)
705 } else {
706 Variable::from_tensor(new_data)
707 }
708 }
709
710 #[must_use]
715 pub fn flatten(&self, start_dim: usize) -> Variable {
716 let shape = self.shape();
717 if start_dim >= shape.len() {
718 return self.clone();
719 }
720 let mut new_shape: Vec<usize> = shape[..start_dim].to_vec();
721 let flat: usize = shape[start_dim..].iter().product();
722 new_shape.push(flat);
723 self.reshape(&new_shape)
724 }
725
726 #[must_use]
728 pub fn transpose(&self, dim0: usize, dim1: usize) -> Variable {
729 let new_data = self
730 .data()
731 .transpose(dim0 as i64, dim1 as i64)
732 .unwrap_or_else(|_| self.data().clone());
733 let requires_grad = self.requires_grad && is_grad_enabled();
734
735 if requires_grad {
736 let grad_fn = GradFn::new(TransposeBackward::new(self.grad_fn.clone(), dim0, dim1));
737 Variable::from_operation(new_data, grad_fn, true)
738 } else {
739 Variable::from_tensor(new_data)
740 }
741 }
742
743 #[must_use]
745 pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Variable {
746 let new_data = self.data().slice(ranges);
747 Variable::new(new_data, self.requires_grad())
748 }
749
750 #[must_use]
755 pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Variable {
756 let input_shape = self.shape();
757 let new_data = self
758 .data()
759 .narrow(dim, start, length)
760 .unwrap_or_else(|_| self.data().clone());
761 let requires_grad = self.requires_grad && is_grad_enabled();
762
763 if requires_grad {
764 let grad_fn = GradFn::new(NarrowBackward::new(
765 self.grad_fn.clone(),
766 input_shape,
767 dim,
768 start,
769 ));
770 Variable::from_operation(new_data, grad_fn, true)
771 } else {
772 Variable::from_tensor(new_data)
773 }
774 }
775
776 #[must_use]
780 pub fn expand(&self, shape: &[usize]) -> Variable {
781 let input_shape = self.shape();
782 let new_data = self.data().broadcast_to(shape);
783 let requires_grad = self.requires_grad && is_grad_enabled();
784
785 if requires_grad {
786 let grad_fn = GradFn::new(ExpandBackward::new(self.grad_fn.clone(), input_shape));
787 Variable::from_operation(new_data, grad_fn, true)
788 } else {
789 Variable::from_tensor(new_data)
790 }
791 }
792
793 #[must_use]
798 pub fn select(&self, dim: usize, index: usize) -> Variable {
799 let input_shape = self.shape();
800 let new_data = self
801 .data()
802 .select(dim, index)
803 .unwrap_or_else(|_| self.data().clone());
804 let requires_grad = self.requires_grad && is_grad_enabled();
805
806 if requires_grad {
807 let grad_fn = GradFn::new(SelectBackward::new(
808 self.grad_fn.clone(),
809 input_shape,
810 dim,
811 index,
812 ));
813 Variable::from_operation(new_data, grad_fn, true)
814 } else {
815 Variable::from_tensor(new_data)
816 }
817 }
818
819 #[must_use]
823 pub fn unsqueeze(&self, dim: usize) -> Variable {
824 let new_data = self
825 .data()
826 .unsqueeze(dim as i64)
827 .unwrap_or_else(|_| self.data().clone());
828 let requires_grad = self.requires_grad && is_grad_enabled();
829
830 if requires_grad {
831 let grad_fn = GradFn::new(UnsqueezeBackward::new(self.grad_fn.clone(), dim));
832 Variable::from_operation(new_data, grad_fn, true)
833 } else {
834 Variable::from_tensor(new_data)
835 }
836 }
837
838 #[must_use]
843 pub fn cat(variables: &[&Variable], dim: usize) -> Variable {
844 let tensors: Vec<Tensor<f32>> = variables.iter().map(|v| v.data()).collect();
845 let tensor_refs: Vec<&Tensor<f32>> = tensors.iter().collect();
846 let result = Tensor::cat(&tensor_refs, dim).unwrap();
847
848 let requires_grad = variables.iter().any(|v| v.requires_grad) && is_grad_enabled();
849
850 if requires_grad {
851 let next_fns: Vec<Option<GradFn>> =
852 variables.iter().map(|v| v.grad_fn.clone()).collect();
853 let sizes: Vec<usize> = variables.iter().map(|v| v.shape()[dim]).collect();
854 let grad_fn = GradFn::new(CatBackward::new(next_fns, sizes, dim));
855 Variable::from_operation(result, grad_fn, true)
856 } else {
857 Variable::from_tensor(result)
858 }
859 }
860
861 #[must_use]
867 pub fn mul_scalar(&self, scalar: f32) -> Variable {
868 let data = self.data();
869 let result = data.mul_scalar(scalar);
870 let requires_grad = self.requires_grad && is_grad_enabled();
871
872 if requires_grad {
873 let grad_fn = GradFn::new(MulScalarBackward::new(self.grad_fn.clone(), scalar));
874 Variable::from_operation(result, grad_fn, true)
875 } else {
876 Variable::from_tensor(result)
877 }
878 }
879
880 #[must_use]
882 pub fn add_scalar(&self, scalar: f32) -> Variable {
883 let data = self.data();
884 let result = data.add_scalar(scalar);
885 let requires_grad = self.requires_grad && is_grad_enabled();
886
887 if requires_grad {
888 let grad_fn = GradFn::new(AddScalarBackward::new(self.grad_fn.clone()));
889 Variable::from_operation(result, grad_fn, true)
890 } else {
891 Variable::from_tensor(result)
892 }
893 }
894
895 #[must_use]
897 pub fn sub_scalar(&self, scalar: f32) -> Variable {
898 self.add_scalar(-scalar)
899 }
900
901 #[must_use]
903 pub fn div_scalar(&self, scalar: f32) -> Variable {
904 self.mul_scalar(1.0 / scalar)
905 }
906
907 #[must_use]
913 pub fn gelu(&self) -> Variable {
914 let self_data = self.data();
915 let result = self_data.gelu();
916 let requires_grad = self.requires_grad && is_grad_enabled();
917
918 if requires_grad {
919 let grad_fn = GradFn::new(GeluBackward::new(self.grad_fn.clone(), self_data));
920 Variable::from_operation(result, grad_fn, true)
921 } else {
922 Variable::from_tensor(result)
923 }
924 }
925
926 #[must_use]
928 pub fn silu(&self) -> Variable {
929 let self_data = self.data();
930 let result = self_data.silu();
931 let requires_grad = self.requires_grad && is_grad_enabled();
932
933 if requires_grad {
934 let grad_fn = GradFn::new(SiluBackward::new(self.grad_fn.clone(), self_data));
935 Variable::from_operation(result, grad_fn, true)
936 } else {
937 Variable::from_tensor(result)
938 }
939 }
940
941 #[must_use]
943 pub fn sqrt(&self) -> Variable {
944 let data = self.data();
945 let result = data.sqrt();
946 let requires_grad = self.requires_grad && is_grad_enabled();
947
948 if requires_grad {
949 let grad_fn = GradFn::new(SqrtBackward::new(self.grad_fn.clone(), result.clone()));
950 Variable::from_operation(result, grad_fn, true)
951 } else {
952 Variable::from_tensor(result)
953 }
954 }
955
956 #[must_use]
962 pub fn softmax(&self, dim: i32) -> Variable {
963 let data = self.data();
964 let result = data.softmax(dim);
965 let requires_grad = self.requires_grad && is_grad_enabled();
966
967 if requires_grad {
968 let grad_fn = GradFn::new(SoftmaxBackward::new(
969 self.grad_fn.clone(),
970 result.clone(),
971 dim as i64,
972 ));
973 Variable::from_operation(result, grad_fn, true)
974 } else {
975 Variable::from_tensor(result)
976 }
977 }
978
979 #[must_use]
981 pub fn log_softmax(&self, dim: i32) -> Variable {
982 let data = self.data();
983 let result = data.log_softmax(dim);
984 let requires_grad = self.requires_grad && is_grad_enabled();
985
986 if requires_grad {
987 let grad_fn = GradFn::new(LogSoftmaxBackward::new(
988 self.grad_fn.clone(),
989 result.clone(),
990 dim as i64,
991 ));
992 Variable::from_operation(result, grad_fn, true)
993 } else {
994 Variable::from_tensor(result)
995 }
996 }
997
998 #[must_use]
1004 pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Variable {
1005 let data = self.data();
1006 let input_shape = data.shape().to_vec();
1007 let ndim = input_shape.len();
1008 let dim_usize = if dim < 0 {
1009 (ndim as i32 + dim) as usize
1010 } else {
1011 dim as usize
1012 };
1013 let result = data.mean_dim(dim, keepdim);
1014 let requires_grad = self.requires_grad && is_grad_enabled();
1015
1016 if requires_grad {
1017 let grad_fn = GradFn::new(MeanDimBackward::new(
1018 self.grad_fn.clone(),
1019 input_shape,
1020 dim_usize,
1021 keepdim,
1022 ));
1023 Variable::from_operation(result, grad_fn, true)
1024 } else {
1025 Variable::from_tensor(result)
1026 }
1027 }
1028
1029 #[must_use]
1031 pub fn var_dim(&self, dim: i32, keepdim: bool) -> Variable {
1032 let self_data = self.data();
1033 let input_shape = self_data.shape().to_vec();
1034 let ndim = input_shape.len();
1035 let dim_usize = if dim < 0 {
1036 (ndim as i32 + dim) as usize
1037 } else {
1038 dim as usize
1039 };
1040 let result = self_data.var_dim(dim, keepdim);
1041 let requires_grad = self.requires_grad && is_grad_enabled();
1042
1043 if requires_grad {
1044 let grad_fn = GradFn::new(VarDimBackward::new(
1045 self.grad_fn.clone(),
1046 self_data,
1047 dim_usize,
1048 keepdim,
1049 ));
1050 Variable::from_operation(result, grad_fn, true)
1051 } else {
1052 Variable::from_tensor(result)
1053 }
1054 }
1055
1056 #[must_use]
1063 pub fn from_tensor_with_grad(data: Tensor<f32>, requires_grad: bool) -> Variable {
1064 Variable::new(data, requires_grad)
1065 }
1066
1067 #[must_use]
1069 pub fn clone_var(&self) -> Variable {
1070 self.clone()
1071 }
1072
1073 #[must_use]
1075 pub fn add(&self, other: &Variable) -> Variable {
1076 self.add_var(other)
1077 }
1078
1079 #[must_use]
1081 pub fn sub(&self, other: &Variable) -> Variable {
1082 self.sub_var(other)
1083 }
1084
1085 #[must_use]
1087 pub fn mul(&self, other: &Variable) -> Variable {
1088 self.mul_var(other)
1089 }
1090
1091 #[must_use]
1093 pub fn div(&self, other: &Variable) -> Variable {
1094 self.div_var(other)
1095 }
1096}
1097
1098impl Add for &Variable {
1103 type Output = Variable;
1104
1105 fn add(self, other: &Variable) -> Variable {
1106 self.add_var(other)
1107 }
1108}
1109
1110impl Sub for &Variable {
1111 type Output = Variable;
1112
1113 fn sub(self, other: &Variable) -> Variable {
1114 self.sub_var(other)
1115 }
1116}
1117
1118impl Mul for &Variable {
1119 type Output = Variable;
1120
1121 fn mul(self, other: &Variable) -> Variable {
1122 self.mul_var(other)
1123 }
1124}
1125
1126impl Div for &Variable {
1127 type Output = Variable;
1128
1129 fn div(self, other: &Variable) -> Variable {
1130 self.div_var(other)
1131 }
1132}
1133
1134impl Neg for &Variable {
1135 type Output = Variable;
1136
1137 fn neg(self) -> Variable {
1138 self.neg_var()
1139 }
1140}
1141
1142impl std::fmt::Debug for Variable {
1143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1144 f.debug_struct("Variable")
1145 .field("shape", &self.shape())
1146 .field("requires_grad", &self.requires_grad)
1147 .field("is_leaf", &self.is_leaf)
1148 .field(
1149 "grad_fn",
1150 &self.grad_fn.as_ref().map(super::grad_fn::GradFn::name),
1151 )
1152 .finish()
1153 }
1154}
1155
1156#[cfg(test)]
1161mod tests {
1162 use super::*;
1163 use axonml_tensor::zeros;
1164
1165 #[test]
1166 fn test_variable_creation() {
1167 let t = zeros::<f32>(&[2, 3]);
1168 let v = Variable::new(t, true);
1169 assert!(v.requires_grad());
1170 assert!(v.is_leaf());
1171 assert_eq!(v.shape(), vec![2, 3]);
1172 }
1173
1174 #[test]
1175 fn test_variable_no_grad() {
1176 let t = zeros::<f32>(&[2, 3]);
1177 let v = Variable::from_tensor(t);
1178 assert!(!v.requires_grad());
1179 }
1180
1181 #[test]
1182 fn test_variable_add() {
1183 let a = Variable::new(
1184 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1185 true,
1186 );
1187 let b = Variable::new(
1188 Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).expect("tensor creation failed"),
1189 true,
1190 );
1191 let c = &a + &b;
1192 assert_eq!(c.data().to_vec(), vec![5.0, 7.0, 9.0]);
1193 assert!(c.requires_grad());
1194 assert!(!c.is_leaf());
1195 }
1196
1197 #[test]
1198 fn test_variable_detach() {
1199 let a = Variable::new(
1200 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1201 true,
1202 );
1203 let b = a.detach();
1204 assert!(!b.requires_grad());
1205 assert!(b.is_leaf());
1206 }
1207
1208 #[test]
1209 fn test_mse_loss() {
1210 let pred = Variable::new(
1211 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1212 true,
1213 );
1214 let target = Variable::from_tensor(
1215 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1216 );
1217 let loss = pred.mse_loss(&target);
1218 assert_eq!(loss.numel(), 1);
1219 assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
1220 }
1221
1222 #[test]
1223 fn test_exp() {
1224 let a = Variable::new(
1225 Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).expect("tensor creation failed"),
1226 true,
1227 );
1228 let b = a.exp();
1229 assert!((b.data().to_vec()[0] - 1.0).abs() < 1e-5);
1230 assert!((b.data().to_vec()[1] - std::f32::consts::E).abs() < 1e-4);
1231
1232 b.sum().backward();
1233 let grad = a.grad().unwrap().to_vec();
1234 assert!((grad[0] - 1.0).abs() < 1e-5);
1236 assert!((grad[1] - std::f32::consts::E).abs() < 1e-4);
1237 }
1238
1239 #[test]
1240 fn test_log() {
1241 let a = Variable::new(
1242 Tensor::from_vec(vec![1.0, std::f32::consts::E, 10.0], &[3])
1243 .expect("tensor creation failed"),
1244 true,
1245 );
1246 let b = a.log();
1247 assert!((b.data().to_vec()[0] - 0.0).abs() < 1e-5);
1248 assert!((b.data().to_vec()[1] - 1.0).abs() < 1e-5);
1249
1250 b.sum().backward();
1251 let grad = a.grad().unwrap().to_vec();
1252 assert!((grad[0] - 1.0).abs() < 1e-5);
1254 assert!((grad[1] - 1.0 / std::f32::consts::E).abs() < 1e-5);
1255 }
1256
1257 #[test]
1258 fn test_clamp() {
1259 let a = Variable::new(
1260 Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).expect("tensor creation failed"),
1261 true,
1262 );
1263 let b = a.clamp(0.0, 1.0);
1264 assert_eq!(b.data().to_vec(), vec![0.0, 0.5, 1.0]);
1265
1266 b.sum().backward();
1267 let grad = a.grad().unwrap().to_vec();
1268 assert_eq!(grad[0], 0.0); assert_eq!(grad[1], 1.0); assert_eq!(grad[2], 0.0); }
1273
1274 #[test]
1279 fn test_add_backward() {
1280 let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), true);
1281 let b = Variable::new(Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(), true);
1282 let c = a.add_var(&b);
1283 c.sum().backward();
1284
1285 let ga = a.grad().expect("a should have grad");
1287 let gb = b.grad().expect("b should have grad");
1288 assert_eq!(ga.to_vec(), vec![1.0, 1.0]);
1289 assert_eq!(gb.to_vec(), vec![1.0, 1.0]);
1290 }
1291
1292 #[test]
1293 fn test_sub_backward() {
1294 let a = Variable::new(Tensor::from_vec(vec![5.0, 3.0], &[2]).unwrap(), true);
1295 let b = Variable::new(Tensor::from_vec(vec![2.0, 1.0], &[2]).unwrap(), true);
1296 let c = a.sub_var(&b);
1297
1298 assert_eq!(c.data().to_vec(), vec![3.0, 2.0]);
1299 c.sum().backward();
1300
1301 let ga = a.grad().unwrap().to_vec();
1303 let gb = b.grad().unwrap().to_vec();
1304 assert_eq!(ga, vec![1.0, 1.0]);
1305 assert_eq!(gb, vec![-1.0, -1.0]);
1306 }
1307
1308 #[test]
1309 fn test_mul_backward() {
1310 let a = Variable::new(Tensor::from_vec(vec![2.0, 3.0], &[2]).unwrap(), true);
1311 let b = Variable::new(Tensor::from_vec(vec![4.0, 5.0], &[2]).unwrap(), true);
1312 let c = a.mul_var(&b);
1313
1314 assert_eq!(c.data().to_vec(), vec![8.0, 15.0]);
1315 c.sum().backward();
1316
1317 let ga = a.grad().unwrap().to_vec();
1319 let gb = b.grad().unwrap().to_vec();
1320 assert_eq!(ga, vec![4.0, 5.0]);
1321 assert_eq!(gb, vec![2.0, 3.0]);
1322 }
1323
1324 #[test]
1325 fn test_div_backward() {
1326 let a = Variable::new(Tensor::from_vec(vec![6.0, 10.0], &[2]).unwrap(), true);
1327 let b = Variable::new(Tensor::from_vec(vec![2.0, 5.0], &[2]).unwrap(), true);
1328 let c = a.div_var(&b);
1329
1330 assert_eq!(c.data().to_vec(), vec![3.0, 2.0]);
1331 c.sum().backward();
1332
1333 let ga = a.grad().unwrap().to_vec();
1335 let gb = b.grad().unwrap().to_vec();
1336 assert!((ga[0] - 0.5).abs() < 1e-5, "da = 1/b = 0.5, got {}", ga[0]);
1337 assert!((ga[1] - 0.2).abs() < 1e-5, "da = 1/b = 0.2, got {}", ga[1]);
1338 assert!(
1339 (gb[0] - (-1.5)).abs() < 1e-5,
1340 "db = -a/b^2 = -6/4 = -1.5, got {}",
1341 gb[0]
1342 );
1343 assert!(
1344 (gb[1] - (-0.4)).abs() < 1e-5,
1345 "db = -a/b^2 = -10/25 = -0.4, got {}",
1346 gb[1]
1347 );
1348 }
1349
1350 #[test]
1351 fn test_mul_scalar_backward() {
1352 let a = Variable::new(Tensor::from_vec(vec![2.0, 3.0], &[2]).unwrap(), true);
1353 let c = a.mul_scalar(5.0);
1354
1355 assert_eq!(c.data().to_vec(), vec![10.0, 15.0]);
1356 c.sum().backward();
1357
1358 let ga = a.grad().unwrap().to_vec();
1360 assert_eq!(ga, vec![5.0, 5.0]);
1361 }
1362
1363 #[test]
1368 fn test_relu_backward() {
1369 let a = Variable::new(Tensor::from_vec(vec![-2.0, 0.0, 3.0], &[3]).unwrap(), true);
1370 let b = a.relu();
1371
1372 assert_eq!(b.data().to_vec(), vec![0.0, 0.0, 3.0]);
1373 b.sum().backward();
1374
1375 let ga = a.grad().unwrap().to_vec();
1377 assert_eq!(ga[0], 0.0); assert_eq!(ga[2], 1.0); }
1380
1381 #[test]
1382 fn test_sigmoid_backward() {
1383 let a = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), true);
1384 let b = a.sigmoid();
1385
1386 assert!((b.data().to_vec()[0] - 0.5).abs() < 1e-5);
1388 b.backward();
1389
1390 let ga = a.grad().unwrap().to_vec();
1392 assert!(
1393 (ga[0] - 0.25).abs() < 1e-4,
1394 "sigmoid'(0) = 0.25, got {}",
1395 ga[0]
1396 );
1397 }
1398
1399 #[test]
1400 fn test_tanh_backward() {
1401 let a = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), true);
1402 let b = a.tanh();
1403
1404 assert!(b.data().to_vec()[0].abs() < 1e-5);
1406 b.backward();
1407
1408 let ga = a.grad().unwrap().to_vec();
1410 assert!((ga[0] - 1.0).abs() < 1e-4, "tanh'(0) = 1.0, got {}", ga[0]);
1411 }
1412
1413 #[test]
1418 fn test_chain_rule_mul_then_add() {
1419 let a = Variable::new(Tensor::from_vec(vec![3.0], &[1]).unwrap(), true);
1421 let b = Variable::new(Tensor::from_vec(vec![4.0], &[1]).unwrap(), true);
1422 let ab = a.mul_var(&b);
1423 let result = ab.add_var(&a);
1424 result.backward();
1425
1426 let ga = a.grad().unwrap().to_vec()[0];
1427 let gb = b.grad().unwrap().to_vec()[0];
1428 assert!((ga - 5.0).abs() < 1e-4, "df/da = b+1 = 5, got {}", ga);
1429 assert!((gb - 3.0).abs() < 1e-4, "df/db = a = 3, got {}", gb);
1430 }
1431
1432 #[test]
1433 fn test_chain_rule_nested_operations() {
1434 let x = Variable::new(Tensor::from_vec(vec![2.0], &[1]).unwrap(), true);
1436 let x_sq = x.mul_var(&x); let shifted = x_sq.add_scalar(-1.0); let out = shifted.relu(); assert!((out.data().to_vec()[0] - 3.0).abs() < 1e-5);
1441 out.backward();
1442
1443 let gx = x.grad().unwrap().to_vec()[0];
1445 assert!((gx - 4.0).abs() < 1e-4, "df/dx = 2x = 4, got {}", gx);
1446 }
1447
1448 #[test]
1449 fn test_sum_backward() {
1450 let a = Variable::new(
1451 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap(),
1452 true,
1453 );
1454 let s = a.sum();
1455
1456 assert!((s.data().to_vec()[0] - 10.0).abs() < 1e-5);
1457 s.backward();
1458
1459 let ga = a.grad().unwrap().to_vec();
1461 assert_eq!(ga, vec![1.0, 1.0, 1.0, 1.0]);
1462 }
1463
1464 #[test]
1465 fn test_mean_backward() {
1466 let a = Variable::new(
1467 Tensor::from_vec(vec![2.0, 4.0, 6.0, 8.0], &[4]).unwrap(),
1468 true,
1469 );
1470 let m = a.mean();
1471
1472 assert!((m.data().to_vec()[0] - 5.0).abs() < 1e-5);
1473 m.backward();
1474
1475 let ga = a.grad().unwrap().to_vec();
1477 for g in &ga {
1478 assert!(
1479 (g - 0.25).abs() < 1e-5,
1480 "d(mean)/dx = 1/4 = 0.25, got {}",
1481 g
1482 );
1483 }
1484 }
1485
1486 #[test]
1491 fn test_matmul_backward() {
1492 let a = Variable::new(
1494 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap(),
1495 true,
1496 );
1497 let b = Variable::new(
1498 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap(),
1499 true,
1500 );
1501 let c = a.matmul(&b); assert_eq!(c.shape(), vec![2, 2]);
1503
1504 c.sum().backward();
1505
1506 let ga = a.grad().expect("A should have grad");
1508 let gb = b.grad().expect("B should have grad");
1509 assert_eq!(ga.shape(), &[2, 3]);
1510 assert_eq!(gb.shape(), &[3, 2]);
1511
1512 assert!(ga.to_vec().iter().all(|g| g.is_finite() && g.abs() > 0.0));
1514 assert!(gb.to_vec().iter().all(|g| g.is_finite() && g.abs() > 0.0));
1515 }
1516
1517 #[test]
1522 fn test_no_grad_skips_backward() {
1523 let a = Variable::new(Tensor::from_vec(vec![1.0], &[1]).unwrap(), false);
1524 let b = a.mul_scalar(2.0);
1525 assert!((b.data().to_vec()[0] - 2.0).abs() < 1e-5);
1527 assert!(a.grad().is_none());
1528 }
1529
1530 #[test]
1531 fn test_detach_stops_gradient() {
1532 let a = Variable::new(Tensor::from_vec(vec![3.0], &[1]).unwrap(), true);
1534 let b = a.mul_scalar(2.0);
1535 let c = b.detach();
1536
1537 assert!(
1539 !c.requires_grad(),
1540 "Detached variable should not require grad"
1541 );
1542 assert!(c.is_leaf(), "Detached variable should be a leaf");
1543
1544 b.backward();
1546 let ga = a.grad().unwrap().to_vec()[0];
1547 assert!(
1548 (ga - 2.0).abs() < 1e-4,
1549 "Gradient through b=2*a should be 2: got {}",
1550 ga
1551 );
1552 }
1553
1554 #[test]
1555 fn test_backward_twice_accumulates() {
1556 let a = Variable::new(Tensor::from_vec(vec![2.0], &[1]).unwrap(), true);
1557 let b = a.mul_scalar(3.0);
1558 b.backward();
1559 let g1 = a.grad().unwrap().to_vec()[0];
1560
1561 let c = a.mul_scalar(3.0);
1563 c.backward();
1564 let g2 = a.grad().unwrap().to_vec()[0];
1565
1566 assert!(
1568 (g2 - g1 * 2.0).abs() < 1e-4 || g2 >= g1,
1569 "Second backward should accumulate: g1={}, g2={}",
1570 g1,
1571 g2
1572 );
1573 }
1574
1575 #[test]
1576 fn test_reshape_preserves_gradient() {
1577 let a = Variable::new(
1578 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
1579 true,
1580 );
1581 let b = a.reshape(&[4]);
1582 let c = b.sum();
1583 c.backward();
1584
1585 let ga = a.grad().expect("Should have gradient through reshape");
1586 assert_eq!(ga.shape(), &[2, 2]);
1587 assert_eq!(ga.to_vec(), vec![1.0, 1.0, 1.0, 1.0]);
1588 }
1589}