use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::{OnceCell, Semaphore};
use crate::summarizer::backend::{CompactOpts, SummarizerBackend};
use crate::summarizer::error::BackendError;
use crate::summarizer::prompts;
use crate::tokenizer::Tokenizer;
pub struct LocalMistralRs {
name: String,
repo_id: String,
model: OnceCell<Arc<mistralrs::Model>>,
permit: Arc<Semaphore>,
#[allow(dead_code)]
tokenizer: Tokenizer,
}
impl std::fmt::Debug for LocalMistralRs {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalMistralRs")
.field("name", &self.name)
.field("repo_id", &self.repo_id)
.field("loaded", &self.model.get().is_some())
.finish()
}
}
impl LocalMistralRs {
pub fn new(name: &str, repo_id: &str, tokenizer: Tokenizer) -> Self {
Self {
name: name.to_string(),
repo_id: repo_id.to_string(),
model: OnceCell::new(),
permit: Arc::new(Semaphore::new(1)),
tokenizer,
}
}
#[allow(clippy::print_stderr)]
async fn model_get_or_load(&self) -> Result<Arc<mistralrs::Model>, BackendError> {
if let Some(m) = self.model.get() {
return Ok(m.clone());
}
let was_cached = hf_cache_has(&self.repo_id);
if !was_cached {
eprintln!(
"downloading {} from HuggingFace; cached at {} — this may take several minutes",
self.repo_id,
hf_cache_root().display(),
);
}
if was_cached {
crate::model_integrity::enforce(&self.repo_id).map_err(integrity_to_backend_error)?;
}
let arc = self
.model
.get_or_try_init(|| async {
let model = mistralrs::ModelBuilder::new(&self.repo_id)
.with_auto_isq(mistralrs::IsqBits::Eight)
.with_logging()
.build()
.await
.map_err(|e| BackendError::Unavailable(format!("model load failed: {e}")))?;
Ok::<Arc<mistralrs::Model>, BackendError>(Arc::new(model))
})
.await?;
if !was_cached {
crate::model_integrity::record_fresh_download(&self.repo_id);
}
Ok(arc.clone())
}
}
#[async_trait]
impl SummarizerBackend for LocalMistralRs {
fn name(&self) -> &str {
&self.name
}
fn model_id(&self) -> &str {
&self.repo_id
}
fn uses_model_prompt(&self) -> bool {
true
}
async fn compact(&self, content: &str, opts: &CompactOpts) -> Result<String, BackendError> {
let _guard = self
.permit
.acquire()
.await
.map_err(|_| BackendError::Unavailable("semaphore closed".into()))?;
let model = self.model_get_or_load().await?;
let parts = prompts::render_abstractive(opts, content);
let messages = mistralrs::TextMessages::new()
.add_message(mistralrs::TextMessageRole::System, &parts.system)
.add_message(mistralrs::TextMessageRole::User, &parts.user);
let max_len = opts.target_tokens.unwrap_or(4096);
let request = mistralrs::RequestBuilder::from(messages).set_sampler_max_len(max_len);
let resp = model
.send_chat_request(request)
.await
.map_err(|e| BackendError::ModelError(format!("inference failed: {e}")))?;
let text = resp
.choices
.first()
.and_then(|c| c.message.content.as_ref())
.ok_or_else(|| BackendError::ModelError("empty response".into()))?
.clone();
Ok(text.trim().to_string())
}
}
fn integrity_to_backend_error(e: crate::model_integrity::IntegrityError) -> BackendError {
match e {
crate::model_integrity::IntegrityError::ModelIntegrityFailure {
file,
expected,
actual,
} => BackendError::ModelIntegrityFailure {
file,
expected,
actual,
},
other => BackendError::Unavailable(format!("model integrity check failed: {other}")),
}
}
pub fn hf_cache_has(repo_id: &str) -> bool {
let path = hf_cache_root().join(format!("models--{}", repo_id.replace('/', "--"),));
path.exists()
&& path
.read_dir()
.map(|mut d| d.next().is_some())
.unwrap_or(false)
}
pub fn hf_cache_root() -> std::path::PathBuf {
if let Ok(p) = std::env::var("HF_HOME") {
return std::path::PathBuf::from(p).join("hub");
}
if let Ok(home) = std::env::var("HOME") {
return std::path::PathBuf::from(home).join(".cache/huggingface/hub");
}
std::path::PathBuf::from(".cache/huggingface/hub")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hf_cache_root_respects_hf_home_env() {
let _g = crate::model_integrity::HF_HOME_TEST_MUTEX
.lock()
.unwrap_or_else(|e| e.into_inner());
let tmp = tempfile::tempdir().unwrap();
let prior = std::env::var("HF_HOME").ok();
unsafe { std::env::set_var("HF_HOME", tmp.path()) };
let root = hf_cache_root();
assert_eq!(root, tmp.path().join("hub"));
unsafe {
if let Some(p) = prior {
std::env::set_var("HF_HOME", p);
} else {
std::env::remove_var("HF_HOME");
}
}
}
#[test]
fn hf_cache_has_returns_false_for_missing_repo() {
let _g = crate::model_integrity::HF_HOME_TEST_MUTEX
.lock()
.unwrap_or_else(|e| e.into_inner());
let tmp = tempfile::tempdir().unwrap();
let prior = std::env::var("HF_HOME").ok();
unsafe { std::env::set_var("HF_HOME", tmp.path()) };
assert!(!hf_cache_has("Qwen/Qwen3.5-0.8B"));
unsafe {
if let Some(p) = prior {
std::env::set_var("HF_HOME", p);
} else {
std::env::remove_var("HF_HOME");
}
}
}
}