Skip to main content

memlink_runtime/
pool.rs

1//! Module instance pooling and load balancing.
2//!
3//! Provides multiple instances of the same module for parallel execution
4//! with load-aware request routing.
5
6use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use dashmap::DashMap;
11
12use crate::instance::ModuleInstance;
13
14/// Pool configuration.
15#[derive(Debug, Clone)]
16pub struct PoolConfig {
17    /// Minimum number of instances.
18    pub min_instances: usize,
19    /// Maximum number of instances.
20    pub max_instances: usize,
21    /// Target load factor per instance (0.0-1.0).
22    pub target_load: f32,
23    /// Scale up threshold (load factor).
24    pub scale_up_threshold: f32,
25    /// Scale down threshold (load factor).
26    pub scale_down_threshold: f32,
27    /// Cooldown between scaling operations.
28    pub scale_cooldown: Duration,
29}
30
31impl Default for PoolConfig {
32    fn default() -> Self {
33        Self {
34            min_instances: 1,
35            max_instances: 4,
36            target_load: 0.6,
37            scale_up_threshold: 0.8,
38            scale_down_threshold: 0.3,
39            scale_cooldown: Duration::from_secs(30),
40        }
41    }
42}
43
44/// Statistics for a pooled instance.
45#[derive(Debug, Clone)]
46pub struct InstanceStats {
47    pub instance_id: usize,
48    pub total_requests: u64,
49    pub pending_requests: usize,
50    pub avg_latency_us: f64,
51    pub last_request: Option<Instant>,
52    pub load_factor: f32,
53}
54
55/// A pooled module instance with load tracking.
56#[derive(Debug)]
57pub struct PooledInstance {
58    /// Instance identifier.
59    id: usize,
60    /// The module instance.
61    instance: ModuleInstance,
62    /// Total requests handled.
63    total_requests: AtomicU64,
64    /// Current pending requests.
65    pending_requests: AtomicUsize,
66    /// Total latency for averaging (microseconds).
67    total_latency_us: AtomicU64,
68    /// Latency count for averaging.
69    latency_count: AtomicU64,
70    /// Last request timestamp.
71    last_request: AtomicU64,
72}
73
74impl PooledInstance {
75    /// Creates a new pooled instance.
76    pub fn new(id: usize, instance: ModuleInstance) -> Self {
77        let now = Instant::now().duration_since(Instant::now()).as_secs();
78        Self {
79            id,
80            instance,
81            total_requests: AtomicU64::new(0),
82            pending_requests: AtomicUsize::new(0),
83            total_latency_us: AtomicU64::new(0),
84            latency_count: AtomicU64::new(0),
85            last_request: AtomicU64::new(now),
86        }
87    }
88
89    /// Records a request start.
90    pub fn record_request_start(&self) {
91        self.pending_requests.fetch_add(1, Ordering::AcqRel);
92        self.total_requests.fetch_add(1, Ordering::Relaxed);
93
94        let now = Instant::now().duration_since(Instant::now()).as_secs();
95        self.last_request.store(now, Ordering::Relaxed);
96    }
97
98    /// Records a request completion.
99    pub fn record_request_end(&self, latency_us: u64) {
100        self.pending_requests.fetch_sub(1, Ordering::AcqRel);
101        self.total_latency_us.fetch_add(latency_us, Ordering::Relaxed);
102        self.latency_count.fetch_add(1, Ordering::Relaxed);
103    }
104
105    /// Returns the instance ID.
106    pub fn id(&self) -> usize {
107        self.id
108    }
109
110    /// Returns the module instance.
111    pub fn instance(&self) -> &ModuleInstance {
112        &self.instance
113    }
114
115    /// Returns the load factor (0.0-1.0).
116    pub fn load_factor(&self) -> f32 {
117        let pending = self.pending_requests.load(Ordering::Acquire) as f32;
118        (pending / 100.0).min(1.0)
119    }
120
121    /// Returns statistics.
122    pub fn stats(&self) -> InstanceStats {
123        let total_latency = self.total_latency_us.load(Ordering::Acquire);
124        let count = self.latency_count.load(Ordering::Acquire);
125
126        InstanceStats {
127            instance_id: self.id,
128            total_requests: self.total_requests.load(Ordering::Acquire),
129            pending_requests: self.pending_requests.load(Ordering::Acquire),
130            avg_latency_us: if count > 0 { total_latency as f64 / count as f64 } else { 0.0 },
131            last_request: None, // Simplified
132            load_factor: self.load_factor(),
133        }
134    }
135
136    /// Returns the module instance (mutable).
137    pub fn instance_mut(&mut self) -> &mut ModuleInstance {
138        &mut self.instance
139    }
140}
141
142/// Load balancing strategy.
143#[derive(Debug, Clone, Copy, Default)]
144pub enum LoadBalanceStrategy {
145    /// Route to instance with lowest load.
146    #[default]
147    LeastLoaded,
148    /// Route to instances in round-robin order.
149    RoundRobin,
150    /// Route randomly.
151    Random,
152}
153
154/// Module instance pool with load balancing.
155#[derive(Debug)]
156pub struct ModulePool {
157    /// Module name.
158    module_name: String,
159    /// Pooled instances.
160    instances: DashMap<usize, Arc<PooledInstance>>,
161    /// Pool configuration.
162    config: PoolConfig,
163    /// Load balancing strategy.
164    strategy: LoadBalanceStrategy,
165    /// Round-robin counter.
166    rr_counter: AtomicUsize,
167}
168
169impl ModulePool {
170    /// Creates a new module pool.
171    pub fn new(module_name: String, config: PoolConfig) -> Self {
172        Self {
173            module_name,
174            instances: DashMap::new(),
175            config,
176            strategy: LoadBalanceStrategy::default(),
177            rr_counter: AtomicUsize::new(0),
178        }
179    }
180
181    /// Creates a pool with a single instance.
182    pub fn with_instance(module_name: String, instance: ModuleInstance) -> Self {
183        let pool = Self::new(module_name, PoolConfig::default());
184        pool.add_instance(instance);
185        pool
186    }
187
188    /// Adds an instance to the pool.
189    pub fn add_instance(&self, instance: ModuleInstance) -> usize {
190        let id = self.instances.len();
191        let pooled = Arc::new(PooledInstance::new(id, instance));
192        self.instances.insert(id, pooled);
193        id
194    }
195
196    /// Removes an instance from the pool.
197    pub fn remove_instance(&self, id: usize) -> Option<Arc<PooledInstance>> {
198        if self.instances.len() <= self.config.min_instances {
199            return None; // Can't go below minimum
200        }
201        self.instances.remove(&id).map(|(_, v)| v)
202    }
203
204    /// Selects the best instance for a new request.
205    pub fn select_instance(&self) -> Option<Arc<PooledInstance>> {
206        if self.instances.is_empty() {
207            return None;
208        }
209
210        match self.strategy {
211            LoadBalanceStrategy::LeastLoaded => self.select_least_loaded(),
212            LoadBalanceStrategy::RoundRobin => self.select_round_robin(),
213            LoadBalanceStrategy::Random => self.select_random(),
214        }
215    }
216
217    /// Selects the instance with lowest load.
218    fn select_least_loaded(&self) -> Option<Arc<PooledInstance>> {
219        let mut best: Option<(usize, f32)> = None;
220
221        for entry in self.instances.iter() {
222            let load = entry.value().load_factor();
223            if best.is_none() || load < best.unwrap().1 {
224                best = Some((*entry.key(), load));
225            }
226        }
227
228        best.and_then(|(id, _)| self.instances.get(&id).map(|e| e.clone()))
229    }
230
231    /// Selects instance in round-robin order.
232    fn select_round_robin(&self) -> Option<Arc<PooledInstance>> {
233        let count = self.instances.len();
234        if count == 0 {
235            return None;
236        }
237
238        let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % count;
239        self.instances.get(&idx).map(|e| e.clone())
240    }
241
242    /// Selects a random instance.
243    fn select_random(&self) -> Option<Arc<PooledInstance>> {
244        let count = self.instances.len();
245        if count == 0 {
246            return None;
247        }
248
249        // Simple hash-based selection
250        use std::collections::hash_map::RandomState;
251        use std::hash::{BuildHasher, Hasher};
252
253        let hasher = RandomState::new().build_hasher();
254        let idx = hasher.finish() as usize % count;
255        self.instances.get(&idx).map(|e| e.clone())
256    }
257
258    /// Returns the average load factor across all instances.
259    pub fn avg_load(&self) -> f32 {
260        let mut total = 0.0;
261        let count = self.instances.len();
262
263        if count == 0 {
264            return 0.0;
265        }
266
267        for entry in self.instances.iter() {
268            total += entry.value().load_factor();
269        }
270
271        total / count as f32
272    }
273
274    /// Checks if the pool should scale up.
275    pub fn should_scale_up(&self) -> bool {
276        if self.instances.len() >= self.config.max_instances {
277            return false;
278        }
279
280        let avg_load = self.avg_load();
281        avg_load >= self.config.scale_up_threshold
282    }
283
284    /// Checks if the pool should scale down.
285    pub fn should_scale_down(&self) -> bool {
286        if self.instances.len() <= self.config.min_instances {
287            return false;
288        }
289
290        let avg_load = self.avg_load();
291        avg_load <= self.config.scale_down_threshold
292    }
293
294    /// Returns pool statistics.
295    pub fn stats(&self) -> PoolStats {
296        let mut total_requests = 0;
297        let mut total_pending = 0;
298        let mut instance_stats = Vec::new();
299
300        for entry in self.instances.iter() {
301            let stats = entry.value().stats();
302            total_requests += stats.total_requests;
303            total_pending += stats.pending_requests;
304            instance_stats.push(stats);
305        }
306
307        PoolStats {
308            module_name: self.module_name.clone(),
309            instance_count: self.instances.len(),
310            total_requests,
311            total_pending,
312            avg_load: self.avg_load(),
313            instances: instance_stats,
314        }
315    }
316
317    /// Returns the module name.
318    pub fn module_name(&self) -> &str {
319        &self.module_name
320    }
321
322    /// Returns the number of instances.
323    pub fn instance_count(&self) -> usize {
324        self.instances.len()
325    }
326
327    /// Returns all instances.
328    pub fn all_instances(&self) -> Vec<Arc<PooledInstance>> {
329        self.instances.iter().map(|e| e.value().clone()).collect()
330    }
331}
332
333/// Pool statistics.
334#[derive(Debug, Clone)]
335pub struct PoolStats {
336    pub module_name: String,
337    pub instance_count: usize,
338    pub total_requests: u64,
339    pub total_pending: usize,
340    pub avg_load: f32,
341    pub instances: Vec<InstanceStats>,
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_pool_creation() {
350        let pool = ModulePool::new("test".to_string(), PoolConfig::default());
351        assert_eq!(pool.instance_count(), 0);
352        assert_eq!(pool.module_name(), "test");
353    }
354
355    #[test]
356    fn test_pool_add_instance() {
357        // Note: Can't create real ModuleInstance without loading a module
358        // This test verifies the pool structure
359        let pool = ModulePool::new("test".to_string(), PoolConfig::default());
360        // Would add instance here if we had a mock
361        assert_eq!(pool.instance_count(), 0);
362    }
363
364    #[test]
365    fn test_pool_config() {
366        let config = PoolConfig {
367            min_instances: 2,
368            max_instances: 8,
369            target_load: 0.5,
370            scale_up_threshold: 0.7,
371            scale_down_threshold: 0.2,
372            scale_cooldown: Duration::from_secs(60),
373        };
374
375        assert_eq!(config.min_instances, 2);
376        assert_eq!(config.max_instances, 8);
377        assert_eq!(config.target_load, 0.5);
378    }
379}