Skip to main content

oxigdal_distributed/
task.rs

1//! Task definitions and management for distributed processing.
2//!
3//! This module defines the task types and execution logic for distributed
4//! geospatial processing operations.
5
6use crate::error::{DistributedError, Result};
7use arrow::record_batch::RecordBatch;
8use serde::{Deserialize, Serialize};
9use std::fmt;
10use std::sync::Arc;
11
12/// Unique identifier for a task.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct TaskId(pub u64);
15
16impl fmt::Display for TaskId {
17    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18        write!(f, "Task({})", self.0)
19    }
20}
21
22/// Unique identifier for a partition.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
24pub struct PartitionId(pub u64);
25
26impl fmt::Display for PartitionId {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        write!(f, "Partition({})", self.0)
29    }
30}
31
32/// Status of a task execution.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
34pub enum TaskStatus {
35    /// Task is pending execution.
36    Pending,
37    /// Task is currently being executed.
38    Running,
39    /// Task completed successfully.
40    Completed,
41    /// Task failed with an error.
42    Failed,
43    /// Task was cancelled.
44    Cancelled,
45}
46
47impl fmt::Display for TaskStatus {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        match self {
50            Self::Pending => write!(f, "Pending"),
51            Self::Running => write!(f, "Running"),
52            Self::Completed => write!(f, "Completed"),
53            Self::Failed => write!(f, "Failed"),
54            Self::Cancelled => write!(f, "Cancelled"),
55        }
56    }
57}
58
59/// Type of geospatial operation to perform.
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub enum TaskOperation {
62    /// Apply a filter to data.
63    Filter {
64        /// Filter expression.
65        expression: String,
66    },
67    /// Calculate a raster index (NDVI, NDWI, etc.).
68    CalculateIndex {
69        /// Index type.
70        index_type: String,
71        /// Band indices for calculation.
72        bands: Vec<usize>,
73    },
74    /// Reproject data to a different CRS.
75    Reproject {
76        /// Target EPSG code.
77        target_epsg: i32,
78    },
79    /// Resample raster data.
80    Resample {
81        /// Target width.
82        width: usize,
83        /// Target height.
84        height: usize,
85        /// Resampling method.
86        method: String,
87    },
88    /// Clip data to a bounding box.
89    Clip {
90        /// Minimum X coordinate.
91        min_x: f64,
92        /// Minimum Y coordinate.
93        min_y: f64,
94        /// Maximum X coordinate.
95        max_x: f64,
96        /// Maximum Y coordinate.
97        max_y: f64,
98    },
99    /// Apply a convolution kernel.
100    Convolve {
101        /// Kernel values.
102        kernel: Vec<f64>,
103        /// Kernel width.
104        kernel_width: usize,
105        /// Kernel height.
106        kernel_height: usize,
107    },
108    /// Custom user-defined operation.
109    Custom {
110        /// Operation name.
111        name: String,
112        /// JSON-serialized parameters.
113        params: String,
114    },
115}
116
117/// A task to be executed by a worker.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct Task {
120    /// Unique task identifier.
121    pub id: TaskId,
122    /// Partition to process.
123    pub partition_id: PartitionId,
124    /// Operation to perform.
125    pub operation: TaskOperation,
126    /// Current status.
127    pub status: TaskStatus,
128    /// Worker ID assigned to this task (if any).
129    pub worker_id: Option<String>,
130    /// Number of retry attempts.
131    pub retry_count: u32,
132    /// Maximum number of retries allowed.
133    pub max_retries: u32,
134}
135
136impl Task {
137    /// Create a new task.
138    pub fn new(id: TaskId, partition_id: PartitionId, operation: TaskOperation) -> Self {
139        Self {
140            id,
141            partition_id,
142            operation,
143            status: TaskStatus::Pending,
144            worker_id: None,
145            retry_count: 0,
146            max_retries: 3,
147        }
148    }
149
150    /// Check if the task can be retried.
151    pub fn can_retry(&self) -> bool {
152        self.retry_count < self.max_retries
153    }
154
155    /// Mark the task as running on a specific worker.
156    pub fn mark_running(&mut self, worker_id: String) {
157        self.status = TaskStatus::Running;
158        self.worker_id = Some(worker_id);
159    }
160
161    /// Mark the task as completed.
162    pub fn mark_completed(&mut self) {
163        self.status = TaskStatus::Completed;
164    }
165
166    /// Mark the task as failed and increment retry count.
167    pub fn mark_failed(&mut self) {
168        self.status = TaskStatus::Failed;
169        self.retry_count += 1;
170    }
171
172    /// Mark the task as cancelled.
173    pub fn mark_cancelled(&mut self) {
174        self.status = TaskStatus::Cancelled;
175    }
176
177    /// Reset the task for retry.
178    pub fn reset_for_retry(&mut self) {
179        self.status = TaskStatus::Pending;
180        self.worker_id = None;
181    }
182}
183
184/// Result of a task execution.
185#[derive(Debug, Clone)]
186pub struct TaskResult {
187    /// Task identifier.
188    pub task_id: TaskId,
189    /// Resulting data as Arrow RecordBatch.
190    pub data: Option<Arc<RecordBatch>>,
191    /// Execution time in milliseconds.
192    pub execution_time_ms: u64,
193    /// Error message if task failed.
194    pub error: Option<String>,
195}
196
197impl TaskResult {
198    /// Create a successful task result.
199    pub fn success(task_id: TaskId, data: Arc<RecordBatch>, execution_time_ms: u64) -> Self {
200        Self {
201            task_id,
202            data: Some(data),
203            execution_time_ms,
204            error: None,
205        }
206    }
207
208    /// Create a failed task result.
209    pub fn failure(task_id: TaskId, error: String, execution_time_ms: u64) -> Self {
210        Self {
211            task_id,
212            data: None,
213            execution_time_ms,
214            error: Some(error),
215        }
216    }
217
218    /// Check if the result indicates success.
219    pub fn is_success(&self) -> bool {
220        self.error.is_none()
221    }
222
223    /// Check if the result indicates failure.
224    pub fn is_failure(&self) -> bool {
225        self.error.is_some()
226    }
227}
228
229/// Task execution context with metadata.
230#[derive(Debug, Clone)]
231pub struct TaskContext {
232    /// Task identifier.
233    pub task_id: TaskId,
234    /// Worker identifier executing this task.
235    pub worker_id: String,
236    /// Total memory available (bytes).
237    pub memory_limit: u64,
238    /// Number of CPU cores available.
239    pub num_cores: usize,
240}
241
242impl TaskContext {
243    /// Create a new task context.
244    pub fn new(task_id: TaskId, worker_id: String) -> Self {
245        Self {
246            task_id,
247            worker_id,
248            memory_limit: 1024 * 1024 * 1024, // 1 GB default
249            num_cores: num_cpus(),
250        }
251    }
252
253    /// Set the memory limit.
254    pub fn with_memory_limit(mut self, limit: u64) -> Self {
255        self.memory_limit = limit;
256        self
257    }
258
259    /// Set the number of cores.
260    pub fn with_num_cores(mut self, cores: usize) -> Self {
261        self.num_cores = cores;
262        self
263    }
264}
265
266/// Get the number of available CPU cores.
267fn num_cpus() -> usize {
268    std::thread::available_parallelism()
269        .map(|n| n.get())
270        .unwrap_or(1)
271}
272
273/// Task scheduler for managing task execution order.
274#[derive(Debug)]
275pub struct TaskScheduler {
276    /// Queue of pending tasks.
277    pending: Vec<Task>,
278    /// Currently running tasks.
279    running: Vec<Task>,
280    /// Completed tasks.
281    completed: Vec<Task>,
282    /// Failed tasks.
283    failed: Vec<Task>,
284}
285
286impl TaskScheduler {
287    /// Create a new task scheduler.
288    pub fn new() -> Self {
289        Self {
290            pending: Vec::new(),
291            running: Vec::new(),
292            completed: Vec::new(),
293            failed: Vec::new(),
294        }
295    }
296
297    /// Add a task to the scheduler.
298    pub fn add_task(&mut self, task: Task) {
299        self.pending.push(task);
300    }
301
302    /// Get the next pending task.
303    pub fn next_task(&mut self) -> Option<Task> {
304        self.pending.pop()
305    }
306
307    /// Mark a task as running.
308    pub fn mark_running(&mut self, mut task: Task, worker_id: String) {
309        task.mark_running(worker_id);
310        self.running.push(task);
311    }
312
313    /// Mark a task as completed.
314    pub fn mark_completed(&mut self, task_id: TaskId) -> Result<()> {
315        if let Some(pos) = self.running.iter().position(|t| t.id == task_id) {
316            let mut task = self.running.remove(pos);
317            task.mark_completed();
318            self.completed.push(task);
319            Ok(())
320        } else {
321            Err(DistributedError::coordinator(format!(
322                "Task {} not found in running tasks",
323                task_id
324            )))
325        }
326    }
327
328    /// Mark a task as failed and potentially retry.
329    pub fn mark_failed(&mut self, task_id: TaskId) -> Result<()> {
330        if let Some(pos) = self.running.iter().position(|t| t.id == task_id) {
331            let mut task = self.running.remove(pos);
332            task.mark_failed();
333
334            if task.can_retry() {
335                task.reset_for_retry();
336                self.pending.push(task);
337            } else {
338                self.failed.push(task);
339            }
340            Ok(())
341        } else {
342            Err(DistributedError::coordinator(format!(
343                "Task {} not found in running tasks",
344                task_id
345            )))
346        }
347    }
348
349    /// Get the number of pending tasks.
350    pub fn pending_count(&self) -> usize {
351        self.pending.len()
352    }
353
354    /// Get the number of running tasks.
355    pub fn running_count(&self) -> usize {
356        self.running.len()
357    }
358
359    /// Get the number of completed tasks.
360    pub fn completed_count(&self) -> usize {
361        self.completed.len()
362    }
363
364    /// Get the number of failed tasks.
365    pub fn failed_count(&self) -> usize {
366        self.failed.len()
367    }
368
369    /// Check if all tasks are complete.
370    pub fn is_complete(&self) -> bool {
371        self.pending.is_empty() && self.running.is_empty()
372    }
373}
374
375impl Default for TaskScheduler {
376    fn default() -> Self {
377        Self::new()
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn test_task_creation() {
387        let task = Task::new(
388            TaskId(1),
389            PartitionId(0),
390            TaskOperation::Filter {
391                expression: "value > 10".to_string(),
392            },
393        );
394
395        assert_eq!(task.id, TaskId(1));
396        assert_eq!(task.partition_id, PartitionId(0));
397        assert_eq!(task.status, TaskStatus::Pending);
398        assert!(task.worker_id.is_none());
399    }
400
401    #[test]
402    fn test_task_lifecycle() {
403        let mut task = Task::new(
404            TaskId(1),
405            PartitionId(0),
406            TaskOperation::Filter {
407                expression: "value > 10".to_string(),
408            },
409        );
410
411        task.mark_running("worker-1".to_string());
412        assert_eq!(task.status, TaskStatus::Running);
413        assert_eq!(task.worker_id, Some("worker-1".to_string()));
414
415        task.mark_completed();
416        assert_eq!(task.status, TaskStatus::Completed);
417    }
418
419    #[test]
420    fn test_task_retry() {
421        let mut task = Task::new(
422            TaskId(1),
423            PartitionId(0),
424            TaskOperation::Filter {
425                expression: "value > 10".to_string(),
426            },
427        );
428
429        task.max_retries = 2;
430
431        assert!(task.can_retry());
432        task.mark_failed();
433        assert_eq!(task.retry_count, 1);
434        assert!(task.can_retry());
435
436        task.mark_failed();
437        assert_eq!(task.retry_count, 2);
438        assert!(!task.can_retry());
439    }
440
441    #[test]
442    fn test_task_scheduler() -> std::result::Result<(), Box<dyn std::error::Error>> {
443        let mut scheduler = TaskScheduler::new();
444
445        let task1 = Task::new(
446            TaskId(1),
447            PartitionId(0),
448            TaskOperation::Filter {
449                expression: "value > 10".to_string(),
450            },
451        );
452        let task2 = Task::new(
453            TaskId(2),
454            PartitionId(1),
455            TaskOperation::Filter {
456                expression: "value < 100".to_string(),
457            },
458        );
459
460        scheduler.add_task(task1);
461        scheduler.add_task(task2);
462
463        assert_eq!(scheduler.pending_count(), 2);
464        assert_eq!(scheduler.running_count(), 0);
465
466        let task = scheduler
467            .next_task()
468            .ok_or_else(|| Box::<dyn std::error::Error>::from("should have task"))?;
469        scheduler.mark_running(task, "worker-1".to_string());
470
471        assert_eq!(scheduler.pending_count(), 1);
472        assert_eq!(scheduler.running_count(), 1);
473
474        scheduler.mark_completed(TaskId(2))?;
475
476        assert_eq!(scheduler.running_count(), 0);
477        assert_eq!(scheduler.completed_count(), 1);
478        Ok(())
479    }
480
481    #[test]
482    fn test_task_context() {
483        let ctx = TaskContext::new(TaskId(1), "worker-1".to_string())
484            .with_memory_limit(2 * 1024 * 1024 * 1024)
485            .with_num_cores(4);
486
487        assert_eq!(ctx.task_id, TaskId(1));
488        assert_eq!(ctx.worker_id, "worker-1");
489        assert_eq!(ctx.memory_limit, 2 * 1024 * 1024 * 1024);
490        assert_eq!(ctx.num_cores, 4);
491    }
492}