Skip to main content

svod_model/resnet/
model.rs

1//! [`ResNet`] — the unified depth-parameterised ResNet model. Construction is
2//! driven by [`ResNetDepth`] and [`OutputMode`]; the forward pass is identical
3//! for every variant and only the loader's key probing is depth-aware.
4//!
5//! Layout matches `timm` / `torchvision`:
6//!
7//! ```text
8//! conv1.weight              # stem 7x7
9//! bn1.{...}                 # stem BN
10//! layer{1..4}.{i}.{...}     # stage blocks
11//!   conv{1..N}.weight
12//!   bn{1..N}.{...}
13//!   downsample.0.weight     # 1x1 downsample conv (when first block downsamples)
14//!   downsample.1.{...}      # downsample BN
15//! fc.weight, fc.bias        # classification head (optional)
16//! ```
17
18use std::path::Path;
19
20use snafu::ResultExt;
21use svod_dtype::DType;
22use svod_ir::SInt;
23use svod_tensor::{BoundVariable, Tensor};
24
25use crate::blocks::{BatchNormWeights, Conv2dWeights, ResidualStage, remap};
26use crate::init::fan_in_uniform;
27use crate::state::{self, HasStateDict, StateDict, get_tensor};
28
29use super::config::{OutputMode, ResNetConfig, ResNetDepth};
30use super::error::{HubSnafu, Result, StateSnafu, TensorSnafu};
31
32/// Image classification / feature backbone. Construct via one of the loaders
33/// ([`ResNet::from_hub`], [`ResNet::from_safetensors`], or
34/// [`ResNet::from_state_dict`]) — the empty-tensor placeholders in the layer
35/// structs are not usable until weights are loaded.
36#[derive(Clone)]
37pub struct ResNet {
38    pub config: ResNetConfig,
39    stem_conv: Conv2dWeights,
40    stem_bn: BatchNormWeights,
41    stage1: ResidualStage,
42    stage2: ResidualStage,
43    stage3: ResidualStage,
44    stage4: ResidualStage,
45    head: Option<HeadWeights>,
46}
47
48#[derive(Clone)]
49struct HeadWeights {
50    weight: Tensor,
51    bias: Tensor,
52}
53
54impl ResNet {
55    /// Build with all-zero weight placeholders. Used by every loader before a
56    /// `load_state_dict` call, and exposed publicly for round-trip tests.
57    pub fn with_zero_weights(config: ResNetConfig) -> Self {
58        let depth = config.depth;
59        let block = depth.block();
60        let expansion = depth.expansion();
61        let layers = depth.layers();
62
63        // timm/torchvision channel schedule: stem emits 64, each stage doubles.
64        // Block-internal expansion (×4 for Bottleneck) multiplies the next
65        // stage's in_planes.
66        let stage1 = ResidualStage::empty(block, 64, 64, layers[0], 1);
67        let stage2 = ResidualStage::empty(block, 64 * expansion, 128, layers[1], 2);
68        let stage3 = ResidualStage::empty(block, 128 * expansion, 256, layers[2], 2);
69        let stage4 = ResidualStage::empty(block, 256 * expansion, 512, layers[3], 2);
70
71        let head = match &config.output {
72            OutputMode::Classification { num_classes } => {
73                let fan_in = 512 * expansion;
74                Some(HeadWeights {
75                    weight: fan_in_uniform(&[*num_classes, fan_in], fan_in, DType::Float32),
76                    bias: fan_in_uniform(&[*num_classes], fan_in, DType::Float32),
77                })
78            }
79            OutputMode::Features => None,
80        };
81
82        Self {
83            config,
84            stem_conv: Conv2dWeights::empty(64, 3, 7, 2, 3),
85            stem_bn: BatchNormWeights::empty(64),
86            stage1,
87            stage2,
88            stage3,
89            stage4,
90            head,
91        }
92    }
93
94    /// Number of output channels after stage 4 (before any FC head). Useful
95    /// when consumers want to pre-allocate downstream buffers.
96    pub fn feature_channels(&self) -> usize {
97        512 * self.config.depth.expansion()
98    }
99
100    // -----------------------------------------------------------------------
101    // Loaders
102    // -----------------------------------------------------------------------
103
104    /// Download `model.safetensors` from a HuggingFace Hub repository at the
105    /// `main` revision and load it. The repo must publish a flat timm /
106    /// torchvision-style state dict.
107    pub fn from_hub(model_id: &str, depth: ResNetDepth, output: OutputMode) -> Result<Self> {
108        Self::from_hub_with_revision(model_id, "main", depth, output)
109    }
110
111    pub fn from_hub_with_revision(
112        model_id: &str,
113        revision: &str,
114        depth: ResNetDepth,
115        output: OutputMode,
116    ) -> Result<Self> {
117        let api = hf_hub::api::sync::Api::new().context(HubSnafu)?;
118        let repo =
119            api.repo(hf_hub::Repo::with_revision(model_id.to_string(), hf_hub::RepoType::Model, revision.to_string()));
120        let weights_path = repo.get("model.safetensors").context(HubSnafu)?;
121        Self::from_safetensors(&weights_path, depth, output)
122    }
123
124    /// Load from a local `model.safetensors`. The file must use the timm /
125    /// torchvision key layout (see the module-level docs for the keys).
126    pub fn from_safetensors(path: &Path, depth: ResNetDepth, output: OutputMode) -> Result<Self> {
127        let sd = state::load_safetensors(path).context(StateSnafu)?;
128        Self::from_state_dict(&sd, ResNetConfig::new(depth, output))
129    }
130
131    /// Build from a preloaded state dict. Runs [`remap::fold_batchnorm`] first
132    /// to translate `running_var` into `invstd` and drop
133    /// `num_batches_tracked` — the loaded layer structs read directly from the
134    /// post-fold layout.
135    pub fn from_state_dict(sd: &StateDict, config: ResNetConfig) -> Result<Self> {
136        let sd = remap::fold_batchnorm(sd.clone())?;
137        let mut model = Self::with_zero_weights(config);
138        model.stem_conv.load_state_dict(&sd, "conv1").context(StateSnafu)?;
139        model.stem_bn.load_state_dict(&sd, "bn1").context(StateSnafu)?;
140        model.stage1.load_state_dict(&sd, "layer1").context(StateSnafu)?;
141        model.stage2.load_state_dict(&sd, "layer2").context(StateSnafu)?;
142        model.stage3.load_state_dict(&sd, "layer3").context(StateSnafu)?;
143        model.stage4.load_state_dict(&sd, "layer4").context(StateSnafu)?;
144
145        if let Some(head) = model.head.as_mut() {
146            head.weight = get_tensor(&sd, "fc.weight").context(StateSnafu)?;
147            head.bias = get_tensor(&sd, "fc.bias").context(StateSnafu)?;
148        }
149        Ok(model)
150    }
151
152    // -----------------------------------------------------------------------
153    // Forward
154    // -----------------------------------------------------------------------
155
156    /// Run the full network on `images` `[max_b, 3, H, W]`, shrunk to the
157    /// `batch` variable's bound value before the stem. Returns either
158    /// classification logits `[B, num_classes]` or the final feature map
159    /// `[B, 512*exp, H/32, W/32]`, depending on
160    /// [`ResNetConfig::output`].
161    pub fn forward(&self, images: &Tensor, batch: &BoundVariable) -> Result<Tensor> {
162        let b = batch.as_sint();
163
164        let x = images.try_shrink([Some((SInt::Const(0), b)), None, None, None]).context(TensorSnafu)?;
165        let x = self.stem_bn.forward(&self.stem_conv.forward(&x)?)?.relu().context(TensorSnafu)?;
166        let x = x
167            .max_pool2d()
168            .kernel_size(&[3, 3])
169            .stride(&[2, 2])
170            .padding(&[(1, 1), (1, 1)])
171            .call()
172            .context(TensorSnafu)?;
173
174        let x = self.stage1.forward(&x)?;
175        let x = self.stage2.forward(&x)?;
176        let x = self.stage3.forward(&x)?;
177        let x = self.stage4.forward(&x)?;
178
179        match (&self.head, &self.config.output) {
180            (Some(fc), OutputMode::Classification { .. }) => {
181                // Global average pool over the two spatial axes.
182                let pooled = x.mean_with().axes(vec![2isize, 3]).keepdim(false).call().context(TensorSnafu)?;
183                pooled.linear().weight(&fc.weight).bias(&fc.bias).call().context(TensorSnafu)
184            }
185            _ => Ok(x),
186        }
187    }
188}
189
190impl HasStateDict for ResNet {
191    fn state_dict(&self, prefix: &str) -> StateDict {
192        let mut sd = self.stem_conv.state_dict(&state::prefixed(prefix, "conv1"));
193        sd.extend(self.stem_bn.state_dict(&state::prefixed(prefix, "bn1")));
194        sd.extend(self.stage1.state_dict(&state::prefixed(prefix, "layer1")));
195        sd.extend(self.stage2.state_dict(&state::prefixed(prefix, "layer2")));
196        sd.extend(self.stage3.state_dict(&state::prefixed(prefix, "layer3")));
197        sd.extend(self.stage4.state_dict(&state::prefixed(prefix, "layer4")));
198        if let Some(head) = &self.head {
199            sd.insert(state::prefixed(prefix, "fc.weight"), head.weight.clone());
200            sd.insert(state::prefixed(prefix, "fc.bias"), head.bias.clone());
201        }
202        sd
203    }
204
205    fn load_state_dict(&mut self, sd: &StateDict, prefix: &str) -> std::result::Result<(), state::Error> {
206        self.stem_conv.load_state_dict(sd, &state::prefixed(prefix, "conv1"))?;
207        self.stem_bn.load_state_dict(sd, &state::prefixed(prefix, "bn1"))?;
208        self.stage1.load_state_dict(sd, &state::prefixed(prefix, "layer1"))?;
209        self.stage2.load_state_dict(sd, &state::prefixed(prefix, "layer2"))?;
210        self.stage3.load_state_dict(sd, &state::prefixed(prefix, "layer3"))?;
211        self.stage4.load_state_dict(sd, &state::prefixed(prefix, "layer4"))?;
212        if let Some(head) = self.head.as_mut() {
213            head.weight = get_tensor(sd, &state::prefixed(prefix, "fc.weight"))?;
214            head.bias = get_tensor(sd, &state::prefixed(prefix, "fc.bias"))?;
215        }
216        Ok(())
217    }
218}