1use candle::{ModuleT, Result, Tensor};
17use candle_nn::{FuncT, VarBuilder};
18
19pub enum Models {
21 Vgg13,
22 Vgg16,
23 Vgg19,
24}
25
26#[derive(Debug)]
28pub struct Vgg<'a> {
29 blocks: Vec<FuncT<'a>>,
30}
31
32struct PreLogitConfig {
34 in_dim: (usize, usize, usize, usize),
35 target_in: usize,
36 target_out: usize,
37}
38
39impl<'a> Vgg<'a> {
41 pub fn new(vb: VarBuilder<'a>, model: Models) -> Result<Self> {
43 let blocks = match model {
44 Models::Vgg13 => vgg13_blocks(vb)?,
45 Models::Vgg16 => vgg16_blocks(vb)?,
46 Models::Vgg19 => vgg19_blocks(vb)?,
47 };
48 Ok(Self { blocks })
49 }
50}
51
52impl ModuleT for Vgg<'_> {
54 fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
55 let mut xs = xs.unsqueeze(0)?;
56 for block in self.blocks.iter() {
57 xs = xs.apply_t(block, train)?;
58 }
59 Ok(xs)
60 }
61}
62
63fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
66 let layers = convs
67 .iter()
68 .map(|&(in_c, out_c, name)| {
69 candle_nn::conv2d(
70 in_c,
71 out_c,
72 3,
73 candle_nn::Conv2dConfig {
74 stride: 1,
75 padding: 1,
76 ..Default::default()
77 },
78 vb.pp(name),
79 )
80 })
81 .collect::<Result<Vec<_>>>()?;
82
83 Ok(FuncT::new(move |xs, _train| {
84 let mut xs = xs.clone();
85 for layer in layers.iter() {
86 xs = xs.apply(layer)?.relu()?
87 }
88 xs = xs.max_pool2d_with_stride(2, 2)?;
89 Ok(xs)
90 }))
91}
92
93fn fully_connected(
96 num_classes: usize,
97 pre_logit_1: PreLogitConfig,
98 pre_logit_2: PreLogitConfig,
99 vb: VarBuilder,
100) -> Result<FuncT> {
101 let lin = get_weights_and_biases(
102 &vb.pp("pre_logits.fc1"),
103 pre_logit_1.in_dim,
104 pre_logit_1.target_in,
105 pre_logit_1.target_out,
106 )?;
107 let lin2 = get_weights_and_biases(
108 &vb.pp("pre_logits.fc2"),
109 pre_logit_2.in_dim,
110 pre_logit_2.target_in,
111 pre_logit_2.target_out,
112 )?;
113 let dropout1 = candle_nn::Dropout::new(0.5);
114 let dropout2 = candle_nn::Dropout::new(0.5);
115 let dropout3 = candle_nn::Dropout::new(0.5);
116 Ok(FuncT::new(move |xs, train| {
117 let xs = xs.reshape((1, pre_logit_1.target_out))?;
118 let xs = xs.apply_t(&dropout1, train)?.apply(&lin)?.relu()?;
119 let xs = xs.apply_t(&dropout2, train)?.apply(&lin2)?.relu()?;
120 let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?;
121 let xs = xs.apply_t(&dropout3, train)?.apply(&lin3)?.relu()?;
122 Ok(xs)
123 }))
124}
125
126fn get_weights_and_biases(
129 vs: &VarBuilder,
130 in_dim: (usize, usize, usize, usize),
131 target_in: usize,
132 target_out: usize,
133) -> Result<candle_nn::Linear> {
134 let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
135 let ws = vs.get_with_hints(in_dim, "weight", init_ws)?;
136 let ws = ws.reshape((target_in, target_out))?;
137 let bound = 1. / (target_out as f64).sqrt();
138 let init_bs = candle_nn::Init::Uniform {
139 lo: -bound,
140 up: bound,
141 };
142 let bs = vs.get_with_hints(target_in, "bias", init_bs)?;
143 Ok(candle_nn::Linear::new(ws, Some(bs)))
144}
145
146fn vgg13_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
147 let num_classes = 1000;
148 let blocks = vec![
149 conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
150 conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
151 conv2d_block(&[(128, 256, "features.10"), (256, 256, "features.12")], &vb)?,
152 conv2d_block(&[(256, 512, "features.15"), (512, 512, "features.17")], &vb)?,
153 conv2d_block(&[(512, 512, "features.20"), (512, 512, "features.22")], &vb)?,
154 fully_connected(
155 num_classes,
156 PreLogitConfig {
157 in_dim: (4096, 512, 7, 7),
158 target_in: 4096,
159 target_out: 512 * 7 * 7,
160 },
161 PreLogitConfig {
162 in_dim: (4096, 4096, 1, 1),
163 target_in: 4096,
164 target_out: 4096,
165 },
166 vb.clone(),
167 )?,
168 ];
169 Ok(blocks)
170}
171
172fn vgg16_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
173 let num_classes = 1000;
174 let blocks = vec![
175 conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
176 conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
177 conv2d_block(
178 &[
179 (128, 256, "features.10"),
180 (256, 256, "features.12"),
181 (256, 256, "features.14"),
182 ],
183 &vb,
184 )?,
185 conv2d_block(
186 &[
187 (256, 512, "features.17"),
188 (512, 512, "features.19"),
189 (512, 512, "features.21"),
190 ],
191 &vb,
192 )?,
193 conv2d_block(
194 &[
195 (512, 512, "features.24"),
196 (512, 512, "features.26"),
197 (512, 512, "features.28"),
198 ],
199 &vb,
200 )?,
201 fully_connected(
202 num_classes,
203 PreLogitConfig {
204 in_dim: (4096, 512, 7, 7),
205 target_in: 4096,
206 target_out: 512 * 7 * 7,
207 },
208 PreLogitConfig {
209 in_dim: (4096, 4096, 1, 1),
210 target_in: 4096,
211 target_out: 4096,
212 },
213 vb.clone(),
214 )?,
215 ];
216 Ok(blocks)
217}
218
219fn vgg19_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
220 let num_classes = 1000;
221 let blocks = vec![
222 conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
223 conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
224 conv2d_block(
225 &[
226 (128, 256, "features.10"),
227 (256, 256, "features.12"),
228 (256, 256, "features.14"),
229 (256, 256, "features.16"),
230 ],
231 &vb,
232 )?,
233 conv2d_block(
234 &[
235 (256, 512, "features.19"),
236 (512, 512, "features.21"),
237 (512, 512, "features.23"),
238 (512, 512, "features.25"),
239 ],
240 &vb,
241 )?,
242 conv2d_block(
243 &[
244 (512, 512, "features.28"),
245 (512, 512, "features.30"),
246 (512, 512, "features.32"),
247 (512, 512, "features.34"),
248 ],
249 &vb,
250 )?,
251 fully_connected(
252 num_classes,
253 PreLogitConfig {
254 in_dim: (4096, 512, 7, 7),
255 target_in: 4096,
256 target_out: 512 * 7 * 7,
257 },
258 PreLogitConfig {
259 in_dim: (4096, 4096, 1, 1),
260 target_in: 4096,
261 target_out: 4096,
262 },
263 vb.clone(),
264 )?,
265 ];
266 Ok(blocks)
267}