Skip to main content

bunsen_cache/
disk.rs

1//! # Cache Policy
2
3use alloc::string::String;
4use std::{
5    fs::{
6        File,
7        remove_file,
8    },
9    io::Write,
10    path::PathBuf,
11};
12
13use anyhow::bail;
14use burn::{
15    config::Config,
16    data::network::downloader,
17};
18
19/// Cache Policy
20#[derive(Config, Debug)]
21pub struct DiskCacheConfig {
22    /// Key for the root cache directory.
23    #[config(default = "\"bimm\".to_string()")]
24    pub root_cache_key: String,
25}
26
27impl Default for DiskCacheConfig {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl DiskCacheConfig {
34    /// Fetch the base cache directory.
35    ///
36    /// If the cache directory does not exist, does not create it.
37    pub fn base_cache_dir(&self) -> anyhow::Result<PathBuf> {
38        Ok(dirs::home_dir()
39            .expect("Should be able to get home directory")
40            .join(".cache")
41            .join(&self.root_cache_key))
42    }
43
44    /// Fetch the base cache directory.
45    ///
46    /// If the cache directory does not exist, creates it.
47    pub fn ensure_base_cache_dir(&self) -> anyhow::Result<PathBuf> {
48        let dir = self.base_cache_dir()?;
49        if !dir.exists() {
50            std::fs::create_dir_all(&dir)?;
51        }
52        Ok(dir)
53    }
54
55    /// Map a resource key to a cache path.
56    ///
57    /// Does not ensure that the path (or any of the parents) exist.
58    pub fn resource_to_path(
59        &self,
60        resource_key: &[String],
61    ) -> anyhow::Result<PathBuf> {
62        let path = self.base_cache_dir()?;
63        Ok(resource_key.iter().fold(path, |acc, s| acc.join(s)))
64    }
65
66    /// Map a resource key to a cache path and ensure the parent directory
67    /// exists.
68    pub fn ensure_resource_parent_dir(
69        &self,
70        resource_key: &[String],
71    ) -> anyhow::Result<PathBuf> {
72        let path = self.resource_to_path(resource_key)?;
73        if !path.exists() {
74            std::fs::create_dir_all(path.parent().unwrap())?;
75        }
76        Ok(path)
77    }
78
79    /// Fetch a Resource to the Cache.
80    pub fn fetch_resource(
81        &self,
82        url: &str,
83        resource: &[String],
84    ) -> anyhow::Result<PathBuf> {
85        let cache_file_path = self.ensure_resource_parent_dir(resource)?;
86        try_cache_download_to_path(url, cache_file_path)
87    }
88}
89
90/// Download a URL resource to a given path.
91///
92/// If the path already exists, does nothing.
93///
94/// # Returns
95///
96/// The cache path.
97pub fn try_cache_download_to_path(
98    url: &str,
99    cache_file_path: PathBuf,
100) -> anyhow::Result<PathBuf> {
101    if !cache_file_path.exists() {
102        let file_name = cache_file_path
103            .file_name()
104            .unwrap()
105            .to_string_lossy()
106            .to_string();
107
108        // TODO: download-to-file instead of download-to-memory.
109        // Download file content
110        let bytes = downloader::download_file_as_bytes(url, &file_name);
111
112        // Write content to file
113        let mut output_file = File::create(&cache_file_path)?;
114        let bytes_written = output_file.write(&bytes)?;
115
116        if bytes_written != bytes.len() {
117            remove_file(cache_file_path)?;
118            bail!("Failed to write the whole model weights file.");
119        }
120    }
121
122    Ok(cache_file_path)
123}