mod optim;
#[cfg(feature = "alloc")]
mod vec;
use core::marker::PhantomData;
#[cfg(not(feature = "dyntensor"))]
pub use tensor_optim::ConstTensorOps;
pub use tensor_optim::TensorOps;
use crate::nn::TensorFloat;
pub use self::optim::{Flatten, StaticShape, Tensor, TensorGrad};
#[cfg(feature = "alloc")]
pub use vec::VecTensor;
pub trait IntoWithGrad<T>: TensorGrad<T> + Sized {
fn with_grad(self) -> WithGrad<Self, T> {
WithGrad::new(self)
}
fn grad_of(self, grad: Self) -> WithGrad<Self, T> {
let mut w = WithGrad::new(self);
w.set_grad(grad);
w
}
}
impl<T: TensorGrad<U>, U> IntoWithGrad<U> for T {}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)]
pub struct WithGrad<T, U = TensorFloat> {
value: T,
grad: T,
_marker: PhantomData<U>,
}
impl<T: TensorGrad<U>, U> WithGrad<T, U> {
pub fn new(value: T) -> Self {
let grad = value.zeros_like();
Self {
value,
grad,
_marker: PhantomData,
}
}
pub fn set_grad(&mut self, grad: T) {
self.grad = grad;
}
pub const fn split(&self) -> (&T, &T) {
(&self.value, &self.grad)
}
pub const fn split_mut(&mut self) -> (&mut T, &mut T) {
(&mut self.value, &mut self.grad)
}
pub const fn get_grad(&self) -> &T {
&self.grad
}
pub const fn get_value(&self) -> &T {
&self.value
}
pub const fn get_grad_mut(&mut self) -> &mut T {
&mut self.grad
}
pub const fn get_value_mut(&mut self) -> &mut T {
&mut self.value
}
pub fn into_value(self) -> T {
self.value
}
pub fn into_grad(self) -> T {
self.grad
}
}