1use sklears_core::error::{Result as SklResult, SklearsError};
7use std::cmp::Ordering;
8use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
9use std::sync::{Arc, Condvar, Mutex, RwLock};
10use std::thread::{self, JoinHandle};
11use std::time::{Duration, SystemTime};
12
13use crate::distributed::{NodeId, ResourceRequirements, TaskId, TaskPriority};
14
15#[derive(Debug, Clone)]
17pub struct ScheduledTask {
18 pub id: TaskId,
20 pub name: String,
22 pub component_type: ComponentType,
24 pub dependencies: Vec<TaskId>,
26 pub resource_requirements: ResourceRequirements,
28 pub priority: TaskPriority,
30 pub estimated_duration: Duration,
32 pub submitted_at: SystemTime,
34 pub deadline: Option<SystemTime>,
36 pub metadata: HashMap<String, String>,
38 pub retry_config: RetryConfig,
40}
41
42#[derive(Debug, Clone)]
44pub enum ComponentType {
45 Transformer,
47 Predictor,
49 DataProcessor,
51 CustomFunction,
53}
54
55#[derive(Debug, Clone)]
57pub struct RetryConfig {
58 pub max_retries: usize,
60 pub delay_strategy: RetryDelayStrategy,
62 pub backoff_multiplier: f64,
64 pub max_delay: Duration,
66}
67
68#[derive(Debug, Clone)]
70pub enum RetryDelayStrategy {
71 Fixed(Duration),
73 Linear(Duration),
75 Exponential(Duration),
77 Custom(fn(usize) -> Duration),
79}
80
81impl Default for RetryConfig {
82 fn default() -> Self {
83 Self {
84 max_retries: 3,
85 delay_strategy: RetryDelayStrategy::Exponential(Duration::from_millis(100)),
86 backoff_multiplier: 2.0,
87 max_delay: Duration::from_secs(60),
88 }
89 }
90}
91
92#[derive(Debug, Clone, PartialEq)]
94pub enum TaskState {
95 Pending,
97 Ready,
99 Running {
101 started_at: SystemTime,
102 node_id: Option<NodeId>,
103 },
104 Completed {
106 completed_at: SystemTime,
107 execution_time: Duration,
108 },
109 Failed {
111 failed_at: SystemTime,
112 error: String,
113 retry_count: usize,
114 },
115 Cancelled { cancelled_at: SystemTime },
117 Retrying {
119 next_retry_at: SystemTime,
120 retry_count: usize,
121 },
122}
123
124#[derive(Debug)]
126struct PriorityTask {
127 task: ScheduledTask,
128 priority_score: i64,
129}
130
131impl PartialEq for PriorityTask {
132 fn eq(&self, other: &Self) -> bool {
133 self.priority_score == other.priority_score
134 }
135}
136
137impl Eq for PriorityTask {}
138
139impl PartialOrd for PriorityTask {
140 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
141 Some(self.cmp(other))
142 }
143}
144
145impl Ord for PriorityTask {
146 fn cmp(&self, other: &Self) -> Ordering {
147 self.priority_score.cmp(&other.priority_score)
149 }
150}
151
152#[derive(Debug, Clone)]
154pub enum SchedulingStrategy {
155 FIFO,
157 Priority,
159 ShortestJobFirst,
161 EarliestDeadlineFirst,
163 FairShare {
165 time_quantum: Duration,
167 },
168 ResourceAware,
170 Custom {
172 schedule_fn: fn(&[ScheduledTask], &ResourcePool) -> Option<TaskId>,
173 },
174}
175
176#[derive(Debug, Clone)]
178pub struct ResourcePool {
179 pub available_cpu: u32,
181 pub available_memory: u64,
183 pub available_disk: u64,
185 pub available_gpu: u32,
187 pub utilization_history: Vec<ResourceUtilization>,
189}
190
191#[derive(Debug, Clone)]
193pub struct ResourceUtilization {
194 pub timestamp: SystemTime,
196 pub cpu_usage: f64,
198 pub memory_usage: f64,
200 pub disk_usage: f64,
202 pub gpu_usage: f64,
204}
205
206impl Default for ResourcePool {
207 fn default() -> Self {
208 Self {
209 available_cpu: 4,
210 available_memory: 8192,
211 available_disk: 100_000,
212 available_gpu: 0,
213 utilization_history: Vec::new(),
214 }
215 }
216}
217
218pub trait PluggableScheduler: Send + Sync + std::fmt::Debug {
220 fn name(&self) -> &str;
222
223 fn description(&self) -> &str;
225
226 fn initialize(&mut self, config: &SchedulerConfig) -> SklResult<()>;
228
229 fn select_next_task(
231 &self,
232 available_tasks: &[ScheduledTask],
233 resource_pool: &ResourcePool,
234 current_time: SystemTime,
235 ) -> Option<TaskId>;
236
237 fn calculate_priority(&self, task: &ScheduledTask, context: &SchedulingContext) -> i64;
239
240 fn can_schedule_task(&self, task: &ScheduledTask, resource_pool: &ResourcePool) -> bool;
242
243 fn get_metrics(&self) -> SchedulerMetrics;
245
246 fn on_task_completed(&mut self, task_id: &TaskId, execution_time: Duration) -> SklResult<()>;
248
249 fn on_task_failed(&mut self, task_id: &TaskId, error: &str) -> SklResult<()>;
251
252 fn cleanup(&mut self) -> SklResult<()>;
254}
255
256#[derive(Debug, Clone, Default)]
258pub struct SchedulingContext {
259 pub system_load: SystemLoad,
261 pub execution_history: Vec<TaskExecutionHistory>,
263 pub resource_constraints: ResourceConstraints,
265 pub temporal_context: TemporalContext,
267 pub custom_data: HashMap<String, String>,
269}
270
271#[derive(Debug, Clone)]
273pub struct SystemLoad {
274 pub cpu_utilization: f64,
276 pub memory_utilization: f64,
278 pub io_wait: f64,
280 pub network_utilization: f64,
282 pub load_average: (f64, f64, f64),
284}
285
286impl Default for SystemLoad {
287 fn default() -> Self {
288 Self {
289 cpu_utilization: 0.0,
290 memory_utilization: 0.0,
291 io_wait: 0.0,
292 network_utilization: 0.0,
293 load_average: (0.0, 0.0, 0.0),
294 }
295 }
296}
297
298#[derive(Debug, Clone)]
300pub struct TaskExecutionHistory {
301 pub task_type: ComponentType,
303 pub execution_time: Duration,
305 pub resource_usage: ResourceUsage,
307 pub success_rate: f64,
309 pub timestamp: SystemTime,
311}
312
313#[derive(Debug, Clone)]
315pub struct ResourceUsage {
316 pub peak_cpu: f64,
318 pub peak_memory: u64,
320 pub io_operations: u64,
322 pub network_bytes: u64,
324}
325
326#[derive(Debug, Clone)]
328pub struct ResourceConstraints {
329 pub max_cpu_per_task: f64,
331 pub max_memory_per_task: u64,
333 pub max_concurrent_io: u32,
335 pub network_bandwidth_limit: u64,
337}
338
339impl Default for ResourceConstraints {
340 fn default() -> Self {
341 Self {
342 max_cpu_per_task: 1.0, max_memory_per_task: 1024, max_concurrent_io: 10, network_bandwidth_limit: 100_000_000, }
347 }
348}
349
350#[derive(Debug, Clone)]
352pub struct TemporalContext {
353 pub current_time: SystemTime,
355 pub business_hours: Option<BusinessHours>,
357 pub maintenance_windows: Vec<MaintenanceWindow>,
359 pub peak_periods: Vec<PeakPeriod>,
361}
362
363impl Default for TemporalContext {
364 fn default() -> Self {
365 Self {
366 current_time: SystemTime::now(),
367 business_hours: None,
368 maintenance_windows: Vec::new(),
369 peak_periods: Vec::new(),
370 }
371 }
372}
373
374#[derive(Debug, Clone)]
376pub struct BusinessHours {
377 pub start: (u8, u8),
379 pub end: (u8, u8),
381 pub business_days: Vec<u8>,
383 pub timezone_offset: i8,
385}
386
387#[derive(Debug, Clone)]
389pub struct MaintenanceWindow {
390 pub name: String,
392 pub start: SystemTime,
394 pub end: SystemTime,
396 pub severity: MaintenanceSeverity,
398}
399
400#[derive(Debug, Clone)]
402pub enum MaintenanceSeverity {
403 Normal,
405 Critical,
407 Emergency,
409}
410
411#[derive(Debug, Clone)]
413pub struct PeakPeriod {
414 pub name: String,
416 pub start: (u8, u8),
418 pub end: (u8, u8),
420 pub peak_factor: f64,
422}
423
424#[derive(Debug, Clone)]
426pub struct SchedulerMetrics {
427 pub tasks_scheduled: u64,
429 pub avg_scheduling_latency: Duration,
431 pub resource_efficiency: f64,
433 pub deadline_miss_rate: f64,
435 pub fairness_index: f64,
437 pub custom_metrics: HashMap<String, f64>,
439}
440
441#[derive(Debug, Clone)]
443pub enum AdvancedSchedulingStrategy {
444 MLAdaptive {
446 model_path: String,
447 feature_extractors: Vec<String>,
448 },
449 GeneticOptimization {
451 population_size: usize,
452 generations: usize,
453 mutation_rate: f64,
454 },
455 MultiObjective {
457 objectives: Vec<SchedulingObjective>,
458 weights: Vec<f64>,
459 },
460 ReinforcementLearning {
462 agent_type: String,
463 learning_rate: f64,
464 exploration_rate: f64,
465 },
466 GameTheory {
468 strategy_type: GameTheoryStrategy,
469 coalition_formation: bool,
470 },
471 QuantumInspired {
473 quantum_operators: Vec<String>,
474 entanglement_depth: usize,
475 },
476}
477
478#[derive(Debug, Clone)]
480pub enum SchedulingObjective {
481 MinimizeMakespan,
483 MinimizeResourceUsage,
485 MaximizeThroughput,
487 MinimizeEnergy,
489 MaximizeFairness,
491 MinimizeDeadlineViolations,
493 Custom {
495 name: String,
496 objective_fn: fn(&[ScheduledTask], &ResourcePool) -> f64,
497 },
498}
499
500#[derive(Debug, Clone)]
502pub enum GameTheoryStrategy {
503 NashEquilibrium,
505 Stackelberg,
507 Cooperative,
509 Auction,
511}
512
513pub struct MultiLevelFeedbackScheduler {
515 name: String,
516 queues: Vec<PriorityQueue>,
517 time_quantum: Vec<Duration>,
518 promotion_threshold: Vec<u32>,
519 demotion_threshold: Vec<u32>,
520 aging_factor: f64,
521 metrics: SchedulerMetrics,
522}
523
524#[derive(Debug)]
526struct PriorityQueue {
527 tasks: VecDeque<ScheduledTask>,
528 priority_level: u8,
529 time_slice: Duration,
530}
531
532pub struct FairShareScheduler {
534 name: String,
535 user_shares: HashMap<String, f64>,
536 group_shares: HashMap<String, f64>,
537 usage_history: HashMap<String, Vec<ResourceUsage>>,
538 decay_factor: f64,
539 metrics: SchedulerMetrics,
540}
541
542pub struct DeadlineAwareScheduler {
544 name: String,
545 deadline_weight: f64,
546 urgency_factor: f64,
547 preemption_enabled: bool,
548 grace_period: Duration,
549 metrics: SchedulerMetrics,
550}
551
552pub struct ResourceAwareScheduler {
554 name: String,
555 resource_weights: HashMap<String, f64>,
556 load_balancing_strategy: LoadBalancingStrategy,
557 prediction_window: Duration,
558 efficiency_threshold: f64,
559 metrics: SchedulerMetrics,
560}
561
562#[derive(Debug, Clone)]
564pub enum LoadBalancingStrategy {
565 RoundRobin,
567 LeastLoaded,
569 WeightedRoundRobin { weights: HashMap<String, f64> },
571 Random,
573 ConsistentHashing { virtual_nodes: usize },
575}
576
577pub struct MLAdaptiveScheduler {
579 name: String,
580 model_type: MLModelType,
581 feature_extractors: Vec<Box<dyn FeatureExtractor>>,
582 training_data: Vec<SchedulingDecision>,
583 prediction_accuracy: f64,
584 retraining_threshold: usize,
585 metrics: SchedulerMetrics,
586}
587
588#[derive(Debug, Clone)]
590pub enum MLModelType {
591 DecisionTree,
593 RandomForest { n_trees: usize },
595 NeuralNetwork { layers: Vec<usize> },
597 SVM { kernel: String },
599 ReinforcementLearning { algorithm: String },
601}
602
603pub trait FeatureExtractor: Send + Sync {
605 fn extract_features(&self, context: &SchedulingContext) -> Vec<f64>;
607
608 fn feature_names(&self) -> Vec<String>;
610}
611
612#[derive(Debug, Clone)]
614pub struct SchedulingDecision {
615 pub features: Vec<f64>,
617 pub chosen_task: TaskId,
619 pub outcome: DecisionOutcome,
621 pub timestamp: SystemTime,
623}
624
625#[derive(Debug, Clone)]
627pub struct DecisionOutcome {
628 pub completion_time: Duration,
630 pub resource_utilization: f64,
632 pub deadline_met: bool,
634 pub satisfaction_score: f64,
636}
637
638#[derive(Debug)]
640pub struct TaskScheduler {
641 strategy: SchedulingStrategy,
643 pluggable_schedulers: HashMap<String, Box<dyn PluggableScheduler>>,
645 active_scheduler: Option<String>,
647 task_queue: Arc<Mutex<BinaryHeap<PriorityTask>>>,
649 task_states: Arc<RwLock<HashMap<TaskId, TaskState>>>,
651 resource_pool: Arc<RwLock<ResourcePool>>,
653 dependency_graph: Arc<RwLock<HashMap<TaskId, HashSet<TaskId>>>>,
655 config: SchedulerConfig,
657 context: Arc<RwLock<SchedulingContext>>,
659 task_notification: Arc<Condvar>,
661 scheduler_thread: Option<JoinHandle<()>>,
663 is_running: Arc<Mutex<bool>>,
665}
666
667#[derive(Debug, Clone)]
669pub struct SchedulerConfig {
670 pub max_concurrent_tasks: usize,
672 pub scheduling_interval: Duration,
674 pub monitoring_interval: Duration,
676 pub default_task_timeout: Duration,
678 pub cleanup_interval: Duration,
680 pub max_task_history: usize,
682}
683
684impl Default for SchedulerConfig {
685 fn default() -> Self {
686 Self {
687 max_concurrent_tasks: 10,
688 scheduling_interval: Duration::from_millis(100),
689 monitoring_interval: Duration::from_secs(1),
690 default_task_timeout: Duration::from_secs(3600),
691 cleanup_interval: Duration::from_secs(300),
692 max_task_history: 10000,
693 }
694 }
695}
696
697impl TaskScheduler {
698 #[must_use]
700 pub fn new(strategy: SchedulingStrategy, config: SchedulerConfig) -> Self {
701 Self {
702 strategy,
703 pluggable_schedulers: HashMap::new(),
704 active_scheduler: None,
705 task_queue: Arc::new(Mutex::new(BinaryHeap::new())),
706 task_states: Arc::new(RwLock::new(HashMap::new())),
707 resource_pool: Arc::new(RwLock::new(ResourcePool::default())),
708 dependency_graph: Arc::new(RwLock::new(HashMap::new())),
709 config,
710 context: Arc::new(RwLock::new(SchedulingContext::default())),
711 task_notification: Arc::new(Condvar::new()),
712 scheduler_thread: None,
713 is_running: Arc::new(Mutex::new(false)),
714 }
715 }
716
717 pub fn submit_task(&self, task: ScheduledTask) -> SklResult<()> {
719 let task_id = task.id.clone();
720
721 {
723 let mut graph = self
724 .dependency_graph
725 .write()
726 .unwrap_or_else(|e| e.into_inner());
727 graph.insert(task_id.clone(), task.dependencies.iter().cloned().collect());
728 }
729
730 {
732 let mut states = self.task_states.write().unwrap_or_else(|e| e.into_inner());
733 states.insert(task_id, TaskState::Pending);
734 }
735
736 let priority_score = self.calculate_priority_score(&task);
738
739 {
741 let mut queue = self.task_queue.lock().unwrap_or_else(|e| e.into_inner());
742 queue.push(PriorityTask {
743 task,
744 priority_score,
745 });
746 }
747
748 self.task_notification.notify_one();
750
751 Ok(())
752 }
753
754 fn calculate_priority_score(&self, task: &ScheduledTask) -> i64 {
756 let mut score = match task.priority {
757 TaskPriority::Low => 1,
758 TaskPriority::Normal => 10,
759 TaskPriority::High => 100,
760 TaskPriority::Critical => 1000,
761 };
762
763 if let Some(deadline) = task.deadline {
765 let time_to_deadline = deadline
766 .duration_since(SystemTime::now())
767 .unwrap_or(Duration::ZERO)
768 .as_secs() as i64;
769 score += 1_000_000 / (time_to_deadline + 1); }
771
772 let age = SystemTime::now()
774 .duration_since(task.submitted_at)
775 .unwrap_or(Duration::ZERO)
776 .as_secs() as i64;
777 score += age / 60; score
780 }
781
782 pub fn start(&mut self) -> SklResult<()> {
784 {
785 let mut running = self.is_running.lock().unwrap_or_else(|e| e.into_inner());
786 *running = true;
787 }
788
789 let task_queue = Arc::clone(&self.task_queue);
790 let task_states = Arc::clone(&self.task_states);
791 let resource_pool = Arc::clone(&self.resource_pool);
792 let dependency_graph = Arc::clone(&self.dependency_graph);
793 let task_notification = Arc::clone(&self.task_notification);
794 let is_running = Arc::clone(&self.is_running);
795 let config = self.config.clone();
796 let strategy = self.strategy.clone();
797
798 let handle = thread::spawn(move || {
799 Self::scheduler_loop(
800 task_queue,
801 task_states,
802 resource_pool,
803 dependency_graph,
804 task_notification,
805 is_running,
806 config,
807 strategy,
808 );
809 });
810
811 self.scheduler_thread = Some(handle);
812 Ok(())
813 }
814
815 pub fn stop(&mut self) -> SklResult<()> {
817 {
818 let mut running = self.is_running.lock().unwrap_or_else(|e| e.into_inner());
819 *running = false;
820 }
821
822 self.task_notification.notify_all();
823
824 if let Some(handle) = self.scheduler_thread.take() {
825 handle.join().map_err(|_| SklearsError::InvalidData {
826 reason: "Failed to join scheduler thread".to_string(),
827 })?;
828 }
829
830 Ok(())
831 }
832
833 fn scheduler_loop(
835 task_queue: Arc<Mutex<BinaryHeap<PriorityTask>>>,
836 task_states: Arc<RwLock<HashMap<TaskId, TaskState>>>,
837 resource_pool: Arc<RwLock<ResourcePool>>,
838 dependency_graph: Arc<RwLock<HashMap<TaskId, HashSet<TaskId>>>>,
839 task_notification: Arc<Condvar>,
840 is_running: Arc<Mutex<bool>>,
841 config: SchedulerConfig,
842 strategy: SchedulingStrategy,
843 ) {
844 let mut lock = task_queue.lock().unwrap_or_else(|e| e.into_inner());
845
846 while *is_running.lock().unwrap_or_else(|e| e.into_inner()) {
847 let ready_tasks = Self::find_ready_tasks(&task_queue, &task_states, &dependency_graph);
849
850 for task_id in ready_tasks {
852 if Self::count_running_tasks(&task_states) >= config.max_concurrent_tasks {
853 break;
854 }
855
856 if Self::can_allocate_resources(&task_id, &task_states, &resource_pool) {
857 Self::start_task_execution(&task_id, &task_states, &resource_pool);
858 }
859 }
860
861 Self::cleanup_tasks(&task_states, &config);
863
864 Self::update_resource_monitoring(&resource_pool);
866
867 let _guard = task_notification
869 .wait_timeout(lock, config.scheduling_interval)
870 .unwrap_or_else(|e| e.into_inner());
871 lock = _guard.0;
872 }
873 }
874
875 fn find_ready_tasks(
877 task_queue: &Arc<Mutex<BinaryHeap<PriorityTask>>>,
878 task_states: &Arc<RwLock<HashMap<TaskId, TaskState>>>,
879 dependency_graph: &Arc<RwLock<HashMap<TaskId, HashSet<TaskId>>>>,
880 ) -> Vec<TaskId> {
881 let mut ready_tasks = Vec::new();
882 let states = task_states.read().unwrap_or_else(|e| e.into_inner());
883 let graph = dependency_graph.read().unwrap_or_else(|e| e.into_inner());
884
885 for (task_id, state) in states.iter() {
886 if *state == TaskState::Pending {
887 if let Some(dependencies) = graph.get(task_id) {
888 let all_deps_completed = dependencies.iter().all(|dep_id| {
889 if let Some(dep_state) = states.get(dep_id) {
890 matches!(dep_state, TaskState::Completed { .. })
891 } else {
892 false
893 }
894 });
895
896 if all_deps_completed {
897 ready_tasks.push(task_id.clone());
898 }
899 }
900 }
901 }
902
903 ready_tasks
904 }
905
906 fn count_running_tasks(task_states: &Arc<RwLock<HashMap<TaskId, TaskState>>>) -> usize {
908 let states = task_states.read().unwrap_or_else(|e| e.into_inner());
909 states
910 .values()
911 .filter(|state| matches!(state, TaskState::Running { .. }))
912 .count()
913 }
914
915 fn can_allocate_resources(
917 task_id: &TaskId,
918 task_states: &Arc<RwLock<HashMap<TaskId, TaskState>>>,
919 resource_pool: &Arc<RwLock<ResourcePool>>,
920 ) -> bool {
921 let pool = resource_pool.read().unwrap_or_else(|e| e.into_inner());
923 pool.available_cpu > 0 && pool.available_memory > 100
924 }
925
926 fn start_task_execution(
928 task_id: &TaskId,
929 task_states: &Arc<RwLock<HashMap<TaskId, TaskState>>>,
930 resource_pool: &Arc<RwLock<ResourcePool>>,
931 ) {
932 let mut states = task_states.write().unwrap_or_else(|e| e.into_inner());
933 states.insert(
934 task_id.clone(),
935 TaskState::Running {
936 started_at: SystemTime::now(),
937 node_id: Some("local".to_string()),
938 },
939 );
940
941 let mut pool = resource_pool.write().unwrap_or_else(|e| e.into_inner());
943 pool.available_cpu = pool.available_cpu.saturating_sub(1);
944 pool.available_memory = pool.available_memory.saturating_sub(100);
945 }
946
947 fn cleanup_tasks(
949 task_states: &Arc<RwLock<HashMap<TaskId, TaskState>>>,
950 config: &SchedulerConfig,
951 ) {
952 let mut states = task_states.write().unwrap_or_else(|e| e.into_inner());
953
954 let cutoff_time = SystemTime::now() - config.cleanup_interval;
955 let mut to_remove = Vec::new();
956
957 for (task_id, state) in states.iter() {
958 let should_remove = match state {
959 TaskState::Completed { completed_at, .. } => *completed_at < cutoff_time,
960 TaskState::Failed { failed_at, .. } => *failed_at < cutoff_time,
961 TaskState::Cancelled { cancelled_at } => *cancelled_at < cutoff_time,
962 _ => false,
963 };
964
965 if should_remove {
966 to_remove.push(task_id.clone());
967 }
968 }
969
970 if states.len() > config.max_task_history {
972 let excess = states.len() - config.max_task_history;
973 for _ in 0..excess {
974 if let Some(oldest_id) = to_remove.first().cloned() {
975 to_remove.remove(0);
976 states.remove(&oldest_id);
977 }
978 }
979 }
980
981 for task_id in to_remove {
982 states.remove(&task_id);
983 }
984 }
985
986 fn update_resource_monitoring(resource_pool: &Arc<RwLock<ResourcePool>>) {
988 let mut pool = resource_pool.write().unwrap_or_else(|e| e.into_inner());
989
990 let utilization = ResourceUtilization {
991 timestamp: SystemTime::now(),
992 cpu_usage: 1.0 - (f64::from(pool.available_cpu) / 4.0), memory_usage: 1.0 - (pool.available_memory as f64 / 8192.0), disk_usage: 0.5, gpu_usage: 0.0,
996 };
997
998 pool.utilization_history.push(utilization);
999
1000 if pool.utilization_history.len() > 100 {
1002 pool.utilization_history.remove(0);
1003 }
1004 }
1005
1006 #[must_use]
1008 pub fn get_task_state(&self, task_id: &TaskId) -> Option<TaskState> {
1009 let states = self.task_states.read().unwrap_or_else(|e| e.into_inner());
1010 states.get(task_id).cloned()
1011 }
1012
1013 #[must_use]
1015 pub fn get_statistics(&self) -> SchedulerStatistics {
1016 let states = self.task_states.read().unwrap_or_else(|e| e.into_inner());
1017 let queue = self.task_queue.lock().unwrap_or_else(|e| e.into_inner());
1018 let pool = self.resource_pool.read().unwrap_or_else(|e| e.into_inner());
1019
1020 let pending_count = states
1021 .values()
1022 .filter(|s| matches!(s, TaskState::Pending))
1023 .count();
1024 let running_count = states
1025 .values()
1026 .filter(|s| matches!(s, TaskState::Running { .. }))
1027 .count();
1028 let completed_count = states
1029 .values()
1030 .filter(|s| matches!(s, TaskState::Completed { .. }))
1031 .count();
1032 let failed_count = states
1033 .values()
1034 .filter(|s| matches!(s, TaskState::Failed { .. }))
1035 .count();
1036
1037 SchedulerStatistics {
1039 total_tasks: states.len(),
1040 pending_tasks: pending_count,
1041 running_tasks: running_count,
1042 completed_tasks: completed_count,
1043 failed_tasks: failed_count,
1044 queued_tasks: queue.len(),
1045 resource_utilization: pool.utilization_history.last().cloned(),
1046 }
1047 }
1048
1049 pub fn cancel_task(&self, task_id: &TaskId) -> SklResult<()> {
1051 let mut states = self.task_states.write().unwrap_or_else(|e| e.into_inner());
1052
1053 if let Some(current_state) = states.get(task_id) {
1054 match current_state {
1055 TaskState::Pending | TaskState::Ready => {
1056 states.insert(
1057 task_id.clone(),
1058 TaskState::Cancelled {
1059 cancelled_at: SystemTime::now(),
1060 },
1061 );
1062 Ok(())
1063 }
1064 TaskState::Running { .. } => {
1065 states.insert(
1067 task_id.clone(),
1068 TaskState::Cancelled {
1069 cancelled_at: SystemTime::now(),
1070 },
1071 );
1072 Ok(())
1073 }
1074 _ => Err(SklearsError::InvalidInput(format!(
1075 "Cannot cancel task {task_id} in state {current_state:?}"
1076 ))),
1077 }
1078 } else {
1079 Err(SklearsError::InvalidInput(format!(
1080 "Task {task_id} not found"
1081 )))
1082 }
1083 }
1084
1085 #[must_use]
1087 pub fn list_tasks(&self) -> HashMap<TaskId, TaskState> {
1088 let states = self.task_states.read().unwrap_or_else(|e| e.into_inner());
1089 states.clone()
1090 }
1091
1092 #[must_use]
1094 pub fn get_resource_utilization(&self) -> ResourceUtilization {
1095 let pool = self.resource_pool.read().unwrap_or_else(|e| e.into_inner());
1096 pool.utilization_history
1097 .last()
1098 .cloned()
1099 .unwrap_or_else(|| ResourceUtilization {
1100 timestamp: SystemTime::now(),
1101 cpu_usage: 0.0,
1102 memory_usage: 0.0,
1103 disk_usage: 0.0,
1104 gpu_usage: 0.0,
1105 })
1106 }
1107}
1108
1109#[derive(Debug, Clone)]
1111pub struct SchedulerStatistics {
1112 pub total_tasks: usize,
1114 pub pending_tasks: usize,
1116 pub running_tasks: usize,
1118 pub completed_tasks: usize,
1120 pub failed_tasks: usize,
1122 pub queued_tasks: usize,
1124 pub resource_utilization: Option<ResourceUtilization>,
1126}
1127
1128#[derive(Debug)]
1130pub struct WorkflowManager {
1131 scheduler: TaskScheduler,
1133 workflows: Arc<RwLock<HashMap<String, Workflow>>>,
1135 workflow_instances: Arc<RwLock<HashMap<String, WorkflowInstance>>>,
1137}
1138
1139#[derive(Debug, Clone)]
1141pub struct Workflow {
1142 pub id: String,
1144 pub name: String,
1146 pub tasks: Vec<WorkflowTask>,
1148 pub config: WorkflowConfig,
1150}
1151
1152#[derive(Debug, Clone)]
1154pub struct WorkflowTask {
1155 pub id: String,
1157 pub template: TaskTemplate,
1159 pub depends_on: Vec<String>,
1161 pub config_overrides: HashMap<String, String>,
1163}
1164
1165#[derive(Debug, Clone)]
1167pub struct TaskTemplate {
1168 pub name: String,
1170 pub component_type: ComponentType,
1172 pub default_resources: ResourceRequirements,
1174 pub default_config: HashMap<String, String>,
1176}
1177
1178#[derive(Debug, Clone)]
1180pub struct WorkflowConfig {
1181 pub max_parallelism: usize,
1183 pub timeout: Duration,
1185 pub failure_strategy: WorkflowFailureStrategy,
1187 pub retry_config: RetryConfig,
1189}
1190
1191#[derive(Debug, Clone)]
1193pub enum WorkflowFailureStrategy {
1194 StopOnFailure,
1196 ContinueOnFailure,
1198 RetryFailedTasks,
1200 UseFallbackTasks,
1202}
1203
1204#[derive(Debug, Clone)]
1206pub struct WorkflowInstance {
1207 pub id: String,
1209 pub workflow_id: String,
1211 pub state: WorkflowState,
1213 pub task_instances: HashMap<String, TaskId>,
1215 pub started_at: SystemTime,
1217 pub ended_at: Option<SystemTime>,
1219 pub context: HashMap<String, String>,
1221}
1222
1223#[derive(Debug, Clone, PartialEq)]
1225pub enum WorkflowState {
1226 Starting,
1228 Running,
1230 Completed,
1232 Failed { error: String },
1234 Cancelled,
1236 Paused,
1238}
1239
1240impl WorkflowManager {
1241 #[must_use]
1243 pub fn new(scheduler: TaskScheduler) -> Self {
1244 Self {
1245 scheduler,
1246 workflows: Arc::new(RwLock::new(HashMap::new())),
1247 workflow_instances: Arc::new(RwLock::new(HashMap::new())),
1248 }
1249 }
1250
1251 pub fn register_workflow(&self, workflow: Workflow) -> SklResult<()> {
1253 let mut workflows = self.workflows.write().unwrap_or_else(|e| e.into_inner());
1254 workflows.insert(workflow.id.clone(), workflow);
1255 Ok(())
1256 }
1257
1258 pub fn start_workflow(
1260 &self,
1261 workflow_id: &str,
1262 context: HashMap<String, String>,
1263 ) -> SklResult<String> {
1264 let workflows = self.workflows.read().unwrap_or_else(|e| e.into_inner());
1265 let workflow = workflows.get(workflow_id).ok_or_else(|| {
1266 SklearsError::InvalidInput(format!("Workflow {workflow_id} not found"))
1267 })?;
1268
1269 let instance_id = format!(
1270 "{}_{}",
1271 workflow_id,
1272 SystemTime::now()
1273 .duration_since(SystemTime::UNIX_EPOCH)
1274 .unwrap_or_default()
1275 .as_millis()
1276 );
1277
1278 let instance = WorkflowInstance {
1279 id: instance_id.clone(),
1280 workflow_id: workflow_id.to_string(),
1281 state: WorkflowState::Starting,
1282 task_instances: HashMap::new(),
1283 started_at: SystemTime::now(),
1284 ended_at: None,
1285 context,
1286 };
1287
1288 {
1289 let mut instances = self
1290 .workflow_instances
1291 .write()
1292 .unwrap_or_else(|e| e.into_inner());
1293 instances.insert(instance_id.clone(), instance);
1294 }
1295
1296 self.submit_ready_tasks(&instance_id, workflow)?;
1298
1299 Ok(instance_id)
1300 }
1301
1302 fn submit_ready_tasks(&self, instance_id: &str, workflow: &Workflow) -> SklResult<()> {
1304 let ready_tasks: Vec<_> = workflow
1305 .tasks
1306 .iter()
1307 .filter(|task| task.depends_on.is_empty())
1308 .collect();
1309
1310 for task in ready_tasks {
1311 let scheduled_task = self.create_scheduled_task(instance_id, task)?;
1312 self.scheduler.submit_task(scheduled_task)?;
1313 }
1314
1315 Ok(())
1316 }
1317
1318 fn create_scheduled_task(
1320 &self,
1321 instance_id: &str,
1322 workflow_task: &WorkflowTask,
1323 ) -> SklResult<ScheduledTask> {
1324 let task_id = format!("{}_{}", instance_id, workflow_task.id);
1325
1326 Ok(ScheduledTask {
1327 id: task_id,
1328 name: workflow_task.template.name.clone(),
1329 component_type: workflow_task.template.component_type.clone(),
1330 dependencies: workflow_task
1331 .depends_on
1332 .iter()
1333 .map(|dep| format!("{instance_id}_{dep}"))
1334 .collect(),
1335 resource_requirements: workflow_task.template.default_resources.clone(),
1336 priority: TaskPriority::Normal,
1337 estimated_duration: Duration::from_secs(60),
1338 submitted_at: SystemTime::now(),
1339 deadline: None,
1340 metadata: HashMap::new(),
1341 retry_config: RetryConfig::default(),
1342 })
1343 }
1344
1345 #[must_use]
1347 pub fn get_workflow_status(&self, instance_id: &str) -> Option<WorkflowInstance> {
1348 let instances = self
1349 .workflow_instances
1350 .read()
1351 .unwrap_or_else(|e| e.into_inner());
1352 instances.get(instance_id).cloned()
1353 }
1354
1355 pub fn cancel_workflow(&self, instance_id: &str) -> SklResult<()> {
1357 let mut instances = self
1358 .workflow_instances
1359 .write()
1360 .unwrap_or_else(|e| e.into_inner());
1361
1362 if let Some(instance) = instances.get_mut(instance_id) {
1363 instance.state = WorkflowState::Cancelled;
1364 instance.ended_at = Some(SystemTime::now());
1365
1366 for task_id in instance.task_instances.values() {
1368 let _ = self.scheduler.cancel_task(task_id);
1369 }
1370
1371 Ok(())
1372 } else {
1373 Err(SklearsError::InvalidInput(format!(
1374 "Workflow instance {instance_id} not found"
1375 )))
1376 }
1377 }
1378
1379 #[must_use]
1381 pub fn list_workflow_instances(&self) -> HashMap<String, WorkflowInstance> {
1382 let instances = self
1383 .workflow_instances
1384 .read()
1385 .unwrap_or_else(|e| e.into_inner());
1386 instances.clone()
1387 }
1388}
1389
1390#[allow(non_snake_case)]
1391#[cfg(test)]
1392mod tests {
1393 use super::*;
1394
1395 #[test]
1396 fn test_scheduled_task_creation() {
1397 let task = ScheduledTask {
1398 id: "test_task".to_string(),
1399 name: "Test Task".to_string(),
1400 component_type: ComponentType::Transformer,
1401 dependencies: vec!["dep1".to_string()],
1402 resource_requirements: ResourceRequirements {
1403 cpu_cores: 1,
1404 memory_mb: 512,
1405 disk_mb: 100,
1406 gpu_required: false,
1407 estimated_duration: Duration::from_secs(60),
1408 priority: TaskPriority::Normal,
1409 },
1410 priority: TaskPriority::Normal,
1411 estimated_duration: Duration::from_secs(60),
1412 submitted_at: SystemTime::now(),
1413 deadline: None,
1414 metadata: HashMap::new(),
1415 retry_config: RetryConfig::default(),
1416 };
1417
1418 assert_eq!(task.id, "test_task");
1419 assert_eq!(task.dependencies.len(), 1);
1420 assert_eq!(task.priority, TaskPriority::Normal);
1421 }
1422
1423 #[test]
1424 fn test_task_scheduler_creation() {
1425 let config = SchedulerConfig::default();
1426 let scheduler = TaskScheduler::new(SchedulingStrategy::FIFO, config);
1427
1428 let stats = scheduler.get_statistics();
1429 assert_eq!(stats.total_tasks, 0);
1430 assert_eq!(stats.pending_tasks, 0);
1431 }
1432
1433 #[test]
1434 fn test_priority_task_ordering() {
1435 let task1 = PriorityTask {
1436 task: ScheduledTask {
1437 id: "task1".to_string(),
1438 name: "Task 1".to_string(),
1439 component_type: ComponentType::Transformer,
1440 dependencies: Vec::new(),
1441 resource_requirements: ResourceRequirements {
1442 cpu_cores: 1,
1443 memory_mb: 512,
1444 disk_mb: 100,
1445 gpu_required: false,
1446 estimated_duration: Duration::from_secs(60),
1447 priority: TaskPriority::Normal,
1448 },
1449 priority: TaskPriority::Normal,
1450 estimated_duration: Duration::from_secs(60),
1451 submitted_at: SystemTime::now(),
1452 deadline: None,
1453 metadata: HashMap::new(),
1454 retry_config: RetryConfig::default(),
1455 },
1456 priority_score: 10,
1457 };
1458
1459 let task2 = PriorityTask {
1460 task: ScheduledTask {
1461 id: "task2".to_string(),
1462 name: "Task 2".to_string(),
1463 component_type: ComponentType::Transformer,
1464 dependencies: Vec::new(),
1465 resource_requirements: ResourceRequirements {
1466 cpu_cores: 1,
1467 memory_mb: 512,
1468 disk_mb: 100,
1469 gpu_required: false,
1470 estimated_duration: Duration::from_secs(60),
1471 priority: TaskPriority::High,
1472 },
1473 priority: TaskPriority::High,
1474 estimated_duration: Duration::from_secs(60),
1475 submitted_at: SystemTime::now(),
1476 deadline: None,
1477 metadata: HashMap::new(),
1478 retry_config: RetryConfig::default(),
1479 },
1480 priority_score: 100,
1481 };
1482
1483 assert!(task2 > task1); }
1485
1486 #[test]
1487 fn test_workflow_creation() {
1488 let workflow = Workflow {
1489 id: "test_workflow".to_string(),
1490 name: "Test Workflow".to_string(),
1491 tasks: vec![WorkflowTask {
1492 id: "task1".to_string(),
1493 template: TaskTemplate {
1494 name: "Task 1".to_string(),
1495 component_type: ComponentType::Transformer,
1496 default_resources: ResourceRequirements {
1497 cpu_cores: 1,
1498 memory_mb: 512,
1499 disk_mb: 100,
1500 gpu_required: false,
1501 estimated_duration: Duration::from_secs(60),
1502 priority: TaskPriority::Normal,
1503 },
1504 default_config: HashMap::new(),
1505 },
1506 depends_on: Vec::new(),
1507 config_overrides: HashMap::new(),
1508 }],
1509 config: WorkflowConfig {
1510 max_parallelism: 5,
1511 timeout: Duration::from_secs(3600),
1512 failure_strategy: WorkflowFailureStrategy::StopOnFailure,
1513 retry_config: RetryConfig::default(),
1514 },
1515 };
1516
1517 assert_eq!(workflow.id, "test_workflow");
1518 assert_eq!(workflow.tasks.len(), 1);
1519 assert_eq!(workflow.config.max_parallelism, 5);
1520 }
1521
1522 #[test]
1523 fn test_resource_utilization() {
1524 let utilization = ResourceUtilization {
1525 timestamp: SystemTime::now(),
1526 cpu_usage: 0.5,
1527 memory_usage: 0.7,
1528 disk_usage: 0.3,
1529 gpu_usage: 0.0,
1530 };
1531
1532 assert_eq!(utilization.cpu_usage, 0.5);
1533 assert_eq!(utilization.memory_usage, 0.7);
1534 }
1535}