1use alloc::{
4 collections::BTreeMap,
5 format,
6 string::{
7 String,
8 ToString,
9 },
10 vec,
11 vec::Vec,
12};
13use std::path::PathBuf;
14
15use anyhow::bail;
16use serde::{
17 Deserialize,
18 Serialize,
19};
20
21use crate::DiskCacheConfig;
22
23const X25: crc::Crc<u16> = crc::Crc::<u16>::new(&crc::CRC_16_IBM_SDLC);
24
25pub fn fetch_model_weights<S: AsRef<str>>(url: S) -> anyhow::Result<PathBuf> {
28 let cache_key = url_to_cache_key(Some("model"), url.as_ref());
29 let resource = pretrained_weights_resource_key(&cache_key);
30
31 let disk_cache = DiskCacheConfig::default();
32 disk_cache.fetch_resource(url.as_ref(), &resource)
33}
34
35pub fn url_to_cache_key(
37 name: Option<&str>,
38 url: &str,
39) -> String {
40 let hash = X25.checksum(url.as_bytes()).to_string();
41 let base_name = url.rsplit_once('/').unwrap().1;
42 match name {
43 Some(n) => format!("{}-{}-{}", n, hash, base_name),
44 None => format!("{}-{}", hash, base_name),
45 }
46}
47
48pub fn pretrained_weights_resource_key(cache_key: &str) -> Vec<String> {
58 vec!["weights".to_string(), cache_key.to_string()]
59}
60
61#[derive(Debug)]
63pub struct StaticPretrainedWeightsDescriptor<'a> {
64 pub name: &'a str,
66
67 pub description: &'a str,
69
70 pub license: Option<&'a str>,
72
73 pub origin: Option<&'a str>,
75
76 pub urls: &'a [&'a str],
78}
79
80impl<'a> StaticPretrainedWeightsDescriptor<'a> {
81 pub fn to_descriptor(&self) -> PretrainedWeightsDescriptor {
83 PretrainedWeightsDescriptor {
84 name: self.name.to_string(),
85 description: self.description.to_string(),
86 license: self.license.map(|s| s.to_string()),
87 origin: self.origin.map(|s| s.to_string()),
88 urls: self.urls.iter().map(|s| s.to_string()).collect(),
89 }
90 }
91}
92
93impl From<&StaticPretrainedWeightsDescriptor<'_>> for PretrainedWeightsDescriptor {
94 fn from(descriptor: &StaticPretrainedWeightsDescriptor) -> Self {
95 descriptor.to_descriptor()
96 }
97}
98
99#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
101pub struct PretrainedWeightsDescriptor {
102 pub name: String,
104
105 pub description: String,
107
108 pub license: Option<String>,
110
111 pub origin: Option<String>,
113
114 pub urls: Vec<String>,
116}
117
118impl PretrainedWeightsDescriptor {
119 pub fn cache_key(&self) -> String {
123 url_to_cache_key(Some(&self.name), self.urls.first().unwrap())
124 }
125
126 pub fn fetch_weights(
132 &self,
133 disk_cache: &DiskCacheConfig,
134 ) -> anyhow::Result<PathBuf> {
135 let url = self.urls.first().unwrap();
136 let cache_key = &self.cache_key();
137 let resource = pretrained_weights_resource_key(cache_key);
138
139 disk_cache.fetch_resource(url, &resource)
140 }
141}
142
143#[derive(Debug)]
145pub struct StaticPretrainedWeightsMap<'a> {
146 pub items: &'a [&'a StaticPretrainedWeightsDescriptor<'a>],
148}
149
150impl<'a> StaticPretrainedWeightsMap<'a> {
151 pub fn to_directory(&self) -> PretrainedWeightsMap {
153 PretrainedWeightsMap {
154 items: self
155 .items
156 .iter()
157 .map(|d| {
158 let desc = d.to_descriptor();
159 (desc.name.clone(), desc)
160 })
161 .collect(),
162 }
163 }
164}
165
166impl<'a> From<&StaticPretrainedWeightsMap<'a>> for PretrainedWeightsMap {
167 fn from(directory: &StaticPretrainedWeightsMap) -> Self {
168 directory.to_directory()
169 }
170}
171
172#[derive(Debug, Clone)]
174pub struct PretrainedWeightsMap {
175 pub items: BTreeMap<String, PretrainedWeightsDescriptor>,
177}
178
179impl PretrainedWeightsMap {
180 pub fn lookup_by_name(
182 &self,
183 name: &str,
184 ) -> Option<PretrainedWeightsDescriptor> {
185 self.items.get(name).cloned()
186 }
187
188 pub fn try_lookup_by_name(
190 &self,
191 name: &str,
192 ) -> anyhow::Result<PretrainedWeightsDescriptor> {
193 match self.lookup_by_name(name) {
194 Some(d) => Ok(d),
195 None => bail!("Descriptor not found: {}", name),
196 }
197 }
198
199 pub fn expect_lookup_by_name(
201 &self,
202 name: &str,
203 ) -> PretrainedWeightsDescriptor {
204 match self.try_lookup_by_name(name) {
205 Ok(p) => p,
206 Err(e) => panic!("{}", e),
207 }
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use alloc::string::{
214 String,
215 ToString,
216 };
217
218 use super::*;
219
220 #[test]
221 fn test_static_descriptor_to_descriptor() {
222 let s_desc = StaticPretrainedWeightsDescriptor {
223 name: "my_model",
224 description: "some description of my model.",
225 urls: &["foo", "bar"],
226 license: Some("MIT"),
227 origin: Some("https://github.com/my_org/my_model"),
228 };
229 let d_desc = s_desc.to_descriptor();
230
231 assert_eq!(d_desc.name, s_desc.name.to_string());
232 assert_eq!(d_desc.description, s_desc.description.to_string());
233 assert_eq!(
234 d_desc.urls,
235 s_desc
236 .urls
237 .iter()
238 .map(|s| s.to_string())
239 .collect::<Vec<String>>()
240 );
241 }
242}