Skip to main content

oxigdal_distributed/
worker.rs

1//! Worker node implementation for distributed processing.
2//!
3//! This module implements worker nodes that execute geospatial processing tasks
4//! assigned by the coordinator.
5
6use crate::error::{DistributedError, Result};
7use crate::task::{Task, TaskContext, TaskId, TaskOperation, TaskResult};
8use arrow::record_batch::RecordBatch;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::{Arc, RwLock};
12use std::time::Instant;
13use tokio::sync::mpsc;
14use tracing::{debug, error, info, warn};
15
16/// Worker node configuration.
17#[derive(Debug, Clone)]
18pub struct WorkerConfig {
19    /// Unique worker identifier.
20    pub worker_id: String,
21    /// Maximum number of concurrent tasks.
22    pub max_concurrent_tasks: usize,
23    /// Memory limit in bytes.
24    pub memory_limit: u64,
25    /// Number of CPU cores available.
26    pub num_cores: usize,
27    /// Heartbeat interval in seconds.
28    pub heartbeat_interval_secs: u64,
29}
30
31impl WorkerConfig {
32    /// Create a new worker configuration.
33    pub fn new(worker_id: String) -> Self {
34        let num_cores = std::thread::available_parallelism()
35            .map(|n| n.get())
36            .unwrap_or(1);
37
38        Self {
39            worker_id,
40            max_concurrent_tasks: num_cores,
41            memory_limit: 4 * 1024 * 1024 * 1024, // 4 GB default
42            num_cores,
43            heartbeat_interval_secs: 30,
44        }
45    }
46
47    /// Set the maximum number of concurrent tasks.
48    pub fn with_max_concurrent_tasks(mut self, max: usize) -> Self {
49        self.max_concurrent_tasks = max;
50        self
51    }
52
53    /// Set the memory limit.
54    pub fn with_memory_limit(mut self, limit: u64) -> Self {
55        self.memory_limit = limit;
56        self
57    }
58
59    /// Set the number of cores.
60    pub fn with_num_cores(mut self, cores: usize) -> Self {
61        self.num_cores = cores;
62        self
63    }
64}
65
66/// Worker node status.
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum WorkerStatus {
69    /// Worker is idle and ready for tasks.
70    Idle,
71    /// Worker is executing tasks.
72    Busy,
73    /// Worker is shutting down.
74    ShuttingDown,
75    /// Worker is offline.
76    Offline,
77}
78
79/// Worker resource metrics.
80#[derive(Debug, Clone, Default)]
81pub struct WorkerMetrics {
82    /// Total tasks executed.
83    pub tasks_executed: u64,
84    /// Total tasks succeeded.
85    pub tasks_succeeded: u64,
86    /// Total tasks failed.
87    pub tasks_failed: u64,
88    /// Total execution time in milliseconds.
89    pub total_execution_time_ms: u64,
90    /// Current memory usage in bytes.
91    pub memory_usage: u64,
92    /// Number of active tasks.
93    pub active_tasks: u64,
94}
95
96impl WorkerMetrics {
97    /// Record a successful task execution.
98    pub fn record_success(&mut self, execution_time_ms: u64) {
99        self.tasks_executed += 1;
100        self.tasks_succeeded += 1;
101        self.total_execution_time_ms += execution_time_ms;
102    }
103
104    /// Record a failed task execution.
105    pub fn record_failure(&mut self, execution_time_ms: u64) {
106        self.tasks_executed += 1;
107        self.tasks_failed += 1;
108        self.total_execution_time_ms += execution_time_ms;
109    }
110
111    /// Get the success rate.
112    pub fn success_rate(&self) -> f64 {
113        if self.tasks_executed == 0 {
114            0.0
115        } else {
116            self.tasks_succeeded as f64 / self.tasks_executed as f64
117        }
118    }
119
120    /// Get the average execution time.
121    pub fn avg_execution_time_ms(&self) -> f64 {
122        if self.tasks_executed == 0 {
123            0.0
124        } else {
125            self.total_execution_time_ms as f64 / self.tasks_executed as f64
126        }
127    }
128}
129
130/// Worker node for executing distributed tasks.
131pub struct Worker {
132    /// Worker configuration.
133    config: WorkerConfig,
134    /// Current status.
135    status: Arc<RwLock<WorkerStatus>>,
136    /// Worker metrics.
137    metrics: Arc<RwLock<WorkerMetrics>>,
138    /// Currently running tasks.
139    running_tasks: Arc<RwLock<HashMap<TaskId, Instant>>>,
140    /// Shutdown signal.
141    shutdown: Arc<AtomicBool>,
142}
143
144impl Worker {
145    /// Create a new worker.
146    pub fn new(config: WorkerConfig) -> Self {
147        Self {
148            config,
149            status: Arc::new(RwLock::new(WorkerStatus::Idle)),
150            metrics: Arc::new(RwLock::new(WorkerMetrics::default())),
151            running_tasks: Arc::new(RwLock::new(HashMap::new())),
152            shutdown: Arc::new(AtomicBool::new(false)),
153        }
154    }
155
156    /// Get the worker ID.
157    pub fn worker_id(&self) -> &str {
158        &self.config.worker_id
159    }
160
161    /// Get the current status.
162    pub fn status(&self) -> WorkerStatus {
163        self.status.read().map_or(WorkerStatus::Offline, |s| *s)
164    }
165
166    /// Get the current metrics.
167    pub fn metrics(&self) -> WorkerMetrics {
168        self.metrics
169            .read()
170            .map_or_else(|_| WorkerMetrics::default(), |m| m.clone())
171    }
172
173    /// Check if the worker is available for new tasks.
174    pub fn is_available(&self) -> bool {
175        let running_count = self.running_tasks.read().map_or(0, |r| r.len());
176        running_count < self.config.max_concurrent_tasks
177            && self.status() == WorkerStatus::Idle
178            && !self.shutdown.load(Ordering::SeqCst)
179    }
180
181    /// Execute a task.
182    pub async fn execute_task(&self, task: Task, data: Arc<RecordBatch>) -> Result<TaskResult> {
183        // Check if shutdown was requested
184        if self.shutdown.load(Ordering::SeqCst) {
185            return Err(DistributedError::worker_task_failure(
186                "Worker is shutting down",
187            ));
188        }
189
190        // Update status
191        {
192            let mut status = self.status.write().map_err(|_| {
193                DistributedError::worker_task_failure("Failed to acquire status lock")
194            })?;
195            *status = WorkerStatus::Busy;
196        }
197
198        // Record task start
199        {
200            let mut running = self.running_tasks.write().map_err(|_| {
201                DistributedError::worker_task_failure("Failed to acquire running tasks lock")
202            })?;
203            running.insert(task.id, Instant::now());
204        }
205
206        // Create task context
207        let context = TaskContext::new(task.id, self.config.worker_id.clone())
208            .with_memory_limit(self.config.memory_limit)
209            .with_num_cores(self.config.num_cores);
210
211        info!(
212            "Worker {} executing task {:?}",
213            self.config.worker_id, task.id
214        );
215
216        let start = Instant::now();
217
218        // Execute the task operation
219        let result = self
220            .execute_operation(&task.operation, data, &context)
221            .await;
222
223        let execution_time_ms = start.elapsed().as_millis() as u64;
224
225        // Remove from running tasks
226        {
227            let mut running = self.running_tasks.write().map_err(|_| {
228                DistributedError::worker_task_failure("Failed to acquire running tasks lock")
229            })?;
230            running.remove(&task.id);
231        }
232
233        // Update metrics and status
234        {
235            let mut metrics = self.metrics.write().map_err(|_| {
236                DistributedError::worker_task_failure("Failed to acquire metrics lock")
237            })?;
238
239            match &result {
240                Ok(batch) => {
241                    metrics.record_success(execution_time_ms);
242                    info!(
243                        "Worker {} completed task {:?} in {}ms",
244                        self.config.worker_id, task.id, execution_time_ms
245                    );
246
247                    let task_result =
248                        TaskResult::success(task.id, batch.clone(), execution_time_ms);
249
250                    // Update status back to idle if no more tasks
251                    if self.running_tasks.read().map_or(true, |r| r.is_empty()) {
252                        if let Ok(mut status) = self.status.write() {
253                            *status = WorkerStatus::Idle;
254                        }
255                    }
256
257                    Ok(task_result)
258                }
259                Err(e) => {
260                    metrics.record_failure(execution_time_ms);
261                    error!(
262                        "Worker {} failed task {:?}: {}",
263                        self.config.worker_id, task.id, e
264                    );
265
266                    let task_result =
267                        TaskResult::failure(task.id, e.to_string(), execution_time_ms);
268
269                    // Update status back to idle if no more tasks
270                    if self.running_tasks.read().map_or(true, |r| r.is_empty()) {
271                        if let Ok(mut status) = self.status.write() {
272                            *status = WorkerStatus::Idle;
273                        }
274                    }
275
276                    Ok(task_result)
277                }
278            }
279        }
280    }
281
282    /// Execute a specific operation.
283    async fn execute_operation(
284        &self,
285        operation: &TaskOperation,
286        data: Arc<RecordBatch>,
287        _context: &TaskContext,
288    ) -> Result<Arc<RecordBatch>> {
289        match operation {
290            TaskOperation::Filter { expression } => {
291                debug!("Applying filter: {}", expression);
292                // Placeholder: In real implementation, apply filter using Arrow compute
293                Ok(data)
294            }
295            TaskOperation::CalculateIndex { index_type, bands } => {
296                debug!("Calculating index: {} with bands {:?}", index_type, bands);
297                // Placeholder: In real implementation, calculate the index
298                Ok(data)
299            }
300            TaskOperation::Reproject { target_epsg } => {
301                debug!("Reprojecting to EPSG:{}", target_epsg);
302                // Placeholder: In real implementation, reproject using oxigdal-proj
303                Ok(data)
304            }
305            TaskOperation::Resample {
306                width,
307                height,
308                method,
309            } => {
310                debug!("Resampling to {}x{} using {}", width, height, method);
311                // Placeholder: In real implementation, resample the raster
312                Ok(data)
313            }
314            TaskOperation::Clip {
315                min_x,
316                min_y,
317                max_x,
318                max_y,
319            } => {
320                debug!(
321                    "Clipping to bbox: [{}, {}, {}, {}]",
322                    min_x, min_y, max_x, max_y
323                );
324                // Placeholder: In real implementation, clip to bbox
325                Ok(data)
326            }
327            TaskOperation::Convolve {
328                kernel,
329                kernel_width,
330                kernel_height,
331            } => {
332                debug!(
333                    "Applying convolution with {}x{} kernel",
334                    kernel_width, kernel_height
335                );
336                // Placeholder: In real implementation, apply convolution
337                let _ = kernel; // Suppress unused warning
338                Ok(data)
339            }
340            TaskOperation::Custom { name, params } => {
341                debug!(
342                    "Executing custom operation: {} with params: {}",
343                    name, params
344                );
345                // Placeholder: In real implementation, execute custom operation
346                Ok(data)
347            }
348        }
349    }
350
351    /// Start the worker's heartbeat loop.
352    pub async fn start_heartbeat(&self, heartbeat_tx: mpsc::Sender<String>) -> Result<()> {
353        let worker_id = self.config.worker_id.clone();
354        let interval = self.config.heartbeat_interval_secs;
355        let shutdown = self.shutdown.clone();
356
357        tokio::spawn(async move {
358            let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(interval));
359
360            loop {
361                interval.tick().await;
362
363                if shutdown.load(Ordering::SeqCst) {
364                    debug!("Worker {} heartbeat loop shutting down", worker_id);
365                    break;
366                }
367
368                if let Err(e) = heartbeat_tx.send(worker_id.clone()).await {
369                    warn!("Failed to send heartbeat for worker {}: {}", worker_id, e);
370                    break;
371                }
372
373                debug!("Worker {} sent heartbeat", worker_id);
374            }
375        });
376
377        Ok(())
378    }
379
380    /// Initiate graceful shutdown.
381    pub async fn shutdown(&self) -> Result<()> {
382        info!("Worker {} initiating shutdown", self.config.worker_id);
383
384        self.shutdown.store(true, Ordering::SeqCst);
385
386        // Update status
387        {
388            let mut status = self.status.write().map_err(|_| {
389                DistributedError::worker_task_failure("Failed to acquire status lock")
390            })?;
391            *status = WorkerStatus::ShuttingDown;
392        }
393
394        // Wait for running tasks to complete (with timeout)
395        let timeout = tokio::time::Duration::from_secs(30);
396        let start = Instant::now();
397
398        while start.elapsed() < timeout {
399            let running_count = self.running_tasks.read().map_or(0, |r| r.len());
400            if running_count == 0 {
401                break;
402            }
403            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
404        }
405
406        // Final status update
407        {
408            let mut status = self.status.write().map_err(|_| {
409                DistributedError::worker_task_failure("Failed to acquire status lock")
410            })?;
411            *status = WorkerStatus::Offline;
412        }
413
414        info!("Worker {} shutdown complete", self.config.worker_id);
415        Ok(())
416    }
417
418    /// Get health check information.
419    pub fn health_check(&self) -> WorkerHealthCheck {
420        let metrics = self.metrics();
421        let status = self.status();
422        let running_count = self.running_tasks.read().map_or(0, |r| r.len());
423
424        WorkerHealthCheck {
425            worker_id: self.config.worker_id.clone(),
426            status,
427            is_healthy: status != WorkerStatus::Offline,
428            active_tasks: running_count,
429            total_tasks_executed: metrics.tasks_executed,
430            success_rate: metrics.success_rate(),
431            avg_execution_time_ms: metrics.avg_execution_time_ms(),
432            memory_usage: metrics.memory_usage,
433        }
434    }
435}
436
437/// Health check information for a worker.
438#[derive(Debug, Clone)]
439pub struct WorkerHealthCheck {
440    /// Worker identifier.
441    pub worker_id: String,
442    /// Current status.
443    pub status: WorkerStatus,
444    /// Whether the worker is healthy.
445    pub is_healthy: bool,
446    /// Number of active tasks.
447    pub active_tasks: usize,
448    /// Total tasks executed.
449    pub total_tasks_executed: u64,
450    /// Success rate (0.0 to 1.0).
451    pub success_rate: f64,
452    /// Average execution time in milliseconds.
453    pub avg_execution_time_ms: f64,
454    /// Current memory usage in bytes.
455    pub memory_usage: u64,
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461    use crate::task::PartitionId;
462    use arrow::array::Int32Array;
463    use arrow::datatypes::{DataType, Field, Schema};
464
465    fn create_test_batch() -> std::result::Result<Arc<RecordBatch>, Box<dyn std::error::Error>> {
466        let schema = Arc::new(Schema::new(vec![Field::new(
467            "value",
468            DataType::Int32,
469            false,
470        )]));
471
472        let array = Int32Array::from(vec![1, 2, 3, 4, 5]);
473
474        Ok(Arc::new(RecordBatch::try_new(
475            schema,
476            vec![Arc::new(array)],
477        )?))
478    }
479
480    #[test]
481    fn test_worker_config() {
482        let config = WorkerConfig::new("worker-1".to_string())
483            .with_max_concurrent_tasks(8)
484            .with_memory_limit(8 * 1024 * 1024 * 1024);
485
486        assert_eq!(config.worker_id, "worker-1");
487        assert_eq!(config.max_concurrent_tasks, 8);
488        assert_eq!(config.memory_limit, 8 * 1024 * 1024 * 1024);
489    }
490
491    #[test]
492    fn test_worker_metrics() {
493        let mut metrics = WorkerMetrics::default();
494
495        metrics.record_success(100);
496        metrics.record_success(200);
497        metrics.record_failure(150);
498
499        assert_eq!(metrics.tasks_executed, 3);
500        assert_eq!(metrics.tasks_succeeded, 2);
501        assert_eq!(metrics.tasks_failed, 1);
502        assert_eq!(metrics.total_execution_time_ms, 450);
503        assert_eq!(metrics.success_rate(), 2.0 / 3.0);
504        assert_eq!(metrics.avg_execution_time_ms(), 150.0);
505    }
506
507    #[tokio::test]
508    async fn test_worker_creation() {
509        let config = WorkerConfig::new("worker-test".to_string());
510        let worker = Worker::new(config);
511
512        assert_eq!(worker.worker_id(), "worker-test");
513        assert_eq!(worker.status(), WorkerStatus::Idle);
514        assert!(worker.is_available());
515    }
516
517    #[tokio::test]
518    async fn test_worker_execute_task() -> std::result::Result<(), Box<dyn std::error::Error>> {
519        let config = WorkerConfig::new("worker-test".to_string());
520        let worker = Worker::new(config);
521
522        let task = Task::new(
523            TaskId(1),
524            PartitionId(0),
525            TaskOperation::Filter {
526                expression: "value > 2".to_string(),
527            },
528        );
529
530        let data = create_test_batch()?;
531        let result = worker.execute_task(task, data).await;
532
533        assert!(result.is_ok());
534        let task_result = result?;
535        assert!(task_result.is_success());
536        Ok(())
537    }
538
539    #[tokio::test]
540    async fn test_worker_health_check() {
541        let config = WorkerConfig::new("worker-test".to_string());
542        let worker = Worker::new(config);
543
544        let health = worker.health_check();
545
546        assert_eq!(health.worker_id, "worker-test");
547        assert!(health.is_healthy);
548        assert_eq!(health.active_tasks, 0);
549        assert_eq!(health.total_tasks_executed, 0);
550    }
551}