use std::cell::RefCell;
use std::fmt;
use std::rc::Rc;
use crate::tensor::{DType, Device, Result, Tensor};
pub(crate) struct VariableInner {
pub data: Tensor,
}
#[derive(Clone)]
pub struct Variable {
pub(crate) inner: Rc<RefCell<VariableInner>>,
}
impl Variable {
pub fn new(data: Tensor, requires_grad: bool) -> Self {
let data = if requires_grad {
data.set_requires_grad(true).unwrap_or(data)
} else {
data
};
Variable {
inner: Rc::new(RefCell::new(VariableInner { data })),
}
}
pub(crate) fn wrap(data: Tensor) -> Self {
Variable {
inner: Rc::new(RefCell::new(VariableInner { data })),
}
}
pub fn data(&self) -> Tensor {
self.inner.borrow().data.clone()
}
pub fn grad(&self) -> Option<Tensor> {
self.inner.borrow().data.grad()
}
pub fn set_grad(&self, grad: Tensor) {
let _ = self.inner.borrow().data.set_grad(&grad);
}
pub fn requires_grad(&self) -> bool {
self.inner.borrow().data.requires_grad()
}
pub fn set_requires_grad(&self, requires_grad: bool) -> Result<()> {
let data = self.inner.borrow().data.set_requires_grad(requires_grad)?;
self.inner.borrow_mut().data = data;
Ok(())
}
pub fn is_leaf(&self) -> bool {
self.inner.borrow().data.is_leaf()
}
pub fn ensure_grad_accumulator(&self) -> Result<Option<crate::tensor::GradAccumulatorHandle>> {
self.inner.borrow().data.ensure_grad_accumulator()
}
pub fn autograd_node_count(&self) -> i64 {
self.inner.borrow().data.autograd_node_count()
}
pub fn shape(&self) -> Vec<i64> {
self.inner.borrow().data.shape()
}
pub fn dtype(&self) -> DType {
self.inner.borrow().data.dtype()
}
pub fn device(&self) -> Device {
self.inner.borrow().data.device()
}
pub fn item(&self) -> Result<f64> {
self.inner.borrow().data.item()
}
pub fn zero_grad(&self) {
let _ = self.inner.borrow().data.zero_grad();
}
pub fn zero_grad_set_to_none(&self) {
self.inner.borrow().data.zero_grad_set_to_none();
}
pub fn detach(&self) -> Variable {
let detached = self.inner.borrow().data.detach()
.unwrap_or_else(|_| self.inner.borrow().data.clone());
Variable::wrap(detached)
}
pub fn to_device(&self, device: Device) -> Result<Variable> {
if self.device() == device {
return Ok(self.clone());
}
let moved = self.inner.borrow().data.to_device(device)?;
Ok(Variable::new(moved, self.requires_grad()))
}
pub fn set_data(&self, data: Tensor) {
let rg = self.requires_grad();
let data = if rg {
data.set_requires_grad(true).unwrap_or(data)
} else {
data
};
self.inner.borrow_mut().data = data;
}
pub fn numel(&self) -> i64 {
self.inner.borrow().data.numel()
}
pub fn backward(&self) -> Result<()> {
let inner = self.inner.borrow();
inner.data.backward()?;
inner.data.detach_()?;
Ok(())
}
}
impl fmt::Debug for Variable {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let inner = self.inner.borrow();
write!(
f,
"Variable({:?}, {:?}, {:?}, requires_grad={})",
inner.data.shape(),
inner.data.dtype(),
inner.data.device(),
inner.data.requires_grad(),
)
}
}