1use std::{
7 collections::{HashMap, HashSet},
8 sync::Arc,
9 time::{Duration, Instant},
10};
11
12use serde::{Deserialize, Serialize};
13use tokio::{
14 sync::{mpsc, RwLock, Mutex},
15 time::{interval, timeout},
16};
17use tracing::{debug, error, info, instrument};
18
19use crate::{
20 workflow::{WorkflowId, WorkflowError, StageId},
21 nat_traversal_api::NatTraversalEndpoint,
22};
23
24type PeerId = String;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub enum CoordinationMessage {
30 CoordinationRequest {
32 workflow_id: WorkflowId,
33 requester: PeerId,
34 participants: Vec<PeerId>,
35 timeout: Duration,
36 },
37 CoordinationAccept {
39 workflow_id: WorkflowId,
40 participant: PeerId,
41 capabilities: NodeCapabilities,
42 },
43 CoordinationReject {
45 workflow_id: WorkflowId,
46 participant: PeerId,
47 reason: String,
48 },
49 WorkflowStart {
51 workflow_id: WorkflowId,
52 stage_assignments: HashMap<StageId, PeerId>,
53 },
54 StageAssignment {
56 workflow_id: WorkflowId,
57 stage_id: StageId,
58 assigned_to: PeerId,
59 },
60 StageStatusUpdate {
62 workflow_id: WorkflowId,
63 stage_id: StageId,
64 status: StageStatus,
65 metrics: StageMetrics,
66 },
67 SyncBarrier {
69 workflow_id: WorkflowId,
70 barrier_id: String,
71 participants: Vec<PeerId>,
72 },
73 BarrierReady {
75 workflow_id: WorkflowId,
76 barrier_id: String,
77 participant: PeerId,
78 },
79 WorkflowComplete {
81 workflow_id: WorkflowId,
82 result: WorkflowCoordinationResult,
83 },
84 Heartbeat {
86 workflow_id: WorkflowId,
87 participant: PeerId,
88 timestamp_ms: u64,
89 },
90 ErrorNotification {
92 workflow_id: WorkflowId,
93 participant: PeerId,
94 error: String,
95 },
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct NodeCapabilities {
101 pub cpu_cores: u32,
103 pub memory_mb: u64,
105 pub bandwidth_mbps: u32,
107 pub supported_workflows: Vec<String>,
109 pub current_load: u8,
111}
112
113#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
115pub enum StageStatus {
116 Pending,
117 Running,
118 Completed,
119 Failed,
120 Cancelled,
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct StageMetrics {
126 pub start_time_ms: Option<u64>,
128 pub end_time_ms: Option<u64>,
130 pub cpu_usage: f32,
132 pub memory_usage: u64,
134 pub bytes_sent: u64,
136 pub bytes_received: u64,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct WorkflowCoordinationResult {
143 pub success: bool,
145 pub duration: Duration,
147 pub stage_results: HashMap<StageId, StageResult>,
149 pub total_metrics: StageMetrics,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct StageResult {
156 pub executor: PeerId,
158 pub status: StageStatus,
160 pub metrics: StageMetrics,
162 pub error: Option<String>,
164}
165
166pub struct WorkflowCoordinator {
168 local_peer_id: PeerId,
170 endpoint: Arc<NatTraversalEndpoint>,
172 coordinations: Arc<RwLock<HashMap<WorkflowId, CoordinationSession>>>,
174 message_handler: Arc<Mutex<mpsc::Receiver<(PeerId, CoordinationMessage)>>>,
176 message_tx: mpsc::Sender<(PeerId, CoordinationMessage)>,
178 capabilities: NodeCapabilities,
180}
181
182impl WorkflowCoordinator {
183 pub fn new(
185 local_peer_id: PeerId,
186 endpoint: Arc<NatTraversalEndpoint>,
187 capabilities: NodeCapabilities,
188 ) -> Self {
189 let (message_tx, message_rx) = mpsc::channel(1000);
190
191 Self {
192 local_peer_id,
193 endpoint,
194 coordinations: Arc::new(RwLock::new(HashMap::new())),
195 message_handler: Arc::new(Mutex::new(message_rx)),
196 message_tx,
197 capabilities,
198 }
199 }
200
201 pub async fn start(&self) -> Result<(), WorkflowError> {
203 info!("Starting workflow coordinator for peer {}", self.local_peer_id);
204
205 let coordinator = self.clone();
207 tokio::spawn(async move {
208 coordinator.message_processing_loop().await;
209 });
210
211 let coordinator = self.clone();
213 tokio::spawn(async move {
214 coordinator.heartbeat_loop().await;
215 });
216
217 Ok(())
218 }
219
220 #[instrument(skip(self))]
222 pub async fn coordinate_workflow(
223 &self,
224 workflow_id: WorkflowId,
225 participants: Vec<PeerId>,
226 stage_assignments: HashMap<StageId, PeerId>,
227 coordination_timeout: Duration,
228 ) -> Result<WorkflowCoordinationResult, WorkflowError> {
229 info!("Coordinating workflow {} with {} participants", workflow_id, participants.len());
230
231 let session = CoordinationSession::new(
233 workflow_id,
234 self.local_peer_id.clone(),
235 participants.clone(),
236 stage_assignments.clone(),
237 );
238
239 {
241 let mut coordinations = self.coordinations.write().await;
242 coordinations.insert(workflow_id, session);
243 }
244
245 for participant in &participants {
247 if participant != &self.local_peer_id {
248 self.send_message(
249 participant.clone(),
250 CoordinationMessage::CoordinationRequest {
251 workflow_id,
252 requester: self.local_peer_id.clone(),
253 participants: participants.clone(),
254 timeout: coordination_timeout,
255 },
256 ).await?;
257 }
258 }
259
260 let accept_timeout = Duration::from_secs(30);
262 let accept_result = timeout(accept_timeout, self.wait_for_acceptances(workflow_id, &participants)).await;
263
264 if accept_result.is_err() {
265 self.cleanup_coordination(workflow_id).await;
266 return Err(WorkflowError {
267 code: "COORDINATION_TIMEOUT".to_string(),
268 message: "Timeout waiting for participant acceptances".to_string(),
269 stage: None,
270 trace: None,
271 recovery_hints: vec!["Check network connectivity".to_string()],
272 });
273 }
274
275 for participant in &participants {
277 self.send_message(
278 participant.clone(),
279 CoordinationMessage::WorkflowStart {
280 workflow_id,
281 stage_assignments: stage_assignments.clone(),
282 },
283 ).await?;
284 }
285
286 let result = timeout(
288 coordination_timeout,
289 self.monitor_workflow_execution(workflow_id),
290 ).await;
291
292 self.cleanup_coordination(workflow_id).await;
294
295 match result {
296 Ok(Ok(result)) => Ok(result),
297 Ok(Err(e)) => Err(e),
298 Err(_) => Err(WorkflowError {
299 code: "WORKFLOW_TIMEOUT".to_string(),
300 message: "Workflow execution timed out".to_string(),
301 stage: None,
302 trace: None,
303 recovery_hints: vec!["Increase timeout or optimize workflow".to_string()],
304 }),
305 }
306 }
307
308 pub async fn join_workflow(
310 &self,
311 workflow_id: WorkflowId,
312 coordinator: PeerId,
313 ) -> Result<(), WorkflowError> {
314 info!("Joining workflow {} coordinated by {}", workflow_id, coordinator);
315
316 self.send_message(
318 coordinator,
319 CoordinationMessage::CoordinationAccept {
320 workflow_id,
321 participant: self.local_peer_id.clone(),
322 capabilities: self.capabilities.clone(),
323 },
324 ).await?;
325
326 Ok(())
327 }
328
329 pub async fn update_stage_status(
331 &self,
332 workflow_id: WorkflowId,
333 stage_id: StageId,
334 status: StageStatus,
335 metrics: StageMetrics,
336 ) -> Result<(), WorkflowError> {
337 let coordinator = {
339 let coordinations = self.coordinations.read().await;
340 coordinations.get(&workflow_id)
341 .map(|session| session.coordinator.clone())
342 };
343
344 if let Some(coordinator) = coordinator {
345 self.send_message(
346 coordinator,
347 CoordinationMessage::StageStatusUpdate {
348 workflow_id,
349 stage_id,
350 status,
351 metrics,
352 },
353 ).await?;
354 }
355
356 Ok(())
357 }
358
359 pub async fn signal_barrier_ready(
361 &self,
362 workflow_id: WorkflowId,
363 barrier_id: String,
364 ) -> Result<(), WorkflowError> {
365 let coordinator = {
367 let coordinations = self.coordinations.read().await;
368 coordinations.get(&workflow_id)
369 .map(|session| session.coordinator.clone())
370 };
371
372 if let Some(coordinator) = coordinator {
373 self.send_message(
374 coordinator,
375 CoordinationMessage::BarrierReady {
376 workflow_id,
377 barrier_id,
378 participant: self.local_peer_id.clone(),
379 },
380 ).await?;
381 }
382
383 Ok(())
384 }
385
386 async fn wait_for_acceptances(
388 &self,
389 workflow_id: WorkflowId,
390 participants: &[PeerId],
391 ) -> Result<(), WorkflowError> {
392 let expected_count = participants.len() - 1; let mut accepted_count = 0;
394
395 let start_time = Instant::now();
396 let check_interval = Duration::from_millis(100);
397
398 loop {
399 let coordinations = self.coordinations.read().await;
400 if let Some(session) = coordinations.get(&workflow_id) {
401 accepted_count = session.accepted_participants.len();
402 if accepted_count >= expected_count {
403 return Ok(());
404 }
405 }
406 drop(coordinations);
407
408 if start_time.elapsed() > Duration::from_secs(30) {
409 return Err(WorkflowError {
410 code: "ACCEPTANCE_TIMEOUT".to_string(),
411 message: format!("Only {}/{} participants accepted", accepted_count, expected_count),
412 stage: None,
413 trace: None,
414 recovery_hints: vec!["Check participant availability".to_string()],
415 });
416 }
417
418 tokio::time::sleep(check_interval).await;
419 }
420 }
421
422 async fn monitor_workflow_execution(
424 &self,
425 workflow_id: WorkflowId,
426 ) -> Result<WorkflowCoordinationResult, WorkflowError> {
427 let start_time = Instant::now();
428
429 loop {
430 let coordinations = self.coordinations.read().await;
431 if let Some(session) = coordinations.get(&workflow_id) {
432 let all_complete = session.stage_status.iter()
434 .all(|(_, status)| matches!(status.status, StageStatus::Completed | StageStatus::Failed));
435
436 if all_complete {
437 let success = session.stage_status.iter()
439 .all(|(_, status)| status.status == StageStatus::Completed);
440
441 let total_metrics = self.aggregate_metrics(&session.stage_status);
442
443 return Ok(WorkflowCoordinationResult {
444 success,
445 duration: start_time.elapsed(),
446 stage_results: session.stage_status.clone(),
447 total_metrics,
448 });
449 }
450 }
451 drop(coordinations);
452
453 tokio::time::sleep(Duration::from_millis(100)).await;
454 }
455 }
456
457 fn aggregate_metrics(&self, stage_results: &HashMap<StageId, StageResult>) -> StageMetrics {
459 let mut total = StageMetrics {
460 start_time_ms: None,
461 end_time_ms: None,
462 cpu_usage: 0.0,
463 memory_usage: 0,
464 bytes_sent: 0,
465 bytes_received: 0,
466 };
467
468 let mut cpu_sum = 0.0;
469 let mut cpu_count = 0;
470
471 for (_, result) in stage_results {
472 if let Some(start) = result.metrics.start_time_ms {
473 total.start_time_ms = Some(total.start_time_ms.map_or(start, |t| t.min(start)));
474 }
475 if let Some(end) = result.metrics.end_time_ms {
476 total.end_time_ms = Some(total.end_time_ms.map_or(end, |t| t.max(end)));
477 }
478
479 cpu_sum += result.metrics.cpu_usage;
480 cpu_count += 1;
481
482 total.memory_usage = total.memory_usage.max(result.metrics.memory_usage);
483 total.bytes_sent += result.metrics.bytes_sent;
484 total.bytes_received += result.metrics.bytes_received;
485 }
486
487 if cpu_count > 0 {
488 total.cpu_usage = cpu_sum / cpu_count as f32;
489 }
490
491 total
492 }
493
494 async fn cleanup_coordination(&self, workflow_id: WorkflowId) {
496 let mut coordinations = self.coordinations.write().await;
497 coordinations.remove(&workflow_id);
498 debug!("Cleaned up coordination session for workflow {}", workflow_id);
499 }
500
501 async fn send_message(
503 &self,
504 peer: PeerId,
505 message: CoordinationMessage,
506 ) -> Result<(), WorkflowError> {
507 debug!("Sending {:?} to peer {}", message, peer);
510
511 if peer == self.local_peer_id {
513 self.message_tx.send((self.local_peer_id.clone(), message)).await
514 .map_err(|_| WorkflowError {
515 code: "SEND_ERROR".to_string(),
516 message: "Failed to send message".to_string(),
517 stage: None,
518 trace: None,
519 recovery_hints: vec![],
520 })?;
521 }
522
523 Ok(())
524 }
525
526 async fn message_processing_loop(&self) {
528 let mut receiver = self.message_handler.lock().await;
529
530 while let Some((sender, message)) = receiver.recv().await {
531 if let Err(e) = self.handle_message(sender, message).await {
532 error!("Error handling coordination message: {:?}", e);
533 }
534 }
535 }
536
537 async fn handle_message(
539 &self,
540 sender: PeerId,
541 message: CoordinationMessage,
542 ) -> Result<(), WorkflowError> {
543 match message {
544 CoordinationMessage::CoordinationRequest { workflow_id, requester, participants: _, timeout: _ } => {
545 self.join_workflow(workflow_id, requester).await?;
547 }
548 CoordinationMessage::CoordinationAccept { workflow_id, participant, capabilities } => {
549 let mut coordinations = self.coordinations.write().await;
550 if let Some(session) = coordinations.get_mut(&workflow_id) {
551 session.accepted_participants.insert(participant.clone());
552 session.participant_capabilities.insert(participant, capabilities);
553 }
554 }
555 CoordinationMessage::StageStatusUpdate { workflow_id, stage_id, status, metrics } => {
556 let mut coordinations = self.coordinations.write().await;
557 if let Some(session) = coordinations.get_mut(&workflow_id) {
558 session.stage_status.insert(stage_id, StageResult {
559 executor: sender,
560 status,
561 metrics,
562 error: None,
563 });
564 }
565 }
566 CoordinationMessage::BarrierReady { workflow_id, barrier_id, participant } => {
567 let mut coordinations = self.coordinations.write().await;
568 if let Some(session) = coordinations.get_mut(&workflow_id) {
569 session.barrier_ready
570 .entry(barrier_id)
571 .or_insert_with(HashSet::new)
572 .insert(participant);
573 }
574 }
575 _ => {}
576 }
577
578 Ok(())
579 }
580
581 async fn heartbeat_loop(&self) {
583 let mut interval = interval(Duration::from_secs(5));
584
585 loop {
586 interval.tick().await;
587
588 let coordinations = self.coordinations.read().await;
589 for (workflow_id, session) in coordinations.iter() {
590 if session.coordinator != self.local_peer_id {
591 let _ = self.send_message(
593 session.coordinator.clone(),
594 CoordinationMessage::Heartbeat {
595 workflow_id: *workflow_id,
596 participant: self.local_peer_id.clone(),
597 timestamp_ms: Instant::now().elapsed().as_millis() as u64,
598 },
599 ).await;
600 }
601 }
602 }
603 }
604}
605
606impl Clone for WorkflowCoordinator {
607 fn clone(&self) -> Self {
608 Self {
609 local_peer_id: self.local_peer_id.clone(),
610 endpoint: self.endpoint.clone(),
611 coordinations: self.coordinations.clone(),
612 message_handler: self.message_handler.clone(),
613 message_tx: self.message_tx.clone(),
614 capabilities: self.capabilities.clone(),
615 }
616 }
617}
618
619struct CoordinationSession {
621 workflow_id: WorkflowId,
623 coordinator: PeerId,
625 participants: Vec<PeerId>,
627 accepted_participants: HashSet<PeerId>,
629 participant_capabilities: HashMap<PeerId, NodeCapabilities>,
631 stage_assignments: HashMap<StageId, PeerId>,
633 stage_status: HashMap<StageId, StageResult>,
635 barrier_ready: HashMap<String, HashSet<PeerId>>,
637 start_time: Instant,
639}
640
641impl CoordinationSession {
642 fn new(
643 workflow_id: WorkflowId,
644 coordinator: PeerId,
645 participants: Vec<PeerId>,
646 stage_assignments: HashMap<StageId, PeerId>,
647 ) -> Self {
648 Self {
649 workflow_id,
650 coordinator,
651 participants,
652 accepted_participants: HashSet::new(),
653 participant_capabilities: HashMap::new(),
654 stage_assignments,
655 stage_status: HashMap::new(),
656 barrier_ready: HashMap::new(),
657 start_time: Instant::now(),
658 }
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665
666 #[tokio::test]
667 async fn test_workflow_coordinator() {
668 let peer_id = "test_peer_id".to_string();
669
670 let mut config = crate::nat_traversal_api::NatTraversalConfig::default();
672 config.bootstrap_nodes.push("127.0.0.1:9000".parse().unwrap());
673
674 let endpoint = Arc::new(NatTraversalEndpoint::new(
675 config,
676 None,
677 ).await.unwrap());
678
679 let capabilities = NodeCapabilities {
680 cpu_cores: 4,
681 memory_mb: 8192,
682 bandwidth_mbps: 100,
683 supported_workflows: vec!["test_workflow".to_string()],
684 current_load: 20,
685 };
686
687 let coordinator = WorkflowCoordinator::new(peer_id.clone(), endpoint, capabilities);
688 coordinator.start().await.unwrap();
689
690 let workflow_id = WorkflowId::generate();
692 let participants = vec![peer_id.clone()];
693 let stage_assignments = HashMap::new();
694
695 let result = coordinator.coordinate_workflow(
696 workflow_id,
697 participants,
698 stage_assignments,
699 Duration::from_secs(60),
700 ).await;
701
702 assert!(result.is_ok());
704 }
705}