use crate::error::{DistributedError, Result};
use arrow::record_batch::RecordBatch;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TaskId(pub u64);
impl fmt::Display for TaskId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Task({})", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct PartitionId(pub u64);
impl fmt::Display for PartitionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Partition({})", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TaskStatus {
Pending,
Running,
Completed,
Failed,
Cancelled,
}
impl fmt::Display for TaskStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Pending => write!(f, "Pending"),
Self::Running => write!(f, "Running"),
Self::Completed => write!(f, "Completed"),
Self::Failed => write!(f, "Failed"),
Self::Cancelled => write!(f, "Cancelled"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TaskOperation {
Filter {
expression: String,
},
CalculateIndex {
index_type: String,
bands: Vec<usize>,
},
Reproject {
target_epsg: i32,
},
Resample {
width: usize,
height: usize,
method: String,
},
Clip {
min_x: f64,
min_y: f64,
max_x: f64,
max_y: f64,
},
Convolve {
kernel: Vec<f64>,
kernel_width: usize,
kernel_height: usize,
},
Custom {
name: String,
params: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Task {
pub id: TaskId,
pub partition_id: PartitionId,
pub operation: TaskOperation,
pub status: TaskStatus,
pub worker_id: Option<String>,
pub retry_count: u32,
pub max_retries: u32,
}
impl Task {
pub fn new(id: TaskId, partition_id: PartitionId, operation: TaskOperation) -> Self {
Self {
id,
partition_id,
operation,
status: TaskStatus::Pending,
worker_id: None,
retry_count: 0,
max_retries: 3,
}
}
pub fn can_retry(&self) -> bool {
self.retry_count < self.max_retries
}
pub fn mark_running(&mut self, worker_id: String) {
self.status = TaskStatus::Running;
self.worker_id = Some(worker_id);
}
pub fn mark_completed(&mut self) {
self.status = TaskStatus::Completed;
}
pub fn mark_failed(&mut self) {
self.status = TaskStatus::Failed;
self.retry_count += 1;
}
pub fn mark_cancelled(&mut self) {
self.status = TaskStatus::Cancelled;
}
pub fn reset_for_retry(&mut self) {
self.status = TaskStatus::Pending;
self.worker_id = None;
}
}
#[derive(Debug, Clone)]
pub struct TaskResult {
pub task_id: TaskId,
pub data: Option<Arc<RecordBatch>>,
pub execution_time_ms: u64,
pub error: Option<String>,
}
impl TaskResult {
pub fn success(task_id: TaskId, data: Arc<RecordBatch>, execution_time_ms: u64) -> Self {
Self {
task_id,
data: Some(data),
execution_time_ms,
error: None,
}
}
pub fn failure(task_id: TaskId, error: String, execution_time_ms: u64) -> Self {
Self {
task_id,
data: None,
execution_time_ms,
error: Some(error),
}
}
pub fn is_success(&self) -> bool {
self.error.is_none()
}
pub fn is_failure(&self) -> bool {
self.error.is_some()
}
}
#[derive(Debug, Clone)]
pub struct TaskContext {
pub task_id: TaskId,
pub worker_id: String,
pub memory_limit: u64,
pub num_cores: usize,
}
impl TaskContext {
pub fn new(task_id: TaskId, worker_id: String) -> Self {
Self {
task_id,
worker_id,
memory_limit: 1024 * 1024 * 1024, num_cores: num_cpus(),
}
}
pub fn with_memory_limit(mut self, limit: u64) -> Self {
self.memory_limit = limit;
self
}
pub fn with_num_cores(mut self, cores: usize) -> Self {
self.num_cores = cores;
self
}
}
fn num_cpus() -> usize {
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
}
#[derive(Debug)]
pub struct TaskScheduler {
pending: Vec<Task>,
running: Vec<Task>,
completed: Vec<Task>,
failed: Vec<Task>,
}
impl TaskScheduler {
pub fn new() -> Self {
Self {
pending: Vec::new(),
running: Vec::new(),
completed: Vec::new(),
failed: Vec::new(),
}
}
pub fn add_task(&mut self, task: Task) {
self.pending.push(task);
}
pub fn next_task(&mut self) -> Option<Task> {
self.pending.pop()
}
pub fn mark_running(&mut self, mut task: Task, worker_id: String) {
task.mark_running(worker_id);
self.running.push(task);
}
pub fn mark_completed(&mut self, task_id: TaskId) -> Result<()> {
if let Some(pos) = self.running.iter().position(|t| t.id == task_id) {
let mut task = self.running.remove(pos);
task.mark_completed();
self.completed.push(task);
Ok(())
} else {
Err(DistributedError::coordinator(format!(
"Task {} not found in running tasks",
task_id
)))
}
}
pub fn mark_failed(&mut self, task_id: TaskId) -> Result<()> {
if let Some(pos) = self.running.iter().position(|t| t.id == task_id) {
let mut task = self.running.remove(pos);
task.mark_failed();
if task.can_retry() {
task.reset_for_retry();
self.pending.push(task);
} else {
self.failed.push(task);
}
Ok(())
} else {
Err(DistributedError::coordinator(format!(
"Task {} not found in running tasks",
task_id
)))
}
}
pub fn pending_count(&self) -> usize {
self.pending.len()
}
pub fn running_count(&self) -> usize {
self.running.len()
}
pub fn completed_count(&self) -> usize {
self.completed.len()
}
pub fn failed_count(&self) -> usize {
self.failed.len()
}
pub fn is_complete(&self) -> bool {
self.pending.is_empty() && self.running.is_empty()
}
}
impl Default for TaskScheduler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_creation() {
let task = Task::new(
TaskId(1),
PartitionId(0),
TaskOperation::Filter {
expression: "value > 10".to_string(),
},
);
assert_eq!(task.id, TaskId(1));
assert_eq!(task.partition_id, PartitionId(0));
assert_eq!(task.status, TaskStatus::Pending);
assert!(task.worker_id.is_none());
}
#[test]
fn test_task_lifecycle() {
let mut task = Task::new(
TaskId(1),
PartitionId(0),
TaskOperation::Filter {
expression: "value > 10".to_string(),
},
);
task.mark_running("worker-1".to_string());
assert_eq!(task.status, TaskStatus::Running);
assert_eq!(task.worker_id, Some("worker-1".to_string()));
task.mark_completed();
assert_eq!(task.status, TaskStatus::Completed);
}
#[test]
fn test_task_retry() {
let mut task = Task::new(
TaskId(1),
PartitionId(0),
TaskOperation::Filter {
expression: "value > 10".to_string(),
},
);
task.max_retries = 2;
assert!(task.can_retry());
task.mark_failed();
assert_eq!(task.retry_count, 1);
assert!(task.can_retry());
task.mark_failed();
assert_eq!(task.retry_count, 2);
assert!(!task.can_retry());
}
#[test]
fn test_task_scheduler() -> std::result::Result<(), Box<dyn std::error::Error>> {
let mut scheduler = TaskScheduler::new();
let task1 = Task::new(
TaskId(1),
PartitionId(0),
TaskOperation::Filter {
expression: "value > 10".to_string(),
},
);
let task2 = Task::new(
TaskId(2),
PartitionId(1),
TaskOperation::Filter {
expression: "value < 100".to_string(),
},
);
scheduler.add_task(task1);
scheduler.add_task(task2);
assert_eq!(scheduler.pending_count(), 2);
assert_eq!(scheduler.running_count(), 0);
let task = scheduler
.next_task()
.ok_or_else(|| Box::<dyn std::error::Error>::from("should have task"))?;
scheduler.mark_running(task, "worker-1".to_string());
assert_eq!(scheduler.pending_count(), 1);
assert_eq!(scheduler.running_count(), 1);
scheduler.mark_completed(TaskId(2))?;
assert_eq!(scheduler.running_count(), 0);
assert_eq!(scheduler.completed_count(), 1);
Ok(())
}
#[test]
fn test_task_context() {
let ctx = TaskContext::new(TaskId(1), "worker-1".to_string())
.with_memory_limit(2 * 1024 * 1024 * 1024)
.with_num_cores(4);
assert_eq!(ctx.task_id, TaskId(1));
assert_eq!(ctx.worker_id, "worker-1");
assert_eq!(ctx.memory_limit, 2 * 1024 * 1024 * 1024);
assert_eq!(ctx.num_cores, 4);
}
}