1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
use crate::{backend::Backend, Int, Tensor};
pub fn embedding<B>(weights: Tensor<B, 2>, indexes: Tensor<B, 2, Int>) -> Tensor<B, 3>
where
B: Backend,
{
Tensor::new(B::embedding(weights.primitive, indexes.primitive))
}
pub fn conv1d<B>(
x: Tensor<B, 3>,
weight: Tensor<B, 3>,
bias: Option<Tensor<B, 1>>,
stride: usize,
padding: usize,
) -> Tensor<B, 3>
where
B: Backend,
{
Tensor::new(B::conv1d(
x.primitive,
weight.primitive,
bias.map(|b| b.primitive),
stride,
padding,
))
}
pub fn conv2d<B>(
x: Tensor<B, 4>,
weight: Tensor<B, 4>,
bias: Option<Tensor<B, 1>>,
stride: [usize; 2],
padding: [usize; 2],
) -> Tensor<B, 4>
where
B: Backend,
{
Tensor::new(B::conv2d(
x.primitive,
weight.primitive,
bias.map(|b| b.primitive),
stride,
padding,
))
}
pub fn max_pool2d<B>(
x: Tensor<B, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> Tensor<B, 4>
where
B: Backend,
{
Tensor::new(B::max_pool2d(x.primitive, kernel_size, stride, padding))
}
pub fn max_pool2d_with_indexes<B>(
x: Tensor<B, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> (Tensor<B, 4>, Tensor<B, 4, Int>)
where
B: Backend,
{
let output = B::max_pool2d_with_indexes(x.primitive, kernel_size, stride, padding);
(Tensor::new(output.output), Tensor::new(output.indexes))
}