1use crate::errors::{PolarisError, PolarisResult};
6use crate::node::{Node, NodeId};
7use crate::task::Task;
8use async_trait::async_trait;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12#[async_trait]
14pub trait Scheduler: Send + Sync {
15 async fn schedule(&self, task: &Task, nodes: &[Node]) -> PolarisResult<NodeId>;
26
27 fn name(&self) -> &str;
29}
30
31#[derive(Debug)]
35pub struct RoundRobinScheduler {
36 counter: Arc<AtomicUsize>,
37}
38
39impl RoundRobinScheduler {
40 pub fn new() -> Self {
42 Self {
43 counter: Arc::new(AtomicUsize::new(0)),
44 }
45 }
46}
47
48impl Default for RoundRobinScheduler {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54#[async_trait]
55impl Scheduler for RoundRobinScheduler {
56 async fn schedule(&self, _task: &Task, nodes: &[Node]) -> PolarisResult<NodeId> {
57 if nodes.is_empty() {
58 return Err(PolarisError::scheduling_failed("No available nodes"));
59 }
60
61 let index = self.counter.fetch_add(1, Ordering::Relaxed) % nodes.len();
62 let node = &nodes[index];
63
64 tracing::debug!(
65 task_id = %_task.id,
66 node_id = %node.id(),
67 scheduler = "round_robin",
68 "Task scheduled"
69 );
70
71 Ok(node.id())
72 }
73
74 fn name(&self) -> &str {
75 "round_robin"
76 }
77}
78
79#[derive(Debug)]
83pub struct LoadAwareScheduler {
84 cpu_weight: f64,
86 memory_weight: f64,
88 task_weight: f64,
90}
91
92impl LoadAwareScheduler {
93 pub fn new() -> Self {
95 Self {
96 cpu_weight: 0.4,
97 memory_weight: 0.3,
98 task_weight: 0.3,
99 }
100 }
101
102 pub fn with_weights(cpu_weight: f64, memory_weight: f64, task_weight: f64) -> Self {
104 let total = cpu_weight + memory_weight + task_weight;
105 Self {
106 cpu_weight: cpu_weight / total,
107 memory_weight: memory_weight / total,
108 task_weight: task_weight / total,
109 }
110 }
111
112 fn calculate_load(&self, node: &Node) -> f64 {
114 let info = node.info();
115 let usage = &info.resource_usage;
116
117 let cpu_load = usage.cpu_usage_percent / 100.0;
118 let memory_load = if info.resource_limits.max_memory_bytes > 0 {
119 usage.memory_usage_bytes as f64 / info.resource_limits.max_memory_bytes as f64
120 } else {
121 0.0
122 };
123 let task_load = node.load_factor();
124
125 self.cpu_weight * cpu_load
126 + self.memory_weight * memory_load
127 + self.task_weight * task_load
128 }
129}
130
131impl Default for LoadAwareScheduler {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137#[async_trait]
138impl Scheduler for LoadAwareScheduler {
139 async fn schedule(&self, task: &Task, nodes: &[Node]) -> PolarisResult<NodeId> {
140 if nodes.is_empty() {
141 return Err(PolarisError::scheduling_failed("No available nodes"));
142 }
143
144 let (best_node, load) = nodes
146 .iter()
147 .map(|node| (node, self.calculate_load(node)))
148 .min_by(|(_, load_a), (_, load_b)| {
149 load_a.partial_cmp(load_b).unwrap_or(std::cmp::Ordering::Equal)
150 })
151 .unwrap();
152
153 tracing::debug!(
154 task_id = %task.id,
155 node_id = %best_node.id(),
156 load = %load,
157 scheduler = "load_aware",
158 "Task scheduled"
159 );
160
161 Ok(best_node.id())
162 }
163
164 fn name(&self) -> &str {
165 "load_aware"
166 }
167}
168
169#[derive(Debug)]
173pub struct PriorityScheduler {
174 base_scheduler: LoadAwareScheduler,
175}
176
177impl PriorityScheduler {
178 pub fn new() -> Self {
180 Self {
181 base_scheduler: LoadAwareScheduler::new(),
182 }
183 }
184}
185
186impl Default for PriorityScheduler {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192#[async_trait]
193impl Scheduler for PriorityScheduler {
194 async fn schedule(&self, task: &Task, nodes: &[Node]) -> PolarisResult<NodeId> {
195 tracing::debug!(
198 task_id = %task.id,
199 priority = ?task.priority,
200 scheduler = "priority",
201 "Scheduling high-priority task"
202 );
203 self.base_scheduler.schedule(task, nodes).await
204 }
205
206 fn name(&self) -> &str {
207 "priority"
208 }
209}
210
211#[derive(Debug)]
215pub struct AffinityScheduler {
216 base_scheduler: LoadAwareScheduler,
217}
218
219impl AffinityScheduler {
220 pub fn new() -> Self {
222 Self {
223 base_scheduler: LoadAwareScheduler::new(),
224 }
225 }
226
227 fn filter_by_labels<'a>(
229 &self,
230 nodes: &'a [Node],
231 required_labels: &std::collections::HashMap<String, String>,
232 ) -> Vec<&'a Node> {
233 if required_labels.is_empty() {
234 return nodes.iter().collect();
235 }
236
237 nodes
238 .iter()
239 .filter(|node| {
240 let node_labels = &node.info().metadata.labels;
241 required_labels
242 .iter()
243 .all(|(k, v)| node_labels.get(k).map(|nv| nv == v).unwrap_or(false))
244 })
245 .collect()
246 }
247}
248
249impl Default for AffinityScheduler {
250 fn default() -> Self {
251 Self::new()
252 }
253}
254
255#[async_trait]
256impl Scheduler for AffinityScheduler {
257 async fn schedule(&self, task: &Task, nodes: &[Node]) -> PolarisResult<NodeId> {
258 let required_labels = &task.metadata.tags;
260
261 let filtered_nodes: Vec<&Node> = self.filter_by_labels(nodes, required_labels);
262
263 if filtered_nodes.is_empty() {
264 return Err(PolarisError::scheduling_failed(
265 "No nodes match affinity constraints",
266 ));
267 }
268
269 let owned_nodes: Vec<Node> = filtered_nodes.into_iter().cloned().collect();
271
272 tracing::debug!(
273 task_id = %task.id,
274 constraints = ?required_labels,
275 scheduler = "affinity",
276 "Scheduling with affinity constraints"
277 );
278
279 self.base_scheduler.schedule(task, &owned_nodes).await
280 }
281
282 fn name(&self) -> &str {
283 "affinity"
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use crate::task::{TaskPriority, TaskStatus};
291 use bytes::Bytes;
292
293 fn create_test_nodes(count: usize) -> Vec<Node> {
294 (0..count)
295 .map(|i| {
296 let addr = format!("127.0.0.1:700{}", i).parse().unwrap();
297 let node = Node::new(format!("node{}", i), addr);
298 node.set_status(crate::node::NodeStatus::Ready);
299 node
300 })
301 .collect()
302 }
303
304 #[tokio::test]
305 async fn test_round_robin_scheduler() {
306 let scheduler = RoundRobinScheduler::new();
307 let nodes = create_test_nodes(3);
308 let task = Task::new("test", Bytes::new());
309
310 let mut node_counts = std::collections::HashMap::new();
312 for _ in 0..9 {
313 let node_id = scheduler.schedule(&task, &nodes).await.unwrap();
314 *node_counts.entry(node_id).or_insert(0) += 1;
315 }
316
317 assert_eq!(node_counts.len(), 3);
319 for count in node_counts.values() {
320 assert_eq!(*count, 3);
321 }
322 }
323
324 #[tokio::test]
325 async fn test_round_robin_no_nodes() {
326 let scheduler = RoundRobinScheduler::new();
327 let nodes = vec![];
328 let task = Task::new("test", Bytes::new());
329
330 let result = scheduler.schedule(&task, &nodes).await;
331 assert!(result.is_err());
332 }
333
334 #[tokio::test]
335 async fn test_load_aware_scheduler() {
336 let scheduler = LoadAwareScheduler::new();
337 let nodes = create_test_nodes(3);
338
339 for _ in 0..10 {
341 nodes[0].increment_active_tasks();
342 }
343
344 let task = Task::new("test", Bytes::new());
345 let node_id = scheduler.schedule(&task, &nodes).await.unwrap();
346
347 assert_ne!(node_id, nodes[0].id());
349 }
350
351 #[tokio::test]
352 async fn test_load_aware_custom_weights() {
353 let scheduler = LoadAwareScheduler::with_weights(1.0, 0.0, 0.0);
354 let nodes = create_test_nodes(2);
355 let task = Task::new("test", Bytes::new());
356
357 let result = scheduler.schedule(&task, &nodes).await;
358 assert!(result.is_ok());
359 }
360
361 #[tokio::test]
362 async fn test_priority_scheduler() {
363 let scheduler = PriorityScheduler::new();
364 let nodes = create_test_nodes(2);
365
366 let task = Task::new("test", Bytes::new()).with_priority(TaskPriority::High);
367
368 let result = scheduler.schedule(&task, &nodes).await;
369 assert!(result.is_ok());
370 }
371
372 #[tokio::test]
373 async fn test_affinity_scheduler_with_labels() {
374 let scheduler = AffinityScheduler::new();
375 let nodes = create_test_nodes(2);
376
377 {
379 let mut info = nodes[0].info();
380 info.metadata.labels.insert("zone".to_string(), "us-east".to_string());
381 }
382
383 let mut task = Task::new("test", Bytes::new());
384 task.metadata.tags.insert("zone".to_string(), "us-east".to_string());
385
386 let node_id = scheduler.schedule(&task, &nodes).await.unwrap();
387 assert_eq!(node_id, nodes[0].id());
388 }
389
390 #[tokio::test]
391 async fn test_affinity_scheduler_no_match() {
392 let scheduler = AffinityScheduler::new();
393 let nodes = create_test_nodes(2);
394
395 let mut task = Task::new("test", Bytes::new());
396 task.metadata.tags.insert("zone".to_string(), "us-west".to_string());
397
398 let result = scheduler.schedule(&task, &nodes).await;
399 assert!(result.is_err());
400 }
401
402 #[test]
403 fn test_scheduler_names() {
404 assert_eq!(RoundRobinScheduler::new().name(), "round_robin");
405 assert_eq!(LoadAwareScheduler::new().name(), "load_aware");
406 assert_eq!(PriorityScheduler::new().name(), "priority");
407 assert_eq!(AffinityScheduler::new().name(), "affinity");
408 }
409}