eshanized_polaris_core/
scheduler.rs

1//! Task scheduling implementations.
2//!
3//! This module provides pluggable schedulers for distributing tasks across nodes.
4
5use crate::errors::{PolarisError, PolarisResult};
6use crate::node::{Node, NodeId};
7use crate::task::Task;
8use async_trait::async_trait;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12/// Trait for implementing custom schedulers
13#[async_trait]
14pub trait Scheduler: Send + Sync {
15    /// Schedule a task on one of the available nodes
16    ///
17    /// # Arguments
18    ///
19    /// * `task` - The task to schedule
20    /// * `nodes` - Available nodes for scheduling
21    ///
22    /// # Returns
23    ///
24    /// Returns the selected node ID, or an error if scheduling fails
25    async fn schedule(&self, task: &Task, nodes: &[Node]) -> PolarisResult<NodeId>;
26
27    /// Get scheduler name
28    fn name(&self) -> &str;
29}
30
31/// Round-robin scheduler
32///
33/// Distributes tasks evenly across nodes in a circular fashion.
34#[derive(Debug)]
35pub struct RoundRobinScheduler {
36    counter: Arc<AtomicUsize>,
37}
38
39impl RoundRobinScheduler {
40    /// Create a new round-robin scheduler
41    pub fn new() -> Self {
42        Self {
43            counter: Arc::new(AtomicUsize::new(0)),
44        }
45    }
46}
47
48impl Default for RoundRobinScheduler {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54#[async_trait]
55impl Scheduler for RoundRobinScheduler {
56    async fn schedule(&self, _task: &Task, nodes: &[Node]) -> PolarisResult<NodeId> {
57        if nodes.is_empty() {
58            return Err(PolarisError::scheduling_failed("No available nodes"));
59        }
60
61        let index = self.counter.fetch_add(1, Ordering::Relaxed) % nodes.len();
62        let node = &nodes[index];
63
64        tracing::debug!(
65            task_id = %_task.id,
66            node_id = %node.id(),
67            scheduler = "round_robin",
68            "Task scheduled"
69        );
70
71        Ok(node.id())
72    }
73
74    fn name(&self) -> &str {
75        "round_robin"
76    }
77}
78
79/// Load-aware scheduler
80///
81/// Schedules tasks on the least loaded node based on current resource usage.
82#[derive(Debug)]
83pub struct LoadAwareScheduler {
84    /// Weight for CPU usage in load calculation
85    cpu_weight: f64,
86    /// Weight for memory usage in load calculation
87    memory_weight: f64,
88    /// Weight for task count in load calculation
89    task_weight: f64,
90}
91
92impl LoadAwareScheduler {
93    /// Create a new load-aware scheduler with default weights
94    pub fn new() -> Self {
95        Self {
96            cpu_weight: 0.4,
97            memory_weight: 0.3,
98            task_weight: 0.3,
99        }
100    }
101
102    /// Create a load-aware scheduler with custom weights
103    pub fn with_weights(cpu_weight: f64, memory_weight: f64, task_weight: f64) -> Self {
104        let total = cpu_weight + memory_weight + task_weight;
105        Self {
106            cpu_weight: cpu_weight / total,
107            memory_weight: memory_weight / total,
108            task_weight: task_weight / total,
109        }
110    }
111
112    /// Calculate load score for a node (lower is better)
113    fn calculate_load(&self, node: &Node) -> f64 {
114        let info = node.info();
115        let usage = &info.resource_usage;
116
117        let cpu_load = usage.cpu_usage_percent / 100.0;
118        let memory_load = if info.resource_limits.max_memory_bytes > 0 {
119            usage.memory_usage_bytes as f64 / info.resource_limits.max_memory_bytes as f64
120        } else {
121            0.0
122        };
123        let task_load = node.load_factor();
124
125        self.cpu_weight * cpu_load
126            + self.memory_weight * memory_load
127            + self.task_weight * task_load
128    }
129}
130
131impl Default for LoadAwareScheduler {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137#[async_trait]
138impl Scheduler for LoadAwareScheduler {
139    async fn schedule(&self, task: &Task, nodes: &[Node]) -> PolarisResult<NodeId> {
140        if nodes.is_empty() {
141            return Err(PolarisError::scheduling_failed("No available nodes"));
142        }
143
144        // Find node with lowest load
145        let (best_node, load) = nodes
146            .iter()
147            .map(|node| (node, self.calculate_load(node)))
148            .min_by(|(_, load_a), (_, load_b)| {
149                load_a.partial_cmp(load_b).unwrap_or(std::cmp::Ordering::Equal)
150            })
151            .unwrap();
152
153        tracing::debug!(
154            task_id = %task.id,
155            node_id = %best_node.id(),
156            load = %load,
157            scheduler = "load_aware",
158            "Task scheduled"
159        );
160
161        Ok(best_node.id())
162    }
163
164    fn name(&self) -> &str {
165        "load_aware"
166    }
167}
168
169/// Priority-based scheduler
170///
171/// Considers task priority when scheduling, favoring higher priority tasks.
172#[derive(Debug)]
173pub struct PriorityScheduler {
174    base_scheduler: LoadAwareScheduler,
175}
176
177impl PriorityScheduler {
178    /// Create a new priority scheduler
179    pub fn new() -> Self {
180        Self {
181            base_scheduler: LoadAwareScheduler::new(),
182        }
183    }
184}
185
186impl Default for PriorityScheduler {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192#[async_trait]
193impl Scheduler for PriorityScheduler {
194    async fn schedule(&self, task: &Task, nodes: &[Node]) -> PolarisResult<NodeId> {
195        // For now, delegate to load-aware scheduler
196        // In a full implementation, this would maintain priority queues per node
197        tracing::debug!(
198            task_id = %task.id,
199            priority = ?task.priority,
200            scheduler = "priority",
201            "Scheduling high-priority task"
202        );
203        self.base_scheduler.schedule(task, nodes).await
204    }
205
206    fn name(&self) -> &str {
207        "priority"
208    }
209}
210
211/// Affinity-based scheduler
212///
213/// Schedules tasks with affinity constraints (e.g., locality, specific node labels).
214#[derive(Debug)]
215pub struct AffinityScheduler {
216    base_scheduler: LoadAwareScheduler,
217}
218
219impl AffinityScheduler {
220    /// Create a new affinity scheduler
221    pub fn new() -> Self {
222        Self {
223            base_scheduler: LoadAwareScheduler::new(),
224        }
225    }
226
227    /// Filter nodes by label selector
228    fn filter_by_labels<'a>(
229        &self,
230        nodes: &'a [Node],
231        required_labels: &std::collections::HashMap<String, String>,
232    ) -> Vec<&'a Node> {
233        if required_labels.is_empty() {
234            return nodes.iter().collect();
235        }
236
237        nodes
238            .iter()
239            .filter(|node| {
240                let node_labels = &node.info().metadata.labels;
241                required_labels
242                    .iter()
243                    .all(|(k, v)| node_labels.get(k).map(|nv| nv == v).unwrap_or(false))
244            })
245            .collect()
246    }
247}
248
249impl Default for AffinityScheduler {
250    fn default() -> Self {
251        Self::new()
252    }
253}
254
255#[async_trait]
256impl Scheduler for AffinityScheduler {
257    async fn schedule(&self, task: &Task, nodes: &[Node]) -> PolarisResult<NodeId> {
258        // Check for affinity constraints in task metadata
259        let required_labels = &task.metadata.tags;
260
261        let filtered_nodes: Vec<&Node> = self.filter_by_labels(nodes, required_labels);
262
263        if filtered_nodes.is_empty() {
264            return Err(PolarisError::scheduling_failed(
265                "No nodes match affinity constraints",
266            ));
267        }
268
269        // Convert to owned nodes for scheduling
270        let owned_nodes: Vec<Node> = filtered_nodes.into_iter().cloned().collect();
271
272        tracing::debug!(
273            task_id = %task.id,
274            constraints = ?required_labels,
275            scheduler = "affinity",
276            "Scheduling with affinity constraints"
277        );
278
279        self.base_scheduler.schedule(task, &owned_nodes).await
280    }
281
282    fn name(&self) -> &str {
283        "affinity"
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use crate::task::{TaskPriority, TaskStatus};
291    use bytes::Bytes;
292
293    fn create_test_nodes(count: usize) -> Vec<Node> {
294        (0..count)
295            .map(|i| {
296                let addr = format!("127.0.0.1:700{}", i).parse().unwrap();
297                let node = Node::new(format!("node{}", i), addr);
298                node.set_status(crate::node::NodeStatus::Ready);
299                node
300            })
301            .collect()
302    }
303
304    #[tokio::test]
305    async fn test_round_robin_scheduler() {
306        let scheduler = RoundRobinScheduler::new();
307        let nodes = create_test_nodes(3);
308        let task = Task::new("test", Bytes::new());
309
310        // Schedule multiple tasks and verify round-robin behavior
311        let mut node_counts = std::collections::HashMap::new();
312        for _ in 0..9 {
313            let node_id = scheduler.schedule(&task, &nodes).await.unwrap();
314            *node_counts.entry(node_id).or_insert(0) += 1;
315        }
316
317        // Each node should get 3 tasks
318        assert_eq!(node_counts.len(), 3);
319        for count in node_counts.values() {
320            assert_eq!(*count, 3);
321        }
322    }
323
324    #[tokio::test]
325    async fn test_round_robin_no_nodes() {
326        let scheduler = RoundRobinScheduler::new();
327        let nodes = vec![];
328        let task = Task::new("test", Bytes::new());
329
330        let result = scheduler.schedule(&task, &nodes).await;
331        assert!(result.is_err());
332    }
333
334    #[tokio::test]
335    async fn test_load_aware_scheduler() {
336        let scheduler = LoadAwareScheduler::new();
337        let nodes = create_test_nodes(3);
338
339        // Make one node heavily loaded
340        for _ in 0..10 {
341            nodes[0].increment_active_tasks();
342        }
343
344        let task = Task::new("test", Bytes::new());
345        let node_id = scheduler.schedule(&task, &nodes).await.unwrap();
346
347        // Should not select the heavily loaded node
348        assert_ne!(node_id, nodes[0].id());
349    }
350
351    #[tokio::test]
352    async fn test_load_aware_custom_weights() {
353        let scheduler = LoadAwareScheduler::with_weights(1.0, 0.0, 0.0);
354        let nodes = create_test_nodes(2);
355        let task = Task::new("test", Bytes::new());
356
357        let result = scheduler.schedule(&task, &nodes).await;
358        assert!(result.is_ok());
359    }
360
361    #[tokio::test]
362    async fn test_priority_scheduler() {
363        let scheduler = PriorityScheduler::new();
364        let nodes = create_test_nodes(2);
365
366        let task = Task::new("test", Bytes::new()).with_priority(TaskPriority::High);
367
368        let result = scheduler.schedule(&task, &nodes).await;
369        assert!(result.is_ok());
370    }
371
372    #[tokio::test]
373    async fn test_affinity_scheduler_with_labels() {
374        let scheduler = AffinityScheduler::new();
375        let nodes = create_test_nodes(2);
376
377        // Add label to first node
378        {
379            let mut info = nodes[0].info();
380            info.metadata.labels.insert("zone".to_string(), "us-east".to_string());
381        }
382
383        let mut task = Task::new("test", Bytes::new());
384        task.metadata.tags.insert("zone".to_string(), "us-east".to_string());
385
386        let node_id = scheduler.schedule(&task, &nodes).await.unwrap();
387        assert_eq!(node_id, nodes[0].id());
388    }
389
390    #[tokio::test]
391    async fn test_affinity_scheduler_no_match() {
392        let scheduler = AffinityScheduler::new();
393        let nodes = create_test_nodes(2);
394
395        let mut task = Task::new("test", Bytes::new());
396        task.metadata.tags.insert("zone".to_string(), "us-west".to_string());
397
398        let result = scheduler.schedule(&task, &nodes).await;
399        assert!(result.is_err());
400    }
401
402    #[test]
403    fn test_scheduler_names() {
404        assert_eq!(RoundRobinScheduler::new().name(), "round_robin");
405        assert_eq!(LoadAwareScheduler::new().name(), "load_aware");
406        assert_eq!(PriorityScheduler::new().name(), "priority");
407        assert_eq!(AffinityScheduler::new().name(), "affinity");
408    }
409}