Skip to main content

cbtop/predictive_scheduler/
scheduler.rs

1//! Predictive scheduler implementation with SLO-aware workload placement.
2
3use std::collections::HashMap;
4use std::time::{Duration, Instant};
5
6use super::types::{
7    HostProfile, InstanceType, PredictiveSchedulerConfig, SchedulerMetrics, SchedulingDecision,
8    WorkloadSpec,
9};
10
11/// Predictive scheduling optimizer
12pub struct PredictiveScheduler {
13    config: PredictiveSchedulerConfig,
14    hosts: HashMap<String, HostProfile>,
15    metrics: SchedulerMetrics,
16    /// Historical execution times per host
17    execution_history: HashMap<String, Vec<Duration>>,
18    /// SLO violation history per host
19    violation_history: HashMap<String, Vec<bool>>,
20}
21
22impl PredictiveScheduler {
23    /// Create a new predictive scheduler
24    pub fn new(config: PredictiveSchedulerConfig) -> Self {
25        Self {
26            config,
27            hosts: HashMap::new(),
28            metrics: SchedulerMetrics::default(),
29            execution_history: HashMap::new(),
30            violation_history: HashMap::new(),
31        }
32    }
33
34    /// Register a host with the scheduler
35    pub fn register_host(&mut self, profile: HostProfile) {
36        let host_id = profile.host_id.clone();
37        self.hosts.insert(host_id.clone(), profile);
38        self.execution_history.insert(host_id.clone(), Vec::new());
39        self.violation_history.insert(host_id, Vec::new());
40    }
41
42    /// Remove a host from the scheduler
43    pub fn deregister_host(&mut self, host_id: &str) {
44        self.hosts.remove(host_id);
45        self.execution_history.remove(host_id);
46        self.violation_history.remove(host_id);
47    }
48
49    /// Update host load
50    pub fn update_host_load(&mut self, host_id: &str, load: f64) {
51        if let Some(host) = self.hosts.get_mut(host_id) {
52            host.current_load = load.clamp(0.0, 1.0);
53        }
54    }
55
56    /// Update host preemption deadline
57    pub fn update_preemption_deadline(&mut self, host_id: &str, deadline: Option<Instant>) {
58        if let Some(host) = self.hosts.get_mut(host_id) {
59            host.preemption_deadline = deadline;
60        }
61    }
62
63    /// Schedule a workload to optimal host
64    pub fn schedule(&mut self, workload: &WorkloadSpec) -> Option<SchedulingDecision> {
65        let start = Instant::now();
66
67        // Filter eligible hosts
68        let eligible_hosts: Vec<_> = self
69            .hosts
70            .values()
71            .filter(|h| self.is_host_eligible(h, workload))
72            .collect();
73
74        if eligible_hosts.is_empty() {
75            return None;
76        }
77
78        // Score each host
79        let mut best_decision: Option<SchedulingDecision> = None;
80        let mut best_score = f64::NEG_INFINITY;
81
82        for host in eligible_hosts {
83            let decision = self.evaluate_host(host, workload);
84            if decision.score > best_score {
85                best_score = decision.score;
86                best_decision = Some(decision);
87            }
88        }
89
90        // Update metrics
91        if let Some(ref decision) = best_decision {
92            self.metrics.total_decisions += 1;
93            let scheduling_time = start.elapsed().as_micros() as f64;
94            let n = self.metrics.total_decisions as f64;
95            self.metrics.avg_scheduling_latency_us =
96                self.metrics.avg_scheduling_latency_us * (n - 1.0) / n + scheduling_time / n;
97
98            // Track spot savings
99            if let Some(host) = self.hosts.get(&decision.host_id) {
100                if host.instance_type == InstanceType::Spot {
101                    let on_demand_cost =
102                        decision.predicted_cost / host.instance_type.cost_multiplier();
103                    self.metrics.spot_savings += on_demand_cost - decision.predicted_cost;
104                }
105            }
106        }
107
108        best_decision
109    }
110
111    /// Check if host is eligible for workload
112    fn is_host_eligible(&self, host: &HostProfile, workload: &WorkloadSpec) -> bool {
113        // Check capacity
114        if host.current_load >= self.config.min_capacity_threshold {
115            return false;
116        }
117
118        // Check memory
119        if host.memory_capacity < workload.memory_required {
120            return false;
121        }
122
123        // Check preemption safety
124        if !host.is_safe_for_scheduling(self.config.preemption_buffer) {
125            return false;
126        }
127
128        // Check spot instance policy
129        if host.instance_type == InstanceType::Spot && !self.config.enable_spot_instances {
130            return false;
131        }
132
133        true
134    }
135
136    /// Evaluate a host for workload placement
137    fn evaluate_host(&self, host: &HostProfile, workload: &WorkloadSpec) -> SchedulingDecision {
138        let predicted_time = self.predict_execution_time(host, workload);
139        let slo_compliance_prob = self.predict_slo_compliance(host, workload, predicted_time);
140        let predicted_cost = self.calculate_cost(host, workload, predicted_time);
141
142        // Multi-objective scoring
143        let score = self.calculate_score(host, slo_compliance_prob, predicted_cost, workload);
144
145        let reason = self.generate_reason(host, slo_compliance_prob, predicted_cost);
146
147        SchedulingDecision {
148            host_id: host.host_id.clone(),
149            predicted_time,
150            predicted_cost,
151            slo_compliance_prob,
152            score,
153            reason,
154        }
155    }
156
157    /// Predict execution time using historical data
158    fn predict_execution_time(&self, host: &HostProfile, workload: &WorkloadSpec) -> Duration {
159        let base_estimate = workload.estimated_execution_time(host);
160
161        // Adjust based on historical variance
162        if let Some(history) = self.execution_history.get(&host.host_id) {
163            if !history.is_empty() {
164                // Use exponential smoothing on historical data
165                let alpha = 0.3;
166                let mut smoothed = history[0].as_secs_f64();
167                for duration in history.iter().skip(1) {
168                    smoothed = alpha * duration.as_secs_f64() + (1.0 - alpha) * smoothed;
169                }
170
171                // Blend historical with estimate
172                let blended = 0.7 * base_estimate.as_secs_f64() + 0.3 * smoothed;
173                return Duration::from_secs_f64(blended);
174            }
175        }
176
177        // Add safety margin based on performance CV
178        let margin = 1.0 + host.performance_cv;
179        Duration::from_secs_f64(base_estimate.as_secs_f64() * margin)
180    }
181
182    /// Predict SLO compliance probability
183    pub(super) fn predict_slo_compliance(
184        &self,
185        host: &HostProfile,
186        workload: &WorkloadSpec,
187        predicted_time: Duration,
188    ) -> f64 {
189        // Base compliance from time vs deadline
190        let time_ratio = predicted_time.as_secs_f64() / workload.slo_deadline.as_secs_f64();
191
192        // Sigmoid function for compliance probability
193        // P(comply) = 1 / (1 + exp(k * (time_ratio - 1)))
194        let k = 10.0; // Steepness
195        let base_prob = 1.0 / (1.0 + (k * (time_ratio - 0.9)).exp());
196
197        // Adjust for host reliability
198        let reliability_factor = host.instance_type.reliability();
199
200        // Adjust for historical compliance
201        let historical_factor = if let Some(history) = self.violation_history.get(&host.host_id) {
202            if history.len() >= 10 {
203                let recent: Vec<_> = history.iter().rev().take(10).collect();
204                let violations = recent.iter().filter(|&&v| *v).count();
205                1.0 - (violations as f64 / 10.0)
206            } else {
207                host.historical_slo_compliance
208            }
209        } else {
210            host.historical_slo_compliance
211        };
212
213        base_prob * reliability_factor * historical_factor
214    }
215
216    /// Calculate execution cost
217    fn calculate_cost(
218        &self,
219        host: &HostProfile,
220        _workload: &WorkloadSpec,
221        predicted_time: Duration,
222    ) -> f64 {
223        let hours = predicted_time.as_secs_f64() / 3600.0;
224        let base_cost = host.hourly_cost * host.instance_type.cost_multiplier() * hours;
225
226        // Add network cost based on latency
227        let network_cost = host.network_latency_ms * 0.0001; // Small factor for latency
228
229        base_cost + network_cost
230    }
231
232    /// Calculate multi-objective score
233    fn calculate_score(
234        &self,
235        host: &HostProfile,
236        slo_compliance_prob: f64,
237        predicted_cost: f64,
238        workload: &WorkloadSpec,
239    ) -> f64 {
240        // Priority weighting
241        let priority_weight = 1.0 + (workload.priority as f64 * 0.1);
242
243        // SLO compliance score (heavily weighted)
244        let slo_score = if slo_compliance_prob >= self.config.target_slo_compliance {
245            slo_compliance_prob * 100.0
246        } else {
247            // Penalty for below-target compliance
248            slo_compliance_prob * 100.0
249                - self.config.slo_violation_penalty
250                    * (self.config.target_slo_compliance - slo_compliance_prob)
251                    * 100.0
252        };
253
254        // Cost score (inverse - lower is better)
255        let max_cost = self.config.max_cost_per_op;
256        let cost_score = if predicted_cost <= max_cost {
257            (1.0 - predicted_cost / max_cost) * 50.0
258        } else {
259            -((predicted_cost / max_cost) - 1.0) * 50.0
260        };
261
262        // Load balancing score (prefer less loaded hosts)
263        let load_score = (1.0 - host.current_load) * 20.0;
264
265        // Combine scores
266        (slo_score + cost_score + load_score) * priority_weight
267    }
268
269    /// Generate human-readable reason for selection
270    fn generate_reason(
271        &self,
272        host: &HostProfile,
273        slo_compliance_prob: f64,
274        predicted_cost: f64,
275    ) -> String {
276        let mut reasons = Vec::new();
277
278        if slo_compliance_prob >= 0.99 {
279            reasons.push("excellent SLO compliance");
280        } else if slo_compliance_prob >= 0.95 {
281            reasons.push("good SLO compliance");
282        }
283
284        if host.instance_type == InstanceType::Spot {
285            reasons.push("cost-effective spot instance");
286        } else if host.instance_type == InstanceType::Reserved {
287            reasons.push("reserved capacity");
288        }
289
290        if host.current_load < 0.3 {
291            reasons.push("low current load");
292        }
293
294        if predicted_cost < self.config.max_cost_per_op * 0.5 {
295            reasons.push("low cost");
296        }
297
298        if reasons.is_empty() {
299            "best available option".to_string()
300        } else {
301            reasons.join(", ")
302        }
303    }
304
305    /// Record execution result for learning
306    pub fn record_result(
307        &mut self,
308        host_id: &str,
309        actual_time: Duration,
310        slo_violated: bool,
311        actual_cost: f64,
312    ) {
313        // Update execution history
314        if let Some(history) = self.execution_history.get_mut(host_id) {
315            history.push(actual_time);
316            if history.len() > self.config.history_window {
317                history.remove(0);
318            }
319        }
320
321        // Update violation history
322        if let Some(history) = self.violation_history.get_mut(host_id) {
323            history.push(slo_violated);
324            if history.len() > self.config.history_window {
325                history.remove(0);
326            }
327        }
328
329        // Update metrics
330        if slo_violated {
331            self.metrics.slo_violations += 1;
332        }
333        self.metrics.total_cost += actual_cost;
334
335        // Update host utilization
336        if let Some(host) = self.hosts.get(host_id) {
337            self.metrics
338                .host_utilization
339                .insert(host_id.to_string(), host.current_load);
340        }
341
342        // Update host historical compliance
343        if let Some(host) = self.hosts.get_mut(host_id) {
344            if let Some(history) = self.violation_history.get(host_id) {
345                let recent_violations = history
346                    .iter()
347                    .rev()
348                    .take(self.config.history_window)
349                    .filter(|&&v| v)
350                    .count();
351                let total = history.len().min(self.config.history_window);
352                if total > 0 {
353                    host.historical_slo_compliance =
354                        1.0 - (recent_violations as f64 / total as f64);
355                }
356            }
357        }
358    }
359
360    /// Get current scheduler metrics
361    pub fn metrics(&self) -> &SchedulerMetrics {
362        &self.metrics
363    }
364
365    /// Get all registered hosts
366    pub fn hosts(&self) -> impl Iterator<Item = &HostProfile> {
367        self.hosts.values()
368    }
369
370    /// Get host by ID
371    pub fn get_host(&self, host_id: &str) -> Option<&HostProfile> {
372        self.hosts.get(host_id)
373    }
374
375    /// Rebalance workloads across hosts (returns migration suggestions)
376    pub fn suggest_rebalancing(&self) -> Vec<(String, String)> {
377        let mut migrations = Vec::new();
378
379        // Find overloaded and underloaded hosts
380        let mut overloaded: Vec<_> = self
381            .hosts
382            .values()
383            .filter(|h| h.current_load > 0.8)
384            .collect();
385        let mut underloaded: Vec<_> = self
386            .hosts
387            .values()
388            .filter(|h| {
389                h.current_load < 0.3 && h.is_safe_for_scheduling(self.config.preemption_buffer)
390            })
391            .collect();
392
393        overloaded.sort_by(|a, b| {
394            b.current_load
395                .partial_cmp(&a.current_load)
396                .expect("values should be comparable")
397        });
398        underloaded.sort_by(|a, b| {
399            a.current_load
400                .partial_cmp(&b.current_load)
401                .expect("values should be comparable")
402        });
403
404        // Suggest migrations from overloaded to underloaded
405        for (over, under) in overloaded.iter().zip(underloaded.iter()) {
406            migrations.push((over.host_id.clone(), under.host_id.clone()));
407        }
408
409        migrations
410    }
411}