use serde_json::Value;
use std::path::PathBuf;
pub fn sanitize_model_name(model_id: &str) -> String {
model_id
.replace('/', "--")
.chars()
.filter(|c| c.is_alphanumeric() || matches!(c, '-' | '_' | '.'))
.collect()
}
pub const CACHE_ROOT_ENV: &str = "UNI_CACHE_DIR";
const DEFAULT_CACHE_ROOT: &str = ".uni_cache";
fn cache_root() -> PathBuf {
std::env::var(CACHE_ROOT_ENV)
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from(DEFAULT_CACHE_ROOT))
}
pub fn resolve_provider_cache_root(provider: &str) -> PathBuf {
cache_root().join(provider)
}
pub fn resolve_cache_dir(provider: &str, model_id: &str, options: &Value) -> PathBuf {
if let Some(dir) = options.get("cache_dir").and_then(|v| v.as_str()) {
return PathBuf::from(dir);
}
cache_root()
.join(provider)
.join(sanitize_model_name(model_id))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[test]
fn sanitize_slash_replaced_with_double_dash() {
assert_eq!(
sanitize_model_name("sentence-transformers/all-MiniLM-L6-v2"),
"sentence-transformers--all-MiniLM-L6-v2"
);
}
#[test]
fn sanitize_strips_unsafe_chars() {
assert_eq!(sanitize_model_name("foo:bar@baz"), "foobarbaz");
}
#[test]
fn sanitize_keeps_safe_chars() {
assert_eq!(
sanitize_model_name("BAAI--bge-small-en-v1.5"),
"BAAI--bge-small-en-v1.5"
);
}
#[test]
fn resolve_default_path() {
let _lock = ENV_LOCK.lock().unwrap();
unsafe { std::env::remove_var(CACHE_ROOT_ENV) };
let path = resolve_cache_dir("fastembed", "BAAI/bge-small-en-v1.5", &json!({}));
assert_eq!(
path,
PathBuf::from(".uni_cache/fastembed/BAAI--bge-small-en-v1.5")
);
}
#[test]
fn resolve_env_var_root() {
let _lock = ENV_LOCK.lock().unwrap();
unsafe { std::env::set_var(CACHE_ROOT_ENV, "/data/models") };
let path = resolve_cache_dir("fastembed", "BAAI/bge-small-en-v1.5", &json!({}));
unsafe { std::env::remove_var(CACHE_ROOT_ENV) };
assert_eq!(
path,
PathBuf::from("/data/models/fastembed/BAAI--bge-small-en-v1.5")
);
}
#[test]
fn resolve_options_cache_dir_takes_priority_over_env() {
let _lock = ENV_LOCK.lock().unwrap();
unsafe { std::env::set_var(CACHE_ROOT_ENV, "/data/models") };
let opts = json!({ "cache_dir": "/tmp/my_cache" });
let path = resolve_cache_dir("fastembed", "some-model", &opts);
unsafe { std::env::remove_var(CACHE_ROOT_ENV) };
assert_eq!(path, PathBuf::from("/tmp/my_cache"));
}
#[test]
fn resolve_user_override() {
let _lock = ENV_LOCK.lock().unwrap();
unsafe { std::env::remove_var(CACHE_ROOT_ENV) };
let opts = json!({ "cache_dir": "/tmp/my_cache" });
let path = resolve_cache_dir("fastembed", "some-model", &opts);
assert_eq!(path, PathBuf::from("/tmp/my_cache"));
}
#[test]
fn resolve_candle_path() {
let _lock = ENV_LOCK.lock().unwrap();
unsafe { std::env::remove_var(CACHE_ROOT_ENV) };
let path = resolve_cache_dir(
"candle",
"sentence-transformers/all-MiniLM-L6-v2",
&json!({}),
);
assert_eq!(
path,
PathBuf::from(".uni_cache/candle/sentence-transformers--all-MiniLM-L6-v2")
);
}
}