1use std::ops::{Add, Div, Mul, Neg, Sub};
11use std::sync::Arc;
12
13use parking_lot::RwLock;
14
15use axonml_tensor::Tensor;
16
17use crate::functions::{
18 AddBackward, CatBackward, DivBackward, ExpandBackward, MatMulBackward, MeanBackward,
19 MulBackward, NarrowBackward, NegBackward, PowBackward, ReluBackward, ReshapeBackward,
20 SelectBackward, SigmoidBackward, SubBackward, SumBackward, SumDimBackward, TanhBackward,
21 TransposeBackward, UnsqueezeBackward,
22};
23use crate::grad_fn::{AccumulateGrad, GradAccumulator, GradFn};
24use crate::graph::{with_graph, GraphNode};
25use crate::no_grad::is_grad_enabled;
26
27#[derive(Clone)]
37pub struct Variable {
38 data: Arc<RwLock<Tensor<f32>>>,
40 grad: GradAccumulator,
42 requires_grad: bool,
44 is_leaf: bool,
46 grad_fn: Option<GradFn>,
48 node: Option<Arc<GraphNode>>,
50}
51
52impl Variable {
53 #[must_use]
59 pub fn new(data: Tensor<f32>, requires_grad: bool) -> Self {
60 let grad: GradAccumulator = Arc::new(RwLock::new(None));
62
63 let node = if requires_grad {
64 Some(with_graph(|g| g.register_leaf(true)))
65 } else {
66 None
67 };
68
69 let grad_fn = if requires_grad {
71 Some(GradFn::new(AccumulateGrad::new(Arc::clone(&grad))))
72 } else {
73 None
74 };
75
76 Self {
77 data: Arc::new(RwLock::new(data)),
78 grad,
79 requires_grad,
80 is_leaf: true,
81 grad_fn,
82 node,
83 }
84 }
85
86 #[must_use]
88 pub fn from_tensor(data: Tensor<f32>) -> Self {
89 Self::new(data, false)
90 }
91
92 pub fn from_operation(data: Tensor<f32>, grad_fn: GradFn, requires_grad: bool) -> Self {
97 let node = if requires_grad {
98 Some(with_graph(|g| g.register_operation(grad_fn.clone(), true)))
99 } else {
100 None
101 };
102
103 Self {
104 data: Arc::new(RwLock::new(data)),
105 grad: Arc::new(RwLock::new(None)),
106 requires_grad,
107 is_leaf: false,
108 grad_fn: if requires_grad { Some(grad_fn) } else { None },
109 node,
110 }
111 }
112
113 #[must_use]
115 pub fn data(&self) -> Tensor<f32> {
116 self.data.read().clone()
117 }
118
119 #[must_use]
121 pub fn shape(&self) -> Vec<usize> {
122 self.data.read().shape().to_vec()
123 }
124
125 #[must_use]
127 pub fn ndim(&self) -> usize {
128 self.data.read().ndim()
129 }
130
131 #[must_use]
133 pub fn numel(&self) -> usize {
134 self.data.read().numel()
135 }
136
137 #[must_use]
139 pub fn requires_grad(&self) -> bool {
140 self.requires_grad
141 }
142
143 #[must_use]
145 pub fn is_leaf(&self) -> bool {
146 self.is_leaf
147 }
148
149 #[must_use]
153 pub fn grad(&self) -> Option<Tensor<f32>> {
154 self.grad.read().clone()
155 }
156
157 #[must_use]
159 pub fn grad_fn(&self) -> Option<&GradFn> {
160 self.grad_fn.as_ref()
161 }
162
163 pub fn set_grad(&self, grad: Tensor<f32>) {
165 *self.grad.write() = Some(grad);
166 }
167
168 pub fn accumulate_grad(&self, grad: &Tensor<f32>) {
170 let mut grad_lock = self.grad.write();
171 if let Some(ref existing) = *grad_lock {
172 *grad_lock = Some(existing.add(grad).unwrap());
173 } else {
174 *grad_lock = Some(grad.clone());
175 }
176 }
177
178 pub fn zero_grad(&self) {
180 *self.grad.write() = None;
181 }
182
183 #[must_use]
187 pub fn detach(&self) -> Self {
188 Self {
189 data: Arc::new(RwLock::new(self.data.read().clone())),
190 grad: Arc::new(RwLock::new(None)),
191 requires_grad: false,
192 is_leaf: true,
193 grad_fn: None,
194 node: None,
195 }
196 }
197
198 #[must_use]
200 pub fn requires_grad_(mut self, requires_grad: bool) -> Self {
201 self.requires_grad = requires_grad;
202 if requires_grad && self.is_leaf {
203 self.grad_fn = Some(GradFn::new(AccumulateGrad::new(Arc::clone(&self.grad))));
205 self.node = Some(with_graph(|g| g.register_leaf(true)));
206 }
207 self
208 }
209
210 pub fn backward(&self) {
215 assert!(
216 self.requires_grad,
217 "Cannot call backward on a variable that doesn't require gradients"
218 );
219
220 assert!(
221 (self.numel() == 1),
222 "backward() can only be called on scalar tensors"
223 );
224
225 let grad_output = Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap();
227 crate::backward::backward(self, &grad_output);
228 }
229
230 #[must_use]
236 pub fn add_var(&self, other: &Variable) -> Variable {
237 let result = self.data.read().add(&other.data.read()).unwrap();
238 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
239
240 if requires_grad {
241 let grad_fn = GradFn::new(AddBackward::new(
242 self.grad_fn.clone(),
243 other.grad_fn.clone(),
244 self.shape(),
245 other.shape(),
246 ));
247 Variable::from_operation(result, grad_fn, true)
248 } else {
249 Variable::from_tensor(result)
250 }
251 }
252
253 #[must_use]
255 pub fn sub_var(&self, other: &Variable) -> Variable {
256 let result = self.data.read().sub(&other.data.read()).unwrap();
257 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
258
259 if requires_grad {
260 let grad_fn = GradFn::new(SubBackward::new(
261 self.grad_fn.clone(),
262 other.grad_fn.clone(),
263 self.shape(),
264 other.shape(),
265 ));
266 Variable::from_operation(result, grad_fn, true)
267 } else {
268 Variable::from_tensor(result)
269 }
270 }
271
272 #[must_use]
274 pub fn mul_var(&self, other: &Variable) -> Variable {
275 let self_data = self.data.read().clone();
276 let other_data = other.data.read().clone();
277 let result = self_data.mul(&other_data).unwrap();
278 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
279
280 if requires_grad {
281 let grad_fn = GradFn::new(MulBackward::new(
282 self.grad_fn.clone(),
283 other.grad_fn.clone(),
284 self_data,
285 other_data,
286 ));
287 Variable::from_operation(result, grad_fn, true)
288 } else {
289 Variable::from_tensor(result)
290 }
291 }
292
293 #[must_use]
295 pub fn div_var(&self, other: &Variable) -> Variable {
296 let self_data = self.data.read().clone();
297 let other_data = other.data.read().clone();
298 let result = self_data.div(&other_data).unwrap();
299 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
300
301 if requires_grad {
302 let grad_fn = GradFn::new(DivBackward::new(
303 self.grad_fn.clone(),
304 other.grad_fn.clone(),
305 self_data,
306 other_data,
307 ));
308 Variable::from_operation(result, grad_fn, true)
309 } else {
310 Variable::from_tensor(result)
311 }
312 }
313
314 #[must_use]
316 pub fn neg_var(&self) -> Variable {
317 let result = self.data.read().neg();
318 let requires_grad = self.requires_grad && is_grad_enabled();
319
320 if requires_grad {
321 let grad_fn = GradFn::new(NegBackward::new(self.grad_fn.clone()));
322 Variable::from_operation(result, grad_fn, true)
323 } else {
324 Variable::from_tensor(result)
325 }
326 }
327
328 #[must_use]
330 pub fn matmul(&self, other: &Variable) -> Variable {
331 let self_data = self.data.read().clone();
332 let other_data = other.data.read().clone();
333 let result = self_data.matmul(&other_data).unwrap();
334 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
335
336 if requires_grad {
337 let grad_fn = GradFn::new(MatMulBackward::new(
338 self.grad_fn.clone(),
339 other.grad_fn.clone(),
340 self_data,
341 other_data,
342 ));
343 Variable::from_operation(result, grad_fn, true)
344 } else {
345 Variable::from_tensor(result)
346 }
347 }
348
349 #[must_use]
351 pub fn pow(&self, exponent: f32) -> Variable {
352 let self_data = self.data.read().clone();
353 let result = self_data.pow(exponent);
354 let requires_grad = self.requires_grad && is_grad_enabled();
355
356 if requires_grad {
357 let grad_fn = GradFn::new(PowBackward::new(self.grad_fn.clone(), self_data, exponent));
358 Variable::from_operation(result, grad_fn, true)
359 } else {
360 Variable::from_tensor(result)
361 }
362 }
363
364 #[must_use]
370 pub fn relu(&self) -> Variable {
371 let self_data = self.data.read().clone();
372 let result = self_data.relu();
373 let requires_grad = self.requires_grad && is_grad_enabled();
374
375 if requires_grad {
376 let grad_fn = GradFn::new(ReluBackward::new(self.grad_fn.clone(), self_data));
377 Variable::from_operation(result, grad_fn, true)
378 } else {
379 Variable::from_tensor(result)
380 }
381 }
382
383 #[must_use]
385 pub fn sigmoid(&self) -> Variable {
386 let result = self.data.read().sigmoid();
387 let requires_grad = self.requires_grad && is_grad_enabled();
388
389 if requires_grad {
390 let grad_fn = GradFn::new(SigmoidBackward::new(self.grad_fn.clone(), result.clone()));
391 Variable::from_operation(result, grad_fn, true)
392 } else {
393 Variable::from_tensor(result)
394 }
395 }
396
397 #[must_use]
399 pub fn tanh(&self) -> Variable {
400 let result = self.data.read().tanh();
401 let requires_grad = self.requires_grad && is_grad_enabled();
402
403 if requires_grad {
404 let grad_fn = GradFn::new(TanhBackward::new(self.grad_fn.clone(), result.clone()));
405 Variable::from_operation(result, grad_fn, true)
406 } else {
407 Variable::from_tensor(result)
408 }
409 }
410
411 #[must_use]
417 pub fn sum(&self) -> Variable {
418 let self_data = self.data.read().clone();
419 let result = self_data.sum(); let requires_grad = self.requires_grad && is_grad_enabled();
421
422 if requires_grad {
423 let grad_fn = GradFn::new(SumBackward::new(self.grad_fn.clone(), self.shape()));
424 Variable::from_operation(result, grad_fn, true)
425 } else {
426 Variable::from_tensor(result)
427 }
428 }
429
430 #[must_use]
432 pub fn sum_dim(&self, dim: usize) -> Variable {
433 let self_data = self.data.read().clone();
434 let result = self_data.sum_dim(dim as i32, false);
435 let requires_grad = self.requires_grad && is_grad_enabled();
436
437 if requires_grad {
438 let grad_fn = GradFn::new(SumDimBackward::new(
439 self.grad_fn.clone(),
440 self.shape(),
441 dim,
442 ));
443 Variable::from_operation(result, grad_fn, true)
444 } else {
445 Variable::from_tensor(result)
446 }
447 }
448
449 #[must_use]
451 pub fn mean(&self) -> Variable {
452 let self_data = self.data.read().clone();
453 let result = self_data.mean().unwrap(); let requires_grad = self.requires_grad && is_grad_enabled();
455
456 if requires_grad {
457 let grad_fn = GradFn::new(MeanBackward::new(self.grad_fn.clone(), self.shape()));
458 Variable::from_operation(result, grad_fn, true)
459 } else {
460 Variable::from_tensor(result)
461 }
462 }
463
464 #[must_use]
470 pub fn mse_loss(&self, target: &Variable) -> Variable {
471 let diff = self.sub_var(target);
472 let squared = diff.pow(2.0);
473 squared.mean()
474 }
475
476 #[must_use]
478 pub fn binary_cross_entropy(&self, target: &Variable) -> Variable {
479 let eps = Variable::from_tensor(Tensor::scalar(1e-7));
480 let one = Variable::from_tensor(Tensor::scalar(1.0));
481
482 let log_p = self.add_var(&eps);
484 let log_1_p = one.sub_var(self).add_var(&eps);
485
486 let term1 = target.mul_var(&Variable::from_tensor(log_p.data().ln()));
487 let term2 = one
488 .sub_var(target)
489 .mul_var(&Variable::from_tensor(log_1_p.data().ln()));
490
491 term1.add_var(&term2).neg_var().mean()
492 }
493
494 #[must_use]
500 pub fn reshape(&self, shape: &[usize]) -> Variable {
501 let isize_shape: Vec<isize> = shape.iter().map(|&x| x as isize).collect();
502 let original_shape = self.shape();
503 let new_data = self
504 .data()
505 .reshape(&isize_shape)
506 .unwrap_or_else(|_| self.data().clone());
507 let requires_grad = self.requires_grad && is_grad_enabled();
508
509 if requires_grad {
510 let grad_fn = GradFn::new(ReshapeBackward::new(self.grad_fn.clone(), original_shape));
511 Variable::from_operation(new_data, grad_fn, true)
512 } else {
513 Variable::from_tensor(new_data)
514 }
515 }
516
517 #[must_use]
519 pub fn transpose(&self, dim0: usize, dim1: usize) -> Variable {
520 let new_data = self
521 .data()
522 .transpose(dim0 as i64, dim1 as i64)
523 .unwrap_or_else(|_| self.data().clone());
524 let requires_grad = self.requires_grad && is_grad_enabled();
525
526 if requires_grad {
527 let grad_fn = GradFn::new(TransposeBackward::new(self.grad_fn.clone(), dim0, dim1));
528 Variable::from_operation(new_data, grad_fn, true)
529 } else {
530 Variable::from_tensor(new_data)
531 }
532 }
533
534 #[must_use]
536 pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Variable {
537 let new_data = self.data().slice(ranges);
538 Variable::new(new_data, self.requires_grad())
539 }
540
541 #[must_use]
546 pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Variable {
547 let input_shape = self.shape();
548 let new_data = self
549 .data()
550 .narrow(dim, start, length)
551 .unwrap_or_else(|_| self.data().clone());
552 let requires_grad = self.requires_grad && is_grad_enabled();
553
554 if requires_grad {
555 let grad_fn = GradFn::new(NarrowBackward::new(
556 self.grad_fn.clone(),
557 input_shape,
558 dim,
559 start,
560 ));
561 Variable::from_operation(new_data, grad_fn, true)
562 } else {
563 Variable::from_tensor(new_data)
564 }
565 }
566
567 #[must_use]
571 pub fn expand(&self, shape: &[usize]) -> Variable {
572 let input_shape = self.shape();
573 let new_data = self.data().broadcast_to(shape);
574 let requires_grad = self.requires_grad && is_grad_enabled();
575
576 if requires_grad {
577 let grad_fn = GradFn::new(ExpandBackward::new(self.grad_fn.clone(), input_shape));
578 Variable::from_operation(new_data, grad_fn, true)
579 } else {
580 Variable::from_tensor(new_data)
581 }
582 }
583
584 #[must_use]
589 pub fn select(&self, dim: usize, index: usize) -> Variable {
590 let input_shape = self.shape();
591 let new_data = self.data().select(dim, index)
592 .unwrap_or_else(|_| self.data().clone());
593 let requires_grad = self.requires_grad && is_grad_enabled();
594
595 if requires_grad {
596 let grad_fn = GradFn::new(SelectBackward::new(
597 self.grad_fn.clone(),
598 input_shape,
599 dim,
600 index,
601 ));
602 Variable::from_operation(new_data, grad_fn, true)
603 } else {
604 Variable::from_tensor(new_data)
605 }
606 }
607
608 #[must_use]
612 pub fn unsqueeze(&self, dim: usize) -> Variable {
613 let new_data = self.data().unsqueeze(dim as i64)
614 .unwrap_or_else(|_| self.data().clone());
615 let requires_grad = self.requires_grad && is_grad_enabled();
616
617 if requires_grad {
618 let grad_fn = GradFn::new(UnsqueezeBackward::new(self.grad_fn.clone(), dim));
619 Variable::from_operation(new_data, grad_fn, true)
620 } else {
621 Variable::from_tensor(new_data)
622 }
623 }
624
625 #[must_use]
630 pub fn cat(variables: &[&Variable], dim: usize) -> Variable {
631 let tensors: Vec<Tensor<f32>> = variables.iter().map(|v| v.data()).collect();
632 let tensor_refs: Vec<&Tensor<f32>> = tensors.iter().collect();
633 let result = Tensor::cat(&tensor_refs, dim).unwrap();
634
635 let requires_grad = variables.iter().any(|v| v.requires_grad) && is_grad_enabled();
636
637 if requires_grad {
638 let next_fns: Vec<Option<GradFn>> =
639 variables.iter().map(|v| v.grad_fn.clone()).collect();
640 let sizes: Vec<usize> = variables.iter().map(|v| v.shape()[dim]).collect();
641 let grad_fn = GradFn::new(CatBackward::new(next_fns, sizes, dim));
642 Variable::from_operation(result, grad_fn, true)
643 } else {
644 Variable::from_tensor(result)
645 }
646 }
647
648 #[must_use]
654 pub fn mul_scalar(&self, scalar: f32) -> Variable {
655 let data = self.data();
656 let shape = data.shape();
657 let numel: usize = shape.iter().product();
658 let scalar_tensor = Tensor::from_vec(vec![scalar; numel], shape).unwrap();
659 let scalar_var = Variable::new(scalar_tensor, false);
660 self.mul_var(&scalar_var)
661 }
662
663 #[must_use]
665 pub fn add_scalar(&self, scalar: f32) -> Variable {
666 let data = self.data();
667 let shape = data.shape();
668 let numel: usize = shape.iter().product();
669 let scalar_tensor = Tensor::from_vec(vec![scalar; numel], shape).unwrap();
670 let scalar_var = Variable::new(scalar_tensor, false);
671 self.add_var(&scalar_var)
672 }
673
674 #[must_use]
676 pub fn sub_scalar(&self, scalar: f32) -> Variable {
677 self.add_scalar(-scalar)
678 }
679
680 #[must_use]
682 pub fn div_scalar(&self, scalar: f32) -> Variable {
683 self.mul_scalar(1.0 / scalar)
684 }
685
686 #[must_use]
692 pub fn gelu(&self) -> Variable {
693 let data = self.data();
695 let result = data.gelu();
696 Variable::new(result, self.requires_grad())
697 }
698
699 #[must_use]
701 pub fn silu(&self) -> Variable {
702 let data = self.data();
703 let result = data.silu();
704 Variable::new(result, self.requires_grad())
705 }
706
707 #[must_use]
709 pub fn sqrt(&self) -> Variable {
710 let data = self.data();
711 let result = data.sqrt();
712 Variable::new(result, self.requires_grad())
713 }
714
715 #[must_use]
721 pub fn softmax(&self, dim: i32) -> Variable {
722 let data = self.data();
723 let result = data.softmax(dim);
724 Variable::new(result, self.requires_grad())
725 }
726
727 #[must_use]
729 pub fn log_softmax(&self, dim: i32) -> Variable {
730 let data = self.data();
731 let result = data.log_softmax(dim);
732 Variable::new(result, self.requires_grad())
733 }
734
735 #[must_use]
741 pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Variable {
742 let data = self.data();
743 let result = data.mean_dim(dim, keepdim);
744 Variable::new(result, self.requires_grad())
745 }
746
747 #[must_use]
749 pub fn var_dim(&self, dim: i32, keepdim: bool) -> Variable {
750 let data = self.data();
751 let result = data.var_dim(dim, keepdim);
752 Variable::new(result, self.requires_grad())
753 }
754
755 #[must_use]
762 pub fn from_tensor_with_grad(data: Tensor<f32>, requires_grad: bool) -> Variable {
763 Variable::new(data, requires_grad)
764 }
765
766 #[must_use]
768 pub fn clone_var(&self) -> Variable {
769 self.clone()
770 }
771
772 #[must_use]
774 pub fn add(&self, other: &Variable) -> Variable {
775 self.add_var(other)
776 }
777
778 #[must_use]
780 pub fn sub(&self, other: &Variable) -> Variable {
781 self.sub_var(other)
782 }
783
784 #[must_use]
786 pub fn mul(&self, other: &Variable) -> Variable {
787 self.mul_var(other)
788 }
789
790 #[must_use]
792 pub fn div(&self, other: &Variable) -> Variable {
793 self.div_var(other)
794 }
795}
796
797impl Add for &Variable {
802 type Output = Variable;
803
804 fn add(self, other: &Variable) -> Variable {
805 self.add_var(other)
806 }
807}
808
809impl Sub for &Variable {
810 type Output = Variable;
811
812 fn sub(self, other: &Variable) -> Variable {
813 self.sub_var(other)
814 }
815}
816
817impl Mul for &Variable {
818 type Output = Variable;
819
820 fn mul(self, other: &Variable) -> Variable {
821 self.mul_var(other)
822 }
823}
824
825impl Div for &Variable {
826 type Output = Variable;
827
828 fn div(self, other: &Variable) -> Variable {
829 self.div_var(other)
830 }
831}
832
833impl Neg for &Variable {
834 type Output = Variable;
835
836 fn neg(self) -> Variable {
837 self.neg_var()
838 }
839}
840
841impl std::fmt::Debug for Variable {
842 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
843 f.debug_struct("Variable")
844 .field("shape", &self.shape())
845 .field("requires_grad", &self.requires_grad)
846 .field("is_leaf", &self.is_leaf)
847 .field(
848 "grad_fn",
849 &self.grad_fn.as_ref().map(super::grad_fn::GradFn::name),
850 )
851 .finish()
852 }
853}
854
855#[cfg(test)]
860mod tests {
861 use super::*;
862 use axonml_tensor::zeros;
863
864 #[test]
865 fn test_variable_creation() {
866 let t = zeros::<f32>(&[2, 3]);
867 let v = Variable::new(t, true);
868 assert!(v.requires_grad());
869 assert!(v.is_leaf());
870 assert_eq!(v.shape(), vec![2, 3]);
871 }
872
873 #[test]
874 fn test_variable_no_grad() {
875 let t = zeros::<f32>(&[2, 3]);
876 let v = Variable::from_tensor(t);
877 assert!(!v.requires_grad());
878 }
879
880 #[test]
881 fn test_variable_add() {
882 let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
883 let b = Variable::new(Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(), true);
884 let c = &a + &b;
885 assert_eq!(c.data().to_vec(), vec![5.0, 7.0, 9.0]);
886 assert!(c.requires_grad());
887 assert!(!c.is_leaf());
888 }
889
890 #[test]
891 fn test_variable_detach() {
892 let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
893 let b = a.detach();
894 assert!(!b.requires_grad());
895 assert!(b.is_leaf());
896 }
897
898 #[test]
899 fn test_mse_loss() {
900 let pred = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
901 let target = Variable::from_tensor(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
902 let loss = pred.mse_loss(&target);
903 assert_eq!(loss.numel(), 1);
904 assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
905 }
906}