use std::collections::HashSet;
use std::sync::{OnceLock, RwLock};
use std::time::Duration;
static WARMUP_CACHE: OnceLock<RwLock<HashSet<(String, String)>>> = OnceLock::new();
fn get_cache() -> &'static RwLock<HashSet<(String, String)>> {
WARMUP_CACHE.get_or_init(|| RwLock::new(HashSet::new()))
}
fn is_warmed(server_url: &str, model: &str) -> bool {
let cache = get_cache().read().unwrap();
cache.contains(&(server_url.to_string(), model.to_string()))
}
fn set_warmed(server_url: &str, model: &str) {
let mut cache = get_cache().write().unwrap();
cache.insert((server_url.to_string(), model.to_string()));
}
pub fn clear_warmup_cache() {
let mut cache = get_cache().write().unwrap();
cache.clear();
}
pub fn preload_ollama(server_url: &str, model: &str) {
if is_warmed(server_url, model) {
crate::verbose!("Ollama model already warmed, skipping preload");
return;
}
let server_url = server_url.to_string();
let model = model.to_string();
std::thread::spawn(move || {
crate::verbose!("Preloading Ollama model '{}' in background...", model);
if let Err(e) = crate::ollama::ensure_ollama_running(&server_url) {
crate::verbose!("Ollama preload: server startup failed: {}", e);
return;
}
match crate::ollama::has_model(&server_url, &model) {
Ok(true) => {
crate::verbose!("Ollama preload: model '{}' found, warming up...", model);
}
Ok(false) => {
crate::verbose!(
"Ollama preload: model '{}' not found, skipping warmup (will pull later if needed)",
model
);
return;
}
Err(e) => {
crate::verbose!("Ollama preload: model check failed: {}", e);
return;
}
}
if let Err(e) = warm_model(&server_url, &model) {
crate::verbose!("Ollama preload: warmup request failed: {}", e);
return;
}
set_warmed(&server_url, &model);
crate::verbose!("Ollama model '{}' preloaded successfully", model);
});
}
fn warm_model(server_url: &str, model: &str) -> Result<(), String> {
let url = format!("{}/api/chat", server_url.trim_end_matches('/'));
let client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
let response = client
.post(&url)
.json(&serde_json::json!({
"model": model,
"messages": [],
"stream": false,
"keep_alive": "5m"
}))
.send()
.map_err(|e| {
if e.is_connect() {
format!("Cannot connect to Ollama at {}", server_url)
} else {
format!("Warmup request failed: {}", e)
}
})?;
if !response.status().is_success() {
return Err(format!(
"Ollama warmup failed: {} - {}",
response.status(),
response.text().unwrap_or_default()
));
}
let response_text = response.text().unwrap_or_default();
if response_text.is_empty() {
return Err("Ollama warmup returned empty response".to_string());
}
serde_json::from_str::<serde_json::Value>(&response_text)
.map_err(|e| format!("Invalid warmup response: {}", e))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
static TEST_COUNTER: AtomicUsize = AtomicUsize::new(0);
fn unique_id() -> usize {
TEST_COUNTER.fetch_add(1, Ordering::SeqCst)
}
#[test]
fn test_cache_set_and_check() {
let id = unique_id();
let url = format!("http://test-set-{}:11434", id);
let model = format!("model-set-{}", id);
assert!(!is_warmed(&url, &model));
set_warmed(&url, &model);
assert!(is_warmed(&url, &model));
}
#[test]
fn test_cache_different_urls() {
let id = unique_id();
let url1 = format!("http://test-url1-{}:11434", id);
let url2 = format!("http://test-url2-{}:11434", id);
let model = format!("model-urls-{}", id);
set_warmed(&url1, &model);
assert!(is_warmed(&url1, &model));
assert!(!is_warmed(&url2, &model));
}
#[test]
fn test_cache_different_models() {
let id = unique_id();
let url = format!("http://test-models-{}:11434", id);
let model1 = format!("model1-{}", id);
let model2 = format!("model2-{}", id);
set_warmed(&url, &model1);
assert!(is_warmed(&url, &model1));
assert!(!is_warmed(&url, &model2));
}
#[test]
fn test_clear_cache() {
let id = unique_id();
let url = format!("http://test-clear-{}:11434", id);
let model = format!("model-clear-{}", id);
set_warmed(&url, &model);
assert!(is_warmed(&url, &model));
clear_warmup_cache();
assert!(!is_warmed(&url, &model));
}
}