1use crate::{nn, nn::ModuleT, Tensor};
3
4fn conv_bn(p: nn::Path, c_in: i64, c_out: i64, ksize: i64, pad: i64, stride: i64) -> impl ModuleT {
5 let conv2d_cfg = nn::ConvConfig { stride, padding: pad, bias: false, ..Default::default() };
6 let bn_cfg = nn::BatchNormConfig { eps: 0.001, ..Default::default() };
7 nn::seq_t()
8 .add(nn::conv2d(&p / "conv", c_in, c_out, ksize, conv2d_cfg))
9 .add(nn::batch_norm2d(&p / "bn", c_out, bn_cfg))
10 .add_fn(|xs| xs.relu())
11}
12
13fn conv_bn2(p: nn::Path, c_in: i64, c_out: i64, ksize: [i64; 2], pad: [i64; 2]) -> impl ModuleT {
14 let conv2d_cfg =
15 nn::ConvConfigND::<[i64; 2]> { padding: pad, bias: false, ..Default::default() };
16 let bn_cfg = nn::BatchNormConfig { eps: 0.001, ..Default::default() };
17 nn::seq_t()
18 .add(nn::conv(&p / "conv", c_in, c_out, ksize, conv2d_cfg))
19 .add(nn::batch_norm2d(&p / "bn", c_out, bn_cfg))
20 .add_fn(|xs| xs.relu())
21}
22
23fn max_pool2d(xs: &Tensor, ksize: i64, stride: i64) -> Tensor {
24 xs.max_pool2d([ksize, ksize], [stride, stride], [0, 0], [1, 1], false)
25}
26
27fn inception_a(p: nn::Path, c_in: i64, c_pool: i64) -> impl ModuleT {
28 let b1 = conv_bn(&p / "branch1x1", c_in, 64, 1, 0, 1);
29 let b2_1 = conv_bn(&p / "branch5x5_1", c_in, 48, 1, 0, 1);
30 let b2_2 = conv_bn(&p / "branch5x5_2", 48, 64, 5, 2, 1);
31 let b3_1 = conv_bn(&p / "branch3x3dbl_1", c_in, 64, 1, 0, 1);
32 let b3_2 = conv_bn(&p / "branch3x3dbl_2", 64, 96, 3, 1, 1);
33 let b3_3 = conv_bn(&p / "branch3x3dbl_3", 96, 96, 3, 1, 1);
34 let bpool = conv_bn(&p / "branch_pool", c_in, c_pool, 1, 0, 1);
35 nn::func_t(move |xs, tr| {
36 let b1 = xs.apply_t(&b1, tr);
37 let b2 = xs.apply_t(&b2_1, tr).apply_t(&b2_2, tr);
38 let b3 = xs.apply_t(&b3_1, tr).apply_t(&b3_2, tr).apply_t(&b3_3, tr);
39 let bpool = xs.avg_pool2d([3, 3], [1, 1], [1, 1], false, true, 9).apply_t(&bpool, tr);
40 Tensor::cat(&[b1, b2, b3, bpool], 1)
41 })
42}
43
44fn inception_b(p: nn::Path, c_in: i64) -> impl ModuleT {
45 let b1 = conv_bn(&p / "branch3x3", c_in, 384, 3, 0, 2);
46 let b2_1 = conv_bn(&p / "branch3x3dbl_1", c_in, 64, 1, 0, 1);
47 let b2_2 = conv_bn(&p / "branch3x3dbl_2", 64, 96, 3, 1, 1);
48 let b2_3 = conv_bn(&p / "branch3x3dbl_3", 96, 96, 3, 0, 2);
49 nn::func_t(move |xs, tr| {
50 let b1 = xs.apply_t(&b1, tr);
51 let b2 = xs.apply_t(&b2_1, tr).apply_t(&b2_2, tr).apply_t(&b2_3, tr);
52 let bpool = max_pool2d(xs, 3, 2);
53 Tensor::cat(&[b1, b2, bpool], 1)
54 })
55}
56
57fn inception_c(p: nn::Path, c_in: i64, c7: i64) -> impl ModuleT {
58 let b1 = conv_bn(&p / "branch1x1", c_in, 192, 1, 0, 1);
59
60 let b2_1 = conv_bn(&p / "branch7x7_1", c_in, c7, 1, 0, 1);
61 let b2_2 = conv_bn2(&p / "branch7x7_2", c7, c7, [1, 7], [0, 3]);
62 let b2_3 = conv_bn2(&p / "branch7x7_3", c7, 192, [7, 1], [3, 0]);
63
64 let b3_1 = conv_bn(&p / "branch7x7dbl_1", c_in, c7, 1, 0, 1);
65 let b3_2 = conv_bn2(&p / "branch7x7dbl_2", c7, c7, [7, 1], [3, 0]);
66 let b3_3 = conv_bn2(&p / "branch7x7dbl_3", c7, c7, [1, 7], [0, 3]);
67 let b3_4 = conv_bn2(&p / "branch7x7dbl_4", c7, c7, [7, 1], [3, 0]);
68 let b3_5 = conv_bn2(&p / "branch7x7dbl_5", c7, 192, [1, 7], [0, 3]);
69
70 let bpool = conv_bn(&p / "branch_pool", c_in, 192, 1, 0, 1);
71
72 nn::func_t(move |xs, tr| {
73 let b1 = xs.apply_t(&b1, tr);
74 let b2 = xs.apply_t(&b2_1, tr).apply_t(&b2_2, tr).apply_t(&b2_3, tr);
75 let b3 = xs
76 .apply_t(&b3_1, tr)
77 .apply_t(&b3_2, tr)
78 .apply_t(&b3_3, tr)
79 .apply_t(&b3_4, tr)
80 .apply_t(&b3_5, tr);
81 let bpool = xs.avg_pool2d([3, 3], [1, 1], [1, 1], false, true, 9).apply_t(&bpool, tr);
82 Tensor::cat(&[b1, b2, b3, bpool], 1)
83 })
84}
85
86fn inception_d(p: nn::Path, c_in: i64) -> impl ModuleT {
87 let b1_1 = conv_bn(&p / "branch3x3_1", c_in, 192, 1, 0, 1);
88 let b1_2 = conv_bn(&p / "branch3x3_2", 192, 320, 3, 0, 2);
89
90 let b2_1 = conv_bn(&p / "branch7x7x3_1", c_in, 192, 1, 0, 1);
91 let b2_2 = conv_bn2(&p / "branch7x7x3_2", 192, 192, [1, 7], [0, 3]);
92 let b2_3 = conv_bn2(&p / "branch7x7x3_3", 192, 192, [7, 1], [3, 0]);
93 let b2_4 = conv_bn(&p / "branch7x7x3_4", 192, 192, 3, 0, 2);
94
95 nn::func_t(move |xs, tr| {
96 let b1 = xs.apply_t(&b1_1, tr).apply_t(&b1_2, tr);
97 let b2 = xs.apply_t(&b2_1, tr).apply_t(&b2_2, tr).apply_t(&b2_3, tr).apply_t(&b2_4, tr);
98 let bpool = max_pool2d(xs, 3, 2);
99 Tensor::cat(&[b1, b2, bpool], 1)
100 })
101}
102
103fn inception_e(p: nn::Path, c_in: i64) -> impl ModuleT {
104 let b1 = conv_bn(&p / "branch1x1", c_in, 320, 1, 0, 1);
105
106 let b2_1 = conv_bn(&p / "branch3x3_1", c_in, 384, 1, 0, 1);
107 let b2_2a = conv_bn2(&p / "branch3x3_2a", 384, 384, [1, 3], [0, 1]);
108 let b2_2b = conv_bn2(&p / "branch3x3_2b", 384, 384, [3, 1], [1, 0]);
109
110 let b3_1 = conv_bn(&p / "branch3x3dbl_1", c_in, 448, 1, 0, 1);
111 let b3_2 = conv_bn(&p / "branch3x3dbl_2", 448, 384, 3, 1, 1);
112 let b3_3a = conv_bn2(&p / "branch3x3dbl_3a", 384, 384, [1, 3], [0, 1]);
113 let b3_3b = conv_bn2(&p / "branch3x3dbl_3b", 384, 384, [3, 1], [1, 0]);
114
115 let bpool = conv_bn(&p / "branch_pool", c_in, 192, 1, 0, 1);
116
117 nn::func_t(move |xs, tr| {
118 let b1 = xs.apply_t(&b1, tr);
119
120 let b2 = xs.apply_t(&b2_1, tr);
121 let b2 = Tensor::cat(&[b2.apply_t(&b2_2a, tr), b2.apply_t(&b2_2b, tr)], 1);
122
123 let b3 = xs.apply_t(&b3_1, tr).apply_t(&b3_2, tr);
124 let b3 = Tensor::cat(&[b3.apply_t(&b3_3a, tr), b3.apply_t(&b3_3b, tr)], 1);
125
126 let bpool = xs.avg_pool2d([3, 3], [1, 1], [1, 1], false, true, 9).apply_t(&bpool, tr);
127
128 Tensor::cat(&[b1, b2, b3, bpool], 1)
129 })
130}
131
132pub fn v3(p: &nn::Path, nclasses: i64) -> impl ModuleT {
133 nn::seq_t()
134 .add(conv_bn(p / "Conv2d_1a_3x3", 3, 32, 3, 0, 2))
135 .add(conv_bn(p / "Conv2d_2a_3x3", 32, 32, 3, 0, 1))
136 .add(conv_bn(p / "Conv2d_2b_3x3", 32, 64, 3, 1, 1))
137 .add_fn(|xs| max_pool2d(&xs.relu(), 3, 2))
138 .add(conv_bn(p / "Conv2d_3b_1x1", 64, 80, 1, 0, 1))
139 .add(conv_bn(p / "Conv2d_4a_3x3", 80, 192, 3, 0, 1))
140 .add_fn(|xs| max_pool2d(&xs.relu(), 3, 2))
141 .add(inception_a(p / "Mixed_5b", 192, 32))
142 .add(inception_a(p / "Mixed_5c", 256, 64))
143 .add(inception_a(p / "Mixed_5d", 288, 64))
144 .add(inception_b(p / "Mixed_6a", 288))
145 .add(inception_c(p / "Mixed_6b", 768, 128))
146 .add(inception_c(p / "Mixed_6c", 768, 160))
147 .add(inception_c(p / "Mixed_6d", 768, 160))
148 .add(inception_c(p / "Mixed_6e", 768, 192))
149 .add(inception_d(p / "Mixed_7a", 768))
150 .add(inception_e(p / "Mixed_7b", 1280))
151 .add(inception_e(p / "Mixed_7c", 2048))
152 .add_fn_t(|xs, train| xs.adaptive_avg_pool2d([1, 1]).dropout(0.5, train).flat_view())
153 .add(nn::linear(p / "fc", 2048, nclasses, Default::default()))
154}