use std::collections::BTreeMap;
use std::sync::{Arc, Mutex};
use tokio::sync::OnceCell;
use crate::error::Result;
use crate::llm::mistralrs::MistralrsModel;
pub struct LlmRegistry<T = MistralrsModel> {
models: Mutex<BTreeMap<String, Arc<OnceCell<Arc<T>>>>>,
}
impl<T> Default for LlmRegistry<T> {
fn default() -> Self {
Self {
models: Mutex::new(BTreeMap::new()),
}
}
}
impl<T: Send + Sync + 'static> LlmRegistry<T> {
pub fn new() -> Self {
Self::default()
}
pub async fn get_or_init<F, Fut>(&self, model_name: &str, init: F) -> Result<Arc<T>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let cell = {
let mut map = self.models.lock().expect("registry mutex poisoned");
if let Some(existing) = map.get(model_name) {
existing.clone()
} else {
map.entry(model_name.to_string())
.or_insert_with(|| Arc::new(OnceCell::new()))
.clone()
}
};
let arc = cell
.get_or_try_init(|| async { init().await.map(Arc::new) })
.await?;
Ok(arc.clone())
}
}