1use crate::error::{DistributedError, Result};
7use crate::task::{Task, TaskContext, TaskId, TaskOperation, TaskResult};
8use arrow::record_batch::RecordBatch;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::{Arc, RwLock};
12use std::time::Instant;
13use tokio::sync::mpsc;
14use tracing::{debug, error, info, warn};
15
16#[derive(Debug, Clone)]
18pub struct WorkerConfig {
19 pub worker_id: String,
21 pub max_concurrent_tasks: usize,
23 pub memory_limit: u64,
25 pub num_cores: usize,
27 pub heartbeat_interval_secs: u64,
29}
30
31impl WorkerConfig {
32 pub fn new(worker_id: String) -> Self {
34 let num_cores = std::thread::available_parallelism()
35 .map(|n| n.get())
36 .unwrap_or(1);
37
38 Self {
39 worker_id,
40 max_concurrent_tasks: num_cores,
41 memory_limit: 4 * 1024 * 1024 * 1024, num_cores,
43 heartbeat_interval_secs: 30,
44 }
45 }
46
47 pub fn with_max_concurrent_tasks(mut self, max: usize) -> Self {
49 self.max_concurrent_tasks = max;
50 self
51 }
52
53 pub fn with_memory_limit(mut self, limit: u64) -> Self {
55 self.memory_limit = limit;
56 self
57 }
58
59 pub fn with_num_cores(mut self, cores: usize) -> Self {
61 self.num_cores = cores;
62 self
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum WorkerStatus {
69 Idle,
71 Busy,
73 ShuttingDown,
75 Offline,
77}
78
79#[derive(Debug, Clone, Default)]
81pub struct WorkerMetrics {
82 pub tasks_executed: u64,
84 pub tasks_succeeded: u64,
86 pub tasks_failed: u64,
88 pub total_execution_time_ms: u64,
90 pub memory_usage: u64,
92 pub active_tasks: u64,
94}
95
96impl WorkerMetrics {
97 pub fn record_success(&mut self, execution_time_ms: u64) {
99 self.tasks_executed += 1;
100 self.tasks_succeeded += 1;
101 self.total_execution_time_ms += execution_time_ms;
102 }
103
104 pub fn record_failure(&mut self, execution_time_ms: u64) {
106 self.tasks_executed += 1;
107 self.tasks_failed += 1;
108 self.total_execution_time_ms += execution_time_ms;
109 }
110
111 pub fn success_rate(&self) -> f64 {
113 if self.tasks_executed == 0 {
114 0.0
115 } else {
116 self.tasks_succeeded as f64 / self.tasks_executed as f64
117 }
118 }
119
120 pub fn avg_execution_time_ms(&self) -> f64 {
122 if self.tasks_executed == 0 {
123 0.0
124 } else {
125 self.total_execution_time_ms as f64 / self.tasks_executed as f64
126 }
127 }
128}
129
130pub struct Worker {
132 config: WorkerConfig,
134 status: Arc<RwLock<WorkerStatus>>,
136 metrics: Arc<RwLock<WorkerMetrics>>,
138 running_tasks: Arc<RwLock<HashMap<TaskId, Instant>>>,
140 shutdown: Arc<AtomicBool>,
142}
143
144impl Worker {
145 pub fn new(config: WorkerConfig) -> Self {
147 Self {
148 config,
149 status: Arc::new(RwLock::new(WorkerStatus::Idle)),
150 metrics: Arc::new(RwLock::new(WorkerMetrics::default())),
151 running_tasks: Arc::new(RwLock::new(HashMap::new())),
152 shutdown: Arc::new(AtomicBool::new(false)),
153 }
154 }
155
156 pub fn worker_id(&self) -> &str {
158 &self.config.worker_id
159 }
160
161 pub fn status(&self) -> WorkerStatus {
163 self.status.read().map_or(WorkerStatus::Offline, |s| *s)
164 }
165
166 pub fn metrics(&self) -> WorkerMetrics {
168 self.metrics
169 .read()
170 .map_or_else(|_| WorkerMetrics::default(), |m| m.clone())
171 }
172
173 pub fn is_available(&self) -> bool {
175 let running_count = self.running_tasks.read().map_or(0, |r| r.len());
176 running_count < self.config.max_concurrent_tasks
177 && self.status() == WorkerStatus::Idle
178 && !self.shutdown.load(Ordering::SeqCst)
179 }
180
181 pub async fn execute_task(&self, task: Task, data: Arc<RecordBatch>) -> Result<TaskResult> {
183 if self.shutdown.load(Ordering::SeqCst) {
185 return Err(DistributedError::worker_task_failure(
186 "Worker is shutting down",
187 ));
188 }
189
190 {
192 let mut status = self.status.write().map_err(|_| {
193 DistributedError::worker_task_failure("Failed to acquire status lock")
194 })?;
195 *status = WorkerStatus::Busy;
196 }
197
198 {
200 let mut running = self.running_tasks.write().map_err(|_| {
201 DistributedError::worker_task_failure("Failed to acquire running tasks lock")
202 })?;
203 running.insert(task.id, Instant::now());
204 }
205
206 let context = TaskContext::new(task.id, self.config.worker_id.clone())
208 .with_memory_limit(self.config.memory_limit)
209 .with_num_cores(self.config.num_cores);
210
211 info!(
212 "Worker {} executing task {:?}",
213 self.config.worker_id, task.id
214 );
215
216 let start = Instant::now();
217
218 let result = self
220 .execute_operation(&task.operation, data, &context)
221 .await;
222
223 let execution_time_ms = start.elapsed().as_millis() as u64;
224
225 {
227 let mut running = self.running_tasks.write().map_err(|_| {
228 DistributedError::worker_task_failure("Failed to acquire running tasks lock")
229 })?;
230 running.remove(&task.id);
231 }
232
233 {
235 let mut metrics = self.metrics.write().map_err(|_| {
236 DistributedError::worker_task_failure("Failed to acquire metrics lock")
237 })?;
238
239 match &result {
240 Ok(batch) => {
241 metrics.record_success(execution_time_ms);
242 info!(
243 "Worker {} completed task {:?} in {}ms",
244 self.config.worker_id, task.id, execution_time_ms
245 );
246
247 let task_result =
248 TaskResult::success(task.id, batch.clone(), execution_time_ms);
249
250 if self.running_tasks.read().map_or(true, |r| r.is_empty()) {
252 if let Ok(mut status) = self.status.write() {
253 *status = WorkerStatus::Idle;
254 }
255 }
256
257 Ok(task_result)
258 }
259 Err(e) => {
260 metrics.record_failure(execution_time_ms);
261 error!(
262 "Worker {} failed task {:?}: {}",
263 self.config.worker_id, task.id, e
264 );
265
266 let task_result =
267 TaskResult::failure(task.id, e.to_string(), execution_time_ms);
268
269 if self.running_tasks.read().map_or(true, |r| r.is_empty()) {
271 if let Ok(mut status) = self.status.write() {
272 *status = WorkerStatus::Idle;
273 }
274 }
275
276 Ok(task_result)
277 }
278 }
279 }
280 }
281
282 async fn execute_operation(
284 &self,
285 operation: &TaskOperation,
286 data: Arc<RecordBatch>,
287 _context: &TaskContext,
288 ) -> Result<Arc<RecordBatch>> {
289 match operation {
290 TaskOperation::Filter { expression } => {
291 debug!("Applying filter: {}", expression);
292 Ok(data)
294 }
295 TaskOperation::CalculateIndex { index_type, bands } => {
296 debug!("Calculating index: {} with bands {:?}", index_type, bands);
297 Ok(data)
299 }
300 TaskOperation::Reproject { target_epsg } => {
301 debug!("Reprojecting to EPSG:{}", target_epsg);
302 Ok(data)
304 }
305 TaskOperation::Resample {
306 width,
307 height,
308 method,
309 } => {
310 debug!("Resampling to {}x{} using {}", width, height, method);
311 Ok(data)
313 }
314 TaskOperation::Clip {
315 min_x,
316 min_y,
317 max_x,
318 max_y,
319 } => {
320 debug!(
321 "Clipping to bbox: [{}, {}, {}, {}]",
322 min_x, min_y, max_x, max_y
323 );
324 Ok(data)
326 }
327 TaskOperation::Convolve {
328 kernel,
329 kernel_width,
330 kernel_height,
331 } => {
332 debug!(
333 "Applying convolution with {}x{} kernel",
334 kernel_width, kernel_height
335 );
336 let _ = kernel; Ok(data)
339 }
340 TaskOperation::Custom { name, params } => {
341 debug!(
342 "Executing custom operation: {} with params: {}",
343 name, params
344 );
345 Ok(data)
347 }
348 }
349 }
350
351 pub async fn start_heartbeat(&self, heartbeat_tx: mpsc::Sender<String>) -> Result<()> {
353 let worker_id = self.config.worker_id.clone();
354 let interval = self.config.heartbeat_interval_secs;
355 let shutdown = self.shutdown.clone();
356
357 tokio::spawn(async move {
358 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(interval));
359
360 loop {
361 interval.tick().await;
362
363 if shutdown.load(Ordering::SeqCst) {
364 debug!("Worker {} heartbeat loop shutting down", worker_id);
365 break;
366 }
367
368 if let Err(e) = heartbeat_tx.send(worker_id.clone()).await {
369 warn!("Failed to send heartbeat for worker {}: {}", worker_id, e);
370 break;
371 }
372
373 debug!("Worker {} sent heartbeat", worker_id);
374 }
375 });
376
377 Ok(())
378 }
379
380 pub async fn shutdown(&self) -> Result<()> {
382 info!("Worker {} initiating shutdown", self.config.worker_id);
383
384 self.shutdown.store(true, Ordering::SeqCst);
385
386 {
388 let mut status = self.status.write().map_err(|_| {
389 DistributedError::worker_task_failure("Failed to acquire status lock")
390 })?;
391 *status = WorkerStatus::ShuttingDown;
392 }
393
394 let timeout = tokio::time::Duration::from_secs(30);
396 let start = Instant::now();
397
398 while start.elapsed() < timeout {
399 let running_count = self.running_tasks.read().map_or(0, |r| r.len());
400 if running_count == 0 {
401 break;
402 }
403 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
404 }
405
406 {
408 let mut status = self.status.write().map_err(|_| {
409 DistributedError::worker_task_failure("Failed to acquire status lock")
410 })?;
411 *status = WorkerStatus::Offline;
412 }
413
414 info!("Worker {} shutdown complete", self.config.worker_id);
415 Ok(())
416 }
417
418 pub fn health_check(&self) -> WorkerHealthCheck {
420 let metrics = self.metrics();
421 let status = self.status();
422 let running_count = self.running_tasks.read().map_or(0, |r| r.len());
423
424 WorkerHealthCheck {
425 worker_id: self.config.worker_id.clone(),
426 status,
427 is_healthy: status != WorkerStatus::Offline,
428 active_tasks: running_count,
429 total_tasks_executed: metrics.tasks_executed,
430 success_rate: metrics.success_rate(),
431 avg_execution_time_ms: metrics.avg_execution_time_ms(),
432 memory_usage: metrics.memory_usage,
433 }
434 }
435}
436
437#[derive(Debug, Clone)]
439pub struct WorkerHealthCheck {
440 pub worker_id: String,
442 pub status: WorkerStatus,
444 pub is_healthy: bool,
446 pub active_tasks: usize,
448 pub total_tasks_executed: u64,
450 pub success_rate: f64,
452 pub avg_execution_time_ms: f64,
454 pub memory_usage: u64,
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use crate::task::PartitionId;
462 use arrow::array::Int32Array;
463 use arrow::datatypes::{DataType, Field, Schema};
464
465 fn create_test_batch() -> std::result::Result<Arc<RecordBatch>, Box<dyn std::error::Error>> {
466 let schema = Arc::new(Schema::new(vec![Field::new(
467 "value",
468 DataType::Int32,
469 false,
470 )]));
471
472 let array = Int32Array::from(vec![1, 2, 3, 4, 5]);
473
474 Ok(Arc::new(RecordBatch::try_new(
475 schema,
476 vec![Arc::new(array)],
477 )?))
478 }
479
480 #[test]
481 fn test_worker_config() {
482 let config = WorkerConfig::new("worker-1".to_string())
483 .with_max_concurrent_tasks(8)
484 .with_memory_limit(8 * 1024 * 1024 * 1024);
485
486 assert_eq!(config.worker_id, "worker-1");
487 assert_eq!(config.max_concurrent_tasks, 8);
488 assert_eq!(config.memory_limit, 8 * 1024 * 1024 * 1024);
489 }
490
491 #[test]
492 fn test_worker_metrics() {
493 let mut metrics = WorkerMetrics::default();
494
495 metrics.record_success(100);
496 metrics.record_success(200);
497 metrics.record_failure(150);
498
499 assert_eq!(metrics.tasks_executed, 3);
500 assert_eq!(metrics.tasks_succeeded, 2);
501 assert_eq!(metrics.tasks_failed, 1);
502 assert_eq!(metrics.total_execution_time_ms, 450);
503 assert_eq!(metrics.success_rate(), 2.0 / 3.0);
504 assert_eq!(metrics.avg_execution_time_ms(), 150.0);
505 }
506
507 #[tokio::test]
508 async fn test_worker_creation() {
509 let config = WorkerConfig::new("worker-test".to_string());
510 let worker = Worker::new(config);
511
512 assert_eq!(worker.worker_id(), "worker-test");
513 assert_eq!(worker.status(), WorkerStatus::Idle);
514 assert!(worker.is_available());
515 }
516
517 #[tokio::test]
518 async fn test_worker_execute_task() -> std::result::Result<(), Box<dyn std::error::Error>> {
519 let config = WorkerConfig::new("worker-test".to_string());
520 let worker = Worker::new(config);
521
522 let task = Task::new(
523 TaskId(1),
524 PartitionId(0),
525 TaskOperation::Filter {
526 expression: "value > 2".to_string(),
527 },
528 );
529
530 let data = create_test_batch()?;
531 let result = worker.execute_task(task, data).await;
532
533 assert!(result.is_ok());
534 let task_result = result?;
535 assert!(task_result.is_success());
536 Ok(())
537 }
538
539 #[tokio::test]
540 async fn test_worker_health_check() {
541 let config = WorkerConfig::new("worker-test".to_string());
542 let worker = Worker::new(config);
543
544 let health = worker.health_check();
545
546 assert_eq!(health.worker_id, "worker-test");
547 assert!(health.is_healthy);
548 assert_eq!(health.active_tasks, 0);
549 assert_eq!(health.total_tasks_executed, 0);
550 }
551}