mod optim;
#[cfg(feature = "alloc")]
mod vec;
#[cfg(not(feature = "dyntensor"))]
pub use tensor_optim::ConstTensorOps;
pub use tensor_optim::TensorOps;
pub use self::optim::{Flatten, Tensor, TensorGrad};
#[cfg(feature = "alloc")]
pub use vec::VecTensor;
pub trait IntoWithGrad: TensorGrad + Sized {
fn with_grad(self) -> WithGrad<Self> {
WithGrad::new(self)
}
}
impl<T: TensorGrad> IntoWithGrad for T {}
#[derive(Debug, Clone, Default)]
pub struct WithGrad<T> {
value: T,
grad: T,
}
impl<T: TensorGrad> WithGrad<T> {
pub fn new(value: T) -> Self {
let grad = value.zeros_like();
Self { value, grad }
}
pub fn set_grad(&mut self, grad: T) {
self.grad = grad;
}
pub const fn split(&self) -> (&T, &T) {
(&self.value, &self.grad)
}
pub const fn split_mut(&mut self) -> (&mut T, &mut T) {
(&mut self.value, &mut self.grad)
}
pub const fn get_grad(&self) -> &T {
&self.grad
}
pub const fn get_value(&self) -> &T {
&self.value
}
pub const fn get_grad_mut(&mut self) -> &mut T {
&mut self.grad
}
pub const fn get_value_mut(&mut self) -> &mut T {
&mut self.value
}
}