Skip to main content

kapsl_backends/
engine_pool.rs

1use kapsl_engine_api::{Engine, EngineError};
2use lru::LruCache;
3use serde::{Deserialize, Serialize};
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::Mutex;
8
9type EngineCache = Arc<Mutex<LruCache<(u32, usize), Arc<dyn Engine>>>>;
10type EvictionCallback = Arc<dyn Fn(u32, usize, Arc<dyn Engine>) + Send + Sync>;
11type EvictionCallbackSlot = Arc<Mutex<Option<EvictionCallback>>>;
12
13#[derive(Debug, Clone, Copy)]
14pub struct PoolMetrics {
15    pub hit_rate: f64,
16    pub hit: u64,
17    pub evictions: u64, // Total number of evictions
18    pub failure: u64,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct EnginePoolConfig {
23    /// Maximum number of engines to keep in the pool
24    #[serde(default = "default_max_size")]
25    pub max_size: usize,
26
27    /// Minimum number of engines to keep in the pool
28    #[serde(default = "default_min_size")]
29    pub min_size: usize,
30
31    /// Time-to-live for engines in the pool
32    #[serde(default = "default_ttl")]
33    pub ttl: Duration,
34
35    /// Health check interval
36    #[serde(default = "default_health_check_interval")]
37    pub health_check_interval: Duration,
38
39    // NEW: Warmup configuration
40    #[serde(default)]
41    pub warmup_enabled: bool,
42
43    #[serde(default)]
44    pub warmup_size: usize, // How many engines to pre-create (usually = min_size)
45}
46
47fn default_max_size() -> usize {
48    5
49}
50
51fn default_min_size() -> usize {
52    1
53}
54
55fn default_ttl() -> Duration {
56    Duration::from_secs(60)
57}
58
59fn default_health_check_interval() -> Duration {
60    Duration::from_secs(10)
61}
62
63impl Default for EnginePoolConfig {
64    fn default() -> Self {
65        Self {
66            max_size: default_max_size(),
67            min_size: default_min_size(),
68            ttl: default_ttl(),
69            health_check_interval: default_health_check_interval(),
70            warmup_enabled: true,
71            warmup_size: default_min_size(),
72        }
73    }
74}
75
76/// A pool for reusing backend engine instances
77#[derive(Clone)]
78pub struct EnginePool {
79    config: EnginePoolConfig,
80    metrics: Arc<Mutex<PoolMetrics>>,
81    // LRU cache mapping (model_id, device_id) -> Engine
82    // We use a tuple key to distinguish instances on different devices
83    cache: EngineCache,
84    // Optional eviction callback called when an engine is evicted from the pool.
85    // Signature: (model_id, device_id, evicted_engine)
86    eviction_callback: EvictionCallbackSlot,
87}
88
89impl EnginePool {
90    pub fn new(config: EnginePoolConfig) -> Self {
91        let capacity = NonZeroUsize::new(config.max_size).unwrap_or(NonZeroUsize::new(1).unwrap());
92        Self {
93            config,
94            cache: Arc::new(Mutex::new(LruCache::new(capacity))),
95            metrics: Arc::new(Mutex::new(PoolMetrics {
96                hit_rate: 0.0,
97                hit: 0,
98                evictions: 0,
99                failure: 0,
100            })),
101            eviction_callback: Arc::new(Mutex::new(None)),
102        }
103    }
104
105    /// Set an eviction callback that will be invoked whenever an engine is evicted.
106    /// The callback receives (model_id, device_id, evicted_engine).
107    pub async fn set_eviction_callback<F>(&self, cb: F)
108    where
109        F: Fn(u32, usize, Arc<dyn Engine>) + Send + Sync + 'static,
110    {
111        let mut guard = self.eviction_callback.lock().await;
112        *guard = Some(Arc::new(cb));
113    }
114
115    /// Clear any previously set eviction callback.
116    pub async fn clear_eviction_callback(&self) {
117        let mut guard = self.eviction_callback.lock().await;
118        *guard = None;
119    }
120
121    pub fn start_health_check_task(&self) -> tokio::task::JoinHandle<()> {
122        let pool = self.clone(); // Clone the Arc references
123        let interval = self.config.health_check_interval;
124
125        tokio::spawn(async move {
126            let mut ticker = tokio::time::interval(interval);
127
128            loop {
129                ticker.tick().await;
130                log::debug!("Running background health checks...");
131
132                // Get all keys in the cache
133                let keys: Vec<(u32, usize)> = {
134                    let cache = pool.cache.lock().await;
135                    cache.iter().map(|(k, _)| *k).collect()
136                };
137
138                // Check each engine's health
139                for (model_id, device_id) in keys {
140                    if let Some(_engine) = pool.get(model_id, device_id).await {
141                        // get() already does health check, so this removes unhealthy ones
142                        log::trace!("Engine ({}, {}) is healthy", model_id, device_id);
143                    }
144                }
145            }
146        })
147    }
148
149    pub fn max_size(&self) -> usize {
150        self.config.max_size
151    }
152
153    pub fn min_size(&self) -> usize {
154        self.config.min_size
155    }
156
157    pub fn ttl(&self) -> Duration {
158        self.config.ttl
159    }
160
161    pub fn health_check_interval(&self) -> Duration {
162        self.config.health_check_interval
163    }
164
165    /// Get an existing engine from the pool if available
166    /// Returns None if engine doesn't exist or fails health check
167    pub async fn get(&self, model_id: u32, device_id: usize) -> Option<Arc<dyn Engine>> {
168        let mut cache = self.cache.lock().await;
169
170        if let Some(engine) = cache.get(&(model_id, device_id)) {
171            // Perform health check before returning
172            match engine.health_check() {
173                Ok(()) => {
174                    // Engine is healthy, return it
175                    self.metrics.lock().await.hit += 1;
176                    Some(engine.clone())
177                }
178                Err(e) => {
179                    // Engine failed health check, remove from pool
180                    log::warn!(
181                        "Engine (model_id={}, device_id={}) failed health check: {}. Removing from pool.",
182                        model_id,
183                        device_id,
184                        e
185                    );
186                    self.metrics.lock().await.failure += 1;
187                    cache.pop(&(model_id, device_id));
188                    None
189                }
190            }
191        } else {
192            None
193        }
194    }
195
196    /// Add an engine to the pool
197    ///
198    /// If adding this engine causes an eviction, the eviction callback (if set) will be
199    /// invoked asynchronously in a separate task to avoid blocking pool operations.
200    pub async fn put(&self, model_id: u32, device_id: usize, engine: Arc<dyn Engine>) {
201        // First, handle cache update and get evicted engine (if any)
202        let evicted_entry = {
203            let mut cache = self.cache.lock().await;
204            cache.push((model_id, device_id), engine)
205        }; // cache lock released here
206
207        if let Some((evicted_key, evicted_engine)) = evicted_entry {
208            let (evicted_model_id, evicted_device_id) = evicted_key;
209            // Update metrics without holding cache lock
210            {
211                let mut metrics = self.metrics.lock().await;
212                metrics.evictions += 1;
213                log::info!(
214                    "Engine evicted from pool for model_id={}, device_id={}. Evictions total={}",
215                    evicted_model_id,
216                    evicted_device_id,
217                    metrics.evictions
218                );
219            } // metrics lock released here
220
221            // Invoke callback without holding any locks to prevent deadlocks
222            let cb_opt = self.eviction_callback.lock().await.clone();
223            if let Some(cb) = cb_opt {
224                // Spawn callback in separate task to avoid blocking pool operations
225                tokio::spawn(async move {
226                    (cb)(evicted_model_id, evicted_device_id, evicted_engine);
227                });
228            }
229        }
230    }
231
232    /// Remove an engine from the pool (e.g. on error or unload)
233    pub async fn remove(&self, model_id: u32, device_id: usize) {
234        let mut cache = self.cache.lock().await;
235        cache.pop(&(model_id, device_id));
236    }
237
238    pub async fn len(&self) -> usize {
239        let cache = self.cache.lock().await;
240        cache.len()
241    }
242
243    pub async fn is_empty(&self) -> bool {
244        self.cache.lock().await.is_empty()
245    }
246
247    /// Warm up the pool by pre-creating engines
248    ///
249    /// `engine_factory` is an async function that creates an engine for given (model_id, device_id)
250    pub async fn warmup<F, Fut>(
251        &self,
252        engine_configs: Vec<(u32, usize)>, // Vec of (model_id, device_id) pairs
253        engine_factory: F,
254    ) -> Result<(), EngineError>
255    where
256        F: Fn(u32, usize) -> Fut,
257        Fut: std::future::Future<Output = Result<Arc<dyn Engine>, EngineError>>,
258    {
259        log::info!("Starting pool warmup with {} engines", engine_configs.len());
260
261        for (model_id, device_id) in engine_configs {
262            match engine_factory(model_id, device_id).await {
263                Ok(engine) => {
264                    self.put(model_id, device_id, engine).await;
265                    log::info!(
266                        "Warmed up engine for model_id={}, device_id={}",
267                        model_id,
268                        device_id
269                    );
270                }
271                Err(e) => {
272                    log::warn!(
273                        "Failed to warm up engine for model_id={}, device_id={}: {}",
274                        model_id,
275                        device_id,
276                        e
277                    );
278                    // Continue warming up other engines even if one fails
279                }
280            }
281        }
282
283        log::info!("Pool warmup complete. Pool size: {}", self.len().await);
284        Ok(())
285    }
286
287    pub async fn pool_metrics(&self) -> PoolMetrics {
288        let mut metrics = self.metrics.lock().await;
289        metrics.hit_rate = (metrics.hit as f64) / (metrics.hit + metrics.failure) as f64;
290        *metrics
291    }
292}
293
294#[cfg(test)]
295#[path = "engine_pool_tests.rs"]
296mod engine_pool_tests;