cbtop/predictive_scheduler/
scheduler.rs1use std::collections::HashMap;
4use std::time::{Duration, Instant};
5
6use super::types::{
7 HostProfile, InstanceType, PredictiveSchedulerConfig, SchedulerMetrics, SchedulingDecision,
8 WorkloadSpec,
9};
10
11pub struct PredictiveScheduler {
13 config: PredictiveSchedulerConfig,
14 hosts: HashMap<String, HostProfile>,
15 metrics: SchedulerMetrics,
16 execution_history: HashMap<String, Vec<Duration>>,
18 violation_history: HashMap<String, Vec<bool>>,
20}
21
22impl PredictiveScheduler {
23 pub fn new(config: PredictiveSchedulerConfig) -> Self {
25 Self {
26 config,
27 hosts: HashMap::new(),
28 metrics: SchedulerMetrics::default(),
29 execution_history: HashMap::new(),
30 violation_history: HashMap::new(),
31 }
32 }
33
34 pub fn register_host(&mut self, profile: HostProfile) {
36 let host_id = profile.host_id.clone();
37 self.hosts.insert(host_id.clone(), profile);
38 self.execution_history.insert(host_id.clone(), Vec::new());
39 self.violation_history.insert(host_id, Vec::new());
40 }
41
42 pub fn deregister_host(&mut self, host_id: &str) {
44 self.hosts.remove(host_id);
45 self.execution_history.remove(host_id);
46 self.violation_history.remove(host_id);
47 }
48
49 pub fn update_host_load(&mut self, host_id: &str, load: f64) {
51 if let Some(host) = self.hosts.get_mut(host_id) {
52 host.current_load = load.clamp(0.0, 1.0);
53 }
54 }
55
56 pub fn update_preemption_deadline(&mut self, host_id: &str, deadline: Option<Instant>) {
58 if let Some(host) = self.hosts.get_mut(host_id) {
59 host.preemption_deadline = deadline;
60 }
61 }
62
63 pub fn schedule(&mut self, workload: &WorkloadSpec) -> Option<SchedulingDecision> {
65 let start = Instant::now();
66
67 let eligible_hosts: Vec<_> = self
69 .hosts
70 .values()
71 .filter(|h| self.is_host_eligible(h, workload))
72 .collect();
73
74 if eligible_hosts.is_empty() {
75 return None;
76 }
77
78 let mut best_decision: Option<SchedulingDecision> = None;
80 let mut best_score = f64::NEG_INFINITY;
81
82 for host in eligible_hosts {
83 let decision = self.evaluate_host(host, workload);
84 if decision.score > best_score {
85 best_score = decision.score;
86 best_decision = Some(decision);
87 }
88 }
89
90 if let Some(ref decision) = best_decision {
92 self.metrics.total_decisions += 1;
93 let scheduling_time = start.elapsed().as_micros() as f64;
94 let n = self.metrics.total_decisions as f64;
95 self.metrics.avg_scheduling_latency_us =
96 self.metrics.avg_scheduling_latency_us * (n - 1.0) / n + scheduling_time / n;
97
98 if let Some(host) = self.hosts.get(&decision.host_id) {
100 if host.instance_type == InstanceType::Spot {
101 let on_demand_cost =
102 decision.predicted_cost / host.instance_type.cost_multiplier();
103 self.metrics.spot_savings += on_demand_cost - decision.predicted_cost;
104 }
105 }
106 }
107
108 best_decision
109 }
110
111 fn is_host_eligible(&self, host: &HostProfile, workload: &WorkloadSpec) -> bool {
113 if host.current_load >= self.config.min_capacity_threshold {
115 return false;
116 }
117
118 if host.memory_capacity < workload.memory_required {
120 return false;
121 }
122
123 if !host.is_safe_for_scheduling(self.config.preemption_buffer) {
125 return false;
126 }
127
128 if host.instance_type == InstanceType::Spot && !self.config.enable_spot_instances {
130 return false;
131 }
132
133 true
134 }
135
136 fn evaluate_host(&self, host: &HostProfile, workload: &WorkloadSpec) -> SchedulingDecision {
138 let predicted_time = self.predict_execution_time(host, workload);
139 let slo_compliance_prob = self.predict_slo_compliance(host, workload, predicted_time);
140 let predicted_cost = self.calculate_cost(host, workload, predicted_time);
141
142 let score = self.calculate_score(host, slo_compliance_prob, predicted_cost, workload);
144
145 let reason = self.generate_reason(host, slo_compliance_prob, predicted_cost);
146
147 SchedulingDecision {
148 host_id: host.host_id.clone(),
149 predicted_time,
150 predicted_cost,
151 slo_compliance_prob,
152 score,
153 reason,
154 }
155 }
156
157 fn predict_execution_time(&self, host: &HostProfile, workload: &WorkloadSpec) -> Duration {
159 let base_estimate = workload.estimated_execution_time(host);
160
161 if let Some(history) = self.execution_history.get(&host.host_id) {
163 if !history.is_empty() {
164 let alpha = 0.3;
166 let mut smoothed = history[0].as_secs_f64();
167 for duration in history.iter().skip(1) {
168 smoothed = alpha * duration.as_secs_f64() + (1.0 - alpha) * smoothed;
169 }
170
171 let blended = 0.7 * base_estimate.as_secs_f64() + 0.3 * smoothed;
173 return Duration::from_secs_f64(blended);
174 }
175 }
176
177 let margin = 1.0 + host.performance_cv;
179 Duration::from_secs_f64(base_estimate.as_secs_f64() * margin)
180 }
181
182 pub(super) fn predict_slo_compliance(
184 &self,
185 host: &HostProfile,
186 workload: &WorkloadSpec,
187 predicted_time: Duration,
188 ) -> f64 {
189 let time_ratio = predicted_time.as_secs_f64() / workload.slo_deadline.as_secs_f64();
191
192 let k = 10.0; let base_prob = 1.0 / (1.0 + (k * (time_ratio - 0.9)).exp());
196
197 let reliability_factor = host.instance_type.reliability();
199
200 let historical_factor = if let Some(history) = self.violation_history.get(&host.host_id) {
202 if history.len() >= 10 {
203 let recent: Vec<_> = history.iter().rev().take(10).collect();
204 let violations = recent.iter().filter(|&&v| *v).count();
205 1.0 - (violations as f64 / 10.0)
206 } else {
207 host.historical_slo_compliance
208 }
209 } else {
210 host.historical_slo_compliance
211 };
212
213 base_prob * reliability_factor * historical_factor
214 }
215
216 fn calculate_cost(
218 &self,
219 host: &HostProfile,
220 _workload: &WorkloadSpec,
221 predicted_time: Duration,
222 ) -> f64 {
223 let hours = predicted_time.as_secs_f64() / 3600.0;
224 let base_cost = host.hourly_cost * host.instance_type.cost_multiplier() * hours;
225
226 let network_cost = host.network_latency_ms * 0.0001; base_cost + network_cost
230 }
231
232 fn calculate_score(
234 &self,
235 host: &HostProfile,
236 slo_compliance_prob: f64,
237 predicted_cost: f64,
238 workload: &WorkloadSpec,
239 ) -> f64 {
240 let priority_weight = 1.0 + (workload.priority as f64 * 0.1);
242
243 let slo_score = if slo_compliance_prob >= self.config.target_slo_compliance {
245 slo_compliance_prob * 100.0
246 } else {
247 slo_compliance_prob * 100.0
249 - self.config.slo_violation_penalty
250 * (self.config.target_slo_compliance - slo_compliance_prob)
251 * 100.0
252 };
253
254 let max_cost = self.config.max_cost_per_op;
256 let cost_score = if predicted_cost <= max_cost {
257 (1.0 - predicted_cost / max_cost) * 50.0
258 } else {
259 -((predicted_cost / max_cost) - 1.0) * 50.0
260 };
261
262 let load_score = (1.0 - host.current_load) * 20.0;
264
265 (slo_score + cost_score + load_score) * priority_weight
267 }
268
269 fn generate_reason(
271 &self,
272 host: &HostProfile,
273 slo_compliance_prob: f64,
274 predicted_cost: f64,
275 ) -> String {
276 let mut reasons = Vec::new();
277
278 if slo_compliance_prob >= 0.99 {
279 reasons.push("excellent SLO compliance");
280 } else if slo_compliance_prob >= 0.95 {
281 reasons.push("good SLO compliance");
282 }
283
284 if host.instance_type == InstanceType::Spot {
285 reasons.push("cost-effective spot instance");
286 } else if host.instance_type == InstanceType::Reserved {
287 reasons.push("reserved capacity");
288 }
289
290 if host.current_load < 0.3 {
291 reasons.push("low current load");
292 }
293
294 if predicted_cost < self.config.max_cost_per_op * 0.5 {
295 reasons.push("low cost");
296 }
297
298 if reasons.is_empty() {
299 "best available option".to_string()
300 } else {
301 reasons.join(", ")
302 }
303 }
304
305 pub fn record_result(
307 &mut self,
308 host_id: &str,
309 actual_time: Duration,
310 slo_violated: bool,
311 actual_cost: f64,
312 ) {
313 if let Some(history) = self.execution_history.get_mut(host_id) {
315 history.push(actual_time);
316 if history.len() > self.config.history_window {
317 history.remove(0);
318 }
319 }
320
321 if let Some(history) = self.violation_history.get_mut(host_id) {
323 history.push(slo_violated);
324 if history.len() > self.config.history_window {
325 history.remove(0);
326 }
327 }
328
329 if slo_violated {
331 self.metrics.slo_violations += 1;
332 }
333 self.metrics.total_cost += actual_cost;
334
335 if let Some(host) = self.hosts.get(host_id) {
337 self.metrics
338 .host_utilization
339 .insert(host_id.to_string(), host.current_load);
340 }
341
342 if let Some(host) = self.hosts.get_mut(host_id) {
344 if let Some(history) = self.violation_history.get(host_id) {
345 let recent_violations = history
346 .iter()
347 .rev()
348 .take(self.config.history_window)
349 .filter(|&&v| v)
350 .count();
351 let total = history.len().min(self.config.history_window);
352 if total > 0 {
353 host.historical_slo_compliance =
354 1.0 - (recent_violations as f64 / total as f64);
355 }
356 }
357 }
358 }
359
360 pub fn metrics(&self) -> &SchedulerMetrics {
362 &self.metrics
363 }
364
365 pub fn hosts(&self) -> impl Iterator<Item = &HostProfile> {
367 self.hosts.values()
368 }
369
370 pub fn get_host(&self, host_id: &str) -> Option<&HostProfile> {
372 self.hosts.get(host_id)
373 }
374
375 pub fn suggest_rebalancing(&self) -> Vec<(String, String)> {
377 let mut migrations = Vec::new();
378
379 let mut overloaded: Vec<_> = self
381 .hosts
382 .values()
383 .filter(|h| h.current_load > 0.8)
384 .collect();
385 let mut underloaded: Vec<_> = self
386 .hosts
387 .values()
388 .filter(|h| {
389 h.current_load < 0.3 && h.is_safe_for_scheduling(self.config.preemption_buffer)
390 })
391 .collect();
392
393 overloaded.sort_by(|a, b| {
394 b.current_load
395 .partial_cmp(&a.current_load)
396 .expect("values should be comparable")
397 });
398 underloaded.sort_by(|a, b| {
399 a.current_load
400 .partial_cmp(&b.current_load)
401 .expect("values should be comparable")
402 });
403
404 for (over, under) in overloaded.iter().zip(underloaded.iter()) {
406 migrations.push((over.host_id.clone(), under.host_id.clone()));
407 }
408
409 migrations
410 }
411}