1use std::collections::HashMap;
7use std::sync::Arc;
8use std::fmt;
9
10use serde::{Deserialize, Serialize};
11use tracing::info;
12
13use crate::workflow::{
14 BackoffStrategy, Condition, ErrorHandler, RollbackStrategy, StageId, Version,
15 WorkflowAction, WorkflowError, WorkflowEvent,
16};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct WorkflowDefinition {
21 pub id: String,
23 pub name: String,
25 pub version: Version,
27 pub description: String,
29 pub stages: Vec<WorkflowStage>,
31 pub transitions: HashMap<(StageId, WorkflowEvent), StageId>,
33 pub timeouts: HashMap<StageId, Duration>,
35 pub error_handlers: HashMap<StageId, ErrorHandler>,
37 pub initial_stage: StageId,
39 pub final_stages: Vec<StageId>,
41 pub global_timeout: Option<Duration>,
43 pub metadata: HashMap<String, String>,
45}
46
47use std::time::Duration;
48
49#[derive(Clone, Serialize, Deserialize)]
51pub struct WorkflowStage {
52 pub id: StageId,
54 pub name: String,
56 pub description: String,
58 #[serde(skip)]
60 pub actions: Vec<Arc<dyn WorkflowAction>>,
61 pub action_names: Vec<String>,
63 #[serde(skip)]
65 pub preconditions: Vec<Arc<dyn Condition>>,
66 pub precondition_descriptions: Vec<String>,
68 #[serde(skip)]
70 pub postconditions: Vec<Arc<dyn Condition>>,
71 pub postcondition_descriptions: Vec<String>,
73 pub rollback: Option<RollbackStrategy>,
75 pub skippable: bool,
77 pub max_duration: Option<Duration>,
79}
80
81impl fmt::Debug for WorkflowStage {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 f.debug_struct("WorkflowStage")
84 .field("id", &self.id)
85 .field("name", &self.name)
86 .field("description", &self.description)
87 .field("action_names", &self.action_names)
88 .field("precondition_descriptions", &self.precondition_descriptions)
89 .field("postcondition_descriptions", &self.postcondition_descriptions)
90 .field("rollback", &self.rollback)
91 .field("skippable", &self.skippable)
92 .field("max_duration", &self.max_duration)
93 .finish()
94 }
95}
96
97pub struct WorkflowDefinitionBuilder {
99 definition: WorkflowDefinition,
100}
101
102impl WorkflowDefinitionBuilder {
103 pub fn new(id: String, name: String, version: Version) -> Self {
105 Self {
106 definition: WorkflowDefinition {
107 id,
108 name,
109 version,
110 description: String::new(),
111 stages: Vec::new(),
112 transitions: HashMap::new(),
113 timeouts: HashMap::new(),
114 error_handlers: HashMap::new(),
115 initial_stage: StageId("start".to_string()),
116 final_stages: vec![StageId("complete".to_string())],
117 global_timeout: None,
118 metadata: HashMap::new(),
119 },
120 }
121 }
122
123 pub fn description(mut self, desc: String) -> Self {
125 self.definition.description = desc;
126 self
127 }
128
129 pub fn add_stage(mut self, stage: WorkflowStage) -> Self {
131 self.definition.stages.push(stage);
132 self
133 }
134
135 pub fn add_transition(
137 mut self,
138 from_stage: StageId,
139 event: WorkflowEvent,
140 to_stage: StageId,
141 ) -> Self {
142 self.definition.transitions.insert((from_stage, event), to_stage);
143 self
144 }
145
146 pub fn set_stage_timeout(mut self, stage: StageId, timeout: Duration) -> Self {
148 self.definition.timeouts.insert(stage, timeout);
149 self
150 }
151
152 pub fn set_error_handler(mut self, stage: StageId, handler: ErrorHandler) -> Self {
154 self.definition.error_handlers.insert(stage, handler);
155 self
156 }
157
158 pub fn initial_stage(mut self, stage: StageId) -> Self {
160 self.definition.initial_stage = stage;
161 self
162 }
163
164 pub fn add_final_stage(mut self, stage: StageId) -> Self {
166 self.definition.final_stages.push(stage);
167 self
168 }
169
170 pub fn global_timeout(mut self, timeout: Duration) -> Self {
172 self.definition.global_timeout = Some(timeout);
173 self
174 }
175
176 pub fn add_metadata(mut self, key: String, value: String) -> Self {
178 self.definition.metadata.insert(key, value);
179 self
180 }
181
182 pub fn build(self) -> WorkflowDefinition {
184 self.definition
185 }
186}
187
188pub struct WorkflowTemplates;
190
191impl WorkflowTemplates {
192 pub fn basic_nat_traversal() -> WorkflowDefinition {
194 WorkflowDefinitionBuilder::new(
195 "nat_traversal_basic".to_string(),
196 "Basic NAT Traversal".to_string(),
197 Version { major: 1, minor: 0, patch: 0 },
198 )
199 .description("Standard NAT traversal workflow with candidate discovery and hole punching".to_string())
200 .add_stage(WorkflowStage {
201 id: StageId("discover_candidates".to_string()),
202 name: "Discover Candidates".to_string(),
203 description: "Discover local and server-reflexive candidates".to_string(),
204 actions: vec![],
205 action_names: vec!["discover_local_candidates".to_string(), "query_stun_servers".to_string()],
206 preconditions: vec![],
207 precondition_descriptions: vec!["network_available".to_string()],
208 postconditions: vec![],
209 postcondition_descriptions: vec!["candidates_discovered".to_string()],
210 rollback: None,
211 skippable: false,
212 max_duration: Some(Duration::from_secs(5)),
213 })
214 .add_stage(WorkflowStage {
215 id: StageId("coordinate_with_peer".to_string()),
216 name: "Coordinate with Peer".to_string(),
217 description: "Exchange candidates and coordinate hole punching".to_string(),
218 actions: vec![],
219 action_names: vec!["exchange_candidates".to_string(), "synchronize_timing".to_string()],
220 preconditions: vec![],
221 precondition_descriptions: vec!["candidates_available".to_string()],
222 postconditions: vec![],
223 postcondition_descriptions: vec!["coordination_complete".to_string()],
224 rollback: Some(RollbackStrategy::JumpToStage {
225 stage_id: StageId("discover_candidates".to_string())
226 }),
227 skippable: false,
228 max_duration: Some(Duration::from_secs(10)),
229 })
230 .add_stage(WorkflowStage {
231 id: StageId("hole_punching".to_string()),
232 name: "Hole Punching".to_string(),
233 description: "Execute synchronized hole punching".to_string(),
234 actions: vec![],
235 action_names: vec!["execute_hole_punch".to_string(), "verify_connectivity".to_string()],
236 preconditions: vec![],
237 precondition_descriptions: vec!["coordination_complete".to_string()],
238 postconditions: vec![],
239 postcondition_descriptions: vec!["connection_established".to_string()],
240 rollback: Some(RollbackStrategy::Compensate {
241 actions: vec!["cleanup_failed_attempts".to_string()]
242 }),
243 skippable: false,
244 max_duration: Some(Duration::from_secs(15)),
245 })
246 .add_stage(WorkflowStage {
247 id: StageId("connection_established".to_string()),
248 name: "Connection Established".to_string(),
249 description: "Connection successfully established".to_string(),
250 actions: vec![],
251 action_names: vec!["finalize_connection".to_string()],
252 preconditions: vec![],
253 precondition_descriptions: vec!["connection_verified".to_string()],
254 postconditions: vec![],
255 postcondition_descriptions: vec![],
256 rollback: None,
257 skippable: false,
258 max_duration: Some(Duration::from_secs(2)),
259 })
260 .initial_stage(StageId("discover_candidates".to_string()))
261 .add_final_stage(StageId("connection_established".to_string()))
262 .add_transition(
263 StageId("discover_candidates".to_string()),
264 WorkflowEvent::StageCompleted { stage_id: StageId("discover_candidates".to_string()) },
265 StageId("coordinate_with_peer".to_string()),
266 )
267 .add_transition(
268 StageId("coordinate_with_peer".to_string()),
269 WorkflowEvent::StageCompleted { stage_id: StageId("coordinate_with_peer".to_string()) },
270 StageId("hole_punching".to_string()),
271 )
272 .add_transition(
273 StageId("hole_punching".to_string()),
274 WorkflowEvent::StageCompleted { stage_id: StageId("hole_punching".to_string()) },
275 StageId("connection_established".to_string()),
276 )
277 .set_stage_timeout(StageId("discover_candidates".to_string()), Duration::from_secs(10))
278 .set_stage_timeout(StageId("coordinate_with_peer".to_string()), Duration::from_secs(20))
279 .set_stage_timeout(StageId("hole_punching".to_string()), Duration::from_secs(30))
280 .set_error_handler(
281 StageId("hole_punching".to_string()),
282 ErrorHandler {
283 max_retries: 3,
284 backoff: BackoffStrategy::Exponential {
285 initial: Duration::from_millis(500),
286 max: Duration::from_secs(5),
287 factor: 2.0,
288 },
289 fallback_stage: Some(StageId("coordinate_with_peer".to_string())),
290 propagate: false,
291 },
292 )
293 .global_timeout(Duration::from_secs(60))
294 .build()
295 }
296
297 pub fn advanced_nat_traversal() -> WorkflowDefinition {
299 let mut basic = Self::basic_nat_traversal();
300 basic.id = "nat_traversal_advanced".to_string();
301 basic.name = "Advanced NAT Traversal with Relay".to_string();
302 basic.version = Version { major: 1, minor: 0, patch: 0 };
303
304 basic.stages.push(WorkflowStage {
306 id: StageId("relay_fallback".to_string()),
307 name: "Relay Fallback".to_string(),
308 description: "Establish connection through relay server".to_string(),
309 actions: vec![],
310 action_names: vec!["connect_to_relay".to_string(), "establish_relay_path".to_string()],
311 preconditions: vec![],
312 precondition_descriptions: vec!["relay_available".to_string()],
313 postconditions: vec![],
314 postcondition_descriptions: vec!["relay_connection_established".to_string()],
315 rollback: None,
316 skippable: false,
317 max_duration: Some(Duration::from_secs(10)),
318 });
319
320 basic.transitions.insert(
322 (
323 StageId("hole_punching".to_string()),
324 WorkflowEvent::StageFailed {
325 stage_id: StageId("hole_punching".to_string()),
326 error: "max_retries_exceeded".to_string(),
327 },
328 ),
329 StageId("relay_fallback".to_string()),
330 );
331
332 basic.transitions.insert(
334 (
335 StageId("relay_fallback".to_string()),
336 WorkflowEvent::StageCompleted { stage_id: StageId("relay_fallback".to_string()) },
337 ),
338 StageId("connection_established".to_string()),
339 );
340
341 basic
342 }
343
344 pub fn multi_peer_coordination() -> WorkflowDefinition {
346 WorkflowDefinitionBuilder::new(
347 "multi_peer_coordination".to_string(),
348 "Multi-Peer Coordination".to_string(),
349 Version { major: 1, minor: 0, patch: 0 },
350 )
351 .description("Coordinate NAT traversal among multiple peers".to_string())
352 .add_stage(WorkflowStage {
353 id: StageId("peer_discovery".to_string()),
354 name: "Peer Discovery".to_string(),
355 description: "Discover available peers".to_string(),
356 actions: vec![],
357 action_names: vec!["query_bootstrap_nodes".to_string(), "exchange_peer_lists".to_string()],
358 preconditions: vec![],
359 precondition_descriptions: vec!["bootstrap_available".to_string()],
360 postconditions: vec![],
361 postcondition_descriptions: vec!["peers_discovered".to_string()],
362 rollback: None,
363 skippable: false,
364 max_duration: Some(Duration::from_secs(10)),
365 })
366 .add_stage(WorkflowStage {
367 id: StageId("establish_coordinator".to_string()),
368 name: "Establish Coordinator".to_string(),
369 description: "Select and establish connection to coordination node".to_string(),
370 actions: vec![],
371 action_names: vec!["select_coordinator".to_string(), "connect_to_coordinator".to_string()],
372 preconditions: vec![],
373 precondition_descriptions: vec!["peers_available".to_string()],
374 postconditions: vec![],
375 postcondition_descriptions: vec!["coordinator_connected".to_string()],
376 rollback: None,
377 skippable: false,
378 max_duration: Some(Duration::from_secs(15)),
379 })
380 .add_stage(WorkflowStage {
381 id: StageId("coordinate_connections".to_string()),
382 name: "Coordinate Connections".to_string(),
383 description: "Coordinate NAT traversal for all peer connections".to_string(),
384 actions: vec![],
385 action_names: vec!["plan_connection_order".to_string(), "execute_coordinated_traversal".to_string()],
386 preconditions: vec![],
387 precondition_descriptions: vec!["coordinator_ready".to_string()],
388 postconditions: vec![],
389 postcondition_descriptions: vec!["all_connections_established".to_string()],
390 rollback: Some(RollbackStrategy::Compensate {
391 actions: vec!["cleanup_partial_connections".to_string()],
392 }),
393 skippable: false,
394 max_duration: Some(Duration::from_secs(60)),
395 })
396 .add_stage(WorkflowStage {
397 id: StageId("mesh_established".to_string()),
398 name: "Mesh Established".to_string(),
399 description: "Peer mesh successfully established".to_string(),
400 actions: vec![],
401 action_names: vec!["verify_mesh_connectivity".to_string(), "optimize_routing".to_string()],
402 preconditions: vec![],
403 precondition_descriptions: vec!["minimum_peers_connected".to_string()],
404 postconditions: vec![],
405 postcondition_descriptions: vec![],
406 rollback: None,
407 skippable: false,
408 max_duration: Some(Duration::from_secs(5)),
409 })
410 .initial_stage(StageId("peer_discovery".to_string()))
411 .add_final_stage(StageId("mesh_established".to_string()))
412 .global_timeout(Duration::from_secs(120))
413 .build()
414 }
415}
416
417pub struct WorkflowRegistry {
419 definitions: RwLock<HashMap<String, WorkflowDefinition>>,
420}
421
422use tokio::sync::RwLock;
423
424impl WorkflowRegistry {
425 pub fn new() -> Self {
427 Self {
428 definitions: RwLock::new(HashMap::new()),
429 }
430 }
431
432 pub async fn register(&self, definition: WorkflowDefinition) -> Result<(), WorkflowError> {
434 let mut definitions = self.definitions.write().await;
435
436 let key = format!("{}:{}", definition.id, definition.version);
437 if definitions.contains_key(&key) {
438 return Err(WorkflowError {
439 code: "ALREADY_EXISTS".to_string(),
440 message: format!("Workflow {} version {} already registered", definition.id, definition.version),
441 stage: None,
442 trace: None,
443 recovery_hints: vec!["Use a different version number".to_string()],
444 });
445 }
446
447 info!("Registered workflow: {} v{}", definition.id, definition.version);
448 definitions.insert(key, definition);
449 Ok(())
450 }
451
452 pub async fn get(&self, id: &str, version: &Version) -> Option<WorkflowDefinition> {
454 let definitions = self.definitions.read().await;
455 let key = format!("{}:{}", id, version);
456 definitions.get(&key).cloned()
457 }
458
459 pub async fn get_latest(&self, id: &str) -> Option<WorkflowDefinition> {
461 let definitions = self.definitions.read().await;
462
463 definitions.iter()
464 .filter(|(k, _)| k.starts_with(&format!("{}:", id)))
465 .max_by_key(|(_, def)| &def.version)
466 .map(|(_, def)| def.clone())
467 }
468
469 pub async fn list(&self) -> Vec<(String, Version)> {
471 let definitions = self.definitions.read().await;
472
473 definitions.values()
474 .map(|def| (def.id.clone(), def.version.clone()))
475 .collect()
476 }
477
478 pub async fn load_defaults(&self) -> Result<(), WorkflowError> {
480 self.register(WorkflowTemplates::basic_nat_traversal()).await?;
481 self.register(WorkflowTemplates::advanced_nat_traversal()).await?;
482 self.register(WorkflowTemplates::multi_peer_coordination()).await?;
483
484 info!("Loaded {} default workflow templates", 3);
485 Ok(())
486 }
487}
488
489impl Default for WorkflowRegistry {
490 fn default() -> Self {
491 Self::new()
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[tokio::test]
500 async fn test_workflow_builder() {
501 let workflow = WorkflowDefinitionBuilder::new(
502 "test_workflow".to_string(),
503 "Test Workflow".to_string(),
504 Version { major: 1, minor: 0, patch: 0 },
505 )
506 .description("Test workflow description".to_string())
507 .add_stage(WorkflowStage {
508 id: StageId("stage1".to_string()),
509 name: "Stage 1".to_string(),
510 description: "First stage".to_string(),
511 actions: vec![],
512 action_names: vec!["action1".to_string()],
513 preconditions: vec![],
514 precondition_descriptions: vec![],
515 postconditions: vec![],
516 postcondition_descriptions: vec![],
517 rollback: None,
518 skippable: false,
519 max_duration: None,
520 })
521 .initial_stage(StageId("stage1".to_string()))
522 .build();
523
524 assert_eq!(workflow.id, "test_workflow");
525 assert_eq!(workflow.stages.len(), 1);
526 assert_eq!(workflow.initial_stage, StageId("stage1".to_string()));
527 }
528
529 #[tokio::test]
530 async fn test_workflow_registry() {
531 let registry = WorkflowRegistry::new();
532
533 let workflow = WorkflowTemplates::basic_nat_traversal();
534 registry.register(workflow.clone()).await.unwrap();
535
536 let retrieved = registry.get(&workflow.id, &workflow.version).await;
537 assert!(retrieved.is_some());
538
539 let latest = registry.get_latest(&workflow.id).await;
540 assert!(latest.is_some());
541
542 let list = registry.list().await;
543 assert_eq!(list.len(), 1);
544 }
545
546 #[test]
547 fn test_workflow_templates() {
548 let basic = WorkflowTemplates::basic_nat_traversal();
549 assert_eq!(basic.id, "nat_traversal_basic");
550 assert!(!basic.stages.is_empty());
551
552 let advanced = WorkflowTemplates::advanced_nat_traversal();
553 assert_eq!(advanced.id, "nat_traversal_advanced");
554 assert!(advanced.stages.len() > basic.stages.len());
555 }
556}