1use chrono::Timelike;
7use dashmap::DashMap;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
10use std::sync::RwLock;
11
12use super::classifier::WorkloadType;
13use super::config::SchedulingPolicy;
14use super::DistribCacheConfig;
15
16#[derive(Debug, Clone)]
18pub struct ScheduledQuery {
19 pub id: u64,
21 pub workload_type: WorkloadType,
23 pub timestamp: std::time::Instant,
25}
26
27#[derive(Debug, Clone)]
29pub enum ScheduleResult {
30 Execute { priority: QueryPriority },
32 Queued { position: usize },
34 Rejected { reason: String },
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum QueryPriority {
41 High,
42 Normal,
43 Low,
44}
45
46#[derive(Debug, Clone)]
48pub struct WorkloadDistribution {
49 pub oltp: WorkloadSlot,
51 pub olap: WorkloadSlot,
53 pub vector: WorkloadSlot,
55 pub ai_agent: WorkloadSlot,
57 pub rag: WorkloadSlot,
59}
60
61#[derive(Debug, Clone)]
63pub struct WorkloadSlot {
64 pub current_pct: f64,
66 pub target_pct: f64,
68 pub queued: u32,
70 pub active: u32,
72}
73
74struct WorkloadQueue {
76 pending: std::collections::VecDeque<ScheduledQuery>,
78 active: AtomicU32,
80 total_processed: AtomicU64,
82}
83
84impl WorkloadQueue {
85 fn new() -> Self {
86 Self {
87 pending: std::collections::VecDeque::new(),
88 active: AtomicU32::new(0),
89 total_processed: AtomicU64::new(0),
90 }
91 }
92}
93
94pub struct WorkloadScheduler {
96 #[allow(dead_code)]
98 config: DistribCacheConfig,
99
100 queues: DashMap<WorkloadType, RwLock<WorkloadQueue>>,
102
103 limits: HashMap<WorkloadType, ResourceLimit>,
105
106 policy: SchedulingPolicy,
108
109 stats: SchedulerStats,
111}
112
113#[derive(Debug, Clone)]
115pub struct ResourceLimit {
116 pub max_concurrent: u32,
118 pub max_cache_mb: usize,
120 pub priority_weight: f64,
122}
123
124impl Default for ResourceLimit {
125 fn default() -> Self {
126 Self {
127 max_concurrent: 100,
128 max_cache_mb: 64,
129 priority_weight: 0.5,
130 }
131 }
132}
133
134#[derive(Debug, Default)]
136struct SchedulerStats {
137 total_scheduled: AtomicU64,
138 total_queued: AtomicU64,
139 total_rejected: AtomicU64,
140 current_active: AtomicU32,
141}
142
143impl WorkloadScheduler {
144 pub fn new(config: DistribCacheConfig) -> Self {
146 let mut limits = HashMap::new();
147
148 limits.insert(
149 WorkloadType::OLTP,
150 ResourceLimit {
151 max_concurrent: config.max_concurrent_oltp,
152 max_cache_mb: 64,
153 priority_weight: config.oltp_priority,
154 },
155 );
156
157 limits.insert(
158 WorkloadType::OLAP,
159 ResourceLimit {
160 max_concurrent: config.max_concurrent_olap,
161 max_cache_mb: 128,
162 priority_weight: config.olap_priority,
163 },
164 );
165
166 limits.insert(
167 WorkloadType::Vector,
168 ResourceLimit {
169 max_concurrent: config.max_concurrent_vector,
170 max_cache_mb: 96,
171 priority_weight: config.vector_priority,
172 },
173 );
174
175 limits.insert(
176 WorkloadType::AIAgent,
177 ResourceLimit {
178 max_concurrent: config.max_concurrent_ai,
179 max_cache_mb: 64,
180 priority_weight: config.ai_agent_priority,
181 },
182 );
183
184 limits.insert(
185 WorkloadType::RAG,
186 ResourceLimit {
187 max_concurrent: config.max_concurrent_ai,
188 max_cache_mb: 64,
189 priority_weight: config.ai_agent_priority,
190 },
191 );
192
193 limits.insert(WorkloadType::Mixed, ResourceLimit::default());
194
195 let queues = DashMap::new();
196 for wt in [
197 WorkloadType::OLTP,
198 WorkloadType::OLAP,
199 WorkloadType::Vector,
200 WorkloadType::AIAgent,
201 WorkloadType::RAG,
202 WorkloadType::Mixed,
203 ] {
204 queues.insert(wt, RwLock::new(WorkloadQueue::new()));
205 }
206
207 Self {
208 policy: config.scheduling_policy,
209 config,
210 queues,
211 limits,
212 stats: SchedulerStats::default(),
213 }
214 }
215
216 pub fn schedule(&self, query: ScheduledQuery) -> ScheduleResult {
218 self.stats.total_scheduled.fetch_add(1, Ordering::Relaxed);
219
220 let workload = query.workload_type;
221 let default_limit = ResourceLimit::default();
222 let limit = self.limits.get(&workload).unwrap_or(&default_limit);
223
224 let current = self.get_current_concurrency(&workload);
226 if current >= limit.max_concurrent {
227 self.enqueue(query.clone());
229 self.stats.total_queued.fetch_add(1, Ordering::Relaxed);
230 return ScheduleResult::Queued {
231 position: self.queue_position(&workload),
232 };
233 }
234
235 match self.policy {
237 SchedulingPolicy::StrictPriority => self.schedule_strict_priority(query),
238 SchedulingPolicy::WeightedFair => self.schedule_weighted_fair(query),
239 SchedulingPolicy::TimeBased => self.schedule_time_based(query),
240 SchedulingPolicy::Adaptive => self.schedule_adaptive(query),
241 }
242 }
243
244 fn schedule_strict_priority(&self, query: ScheduledQuery) -> ScheduleResult {
246 let priority = match query.workload_type {
247 WorkloadType::OLTP => QueryPriority::High,
248 WorkloadType::AIAgent | WorkloadType::RAG => QueryPriority::Normal,
249 WorkloadType::Vector => QueryPriority::Normal,
250 WorkloadType::OLAP => QueryPriority::Low,
251 WorkloadType::Mixed => QueryPriority::Normal,
252 };
253
254 self.mark_active(&query.workload_type);
255 ScheduleResult::Execute { priority }
256 }
257
258 fn schedule_weighted_fair(&self, query: ScheduledQuery) -> ScheduleResult {
260 let limit = self.limits.get(&query.workload_type).unwrap();
261 let weight = limit.priority_weight;
262
263 let priority = if weight >= 0.8 {
264 QueryPriority::High
265 } else if weight >= 0.4 {
266 QueryPriority::Normal
267 } else {
268 QueryPriority::Low
269 };
270
271 self.mark_active(&query.workload_type);
272 ScheduleResult::Execute { priority }
273 }
274
275 fn schedule_time_based(&self, query: ScheduledQuery) -> ScheduleResult {
277 let hour = chrono::Utc::now().hour();
278
279 let priority = if (9..18).contains(&hour) {
281 match query.workload_type {
282 WorkloadType::OLTP | WorkloadType::AIAgent => QueryPriority::High,
283 WorkloadType::OLAP => QueryPriority::Low,
284 _ => QueryPriority::Normal,
285 }
286 } else {
287 match query.workload_type {
289 WorkloadType::OLAP => QueryPriority::High,
290 WorkloadType::OLTP => QueryPriority::Normal,
291 _ => QueryPriority::Normal,
292 }
293 };
294
295 self.mark_active(&query.workload_type);
296 ScheduleResult::Execute { priority }
297 }
298
299 fn schedule_adaptive(&self, query: ScheduledQuery) -> ScheduleResult {
301 let distribution = self.get_distribution();
303 let workload = query.workload_type;
304
305 let slot = match workload {
306 WorkloadType::OLTP => &distribution.oltp,
307 WorkloadType::OLAP => &distribution.olap,
308 WorkloadType::Vector => &distribution.vector,
309 WorkloadType::AIAgent => &distribution.ai_agent,
310 WorkloadType::RAG => &distribution.rag,
311 WorkloadType::Mixed => &distribution.oltp, };
313
314 let priority = if slot.current_pct < slot.target_pct {
315 QueryPriority::High } else if slot.current_pct > slot.target_pct * 1.2 {
317 QueryPriority::Low } else {
319 QueryPriority::Normal
320 };
321
322 self.mark_active(&query.workload_type);
323 ScheduleResult::Execute { priority }
324 }
325
326 fn get_current_concurrency(&self, workload: &WorkloadType) -> u32 {
328 self.queues
329 .get(workload)
330 .map(|q| q.read().unwrap().active.load(Ordering::Relaxed))
331 .unwrap_or(0)
332 }
333
334 fn queue_position(&self, workload: &WorkloadType) -> usize {
336 self.queues
337 .get(workload)
338 .map(|q| q.read().unwrap().pending.len())
339 .unwrap_or(0)
340 }
341
342 fn enqueue(&self, query: ScheduledQuery) {
344 if let Some(queue) = self.queues.get(&query.workload_type) {
345 queue.write().unwrap().pending.push_back(query);
346 }
347 }
348
349 fn mark_active(&self, workload: &WorkloadType) {
351 if let Some(queue) = self.queues.get(workload) {
352 queue.read().unwrap().active.fetch_add(1, Ordering::Relaxed);
353 }
354 self.stats.current_active.fetch_add(1, Ordering::Relaxed);
355 }
356
357 pub fn mark_complete(&self, workload: WorkloadType) {
359 if let Some(queue) = self.queues.get(&workload) {
360 let q = queue.read().unwrap();
361 q.active.fetch_sub(1, Ordering::Relaxed);
362 q.total_processed.fetch_add(1, Ordering::Relaxed);
363 }
364 self.stats.current_active.fetch_sub(1, Ordering::Relaxed);
365 }
366
367 pub fn get_distribution(&self) -> WorkloadDistribution {
369 let total_active = self.stats.current_active.load(Ordering::Relaxed) as f64;
370
371 let get_slot = |wt: WorkloadType| -> WorkloadSlot {
372 let queue = self.queues.get(&wt).unwrap();
373 let q = queue.read().unwrap();
374 let active = q.active.load(Ordering::Relaxed);
375 let limit = self.limits.get(&wt).unwrap();
376
377 WorkloadSlot {
378 current_pct: if total_active > 0.0 {
379 active as f64 / total_active * 100.0
380 } else {
381 0.0
382 },
383 target_pct: limit.priority_weight * 100.0 / 2.5, queued: q.pending.len() as u32,
385 active,
386 }
387 };
388
389 WorkloadDistribution {
390 oltp: get_slot(WorkloadType::OLTP),
391 olap: get_slot(WorkloadType::OLAP),
392 vector: get_slot(WorkloadType::Vector),
393 ai_agent: get_slot(WorkloadType::AIAgent),
394 rag: get_slot(WorkloadType::RAG),
395 }
396 }
397
398 pub fn stats(&self) -> SchedulerStatsSnapshot {
400 SchedulerStatsSnapshot {
401 total_scheduled: self.stats.total_scheduled.load(Ordering::Relaxed),
402 total_queued: self.stats.total_queued.load(Ordering::Relaxed),
403 total_rejected: self.stats.total_rejected.load(Ordering::Relaxed),
404 current_active: self.stats.current_active.load(Ordering::Relaxed),
405 policy: self.policy,
406 }
407 }
408}
409
410#[derive(Debug, Clone)]
412pub struct SchedulerStatsSnapshot {
413 pub total_scheduled: u64,
414 pub total_queued: u64,
415 pub total_rejected: u64,
416 pub current_active: u32,
417 pub policy: SchedulingPolicy,
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn test_schedule_oltp() {
426 let config = DistribCacheConfig::default();
427 let scheduler = WorkloadScheduler::new(config);
428
429 let query = ScheduledQuery {
430 id: 1,
431 workload_type: WorkloadType::OLTP,
432 timestamp: std::time::Instant::now(),
433 };
434
435 let result = scheduler.schedule(query);
436 assert!(matches!(result, ScheduleResult::Execute { .. }));
437 }
438
439 #[test]
440 fn test_schedule_with_concurrency_limit() {
441 let mut config = DistribCacheConfig::default();
442 config.max_concurrent_oltp = 1;
443
444 let scheduler = WorkloadScheduler::new(config);
445
446 let query1 = ScheduledQuery {
448 id: 1,
449 workload_type: WorkloadType::OLTP,
450 timestamp: std::time::Instant::now(),
451 };
452 let result1 = scheduler.schedule(query1);
453 assert!(matches!(result1, ScheduleResult::Execute { .. }));
454
455 let query2 = ScheduledQuery {
457 id: 2,
458 workload_type: WorkloadType::OLTP,
459 timestamp: std::time::Instant::now(),
460 };
461 let result2 = scheduler.schedule(query2);
462 assert!(matches!(result2, ScheduleResult::Queued { .. }));
463 }
464
465 #[test]
466 fn test_mark_complete() {
467 let config = DistribCacheConfig::default();
468 let scheduler = WorkloadScheduler::new(config);
469
470 let query = ScheduledQuery {
471 id: 1,
472 workload_type: WorkloadType::OLTP,
473 timestamp: std::time::Instant::now(),
474 };
475
476 scheduler.schedule(query);
477 assert_eq!(scheduler.stats().current_active, 1);
478
479 scheduler.mark_complete(WorkloadType::OLTP);
480 assert_eq!(scheduler.stats().current_active, 0);
481 }
482
483 #[test]
484 fn test_get_distribution() {
485 let config = DistribCacheConfig::default();
486 let scheduler = WorkloadScheduler::new(config);
487
488 for i in 0..5 {
490 let query = ScheduledQuery {
491 id: i,
492 workload_type: WorkloadType::OLTP,
493 timestamp: std::time::Instant::now(),
494 };
495 scheduler.schedule(query);
496 }
497
498 for i in 0..3 {
499 let query = ScheduledQuery {
500 id: i + 10,
501 workload_type: WorkloadType::OLAP,
502 timestamp: std::time::Instant::now(),
503 };
504 scheduler.schedule(query);
505 }
506
507 let dist = scheduler.get_distribution();
508 assert!(dist.oltp.active > 0);
509 assert!(dist.olap.active > 0);
510 }
511}