1use std::ops::{Add, Div, Mul, Neg, Sub};
11use std::sync::Arc;
12
13use parking_lot::RwLock;
14
15use axonml_tensor::Tensor;
16
17use crate::functions::{
18 AddBackward, DivBackward, MatMulBackward, MeanBackward, MulBackward, NegBackward, PowBackward,
19 ReluBackward, ReshapeBackward, SigmoidBackward, SubBackward, SumBackward, TanhBackward,
20 TransposeBackward,
21};
22use crate::grad_fn::{AccumulateGrad, GradAccumulator, GradFn};
23use crate::graph::{with_graph, GraphNode};
24use crate::no_grad::is_grad_enabled;
25
26#[derive(Clone)]
36pub struct Variable {
37 data: Arc<RwLock<Tensor<f32>>>,
39 grad: GradAccumulator,
41 requires_grad: bool,
43 is_leaf: bool,
45 grad_fn: Option<GradFn>,
47 node: Option<Arc<GraphNode>>,
49}
50
51impl Variable {
52 #[must_use] pub fn new(data: Tensor<f32>, requires_grad: bool) -> Self {
58 let grad: GradAccumulator = Arc::new(RwLock::new(None));
60
61 let node = if requires_grad {
62 Some(with_graph(|g| g.register_leaf(true)))
63 } else {
64 None
65 };
66
67 let grad_fn = if requires_grad {
69 Some(GradFn::new(AccumulateGrad::new(Arc::clone(&grad))))
70 } else {
71 None
72 };
73
74 Self {
75 data: Arc::new(RwLock::new(data)),
76 grad,
77 requires_grad,
78 is_leaf: true,
79 grad_fn,
80 node,
81 }
82 }
83
84 #[must_use] pub fn from_tensor(data: Tensor<f32>) -> Self {
86 Self::new(data, false)
87 }
88
89 fn from_operation(data: Tensor<f32>, grad_fn: GradFn, requires_grad: bool) -> Self {
91 let node = if requires_grad {
92 Some(with_graph(|g| g.register_operation(grad_fn.clone(), true)))
93 } else {
94 None
95 };
96
97 Self {
98 data: Arc::new(RwLock::new(data)),
99 grad: Arc::new(RwLock::new(None)),
100 requires_grad,
101 is_leaf: false,
102 grad_fn: if requires_grad { Some(grad_fn) } else { None },
103 node,
104 }
105 }
106
107 #[must_use] pub fn data(&self) -> Tensor<f32> {
109 self.data.read().clone()
110 }
111
112 #[must_use] pub fn shape(&self) -> Vec<usize> {
114 self.data.read().shape().to_vec()
115 }
116
117 #[must_use] pub fn ndim(&self) -> usize {
119 self.data.read().ndim()
120 }
121
122 #[must_use] pub fn numel(&self) -> usize {
124 self.data.read().numel()
125 }
126
127 #[must_use] pub fn requires_grad(&self) -> bool {
129 self.requires_grad
130 }
131
132 #[must_use] pub fn is_leaf(&self) -> bool {
134 self.is_leaf
135 }
136
137 #[must_use] pub fn grad(&self) -> Option<Tensor<f32>> {
141 self.grad.read().clone()
142 }
143
144 #[must_use] pub fn grad_fn(&self) -> Option<&GradFn> {
146 self.grad_fn.as_ref()
147 }
148
149 pub fn set_grad(&self, grad: Tensor<f32>) {
151 *self.grad.write() = Some(grad);
152 }
153
154 pub fn accumulate_grad(&self, grad: &Tensor<f32>) {
156 let mut grad_lock = self.grad.write();
157 if let Some(ref existing) = *grad_lock {
158 *grad_lock = Some(existing.add(grad).unwrap());
159 } else {
160 *grad_lock = Some(grad.clone());
161 }
162 }
163
164 pub fn zero_grad(&self) {
166 *self.grad.write() = None;
167 }
168
169 #[must_use] pub fn detach(&self) -> Self {
173 Self {
174 data: Arc::new(RwLock::new(self.data.read().clone())),
175 grad: Arc::new(RwLock::new(None)),
176 requires_grad: false,
177 is_leaf: true,
178 grad_fn: None,
179 node: None,
180 }
181 }
182
183 #[must_use] pub fn requires_grad_(mut self, requires_grad: bool) -> Self {
185 self.requires_grad = requires_grad;
186 if requires_grad && self.is_leaf {
187 self.grad_fn = Some(GradFn::new(AccumulateGrad::new(Arc::clone(&self.grad))));
189 self.node = Some(with_graph(|g| g.register_leaf(true)));
190 }
191 self
192 }
193
194 pub fn backward(&self) {
199 assert!(self.requires_grad, "Cannot call backward on a variable that doesn't require gradients");
200
201 assert!((self.numel() == 1), "backward() can only be called on scalar tensors");
202
203 let grad_output = Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap();
205 crate::backward::backward(self, &grad_output);
206 }
207
208 #[must_use] pub fn add_var(&self, other: &Variable) -> Variable {
214 let result = self.data.read().add(&other.data.read()).unwrap();
215 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
216
217 if requires_grad {
218 let grad_fn = GradFn::new(AddBackward::new(
219 self.grad_fn.clone(),
220 other.grad_fn.clone(),
221 self.shape(),
222 other.shape(),
223 ));
224 Variable::from_operation(result, grad_fn, true)
225 } else {
226 Variable::from_tensor(result)
227 }
228 }
229
230 #[must_use] pub fn sub_var(&self, other: &Variable) -> Variable {
232 let result = self.data.read().sub(&other.data.read()).unwrap();
233 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
234
235 if requires_grad {
236 let grad_fn = GradFn::new(SubBackward::new(
237 self.grad_fn.clone(),
238 other.grad_fn.clone(),
239 self.shape(),
240 other.shape(),
241 ));
242 Variable::from_operation(result, grad_fn, true)
243 } else {
244 Variable::from_tensor(result)
245 }
246 }
247
248 #[must_use] pub fn mul_var(&self, other: &Variable) -> Variable {
250 let self_data = self.data.read().clone();
251 let other_data = other.data.read().clone();
252 let result = self_data.mul(&other_data).unwrap();
253 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
254
255 if requires_grad {
256 let grad_fn = GradFn::new(MulBackward::new(
257 self.grad_fn.clone(),
258 other.grad_fn.clone(),
259 self_data,
260 other_data,
261 ));
262 Variable::from_operation(result, grad_fn, true)
263 } else {
264 Variable::from_tensor(result)
265 }
266 }
267
268 #[must_use] pub fn div_var(&self, other: &Variable) -> Variable {
270 let self_data = self.data.read().clone();
271 let other_data = other.data.read().clone();
272 let result = self_data.div(&other_data).unwrap();
273 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
274
275 if requires_grad {
276 let grad_fn = GradFn::new(DivBackward::new(
277 self.grad_fn.clone(),
278 other.grad_fn.clone(),
279 self_data,
280 other_data,
281 ));
282 Variable::from_operation(result, grad_fn, true)
283 } else {
284 Variable::from_tensor(result)
285 }
286 }
287
288 #[must_use] pub fn neg_var(&self) -> Variable {
290 let result = self.data.read().neg();
291 let requires_grad = self.requires_grad && is_grad_enabled();
292
293 if requires_grad {
294 let grad_fn = GradFn::new(NegBackward::new(self.grad_fn.clone()));
295 Variable::from_operation(result, grad_fn, true)
296 } else {
297 Variable::from_tensor(result)
298 }
299 }
300
301 #[must_use] pub fn matmul(&self, other: &Variable) -> Variable {
303 let self_data = self.data.read().clone();
304 let other_data = other.data.read().clone();
305 let result = self_data.matmul(&other_data).unwrap();
306 let requires_grad = (self.requires_grad || other.requires_grad) && is_grad_enabled();
307
308 if requires_grad {
309 let grad_fn = GradFn::new(MatMulBackward::new(
310 self.grad_fn.clone(),
311 other.grad_fn.clone(),
312 self_data,
313 other_data,
314 ));
315 Variable::from_operation(result, grad_fn, true)
316 } else {
317 Variable::from_tensor(result)
318 }
319 }
320
321 #[must_use] pub fn pow(&self, exponent: f32) -> Variable {
323 let self_data = self.data.read().clone();
324 let result = self_data.pow(exponent);
325 let requires_grad = self.requires_grad && is_grad_enabled();
326
327 if requires_grad {
328 let grad_fn = GradFn::new(PowBackward::new(self.grad_fn.clone(), self_data, exponent));
329 Variable::from_operation(result, grad_fn, true)
330 } else {
331 Variable::from_tensor(result)
332 }
333 }
334
335 #[must_use] pub fn relu(&self) -> Variable {
341 let self_data = self.data.read().clone();
342 let result = self_data.relu();
343 let requires_grad = self.requires_grad && is_grad_enabled();
344
345 if requires_grad {
346 let grad_fn = GradFn::new(ReluBackward::new(self.grad_fn.clone(), self_data));
347 Variable::from_operation(result, grad_fn, true)
348 } else {
349 Variable::from_tensor(result)
350 }
351 }
352
353 #[must_use] pub fn sigmoid(&self) -> Variable {
355 let result = self.data.read().sigmoid();
356 let requires_grad = self.requires_grad && is_grad_enabled();
357
358 if requires_grad {
359 let grad_fn = GradFn::new(SigmoidBackward::new(self.grad_fn.clone(), result.clone()));
360 Variable::from_operation(result, grad_fn, true)
361 } else {
362 Variable::from_tensor(result)
363 }
364 }
365
366 #[must_use] pub fn tanh(&self) -> Variable {
368 let result = self.data.read().tanh();
369 let requires_grad = self.requires_grad && is_grad_enabled();
370
371 if requires_grad {
372 let grad_fn = GradFn::new(TanhBackward::new(self.grad_fn.clone(), result.clone()));
373 Variable::from_operation(result, grad_fn, true)
374 } else {
375 Variable::from_tensor(result)
376 }
377 }
378
379 #[must_use] pub fn sum(&self) -> Variable {
385 let self_data = self.data.read().clone();
386 let result = self_data.sum(); let requires_grad = self.requires_grad && is_grad_enabled();
388
389 if requires_grad {
390 let grad_fn = GradFn::new(SumBackward::new(self.grad_fn.clone(), self.shape()));
391 Variable::from_operation(result, grad_fn, true)
392 } else {
393 Variable::from_tensor(result)
394 }
395 }
396
397 #[must_use] pub fn mean(&self) -> Variable {
399 let self_data = self.data.read().clone();
400 let result = self_data.mean().unwrap(); let requires_grad = self.requires_grad && is_grad_enabled();
402
403 if requires_grad {
404 let grad_fn = GradFn::new(MeanBackward::new(self.grad_fn.clone(), self.shape()));
405 Variable::from_operation(result, grad_fn, true)
406 } else {
407 Variable::from_tensor(result)
408 }
409 }
410
411 #[must_use] pub fn mse_loss(&self, target: &Variable) -> Variable {
417 let diff = self.sub_var(target);
418 let squared = diff.pow(2.0);
419 squared.mean()
420 }
421
422 #[must_use] pub fn binary_cross_entropy(&self, target: &Variable) -> Variable {
424 let eps = Variable::from_tensor(Tensor::scalar(1e-7));
425 let one = Variable::from_tensor(Tensor::scalar(1.0));
426
427 let log_p = self.add_var(&eps);
429 let log_1_p = one.sub_var(self).add_var(&eps);
430
431 let term1 = target.mul_var(&Variable::from_tensor(log_p.data().ln()));
432 let term2 = one
433 .sub_var(target)
434 .mul_var(&Variable::from_tensor(log_1_p.data().ln()));
435
436 term1.add_var(&term2).neg_var().mean()
437 }
438
439 #[must_use] pub fn reshape(&self, shape: &[usize]) -> Variable {
445 let isize_shape: Vec<isize> = shape.iter().map(|&x| x as isize).collect();
446 let original_shape = self.shape();
447 let new_data = self.data().reshape(&isize_shape).unwrap_or_else(|_| self.data().clone());
448 let requires_grad = self.requires_grad && is_grad_enabled();
449
450 if requires_grad {
451 let grad_fn = GradFn::new(ReshapeBackward::new(self.grad_fn.clone(), original_shape));
452 Variable::from_operation(new_data, grad_fn, true)
453 } else {
454 Variable::from_tensor(new_data)
455 }
456 }
457
458 #[must_use] pub fn transpose(&self, dim0: usize, dim1: usize) -> Variable {
460 let new_data = self.data().transpose(dim0 as i64, dim1 as i64).unwrap_or_else(|_| self.data().clone());
461 let requires_grad = self.requires_grad && is_grad_enabled();
462
463 if requires_grad {
464 let grad_fn = GradFn::new(TransposeBackward::new(self.grad_fn.clone(), dim0, dim1));
465 Variable::from_operation(new_data, grad_fn, true)
466 } else {
467 Variable::from_tensor(new_data)
468 }
469 }
470
471 #[must_use] pub fn slice(&self, ranges: &[std::ops::Range<usize>]) -> Variable {
473 let new_data = self.data().slice(ranges);
474 Variable::new(new_data, self.requires_grad())
475 }
476
477 #[must_use] pub fn expand(&self, shape: &[usize]) -> Variable {
479 let new_data = self.data().broadcast_to(shape);
480 Variable::new(new_data, self.requires_grad())
481 }
482
483 #[must_use] pub fn mul_scalar(&self, scalar: f32) -> Variable {
489 let data = self.data();
490 let shape = data.shape();
491 let numel: usize = shape.iter().product();
492 let scalar_tensor = Tensor::from_vec(vec![scalar; numel], shape).unwrap();
493 let scalar_var = Variable::new(scalar_tensor, false);
494 self.mul_var(&scalar_var)
495 }
496
497 #[must_use] pub fn add_scalar(&self, scalar: f32) -> Variable {
499 let data = self.data();
500 let shape = data.shape();
501 let numel: usize = shape.iter().product();
502 let scalar_tensor = Tensor::from_vec(vec![scalar; numel], shape).unwrap();
503 let scalar_var = Variable::new(scalar_tensor, false);
504 self.add_var(&scalar_var)
505 }
506
507 #[must_use] pub fn sub_scalar(&self, scalar: f32) -> Variable {
509 self.add_scalar(-scalar)
510 }
511
512 #[must_use] pub fn div_scalar(&self, scalar: f32) -> Variable {
514 self.mul_scalar(1.0 / scalar)
515 }
516
517 #[must_use] pub fn gelu(&self) -> Variable {
523 let data = self.data();
525 let result = data.gelu();
526 Variable::new(result, self.requires_grad())
527 }
528
529 #[must_use] pub fn silu(&self) -> Variable {
531 let data = self.data();
532 let result = data.silu();
533 Variable::new(result, self.requires_grad())
534 }
535
536 #[must_use] pub fn sqrt(&self) -> Variable {
538 let data = self.data();
539 let result = data.sqrt();
540 Variable::new(result, self.requires_grad())
541 }
542
543 #[must_use] pub fn softmax(&self, dim: i32) -> Variable {
549 let data = self.data();
550 let result = data.softmax(dim);
551 Variable::new(result, self.requires_grad())
552 }
553
554 #[must_use] pub fn log_softmax(&self, dim: i32) -> Variable {
556 let data = self.data();
557 let result = data.log_softmax(dim);
558 Variable::new(result, self.requires_grad())
559 }
560
561 #[must_use] pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Variable {
567 let data = self.data();
568 let result = data.mean_dim(dim, keepdim);
569 Variable::new(result, self.requires_grad())
570 }
571
572 #[must_use] pub fn var_dim(&self, dim: i32, keepdim: bool) -> Variable {
574 let data = self.data();
575 let result = data.var_dim(dim, keepdim);
576 Variable::new(result, self.requires_grad())
577 }
578
579 #[must_use] pub fn from_tensor_with_grad(data: Tensor<f32>, requires_grad: bool) -> Variable {
586 Variable::new(data, requires_grad)
587 }
588
589 #[must_use] pub fn clone_var(&self) -> Variable {
591 self.clone()
592 }
593
594 #[must_use] pub fn add(&self, other: &Variable) -> Variable {
596 self.add_var(other)
597 }
598
599 #[must_use] pub fn sub(&self, other: &Variable) -> Variable {
601 self.sub_var(other)
602 }
603
604 #[must_use] pub fn mul(&self, other: &Variable) -> Variable {
606 self.mul_var(other)
607 }
608
609 #[must_use] pub fn div(&self, other: &Variable) -> Variable {
611 self.div_var(other)
612 }
613}
614
615impl Add for &Variable {
620 type Output = Variable;
621
622 fn add(self, other: &Variable) -> Variable {
623 self.add_var(other)
624 }
625}
626
627impl Sub for &Variable {
628 type Output = Variable;
629
630 fn sub(self, other: &Variable) -> Variable {
631 self.sub_var(other)
632 }
633}
634
635impl Mul for &Variable {
636 type Output = Variable;
637
638 fn mul(self, other: &Variable) -> Variable {
639 self.mul_var(other)
640 }
641}
642
643impl Div for &Variable {
644 type Output = Variable;
645
646 fn div(self, other: &Variable) -> Variable {
647 self.div_var(other)
648 }
649}
650
651impl Neg for &Variable {
652 type Output = Variable;
653
654 fn neg(self) -> Variable {
655 self.neg_var()
656 }
657}
658
659impl std::fmt::Debug for Variable {
660 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
661 f.debug_struct("Variable")
662 .field("shape", &self.shape())
663 .field("requires_grad", &self.requires_grad)
664 .field("is_leaf", &self.is_leaf)
665 .field("grad_fn", &self.grad_fn.as_ref().map(super::grad_fn::GradFn::name))
666 .finish()
667 }
668}
669
670#[cfg(test)]
675mod tests {
676 use super::*;
677 use axonml_tensor::zeros;
678
679 #[test]
680 fn test_variable_creation() {
681 let t = zeros::<f32>(&[2, 3]);
682 let v = Variable::new(t, true);
683 assert!(v.requires_grad());
684 assert!(v.is_leaf());
685 assert_eq!(v.shape(), vec![2, 3]);
686 }
687
688 #[test]
689 fn test_variable_no_grad() {
690 let t = zeros::<f32>(&[2, 3]);
691 let v = Variable::from_tensor(t);
692 assert!(!v.requires_grad());
693 }
694
695 #[test]
696 fn test_variable_add() {
697 let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
698 let b = Variable::new(Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(), true);
699 let c = &a + &b;
700 assert_eq!(c.data().to_vec(), vec![5.0, 7.0, 9.0]);
701 assert!(c.requires_grad());
702 assert!(!c.is_leaf());
703 }
704
705 #[test]
706 fn test_variable_detach() {
707 let a = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
708 let b = a.detach();
709 assert!(!b.requires_grad());
710 assert!(b.is_leaf());
711 }
712
713 #[test]
714 fn test_mse_loss() {
715 let pred = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
716 let target = Variable::from_tensor(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
717 let loss = pred.mse_loss(&target);
718 assert_eq!(loss.numel(), 1);
719 assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
720 }
721}