use burn_core as burn;
use super::vgg19::Vgg19;
use burn::tensor::backend::Backend;
use burn_std::network::downloader::download_file_as_bytes;
use burn_store::{ModuleSnapshot, PytorchStore};
use std::fs::{File, create_dir_all, rename};
use std::io::Write;
use std::path::PathBuf;
const VGG19_URL: &str = "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth";
fn get_cache_dir() -> PathBuf {
let cache_dir = dirs::cache_dir()
.expect("Failed to get cache directory for Gram Matrix Loss")
.join("burn-pretrained-models")
.join("loss")
.join("vgg19");
if !cache_dir.exists() {
create_dir_all(&cache_dir).expect("Failed to create cache directory for Gram Matrix Loss");
}
cache_dir
}
fn download_weights_if_not_saved(cache_path: &PathBuf) {
if !cache_path.exists() {
let bytes = download_file_as_bytes(
VGG19_URL,
"Downloading VGG19 ImageNet weights for Gram Matrix Loss...",
);
let temp_path = cache_path.with_extension("pth.tmp");
let mut file = File::create(&temp_path)
.expect("Failed to create VGG19 cache file for Gram Matrix Loss");
file.write_all(&bytes)
.expect("Failed to write VGG19 weights to the cache file for Gram Matrix Loss");
rename(temp_path, cache_path)
.expect("Failed to rename temporary file to the actual VGG19 cache file name for Gram Matrix Loss");
}
}
pub fn load_vgg19_weights<B: Backend>(mut vgg19: Vgg19<B>) -> Vgg19<B> {
let cache_dir = get_cache_dir();
let cache_path = cache_dir.join("vgg19.pth");
download_weights_if_not_saved(&cache_path);
let mut store = PytorchStore::from_file(cache_path)
.allow_partial(true)
.with_key_remapping(r"^features\.0\.", "conv1_1.")
.with_key_remapping(r"^features\.2\.", "conv1_2.")
.with_key_remapping(r"^features\.5\.", "conv2_1.")
.with_key_remapping(r"^features\.7\.", "conv2_2.")
.with_key_remapping(r"^features\.10\.", "conv3_1.")
.with_key_remapping(r"^features\.12\.", "conv3_2.")
.with_key_remapping(r"^features\.14\.", "conv3_3.")
.with_key_remapping(r"^features\.16\.", "conv3_4.")
.with_key_remapping(r"^features\.19\.", "conv4_1.")
.with_key_remapping(r"^features\.21\.", "conv4_2.")
.with_key_remapping(r"^features\.23\.", "conv4_3.")
.with_key_remapping(r"^features\.25\.", "conv4_4.")
.with_key_remapping(r"^features\.28\.", "conv5_1.");
let result = vgg19.load_from(&mut store);
if let Err(e) = result {
eprintln!("Warning: Some VGG19 weights could not be loaded: {:?}", e);
}
vgg19
}