tl_cuda 0.4.0

CUDA GPU tensor library for TL
1
2
3
4
5
6
7
8
9
10
11
12
13
14
//! CUDA Autograd

pub mod ops;

use crate::tensor::{CudaTensor, TensorRef};
use tl_backend::BackendResult;

/// Autograd 勾配関数トレイト(V5.0 Arc ベース)
pub trait GradFn: Send + Sync {
    /// backward 計算
    fn backward(&self, grad_output: &CudaTensor) -> BackendResult<Vec<CudaTensor>>;
    /// 入力テンソルへの参照(TensorRef = Arc<UnsafeCell<CudaTensor>>)
    fn inputs(&self) -> Vec<TensorRef>;
}