use anyhow::{bail, Result};
#[derive(Debug, Clone)]
pub enum UpscalerArch {
SRVGGNetCompact {
num_feat: usize,
num_conv: usize,
scale: u32,
},
RRDBNet {
num_feat: usize,
num_grow_ch: usize,
num_block: usize,
scale: u32,
},
}
pub fn detect_architecture(tensor_names: &[&str]) -> Result<UpscalerArch> {
let has_conv_first = tensor_names.contains(&"conv_first.weight");
let has_body_0 = tensor_names.contains(&"body.0.weight");
if has_conv_first {
detect_rrdbnet(tensor_names)
} else if has_body_0 {
detect_srvggnet(tensor_names)
} else {
bail!(
"unknown upscaler architecture: no conv_first.weight or body.0.weight found in state dict"
);
}
}
fn detect_srvggnet(tensor_names: &[&str]) -> Result<UpscalerArch> {
let max_body_idx = tensor_names
.iter()
.filter_map(|n| {
n.strip_prefix("body.")
.and_then(|rest| rest.split('.').next())
.and_then(|idx| idx.parse::<usize>().ok())
})
.max()
.unwrap_or(0);
let num_conv_layers: usize = tensor_names
.iter()
.filter(|n| n.ends_with(".weight"))
.filter_map(|n| {
n.strip_prefix("body.")
.and_then(|rest| rest.strip_suffix(".weight"))
.and_then(|idx| idx.parse::<usize>().ok())
})
.filter(|idx| idx % 2 == 0) .count();
let num_feat = 64;
let scale = 4u32;
let num_conv = num_conv_layers.saturating_sub(2);
let _ = max_body_idx;
Ok(UpscalerArch::SRVGGNetCompact {
num_feat,
num_conv,
scale,
})
}
fn detect_rrdbnet(tensor_names: &[&str]) -> Result<UpscalerArch> {
let num_block = tensor_names
.iter()
.filter_map(|n| {
let rest = n.strip_prefix("body.")?;
let (idx_str, remainder) = rest.split_once('.')?;
let idx = idx_str.parse::<usize>().ok()?;
if remainder.starts_with("rdb") {
Some(idx)
} else {
None
}
})
.max()
.map(|max_idx| max_idx + 1) .unwrap_or(23);
let has_conv_up1 = tensor_names.contains(&"conv_up1.weight");
let has_conv_up2 = tensor_names.contains(&"conv_up2.weight");
let scale = if has_conv_up2 {
4
} else if has_conv_up1 {
2
} else {
4
};
let num_feat = 64;
let num_grow_ch = 32;
Ok(UpscalerArch::RRDBNet {
num_feat,
num_grow_ch,
num_block,
scale,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn detect_rrdbnet_from_keys() {
let keys = vec![
"conv_first.weight",
"conv_first.bias",
"body.0.rdb1.conv1.weight",
"body.0.rdb1.conv1.bias",
"body.0.rdb1.conv2.weight",
"body.22.rdb3.conv5.weight",
"body.23.weight", "conv_up1.weight",
"conv_up2.weight",
"conv_hr.weight",
"conv_last.weight",
];
let arch = detect_architecture(&keys).unwrap();
match arch {
UpscalerArch::RRDBNet {
num_block, scale, ..
} => {
assert_eq!(num_block, 23);
assert_eq!(scale, 4);
}
_ => panic!("expected RRDBNet"),
}
}
#[test]
fn detect_rrdbnet_x2_from_keys() {
let keys = vec![
"conv_first.weight",
"body.0.rdb1.conv1.weight",
"body.22.rdb3.conv5.weight",
"body.23.weight",
"conv_up1.weight",
"conv_hr.weight",
"conv_last.weight",
];
let arch = detect_architecture(&keys).unwrap();
match arch {
UpscalerArch::RRDBNet { scale, .. } => assert_eq!(scale, 2),
_ => panic!("expected RRDBNet"),
}
}
#[test]
fn detect_rrdbnet_anime_6b() {
let keys = vec![
"conv_first.weight",
"body.0.rdb1.conv1.weight",
"body.5.rdb3.conv5.weight",
"body.6.weight", "conv_up1.weight",
"conv_up2.weight",
"conv_hr.weight",
"conv_last.weight",
];
let arch = detect_architecture(&keys).unwrap();
match arch {
UpscalerArch::RRDBNet {
num_block, scale, ..
} => {
assert_eq!(num_block, 6);
assert_eq!(scale, 4);
}
_ => panic!("expected RRDBNet"),
}
}
#[test]
fn detect_srvggnet_from_keys() {
let keys = vec![
"body.0.weight",
"body.0.bias",
"body.1.weight", "body.2.weight",
"body.2.bias",
"body.3.weight", "body.30.weight", "body.30.bias",
"body.31.weight", ];
let arch = detect_architecture(&keys).unwrap();
assert!(matches!(arch, UpscalerArch::SRVGGNetCompact { .. }));
}
#[test]
fn unknown_architecture_errors() {
let keys = vec!["something_unknown.weight"];
assert!(detect_architecture(&keys).is_err());
}
}