1use crate::error::{ClusterError, Result};
8use dashmap::DashMap;
9use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet, VecDeque};
12use std::hash::Hash;
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15use uuid::Uuid;
16
17#[derive(Debug, Clone)]
19pub struct Task {
20 pub id: TaskId,
22
23 pub name: String,
25
26 pub task_type: String,
28
29 pub priority: i32,
31
32 pub payload: Vec<u8>,
34
35 pub dependencies: Vec<TaskId>,
37
38 pub estimated_duration: Option<Duration>,
40
41 pub resources: ResourceRequirements,
43
44 pub locality_hints: Vec<String>,
46
47 pub created_at: Instant,
49
50 pub scheduled_at: Option<Instant>,
52
53 pub started_at: Option<Instant>,
55
56 pub completed_at: Option<Instant>,
58
59 pub status: TaskStatus,
61
62 pub result: Option<TaskResult>,
64
65 pub error: Option<String>,
67
68 pub retry_count: u32,
70
71 pub checkpoint: Option<Vec<u8>>,
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
77pub struct TaskId(pub Uuid);
78
79impl TaskId {
80 pub fn new() -> Self {
82 Self(Uuid::new_v4())
83 }
84
85 pub fn from_uuid(uuid: Uuid) -> Self {
87 Self(uuid)
88 }
89}
90
91impl Default for TaskId {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97impl std::fmt::Display for TaskId {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 write!(f, "{}", self.0)
100 }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct ResourceRequirements {
106 pub cpu_cores: f64,
108
109 pub memory_bytes: u64,
111
112 pub gpu: bool,
114
115 pub storage_bytes: u64,
117}
118
119impl Default for ResourceRequirements {
120 fn default() -> Self {
121 Self {
122 cpu_cores: 1.0,
123 memory_bytes: 1024 * 1024 * 1024, gpu: false,
125 storage_bytes: 0,
126 }
127 }
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
132pub enum TaskStatus {
133 Pending,
135
136 Ready,
138
139 Scheduled,
141
142 Running,
144
145 Completed,
147
148 Failed,
150
151 Cancelled,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct TaskResult {
158 pub data: Vec<u8>,
160
161 pub duration: Duration,
163
164 pub worker_id: String,
166}
167
168pub struct TaskGraph {
170 tasks: DashMap<TaskId, Arc<RwLock<Task>>>,
172
173 dependencies: DashMap<TaskId, HashSet<TaskId>>,
175
176 dependents: DashMap<TaskId, HashSet<TaskId>>,
178
179 result_cache: DashMap<String, Arc<TaskResult>>,
181
182 plan_cache: RwLock<Option<ExecutionPlan>>,
184}
185
186#[derive(Debug, Clone)]
188pub struct ExecutionPlan {
189 pub levels: Vec<Vec<TaskId>>,
191
192 pub estimated_duration: Duration,
194
195 pub critical_path: Vec<TaskId>,
197
198 pub parallelism: Vec<usize>,
200}
201
202impl TaskGraph {
203 pub fn new() -> Self {
205 Self {
206 tasks: DashMap::new(),
207 dependencies: DashMap::new(),
208 dependents: DashMap::new(),
209 result_cache: DashMap::new(),
210 plan_cache: RwLock::new(None),
211 }
212 }
213
214 pub fn add_task(&self, mut task: Task) -> Result<TaskId> {
216 let task_id = task.id;
217
218 for dep_id in &task.dependencies {
220 if self.would_create_cycle(task_id, *dep_id)? {
221 return Err(ClusterError::DependencyCycle(format!(
222 "Adding task {} would create a cycle",
223 task_id
224 )));
225 }
226 }
227
228 if task.dependencies.is_empty() {
230 task.status = TaskStatus::Ready;
231 } else {
232 task.status = TaskStatus::Pending;
233 }
234
235 self.tasks
237 .insert(task_id, Arc::new(RwLock::new(task.clone())));
238
239 let deps: HashSet<TaskId> = task.dependencies.iter().copied().collect();
241 self.dependencies.insert(task_id, deps.clone());
242
243 for dep_id in deps {
244 self.dependents.entry(dep_id).or_default().insert(task_id);
245 }
246
247 *self.plan_cache.write() = None;
249
250 Ok(task_id)
251 }
252
253 pub fn remove_task(&self, task_id: TaskId) -> Result<()> {
255 self.tasks.remove(&task_id);
257
258 if let Some((_, deps)) = self.dependencies.remove(&task_id) {
260 for dep_id in deps {
261 if let Some(mut dependents) = self.dependents.get_mut(&dep_id) {
262 dependents.remove(&task_id);
263 }
264 }
265 }
266
267 if let Some((_, dependents)) = self.dependents.remove(&task_id) {
269 for dependent_id in dependents {
270 if let Some(mut deps) = self.dependencies.get_mut(&dependent_id) {
271 deps.remove(&task_id);
272 }
273 }
274 }
275
276 *self.plan_cache.write() = None;
278
279 Ok(())
280 }
281
282 pub fn get_task(&self, task_id: TaskId) -> Result<Arc<RwLock<Task>>> {
284 self.tasks
285 .get(&task_id)
286 .map(|entry| Arc::clone(entry.value()))
287 .ok_or_else(|| ClusterError::TaskNotFound(task_id.to_string()))
288 }
289
290 pub fn get_all_tasks(&self) -> Vec<Arc<RwLock<Task>>> {
292 self.tasks
293 .iter()
294 .map(|entry| Arc::clone(entry.value()))
295 .collect()
296 }
297
298 pub fn get_tasks_by_status(&self, status: TaskStatus) -> Vec<Arc<RwLock<Task>>> {
300 self.tasks
301 .iter()
302 .filter(|entry| entry.value().read().status == status)
303 .map(|entry| Arc::clone(entry.value()))
304 .collect()
305 }
306
307 pub fn update_task_status(&self, task_id: TaskId, status: TaskStatus) -> Result<()> {
309 let task = self.get_task(task_id)?;
310 let mut task = task.write();
311
312 let old_status = task.status;
313 task.status = status;
314
315 match status {
316 TaskStatus::Scheduled => {
317 task.scheduled_at = Some(Instant::now());
318 }
319 TaskStatus::Running => {
320 task.started_at = Some(Instant::now());
321 }
322 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
323 task.completed_at = Some(Instant::now());
324
325 if status == TaskStatus::Completed {
327 drop(task); self.update_dependents(task_id)?;
329 }
330 }
331 _ => {}
332 }
333
334 if old_status != status {
336 *self.plan_cache.write() = None;
337 }
338
339 Ok(())
340 }
341
342 pub fn set_task_result(&self, task_id: TaskId, result: TaskResult) -> Result<()> {
344 let task = self.get_task(task_id)?;
345 let mut task = task.write();
346
347 task.result = Some(result.clone());
348 task.status = TaskStatus::Completed;
349 task.completed_at = Some(Instant::now());
350
351 if !task.name.is_empty() {
353 self.result_cache
354 .insert(task.name.clone(), Arc::new(result));
355 }
356
357 Ok(())
358 }
359
360 pub fn set_task_error(&self, task_id: TaskId, error: String) -> Result<()> {
362 let task = self.get_task(task_id)?;
363 let mut task = task.write();
364
365 task.error = Some(error);
366 task.status = TaskStatus::Failed;
367 task.completed_at = Some(Instant::now());
368
369 Ok(())
370 }
371
372 pub fn get_cached_result(&self, name: &str) -> Option<Arc<TaskResult>> {
374 self.result_cache
375 .get(name)
376 .map(|entry| Arc::clone(entry.value()))
377 }
378
379 pub fn clear_result_cache(&self) {
381 self.result_cache.clear();
382 }
383
384 fn would_create_cycle(&self, from: TaskId, to: TaskId) -> Result<bool> {
386 if from == to {
387 return Ok(true);
388 }
389
390 let mut visited = HashSet::new();
392 let mut queue = VecDeque::new();
393 queue.push_back(to);
394 visited.insert(to);
395
396 while let Some(current) = queue.pop_front() {
397 if current == from {
398 return Ok(true);
399 }
400
401 if let Some(deps) = self.dependencies.get(¤t) {
402 for dep in deps.iter() {
403 if visited.insert(*dep) {
404 queue.push_back(*dep);
405 }
406 }
407 }
408 }
409
410 Ok(false)
411 }
412
413 fn update_dependents(&self, completed_task_id: TaskId) -> Result<()> {
415 if let Some(dependents) = self.dependents.get(&completed_task_id) {
416 for dependent_id in dependents.iter() {
417 let dependent_task = self.get_task(*dependent_id)?;
418 let mut dependent_task = dependent_task.write();
419
420 let all_deps_completed = dependent_task.dependencies.iter().all(|dep_id| {
422 self.tasks
423 .get(dep_id)
424 .map(|t| t.read().status == TaskStatus::Completed)
425 .unwrap_or(false)
426 });
427
428 if all_deps_completed && dependent_task.status == TaskStatus::Pending {
429 dependent_task.status = TaskStatus::Ready;
430 }
431 }
432 }
433
434 Ok(())
435 }
436
437 pub fn build_execution_plan(&self) -> Result<ExecutionPlan> {
439 {
441 let cache = self.plan_cache.read();
442 if let Some(plan) = cache.as_ref() {
443 return Ok(plan.clone());
444 }
445 }
446
447 let mut in_degrees = HashMap::new();
449 for task_entry in self.tasks.iter() {
450 let task_id = *task_entry.key();
451 let task = task_entry.value().read();
452
453 if task.status == TaskStatus::Completed
454 || task.status == TaskStatus::Failed
455 || task.status == TaskStatus::Cancelled
456 {
457 continue;
458 }
459
460 let deps = self
461 .dependencies
462 .get(&task_id)
463 .map(|d| d.len())
464 .unwrap_or(0);
465 in_degrees.insert(task_id, deps);
466 }
467
468 let mut levels = Vec::new();
470 let mut current_level = Vec::new();
471 let mut task_levels = HashMap::new();
472
473 for (task_id, degree) in &in_degrees {
475 if *degree == 0 {
476 current_level.push(*task_id);
477 task_levels.insert(*task_id, 0);
478 }
479 }
480
481 let mut level_idx = 0;
482 while !current_level.is_empty() {
483 levels.push(current_level.clone());
484
485 let mut next_level = Vec::new();
486
487 for task_id in ¤t_level {
488 if let Some(dependents) = self.dependents.get(task_id) {
489 for dependent_id in dependents.iter() {
490 if !in_degrees.contains_key(dependent_id) {
491 continue;
492 }
493
494 let new_degree = in_degrees
495 .get(dependent_id)
496 .copied()
497 .unwrap_or(0)
498 .saturating_sub(1);
499 in_degrees.insert(*dependent_id, new_degree);
500
501 if new_degree == 0 {
502 next_level.push(*dependent_id);
503 task_levels.insert(*dependent_id, level_idx + 1);
504 }
505 }
506 }
507 }
508
509 current_level = next_level;
510 level_idx += 1;
511 }
512
513 let remaining: Vec<_> = in_degrees
515 .iter()
516 .filter(|&(_, °ree)| degree > 0)
517 .map(|(id, _)| *id)
518 .collect();
519
520 if !remaining.is_empty() {
521 return Err(ClusterError::DependencyCycle(format!(
522 "Cycle detected involving tasks: {:?}",
523 remaining
524 )));
525 }
526
527 let parallelism: Vec<usize> = levels.iter().map(|level| level.len()).collect();
529
530 let critical_path = self.compute_critical_path(&task_levels);
532
533 let estimated_duration = self.estimate_total_duration(&levels);
535
536 let plan = ExecutionPlan {
537 levels,
538 estimated_duration,
539 critical_path,
540 parallelism,
541 };
542
543 *self.plan_cache.write() = Some(plan.clone());
545
546 Ok(plan)
547 }
548
549 fn compute_critical_path(&self, task_levels: &HashMap<TaskId, usize>) -> Vec<TaskId> {
551 let mut longest_path = Vec::new();
552 let mut max_duration = Duration::from_secs(0);
553
554 let max_level = task_levels.values().max().copied().unwrap_or(0);
556
557 for (task_id, level) in task_levels {
559 if *level == max_level {
560 let path = self.trace_longest_path(*task_id);
561 let path_duration: Duration = path
562 .iter()
563 .filter_map(|id| self.tasks.get(id).and_then(|t| t.read().estimated_duration))
564 .sum();
565
566 if path_duration > max_duration {
567 max_duration = path_duration;
568 longest_path = path;
569 }
570 }
571 }
572
573 longest_path
574 }
575
576 fn trace_longest_path(&self, task_id: TaskId) -> Vec<TaskId> {
578 let mut path = vec![task_id];
579 let mut current = task_id;
580
581 loop {
582 let deps = self.dependencies.get(¤t);
583 if deps.is_none() || deps.as_ref().map(|d| d.is_empty()).unwrap_or(true) {
584 break;
585 }
586
587 let longest_dep = deps.as_ref().and_then(|deps| {
589 deps.iter()
590 .max_by_key(|dep_id| {
591 self.tasks
592 .get(dep_id)
593 .and_then(|t| t.read().estimated_duration)
594 .unwrap_or(Duration::from_secs(0))
595 })
596 .copied()
597 });
598
599 match longest_dep {
600 Some(dep_id) => {
601 path.push(dep_id);
602 current = dep_id;
603 }
604 None => break,
605 }
606 }
607
608 path.reverse();
609 path
610 }
611
612 fn estimate_total_duration(&self, levels: &[Vec<TaskId>]) -> Duration {
614 levels
615 .iter()
616 .map(|level| {
617 level
618 .iter()
619 .filter_map(|id| self.tasks.get(id).and_then(|t| t.read().estimated_duration))
620 .max()
621 .unwrap_or(Duration::from_secs(0))
622 })
623 .sum()
624 }
625
626 pub fn optimize_fusion(&self) -> Result<Vec<(TaskId, TaskId)>> {
628 let mut fused_pairs = Vec::new();
629
630 for task_entry in self.tasks.iter() {
632 let task_id = *task_entry.key();
633 let task = task_entry.value().read();
634
635 if let Some(dependents) = self.dependents.get(&task_id) {
636 if dependents.len() == 1 {
638 let dependent_id = *dependents.iter().next().ok_or_else(|| {
639 ClusterError::InvalidState("Empty dependents set".to_string())
640 })?;
641
642 let dependent = self.get_task(dependent_id)?;
643 let dependent = dependent.read();
644
645 if self.can_fuse_tasks(&task, &dependent) {
647 fused_pairs.push((task_id, dependent_id));
648 }
649 }
650 }
651 }
652
653 Ok(fused_pairs)
654 }
655
656 fn can_fuse_tasks(&self, task1: &Task, task2: &Task) -> bool {
658 if task1.task_type != task2.task_type {
660 return false;
661 }
662
663 if task1.resources.gpu != task2.resources.gpu {
665 return false;
666 }
667
668 let task1_dur = task1.estimated_duration.unwrap_or(Duration::from_secs(1));
670 let task2_dur = task2.estimated_duration.unwrap_or(Duration::from_secs(1));
671
672 if task1_dur + task2_dur > Duration::from_secs(60) {
673 return false;
674 }
675
676 true
677 }
678
679 pub fn prune_completed(&self) -> Result<usize> {
681 let completed_tasks: Vec<TaskId> = self
682 .tasks
683 .iter()
684 .filter(|entry| {
685 matches!(
686 entry.value().read().status,
687 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
688 )
689 })
690 .map(|entry| *entry.key())
691 .collect();
692
693 let count = completed_tasks.len();
694
695 for task_id in completed_tasks {
696 self.remove_task(task_id)?;
697 }
698
699 Ok(count)
700 }
701
702 pub fn get_statistics(&self) -> TaskGraphStatistics {
704 let total_tasks = self.tasks.len();
705 let mut status_counts = HashMap::new();
706
707 for entry in self.tasks.iter() {
708 let status = entry.value().read().status;
709 *status_counts.entry(status).or_insert(0) += 1;
710 }
711
712 let total_edges = self.dependencies.iter().map(|e| e.value().len()).sum();
713
714 TaskGraphStatistics {
715 total_tasks,
716 status_counts,
717 total_edges,
718 cached_results: self.result_cache.len(),
719 }
720 }
721}
722
723impl Default for TaskGraph {
724 fn default() -> Self {
725 Self::new()
726 }
727}
728
729#[derive(Debug, Clone, Serialize, Deserialize)]
731pub struct TaskGraphStatistics {
732 pub total_tasks: usize,
734
735 pub status_counts: HashMap<TaskStatus, usize>,
737
738 pub total_edges: usize,
740
741 pub cached_results: usize,
743}
744
745#[cfg(test)]
746#[allow(clippy::expect_used, clippy::unwrap_used)]
747mod tests {
748 use super::*;
749
750 fn create_test_task(name: &str, dependencies: Vec<TaskId>) -> Task {
751 Task {
752 id: TaskId::new(),
753 name: name.to_string(),
754 task_type: "test".to_string(),
755 priority: 0,
756 payload: vec![],
757 dependencies,
758 estimated_duration: Some(Duration::from_secs(1)),
759 resources: ResourceRequirements::default(),
760 locality_hints: vec![],
761 created_at: Instant::now(),
762 scheduled_at: None,
763 started_at: None,
764 completed_at: None,
765 status: TaskStatus::Pending,
766 result: None,
767 error: None,
768 retry_count: 0,
769 checkpoint: None,
770 }
771 }
772
773 #[test]
774 fn test_task_graph_creation() {
775 let graph = TaskGraph::new();
776 assert_eq!(graph.tasks.len(), 0);
777 }
778
779 #[test]
780 fn test_add_task() {
781 let graph = TaskGraph::new();
782 let task = create_test_task("task1", vec![]);
783
784 let result = graph.add_task(task);
785 assert!(result.is_ok());
786 assert_eq!(graph.tasks.len(), 1);
787 }
788
789 #[test]
790 fn test_task_dependencies() {
791 let graph = TaskGraph::new();
792
793 let task1 = create_test_task("task1", vec![]);
794 let task1_id = graph.add_task(task1).ok().unwrap_or_default();
795
796 let task2 = create_test_task("task2", vec![task1_id]);
797 let result = graph.add_task(task2);
798
799 assert!(result.is_ok());
800 assert_eq!(graph.tasks.len(), 2);
801 }
802
803 #[test]
804 fn test_cycle_detection() {
805 let graph = TaskGraph::new();
806
807 let task1 = create_test_task("task1", vec![]);
808 let task1_id = task1.id;
809 let _ = graph.add_task(task1);
810
811 let task2 = create_test_task("task2", vec![task1_id]);
812 let task2_id = task2.id;
813 let _ = graph.add_task(task2);
814
815 let result = graph.would_create_cycle(task1_id, task2_id);
817
818 assert!(result.is_ok());
819 assert!(result.ok().unwrap_or(false));
820 }
821
822 #[test]
823 fn test_execution_plan() {
824 let graph = TaskGraph::new();
825
826 let task1 = create_test_task("task1", vec![]);
827 let task1_id = graph.add_task(task1).ok().unwrap_or_default();
828
829 let task2 = create_test_task("task2", vec![task1_id]);
830 graph.add_task(task2).ok();
831
832 let plan = graph.build_execution_plan();
833 assert!(plan.is_ok());
834
835 let plan = plan.ok();
836 if let Some(plan) = plan {
837 assert_eq!(plan.levels.len(), 2);
838 assert_eq!(plan.levels[0].len(), 1);
839 assert_eq!(plan.levels[1].len(), 1);
840 }
841 }
842
843 #[test]
844 fn test_task_status_update() {
845 let graph = TaskGraph::new();
846 let task = create_test_task("task1", vec![]);
847 let task_id = graph.add_task(task).ok().unwrap_or_default();
848
849 let result = graph.update_task_status(task_id, TaskStatus::Running);
850 assert!(result.is_ok());
851
852 let task = graph.get_task(task_id);
853 assert!(task.is_ok());
854 if let Ok(task) = task {
855 assert_eq!(task.read().status, TaskStatus::Running);
856 }
857 }
858}