use super::GradFn;
use crate::runtime::Runtime;
use crate::tensor::{Tensor, TensorId};
use std::sync::Arc;
pub struct Var<R: Runtime> {
tensor: Tensor<R>,
id: TensorId,
requires_grad: bool,
grad_fn: Option<Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> Var<R> {
pub fn new(tensor: Tensor<R>, requires_grad: bool) -> Self {
Self {
id: tensor.id(),
tensor,
requires_grad,
grad_fn: None,
}
}
pub fn with_id(tensor: Tensor<R>, id: TensorId, requires_grad: bool) -> Self {
Self {
id,
tensor,
requires_grad,
grad_fn: None,
}
}
pub fn with_id_and_grad_fn(
tensor: Tensor<R>,
id: TensorId,
grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
id,
tensor,
requires_grad: true,
grad_fn,
}
}
pub fn from_op(tensor: Tensor<R>, grad_fn: Arc<dyn GradFn<R>>) -> Self {
Self {
id: TensorId::new(),
tensor,
requires_grad: true,
grad_fn: Some(grad_fn),
}
}
#[inline]
pub fn id(&self) -> TensorId {
self.id
}
#[inline]
pub fn tensor(&self) -> &Tensor<R> {
&self.tensor
}
#[inline]
pub fn requires_grad(&self) -> bool {
self.requires_grad
}
#[inline]
pub fn grad_fn(&self) -> Option<&Arc<dyn GradFn<R>>> {
self.grad_fn.as_ref()
}
pub fn detach(&self) -> Self {
Self {
tensor: self.tensor.clone(),
id: TensorId::new(),
requires_grad: false,
grad_fn: None,
}
}
pub fn set_requires_grad(&mut self, requires_grad: bool) {
self.requires_grad = requires_grad;
if !requires_grad {
self.grad_fn = None;
}
}
#[inline]
pub fn shape(&self) -> &[usize] {
self.tensor.shape()
}
#[inline]
pub fn numel(&self) -> usize {
self.tensor.numel()
}
#[inline]
pub fn ndim(&self) -> usize {
self.tensor.ndim()
}
}
impl<R: Runtime> Clone for Var<R> {
fn clone(&self) -> Self {
Self {
tensor: self.tensor.clone(),
id: TensorId::new(),
requires_grad: self.requires_grad,
grad_fn: self.grad_fn.clone(),
}
}
}
impl<R: Runtime> std::fmt::Debug for Var<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Var")
.field("id", &self.id)
.field("shape", &self.tensor.shape())
.field("requires_grad", &self.requires_grad)
.field("has_grad_fn", &self.grad_fn.is_some())
.finish()
}
}