1use crate::error::Result;
8use crate::partition::{PartitionConfig, PartitionId, Partitioner};
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12use tokio::sync::mpsc;
13
14pub type WorkerId = String;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct CommandEnvelope {
20 pub id: String,
22 pub command_type: String,
24 pub lane_id: String,
26 pub partition_id: PartitionId,
28 pub payload: serde_json::Value,
30 pub retry_count: u32,
32 pub created_at: chrono::DateTime<chrono::Utc>,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct CommandResult {
39 pub command_id: String,
41 pub result: std::result::Result<serde_json::Value, String>,
43 pub worker_id: WorkerId,
45 pub duration_ms: u64,
47}
48
49#[async_trait]
54pub trait DistributedQueue: Send + Sync {
55 async fn enqueue(&self, envelope: CommandEnvelope) -> Result<()>;
57
58 async fn dequeue(&self, partition_id: PartitionId) -> Result<Option<CommandEnvelope>>;
60
61 async fn complete(&self, result: CommandResult) -> Result<()>;
63
64 fn num_partitions(&self) -> usize;
66
67 fn worker_id(&self) -> &WorkerId;
69
70 fn is_coordinator(&self) -> bool;
72
73 fn is_worker(&self) -> bool;
75}
76
77pub struct LocalDistributedQueue {
83 worker_id: WorkerId,
84 partition_config: PartitionConfig,
85 partitioner: Arc<dyn Partitioner>,
86 partition_senders: Vec<mpsc::Sender<CommandEnvelope>>,
88 partition_receivers: Vec<Arc<tokio::sync::Mutex<mpsc::Receiver<CommandEnvelope>>>>,
90 result_sender: mpsc::Sender<CommandResult>,
92 result_receiver: Arc<tokio::sync::Mutex<mpsc::Receiver<CommandResult>>>,
93}
94
95impl LocalDistributedQueue {
96 pub fn new(partition_config: PartitionConfig) -> Self {
98 let num_partitions = partition_config.num_partitions;
99 let partitioner = partition_config.create_partitioner();
100
101 let mut partition_senders = Vec::with_capacity(num_partitions);
102 let mut partition_receivers = Vec::with_capacity(num_partitions);
103
104 for _ in 0..num_partitions {
106 let (tx, rx) = mpsc::channel(1000); partition_senders.push(tx);
108 partition_receivers.push(Arc::new(tokio::sync::Mutex::new(rx)));
109 }
110
111 let (result_tx, result_rx) = mpsc::channel(1000);
113
114 Self {
115 worker_id: format!("local-{}", uuid::Uuid::new_v4()),
116 partition_config,
117 partitioner,
118 partition_senders,
119 partition_receivers,
120 result_sender: result_tx,
121 result_receiver: Arc::new(tokio::sync::Mutex::new(result_rx)),
122 }
123 }
124
125 pub fn auto() -> Self {
127 Self::new(PartitionConfig::auto())
128 }
129
130 pub fn partitioner(&self) -> &Arc<dyn Partitioner> {
132 &self.partitioner
133 }
134
135 pub fn partition_config(&self) -> &PartitionConfig {
137 &self.partition_config
138 }
139
140 pub fn partition_receiver(
142 &self,
143 partition_id: PartitionId,
144 ) -> Option<Arc<tokio::sync::Mutex<mpsc::Receiver<CommandEnvelope>>>> {
145 self.partition_receivers.get(partition_id).cloned()
146 }
147
148 pub fn result_receiver(&self) -> Arc<tokio::sync::Mutex<mpsc::Receiver<CommandResult>>> {
150 Arc::clone(&self.result_receiver)
151 }
152
153 pub fn result_sender(&self) -> mpsc::Sender<CommandResult> {
155 self.result_sender.clone()
156 }
157}
158
159#[async_trait]
160impl DistributedQueue for LocalDistributedQueue {
161 async fn enqueue(&self, envelope: CommandEnvelope) -> Result<()> {
162 let partition_id = envelope.partition_id;
163 if partition_id >= self.partition_senders.len() {
164 return Err(crate::error::LaneError::Other(format!(
165 "Invalid partition ID: {}",
166 partition_id
167 )));
168 }
169
170 self.partition_senders[partition_id]
171 .send(envelope)
172 .await
173 .map_err(|e| crate::error::LaneError::Other(format!("Failed to enqueue: {}", e)))?;
174
175 Ok(())
176 }
177
178 async fn dequeue(&self, partition_id: PartitionId) -> Result<Option<CommandEnvelope>> {
179 if partition_id >= self.partition_receivers.len() {
180 return Err(crate::error::LaneError::Other(format!(
181 "Invalid partition ID: {}",
182 partition_id
183 )));
184 }
185
186 let mut receiver = self.partition_receivers[partition_id].lock().await;
187 match receiver.try_recv() {
188 Ok(envelope) => Ok(Some(envelope)),
189 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
190 Err(mpsc::error::TryRecvError::Disconnected) => Ok(None),
191 }
192 }
193
194 async fn complete(&self, result: CommandResult) -> Result<()> {
195 self.result_sender
196 .send(result)
197 .await
198 .map_err(|e| crate::error::LaneError::Other(format!("Failed to send result: {}", e)))?;
199 Ok(())
200 }
201
202 fn num_partitions(&self) -> usize {
203 self.partition_config.num_partitions
204 }
205
206 fn worker_id(&self) -> &WorkerId {
207 &self.worker_id
208 }
209
210 fn is_coordinator(&self) -> bool {
211 true }
213
214 fn is_worker(&self) -> bool {
215 true }
217}
218
219pub struct WorkerPool {
221 distributed_queue: Arc<dyn DistributedQueue>,
222 worker_handles: Vec<tokio::task::JoinHandle<()>>,
223 shutdown: Arc<std::sync::atomic::AtomicBool>,
224}
225
226impl WorkerPool {
227 pub fn new(distributed_queue: Arc<dyn DistributedQueue>) -> Self {
229 Self {
230 distributed_queue,
231 worker_handles: Vec::new(),
232 shutdown: Arc::new(std::sync::atomic::AtomicBool::new(false)),
233 }
234 }
235
236 pub fn start<F, Fut>(&mut self, command_executor: F)
240 where
241 F: Fn(CommandEnvelope) -> Fut + Send + Sync + Clone + 'static,
242 Fut: std::future::Future<Output = std::result::Result<serde_json::Value, String>>
243 + Send
244 + 'static,
245 {
246 let num_partitions = self.distributed_queue.num_partitions();
247
248 for partition_id in 0..num_partitions {
249 let queue = Arc::clone(&self.distributed_queue);
250 let shutdown = Arc::clone(&self.shutdown);
251 let executor = command_executor.clone();
252
253 let handle = tokio::spawn(async move {
254 loop {
255 if shutdown.load(std::sync::atomic::Ordering::Relaxed) {
256 break;
257 }
258
259 match queue.dequeue(partition_id).await {
260 Ok(Some(envelope)) => {
261 let command_id = envelope.id.clone();
262 let start = std::time::Instant::now();
263
264 let result = executor(envelope).await;
265 let duration_ms = start.elapsed().as_millis() as u64;
266
267 let command_result = CommandResult {
268 command_id,
269 result,
270 worker_id: queue.worker_id().clone(),
271 duration_ms,
272 };
273
274 let _ = queue.complete(command_result).await;
275 }
276 Ok(None) => {
277 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
279 }
280 Err(_) => {
281 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
283 }
284 }
285 }
286 });
287
288 self.worker_handles.push(handle);
289 }
290 }
291
292 pub async fn shutdown(&mut self) {
294 self.shutdown
295 .store(true, std::sync::atomic::Ordering::Relaxed);
296
297 for handle in self.worker_handles.drain(..) {
298 let _ = handle.await;
299 }
300 }
301
302 pub fn is_running(&self) -> bool {
304 !self.worker_handles.is_empty() && !self.shutdown.load(std::sync::atomic::Ordering::Relaxed)
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[tokio::test]
313 async fn test_local_distributed_queue_creation() {
314 let queue = LocalDistributedQueue::auto();
315 assert!(queue.num_partitions() > 0);
316 assert!(queue.is_coordinator());
317 assert!(queue.is_worker());
318 }
319
320 #[tokio::test]
321 async fn test_local_distributed_queue_enqueue_dequeue() {
322 let queue = LocalDistributedQueue::new(PartitionConfig::new(
323 2,
324 crate::partition::PartitionStrategy::RoundRobin,
325 ));
326
327 let envelope = CommandEnvelope {
328 id: "cmd1".to_string(),
329 command_type: "test".to_string(),
330 lane_id: "query".to_string(),
331 partition_id: 0,
332 payload: serde_json::json!({"data": "test"}),
333 retry_count: 0,
334 created_at: chrono::Utc::now(),
335 };
336
337 queue.enqueue(envelope.clone()).await.unwrap();
339
340 let dequeued = queue.dequeue(0).await.unwrap();
342 assert!(dequeued.is_some());
343 let dequeued = dequeued.unwrap();
344 assert_eq!(dequeued.id, "cmd1");
345 assert_eq!(dequeued.command_type, "test");
346
347 let dequeued = queue.dequeue(0).await.unwrap();
349 assert!(dequeued.is_none());
350 }
351
352 #[tokio::test]
353 async fn test_local_distributed_queue_complete() {
354 let queue = LocalDistributedQueue::new(PartitionConfig::new(
355 2,
356 crate::partition::PartitionStrategy::RoundRobin,
357 ));
358
359 let result = CommandResult {
360 command_id: "cmd1".to_string(),
361 result: Ok(serde_json::json!({"success": true})),
362 worker_id: "worker1".to_string(),
363 duration_ms: 100,
364 };
365
366 queue.complete(result).await.unwrap();
367
368 let receiver_arc = queue.result_receiver();
370 let mut receiver = receiver_arc.lock().await;
371 let received = receiver.try_recv();
372 assert!(received.is_ok());
373 let received = received.unwrap();
374 assert_eq!(received.command_id, "cmd1");
375 }
376
377 #[tokio::test]
378 async fn test_worker_pool() {
379 let queue = Arc::new(LocalDistributedQueue::new(PartitionConfig::new(
380 2,
381 crate::partition::PartitionStrategy::RoundRobin,
382 )));
383
384 let mut pool = WorkerPool::new(queue.clone());
385
386 pool.start(|envelope| async move { Ok(serde_json::json!({"processed": envelope.id})) });
388
389 assert!(pool.is_running());
390
391 let envelope = CommandEnvelope {
393 id: "cmd1".to_string(),
394 command_type: "test".to_string(),
395 lane_id: "query".to_string(),
396 partition_id: 0,
397 payload: serde_json::json!({}),
398 retry_count: 0,
399 created_at: chrono::Utc::now(),
400 };
401 queue.enqueue(envelope).await.unwrap();
402
403 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
405
406 let receiver_arc = queue.result_receiver();
408 let mut receiver = receiver_arc.lock().await;
409 let result = receiver.try_recv();
410 assert!(result.is_ok());
411 let result = result.unwrap();
412 assert_eq!(result.command_id, "cmd1");
413 assert!(result.result.is_ok());
414
415 pool.shutdown().await;
417 assert!(!pool.is_running());
418 }
419
420 #[test]
421 fn test_command_envelope_serialization() {
422 let envelope = CommandEnvelope {
423 id: "cmd1".to_string(),
424 command_type: "test".to_string(),
425 lane_id: "query".to_string(),
426 partition_id: 0,
427 payload: serde_json::json!({"key": "value"}),
428 retry_count: 2,
429 created_at: chrono::Utc::now(),
430 };
431
432 let json = serde_json::to_string(&envelope).unwrap();
433 let parsed: CommandEnvelope = serde_json::from_str(&json).unwrap();
434
435 assert_eq!(parsed.id, "cmd1");
436 assert_eq!(parsed.command_type, "test");
437 assert_eq!(parsed.partition_id, 0);
438 assert_eq!(parsed.retry_count, 2);
439 }
440
441 #[test]
442 fn test_command_result_serialization() {
443 let result = CommandResult {
444 command_id: "cmd1".to_string(),
445 result: Ok(serde_json::json!({"success": true})),
446 worker_id: "worker1".to_string(),
447 duration_ms: 150,
448 };
449
450 let json = serde_json::to_string(&result).unwrap();
451 let parsed: CommandResult = serde_json::from_str(&json).unwrap();
452
453 assert_eq!(parsed.command_id, "cmd1");
454 assert_eq!(parsed.worker_id, "worker1");
455 assert_eq!(parsed.duration_ms, 150);
456 }
457}