1use std::ops::{Add, Div, Mul, Neg, Sub};
18use std::sync::Arc;
19
20use parking_lot::RwLock;
21
22use axonml_tensor::Tensor;
23
24use crate::functions::{
25 AddBackward, AddScalarBackward, CatBackward, ClampBackward, DivBackward, EluBackward,
26 ExpBackward, ExpandBackward, GeluBackward, LeakyReluBackward, LogBackward, LogSoftmaxBackward,
27 MatMulBackward, MeanBackward, MeanDimBackward, MulBackward, MulScalarBackward, NarrowBackward,
28 NegBackward, PowBackward, ReluBackward, ReshapeBackward, SelectBackward, SigmoidBackward,
29 SiluBackward, SoftmaxBackward, SqrtBackward, SubBackward, SumBackward, SumDimBackward,
30 TanhBackward, TransposeBackward, UnsqueezeBackward, VarDimBackward,
31};
32use crate::grad_fn::{AccumulateGrad, GradAccumulator, GradFn};
33use crate::graph::{GraphNode, with_graph};
34use crate::no_grad::is_grad_enabled;
35
36#[derive(Clone)]
46pub struct Variable {
47 data: Arc<RwLock<Tensor<f32>>>,
49 grad: GradAccumulator,
51 requires_grad: bool,
53 is_leaf: bool,
55 grad_fn: Option<GradFn>,
57 node: Option<Arc<GraphNode>>,
59}
60
61impl Variable {
62 #[must_use]
68 pub fn new(data: Tensor<f32>, requires_grad: bool) -> Self {
69 let grad: GradAccumulator = Arc::new(RwLock::new(None));
71
72 let node = if requires_grad {
73 Some(with_graph(|g| g.register_leaf(true)))
74 } else {
75 None
76 };
77
78 let grad_fn = if requires_grad {
80 Some(GradFn::new(AccumulateGrad::new(Arc::clone(&grad))))
81 } else {
82 None
83 };
84
85 Self {
86 data: Arc::new(RwLock::new(data)),
87 grad,
88 requires_grad,
89 is_leaf: true,
90 grad_fn,
91 node,
92 }
93 }
94
95 #[must_use]
97 pub fn from_tensor(data: Tensor<f32>) -> Self {
98 Self::new(data, false)
99 }
100
101 pub fn from_operation(data: Tensor<f32>, grad_fn: GradFn, requires_grad: bool) -> Self {
106 let node = if requires_grad {
107 Some(with_graph(|g| g.register_operation(grad_fn.clone(), true)))
108 } else {
109 None
110 };
111
112 Self {
113 data: Arc::new(RwLock::new(data)),
114 grad: Arc::new(RwLock::new(None)),
115 requires_grad,
116 is_leaf: false,
117 grad_fn: if requires_grad { Some(grad_fn) } else { None },
118 node,
119 }
120 }
121
122 #[must_use]
124 pub fn data(&self) -> Tensor<f32> {
125 self.data.read().clone()
126 }
127
128 #[must_use]
130 pub fn shape(&self) -> Vec<usize> {
131 self.data.read().shape().to_vec()
132 }
133
134 #[must_use]
136 pub fn ndim(&self) -> usize {
137 self.data.read().ndim()
138 }
139
140 #[must_use]
142 pub fn numel(&self) -> usize {
143 self.data.read().numel()
144 }
145
146 #[must_use]
148 pub fn device(&self) -> axonml_tensor::Device {
149 self.data.read().device()
150 }
151
152 pub fn to_device(&self, device: axonml_tensor::Device) -> Self {
157 let current = self.data.read().clone();
158 if current.device() == device {
159 return self.clone();
160 }
161 let moved = current
162 .to_device(device)
163 .expect("Failed to move variable to device");
164 Variable::new(moved, self.requires_grad)
165 }
166
167 #[must_use]
169 pub fn requires_grad(&self) -> bool {
170 self.requires_grad
171 }
172
173 #[must_use]
175 pub fn is_leaf(&self) -> bool {
176 self.is_leaf
177 }
178
179 #[must_use]
183 pub fn grad(&self) -> Option<Tensor<f32>> {
184 self.grad.read().clone()
185 }
186
187 #[must_use]
189 pub fn grad_fn(&self) -> Option<&GradFn> {
190 self.grad_fn.as_ref()
191 }
192
193 pub fn set_grad(&self, grad: Tensor<f32>) {
195 *self.grad.write() = Some(grad);
196 }
197
198 pub fn accumulate_grad(&self, grad: &Tensor<f32>) {
200 let mut grad_lock = self.grad.write();
201 if let Some(ref existing) = *grad_lock {
202 *grad_lock = Some(existing.add(grad).unwrap());
203 } else {
204 *grad_lock = Some(grad.clone());
205 }
206 }
207
208 pub fn zero_grad(&self) {
210 *self.grad.write() = None;
211 }
212
213 #[must_use]
217 pub fn detach(&self) -> Self {
218 Self {
219 data: Arc::new(RwLock::new(self.data.read().clone())),
220 grad: Arc::new(RwLock::new(None)),
221 requires_grad: false,
222 is_leaf: true,
223 grad_fn: None,
224 node: None,
225 }
226 }
227
228 #[must_use]
230 pub fn requires_grad_(mut self, requires_grad: bool) -> Self {
231 self.requires_grad = requires_grad;
232 if requires_grad && self.is_leaf {
233 self.grad_fn = Some(GradFn::new(AccumulateGrad::new(Arc::clone(&self.grad))));
235 self.node = Some(with_graph(|g| g.register_leaf(true)));
236 }
237 self
238 }
239
240 pub fn backward(&self) {
245 assert!(
246 self.requires_grad,
247 "Cannot call backward on a variable that doesn't require gradients"
248 );
249
250 assert!(
251 (self.numel() == 1),
252 "backward() can only be called on scalar tensors"
253 );
254
255 let mut grad_output = Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap();
257 let device = self.data.read().device();
258 if device.is_gpu() {
259 grad_output = grad_output.to_device(device).unwrap();
260 }
261 crate::backward::backward(self, &grad_output);
262 }
263
264 pub fn backward_with_grad(&self, grad_output: &Tensor<f32>) {
269 if !self.requires_grad {
270 return;
271 }
272 let device = self.data.read().device();
273 let grad = if grad_output.device() != device && device.is_gpu() {
274 grad_output.to_device(device).unwrap()
275 } else {
276 grad_output.clone()
277 };
278 crate::backward::backward(self, &grad);
279 }
280
281 #[must_use]
287 pub fn add_var(&self, other: &Variable) -> Variable {
288 let result = self.data.read().add(&other.data.read()).unwrap();
289 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
290
291 if requires_grad {
292 let grad_fn = GradFn::new(AddBackward::new(
293 self.grad_fn.clone(),
294 other.grad_fn.clone(),
295 self.shape(),
296 other.shape(),
297 ));
298 Variable::from_operation(result, grad_fn, true)
299 } else {
300 Variable::from_tensor(result)
301 }
302 }
303
304 #[must_use]
306 pub fn sub_var(&self, other: &Variable) -> Variable {
307 let result = self.data.read().sub(&other.data.read()).unwrap();
308 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
309
310 if requires_grad {
311 let grad_fn = GradFn::new(SubBackward::new(
312 self.grad_fn.clone(),
313 other.grad_fn.clone(),
314 self.shape(),
315 other.shape(),
316 ));
317 Variable::from_operation(result, grad_fn, true)
318 } else {
319 Variable::from_tensor(result)
320 }
321 }
322
323 #[must_use]
325 pub fn mul_var(&self, other: &Variable) -> Variable {
326 let self_data = self.data.read().clone();
327 let other_data = other.data.read().clone();
328 let result = self_data.mul(&other_data).unwrap();
329 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
330
331 if requires_grad {
332 let grad_fn = GradFn::new(MulBackward::new(
333 self.grad_fn.clone(),
334 other.grad_fn.clone(),
335 self_data,
336 other_data,
337 ));
338 Variable::from_operation(result, grad_fn, true)
339 } else {
340 Variable::from_tensor(result)
341 }
342 }
343
344 #[must_use]
346 pub fn div_var(&self, other: &Variable) -> Variable {
347 let self_data = self.data.read().clone();
348 let other_data = other.data.read().clone();
349 let result = self_data.div(&other_data).unwrap();
350 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
351
352 if requires_grad {
353 let grad_fn = GradFn::new(DivBackward::new(
354 self.grad_fn.clone(),
355 other.grad_fn.clone(),
356 self_data,
357 other_data,
358 ));
359 Variable::from_operation(result, grad_fn, true)
360 } else {
361 Variable::from_tensor(result)
362 }
363 }
364
365 #[must_use]
367 pub fn neg_var(&self) -> Variable {
368 let result = self.data.read().neg();
369 let requires_grad = self.requires_grad && is_grad_enabled();
370
371 if requires_grad {
372 let grad_fn = GradFn::new(NegBackward::new(self.grad_fn.clone()));
373 Variable::from_operation(result, grad_fn, true)
374 } else {
375 Variable::from_tensor(result)
376 }
377 }
378
379 #[must_use]
381 pub fn matmul(&self, other: &Variable) -> Variable {
382 let self_data = self.data.read().clone();
383 let other_data = other.data.read().clone();
384 let result = self_data.matmul(&other_data).unwrap();
385 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
386
387 if requires_grad {
388 let grad_fn = GradFn::new(MatMulBackward::new(
389 self.grad_fn.clone(),
390 other.grad_fn.clone(),
391 self_data,
392 other_data,
393 ));
394 Variable::from_operation(result, grad_fn, true)
395 } else {
396 Variable::from_tensor(result)
397 }
398 }
399
400 #[must_use]
402 pub fn pow(&self, exponent: f32) -> Variable {
403 let self_data = self.data.read().clone();
404 let result = self_data.pow(exponent);
405 let requires_grad = self.requires_grad && is_grad_enabled();
406
407 if requires_grad {
408 let grad_fn = GradFn::new(PowBackward::new(self.grad_fn.clone(), self_data, exponent));
409 Variable::from_operation(result, grad_fn, true)
410 } else {
411 Variable::from_tensor(result)
412 }
413 }
414
415 #[must_use]
421 pub fn relu(&self) -> Variable {
422 let self_data = self.data.read().clone();
423 let result = self_data.relu();
424 let requires_grad = self.requires_grad && is_grad_enabled();
425
426 if requires_grad {
427 let grad_fn = GradFn::new(ReluBackward::new(self.grad_fn.clone(), self_data));
428 Variable::from_operation(result, grad_fn, true)
429 } else {
430 Variable::from_tensor(result)
431 }
432 }
433
434 #[must_use]
436 pub fn leaky_relu(&self, negative_slope: f32) -> Variable {
437 let self_data = self.data.read().clone();
438 let device = self_data.device();
439 let result_vec: Vec<f32> = self_data
440 .to_vec()
441 .iter()
442 .map(|&x| if x > 0.0 { x } else { x * negative_slope })
443 .collect();
444 let mut result = Tensor::from_vec(result_vec, self_data.shape()).unwrap();
445 if device.is_gpu() {
446 result = result.to_device(device).unwrap();
447 }
448 let requires_grad = self.requires_grad && is_grad_enabled();
449
450 if requires_grad {
451 let grad_fn = GradFn::new(LeakyReluBackward::new(
452 self.grad_fn.clone(),
453 self_data,
454 negative_slope,
455 ));
456 Variable::from_operation(result, grad_fn, true)
457 } else {
458 Variable::from_tensor(result)
459 }
460 }
461
462 #[must_use]
464 pub fn elu(&self, alpha: f32) -> Variable {
465 let self_data = self.data.read().clone();
466 let device = self_data.device();
467 let result_vec: Vec<f32> = self_data
468 .to_vec()
469 .iter()
470 .map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
471 .collect();
472 let mut result = Tensor::from_vec(result_vec, self_data.shape()).unwrap();
473 if device.is_gpu() {
474 result = result.to_device(device).unwrap();
475 }
476 let requires_grad = self.requires_grad && is_grad_enabled();
477
478 if requires_grad {
479 let grad_fn = GradFn::new(EluBackward::new(self.grad_fn.clone(), self_data, alpha));
480 Variable::from_operation(result, grad_fn, true)
481 } else {
482 Variable::from_tensor(result)
483 }
484 }
485
486 #[must_use]
488 pub fn sigmoid(&self) -> Variable {
489 let result = self.data.read().sigmoid();
490 let requires_grad = self.requires_grad && is_grad_enabled();
491
492 if requires_grad {
493 let grad_fn = GradFn::new(SigmoidBackward::new(self.grad_fn.clone(), result.clone()));
494 Variable::from_operation(result, grad_fn, true)
495 } else {
496 Variable::from_tensor(result)
497 }
498 }
499
500 #[must_use]
502 pub fn tanh(&self) -> Variable {
503 let result = self.data.read().tanh();
504 let requires_grad = self.requires_grad && is_grad_enabled();
505
506 if requires_grad {
507 let grad_fn = GradFn::new(TanhBackward::new(self.grad_fn.clone(), result.clone()));
508 Variable::from_operation(result, grad_fn, true)
509 } else {
510 Variable::from_tensor(result)
511 }
512 }
513
514 #[must_use]
516 pub fn exp(&self) -> Variable {
517 let self_data = self.data.read().clone();
518 let result = self_data.exp();
519 let requires_grad = self.requires_grad && is_grad_enabled();
520
521 if requires_grad {
522 let grad_fn = GradFn::new(ExpBackward::new(self.grad_fn.clone(), result.clone()));
523 Variable::from_operation(result, grad_fn, true)
524 } else {
525 Variable::from_tensor(result)
526 }
527 }
528
529 #[must_use]
531 pub fn log(&self) -> Variable {
532 let self_data = self.data.read().clone();
533 let result = self_data.ln();
534 let requires_grad = self.requires_grad && is_grad_enabled();
535
536 if requires_grad {
537 let grad_fn = GradFn::new(LogBackward::new(self.grad_fn.clone(), self_data));
538 Variable::from_operation(result, grad_fn, true)
539 } else {
540 Variable::from_tensor(result)
541 }
542 }
543
544 #[must_use]
546 pub fn clamp(&self, min_val: f32, max_val: f32) -> Variable {
547 let self_data = self.data.read().clone();
548 let device = self_data.device();
549 let result_data: Vec<f32> = self_data
550 .to_vec()
551 .iter()
552 .map(|&x| x.clamp(min_val, max_val))
553 .collect();
554 let mut result = Tensor::from_vec(result_data, self_data.shape()).unwrap();
555 if device.is_gpu() {
556 result = result.to_device(device).unwrap();
557 }
558 let requires_grad = self.requires_grad && is_grad_enabled();
559
560 if requires_grad {
561 let grad_fn = GradFn::new(ClampBackward::new(
562 self.grad_fn.clone(),
563 self_data,
564 min_val,
565 max_val,
566 ));
567 Variable::from_operation(result, grad_fn, true)
568 } else {
569 Variable::from_tensor(result)
570 }
571 }
572
573 #[must_use]
579 pub fn sum(&self) -> Variable {
580 let self_data = self.data.read().clone();
581 let result = self_data.sum(); let requires_grad = self.requires_grad && is_grad_enabled();
583
584 if requires_grad {
585 let grad_fn = GradFn::new(SumBackward::new(self.grad_fn.clone(), self.shape()));
586 Variable::from_operation(result, grad_fn, true)
587 } else {
588 Variable::from_tensor(result)
589 }
590 }
591
592 #[must_use]
594 pub fn sum_dim(&self, dim: usize) -> Variable {
595 let self_data = self.data.read().clone();
596 let result = self_data.sum_dim(dim as i32, false);
597 let requires_grad = self.requires_grad && is_grad_enabled();
598
599 if requires_grad {
600 let grad_fn = GradFn::new(SumDimBackward::new(self.grad_fn.clone(), self.shape(), dim));
601 Variable::from_operation(result, grad_fn, true)
602 } else {
603 Variable::from_tensor(result)
604 }
605 }
606
607 #[must_use]
609 pub fn mean(&self) -> Variable {
610 let self_data = self.data.read().clone();
611 let result = self_data.mean().unwrap(); let requires_grad = self.requires_grad && is_grad_enabled();
613
614 if requires_grad {
615 let grad_fn = GradFn::new(MeanBackward::new(self.grad_fn.clone(), self.shape()));
616 Variable::from_operation(result, grad_fn, true)
617 } else {
618 Variable::from_tensor(result)
619 }
620 }
621
622 #[must_use]
628 pub fn mse_loss(&self, target: &Variable) -> Variable {
629 let diff = self.sub_var(target);
630 let squared = diff.pow(2.0);
631 squared.mean()
632 }
633
634 #[must_use]
636 pub fn binary_cross_entropy(&self, target: &Variable) -> Variable {
637 let eps = Variable::from_tensor(Tensor::scalar(1e-7));
638 let one = Variable::from_tensor(Tensor::scalar(1.0));
639
640 let log_p = self.add_var(&eps);
642 let log_1_p = one.sub_var(self).add_var(&eps);
643
644 let term1 = target.mul_var(&Variable::from_tensor(log_p.data().ln()));
645 let term2 = one
646 .sub_var(target)
647 .mul_var(&Variable::from_tensor(log_1_p.data().ln()));
648
649 term1.add_var(&term2).neg_var().mean()
650 }
651
652 #[must_use]
658 pub fn reshape(&self, shape: &[usize]) -> Variable {
659 let isize_shape: Vec<isize> = shape.iter().map(|&x| x as isize).collect();
660 let original_shape = self.shape();
661 let new_data = self
662 .data()
663 .reshape(&isize_shape)
664 .unwrap_or_else(|_| self.data().clone());
665 let requires_grad = self.requires_grad && is_grad_enabled();
666
667 if requires_grad {
668 let grad_fn = GradFn::new(ReshapeBackward::new(self.grad_fn.clone(), original_shape));
669 Variable::from_operation(new_data, grad_fn, true)
670 } else {
671 Variable::from_tensor(new_data)
672 }
673 }
674
675 #[must_use]
680 pub fn flatten(&self, start_dim: usize) -> Variable {
681 let shape = self.shape();
682 if start_dim >= shape.len() {
683 return self.clone();
684 }
685 let mut new_shape: Vec<usize> = shape[..start_dim].to_vec();
686 let flat: usize = shape[start_dim..].iter().product();
687 new_shape.push(flat);
688 self.reshape(&new_shape)
689 }
690
691 #[must_use]
693 pub fn transpose(&self, dim0: usize, dim1: usize) -> Variable {
694 let new_data = self
695 .data()
696 .transpose(dim0 as i64, dim1 as i64)
697 .unwrap_or_else(|_| self.data().clone());
698 let requires_grad = self.requires_grad && is_grad_enabled();
699
700 if requires_grad {
701 let grad_fn = GradFn::new(TransposeBackward::new(self.grad_fn.clone(), dim0, dim1));
702 Variable::from_operation(new_data, grad_fn, true)
703 } else {
704 Variable::from_tensor(new_data)
705 }
706 }
707
708 #[must_use]
710 pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Variable {
711 let new_data = self.data().slice(ranges);
712 Variable::new(new_data, self.requires_grad())
713 }
714
715 #[must_use]
720 pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Variable {
721 let input_shape = self.shape();
722 let new_data = self
723 .data()
724 .narrow(dim, start, length)
725 .unwrap_or_else(|_| self.data().clone());
726 let requires_grad = self.requires_grad && is_grad_enabled();
727
728 if requires_grad {
729 let grad_fn = GradFn::new(NarrowBackward::new(
730 self.grad_fn.clone(),
731 input_shape,
732 dim,
733 start,
734 ));
735 Variable::from_operation(new_data, grad_fn, true)
736 } else {
737 Variable::from_tensor(new_data)
738 }
739 }
740
741 #[must_use]
745 pub fn expand(&self, shape: &[usize]) -> Variable {
746 let input_shape = self.shape();
747 let new_data = self.data().broadcast_to(shape);
748 let requires_grad = self.requires_grad && is_grad_enabled();
749
750 if requires_grad {
751 let grad_fn = GradFn::new(ExpandBackward::new(self.grad_fn.clone(), input_shape));
752 Variable::from_operation(new_data, grad_fn, true)
753 } else {
754 Variable::from_tensor(new_data)
755 }
756 }
757
758 #[must_use]
763 pub fn select(&self, dim: usize, index: usize) -> Variable {
764 let input_shape = self.shape();
765 let new_data = self
766 .data()
767 .select(dim, index)
768 .unwrap_or_else(|_| self.data().clone());
769 let requires_grad = self.requires_grad && is_grad_enabled();
770
771 if requires_grad {
772 let grad_fn = GradFn::new(SelectBackward::new(
773 self.grad_fn.clone(),
774 input_shape,
775 dim,
776 index,
777 ));
778 Variable::from_operation(new_data, grad_fn, true)
779 } else {
780 Variable::from_tensor(new_data)
781 }
782 }
783
784 #[must_use]
788 pub fn unsqueeze(&self, dim: usize) -> Variable {
789 let new_data = self
790 .data()
791 .unsqueeze(dim as i64)
792 .unwrap_or_else(|_| self.data().clone());
793 let requires_grad = self.requires_grad && is_grad_enabled();
794
795 if requires_grad {
796 let grad_fn = GradFn::new(UnsqueezeBackward::new(self.grad_fn.clone(), dim));
797 Variable::from_operation(new_data, grad_fn, true)
798 } else {
799 Variable::from_tensor(new_data)
800 }
801 }
802
803 #[must_use]
808 pub fn cat(variables: &[&Variable], dim: usize) -> Variable {
809 let tensors: Vec<Tensor<f32>> = variables.iter().map(|v| v.data()).collect();
810 let tensor_refs: Vec<&Tensor<f32>> = tensors.iter().collect();
811 let result = Tensor::cat(&tensor_refs, dim).unwrap();
812
813 let requires_grad = variables.iter().any(|v| v.requires_grad) && is_grad_enabled();
814
815 if requires_grad {
816 let next_fns: Vec<Option<GradFn>> =
817 variables.iter().map(|v| v.grad_fn.clone()).collect();
818 let sizes: Vec<usize> = variables.iter().map(|v| v.shape()[dim]).collect();
819 let grad_fn = GradFn::new(CatBackward::new(next_fns, sizes, dim));
820 Variable::from_operation(result, grad_fn, true)
821 } else {
822 Variable::from_tensor(result)
823 }
824 }
825
826 #[must_use]
832 pub fn mul_scalar(&self, scalar: f32) -> Variable {
833 let data = self.data();
834 let result = data.mul_scalar(scalar);
835 let requires_grad = self.requires_grad && is_grad_enabled();
836
837 if requires_grad {
838 let grad_fn = GradFn::new(MulScalarBackward::new(self.grad_fn.clone(), scalar));
839 Variable::from_operation(result, grad_fn, true)
840 } else {
841 Variable::from_tensor(result)
842 }
843 }
844
845 #[must_use]
847 pub fn add_scalar(&self, scalar: f32) -> Variable {
848 let data = self.data();
849 let result = data.add_scalar(scalar);
850 let requires_grad = self.requires_grad && is_grad_enabled();
851
852 if requires_grad {
853 let grad_fn = GradFn::new(AddScalarBackward::new(self.grad_fn.clone()));
854 Variable::from_operation(result, grad_fn, true)
855 } else {
856 Variable::from_tensor(result)
857 }
858 }
859
860 #[must_use]
862 pub fn sub_scalar(&self, scalar: f32) -> Variable {
863 self.add_scalar(-scalar)
864 }
865
866 #[must_use]
868 pub fn div_scalar(&self, scalar: f32) -> Variable {
869 self.mul_scalar(1.0 / scalar)
870 }
871
872 #[must_use]
878 pub fn gelu(&self) -> Variable {
879 let self_data = self.data();
880 let result = self_data.gelu();
881 let requires_grad = self.requires_grad && is_grad_enabled();
882
883 if requires_grad {
884 let grad_fn = GradFn::new(GeluBackward::new(self.grad_fn.clone(), self_data));
885 Variable::from_operation(result, grad_fn, true)
886 } else {
887 Variable::from_tensor(result)
888 }
889 }
890
891 #[must_use]
893 pub fn silu(&self) -> Variable {
894 let self_data = self.data();
895 let result = self_data.silu();
896 let requires_grad = self.requires_grad && is_grad_enabled();
897
898 if requires_grad {
899 let grad_fn = GradFn::new(SiluBackward::new(self.grad_fn.clone(), self_data));
900 Variable::from_operation(result, grad_fn, true)
901 } else {
902 Variable::from_tensor(result)
903 }
904 }
905
906 #[must_use]
908 pub fn sqrt(&self) -> Variable {
909 let data = self.data();
910 let result = data.sqrt();
911 let requires_grad = self.requires_grad && is_grad_enabled();
912
913 if requires_grad {
914 let grad_fn = GradFn::new(SqrtBackward::new(self.grad_fn.clone(), result.clone()));
915 Variable::from_operation(result, grad_fn, true)
916 } else {
917 Variable::from_tensor(result)
918 }
919 }
920
921 #[must_use]
927 pub fn softmax(&self, dim: i32) -> Variable {
928 let data = self.data();
929 let result = data.softmax(dim);
930 let requires_grad = self.requires_grad && is_grad_enabled();
931
932 if requires_grad {
933 let grad_fn = GradFn::new(SoftmaxBackward::new(
934 self.grad_fn.clone(),
935 result.clone(),
936 dim as i64,
937 ));
938 Variable::from_operation(result, grad_fn, true)
939 } else {
940 Variable::from_tensor(result)
941 }
942 }
943
944 #[must_use]
946 pub fn log_softmax(&self, dim: i32) -> Variable {
947 let data = self.data();
948 let result = data.log_softmax(dim);
949 let requires_grad = self.requires_grad && is_grad_enabled();
950
951 if requires_grad {
952 let grad_fn = GradFn::new(LogSoftmaxBackward::new(
953 self.grad_fn.clone(),
954 result.clone(),
955 dim as i64,
956 ));
957 Variable::from_operation(result, grad_fn, true)
958 } else {
959 Variable::from_tensor(result)
960 }
961 }
962
963 #[must_use]
969 pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Variable {
970 let data = self.data();
971 let input_shape = data.shape().to_vec();
972 let ndim = input_shape.len();
973 let dim_usize = if dim < 0 {
974 (ndim as i32 + dim) as usize
975 } else {
976 dim as usize
977 };
978 let result = data.mean_dim(dim, keepdim);
979 let requires_grad = self.requires_grad && is_grad_enabled();
980
981 if requires_grad {
982 let grad_fn = GradFn::new(MeanDimBackward::new(
983 self.grad_fn.clone(),
984 input_shape,
985 dim_usize,
986 keepdim,
987 ));
988 Variable::from_operation(result, grad_fn, true)
989 } else {
990 Variable::from_tensor(result)
991 }
992 }
993
994 #[must_use]
996 pub fn var_dim(&self, dim: i32, keepdim: bool) -> Variable {
997 let self_data = self.data();
998 let input_shape = self_data.shape().to_vec();
999 let ndim = input_shape.len();
1000 let dim_usize = if dim < 0 {
1001 (ndim as i32 + dim) as usize
1002 } else {
1003 dim as usize
1004 };
1005 let result = self_data.var_dim(dim, keepdim);
1006 let requires_grad = self.requires_grad && is_grad_enabled();
1007
1008 if requires_grad {
1009 let grad_fn = GradFn::new(VarDimBackward::new(
1010 self.grad_fn.clone(),
1011 self_data,
1012 dim_usize,
1013 keepdim,
1014 ));
1015 Variable::from_operation(result, grad_fn, true)
1016 } else {
1017 Variable::from_tensor(result)
1018 }
1019 }
1020
1021 #[must_use]
1028 pub fn from_tensor_with_grad(data: Tensor<f32>, requires_grad: bool) -> Variable {
1029 Variable::new(data, requires_grad)
1030 }
1031
1032 #[must_use]
1034 pub fn clone_var(&self) -> Variable {
1035 self.clone()
1036 }
1037
1038 #[must_use]
1040 pub fn add(&self, other: &Variable) -> Variable {
1041 self.add_var(other)
1042 }
1043
1044 #[must_use]
1046 pub fn sub(&self, other: &Variable) -> Variable {
1047 self.sub_var(other)
1048 }
1049
1050 #[must_use]
1052 pub fn mul(&self, other: &Variable) -> Variable {
1053 self.mul_var(other)
1054 }
1055
1056 #[must_use]
1058 pub fn div(&self, other: &Variable) -> Variable {
1059 self.div_var(other)
1060 }
1061}
1062
1063impl Add for &Variable {
1068 type Output = Variable;
1069
1070 fn add(self, other: &Variable) -> Variable {
1071 self.add_var(other)
1072 }
1073}
1074
1075impl Sub for &Variable {
1076 type Output = Variable;
1077
1078 fn sub(self, other: &Variable) -> Variable {
1079 self.sub_var(other)
1080 }
1081}
1082
1083impl Mul for &Variable {
1084 type Output = Variable;
1085
1086 fn mul(self, other: &Variable) -> Variable {
1087 self.mul_var(other)
1088 }
1089}
1090
1091impl Div for &Variable {
1092 type Output = Variable;
1093
1094 fn div(self, other: &Variable) -> Variable {
1095 self.div_var(other)
1096 }
1097}
1098
1099impl Neg for &Variable {
1100 type Output = Variable;
1101
1102 fn neg(self) -> Variable {
1103 self.neg_var()
1104 }
1105}
1106
1107impl std::fmt::Debug for Variable {
1108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1109 f.debug_struct("Variable")
1110 .field("shape", &self.shape())
1111 .field("requires_grad", &self.requires_grad)
1112 .field("is_leaf", &self.is_leaf)
1113 .field(
1114 "grad_fn",
1115 &self.grad_fn.as_ref().map(super::grad_fn::GradFn::name),
1116 )
1117 .finish()
1118 }
1119}
1120
1121#[cfg(test)]
1126mod tests {
1127 use super::*;
1128 use axonml_tensor::zeros;
1129
1130 #[test]
1131 fn test_variable_creation() {
1132 let t = zeros::<f32>(&[2, 3]);
1133 let v = Variable::new(t, true);
1134 assert!(v.requires_grad());
1135 assert!(v.is_leaf());
1136 assert_eq!(v.shape(), vec![2, 3]);
1137 }
1138
1139 #[test]
1140 fn test_variable_no_grad() {
1141 let t = zeros::<f32>(&[2, 3]);
1142 let v = Variable::from_tensor(t);
1143 assert!(!v.requires_grad());
1144 }
1145
1146 #[test]
1147 fn test_variable_add() {
1148 let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
1149 let b = Variable::new(Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(), true);
1150 let c = &a + &b;
1151 assert_eq!(c.data().to_vec(), vec![5.0, 7.0, 9.0]);
1152 assert!(c.requires_grad());
1153 assert!(!c.is_leaf());
1154 }
1155
1156 #[test]
1157 fn test_variable_detach() {
1158 let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
1159 let b = a.detach();
1160 assert!(!b.requires_grad());
1161 assert!(b.is_leaf());
1162 }
1163
1164 #[test]
1165 fn test_mse_loss() {
1166 let pred = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
1167 let target = Variable::from_tensor(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
1168 let loss = pred.mse_loss(&target);
1169 assert_eq!(loss.numel(), 1);
1170 assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
1171 }
1172
1173 #[test]
1174 fn test_exp() {
1175 let a = Variable::new(Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap(), true);
1176 let b = a.exp();
1177 assert!((b.data().to_vec()[0] - 1.0).abs() < 1e-5);
1178 assert!((b.data().to_vec()[1] - std::f32::consts::E).abs() < 1e-4);
1179
1180 b.sum().backward();
1181 let grad = a.grad().unwrap().to_vec();
1182 assert!((grad[0] - 1.0).abs() < 1e-5);
1184 assert!((grad[1] - std::f32::consts::E).abs() < 1e-4);
1185 }
1186
1187 #[test]
1188 fn test_log() {
1189 let a = Variable::new(
1190 Tensor::from_vec(vec![1.0, std::f32::consts::E, 10.0], &[3]).unwrap(),
1191 true,
1192 );
1193 let b = a.log();
1194 assert!((b.data().to_vec()[0] - 0.0).abs() < 1e-5);
1195 assert!((b.data().to_vec()[1] - 1.0).abs() < 1e-5);
1196
1197 b.sum().backward();
1198 let grad = a.grad().unwrap().to_vec();
1199 assert!((grad[0] - 1.0).abs() < 1e-5);
1201 assert!((grad[1] - 1.0 / std::f32::consts::E).abs() < 1e-5);
1202 }
1203
1204 #[test]
1205 fn test_clamp() {
1206 let a = Variable::new(Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap(), true);
1207 let b = a.clamp(0.0, 1.0);
1208 assert_eq!(b.data().to_vec(), vec![0.0, 0.5, 1.0]);
1209
1210 b.sum().backward();
1211 let grad = a.grad().unwrap().to_vec();
1212 assert_eq!(grad[0], 0.0); assert_eq!(grad[1], 1.0); assert_eq!(grad[2], 0.0); }
1217}