use super::*;
use crate::common::error::RustBertError;
use cached_path::{Cache, Options, ProgressBar};
use dirs::cache_dir;
use lazy_static::lazy_static;
use std::path::PathBuf;
#[derive(PartialEq, Eq, Clone, Debug)]
pub struct RemoteResource {
pub url: String,
pub cache_subdir: String,
}
impl RemoteResource {
pub fn new(url: &str, cache_subdir: &str) -> RemoteResource {
RemoteResource {
url: url.to_string(),
cache_subdir: cache_subdir.to_string(),
}
}
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
let cache_subdir = name_url_tuple.0.to_string();
let url = name_url_tuple.1.to_string();
RemoteResource { url, cache_subdir }
}
}
impl ResourceProvider for RemoteResource {
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
let cached_path = CACHE
.cached_path_with_options(&self.url, &Options::default().subdir(&self.cache_subdir))?;
Ok(cached_path)
}
fn get_resource(&self) -> Result<Resource, RustBertError> {
Ok(Resource::PathBuf(self.get_local_path()?))
}
}
lazy_static! {
#[derive(Copy, Clone, Debug)]
pub static ref CACHE: Cache = Cache::builder()
.dir(_get_cache_directory())
.progress_bar(Some(ProgressBar::Light))
.build().unwrap();
}
fn _get_cache_directory() -> PathBuf {
match std::env::var("RUSTBERT_CACHE") {
Ok(value) => PathBuf::from(value),
Err(_) => {
let mut home = cache_dir().unwrap();
home.push(".rustbert");
home
}
}
}