pub struct Tape<'d> { /* private fields */ }Expand description
Autograd tape. Records forward operations, runs backward to compute gradients.
Implementations§
Source§impl<'d> Tape<'d>
impl<'d> Tape<'d>
pub fn new(dev: &'d GpuDevice) -> Self
Sourcepub fn leaf(&mut self, data: &[f32]) -> TensorId
pub fn leaf(&mut self, data: &[f32]) -> TensorId
Register a leaf tensor (parameter or input data). No backward through this.
Sourcepub fn read_grad(&self, id: TensorId) -> Result<Option<Vec<f32>>>
pub fn read_grad(&self, id: TensorId) -> Result<Option<Vec<f32>>>
Read gradient data back to CPU. Returns None if no gradient computed.
pub fn add(&mut self, a: TensorId, b: TensorId) -> Result<TensorId>
pub fn sub(&mut self, a: TensorId, b: TensorId) -> Result<TensorId>
pub fn mul(&mut self, a: TensorId, b: TensorId) -> Result<TensorId>
pub fn scale(&mut self, a: TensorId, s: f32) -> Result<TensorId>
pub fn relu(&mut self, a: TensorId) -> Result<TensorId>
pub fn sigmoid(&mut self, a: TensorId) -> Result<TensorId>
pub fn swish(&mut self, a: TensorId) -> Result<TensorId>
pub fn tanh_act(&mut self, a: TensorId) -> Result<TensorId>
pub fn matmul( &mut self, a: TensorId, b: TensorId, m: u32, n: u32, k: u32, ) -> Result<TensorId>
pub fn mse_loss(&mut self, pred: TensorId, target: TensorId) -> Result<TensorId>
pub fn conv2d( &mut self, input: TensorId, weight: TensorId, bias: Option<TensorId>, batch: u32, in_c: u32, in_h: u32, in_w: u32, out_c: u32, kh: u32, kw: u32, stride: (u32, u32), padding: (u32, u32), dilation: (u32, u32), groups: u32, ) -> Result<TensorId>
Auto Trait Implementations§
impl<'d> Freeze for Tape<'d>
impl<'d> !RefUnwindSafe for Tape<'d>
impl<'d> Send for Tape<'d>
impl<'d> Sync for Tape<'d>
impl<'d> Unpin for Tape<'d>
impl<'d> UnsafeUnpin for Tape<'d>
impl<'d> !UnwindSafe for Tape<'d>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more