use std::{collections::HashMap, error::Error, sync::Arc};
use tokio::sync::Mutex;
use crate::{
discover::MODEL_DISCOVERER,
engine::{EngineConfig, InferenceEngine}, };
pub struct ModelPool {
models: Mutex<HashMap<String, Arc<InferenceEngine>>>,
}
impl ModelPool {
pub fn new() -> Self {
ModelPool {
models: Mutex::new(HashMap::new()),
}
}
pub async fn get_model(
&self,
model_name: &str,
) -> Result<Arc<InferenceEngine>, Box<dyn Error>> {
{
let models_guard = self.models.lock().await; if let Some(engine_arc) = models_guard.get(model_name) {
println!("[ModelPool] Model '{}' found in pool.", model_name);
return Ok(Arc::clone(engine_arc));
}
}
println!(
"[ModelPool] Model '{}' not found in pool. Loading...",
model_name
);
let model = {
let discoverer_guard = MODEL_DISCOVERER.lock().unwrap(); discoverer_guard
.find_model(model_name)
.map_err(|e| -> Box<dyn Error> {
Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Model '{}' not found: {}", model_name, e),
)) as Box<dyn Error>
})?
};
let engine_config = EngineConfig {
n_ctx: 4096,
n_len: None, temperature: 0.8,
top_k: 40,
top_p: 0.9,
repeat_penalty: 1.1,
};
let concrete_engine = crate::engine::InferenceEngine::new(&engine_config, &model).map_err(
|e| -> Box<dyn Error> {
Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Failed to load model '{}': {}", model_name, e),
)) as Box<dyn Error>
},
)?;
#[cfg(feature = "engine-llama-cpp")]
{
llama_cpp_2::send_logs_to_tracing(
llama_cpp_2::LogOptions::default().with_logs_enabled(true),
);
}
let new_engine_arc: Arc<InferenceEngine> = Arc::new(concrete_engine);
let mut models_guard = self.models.lock().await; models_guard.insert(model_name.to_string(), Arc::clone(&new_engine_arc));
println!(
"[ModelPool] Model '{}' loaded and added to pool.",
model_name
);
Ok(new_engine_arc)
}
pub async fn unload_model(&self, model_name: &str) {
{
let mut models_guard = self.models.lock().await; if models_guard.remove(model_name).is_some() {
println!("[ModelPool] Model '{}' unloaded from pool.", model_name);
}
}
}
}