Skip to main content

bunsen_cache/
weights.rs

1//! # Module / Weight Caches
2
3use alloc::{
4    collections::BTreeMap,
5    format,
6    string::{
7        String,
8        ToString,
9    },
10    vec,
11    vec::Vec,
12};
13use std::path::PathBuf;
14
15use anyhow::bail;
16use serde::{
17    Deserialize,
18    Serialize,
19};
20
21use crate::DiskCacheConfig;
22
23const X25: crc::Crc<u16> = crc::Crc::<u16>::new(&crc::CRC_16_IBM_SDLC);
24
25/// Returns a local path to model weights file.
26/// If the file does not exist, it will be downloaded from the given URL.
27pub fn fetch_model_weights<S: AsRef<str>>(url: S) -> anyhow::Result<PathBuf> {
28    let cache_key = url_to_cache_key(Some("model"), url.as_ref());
29    let resource = pretrained_weights_resource_key(&cache_key);
30
31    let disk_cache = DiskCacheConfig::default();
32    disk_cache.fetch_resource(url.as_ref(), &resource)
33}
34
35/// Build a cache key (bare cache file name) from a name and URL.
36pub fn url_to_cache_key(
37    name: Option<&str>,
38    url: &str,
39) -> String {
40    let hash = X25.checksum(url.as_bytes()).to_string();
41    let base_name = url.rsplit_once('/').unwrap().1;
42    match name {
43        Some(n) => format!("{}-{}-{}", n, hash, base_name),
44        None => format!("{}-{}", hash, base_name),
45    }
46}
47
48/// Get the cache resource key for a pretrained weights file.
49///
50/// # Arguments
51///
52/// - `cache_key`: the cache key (the bare cache file name).
53///
54/// # Returns
55///
56/// The cache resource key.
57pub fn pretrained_weights_resource_key(cache_key: &str) -> Vec<String> {
58    vec!["weights".to_string(), cache_key.to_string()]
59}
60
61/// Static [`PretrainedWeightsDescriptor`] provider.
62#[derive(Debug)]
63pub struct StaticPretrainedWeightsDescriptor<'a> {
64    /// Name of the model.
65    pub name: &'a str,
66
67    /// Description of the model.
68    pub description: &'a str,
69
70    /// License.
71    pub license: Option<&'a str>,
72
73    /// Source URL.
74    pub origin: Option<&'a str>,
75
76    /// URL to download the weights from.
77    pub urls: &'a [&'a str],
78}
79
80impl<'a> StaticPretrainedWeightsDescriptor<'a> {
81    /// Convert to a [`PretrainedWeightsDescriptor`].
82    pub fn to_descriptor(&self) -> PretrainedWeightsDescriptor {
83        PretrainedWeightsDescriptor {
84            name: self.name.to_string(),
85            description: self.description.to_string(),
86            license: self.license.map(|s| s.to_string()),
87            origin: self.origin.map(|s| s.to_string()),
88            urls: self.urls.iter().map(|s| s.to_string()).collect(),
89        }
90    }
91}
92
93impl From<&StaticPretrainedWeightsDescriptor<'_>> for PretrainedWeightsDescriptor {
94    fn from(descriptor: &StaticPretrainedWeightsDescriptor) -> Self {
95        descriptor.to_descriptor()
96    }
97}
98
99/// A descriptor for a pretrained weights file.
100#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
101pub struct PretrainedWeightsDescriptor {
102    /// Name of the model.
103    pub name: String,
104
105    /// Description of the model.
106    pub description: String,
107
108    /// License.
109    pub license: Option<String>,
110
111    /// Source URL.
112    pub origin: Option<String>,
113
114    /// URL to download the weights from.
115    pub urls: Vec<String>,
116}
117
118impl PretrainedWeightsDescriptor {
119    /// Cache Key
120    ///
121    /// The key is ``{name}-{url crc hash}-{url basename}``.
122    pub fn cache_key(&self) -> String {
123        url_to_cache_key(Some(&self.name), self.urls.first().unwrap())
124    }
125
126    /// Read-Through Cache the Model Weights
127    ///
128    /// # Returns
129    ///
130    /// The disk location of the cached weights.
131    pub fn fetch_weights(
132        &self,
133        disk_cache: &DiskCacheConfig,
134    ) -> anyhow::Result<PathBuf> {
135        let url = self.urls.first().unwrap();
136        let cache_key = &self.cache_key();
137        let resource = pretrained_weights_resource_key(cache_key);
138
139        disk_cache.fetch_resource(url, &resource)
140    }
141}
142
143/// Static [`PretrainedWeightsMap`] builder.
144#[derive(Debug)]
145pub struct StaticPretrainedWeightsMap<'a> {
146    /// List of static descriptors.
147    pub items: &'a [&'a StaticPretrainedWeightsDescriptor<'a>],
148}
149
150impl<'a> StaticPretrainedWeightsMap<'a> {
151    /// Convert to a [`PretrainedWeightsMap`].
152    pub fn to_directory(&self) -> PretrainedWeightsMap {
153        PretrainedWeightsMap {
154            items: self
155                .items
156                .iter()
157                .map(|d| {
158                    let desc = d.to_descriptor();
159                    (desc.name.clone(), desc)
160                })
161                .collect(),
162        }
163    }
164}
165
166impl<'a> From<&StaticPretrainedWeightsMap<'a>> for PretrainedWeightsMap {
167    fn from(directory: &StaticPretrainedWeightsMap) -> Self {
168        directory.to_directory()
169    }
170}
171
172/// Directory of [`PretrainedWeightsDescriptor`]s.
173#[derive(Debug, Clone)]
174pub struct PretrainedWeightsMap {
175    /// Map of descriptors.
176    pub items: BTreeMap<String, PretrainedWeightsDescriptor>,
177}
178
179impl PretrainedWeightsMap {
180    /// Lookup a descriptor by name.
181    pub fn lookup_by_name(
182        &self,
183        name: &str,
184    ) -> Option<PretrainedWeightsDescriptor> {
185        self.items.get(name).cloned()
186    }
187
188    /// Lookup a descriptor.
189    pub fn try_lookup_by_name(
190        &self,
191        name: &str,
192    ) -> anyhow::Result<PretrainedWeightsDescriptor> {
193        match self.lookup_by_name(name) {
194            Some(d) => Ok(d),
195            None => bail!("Descriptor not found: {}", name),
196        }
197    }
198
199    /// Lookup a descriptor.
200    pub fn expect_lookup_by_name(
201        &self,
202        name: &str,
203    ) -> PretrainedWeightsDescriptor {
204        match self.try_lookup_by_name(name) {
205            Ok(p) => p,
206            Err(e) => panic!("{}", e),
207        }
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use alloc::string::{
214        String,
215        ToString,
216    };
217
218    use super::*;
219
220    #[test]
221    fn test_static_descriptor_to_descriptor() {
222        let s_desc = StaticPretrainedWeightsDescriptor {
223            name: "my_model",
224            description: "some description of my model.",
225            urls: &["foo", "bar"],
226            license: Some("MIT"),
227            origin: Some("https://github.com/my_org/my_model"),
228        };
229        let d_desc = s_desc.to_descriptor();
230
231        assert_eq!(d_desc.name, s_desc.name.to_string());
232        assert_eq!(d_desc.description, s_desc.description.to_string());
233        assert_eq!(
234            d_desc.urls,
235            s_desc
236                .urls
237                .iter()
238                .map(|s| s.to_string())
239                .collect::<Vec<String>>()
240        );
241    }
242}