use alloc::{
collections::BTreeMap,
format,
string::{
String,
ToString,
},
vec,
vec::Vec,
};
use std::path::PathBuf;
use anyhow::bail;
use serde::{
Deserialize,
Serialize,
};
use crate::cache::disk::DiskCacheConfig;
const X25: crc::Crc<u16> = crc::Crc::<u16>::new(&crc::CRC_16_IBM_SDLC);
pub fn fetch_model_weights<S: AsRef<str>>(url: S) -> anyhow::Result<PathBuf> {
let cache_key = url_to_cache_key(Some("model"), url.as_ref());
let resource = pretrained_weights_resource_key(&cache_key);
let disk_cache = DiskCacheConfig::default();
disk_cache.fetch_resource(url.as_ref(), &resource)
}
pub fn url_to_cache_key(
name: Option<&str>,
url: &str,
) -> String {
let hash = X25.checksum(url.as_bytes()).to_string();
let base_name = url.rsplit_once('/').unwrap().1;
match name {
Some(n) => format!("{}-{}-{}", n, hash, base_name),
None => format!("{}-{}", hash, base_name),
}
}
pub fn pretrained_weights_resource_key(cache_key: &str) -> Vec<String> {
vec!["weights".to_string(), cache_key.to_string()]
}
#[derive(Debug)]
pub struct StaticPretrainedWeightsDescriptor<'a> {
pub name: &'a str,
pub description: &'a str,
pub license: Option<&'a str>,
pub origin: Option<&'a str>,
pub urls: &'a [&'a str],
}
impl<'a> StaticPretrainedWeightsDescriptor<'a> {
pub fn to_descriptor(&self) -> PretrainedWeightsDescriptor {
PretrainedWeightsDescriptor {
name: self.name.to_string(),
description: self.description.to_string(),
license: self.license.map(|s| s.to_string()),
origin: self.origin.map(|s| s.to_string()),
urls: self.urls.iter().map(|s| s.to_string()).collect(),
}
}
}
impl From<&StaticPretrainedWeightsDescriptor<'_>> for PretrainedWeightsDescriptor {
fn from(descriptor: &StaticPretrainedWeightsDescriptor) -> Self {
descriptor.to_descriptor()
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct PretrainedWeightsDescriptor {
pub name: String,
pub description: String,
pub license: Option<String>,
pub origin: Option<String>,
pub urls: Vec<String>,
}
impl PretrainedWeightsDescriptor {
pub fn cache_key(&self) -> String {
url_to_cache_key(Some(&self.name), self.urls.first().unwrap())
}
pub fn fetch_weights(
&self,
disk_cache: &DiskCacheConfig,
) -> anyhow::Result<PathBuf> {
let url = self.urls.first().unwrap();
let cache_key = &self.cache_key();
let resource = pretrained_weights_resource_key(cache_key);
disk_cache.fetch_resource(url, &resource)
}
}
#[derive(Debug)]
pub struct StaticPretrainedWeightsMap<'a> {
pub items: &'a [&'a StaticPretrainedWeightsDescriptor<'a>],
}
impl<'a> StaticPretrainedWeightsMap<'a> {
pub fn to_directory(&self) -> PretrainedWeightsMap {
PretrainedWeightsMap {
items: self
.items
.iter()
.map(|d| {
let desc = d.to_descriptor();
(desc.name.clone(), desc)
})
.collect(),
}
}
}
impl<'a> From<&StaticPretrainedWeightsMap<'a>> for PretrainedWeightsMap {
fn from(directory: &StaticPretrainedWeightsMap) -> Self {
directory.to_directory()
}
}
#[derive(Debug, Clone)]
pub struct PretrainedWeightsMap {
pub items: BTreeMap<String, PretrainedWeightsDescriptor>,
}
impl PretrainedWeightsMap {
pub fn lookup_by_name(
&self,
name: &str,
) -> Option<PretrainedWeightsDescriptor> {
self.items.get(name).cloned()
}
pub fn try_lookup_by_name(
&self,
name: &str,
) -> anyhow::Result<PretrainedWeightsDescriptor> {
match self.lookup_by_name(name) {
Some(d) => Ok(d),
None => bail!("Descriptor not found: {}", name),
}
}
pub fn expect_lookup_by_name(
&self,
name: &str,
) -> PretrainedWeightsDescriptor {
match self.try_lookup_by_name(name) {
Ok(p) => p,
Err(e) => panic!("{}", e),
}
}
}
#[cfg(test)]
mod tests {
use alloc::string::{
String,
ToString,
};
use super::*;
#[test]
fn test_static_descriptor_to_descriptor() {
let s_desc = StaticPretrainedWeightsDescriptor {
name: "my_model",
description: "some description of my model.",
urls: &["foo", "bar"],
license: Some("MIT"),
origin: Some("https://github.com/my_org/my_model"),
};
let d_desc = s_desc.to_descriptor();
assert_eq!(d_desc.name, s_desc.name.to_string());
assert_eq!(d_desc.description, s_desc.description.to_string());
assert_eq!(
d_desc.urls,
s_desc
.urls
.iter()
.map(|s| s.to_string())
.collect::<Vec<String>>()
);
}
}