1use crate::error::{DistributedError, Result};
7use arrow::record_batch::RecordBatch;
8use serde::{Deserialize, Serialize};
9use std::fmt;
10use std::sync::Arc;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct TaskId(pub u64);
15
16impl fmt::Display for TaskId {
17 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18 write!(f, "Task({})", self.0)
19 }
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
24pub struct PartitionId(pub u64);
25
26impl fmt::Display for PartitionId {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 write!(f, "Partition({})", self.0)
29 }
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
34pub enum TaskStatus {
35 Pending,
37 Running,
39 Completed,
41 Failed,
43 Cancelled,
45}
46
47impl fmt::Display for TaskStatus {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 match self {
50 Self::Pending => write!(f, "Pending"),
51 Self::Running => write!(f, "Running"),
52 Self::Completed => write!(f, "Completed"),
53 Self::Failed => write!(f, "Failed"),
54 Self::Cancelled => write!(f, "Cancelled"),
55 }
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub enum TaskOperation {
62 Filter {
64 expression: String,
66 },
67 CalculateIndex {
69 index_type: String,
71 bands: Vec<usize>,
73 },
74 Reproject {
76 target_epsg: i32,
78 },
79 Resample {
81 width: usize,
83 height: usize,
85 method: String,
87 },
88 Clip {
90 min_x: f64,
92 min_y: f64,
94 max_x: f64,
96 max_y: f64,
98 },
99 Convolve {
101 kernel: Vec<f64>,
103 kernel_width: usize,
105 kernel_height: usize,
107 },
108 Custom {
110 name: String,
112 params: String,
114 },
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct Task {
120 pub id: TaskId,
122 pub partition_id: PartitionId,
124 pub operation: TaskOperation,
126 pub status: TaskStatus,
128 pub worker_id: Option<String>,
130 pub retry_count: u32,
132 pub max_retries: u32,
134}
135
136impl Task {
137 pub fn new(id: TaskId, partition_id: PartitionId, operation: TaskOperation) -> Self {
139 Self {
140 id,
141 partition_id,
142 operation,
143 status: TaskStatus::Pending,
144 worker_id: None,
145 retry_count: 0,
146 max_retries: 3,
147 }
148 }
149
150 pub fn can_retry(&self) -> bool {
152 self.retry_count < self.max_retries
153 }
154
155 pub fn mark_running(&mut self, worker_id: String) {
157 self.status = TaskStatus::Running;
158 self.worker_id = Some(worker_id);
159 }
160
161 pub fn mark_completed(&mut self) {
163 self.status = TaskStatus::Completed;
164 }
165
166 pub fn mark_failed(&mut self) {
168 self.status = TaskStatus::Failed;
169 self.retry_count += 1;
170 }
171
172 pub fn mark_cancelled(&mut self) {
174 self.status = TaskStatus::Cancelled;
175 }
176
177 pub fn reset_for_retry(&mut self) {
179 self.status = TaskStatus::Pending;
180 self.worker_id = None;
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct TaskResult {
187 pub task_id: TaskId,
189 pub data: Option<Arc<RecordBatch>>,
191 pub execution_time_ms: u64,
193 pub error: Option<String>,
195}
196
197impl TaskResult {
198 pub fn success(task_id: TaskId, data: Arc<RecordBatch>, execution_time_ms: u64) -> Self {
200 Self {
201 task_id,
202 data: Some(data),
203 execution_time_ms,
204 error: None,
205 }
206 }
207
208 pub fn failure(task_id: TaskId, error: String, execution_time_ms: u64) -> Self {
210 Self {
211 task_id,
212 data: None,
213 execution_time_ms,
214 error: Some(error),
215 }
216 }
217
218 pub fn is_success(&self) -> bool {
220 self.error.is_none()
221 }
222
223 pub fn is_failure(&self) -> bool {
225 self.error.is_some()
226 }
227}
228
229#[derive(Debug, Clone)]
231pub struct TaskContext {
232 pub task_id: TaskId,
234 pub worker_id: String,
236 pub memory_limit: u64,
238 pub num_cores: usize,
240}
241
242impl TaskContext {
243 pub fn new(task_id: TaskId, worker_id: String) -> Self {
245 Self {
246 task_id,
247 worker_id,
248 memory_limit: 1024 * 1024 * 1024, num_cores: num_cpus(),
250 }
251 }
252
253 pub fn with_memory_limit(mut self, limit: u64) -> Self {
255 self.memory_limit = limit;
256 self
257 }
258
259 pub fn with_num_cores(mut self, cores: usize) -> Self {
261 self.num_cores = cores;
262 self
263 }
264}
265
266fn num_cpus() -> usize {
268 std::thread::available_parallelism()
269 .map(|n| n.get())
270 .unwrap_or(1)
271}
272
273#[derive(Debug)]
275pub struct TaskScheduler {
276 pending: Vec<Task>,
278 running: Vec<Task>,
280 completed: Vec<Task>,
282 failed: Vec<Task>,
284}
285
286impl TaskScheduler {
287 pub fn new() -> Self {
289 Self {
290 pending: Vec::new(),
291 running: Vec::new(),
292 completed: Vec::new(),
293 failed: Vec::new(),
294 }
295 }
296
297 pub fn add_task(&mut self, task: Task) {
299 self.pending.push(task);
300 }
301
302 pub fn next_task(&mut self) -> Option<Task> {
304 self.pending.pop()
305 }
306
307 pub fn mark_running(&mut self, mut task: Task, worker_id: String) {
309 task.mark_running(worker_id);
310 self.running.push(task);
311 }
312
313 pub fn mark_completed(&mut self, task_id: TaskId) -> Result<()> {
315 if let Some(pos) = self.running.iter().position(|t| t.id == task_id) {
316 let mut task = self.running.remove(pos);
317 task.mark_completed();
318 self.completed.push(task);
319 Ok(())
320 } else {
321 Err(DistributedError::coordinator(format!(
322 "Task {} not found in running tasks",
323 task_id
324 )))
325 }
326 }
327
328 pub fn mark_failed(&mut self, task_id: TaskId) -> Result<()> {
330 if let Some(pos) = self.running.iter().position(|t| t.id == task_id) {
331 let mut task = self.running.remove(pos);
332 task.mark_failed();
333
334 if task.can_retry() {
335 task.reset_for_retry();
336 self.pending.push(task);
337 } else {
338 self.failed.push(task);
339 }
340 Ok(())
341 } else {
342 Err(DistributedError::coordinator(format!(
343 "Task {} not found in running tasks",
344 task_id
345 )))
346 }
347 }
348
349 pub fn pending_count(&self) -> usize {
351 self.pending.len()
352 }
353
354 pub fn running_count(&self) -> usize {
356 self.running.len()
357 }
358
359 pub fn completed_count(&self) -> usize {
361 self.completed.len()
362 }
363
364 pub fn failed_count(&self) -> usize {
366 self.failed.len()
367 }
368
369 pub fn is_complete(&self) -> bool {
371 self.pending.is_empty() && self.running.is_empty()
372 }
373}
374
375impl Default for TaskScheduler {
376 fn default() -> Self {
377 Self::new()
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn test_task_creation() {
387 let task = Task::new(
388 TaskId(1),
389 PartitionId(0),
390 TaskOperation::Filter {
391 expression: "value > 10".to_string(),
392 },
393 );
394
395 assert_eq!(task.id, TaskId(1));
396 assert_eq!(task.partition_id, PartitionId(0));
397 assert_eq!(task.status, TaskStatus::Pending);
398 assert!(task.worker_id.is_none());
399 }
400
401 #[test]
402 fn test_task_lifecycle() {
403 let mut task = Task::new(
404 TaskId(1),
405 PartitionId(0),
406 TaskOperation::Filter {
407 expression: "value > 10".to_string(),
408 },
409 );
410
411 task.mark_running("worker-1".to_string());
412 assert_eq!(task.status, TaskStatus::Running);
413 assert_eq!(task.worker_id, Some("worker-1".to_string()));
414
415 task.mark_completed();
416 assert_eq!(task.status, TaskStatus::Completed);
417 }
418
419 #[test]
420 fn test_task_retry() {
421 let mut task = Task::new(
422 TaskId(1),
423 PartitionId(0),
424 TaskOperation::Filter {
425 expression: "value > 10".to_string(),
426 },
427 );
428
429 task.max_retries = 2;
430
431 assert!(task.can_retry());
432 task.mark_failed();
433 assert_eq!(task.retry_count, 1);
434 assert!(task.can_retry());
435
436 task.mark_failed();
437 assert_eq!(task.retry_count, 2);
438 assert!(!task.can_retry());
439 }
440
441 #[test]
442 fn test_task_scheduler() -> std::result::Result<(), Box<dyn std::error::Error>> {
443 let mut scheduler = TaskScheduler::new();
444
445 let task1 = Task::new(
446 TaskId(1),
447 PartitionId(0),
448 TaskOperation::Filter {
449 expression: "value > 10".to_string(),
450 },
451 );
452 let task2 = Task::new(
453 TaskId(2),
454 PartitionId(1),
455 TaskOperation::Filter {
456 expression: "value < 100".to_string(),
457 },
458 );
459
460 scheduler.add_task(task1);
461 scheduler.add_task(task2);
462
463 assert_eq!(scheduler.pending_count(), 2);
464 assert_eq!(scheduler.running_count(), 0);
465
466 let task = scheduler
467 .next_task()
468 .ok_or_else(|| Box::<dyn std::error::Error>::from("should have task"))?;
469 scheduler.mark_running(task, "worker-1".to_string());
470
471 assert_eq!(scheduler.pending_count(), 1);
472 assert_eq!(scheduler.running_count(), 1);
473
474 scheduler.mark_completed(TaskId(2))?;
475
476 assert_eq!(scheduler.running_count(), 0);
477 assert_eq!(scheduler.completed_count(), 1);
478 Ok(())
479 }
480
481 #[test]
482 fn test_task_context() {
483 let ctx = TaskContext::new(TaskId(1), "worker-1".to_string())
484 .with_memory_limit(2 * 1024 * 1024 * 1024)
485 .with_num_cores(4);
486
487 assert_eq!(ctx.task_id, TaskId(1));
488 assert_eq!(ctx.worker_id, "worker-1");
489 assert_eq!(ctx.memory_limit, 2 * 1024 * 1024 * 1024);
490 assert_eq!(ctx.num_cores, 4);
491 }
492}