candle_transformers/models/
vgg.rs

1//! VGG-16 model implementation.
2//!
3//! VGG-16 is a convolutional neural network architecture. It consists of 13
4//! convolutional layers followed by 3 fully connected layers.
5//!
6//! Key characteristics:
7//! - Conv layers with 3x3 filters
8//! - Max pooling after every 2-3 conv layers
9//! - Three fully connected layers of 4096, 4096, 1000 units
10//! - ReLU activation and dropout
11//!
12//! References:
13//! - [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556)
14//!
15
16use candle::{ModuleT, Result, Tensor};
17use candle_nn::{FuncT, VarBuilder};
18
19// Enum representing the different VGG models
20pub enum Models {
21    Vgg13,
22    Vgg16,
23    Vgg19,
24}
25
26// Struct representing a VGG model
27#[derive(Debug)]
28pub struct Vgg<'a> {
29    blocks: Vec<FuncT<'a>>,
30}
31
32// Struct representing the configuration for the pre-logit layer
33struct PreLogitConfig {
34    in_dim: (usize, usize, usize, usize),
35    target_in: usize,
36    target_out: usize,
37}
38
39// Implementation of the VGG model
40impl<'a> Vgg<'a> {
41    // Function to create a new VGG model
42    pub fn new(vb: VarBuilder<'a>, model: Models) -> Result<Self> {
43        let blocks = match model {
44            Models::Vgg13 => vgg13_blocks(vb)?,
45            Models::Vgg16 => vgg16_blocks(vb)?,
46            Models::Vgg19 => vgg19_blocks(vb)?,
47        };
48        Ok(Self { blocks })
49    }
50}
51
52// Implementation of the forward pass for the VGG model
53impl ModuleT for Vgg<'_> {
54    fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
55        let mut xs = xs.unsqueeze(0)?;
56        for block in self.blocks.iter() {
57            xs = xs.apply_t(block, train)?;
58        }
59        Ok(xs)
60    }
61}
62
63// Function to create a conv2d block
64// The block is composed of two conv2d layers followed by a max pool layer
65fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
66    let layers = convs
67        .iter()
68        .map(|&(in_c, out_c, name)| {
69            candle_nn::conv2d(
70                in_c,
71                out_c,
72                3,
73                candle_nn::Conv2dConfig {
74                    stride: 1,
75                    padding: 1,
76                    ..Default::default()
77                },
78                vb.pp(name),
79            )
80        })
81        .collect::<Result<Vec<_>>>()?;
82
83    Ok(FuncT::new(move |xs, _train| {
84        let mut xs = xs.clone();
85        for layer in layers.iter() {
86            xs = xs.apply(layer)?.relu()?
87        }
88        xs = xs.max_pool2d_with_stride(2, 2)?;
89        Ok(xs)
90    }))
91}
92
93// Function to create a fully connected layer
94// The layer is composed of two linear layers followed by a dropout layer
95fn fully_connected(
96    num_classes: usize,
97    pre_logit_1: PreLogitConfig,
98    pre_logit_2: PreLogitConfig,
99    vb: VarBuilder,
100) -> Result<FuncT> {
101    let lin = get_weights_and_biases(
102        &vb.pp("pre_logits.fc1"),
103        pre_logit_1.in_dim,
104        pre_logit_1.target_in,
105        pre_logit_1.target_out,
106    )?;
107    let lin2 = get_weights_and_biases(
108        &vb.pp("pre_logits.fc2"),
109        pre_logit_2.in_dim,
110        pre_logit_2.target_in,
111        pre_logit_2.target_out,
112    )?;
113    let dropout1 = candle_nn::Dropout::new(0.5);
114    let dropout2 = candle_nn::Dropout::new(0.5);
115    let dropout3 = candle_nn::Dropout::new(0.5);
116    Ok(FuncT::new(move |xs, train| {
117        let xs = xs.reshape((1, pre_logit_1.target_out))?;
118        let xs = xs.apply_t(&dropout1, train)?.apply(&lin)?.relu()?;
119        let xs = xs.apply_t(&dropout2, train)?.apply(&lin2)?.relu()?;
120        let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?;
121        let xs = xs.apply_t(&dropout3, train)?.apply(&lin3)?.relu()?;
122        Ok(xs)
123    }))
124}
125
126// Function to get the weights and biases for a layer
127// This is required because the weights and biases are stored in different format than our linear layer expects
128fn get_weights_and_biases(
129    vs: &VarBuilder,
130    in_dim: (usize, usize, usize, usize),
131    target_in: usize,
132    target_out: usize,
133) -> Result<candle_nn::Linear> {
134    let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
135    let ws = vs.get_with_hints(in_dim, "weight", init_ws)?;
136    let ws = ws.reshape((target_in, target_out))?;
137    let bound = 1. / (target_out as f64).sqrt();
138    let init_bs = candle_nn::Init::Uniform {
139        lo: -bound,
140        up: bound,
141    };
142    let bs = vs.get_with_hints(target_in, "bias", init_bs)?;
143    Ok(candle_nn::Linear::new(ws, Some(bs)))
144}
145
146fn vgg13_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
147    let num_classes = 1000;
148    let blocks = vec![
149        conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
150        conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
151        conv2d_block(&[(128, 256, "features.10"), (256, 256, "features.12")], &vb)?,
152        conv2d_block(&[(256, 512, "features.15"), (512, 512, "features.17")], &vb)?,
153        conv2d_block(&[(512, 512, "features.20"), (512, 512, "features.22")], &vb)?,
154        fully_connected(
155            num_classes,
156            PreLogitConfig {
157                in_dim: (4096, 512, 7, 7),
158                target_in: 4096,
159                target_out: 512 * 7 * 7,
160            },
161            PreLogitConfig {
162                in_dim: (4096, 4096, 1, 1),
163                target_in: 4096,
164                target_out: 4096,
165            },
166            vb.clone(),
167        )?,
168    ];
169    Ok(blocks)
170}
171
172fn vgg16_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
173    let num_classes = 1000;
174    let blocks = vec![
175        conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
176        conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
177        conv2d_block(
178            &[
179                (128, 256, "features.10"),
180                (256, 256, "features.12"),
181                (256, 256, "features.14"),
182            ],
183            &vb,
184        )?,
185        conv2d_block(
186            &[
187                (256, 512, "features.17"),
188                (512, 512, "features.19"),
189                (512, 512, "features.21"),
190            ],
191            &vb,
192        )?,
193        conv2d_block(
194            &[
195                (512, 512, "features.24"),
196                (512, 512, "features.26"),
197                (512, 512, "features.28"),
198            ],
199            &vb,
200        )?,
201        fully_connected(
202            num_classes,
203            PreLogitConfig {
204                in_dim: (4096, 512, 7, 7),
205                target_in: 4096,
206                target_out: 512 * 7 * 7,
207            },
208            PreLogitConfig {
209                in_dim: (4096, 4096, 1, 1),
210                target_in: 4096,
211                target_out: 4096,
212            },
213            vb.clone(),
214        )?,
215    ];
216    Ok(blocks)
217}
218
219fn vgg19_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
220    let num_classes = 1000;
221    let blocks = vec![
222        conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
223        conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
224        conv2d_block(
225            &[
226                (128, 256, "features.10"),
227                (256, 256, "features.12"),
228                (256, 256, "features.14"),
229                (256, 256, "features.16"),
230            ],
231            &vb,
232        )?,
233        conv2d_block(
234            &[
235                (256, 512, "features.19"),
236                (512, 512, "features.21"),
237                (512, 512, "features.23"),
238                (512, 512, "features.25"),
239            ],
240            &vb,
241        )?,
242        conv2d_block(
243            &[
244                (512, 512, "features.28"),
245                (512, 512, "features.30"),
246                (512, 512, "features.32"),
247                (512, 512, "features.34"),
248            ],
249            &vb,
250        )?,
251        fully_connected(
252            num_classes,
253            PreLogitConfig {
254                in_dim: (4096, 512, 7, 7),
255                target_in: 4096,
256                target_out: 512 * 7 * 7,
257            },
258            PreLogitConfig {
259                in_dim: (4096, 4096, 1, 1),
260                target_in: 4096,
261                target_out: 4096,
262            },
263            vb.clone(),
264        )?,
265    ];
266    Ok(blocks)
267}