Skip to main content

a3s_lane/
distributed.rs

1//! Distributed queue support for multi-machine parallel processing
2//!
3//! This module provides traits and implementations for distributed queue processing.
4//! The default implementation uses local multi-core parallelism, but users can
5//! implement the `DistributedQueue` trait for multi-machine distributed processing.
6
7use crate::error::Result;
8use crate::partition::{PartitionConfig, PartitionId, Partitioner};
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12use tokio::sync::mpsc;
13
14/// Worker identifier
15pub type WorkerId = String;
16
17/// Distributed command envelope for serialization across workers
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct CommandEnvelope {
20    /// Unique command ID
21    pub id: String,
22    /// Command type identifier
23    pub command_type: String,
24    /// Lane ID
25    pub lane_id: String,
26    /// Partition ID
27    pub partition_id: PartitionId,
28    /// Serialized command payload
29    pub payload: serde_json::Value,
30    /// Retry count
31    pub retry_count: u32,
32    /// Created timestamp
33    pub created_at: chrono::DateTime<chrono::Utc>,
34}
35
36/// Result of command execution from a worker
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct CommandResult {
39    /// Command ID
40    pub command_id: String,
41    /// Success or error
42    pub result: std::result::Result<serde_json::Value, String>,
43    /// Worker that executed the command
44    pub worker_id: WorkerId,
45    /// Execution duration in milliseconds
46    pub duration_ms: u64,
47}
48
49/// Distributed queue trait for multi-machine parallel processing
50///
51/// Implement this trait to enable distributed queue processing across multiple machines.
52/// The default implementation (`LocalDistributedQueue`) uses local multi-core parallelism.
53#[async_trait]
54pub trait DistributedQueue: Send + Sync {
55    /// Enqueue a command to be processed by a worker
56    async fn enqueue(&self, envelope: CommandEnvelope) -> Result<()>;
57
58    /// Dequeue a command for processing (called by workers)
59    async fn dequeue(&self, partition_id: PartitionId) -> Result<Option<CommandEnvelope>>;
60
61    /// Report command completion
62    async fn complete(&self, result: CommandResult) -> Result<()>;
63
64    /// Get the number of partitions
65    fn num_partitions(&self) -> usize;
66
67    /// Get the worker ID for this instance
68    fn worker_id(&self) -> &WorkerId;
69
70    /// Check if this instance is a coordinator (can enqueue commands)
71    fn is_coordinator(&self) -> bool;
72
73    /// Check if this instance is a worker (can dequeue and execute commands)
74    fn is_worker(&self) -> bool;
75}
76
77/// Local distributed queue implementation using multi-core parallelism
78///
79/// This is the default implementation that uses local channels for communication
80/// between partitions. Each partition runs on a separate tokio task, enabling
81/// efficient multi-core utilization.
82pub struct LocalDistributedQueue {
83    worker_id: WorkerId,
84    partition_config: PartitionConfig,
85    partitioner: Arc<dyn Partitioner>,
86    /// Channels for each partition (sender side)
87    partition_senders: Vec<mpsc::Sender<CommandEnvelope>>,
88    /// Channels for each partition (receiver side, wrapped in mutex for sharing)
89    partition_receivers: Vec<Arc<tokio::sync::Mutex<mpsc::Receiver<CommandEnvelope>>>>,
90    /// Channel for completed results
91    result_sender: mpsc::Sender<CommandResult>,
92    result_receiver: Arc<tokio::sync::Mutex<mpsc::Receiver<CommandResult>>>,
93}
94
95impl LocalDistributedQueue {
96    /// Create a new local distributed queue with the specified partition configuration
97    pub fn new(partition_config: PartitionConfig) -> Self {
98        let num_partitions = partition_config.num_partitions;
99        let partitioner = partition_config.create_partitioner();
100
101        let mut partition_senders = Vec::with_capacity(num_partitions);
102        let mut partition_receivers = Vec::with_capacity(num_partitions);
103
104        // Create channels for each partition
105        for _ in 0..num_partitions {
106            let (tx, rx) = mpsc::channel(1000); // Buffer size per partition
107            partition_senders.push(tx);
108            partition_receivers.push(Arc::new(tokio::sync::Mutex::new(rx)));
109        }
110
111        // Create result channel
112        let (result_tx, result_rx) = mpsc::channel(1000);
113
114        Self {
115            worker_id: format!("local-{}", uuid::Uuid::new_v4()),
116            partition_config,
117            partitioner,
118            partition_senders,
119            partition_receivers,
120            result_sender: result_tx,
121            result_receiver: Arc::new(tokio::sync::Mutex::new(result_rx)),
122        }
123    }
124
125    /// Create a local distributed queue that automatically uses all CPU cores
126    pub fn auto() -> Self {
127        Self::new(PartitionConfig::auto())
128    }
129
130    /// Get the partitioner
131    pub fn partitioner(&self) -> &Arc<dyn Partitioner> {
132        &self.partitioner
133    }
134
135    /// Get the partition configuration
136    pub fn partition_config(&self) -> &PartitionConfig {
137        &self.partition_config
138    }
139
140    /// Get the receiver for a specific partition (for worker tasks)
141    pub fn partition_receiver(
142        &self,
143        partition_id: PartitionId,
144    ) -> Option<Arc<tokio::sync::Mutex<mpsc::Receiver<CommandEnvelope>>>> {
145        self.partition_receivers.get(partition_id).cloned()
146    }
147
148    /// Get the result receiver (for coordinator to collect results)
149    pub fn result_receiver(&self) -> Arc<tokio::sync::Mutex<mpsc::Receiver<CommandResult>>> {
150        Arc::clone(&self.result_receiver)
151    }
152
153    /// Get the result sender (for workers to send results)
154    pub fn result_sender(&self) -> mpsc::Sender<CommandResult> {
155        self.result_sender.clone()
156    }
157}
158
159#[async_trait]
160impl DistributedQueue for LocalDistributedQueue {
161    async fn enqueue(&self, envelope: CommandEnvelope) -> Result<()> {
162        let partition_id = envelope.partition_id;
163        if partition_id >= self.partition_senders.len() {
164            return Err(crate::error::LaneError::Other(format!(
165                "Invalid partition ID: {}",
166                partition_id
167            )));
168        }
169
170        self.partition_senders[partition_id]
171            .send(envelope)
172            .await
173            .map_err(|e| crate::error::LaneError::Other(format!("Failed to enqueue: {}", e)))?;
174
175        Ok(())
176    }
177
178    async fn dequeue(&self, partition_id: PartitionId) -> Result<Option<CommandEnvelope>> {
179        if partition_id >= self.partition_receivers.len() {
180            return Err(crate::error::LaneError::Other(format!(
181                "Invalid partition ID: {}",
182                partition_id
183            )));
184        }
185
186        let mut receiver = self.partition_receivers[partition_id].lock().await;
187        match receiver.try_recv() {
188            Ok(envelope) => Ok(Some(envelope)),
189            Err(mpsc::error::TryRecvError::Empty) => Ok(None),
190            Err(mpsc::error::TryRecvError::Disconnected) => Ok(None),
191        }
192    }
193
194    async fn complete(&self, result: CommandResult) -> Result<()> {
195        self.result_sender
196            .send(result)
197            .await
198            .map_err(|e| crate::error::LaneError::Other(format!("Failed to send result: {}", e)))?;
199        Ok(())
200    }
201
202    fn num_partitions(&self) -> usize {
203        self.partition_config.num_partitions
204    }
205
206    fn worker_id(&self) -> &WorkerId {
207        &self.worker_id
208    }
209
210    fn is_coordinator(&self) -> bool {
211        true // Local queue is always both coordinator and worker
212    }
213
214    fn is_worker(&self) -> bool {
215        true // Local queue is always both coordinator and worker
216    }
217}
218
219/// Worker pool for processing commands across multiple partitions
220pub struct WorkerPool {
221    distributed_queue: Arc<dyn DistributedQueue>,
222    worker_handles: Vec<tokio::task::JoinHandle<()>>,
223    shutdown: Arc<std::sync::atomic::AtomicBool>,
224}
225
226impl WorkerPool {
227    /// Create a new worker pool with the given distributed queue
228    pub fn new(distributed_queue: Arc<dyn DistributedQueue>) -> Self {
229        Self {
230            distributed_queue,
231            worker_handles: Vec::new(),
232            shutdown: Arc::new(std::sync::atomic::AtomicBool::new(false)),
233        }
234    }
235
236    /// Start worker tasks for all partitions
237    ///
238    /// The `command_executor` function is called for each command to execute it.
239    pub fn start<F, Fut>(&mut self, command_executor: F)
240    where
241        F: Fn(CommandEnvelope) -> Fut + Send + Sync + Clone + 'static,
242        Fut: std::future::Future<Output = std::result::Result<serde_json::Value, String>>
243            + Send
244            + 'static,
245    {
246        let num_partitions = self.distributed_queue.num_partitions();
247
248        for partition_id in 0..num_partitions {
249            let queue = Arc::clone(&self.distributed_queue);
250            let shutdown = Arc::clone(&self.shutdown);
251            let executor = command_executor.clone();
252
253            let handle = tokio::spawn(async move {
254                loop {
255                    if shutdown.load(std::sync::atomic::Ordering::Relaxed) {
256                        break;
257                    }
258
259                    match queue.dequeue(partition_id).await {
260                        Ok(Some(envelope)) => {
261                            let command_id = envelope.id.clone();
262                            let start = std::time::Instant::now();
263
264                            let result = executor(envelope).await;
265                            let duration_ms = start.elapsed().as_millis() as u64;
266
267                            let command_result = CommandResult {
268                                command_id,
269                                result,
270                                worker_id: queue.worker_id().clone(),
271                                duration_ms,
272                            };
273
274                            let _ = queue.complete(command_result).await;
275                        }
276                        Ok(None) => {
277                            // No command available, sleep briefly
278                            tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
279                        }
280                        Err(_) => {
281                            // Error dequeuing, sleep and retry
282                            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
283                        }
284                    }
285                }
286            });
287
288            self.worker_handles.push(handle);
289        }
290    }
291
292    /// Shutdown all workers
293    pub async fn shutdown(&mut self) {
294        self.shutdown
295            .store(true, std::sync::atomic::Ordering::Relaxed);
296
297        for handle in self.worker_handles.drain(..) {
298            let _ = handle.await;
299        }
300    }
301
302    /// Check if workers are running
303    pub fn is_running(&self) -> bool {
304        !self.worker_handles.is_empty() && !self.shutdown.load(std::sync::atomic::Ordering::Relaxed)
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[tokio::test]
313    async fn test_local_distributed_queue_creation() {
314        let queue = LocalDistributedQueue::auto();
315        assert!(queue.num_partitions() > 0);
316        assert!(queue.is_coordinator());
317        assert!(queue.is_worker());
318    }
319
320    #[tokio::test]
321    async fn test_local_distributed_queue_enqueue_dequeue() {
322        let queue = LocalDistributedQueue::new(PartitionConfig::new(
323            2,
324            crate::partition::PartitionStrategy::RoundRobin,
325        ));
326
327        let envelope = CommandEnvelope {
328            id: "cmd1".to_string(),
329            command_type: "test".to_string(),
330            lane_id: "query".to_string(),
331            partition_id: 0,
332            payload: serde_json::json!({"data": "test"}),
333            retry_count: 0,
334            created_at: chrono::Utc::now(),
335        };
336
337        // Enqueue
338        queue.enqueue(envelope.clone()).await.unwrap();
339
340        // Dequeue
341        let dequeued = queue.dequeue(0).await.unwrap();
342        assert!(dequeued.is_some());
343        let dequeued = dequeued.unwrap();
344        assert_eq!(dequeued.id, "cmd1");
345        assert_eq!(dequeued.command_type, "test");
346
347        // Dequeue again should be empty
348        let dequeued = queue.dequeue(0).await.unwrap();
349        assert!(dequeued.is_none());
350    }
351
352    #[tokio::test]
353    async fn test_local_distributed_queue_complete() {
354        let queue = LocalDistributedQueue::new(PartitionConfig::new(
355            2,
356            crate::partition::PartitionStrategy::RoundRobin,
357        ));
358
359        let result = CommandResult {
360            command_id: "cmd1".to_string(),
361            result: Ok(serde_json::json!({"success": true})),
362            worker_id: "worker1".to_string(),
363            duration_ms: 100,
364        };
365
366        queue.complete(result).await.unwrap();
367
368        // Check result was received
369        let receiver_arc = queue.result_receiver();
370        let mut receiver = receiver_arc.lock().await;
371        let received = receiver.try_recv();
372        assert!(received.is_ok());
373        let received = received.unwrap();
374        assert_eq!(received.command_id, "cmd1");
375    }
376
377    #[tokio::test]
378    async fn test_worker_pool() {
379        let queue = Arc::new(LocalDistributedQueue::new(PartitionConfig::new(
380            2,
381            crate::partition::PartitionStrategy::RoundRobin,
382        )));
383
384        let mut pool = WorkerPool::new(queue.clone());
385
386        // Start workers with a simple executor
387        pool.start(|envelope| async move { Ok(serde_json::json!({"processed": envelope.id})) });
388
389        assert!(pool.is_running());
390
391        // Enqueue a command
392        let envelope = CommandEnvelope {
393            id: "cmd1".to_string(),
394            command_type: "test".to_string(),
395            lane_id: "query".to_string(),
396            partition_id: 0,
397            payload: serde_json::json!({}),
398            retry_count: 0,
399            created_at: chrono::Utc::now(),
400        };
401        queue.enqueue(envelope).await.unwrap();
402
403        // Wait for processing
404        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
405
406        // Check result
407        let receiver_arc = queue.result_receiver();
408        let mut receiver = receiver_arc.lock().await;
409        let result = receiver.try_recv();
410        assert!(result.is_ok());
411        let result = result.unwrap();
412        assert_eq!(result.command_id, "cmd1");
413        assert!(result.result.is_ok());
414
415        // Shutdown
416        pool.shutdown().await;
417        assert!(!pool.is_running());
418    }
419
420    #[test]
421    fn test_command_envelope_serialization() {
422        let envelope = CommandEnvelope {
423            id: "cmd1".to_string(),
424            command_type: "test".to_string(),
425            lane_id: "query".to_string(),
426            partition_id: 0,
427            payload: serde_json::json!({"key": "value"}),
428            retry_count: 2,
429            created_at: chrono::Utc::now(),
430        };
431
432        let json = serde_json::to_string(&envelope).unwrap();
433        let parsed: CommandEnvelope = serde_json::from_str(&json).unwrap();
434
435        assert_eq!(parsed.id, "cmd1");
436        assert_eq!(parsed.command_type, "test");
437        assert_eq!(parsed.partition_id, 0);
438        assert_eq!(parsed.retry_count, 2);
439    }
440
441    #[test]
442    fn test_command_result_serialization() {
443        let result = CommandResult {
444            command_id: "cmd1".to_string(),
445            result: Ok(serde_json::json!({"success": true})),
446            worker_id: "worker1".to_string(),
447            duration_ms: 150,
448        };
449
450        let json = serde_json::to_string(&result).unwrap();
451        let parsed: CommandResult = serde_json::from_str(&json).unwrap();
452
453        assert_eq!(parsed.command_id, "cmd1");
454        assert_eq!(parsed.worker_id, "worker1");
455        assert_eq!(parsed.duration_ms, 150);
456    }
457}