use std::path::Path;
use snafu::ResultExt;
use svod_dtype::DType;
use svod_ir::SInt;
use svod_tensor::{BoundVariable, Tensor};
use crate::blocks::{BatchNormWeights, Conv2dWeights, ResidualStage, remap};
use crate::init::fan_in_uniform;
use crate::state::{self, HasStateDict, StateDict, get_tensor};
use super::config::{OutputMode, ResNetConfig, ResNetDepth};
use super::error::{HubSnafu, Result, StateSnafu, TensorSnafu};
#[derive(Clone)]
pub struct ResNet {
pub config: ResNetConfig,
stem_conv: Conv2dWeights,
stem_bn: BatchNormWeights,
stage1: ResidualStage,
stage2: ResidualStage,
stage3: ResidualStage,
stage4: ResidualStage,
head: Option<HeadWeights>,
}
#[derive(Clone)]
struct HeadWeights {
weight: Tensor,
bias: Tensor,
}
impl ResNet {
pub fn with_zero_weights(config: ResNetConfig) -> Self {
let depth = config.depth;
let block = depth.block();
let expansion = depth.expansion();
let layers = depth.layers();
let stage1 = ResidualStage::empty(block, 64, 64, layers[0], 1);
let stage2 = ResidualStage::empty(block, 64 * expansion, 128, layers[1], 2);
let stage3 = ResidualStage::empty(block, 128 * expansion, 256, layers[2], 2);
let stage4 = ResidualStage::empty(block, 256 * expansion, 512, layers[3], 2);
let head = match &config.output {
OutputMode::Classification { num_classes } => {
let fan_in = 512 * expansion;
Some(HeadWeights {
weight: fan_in_uniform(&[*num_classes, fan_in], fan_in, DType::Float32),
bias: fan_in_uniform(&[*num_classes], fan_in, DType::Float32),
})
}
OutputMode::Features => None,
};
Self {
config,
stem_conv: Conv2dWeights::empty(64, 3, 7, 2, 3),
stem_bn: BatchNormWeights::empty(64),
stage1,
stage2,
stage3,
stage4,
head,
}
}
pub fn feature_channels(&self) -> usize {
512 * self.config.depth.expansion()
}
pub fn from_hub(model_id: &str, depth: ResNetDepth, output: OutputMode) -> Result<Self> {
Self::from_hub_with_revision(model_id, "main", depth, output)
}
pub fn from_hub_with_revision(
model_id: &str,
revision: &str,
depth: ResNetDepth,
output: OutputMode,
) -> Result<Self> {
let api = hf_hub::api::sync::Api::new().context(HubSnafu)?;
let repo =
api.repo(hf_hub::Repo::with_revision(model_id.to_string(), hf_hub::RepoType::Model, revision.to_string()));
let weights_path = repo.get("model.safetensors").context(HubSnafu)?;
Self::from_safetensors(&weights_path, depth, output)
}
pub fn from_safetensors(path: &Path, depth: ResNetDepth, output: OutputMode) -> Result<Self> {
let sd = state::load_safetensors(path).context(StateSnafu)?;
Self::from_state_dict(&sd, ResNetConfig::new(depth, output))
}
pub fn from_state_dict(sd: &StateDict, config: ResNetConfig) -> Result<Self> {
let sd = remap::fold_batchnorm(sd.clone())?;
let mut model = Self::with_zero_weights(config);
model.stem_conv.load_state_dict(&sd, "conv1").context(StateSnafu)?;
model.stem_bn.load_state_dict(&sd, "bn1").context(StateSnafu)?;
model.stage1.load_state_dict(&sd, "layer1").context(StateSnafu)?;
model.stage2.load_state_dict(&sd, "layer2").context(StateSnafu)?;
model.stage3.load_state_dict(&sd, "layer3").context(StateSnafu)?;
model.stage4.load_state_dict(&sd, "layer4").context(StateSnafu)?;
if let Some(head) = model.head.as_mut() {
head.weight = get_tensor(&sd, "fc.weight").context(StateSnafu)?;
head.bias = get_tensor(&sd, "fc.bias").context(StateSnafu)?;
}
Ok(model)
}
pub fn forward(&self, images: &Tensor, batch: &BoundVariable) -> Result<Tensor> {
let b = batch.as_sint();
let x = images.try_shrink([Some((SInt::Const(0), b)), None, None, None]).context(TensorSnafu)?;
let x = self.stem_bn.forward(&self.stem_conv.forward(&x)?)?.relu().context(TensorSnafu)?;
let x = x
.max_pool2d()
.kernel_size(&[3, 3])
.stride(&[2, 2])
.padding(&[(1, 1), (1, 1)])
.call()
.context(TensorSnafu)?;
let x = self.stage1.forward(&x)?;
let x = self.stage2.forward(&x)?;
let x = self.stage3.forward(&x)?;
let x = self.stage4.forward(&x)?;
match (&self.head, &self.config.output) {
(Some(fc), OutputMode::Classification { .. }) => {
let pooled = x.mean_with().axes(vec![2isize, 3]).keepdim(false).call().context(TensorSnafu)?;
pooled.linear().weight(&fc.weight).bias(&fc.bias).call().context(TensorSnafu)
}
_ => Ok(x),
}
}
}
impl HasStateDict for ResNet {
fn state_dict(&self, prefix: &str) -> StateDict {
let mut sd = self.stem_conv.state_dict(&state::prefixed(prefix, "conv1"));
sd.extend(self.stem_bn.state_dict(&state::prefixed(prefix, "bn1")));
sd.extend(self.stage1.state_dict(&state::prefixed(prefix, "layer1")));
sd.extend(self.stage2.state_dict(&state::prefixed(prefix, "layer2")));
sd.extend(self.stage3.state_dict(&state::prefixed(prefix, "layer3")));
sd.extend(self.stage4.state_dict(&state::prefixed(prefix, "layer4")));
if let Some(head) = &self.head {
sd.insert(state::prefixed(prefix, "fc.weight"), head.weight.clone());
sd.insert(state::prefixed(prefix, "fc.bias"), head.bias.clone());
}
sd
}
fn load_state_dict(&mut self, sd: &StateDict, prefix: &str) -> std::result::Result<(), state::Error> {
self.stem_conv.load_state_dict(sd, &state::prefixed(prefix, "conv1"))?;
self.stem_bn.load_state_dict(sd, &state::prefixed(prefix, "bn1"))?;
self.stage1.load_state_dict(sd, &state::prefixed(prefix, "layer1"))?;
self.stage2.load_state_dict(sd, &state::prefixed(prefix, "layer2"))?;
self.stage3.load_state_dict(sd, &state::prefixed(prefix, "layer3"))?;
self.stage4.load_state_dict(sd, &state::prefixed(prefix, "layer4"))?;
if let Some(head) = self.head.as_mut() {
head.weight = get_tensor(sd, &state::prefixed(prefix, "fc.weight"))?;
head.bias = get_tensor(sd, &state::prefixed(prefix, "fc.bias"))?;
}
Ok(())
}
}