Trait burn_tensor::ops::ModuleOps
source · pub trait ModuleOps<B: Backend> {
fn embedding(
weights: &B::TensorPrimitive<2>,
indexes: &<B::IntegerBackend as Backend>::TensorPrimitive<2>
) -> B::TensorPrimitive<3>;
fn embedding_backward(
weights: &B::TensorPrimitive<2>,
output: &B::TensorPrimitive<3>,
indexes: &<B::IntegerBackend as Backend>::TensorPrimitive<2>
) -> B::TensorPrimitive<2>;
}