1use std::ops::{Add, Div, Mul, Neg, Sub};
18use std::sync::Arc;
19
20use parking_lot::RwLock;
21
22use axonml_tensor::Tensor;
23
24use crate::functions::{
25 AddBackward, AddScalarBackward, CatBackward, ClampBackward, DivBackward, EluBackward,
26 ExpBackward, ExpandBackward, GeluBackward, LeakyReluBackward, LogBackward, LogSoftmaxBackward,
27 MatMulBackward, MeanBackward, MeanDimBackward, MulBackward, MulScalarBackward, NarrowBackward,
28 NegBackward, PowBackward, ReluBackward, ReshapeBackward, SelectBackward, SigmoidBackward,
29 SiluBackward, SoftmaxBackward, SqrtBackward, SubBackward, SumBackward, SumDimBackward,
30 TanhBackward, TransposeBackward, UnsqueezeBackward, VarDimBackward,
31};
32use crate::grad_fn::{AccumulateGrad, GradAccumulator, GradFn};
33use crate::graph::{GraphNode, with_graph};
34use crate::no_grad::is_grad_enabled;
35
36#[derive(Clone)]
46pub struct Variable {
47 data: Arc<RwLock<Tensor<f32>>>,
49 grad: GradAccumulator,
51 requires_grad: bool,
53 is_leaf: bool,
55 grad_fn: Option<GradFn>,
57 node: Option<Arc<GraphNode>>,
59}
60
61impl Variable {
62 #[must_use]
68 pub fn new(data: Tensor<f32>, requires_grad: bool) -> Self {
69 let grad: GradAccumulator = Arc::new(RwLock::new(None));
71
72 let node = if requires_grad {
73 Some(with_graph(|g| g.register_leaf(true)))
74 } else {
75 None
76 };
77
78 let grad_fn = if requires_grad {
80 Some(GradFn::new(AccumulateGrad::new(Arc::clone(&grad))))
81 } else {
82 None
83 };
84
85 Self {
86 data: Arc::new(RwLock::new(data)),
87 grad,
88 requires_grad,
89 is_leaf: true,
90 grad_fn,
91 node,
92 }
93 }
94
95 #[must_use]
97 pub fn from_tensor(data: Tensor<f32>) -> Self {
98 Self::new(data, false)
99 }
100
101 pub fn from_operation(data: Tensor<f32>, grad_fn: GradFn, requires_grad: bool) -> Self {
106 let node = if requires_grad {
107 Some(with_graph(|g| g.register_operation(grad_fn.clone(), true)))
108 } else {
109 None
110 };
111
112 Self {
113 data: Arc::new(RwLock::new(data)),
114 grad: Arc::new(RwLock::new(None)),
115 requires_grad,
116 is_leaf: false,
117 grad_fn: if requires_grad { Some(grad_fn) } else { None },
118 node,
119 }
120 }
121
122 #[must_use]
127 pub fn data(&self) -> Tensor<f32> {
128 self.data.read().clone()
129 }
130
131 #[must_use]
133 pub fn shape(&self) -> Vec<usize> {
134 self.data.read().shape().to_vec()
135 }
136
137 #[must_use]
139 pub fn ndim(&self) -> usize {
140 self.data.read().ndim()
141 }
142
143 #[must_use]
145 pub fn numel(&self) -> usize {
146 self.data.read().numel()
147 }
148
149 #[must_use]
151 pub fn device(&self) -> axonml_tensor::Device {
152 self.data.read().device()
153 }
154
155 pub fn to_device(&self, device: axonml_tensor::Device) -> Self {
160 let current = self.data.read().clone();
161 if current.device() == device {
162 return self.clone();
163 }
164 let moved = current
165 .to_device(device)
166 .expect("Failed to move variable to device");
167 Variable::new(moved, self.requires_grad)
168 }
169
170 #[must_use]
172 pub fn requires_grad(&self) -> bool {
173 self.requires_grad
174 }
175
176 #[must_use]
178 pub fn is_leaf(&self) -> bool {
179 self.is_leaf
180 }
181
182 #[must_use]
186 pub fn grad(&self) -> Option<Tensor<f32>> {
187 self.grad.read().clone()
188 }
189
190 #[must_use]
192 pub fn grad_fn(&self) -> Option<&GradFn> {
193 self.grad_fn.as_ref()
194 }
195
196 pub fn set_data(&mut self, new_data: Tensor<f32>) {
201 *self.data.write() = new_data;
202 }
203
204 pub fn set_grad(&self, grad: Tensor<f32>) {
206 *self.grad.write() = Some(grad);
207 }
208
209 pub fn accumulate_grad(&self, grad: &Tensor<f32>) {
211 let mut grad_lock = self.grad.write();
212 if let Some(ref existing) = *grad_lock {
213 *grad_lock = Some(existing.add(grad).unwrap());
214 } else {
215 *grad_lock = Some(grad.clone());
216 }
217 }
218
219 pub fn zero_grad(&self) {
221 *self.grad.write() = None;
222 }
223
224 #[must_use]
228 pub fn detach(&self) -> Self {
229 Self {
230 data: Arc::new(RwLock::new(self.data.read().clone())),
231 grad: Arc::new(RwLock::new(None)),
232 requires_grad: false,
233 is_leaf: true,
234 grad_fn: None,
235 node: None,
236 }
237 }
238
239 #[must_use]
241 pub fn requires_grad_(mut self, requires_grad: bool) -> Self {
242 self.requires_grad = requires_grad;
243 if requires_grad && self.is_leaf {
244 self.grad_fn = Some(GradFn::new(AccumulateGrad::new(Arc::clone(&self.grad))));
246 self.node = Some(with_graph(|g| g.register_leaf(true)));
247 }
248 self
249 }
250
251 pub fn backward(&self) {
256 assert!(
257 self.requires_grad,
258 "Cannot call backward on a variable that doesn't require gradients"
259 );
260
261 assert!(
262 (self.numel() == 1),
263 "backward() can only be called on scalar tensors"
264 );
265
266 let mut grad_output = Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap();
268 let device = self.data.read().device();
269 if device.is_gpu() {
270 grad_output = grad_output.to_device(device).expect("device transfer failed");
271 }
272 crate::backward::backward(self, &grad_output);
273 }
274
275 pub fn backward_with_grad(&self, grad_output: &Tensor<f32>) {
280 if !self.requires_grad {
281 return;
282 }
283 let device = self.data.read().device();
284 let grad = if grad_output.device() != device && device.is_gpu() {
285 grad_output.to_device(device).expect("device transfer failed")
286 } else {
287 grad_output.clone()
288 };
289 crate::backward::backward(self, &grad);
290 }
291
292 #[must_use]
298 pub fn add_var(&self, other: &Variable) -> Variable {
299 let result = self.data.read().add(&other.data.read()).unwrap();
300 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
301
302 if requires_grad {
303 let grad_fn = GradFn::new(AddBackward::new(
304 self.grad_fn.clone(),
305 other.grad_fn.clone(),
306 self.shape(),
307 other.shape(),
308 ));
309 Variable::from_operation(result, grad_fn, true)
310 } else {
311 Variable::from_tensor(result)
312 }
313 }
314
315 #[must_use]
317 pub fn sub_var(&self, other: &Variable) -> Variable {
318 let result = self.data.read().sub(&other.data.read()).unwrap();
319 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
320
321 if requires_grad {
322 let grad_fn = GradFn::new(SubBackward::new(
323 self.grad_fn.clone(),
324 other.grad_fn.clone(),
325 self.shape(),
326 other.shape(),
327 ));
328 Variable::from_operation(result, grad_fn, true)
329 } else {
330 Variable::from_tensor(result)
331 }
332 }
333
334 #[must_use]
336 pub fn mul_var(&self, other: &Variable) -> Variable {
337 let self_data = self.data.read().clone();
338 let other_data = other.data.read().clone();
339 let result = self_data.mul(&other_data).unwrap();
340 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
341
342 if requires_grad {
343 let grad_fn = GradFn::new(MulBackward::new(
344 self.grad_fn.clone(),
345 other.grad_fn.clone(),
346 self_data,
347 other_data,
348 ));
349 Variable::from_operation(result, grad_fn, true)
350 } else {
351 Variable::from_tensor(result)
352 }
353 }
354
355 #[must_use]
357 pub fn div_var(&self, other: &Variable) -> Variable {
358 let self_data = self.data.read().clone();
359 let other_data = other.data.read().clone();
360 let result = self_data.div(&other_data).unwrap();
361 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
362
363 if requires_grad {
364 let grad_fn = GradFn::new(DivBackward::new(
365 self.grad_fn.clone(),
366 other.grad_fn.clone(),
367 self_data,
368 other_data,
369 ));
370 Variable::from_operation(result, grad_fn, true)
371 } else {
372 Variable::from_tensor(result)
373 }
374 }
375
376 #[must_use]
378 pub fn neg_var(&self) -> Variable {
379 let result = self.data.read().neg();
380 let requires_grad = self.requires_grad && is_grad_enabled();
381
382 if requires_grad {
383 let grad_fn = GradFn::new(NegBackward::new(self.grad_fn.clone()));
384 Variable::from_operation(result, grad_fn, true)
385 } else {
386 Variable::from_tensor(result)
387 }
388 }
389
390 #[must_use]
392 pub fn matmul(&self, other: &Variable) -> Variable {
393 let self_data = self.data.read().clone();
394 let other_data = other.data.read().clone();
395
396 let (compute_a, compute_b) = if crate::amp::is_autocast_enabled() {
398 (self_data.to_f16_precision(), other_data.to_f16_precision())
399 } else {
400 (self_data.clone(), other_data.clone())
401 };
402
403 let result = compute_a.matmul(&compute_b).expect("matmul failed");
404 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
405
406 if requires_grad {
407 let grad_fn = GradFn::new(MatMulBackward::new(
408 self.grad_fn.clone(),
409 other.grad_fn.clone(),
410 self_data,
411 other_data,
412 ));
413 Variable::from_operation(result, grad_fn, true)
414 } else {
415 Variable::from_tensor(result)
416 }
417 }
418
419 #[must_use]
421 pub fn pow(&self, exponent: f32) -> Variable {
422 let self_data = self.data.read().clone();
423 let result = self_data.pow(exponent);
424 let requires_grad = self.requires_grad && is_grad_enabled();
425
426 if requires_grad {
427 let grad_fn = GradFn::new(PowBackward::new(self.grad_fn.clone(), self_data, exponent));
428 Variable::from_operation(result, grad_fn, true)
429 } else {
430 Variable::from_tensor(result)
431 }
432 }
433
434 #[must_use]
440 pub fn relu(&self) -> Variable {
441 let self_data = self.data.read().clone();
442 let result = self_data.relu();
443 let requires_grad = self.requires_grad && is_grad_enabled();
444
445 if requires_grad {
446 let grad_fn = GradFn::new(ReluBackward::new(self.grad_fn.clone(), self_data));
447 Variable::from_operation(result, grad_fn, true)
448 } else {
449 Variable::from_tensor(result)
450 }
451 }
452
453 #[must_use]
455 pub fn leaky_relu(&self, negative_slope: f32) -> Variable {
456 let self_data = self.data.read().clone();
457 let device = self_data.device();
458 let result_vec: Vec<f32> = self_data
459 .to_vec()
460 .iter()
461 .map(|&x| if x > 0.0 { x } else { x * negative_slope })
462 .collect();
463 let mut result = Tensor::from_vec(result_vec, self_data.shape()).unwrap();
464 if device.is_gpu() {
465 result = result.to_device(device).expect("device transfer failed");
466 }
467 let requires_grad = self.requires_grad && is_grad_enabled();
468
469 if requires_grad {
470 let grad_fn = GradFn::new(LeakyReluBackward::new(
471 self.grad_fn.clone(),
472 self_data,
473 negative_slope,
474 ));
475 Variable::from_operation(result, grad_fn, true)
476 } else {
477 Variable::from_tensor(result)
478 }
479 }
480
481 #[must_use]
483 pub fn elu(&self, alpha: f32) -> Variable {
484 let self_data = self.data.read().clone();
485 let device = self_data.device();
486 let result_vec: Vec<f32> = self_data
487 .to_vec()
488 .iter()
489 .map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
490 .collect();
491 let mut result = Tensor::from_vec(result_vec, self_data.shape()).unwrap();
492 if device.is_gpu() {
493 result = result.to_device(device).expect("device transfer failed");
494 }
495 let requires_grad = self.requires_grad && is_grad_enabled();
496
497 if requires_grad {
498 let grad_fn = GradFn::new(EluBackward::new(self.grad_fn.clone(), self_data, alpha));
499 Variable::from_operation(result, grad_fn, true)
500 } else {
501 Variable::from_tensor(result)
502 }
503 }
504
505 #[must_use]
507 pub fn sigmoid(&self) -> Variable {
508 let result = self.data.read().sigmoid();
509 let requires_grad = self.requires_grad && is_grad_enabled();
510
511 if requires_grad {
512 let grad_fn = GradFn::new(SigmoidBackward::new(self.grad_fn.clone(), result.clone()));
513 Variable::from_operation(result, grad_fn, true)
514 } else {
515 Variable::from_tensor(result)
516 }
517 }
518
519 #[must_use]
521 pub fn tanh(&self) -> Variable {
522 let result = self.data.read().tanh();
523 let requires_grad = self.requires_grad && is_grad_enabled();
524
525 if requires_grad {
526 let grad_fn = GradFn::new(TanhBackward::new(self.grad_fn.clone(), result.clone()));
527 Variable::from_operation(result, grad_fn, true)
528 } else {
529 Variable::from_tensor(result)
530 }
531 }
532
533 #[must_use]
535 pub fn exp(&self) -> Variable {
536 let self_data = self.data.read().clone();
537 let result = self_data.exp();
538 let requires_grad = self.requires_grad && is_grad_enabled();
539
540 if requires_grad {
541 let grad_fn = GradFn::new(ExpBackward::new(self.grad_fn.clone(), result.clone()));
542 Variable::from_operation(result, grad_fn, true)
543 } else {
544 Variable::from_tensor(result)
545 }
546 }
547
548 #[must_use]
550 pub fn log(&self) -> Variable {
551 let self_data = self.data.read().clone();
552 let result = self_data.ln();
553 let requires_grad = self.requires_grad && is_grad_enabled();
554
555 if requires_grad {
556 let grad_fn = GradFn::new(LogBackward::new(self.grad_fn.clone(), self_data));
557 Variable::from_operation(result, grad_fn, true)
558 } else {
559 Variable::from_tensor(result)
560 }
561 }
562
563 #[must_use]
565 pub fn clamp(&self, min_val: f32, max_val: f32) -> Variable {
566 let self_data = self.data.read().clone();
567 let device = self_data.device();
568 let result_data: Vec<f32> = self_data
569 .to_vec()
570 .iter()
571 .map(|&x| x.clamp(min_val, max_val))
572 .collect();
573 let mut result = Tensor::from_vec(result_data, self_data.shape()).unwrap();
574 if device.is_gpu() {
575 result = result.to_device(device).expect("device transfer failed");
576 }
577 let requires_grad = self.requires_grad && is_grad_enabled();
578
579 if requires_grad {
580 let grad_fn = GradFn::new(ClampBackward::new(
581 self.grad_fn.clone(),
582 self_data,
583 min_val,
584 max_val,
585 ));
586 Variable::from_operation(result, grad_fn, true)
587 } else {
588 Variable::from_tensor(result)
589 }
590 }
591
592 #[must_use]
598 pub fn sum(&self) -> Variable {
599 let self_data = self.data.read().clone();
600 let result = self_data.sum(); let requires_grad = self.requires_grad && is_grad_enabled();
602
603 if requires_grad {
604 let grad_fn = GradFn::new(SumBackward::new(self.grad_fn.clone(), self.shape()));
605 Variable::from_operation(result, grad_fn, true)
606 } else {
607 Variable::from_tensor(result)
608 }
609 }
610
611 #[must_use]
613 pub fn sum_dim(&self, dim: usize) -> Variable {
614 let self_data = self.data.read().clone();
615 let result = self_data.sum_dim(dim as i32, false);
616 let requires_grad = self.requires_grad && is_grad_enabled();
617
618 if requires_grad {
619 let grad_fn = GradFn::new(SumDimBackward::new(self.grad_fn.clone(), self.shape(), dim));
620 Variable::from_operation(result, grad_fn, true)
621 } else {
622 Variable::from_tensor(result)
623 }
624 }
625
626 #[must_use]
628 pub fn mean(&self) -> Variable {
629 let self_data = self.data.read().clone();
630 let result = self_data.mean().unwrap(); let requires_grad = self.requires_grad && is_grad_enabled();
632
633 if requires_grad {
634 let grad_fn = GradFn::new(MeanBackward::new(self.grad_fn.clone(), self.shape()));
635 Variable::from_operation(result, grad_fn, true)
636 } else {
637 Variable::from_tensor(result)
638 }
639 }
640
641 #[must_use]
647 pub fn mse_loss(&self, target: &Variable) -> Variable {
648 let diff = self.sub_var(target);
649 let squared = diff.pow(2.0);
650 squared.mean()
651 }
652
653 #[must_use]
655 pub fn binary_cross_entropy(&self, target: &Variable) -> Variable {
656 let eps = Variable::from_tensor(Tensor::scalar(1e-7));
657 let one = Variable::from_tensor(Tensor::scalar(1.0));
658
659 let log_p = self.add_var(&eps);
661 let log_1_p = one.sub_var(self).add_var(&eps);
662
663 let term1 = target.mul_var(&Variable::from_tensor(log_p.data().ln()));
664 let term2 = one
665 .sub_var(target)
666 .mul_var(&Variable::from_tensor(log_1_p.data().ln()));
667
668 term1.add_var(&term2).neg_var().mean()
669 }
670
671 #[must_use]
677 pub fn reshape(&self, shape: &[usize]) -> Variable {
678 let isize_shape: Vec<isize> = shape.iter().map(|&x| x as isize).collect();
679 let original_shape = self.shape();
680 let new_data = self
681 .data()
682 .reshape(&isize_shape)
683 .unwrap_or_else(|_| self.data().clone());
684 let requires_grad = self.requires_grad && is_grad_enabled();
685
686 if requires_grad {
687 let grad_fn = GradFn::new(ReshapeBackward::new(self.grad_fn.clone(), original_shape));
688 Variable::from_operation(new_data, grad_fn, true)
689 } else {
690 Variable::from_tensor(new_data)
691 }
692 }
693
694 #[must_use]
699 pub fn flatten(&self, start_dim: usize) -> Variable {
700 let shape = self.shape();
701 if start_dim >= shape.len() {
702 return self.clone();
703 }
704 let mut new_shape: Vec<usize> = shape[..start_dim].to_vec();
705 let flat: usize = shape[start_dim..].iter().product();
706 new_shape.push(flat);
707 self.reshape(&new_shape)
708 }
709
710 #[must_use]
712 pub fn transpose(&self, dim0: usize, dim1: usize) -> Variable {
713 let new_data = self
714 .data()
715 .transpose(dim0 as i64, dim1 as i64)
716 .unwrap_or_else(|_| self.data().clone());
717 let requires_grad = self.requires_grad && is_grad_enabled();
718
719 if requires_grad {
720 let grad_fn = GradFn::new(TransposeBackward::new(self.grad_fn.clone(), dim0, dim1));
721 Variable::from_operation(new_data, grad_fn, true)
722 } else {
723 Variable::from_tensor(new_data)
724 }
725 }
726
727 #[must_use]
729 pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Variable {
730 let new_data = self.data().slice(ranges);
731 Variable::new(new_data, self.requires_grad())
732 }
733
734 #[must_use]
739 pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Variable {
740 let input_shape = self.shape();
741 let new_data = self
742 .data()
743 .narrow(dim, start, length)
744 .unwrap_or_else(|_| self.data().clone());
745 let requires_grad = self.requires_grad && is_grad_enabled();
746
747 if requires_grad {
748 let grad_fn = GradFn::new(NarrowBackward::new(
749 self.grad_fn.clone(),
750 input_shape,
751 dim,
752 start,
753 ));
754 Variable::from_operation(new_data, grad_fn, true)
755 } else {
756 Variable::from_tensor(new_data)
757 }
758 }
759
760 #[must_use]
764 pub fn expand(&self, shape: &[usize]) -> Variable {
765 let input_shape = self.shape();
766 let new_data = self.data().broadcast_to(shape);
767 let requires_grad = self.requires_grad && is_grad_enabled();
768
769 if requires_grad {
770 let grad_fn = GradFn::new(ExpandBackward::new(self.grad_fn.clone(), input_shape));
771 Variable::from_operation(new_data, grad_fn, true)
772 } else {
773 Variable::from_tensor(new_data)
774 }
775 }
776
777 #[must_use]
782 pub fn select(&self, dim: usize, index: usize) -> Variable {
783 let input_shape = self.shape();
784 let new_data = self
785 .data()
786 .select(dim, index)
787 .unwrap_or_else(|_| self.data().clone());
788 let requires_grad = self.requires_grad && is_grad_enabled();
789
790 if requires_grad {
791 let grad_fn = GradFn::new(SelectBackward::new(
792 self.grad_fn.clone(),
793 input_shape,
794 dim,
795 index,
796 ));
797 Variable::from_operation(new_data, grad_fn, true)
798 } else {
799 Variable::from_tensor(new_data)
800 }
801 }
802
803 #[must_use]
807 pub fn unsqueeze(&self, dim: usize) -> Variable {
808 let new_data = self
809 .data()
810 .unsqueeze(dim as i64)
811 .unwrap_or_else(|_| self.data().clone());
812 let requires_grad = self.requires_grad && is_grad_enabled();
813
814 if requires_grad {
815 let grad_fn = GradFn::new(UnsqueezeBackward::new(self.grad_fn.clone(), dim));
816 Variable::from_operation(new_data, grad_fn, true)
817 } else {
818 Variable::from_tensor(new_data)
819 }
820 }
821
822 #[must_use]
827 pub fn cat(variables: &[&Variable], dim: usize) -> Variable {
828 let tensors: Vec<Tensor<f32>> = variables.iter().map(|v| v.data()).collect();
829 let tensor_refs: Vec<&Tensor<f32>> = tensors.iter().collect();
830 let result = Tensor::cat(&tensor_refs, dim).unwrap();
831
832 let requires_grad = variables.iter().any(|v| v.requires_grad) && is_grad_enabled();
833
834 if requires_grad {
835 let next_fns: Vec<Option<GradFn>> =
836 variables.iter().map(|v| v.grad_fn.clone()).collect();
837 let sizes: Vec<usize> = variables.iter().map(|v| v.shape()[dim]).collect();
838 let grad_fn = GradFn::new(CatBackward::new(next_fns, sizes, dim));
839 Variable::from_operation(result, grad_fn, true)
840 } else {
841 Variable::from_tensor(result)
842 }
843 }
844
845 #[must_use]
851 pub fn mul_scalar(&self, scalar: f32) -> Variable {
852 let data = self.data();
853 let result = data.mul_scalar(scalar);
854 let requires_grad = self.requires_grad && is_grad_enabled();
855
856 if requires_grad {
857 let grad_fn = GradFn::new(MulScalarBackward::new(self.grad_fn.clone(), scalar));
858 Variable::from_operation(result, grad_fn, true)
859 } else {
860 Variable::from_tensor(result)
861 }
862 }
863
864 #[must_use]
866 pub fn add_scalar(&self, scalar: f32) -> Variable {
867 let data = self.data();
868 let result = data.add_scalar(scalar);
869 let requires_grad = self.requires_grad && is_grad_enabled();
870
871 if requires_grad {
872 let grad_fn = GradFn::new(AddScalarBackward::new(self.grad_fn.clone()));
873 Variable::from_operation(result, grad_fn, true)
874 } else {
875 Variable::from_tensor(result)
876 }
877 }
878
879 #[must_use]
881 pub fn sub_scalar(&self, scalar: f32) -> Variable {
882 self.add_scalar(-scalar)
883 }
884
885 #[must_use]
887 pub fn div_scalar(&self, scalar: f32) -> Variable {
888 self.mul_scalar(1.0 / scalar)
889 }
890
891 #[must_use]
897 pub fn gelu(&self) -> Variable {
898 let self_data = self.data();
899 let result = self_data.gelu();
900 let requires_grad = self.requires_grad && is_grad_enabled();
901
902 if requires_grad {
903 let grad_fn = GradFn::new(GeluBackward::new(self.grad_fn.clone(), self_data));
904 Variable::from_operation(result, grad_fn, true)
905 } else {
906 Variable::from_tensor(result)
907 }
908 }
909
910 #[must_use]
912 pub fn silu(&self) -> Variable {
913 let self_data = self.data();
914 let result = self_data.silu();
915 let requires_grad = self.requires_grad && is_grad_enabled();
916
917 if requires_grad {
918 let grad_fn = GradFn::new(SiluBackward::new(self.grad_fn.clone(), self_data));
919 Variable::from_operation(result, grad_fn, true)
920 } else {
921 Variable::from_tensor(result)
922 }
923 }
924
925 #[must_use]
927 pub fn sqrt(&self) -> Variable {
928 let data = self.data();
929 let result = data.sqrt();
930 let requires_grad = self.requires_grad && is_grad_enabled();
931
932 if requires_grad {
933 let grad_fn = GradFn::new(SqrtBackward::new(self.grad_fn.clone(), result.clone()));
934 Variable::from_operation(result, grad_fn, true)
935 } else {
936 Variable::from_tensor(result)
937 }
938 }
939
940 #[must_use]
946 pub fn softmax(&self, dim: i32) -> Variable {
947 let data = self.data();
948 let result = data.softmax(dim);
949 let requires_grad = self.requires_grad && is_grad_enabled();
950
951 if requires_grad {
952 let grad_fn = GradFn::new(SoftmaxBackward::new(
953 self.grad_fn.clone(),
954 result.clone(),
955 dim as i64,
956 ));
957 Variable::from_operation(result, grad_fn, true)
958 } else {
959 Variable::from_tensor(result)
960 }
961 }
962
963 #[must_use]
965 pub fn log_softmax(&self, dim: i32) -> Variable {
966 let data = self.data();
967 let result = data.log_softmax(dim);
968 let requires_grad = self.requires_grad && is_grad_enabled();
969
970 if requires_grad {
971 let grad_fn = GradFn::new(LogSoftmaxBackward::new(
972 self.grad_fn.clone(),
973 result.clone(),
974 dim as i64,
975 ));
976 Variable::from_operation(result, grad_fn, true)
977 } else {
978 Variable::from_tensor(result)
979 }
980 }
981
982 #[must_use]
988 pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Variable {
989 let data = self.data();
990 let input_shape = data.shape().to_vec();
991 let ndim = input_shape.len();
992 let dim_usize = if dim < 0 {
993 (ndim as i32 + dim) as usize
994 } else {
995 dim as usize
996 };
997 let result = data.mean_dim(dim, keepdim);
998 let requires_grad = self.requires_grad && is_grad_enabled();
999
1000 if requires_grad {
1001 let grad_fn = GradFn::new(MeanDimBackward::new(
1002 self.grad_fn.clone(),
1003 input_shape,
1004 dim_usize,
1005 keepdim,
1006 ));
1007 Variable::from_operation(result, grad_fn, true)
1008 } else {
1009 Variable::from_tensor(result)
1010 }
1011 }
1012
1013 #[must_use]
1015 pub fn var_dim(&self, dim: i32, keepdim: bool) -> Variable {
1016 let self_data = self.data();
1017 let input_shape = self_data.shape().to_vec();
1018 let ndim = input_shape.len();
1019 let dim_usize = if dim < 0 {
1020 (ndim as i32 + dim) as usize
1021 } else {
1022 dim as usize
1023 };
1024 let result = self_data.var_dim(dim, keepdim);
1025 let requires_grad = self.requires_grad && is_grad_enabled();
1026
1027 if requires_grad {
1028 let grad_fn = GradFn::new(VarDimBackward::new(
1029 self.grad_fn.clone(),
1030 self_data,
1031 dim_usize,
1032 keepdim,
1033 ));
1034 Variable::from_operation(result, grad_fn, true)
1035 } else {
1036 Variable::from_tensor(result)
1037 }
1038 }
1039
1040 #[must_use]
1047 pub fn from_tensor_with_grad(data: Tensor<f32>, requires_grad: bool) -> Variable {
1048 Variable::new(data, requires_grad)
1049 }
1050
1051 #[must_use]
1053 pub fn clone_var(&self) -> Variable {
1054 self.clone()
1055 }
1056
1057 #[must_use]
1059 pub fn add(&self, other: &Variable) -> Variable {
1060 self.add_var(other)
1061 }
1062
1063 #[must_use]
1065 pub fn sub(&self, other: &Variable) -> Variable {
1066 self.sub_var(other)
1067 }
1068
1069 #[must_use]
1071 pub fn mul(&self, other: &Variable) -> Variable {
1072 self.mul_var(other)
1073 }
1074
1075 #[must_use]
1077 pub fn div(&self, other: &Variable) -> Variable {
1078 self.div_var(other)
1079 }
1080}
1081
1082impl Add for &Variable {
1087 type Output = Variable;
1088
1089 fn add(self, other: &Variable) -> Variable {
1090 self.add_var(other)
1091 }
1092}
1093
1094impl Sub for &Variable {
1095 type Output = Variable;
1096
1097 fn sub(self, other: &Variable) -> Variable {
1098 self.sub_var(other)
1099 }
1100}
1101
1102impl Mul for &Variable {
1103 type Output = Variable;
1104
1105 fn mul(self, other: &Variable) -> Variable {
1106 self.mul_var(other)
1107 }
1108}
1109
1110impl Div for &Variable {
1111 type Output = Variable;
1112
1113 fn div(self, other: &Variable) -> Variable {
1114 self.div_var(other)
1115 }
1116}
1117
1118impl Neg for &Variable {
1119 type Output = Variable;
1120
1121 fn neg(self) -> Variable {
1122 self.neg_var()
1123 }
1124}
1125
1126impl std::fmt::Debug for Variable {
1127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1128 f.debug_struct("Variable")
1129 .field("shape", &self.shape())
1130 .field("requires_grad", &self.requires_grad)
1131 .field("is_leaf", &self.is_leaf)
1132 .field(
1133 "grad_fn",
1134 &self.grad_fn.as_ref().map(super::grad_fn::GradFn::name),
1135 )
1136 .finish()
1137 }
1138}
1139
1140#[cfg(test)]
1145mod tests {
1146 use super::*;
1147 use axonml_tensor::zeros;
1148
1149 #[test]
1150 fn test_variable_creation() {
1151 let t = zeros::<f32>(&[2, 3]);
1152 let v = Variable::new(t, true);
1153 assert!(v.requires_grad());
1154 assert!(v.is_leaf());
1155 assert_eq!(v.shape(), vec![2, 3]);
1156 }
1157
1158 #[test]
1159 fn test_variable_no_grad() {
1160 let t = zeros::<f32>(&[2, 3]);
1161 let v = Variable::from_tensor(t);
1162 assert!(!v.requires_grad());
1163 }
1164
1165 #[test]
1166 fn test_variable_add() {
1167 let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), true);
1168 let b = Variable::new(Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).expect("tensor creation failed"), true);
1169 let c = &a + &b;
1170 assert_eq!(c.data().to_vec(), vec![5.0, 7.0, 9.0]);
1171 assert!(c.requires_grad());
1172 assert!(!c.is_leaf());
1173 }
1174
1175 #[test]
1176 fn test_variable_detach() {
1177 let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), true);
1178 let b = a.detach();
1179 assert!(!b.requires_grad());
1180 assert!(b.is_leaf());
1181 }
1182
1183 #[test]
1184 fn test_mse_loss() {
1185 let pred = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), true);
1186 let target = Variable::from_tensor(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"));
1187 let loss = pred.mse_loss(&target);
1188 assert_eq!(loss.numel(), 1);
1189 assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
1190 }
1191
1192 #[test]
1193 fn test_exp() {
1194 let a = Variable::new(Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).expect("tensor creation failed"), true);
1195 let b = a.exp();
1196 assert!((b.data().to_vec()[0] - 1.0).abs() < 1e-5);
1197 assert!((b.data().to_vec()[1] - std::f32::consts::E).abs() < 1e-4);
1198
1199 b.sum().backward();
1200 let grad = a.grad().unwrap().to_vec();
1201 assert!((grad[0] - 1.0).abs() < 1e-5);
1203 assert!((grad[1] - std::f32::consts::E).abs() < 1e-4);
1204 }
1205
1206 #[test]
1207 fn test_log() {
1208 let a = Variable::new(
1209 Tensor::from_vec(vec![1.0, std::f32::consts::E, 10.0], &[3]).expect("tensor creation failed"),
1210 true,
1211 );
1212 let b = a.log();
1213 assert!((b.data().to_vec()[0] - 0.0).abs() < 1e-5);
1214 assert!((b.data().to_vec()[1] - 1.0).abs() < 1e-5);
1215
1216 b.sum().backward();
1217 let grad = a.grad().unwrap().to_vec();
1218 assert!((grad[0] - 1.0).abs() < 1e-5);
1220 assert!((grad[1] - 1.0 / std::f32::consts::E).abs() < 1e-5);
1221 }
1222
1223 #[test]
1224 fn test_clamp() {
1225 let a = Variable::new(Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).expect("tensor creation failed"), true);
1226 let b = a.clamp(0.0, 1.0);
1227 assert_eq!(b.data().to_vec(), vec![0.0, 0.5, 1.0]);
1228
1229 b.sum().backward();
1230 let grad = a.grad().unwrap().to_vec();
1231 assert_eq!(grad[0], 0.0); assert_eq!(grad[1], 1.0); assert_eq!(grad[2], 0.0); }
1236}