use kapsl_engine_api::{Engine, EngineError};
use lru::LruCache;
use serde::{Deserialize, Serialize};
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
type EngineCache = Arc<Mutex<LruCache<(u32, usize), Arc<dyn Engine>>>>;
type EvictionCallback = Arc<dyn Fn(u32, usize, Arc<dyn Engine>) + Send + Sync>;
type EvictionCallbackSlot = Arc<Mutex<Option<EvictionCallback>>>;
#[derive(Debug, Clone, Copy)]
pub struct PoolMetrics {
pub hit_rate: f64,
pub hit: u64,
pub evictions: u64, pub failure: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnginePoolConfig {
#[serde(default = "default_max_size")]
pub max_size: usize,
#[serde(default = "default_min_size")]
pub min_size: usize,
#[serde(default = "default_ttl")]
pub ttl: Duration,
#[serde(default = "default_health_check_interval")]
pub health_check_interval: Duration,
#[serde(default)]
pub warmup_enabled: bool,
#[serde(default)]
pub warmup_size: usize, }
fn default_max_size() -> usize {
5
}
fn default_min_size() -> usize {
1
}
fn default_ttl() -> Duration {
Duration::from_secs(60)
}
fn default_health_check_interval() -> Duration {
Duration::from_secs(10)
}
impl Default for EnginePoolConfig {
fn default() -> Self {
Self {
max_size: default_max_size(),
min_size: default_min_size(),
ttl: default_ttl(),
health_check_interval: default_health_check_interval(),
warmup_enabled: true,
warmup_size: default_min_size(),
}
}
}
#[derive(Clone)]
pub struct EnginePool {
config: EnginePoolConfig,
metrics: Arc<Mutex<PoolMetrics>>,
cache: EngineCache,
eviction_callback: EvictionCallbackSlot,
}
impl EnginePool {
pub fn new(config: EnginePoolConfig) -> Self {
let capacity = NonZeroUsize::new(config.max_size).unwrap_or(NonZeroUsize::new(1).unwrap());
Self {
config,
cache: Arc::new(Mutex::new(LruCache::new(capacity))),
metrics: Arc::new(Mutex::new(PoolMetrics {
hit_rate: 0.0,
hit: 0,
evictions: 0,
failure: 0,
})),
eviction_callback: Arc::new(Mutex::new(None)),
}
}
pub async fn set_eviction_callback<F>(&self, cb: F)
where
F: Fn(u32, usize, Arc<dyn Engine>) + Send + Sync + 'static,
{
let mut guard = self.eviction_callback.lock().await;
*guard = Some(Arc::new(cb));
}
pub async fn clear_eviction_callback(&self) {
let mut guard = self.eviction_callback.lock().await;
*guard = None;
}
pub fn start_health_check_task(&self) -> tokio::task::JoinHandle<()> {
let pool = self.clone(); let interval = self.config.health_check_interval;
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
loop {
ticker.tick().await;
log::debug!("Running background health checks...");
let keys: Vec<(u32, usize)> = {
let cache = pool.cache.lock().await;
cache.iter().map(|(k, _)| *k).collect()
};
for (model_id, device_id) in keys {
if let Some(_engine) = pool.get(model_id, device_id).await {
log::trace!("Engine ({}, {}) is healthy", model_id, device_id);
}
}
}
})
}
pub fn max_size(&self) -> usize {
self.config.max_size
}
pub fn min_size(&self) -> usize {
self.config.min_size
}
pub fn ttl(&self) -> Duration {
self.config.ttl
}
pub fn health_check_interval(&self) -> Duration {
self.config.health_check_interval
}
pub async fn get(&self, model_id: u32, device_id: usize) -> Option<Arc<dyn Engine>> {
let mut cache = self.cache.lock().await;
if let Some(engine) = cache.get(&(model_id, device_id)) {
match engine.health_check() {
Ok(()) => {
self.metrics.lock().await.hit += 1;
Some(engine.clone())
}
Err(e) => {
log::warn!(
"Engine (model_id={}, device_id={}) failed health check: {}. Removing from pool.",
model_id,
device_id,
e
);
self.metrics.lock().await.failure += 1;
cache.pop(&(model_id, device_id));
None
}
}
} else {
None
}
}
pub async fn put(&self, model_id: u32, device_id: usize, engine: Arc<dyn Engine>) {
let evicted_entry = {
let mut cache = self.cache.lock().await;
cache.push((model_id, device_id), engine)
};
if let Some((evicted_key, evicted_engine)) = evicted_entry {
let (evicted_model_id, evicted_device_id) = evicted_key;
{
let mut metrics = self.metrics.lock().await;
metrics.evictions += 1;
log::info!(
"Engine evicted from pool for model_id={}, device_id={}. Evictions total={}",
evicted_model_id,
evicted_device_id,
metrics.evictions
);
}
let cb_opt = self.eviction_callback.lock().await.clone();
if let Some(cb) = cb_opt {
tokio::spawn(async move {
(cb)(evicted_model_id, evicted_device_id, evicted_engine);
});
}
}
}
pub async fn remove(&self, model_id: u32, device_id: usize) {
let mut cache = self.cache.lock().await;
cache.pop(&(model_id, device_id));
}
pub async fn len(&self) -> usize {
let cache = self.cache.lock().await;
cache.len()
}
pub async fn is_empty(&self) -> bool {
self.cache.lock().await.is_empty()
}
pub async fn warmup<F, Fut>(
&self,
engine_configs: Vec<(u32, usize)>, engine_factory: F,
) -> Result<(), EngineError>
where
F: Fn(u32, usize) -> Fut,
Fut: std::future::Future<Output = Result<Arc<dyn Engine>, EngineError>>,
{
log::info!("Starting pool warmup with {} engines", engine_configs.len());
for (model_id, device_id) in engine_configs {
match engine_factory(model_id, device_id).await {
Ok(engine) => {
self.put(model_id, device_id, engine).await;
log::info!(
"Warmed up engine for model_id={}, device_id={}",
model_id,
device_id
);
}
Err(e) => {
log::warn!(
"Failed to warm up engine for model_id={}, device_id={}: {}",
model_id,
device_id,
e
);
}
}
}
log::info!("Pool warmup complete. Pool size: {}", self.len().await);
Ok(())
}
pub async fn pool_metrics(&self) -> PoolMetrics {
let mut metrics = self.metrics.lock().await;
metrics.hit_rate = (metrics.hit as f64) / (metrics.hit + metrics.failure) as f64;
*metrics
}
}
#[cfg(test)]
#[path = "engine_pool_tests.rs"]
mod engine_pool_tests;