1use std::path::Path;
19
20use snafu::ResultExt;
21use svod_dtype::DType;
22use svod_ir::SInt;
23use svod_tensor::{BoundVariable, Tensor};
24
25use crate::blocks::{BatchNormWeights, Conv2dWeights, ResidualStage, remap};
26use crate::init::fan_in_uniform;
27use crate::state::{self, HasStateDict, StateDict, get_tensor};
28
29use super::config::{OutputMode, ResNetConfig, ResNetDepth};
30use super::error::{HubSnafu, Result, StateSnafu, TensorSnafu};
31
32#[derive(Clone)]
37pub struct ResNet {
38 pub config: ResNetConfig,
39 stem_conv: Conv2dWeights,
40 stem_bn: BatchNormWeights,
41 stage1: ResidualStage,
42 stage2: ResidualStage,
43 stage3: ResidualStage,
44 stage4: ResidualStage,
45 head: Option<HeadWeights>,
46}
47
48#[derive(Clone)]
49struct HeadWeights {
50 weight: Tensor,
51 bias: Tensor,
52}
53
54impl ResNet {
55 pub fn with_zero_weights(config: ResNetConfig) -> Self {
58 let depth = config.depth;
59 let block = depth.block();
60 let expansion = depth.expansion();
61 let layers = depth.layers();
62
63 let stage1 = ResidualStage::empty(block, 64, 64, layers[0], 1);
67 let stage2 = ResidualStage::empty(block, 64 * expansion, 128, layers[1], 2);
68 let stage3 = ResidualStage::empty(block, 128 * expansion, 256, layers[2], 2);
69 let stage4 = ResidualStage::empty(block, 256 * expansion, 512, layers[3], 2);
70
71 let head = match &config.output {
72 OutputMode::Classification { num_classes } => {
73 let fan_in = 512 * expansion;
74 Some(HeadWeights {
75 weight: fan_in_uniform(&[*num_classes, fan_in], fan_in, DType::Float32),
76 bias: fan_in_uniform(&[*num_classes], fan_in, DType::Float32),
77 })
78 }
79 OutputMode::Features => None,
80 };
81
82 Self {
83 config,
84 stem_conv: Conv2dWeights::empty(64, 3, 7, 2, 3),
85 stem_bn: BatchNormWeights::empty(64),
86 stage1,
87 stage2,
88 stage3,
89 stage4,
90 head,
91 }
92 }
93
94 pub fn feature_channels(&self) -> usize {
97 512 * self.config.depth.expansion()
98 }
99
100 pub fn from_hub(model_id: &str, depth: ResNetDepth, output: OutputMode) -> Result<Self> {
108 Self::from_hub_with_revision(model_id, "main", depth, output)
109 }
110
111 pub fn from_hub_with_revision(
112 model_id: &str,
113 revision: &str,
114 depth: ResNetDepth,
115 output: OutputMode,
116 ) -> Result<Self> {
117 let api = hf_hub::api::sync::Api::new().context(HubSnafu)?;
118 let repo =
119 api.repo(hf_hub::Repo::with_revision(model_id.to_string(), hf_hub::RepoType::Model, revision.to_string()));
120 let weights_path = repo.get("model.safetensors").context(HubSnafu)?;
121 Self::from_safetensors(&weights_path, depth, output)
122 }
123
124 pub fn from_safetensors(path: &Path, depth: ResNetDepth, output: OutputMode) -> Result<Self> {
127 let sd = state::load_safetensors(path).context(StateSnafu)?;
128 Self::from_state_dict(&sd, ResNetConfig::new(depth, output))
129 }
130
131 pub fn from_state_dict(sd: &StateDict, config: ResNetConfig) -> Result<Self> {
136 let sd = remap::fold_batchnorm(sd.clone())?;
137 let mut model = Self::with_zero_weights(config);
138 model.stem_conv.load_state_dict(&sd, "conv1").context(StateSnafu)?;
139 model.stem_bn.load_state_dict(&sd, "bn1").context(StateSnafu)?;
140 model.stage1.load_state_dict(&sd, "layer1").context(StateSnafu)?;
141 model.stage2.load_state_dict(&sd, "layer2").context(StateSnafu)?;
142 model.stage3.load_state_dict(&sd, "layer3").context(StateSnafu)?;
143 model.stage4.load_state_dict(&sd, "layer4").context(StateSnafu)?;
144
145 if let Some(head) = model.head.as_mut() {
146 head.weight = get_tensor(&sd, "fc.weight").context(StateSnafu)?;
147 head.bias = get_tensor(&sd, "fc.bias").context(StateSnafu)?;
148 }
149 Ok(model)
150 }
151
152 pub fn forward(&self, images: &Tensor, batch: &BoundVariable) -> Result<Tensor> {
162 let b = batch.as_sint();
163
164 let x = images.try_shrink([Some((SInt::Const(0), b)), None, None, None]).context(TensorSnafu)?;
165 let x = self.stem_bn.forward(&self.stem_conv.forward(&x)?)?.relu().context(TensorSnafu)?;
166 let x = x
167 .max_pool2d()
168 .kernel_size(&[3, 3])
169 .stride(&[2, 2])
170 .padding(&[(1, 1), (1, 1)])
171 .call()
172 .context(TensorSnafu)?;
173
174 let x = self.stage1.forward(&x)?;
175 let x = self.stage2.forward(&x)?;
176 let x = self.stage3.forward(&x)?;
177 let x = self.stage4.forward(&x)?;
178
179 match (&self.head, &self.config.output) {
180 (Some(fc), OutputMode::Classification { .. }) => {
181 let pooled = x.mean_with().axes(vec![2isize, 3]).keepdim(false).call().context(TensorSnafu)?;
183 pooled.linear().weight(&fc.weight).bias(&fc.bias).call().context(TensorSnafu)
184 }
185 _ => Ok(x),
186 }
187 }
188}
189
190impl HasStateDict for ResNet {
191 fn state_dict(&self, prefix: &str) -> StateDict {
192 let mut sd = self.stem_conv.state_dict(&state::prefixed(prefix, "conv1"));
193 sd.extend(self.stem_bn.state_dict(&state::prefixed(prefix, "bn1")));
194 sd.extend(self.stage1.state_dict(&state::prefixed(prefix, "layer1")));
195 sd.extend(self.stage2.state_dict(&state::prefixed(prefix, "layer2")));
196 sd.extend(self.stage3.state_dict(&state::prefixed(prefix, "layer3")));
197 sd.extend(self.stage4.state_dict(&state::prefixed(prefix, "layer4")));
198 if let Some(head) = &self.head {
199 sd.insert(state::prefixed(prefix, "fc.weight"), head.weight.clone());
200 sd.insert(state::prefixed(prefix, "fc.bias"), head.bias.clone());
201 }
202 sd
203 }
204
205 fn load_state_dict(&mut self, sd: &StateDict, prefix: &str) -> std::result::Result<(), state::Error> {
206 self.stem_conv.load_state_dict(sd, &state::prefixed(prefix, "conv1"))?;
207 self.stem_bn.load_state_dict(sd, &state::prefixed(prefix, "bn1"))?;
208 self.stage1.load_state_dict(sd, &state::prefixed(prefix, "layer1"))?;
209 self.stage2.load_state_dict(sd, &state::prefixed(prefix, "layer2"))?;
210 self.stage3.load_state_dict(sd, &state::prefixed(prefix, "layer3"))?;
211 self.stage4.load_state_dict(sd, &state::prefixed(prefix, "layer4"))?;
212 if let Some(head) = self.head.as_mut() {
213 head.weight = get_tensor(sd, &state::prefixed(prefix, "fc.weight"))?;
214 head.bias = get_tensor(sd, &state::prefixed(prefix, "fc.bias"))?;
215 }
216 Ok(())
217 }
218}