use alloc::string::String;
use std::{
fs::{
File,
remove_file,
},
io::Write,
path::PathBuf,
};
use anyhow::bail;
use burn::{
config::Config,
data::network::downloader,
};
#[derive(Config, Debug)]
pub struct DiskCacheConfig {
#[config(default = "\"bimm\".to_string()")]
pub root_cache_key: String,
}
impl Default for DiskCacheConfig {
fn default() -> Self {
Self::new()
}
}
impl DiskCacheConfig {
pub fn base_cache_dir(&self) -> anyhow::Result<PathBuf> {
Ok(dirs::home_dir()
.expect("Should be able to get home directory")
.join(".cache")
.join(&self.root_cache_key))
}
pub fn ensure_base_cache_dir(&self) -> anyhow::Result<PathBuf> {
let dir = self.base_cache_dir()?;
if !dir.exists() {
std::fs::create_dir_all(&dir)?;
}
Ok(dir)
}
pub fn resource_to_path(
&self,
resource_key: &[String],
) -> anyhow::Result<PathBuf> {
let path = self.base_cache_dir()?;
Ok(resource_key.iter().fold(path, |acc, s| acc.join(s)))
}
pub fn ensure_resource_parent_dir(
&self,
resource_key: &[String],
) -> anyhow::Result<PathBuf> {
let path = self.resource_to_path(resource_key)?;
if !path.exists() {
std::fs::create_dir_all(path.parent().unwrap())?;
}
Ok(path)
}
pub fn fetch_resource(
&self,
url: &str,
resource: &[String],
) -> anyhow::Result<PathBuf> {
let cache_file_path = self.ensure_resource_parent_dir(resource)?;
try_cache_download_to_path(url, cache_file_path)
}
}
pub fn try_cache_download_to_path(
url: &str,
cache_file_path: PathBuf,
) -> anyhow::Result<PathBuf> {
if !cache_file_path.exists() {
let file_name = cache_file_path
.file_name()
.unwrap()
.to_string_lossy()
.to_string();
let bytes = downloader::download_file_as_bytes(url, &file_name);
let mut output_file = File::create(&cache_file_path)?;
let bytes_written = output_file.write(&bytes)?;
if bytes_written != bytes.len() {
remove_file(cache_file_path)?;
bail!("Failed to write the whole model weights file.");
}
}
Ok(cache_file_path)
}