1use alloc::string::String;
4use std::{
5 fs::{
6 File,
7 remove_file,
8 },
9 io::Write,
10 path::PathBuf,
11};
12
13use anyhow::bail;
14use burn::{
15 config::Config,
16 data::network::downloader,
17};
18
19#[derive(Config, Debug)]
21pub struct DiskCacheConfig {
22 #[config(default = "\"bimm\".to_string()")]
24 pub root_cache_key: String,
25}
26
27impl Default for DiskCacheConfig {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl DiskCacheConfig {
34 pub fn base_cache_dir(&self) -> anyhow::Result<PathBuf> {
38 Ok(dirs::home_dir()
39 .expect("Should be able to get home directory")
40 .join(".cache")
41 .join(&self.root_cache_key))
42 }
43
44 pub fn ensure_base_cache_dir(&self) -> anyhow::Result<PathBuf> {
48 let dir = self.base_cache_dir()?;
49 if !dir.exists() {
50 std::fs::create_dir_all(&dir)?;
51 }
52 Ok(dir)
53 }
54
55 pub fn resource_to_path(
59 &self,
60 resource_key: &[String],
61 ) -> anyhow::Result<PathBuf> {
62 let path = self.base_cache_dir()?;
63 Ok(resource_key.iter().fold(path, |acc, s| acc.join(s)))
64 }
65
66 pub fn ensure_resource_parent_dir(
69 &self,
70 resource_key: &[String],
71 ) -> anyhow::Result<PathBuf> {
72 let path = self.resource_to_path(resource_key)?;
73 if !path.exists() {
74 std::fs::create_dir_all(path.parent().unwrap())?;
75 }
76 Ok(path)
77 }
78
79 pub fn fetch_resource(
81 &self,
82 url: &str,
83 resource: &[String],
84 ) -> anyhow::Result<PathBuf> {
85 let cache_file_path = self.ensure_resource_parent_dir(resource)?;
86 try_cache_download_to_path(url, cache_file_path)
87 }
88}
89
90pub fn try_cache_download_to_path(
98 url: &str,
99 cache_file_path: PathBuf,
100) -> anyhow::Result<PathBuf> {
101 if !cache_file_path.exists() {
102 let file_name = cache_file_path
103 .file_name()
104 .unwrap()
105 .to_string_lossy()
106 .to_string();
107
108 let bytes = downloader::download_file_as_bytes(url, &file_name);
111
112 let mut output_file = File::create(&cache_file_path)?;
114 let bytes_written = output_file.write(&bytes)?;
115
116 if bytes_written != bytes.len() {
117 remove_file(cache_file_path)?;
118 bail!("Failed to write the whole model weights file.");
119 }
120 }
121
122 Ok(cache_file_path)
123}