use std::fs::{self, File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::OnceLock;
use std::time::{Duration, Instant};
use fs2::FileExt;
use sha2::{Digest, Sha256};
use thiserror::Error;
pub(crate) const HF_REPO: &str = "BAAI/bge-small-en-v1.5";
#[cfg(any(test, feature = "loader-test-hooks"))]
pub const HF_REVISION: &str = "5c38ec7c405ec4b44b94cc5a9bb96e735b38267a";
#[cfg(not(any(test, feature = "loader-test-hooks")))]
pub(crate) const HF_REVISION: &str = "5c38ec7c405ec4b44b94cc5a9bb96e735b38267a";
pub(crate) const HF_BASE_URL: &str = "https://huggingface.co";
pub(crate) const CONFIG_JSON_SHA256: &str =
"094f8e891b932f2000c92cfc663bac4c62069f5d8af5b5278c4306aef3084750";
pub(crate) const TOKENIZER_JSON_SHA256: &str =
"d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66";
pub(crate) const MODEL_SAFETENSORS_SHA256: &str =
"3c9f31665447c8911517620762200d2245a2518d6e7208acc78cd9db317e21ad";
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(60);
const MAX_ATTEMPTS: u32 = 3;
const DEFAULT_LOCK_TIMEOUT: Duration = Duration::from_secs(120);
pub(crate) const ENV_LOCK_TIMEOUT: &str = "FATHOMDB_EMBEDDER_LOCK_TIMEOUT_S";
pub(crate) const ENV_CONNECT_TIMEOUT: &str = "FATHOMDB_EMBEDDER_CONNECT_TIMEOUT_S";
pub(crate) const ENV_READ_TIMEOUT: &str = "FATHOMDB_EMBEDDER_READ_TIMEOUT_S";
fn parse_secs_env_or_default(var: &str, default: Duration) -> Duration {
match std::env::var(var) {
Ok(s) => match s.parse::<u64>() {
Ok(n) => Duration::from_secs(n),
Err(_) => {
eprintln!(
"fathomdb-embedder: invalid value for {var} ({s:?}); falling back to default \
{default:?}"
);
default
}
},
Err(_) => default,
}
}
fn model_sha_prefix() -> &'static str {
static PREFIX: OnceLock<String> = OnceLock::new();
PREFIX.get_or_init(|| {
let mut h = Sha256::new();
h.update(format!("{HF_REPO}@{HF_REVISION}").as_bytes());
let hex = format!("{:x}", h.finalize());
hex[..12].to_string()
})
}
#[derive(Debug, Clone)]
pub struct LoadedWeights {
pub config_json_path: PathBuf,
pub tokenizer_json_path: PathBuf,
pub model_safetensors_path: PathBuf,
pub bytes_downloaded: u64,
pub events: Vec<EmbedderEvent>,
}
pub use super::EmbedderEvent;
#[derive(Debug, Error)]
pub enum EmbedderLoadError {
#[error("network unavailable after {attempts} attempts: {source}")]
NetworkUnavailable {
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
attempts: u32,
},
#[error("checksum mismatch for {file:?}: expected {expected}, actual {actual}")]
ChecksumMismatch { file: PathBuf, expected: String, actual: String },
#[error("model dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: u32, actual: u32 },
#[error("cache I/O error at {path:?}: {source}")]
CacheIoError {
path: PathBuf,
#[source]
source: std::io::Error,
},
#[error("model deserialize: {source}")]
ModelDeserialize {
#[source]
source: candle_core::Error,
},
#[error("tokenizer load: {source}")]
TokenizerLoad {
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
#[error("timed out acquiring embedder cache lock at {lock_path:?} after {waited_s}s")]
LockTimeout { lock_path: PathBuf, waited_s: u64 },
}
enum DownloadAttemptError {
Network(Box<ureq::Error>),
NetworkStreamIo(std::io::Error),
CacheIo { path: PathBuf, source: std::io::Error },
}
fn retry_decision_ureq(err: &ureq::Error) -> RetryDecision {
match err {
ureq::Error::Status(code, _) => {
if (500..=599).contains(code) || *code == 408 || *code == 429 {
RetryDecision::Retry
} else {
RetryDecision::FailFast
}
}
ureq::Error::Transport(_) => RetryDecision::Retry,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum RetryDecision {
Retry,
FailFast,
}
#[derive(Debug, Clone)]
pub struct LoaderConfig {
base_url: String,
cache_root: PathBuf,
hf_token: Option<String>,
config_sha: String,
tokenizer_sha: String,
model_sha: String,
connect_timeout: Duration,
read_timeout: Duration,
lock_timeout: Duration,
hf_hub_root: Option<PathBuf>,
}
impl LoaderConfig {
pub(crate) fn production() -> Result<Self, EmbedderLoadError> {
let cache_root = dirs::cache_dir().ok_or_else(|| EmbedderLoadError::CacheIoError {
path: PathBuf::from("<dirs::cache_dir>"),
source: std::io::Error::new(
std::io::ErrorKind::NotFound,
"platform cache dir unavailable",
),
})?;
let lock_timeout = parse_secs_env_or_default(ENV_LOCK_TIMEOUT, DEFAULT_LOCK_TIMEOUT);
let connect_timeout =
parse_secs_env_or_default(ENV_CONNECT_TIMEOUT, DEFAULT_CONNECT_TIMEOUT);
let read_timeout = parse_secs_env_or_default(ENV_READ_TIMEOUT, DEFAULT_READ_TIMEOUT);
Ok(Self {
base_url: HF_BASE_URL.to_string(),
cache_root,
hf_token: std::env::var("HF_TOKEN").ok(),
config_sha: CONFIG_JSON_SHA256.to_string(),
tokenizer_sha: TOKENIZER_JSON_SHA256.to_string(),
model_sha: MODEL_SAFETENSORS_SHA256.to_string(),
connect_timeout,
read_timeout,
lock_timeout,
hf_hub_root: None,
})
}
#[cfg(any(test, feature = "loader-test-hooks"))]
pub fn for_tests() -> Self {
Self {
base_url: "http://127.0.0.1:0".to_string(),
cache_root: PathBuf::from("/tmp/fathomdb-embedder-tests"),
hf_token: None,
config_sha: String::new(),
tokenizer_sha: String::new(),
model_sha: String::new(),
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
read_timeout: DEFAULT_READ_TIMEOUT,
lock_timeout: DEFAULT_LOCK_TIMEOUT,
hf_hub_root: Some(PathBuf::from("/nonexistent-fathomdb-embedder-test-hub")),
}
}
#[cfg(any(test, feature = "loader-test-hooks"))]
pub fn with_hf_hub_root(mut self, root: Option<PathBuf>) -> Self {
self.hf_hub_root = root;
self
}
#[cfg(any(test, feature = "loader-test-hooks"))]
pub fn with_base_url(mut self, base_url: String) -> Self {
self.base_url = base_url;
self
}
#[cfg(any(test, feature = "loader-test-hooks"))]
pub fn with_cache_root(mut self, cache_root: PathBuf) -> Self {
self.cache_root = cache_root;
self
}
#[cfg(any(test, feature = "loader-test-hooks"))]
pub fn with_hf_token(mut self, token: Option<String>) -> Self {
self.hf_token = token;
self
}
#[cfg(any(test, feature = "loader-test-hooks"))]
pub fn with_test_pins(
mut self,
config_sha: String,
tokenizer_sha: String,
model_sha: String,
) -> Self {
self.config_sha = config_sha;
self.tokenizer_sha = tokenizer_sha;
self.model_sha = model_sha;
self
}
#[cfg(any(test, feature = "loader-test-hooks"))]
pub fn expected_cache_dir(&self) -> PathBuf {
self.cache_dir_internal()
}
#[cfg(any(test, feature = "loader-test-hooks"))]
pub fn for_tests_reading_timeout_env() -> Self {
let connect_timeout =
parse_secs_env_or_default(ENV_CONNECT_TIMEOUT, DEFAULT_CONNECT_TIMEOUT);
let read_timeout = parse_secs_env_or_default(ENV_READ_TIMEOUT, DEFAULT_READ_TIMEOUT);
let mut cfg = Self::for_tests();
cfg.connect_timeout = connect_timeout;
cfg.read_timeout = read_timeout;
cfg
}
#[cfg(any(test, feature = "loader-test-hooks"))]
pub fn connect_timeout(&self) -> Duration {
self.connect_timeout
}
#[cfg(any(test, feature = "loader-test-hooks"))]
pub fn read_timeout(&self) -> Duration {
self.read_timeout
}
fn cache_dir_internal(&self) -> PathBuf {
self.cache_root.join("fathomdb").join("embedders").join(model_sha_prefix())
}
}
pub fn load_pinned_default_embedder() -> Result<LoadedWeights, EmbedderLoadError> {
load_with_config_internal(LoaderConfig::production()?)
}
#[cfg(any(test, feature = "loader-test-hooks"))]
pub fn load_with_config(cfg: LoaderConfig) -> Result<LoadedWeights, EmbedderLoadError> {
load_with_config_internal(cfg)
}
fn load_with_config_internal(cfg: LoaderConfig) -> Result<LoadedWeights, EmbedderLoadError> {
let cache_dir = cfg.cache_dir_internal();
fs::create_dir_all(&cache_dir)
.map_err(|source| EmbedderLoadError::CacheIoError { path: cache_dir.clone(), source })?;
let mut events = Vec::new();
let mut bytes_downloaded: u64 = 0;
let files = [
("config.json", cfg.config_sha.clone()),
("tokenizer.json", cfg.tokenizer_sha.clone()),
("model.safetensors", cfg.model_sha.clone()),
];
let mut paths = Vec::with_capacity(3);
for (file_name, expected_sha) in &files {
let final_path = cache_dir.join(file_name);
if file_matches_sha(&final_path, expected_sha)? {
events.push(EmbedderEvent::DefaultEmbedderCacheHit {
file: (*file_name).to_string(),
sha256: expected_sha.clone(),
cache_path: final_path.clone(),
});
paths.push(final_path);
continue;
}
if let Some(hub_path) = hf_hub_candidate_path(&cfg, file_name) {
if file_matches_sha(&hub_path, expected_sha)? {
materialize_from_hf_hub(&hub_path, &final_path)?;
events.push(EmbedderEvent::DefaultEmbedderCacheHit {
file: (*file_name).to_string(),
sha256: expected_sha.clone(),
cache_path: final_path.clone(),
});
paths.push(final_path);
continue;
}
}
let (n, fetched_event) = fetch_under_lock(&cfg, &cache_dir, file_name, expected_sha)?;
bytes_downloaded = bytes_downloaded.saturating_add(n);
match fetched_event {
FetchOutcome::Downloaded(ev) => events.push(ev),
FetchOutcome::CacheHitAfterLock(ev) => events.push(ev),
}
paths.push(final_path);
}
Ok(LoadedWeights {
config_json_path: paths[0].clone(),
tokenizer_json_path: paths[1].clone(),
model_safetensors_path: paths[2].clone(),
bytes_downloaded,
events,
})
}
enum FetchOutcome {
Downloaded(EmbedderEvent),
CacheHitAfterLock(EmbedderEvent),
}
fn fetch_under_lock(
cfg: &LoaderConfig,
cache_dir: &Path,
file_name: &str,
expected_sha: &str,
) -> Result<(u64, FetchOutcome), EmbedderLoadError> {
let lock_path = cache_dir.join(".lock");
let lock_file = OpenOptions::new()
.create(true)
.read(true)
.write(true)
.truncate(false)
.open(&lock_path)
.map_err(|source| EmbedderLoadError::CacheIoError { path: lock_path.clone(), source })?;
acquire_exclusive_with_timeout(&lock_file, &lock_path, cfg.lock_timeout)?;
let _guard = LockGuard(&lock_file);
let final_path = cache_dir.join(file_name);
if file_matches_sha(&final_path, expected_sha)? {
return Ok((
0,
FetchOutcome::CacheHitAfterLock(EmbedderEvent::DefaultEmbedderCacheHit {
file: file_name.to_string(),
sha256: expected_sha.to_string(),
cache_path: final_path,
}),
));
}
let partial_path = cache_dir.join(format!("{file_name}.partial"));
let url = format!("{}/{}/resolve/{}/{}", cfg.base_url, HF_REPO, HF_REVISION, file_name);
let start = Instant::now();
let bytes = download_with_retries(cfg, &url, &partial_path, file_name)?;
let duration_ms = start.elapsed().as_millis() as u64;
let observed_sha = sha256_file(&partial_path)
.map_err(|source| EmbedderLoadError::CacheIoError { path: partial_path.clone(), source })?;
if observed_sha != expected_sha {
let _ = fs::remove_file(&partial_path);
return Err(EmbedderLoadError::ChecksumMismatch {
file: partial_path.clone(),
expected: expected_sha.to_string(),
actual: observed_sha,
});
}
fs::rename(&partial_path, &final_path)
.map_err(|source| EmbedderLoadError::CacheIoError { path: final_path.clone(), source })?;
#[cfg(unix)]
fsync_parent_dir(&final_path)?;
Ok((
bytes,
FetchOutcome::Downloaded(EmbedderEvent::DefaultEmbedderDownload {
file: file_name.to_string(),
url,
bytes,
sha256: observed_sha,
cache_path: final_path,
duration_ms,
}),
))
}
#[cfg(unix)]
fn fsync_parent_dir(path: &Path) -> Result<(), EmbedderLoadError> {
if let Some(parent) = path.parent() {
let dir = File::open(parent).map_err(|source| EmbedderLoadError::CacheIoError {
path: parent.to_path_buf(),
source,
})?;
dir.sync_all().map_err(|source| EmbedderLoadError::CacheIoError {
path: parent.to_path_buf(),
source,
})?;
}
Ok(())
}
fn hf_hub_candidate_path(cfg: &LoaderConfig, file_name: &str) -> Option<PathBuf> {
let hf_home = if let Some(root) = &cfg.hf_hub_root {
root.clone()
} else {
match std::env::var_os("HF_HOME") {
Some(p) => PathBuf::from(p),
None => dirs::home_dir()?.join(".cache").join("huggingface"),
}
};
let repo_encoded = format!("models--{}", HF_REPO.replace('/', "--"));
Some(hf_home.join("hub").join(repo_encoded).join("snapshots").join(HF_REVISION).join(file_name))
}
fn materialize_from_hf_hub(src: &Path, dst: &Path) -> Result<(), EmbedderLoadError> {
#[cfg(unix)]
{
if fs::hard_link(src, dst).is_ok() {
return Ok(());
}
}
fs::copy(src, dst)
.map(|_| ())
.map_err(|source| EmbedderLoadError::CacheIoError { path: dst.to_path_buf(), source })
}
fn acquire_exclusive_with_timeout(
f: &File,
lock_path: &Path,
timeout: Duration,
) -> Result<(), EmbedderLoadError> {
let deadline = Instant::now() + timeout;
loop {
match f.try_lock_exclusive() {
Ok(()) => return Ok(()),
Err(e) => {
if e.kind() != std::io::ErrorKind::WouldBlock {
return Err(EmbedderLoadError::CacheIoError {
path: lock_path.to_path_buf(),
source: e,
});
}
if Instant::now() >= deadline {
return Err(EmbedderLoadError::LockTimeout {
lock_path: lock_path.to_path_buf(),
waited_s: timeout.as_secs(),
});
}
std::thread::sleep(Duration::from_millis(25));
}
}
}
}
struct LockGuard<'a>(&'a File);
impl Drop for LockGuard<'_> {
fn drop(&mut self) {
let _ = fs2::FileExt::unlock(self.0);
}
}
fn download_with_retries(
cfg: &LoaderConfig,
url: &str,
partial_path: &Path,
_file_name: &str,
) -> Result<u64, EmbedderLoadError> {
let mut last_net_err: Option<Box<dyn std::error::Error + Send + Sync>> = None;
let mut completed_attempts: u32 = 0;
for attempt in 0..MAX_ATTEMPTS {
completed_attempts = attempt + 1;
match download_once(cfg, url, partial_path) {
Ok(n) => return Ok(n),
Err(DownloadAttemptError::CacheIo { path, source }) => {
return Err(EmbedderLoadError::CacheIoError { path, source });
}
Err(DownloadAttemptError::Network(e)) => {
if retry_decision_ureq(&e) == RetryDecision::FailFast {
return Err(EmbedderLoadError::NetworkUnavailable {
source: e,
attempts: completed_attempts,
});
}
last_net_err = Some(e);
if attempt + 1 < MAX_ATTEMPTS {
let secs = 1u64 << attempt;
std::thread::sleep(Duration::from_secs(secs));
}
}
Err(DownloadAttemptError::NetworkStreamIo(io)) => {
last_net_err = Some(Box::new(io));
if attempt + 1 < MAX_ATTEMPTS {
let secs = 1u64 << attempt;
std::thread::sleep(Duration::from_secs(secs));
}
}
}
}
Err(EmbedderLoadError::NetworkUnavailable {
source: last_net_err.expect("at least one retryable attempt produced an error"),
attempts: completed_attempts,
})
}
fn download_once(
cfg: &LoaderConfig,
url: &str,
partial_path: &Path,
) -> Result<u64, DownloadAttemptError> {
let agent = ureq::AgentBuilder::new()
.timeout_connect(cfg.connect_timeout)
.timeout_read(cfg.read_timeout)
.redirects(3)
.build();
let existing = fs::metadata(partial_path).map(|m| m.len()).unwrap_or(0);
let mut req = agent.get(url);
if let Some(token) = &cfg.hf_token {
req = req.set("Authorization", &format!("Bearer {token}"));
}
if existing > 0 {
req = req.set("Range", &format!("bytes={existing}-"));
}
let resp = req.call().map_err(|e| DownloadAttemptError::Network(Box::new(e)))?;
let status = resp.status();
if !(status == 200 || status == 206) {
return Err(DownloadAttemptError::Network(Box::new(ureq::Error::Status(status, resp))));
}
let mk_io = |source: std::io::Error| DownloadAttemptError::CacheIo {
path: partial_path.to_path_buf(),
source,
};
let mut f = if status == 206 && existing > 0 {
let mut f = OpenOptions::new().write(true).open(partial_path).map_err(mk_io)?;
f.seek(SeekFrom::End(0)).map_err(mk_io)?;
f
} else {
match OpenOptions::new().write(true).create_new(true).open(partial_path) {
Ok(f) => f,
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
fs::remove_file(partial_path).map_err(mk_io)?;
OpenOptions::new().write(true).create_new(true).open(partial_path).map_err(mk_io)?
}
Err(source) => {
return Err(mk_io(source));
}
}
};
let mut reader = resp.into_reader();
let mut buf = [0u8; 64 * 1024];
let mut written: u64 = 0;
loop {
match reader.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
f.write_all(&buf[..n]).map_err(mk_io)?;
written += n as u64;
}
Err(source) => {
return Err(DownloadAttemptError::NetworkStreamIo(source));
}
}
}
f.sync_all().map_err(mk_io)?;
Ok(written)
}
fn file_matches_sha(path: &Path, expected: &str) -> Result<bool, EmbedderLoadError> {
if !path.is_file() {
return Ok(false);
}
let observed = sha256_file(path)
.map_err(|source| EmbedderLoadError::CacheIoError { path: path.to_path_buf(), source })?;
Ok(observed == expected)
}
fn sha256_file(path: &Path) -> std::io::Result<String> {
let mut f = File::open(path)?;
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = f.read(&mut buf)?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
Ok(format!("{:x}", hasher.finalize()))
}