use burn_core as burn;
use burn::tensor::backend::Backend;
use burn_std::network::downloader::download_file_as_bytes;
use burn_store::{ModuleSnapshot, PytorchStore};
use std::fs::{File, create_dir_all};
use std::io::Write;
use std::path::PathBuf;
use super::metric::{Lpips, LpipsNet};
const LPIPS_VGG_URL: &str =
"https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v0.1/vgg.pth";
const LPIPS_ALEX_URL: &str =
"https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v0.1/alex.pth";
const LPIPS_SQUEEZE_URL: &str =
"https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v0.1/squeeze.pth";
const VGG16_IMAGENET_URL: &str = "https://download.pytorch.org/models/vgg16-397923af.pth";
const ALEXNET_IMAGENET_URL: &str = "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth";
const SQUEEZENET_IMAGENET_URL: &str =
"https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth";
pub fn get_lpips_weights_url(net: LpipsNet) -> &'static str {
match net {
LpipsNet::Vgg => LPIPS_VGG_URL,
LpipsNet::Alex => LPIPS_ALEX_URL,
LpipsNet::Squeeze => LPIPS_SQUEEZE_URL,
}
}
pub fn get_backbone_weights_url(net: LpipsNet) -> &'static str {
match net {
LpipsNet::Vgg => VGG16_IMAGENET_URL,
LpipsNet::Alex => ALEXNET_IMAGENET_URL,
LpipsNet::Squeeze => SQUEEZENET_IMAGENET_URL,
}
}
fn get_cache_dir() -> PathBuf {
let cache_dir = dirs::cache_dir()
.expect("Could not get cache directory")
.join("burn-dataset")
.join("lpips");
if !cache_dir.exists() {
create_dir_all(&cache_dir).expect("Failed to create cache directory");
}
cache_dir
}
fn download_if_needed(url: &str, cache_path: &PathBuf, message: &str) {
if !cache_path.exists() {
let bytes = download_file_as_bytes(url, message);
let mut file = File::create(cache_path).expect("Failed to create cache file");
file.write_all(&bytes).expect("Failed to write weights");
}
}
pub fn load_pretrained_weights<B: Backend>(mut lpips: Lpips<B>, net: LpipsNet) -> Lpips<B> {
let cache_dir = get_cache_dir();
let backbone_url = get_backbone_weights_url(net);
let backbone_cache_path = cache_dir.join(format!("{:?}_backbone.pth", net).to_lowercase());
let backbone_message = match net {
LpipsNet::Vgg => "Downloading VGG16 ImageNet weights...",
LpipsNet::Alex => "Downloading AlexNet ImageNet weights...",
LpipsNet::Squeeze => "Downloading SqueezeNet ImageNet weights...",
};
download_if_needed(backbone_url, &backbone_cache_path, backbone_message);
let lpips_url = get_lpips_weights_url(net);
let lpips_cache_path = cache_dir.join(format!("{:?}_lpips.pth", net).to_lowercase());
let lpips_message = match net {
LpipsNet::Vgg => "Downloading LPIPS VGG weights...",
LpipsNet::Alex => "Downloading LPIPS AlexNet weights...",
LpipsNet::Squeeze => "Downloading LPIPS SqueezeNet weights...",
};
download_if_needed(lpips_url, &lpips_cache_path, lpips_message);
lpips = load_backbone_weights(lpips, &backbone_cache_path);
lpips = load_lpips_weights(lpips, &lpips_cache_path);
lpips
}
fn load_backbone_weights<B: Backend>(lpips: Lpips<B>, cache_path: &PathBuf) -> Lpips<B> {
match lpips {
Lpips::Vgg(mut inner) => {
let mut store = PytorchStore::from_file(cache_path)
.allow_partial(true)
.with_key_remapping(r"^features\.0\.", "extractor.conv1_1.")
.with_key_remapping(r"^features\.2\.", "extractor.conv1_2.")
.with_key_remapping(r"^features\.5\.", "extractor.conv2_1.")
.with_key_remapping(r"^features\.7\.", "extractor.conv2_2.")
.with_key_remapping(r"^features\.10\.", "extractor.conv3_1.")
.with_key_remapping(r"^features\.12\.", "extractor.conv3_2.")
.with_key_remapping(r"^features\.14\.", "extractor.conv3_3.")
.with_key_remapping(r"^features\.17\.", "extractor.conv4_1.")
.with_key_remapping(r"^features\.19\.", "extractor.conv4_2.")
.with_key_remapping(r"^features\.21\.", "extractor.conv4_3.")
.with_key_remapping(r"^features\.24\.", "extractor.conv5_1.")
.with_key_remapping(r"^features\.26\.", "extractor.conv5_2.")
.with_key_remapping(r"^features\.28\.", "extractor.conv5_3.");
if let Err(e) = inner.load_from(&mut store) {
log::warn!("Some VGG backbone weights could not be loaded: {:?}", e);
}
Lpips::Vgg(inner)
}
Lpips::Alex(mut inner) => {
let mut store = PytorchStore::from_file(cache_path)
.allow_partial(true)
.with_key_remapping(r"^features\.0\.", "extractor.conv1.")
.with_key_remapping(r"^features\.3\.", "extractor.conv2.")
.with_key_remapping(r"^features\.6\.", "extractor.conv3.")
.with_key_remapping(r"^features\.8\.", "extractor.conv4.")
.with_key_remapping(r"^features\.10\.", "extractor.conv5.");
if let Err(e) = inner.load_from(&mut store) {
log::warn!("Some AlexNet backbone weights could not be loaded: {:?}", e);
}
Lpips::Alex(inner)
}
Lpips::Squeeze(mut inner) => {
let mut store = PytorchStore::from_file(cache_path)
.allow_partial(true)
.with_key_remapping(r"^features\.0\.", "extractor.conv1.")
.with_key_remapping(r"^features\.3\.", "extractor.fire1.")
.with_key_remapping(r"^features\.4\.", "extractor.fire2.")
.with_key_remapping(r"^features\.6\.", "extractor.fire3.")
.with_key_remapping(r"^features\.7\.", "extractor.fire4.")
.with_key_remapping(r"^features\.9\.", "extractor.fire5.")
.with_key_remapping(r"^features\.10\.", "extractor.fire6.")
.with_key_remapping(r"^features\.11\.", "extractor.fire7.")
.with_key_remapping(r"^features\.12\.", "extractor.fire8.");
if let Err(e) = inner.load_from(&mut store) {
log::warn!(
"Some SqueezeNet backbone weights could not be loaded: {:?}",
e
);
}
Lpips::Squeeze(inner)
}
}
}
fn load_lpips_weights<B: Backend>(lpips: Lpips<B>, cache_path: &PathBuf) -> Lpips<B> {
match lpips {
Lpips::Vgg(mut inner) => {
let mut store = PytorchStore::from_file(cache_path)
.allow_partial(true)
.with_key_remapping(r"^lin0\.model\.1\.", "lin0.")
.with_key_remapping(r"^lin1\.model\.1\.", "lin1.")
.with_key_remapping(r"^lin2\.model\.1\.", "lin2.")
.with_key_remapping(r"^lin3\.model\.1\.", "lin3.")
.with_key_remapping(r"^lin4\.model\.1\.", "lin4.");
if let Err(e) = inner.load_from(&mut store) {
log::warn!("Some VGG LPIPS weights could not be loaded: {:?}", e);
}
Lpips::Vgg(inner)
}
Lpips::Alex(mut inner) => {
let mut store = PytorchStore::from_file(cache_path)
.allow_partial(true)
.with_key_remapping(r"^lin0\.model\.1\.", "lin0.")
.with_key_remapping(r"^lin1\.model\.1\.", "lin1.")
.with_key_remapping(r"^lin2\.model\.1\.", "lin2.")
.with_key_remapping(r"^lin3\.model\.1\.", "lin3.")
.with_key_remapping(r"^lin4\.model\.1\.", "lin4.");
if let Err(e) = inner.load_from(&mut store) {
log::warn!("Some AlexNet LPIPS weights could not be loaded: {:?}", e);
}
Lpips::Alex(inner)
}
Lpips::Squeeze(mut inner) => {
let mut store = PytorchStore::from_file(cache_path)
.allow_partial(true)
.with_key_remapping(r"^lin0\.model\.1\.", "lin0.")
.with_key_remapping(r"^lin1\.model\.1\.", "lin1.")
.with_key_remapping(r"^lin2\.model\.1\.", "lin2.")
.with_key_remapping(r"^lin3\.model\.1\.", "lin3.")
.with_key_remapping(r"^lin4\.model\.1\.", "lin4.")
.with_key_remapping(r"^lin5\.model\.1\.", "lin5.")
.with_key_remapping(r"^lin6\.model\.1\.", "lin6.");
if let Err(e) = inner.load_from(&mut store) {
log::warn!("Some SqueezeNet LPIPS weights could not be loaded: {:?}", e);
}
Lpips::Squeeze(inner)
}
}
}