1use crate::{Edge, Node, NodeId, NodeKind};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use uuid::Uuid;
5
6#[cfg(feature = "openapi")]
7use utoipa::ToSchema;
8
9pub type WorkflowId = Uuid;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14#[cfg_attr(feature = "openapi", derive(ToSchema))]
15pub struct WorkflowMetadata {
16 #[cfg_attr(feature = "openapi", schema(value_type = String))]
18 pub id: WorkflowId,
19
20 pub name: String,
22
23 pub description: Option<String>,
25
26 pub version: String,
28
29 pub created_at: DateTime<Utc>,
31
32 pub updated_at: DateTime<Utc>,
34
35 #[serde(default)]
37 pub tags: Vec<String>,
38
39 #[cfg_attr(feature = "openapi", schema(value_type = Option<String>))]
41 #[serde(default)]
42 pub parent_id: Option<WorkflowId>,
43
44 #[serde(default)]
46 pub change_description: Option<String>,
47
48 #[serde(default)]
50 pub schedule: Option<WorkflowSchedule>,
51}
52
53impl WorkflowMetadata {
54 pub fn new(name: String) -> Self {
55 let now = Utc::now();
56 Self {
57 id: Uuid::new_v4(),
58 name,
59 description: None,
60 version: "0.1.0".to_string(),
61 created_at: now,
62 updated_at: now,
63 tags: Vec::new(),
64 parent_id: None,
65 change_description: None,
66 schedule: None,
67 }
68 }
69
70 pub fn parse_version(&self) -> Result<(u32, u32, u32), String> {
72 let parts: Vec<&str> = self.version.split('.').collect();
73 if parts.len() != 3 {
74 return Err(format!("Invalid version format: {}", self.version));
75 }
76
77 let major = parts[0]
78 .parse::<u32>()
79 .map_err(|_| format!("Invalid major version: {}", parts[0]))?;
80 let minor = parts[1]
81 .parse::<u32>()
82 .map_err(|_| format!("Invalid minor version: {}", parts[1]))?;
83 let patch = parts[2]
84 .parse::<u32>()
85 .map_err(|_| format!("Invalid patch version: {}", parts[2]))?;
86
87 Ok((major, minor, patch))
88 }
89
90 pub fn bump_major(&mut self) {
92 if let Ok((major, _, _)) = self.parse_version() {
93 self.version = format!("{}.0.0", major + 1);
94 self.updated_at = Utc::now();
95 }
96 }
97
98 pub fn bump_minor(&mut self) {
100 if let Ok((major, minor, _)) = self.parse_version() {
101 self.version = format!("{}.{}.0", major, minor + 1);
102 self.updated_at = Utc::now();
103 }
104 }
105
106 pub fn bump_patch(&mut self) {
108 if let Ok((major, minor, patch)) = self.parse_version() {
109 self.version = format!("{}.{}.{}", major, minor, patch + 1);
110 self.updated_at = Utc::now();
111 }
112 }
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117#[cfg_attr(feature = "openapi", derive(ToSchema))]
118pub struct Workflow {
119 pub metadata: WorkflowMetadata,
121
122 pub nodes: Vec<Node>,
124
125 pub edges: Vec<Edge>,
127}
128
129impl Workflow {
130 pub fn new(name: String) -> Self {
131 Self {
132 metadata: WorkflowMetadata::new(name),
133 nodes: Vec::new(),
134 edges: Vec::new(),
135 }
136 }
137
138 pub fn add_node(&mut self, node: Node) {
140 self.nodes.push(node);
141 }
142
143 pub fn add_edge(&mut self, edge: Edge) {
145 self.edges.push(edge);
146 }
147
148 pub fn get_node(&self, id: &NodeId) -> Option<&Node> {
150 self.nodes.iter().find(|n| &n.id == id)
151 }
152
153 pub fn get_node_mut(&mut self, id: &NodeId) -> Option<&mut Node> {
155 self.nodes.iter_mut().find(|n| &n.id == id)
156 }
157
158 pub fn find_nodes_by_kind(&self, kind: &NodeKind) -> Vec<&Node> {
160 self.nodes
161 .iter()
162 .filter(|n| std::mem::discriminant(&n.kind) == std::mem::discriminant(kind))
163 .collect()
164 }
165
166 pub fn get_start_node(&self) -> Option<&Node> {
168 self.nodes
169 .iter()
170 .find(|n| matches!(n.kind, NodeKind::Start))
171 }
172
173 pub fn get_end_nodes(&self) -> Vec<&Node> {
175 self.nodes
176 .iter()
177 .filter(|n| matches!(n.kind, NodeKind::End))
178 .collect()
179 }
180
181 pub fn remove_node(&mut self, id: &NodeId) -> bool {
183 let node_existed = self.nodes.iter().any(|n| &n.id == id);
184 if node_existed {
185 self.nodes.retain(|n| &n.id != id);
186 self.edges.retain(|e| &e.from != id && &e.to != id);
187 }
188 node_existed
189 }
190
191 pub fn remove_edge(&mut self, from: &NodeId, to: &NodeId) -> bool {
193 let edge_count = self.edges.len();
194 self.edges.retain(|e| &e.from != from || &e.to != to);
195 self.edges.len() < edge_count
196 }
197
198 pub fn node_count(&self) -> usize {
200 self.nodes.len()
201 }
202
203 pub fn edge_count(&self) -> usize {
205 self.edges.len()
206 }
207
208 pub fn get_outgoing_edges(&self, node_id: &NodeId) -> Vec<&Edge> {
210 self.edges.iter().filter(|e| &e.from == node_id).collect()
211 }
212
213 pub fn get_incoming_edges(&self, node_id: &NodeId) -> Vec<&Edge> {
215 self.edges.iter().filter(|e| &e.to == node_id).collect()
216 }
217
218 pub fn validate(&self) -> Result<(), String> {
220 use crate::validation::WorkflowValidator;
221
222 match WorkflowValidator::validate(self) {
223 Ok(_report) => Ok(()),
224 Err(e) => Err(e.to_string()),
225 }
226 }
227
228 pub fn to_json(&self) -> Result<String, String> {
230 serde_json::to_string_pretty(self).map_err(|e| format!("JSON serialization error: {}", e))
231 }
232
233 pub fn to_json_file(&self, path: &str) -> Result<(), String> {
235 let json = self.to_json()?;
236 std::fs::write(path, json).map_err(|e| format!("File write error: {}", e))
237 }
238
239 pub fn from_json(json: &str) -> Result<Self, String> {
241 serde_json::from_str(json).map_err(|e| format!("JSON deserialization error: {}", e))
242 }
243
244 pub fn from_json_file(path: &str) -> Result<Self, String> {
246 let json = std::fs::read_to_string(path).map_err(|e| format!("File read error: {}", e))?;
247 Self::from_json(&json)
248 }
249
250 pub fn to_yaml(&self) -> Result<String, String> {
252 serde_yaml::to_string(self).map_err(|e| format!("YAML serialization error: {}", e))
253 }
254
255 pub fn to_yaml_file(&self, path: &str) -> Result<(), String> {
257 let yaml = self.to_yaml()?;
258 std::fs::write(path, yaml).map_err(|e| format!("File write error: {}", e))
259 }
260
261 pub fn from_yaml(yaml: &str) -> Result<Self, String> {
263 serde_yaml::from_str(yaml).map_err(|e| format!("YAML deserialization error: {}", e))
264 }
265
266 pub fn from_yaml_file(path: &str) -> Result<Self, String> {
268 let yaml = std::fs::read_to_string(path).map_err(|e| format!("File read error: {}", e))?;
269 Self::from_yaml(&yaml)
270 }
271
272 pub fn create_new_version(
274 &self,
275 change_description: String,
276 version_type: VersionBump,
277 ) -> Self {
278 let mut new_workflow = self.clone();
279
280 new_workflow.metadata.id = Uuid::new_v4();
282 new_workflow.metadata.parent_id = Some(self.metadata.id);
283 new_workflow.metadata.change_description = Some(change_description);
284 new_workflow.metadata.created_at = Utc::now();
285 new_workflow.metadata.updated_at = Utc::now();
286
287 match version_type {
289 VersionBump::Major => new_workflow.metadata.bump_major(),
290 VersionBump::Minor => new_workflow.metadata.bump_minor(),
291 VersionBump::Patch => new_workflow.metadata.bump_patch(),
292 }
293
294 new_workflow
295 }
296
297 pub fn is_newer_than(&self, other: &Workflow) -> Result<bool, String> {
299 let (self_major, self_minor, self_patch) = self.metadata.parse_version()?;
300 let (other_major, other_minor, other_patch) = other.metadata.parse_version()?;
301
302 Ok(self_major > other_major
303 || (self_major == other_major && self_minor > other_minor)
304 || (self_major == other_major && self_minor == other_minor && self_patch > other_patch))
305 }
306
307 pub fn version_tuple(&self) -> Result<(u32, u32, u32), String> {
309 self.metadata.parse_version()
310 }
311}
312
313#[derive(Debug, Clone, Copy, PartialEq, Eq)]
315pub enum VersionBump {
316 Major,
318 Minor,
320 Patch,
322}
323
324#[derive(Debug, Clone, Serialize, Deserialize)]
326#[cfg_attr(feature = "openapi", derive(ToSchema))]
327pub struct WorkflowSchedule {
328 pub cron: String,
330
331 #[serde(default = "default_timezone")]
333 pub timezone: String,
334
335 #[serde(default = "default_enabled")]
337 pub enabled: bool,
338
339 #[serde(default)]
341 pub max_concurrent_runs: Option<u32>,
342
343 #[serde(default)]
345 pub retry_on_failure: bool,
346
347 #[serde(default)]
349 pub start_date: Option<DateTime<Utc>>,
350
351 #[serde(default)]
353 pub end_date: Option<DateTime<Utc>>,
354}
355
356fn default_timezone() -> String {
357 "UTC".to_string()
358}
359
360fn default_enabled() -> bool {
361 true
362}
363
364impl WorkflowSchedule {
365 pub fn new(cron: String) -> Self {
367 Self {
368 cron,
369 timezone: default_timezone(),
370 enabled: true,
371 max_concurrent_runs: None,
372 retry_on_failure: false,
373 start_date: None,
374 end_date: None,
375 }
376 }
377
378 pub fn with_timezone(mut self, timezone: String) -> Self {
380 self.timezone = timezone;
381 self
382 }
383
384 pub fn set_enabled(mut self, enabled: bool) -> Self {
386 self.enabled = enabled;
387 self
388 }
389
390 pub fn with_max_concurrent_runs(mut self, max: u32) -> Self {
392 self.max_concurrent_runs = Some(max);
393 self
394 }
395
396 pub fn with_date_range(mut self, start: DateTime<Utc>, end: DateTime<Utc>) -> Self {
398 self.start_date = Some(start);
399 self.end_date = Some(end);
400 self
401 }
402
403 pub fn is_valid_now(&self) -> bool {
405 if !self.enabled {
406 return false;
407 }
408
409 let now = Utc::now();
410
411 if let Some(start) = self.start_date {
412 if now < start {
413 return false;
414 }
415 }
416
417 if let Some(end) = self.end_date {
418 if now > end {
419 return false;
420 }
421 }
422
423 true
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430 use crate::{Edge, LlmConfig, Node, NodeKind};
431 use chrono::Duration;
432
433 #[test]
434 fn test_workflow_validation() {
435 let mut workflow = Workflow::new("Test Workflow".to_string());
436
437 let start_node = Node::new("Start".to_string(), NodeKind::Start);
439 let start_id = start_node.id;
440 workflow.add_node(start_node);
441
442 let end_node = Node::new("End".to_string(), NodeKind::End);
444 let end_id = end_node.id;
445 workflow.add_node(end_node);
446
447 workflow.add_edge(Edge::new(start_id, end_id));
449
450 assert!(workflow.validate().is_ok());
451 }
452
453 #[test]
454 fn test_workflow_missing_start_node() {
455 let workflow = Workflow::new("Test Workflow".to_string());
456 assert!(workflow.validate().is_err());
457 }
458
459 #[test]
460 fn test_workflow_json_serialization() {
461 let mut workflow = Workflow::new("Test Workflow".to_string());
462
463 let start_node = Node::new("Start".to_string(), NodeKind::Start);
464 let start_id = start_node.id;
465 workflow.add_node(start_node);
466
467 let end_node = Node::new("End".to_string(), NodeKind::End);
468 let end_id = end_node.id;
469 workflow.add_node(end_node);
470
471 workflow.add_edge(Edge::new(start_id, end_id));
472
473 let json = workflow.to_json();
475 assert!(json.is_ok());
476
477 let restored = Workflow::from_json(&json.unwrap());
479 assert!(restored.is_ok());
480
481 let restored_workflow = restored.unwrap();
482 assert_eq!(restored_workflow.nodes.len(), 2);
483 assert_eq!(restored_workflow.edges.len(), 1);
484 assert_eq!(restored_workflow.metadata.name, "Test Workflow");
485 }
486
487 #[test]
488 fn test_workflow_metadata_new() {
489 let metadata = WorkflowMetadata::new("Test Workflow".to_string());
490
491 assert_eq!(metadata.name, "Test Workflow");
492 assert_eq!(metadata.version, "0.1.0");
493 assert!(metadata.description.is_none());
494 assert_eq!(metadata.tags.len(), 0);
495 assert!(metadata.parent_id.is_none());
496 assert!(metadata.change_description.is_none());
497 assert!(metadata.schedule.is_none());
498 }
499
500 #[test]
501 fn test_workflow_metadata_parse_version() {
502 let metadata = WorkflowMetadata::new("Test".to_string());
503 let (major, minor, patch) = metadata.parse_version().unwrap();
504
505 assert_eq!(major, 0);
506 assert_eq!(minor, 1);
507 assert_eq!(patch, 0);
508 }
509
510 #[test]
511 fn test_workflow_metadata_parse_version_invalid() {
512 let mut metadata = WorkflowMetadata::new("Test".to_string());
513 metadata.version = "invalid".to_string();
514
515 assert!(metadata.parse_version().is_err());
516 }
517
518 #[test]
519 fn test_workflow_metadata_bump_major() {
520 let mut metadata = WorkflowMetadata::new("Test".to_string());
521 metadata.version = "1.2.3".to_string();
522
523 metadata.bump_major();
524
525 assert_eq!(metadata.version, "2.0.0");
526 }
527
528 #[test]
529 fn test_workflow_metadata_bump_minor() {
530 let mut metadata = WorkflowMetadata::new("Test".to_string());
531 metadata.version = "1.2.3".to_string();
532
533 metadata.bump_minor();
534
535 assert_eq!(metadata.version, "1.3.0");
536 }
537
538 #[test]
539 fn test_workflow_metadata_bump_patch() {
540 let mut metadata = WorkflowMetadata::new("Test".to_string());
541 metadata.version = "1.2.3".to_string();
542
543 metadata.bump_patch();
544
545 assert_eq!(metadata.version, "1.2.4");
546 }
547
548 #[test]
549 fn test_workflow_new() {
550 let workflow = Workflow::new("Test Workflow".to_string());
551
552 assert_eq!(workflow.metadata.name, "Test Workflow");
553 assert_eq!(workflow.nodes.len(), 0);
554 assert_eq!(workflow.edges.len(), 0);
555 }
556
557 #[test]
558 fn test_workflow_add_node() {
559 let mut workflow = Workflow::new("Test".to_string());
560 let node = Node::new("Test Node".to_string(), NodeKind::Start);
561
562 workflow.add_node(node);
563
564 assert_eq!(workflow.nodes.len(), 1);
565 assert_eq!(workflow.nodes[0].name, "Test Node");
566 }
567
568 #[test]
569 fn test_workflow_add_edge() {
570 let mut workflow = Workflow::new("Test".to_string());
571 let from = Uuid::new_v4();
572 let to = Uuid::new_v4();
573 let edge = Edge::new(from, to);
574
575 workflow.add_edge(edge);
576
577 assert_eq!(workflow.edges.len(), 1);
578 assert_eq!(workflow.edges[0].from, from);
579 assert_eq!(workflow.edges[0].to, to);
580 }
581
582 #[test]
583 fn test_workflow_get_node() {
584 let mut workflow = Workflow::new("Test".to_string());
585 let node = Node::new("Test Node".to_string(), NodeKind::Start);
586 let node_id = node.id;
587
588 workflow.add_node(node);
589
590 let found = workflow.get_node(&node_id);
591 assert!(found.is_some());
592 assert_eq!(found.unwrap().name, "Test Node");
593 }
594
595 #[test]
596 fn test_workflow_get_node_not_found() {
597 let workflow = Workflow::new("Test".to_string());
598 let node_id = Uuid::new_v4();
599
600 let found = workflow.get_node(&node_id);
601 assert!(found.is_none());
602 }
603
604 #[test]
605 fn test_workflow_get_outgoing_edges() {
606 let mut workflow = Workflow::new("Test".to_string());
607 let from = Uuid::new_v4();
608 let to1 = Uuid::new_v4();
609 let to2 = Uuid::new_v4();
610
611 workflow.add_edge(Edge::new(from, to1));
612 workflow.add_edge(Edge::new(from, to2));
613 workflow.add_edge(Edge::new(to1, to2));
614
615 let outgoing = workflow.get_outgoing_edges(&from);
616 assert_eq!(outgoing.len(), 2);
617 }
618
619 #[test]
620 fn test_workflow_get_incoming_edges() {
621 let mut workflow = Workflow::new("Test".to_string());
622 let from1 = Uuid::new_v4();
623 let from2 = Uuid::new_v4();
624 let to = Uuid::new_v4();
625
626 workflow.add_edge(Edge::new(from1, to));
627 workflow.add_edge(Edge::new(from2, to));
628 workflow.add_edge(Edge::new(from1, from2));
629
630 let incoming = workflow.get_incoming_edges(&to);
631 assert_eq!(incoming.len(), 2);
632 }
633
634 #[test]
635 fn test_workflow_yaml_serialization() {
636 let mut workflow = Workflow::new("Test Workflow".to_string());
637
638 let start_node = Node::new("Start".to_string(), NodeKind::Start);
639 let start_id = start_node.id;
640 workflow.add_node(start_node);
641
642 let end_node = Node::new("End".to_string(), NodeKind::End);
643 let end_id = end_node.id;
644 workflow.add_node(end_node);
645
646 workflow.add_edge(Edge::new(start_id, end_id));
647
648 let yaml = workflow.to_yaml();
650 assert!(yaml.is_ok());
651
652 let restored = Workflow::from_yaml(&yaml.unwrap());
654 assert!(restored.is_ok());
655
656 let restored_workflow = restored.unwrap();
657 assert_eq!(restored_workflow.nodes.len(), 2);
658 assert_eq!(restored_workflow.edges.len(), 1);
659 assert_eq!(restored_workflow.metadata.name, "Test Workflow");
660 }
661
662 #[test]
663 fn test_workflow_create_new_version_major() {
664 let workflow = Workflow::new("Test".to_string());
665 let new_version =
666 workflow.create_new_version("Breaking changes".to_string(), VersionBump::Major);
667
668 assert_ne!(new_version.metadata.id, workflow.metadata.id);
669 assert_eq!(new_version.metadata.parent_id, Some(workflow.metadata.id));
670 assert_eq!(new_version.metadata.version, "1.0.0");
671 assert_eq!(
672 new_version.metadata.change_description,
673 Some("Breaking changes".to_string())
674 );
675 }
676
677 #[test]
678 fn test_workflow_create_new_version_minor() {
679 let mut workflow = Workflow::new("Test".to_string());
680 workflow.metadata.version = "1.0.0".to_string();
681
682 let new_version =
683 workflow.create_new_version("New features".to_string(), VersionBump::Minor);
684
685 assert_eq!(new_version.metadata.version, "1.1.0");
686 }
687
688 #[test]
689 fn test_workflow_create_new_version_patch() {
690 let mut workflow = Workflow::new("Test".to_string());
691 workflow.metadata.version = "1.0.0".to_string();
692
693 let new_version = workflow.create_new_version("Bug fixes".to_string(), VersionBump::Patch);
694
695 assert_eq!(new_version.metadata.version, "1.0.1");
696 }
697
698 #[test]
699 fn test_workflow_is_newer_than() {
700 let mut workflow1 = Workflow::new("Test".to_string());
701 workflow1.metadata.version = "1.0.0".to_string();
702
703 let mut workflow2 = Workflow::new("Test".to_string());
704 workflow2.metadata.version = "2.0.0".to_string();
705
706 assert!(workflow2.is_newer_than(&workflow1).unwrap());
707 assert!(!workflow1.is_newer_than(&workflow2).unwrap());
708 }
709
710 #[test]
711 fn test_workflow_version_tuple() {
712 let mut workflow = Workflow::new("Test".to_string());
713 workflow.metadata.version = "3.2.1".to_string();
714
715 let (major, minor, patch) = workflow.version_tuple().unwrap();
716 assert_eq!(major, 3);
717 assert_eq!(minor, 2);
718 assert_eq!(patch, 1);
719 }
720
721 #[test]
722 fn test_workflow_schedule_new() {
723 let schedule = WorkflowSchedule::new("0 0 * * *".to_string());
724
725 assert_eq!(schedule.cron, "0 0 * * *");
726 assert_eq!(schedule.timezone, "UTC");
727 assert!(schedule.enabled);
728 assert!(schedule.max_concurrent_runs.is_none());
729 assert!(!schedule.retry_on_failure);
730 assert!(schedule.start_date.is_none());
731 assert!(schedule.end_date.is_none());
732 }
733
734 #[test]
735 fn test_workflow_schedule_with_timezone() {
736 let schedule = WorkflowSchedule::new("0 0 * * *".to_string())
737 .with_timezone("America/New_York".to_string());
738
739 assert_eq!(schedule.timezone, "America/New_York");
740 }
741
742 #[test]
743 fn test_workflow_schedule_set_enabled() {
744 let schedule = WorkflowSchedule::new("0 0 * * *".to_string()).set_enabled(false);
745
746 assert!(!schedule.enabled);
747 }
748
749 #[test]
750 fn test_workflow_schedule_with_max_concurrent_runs() {
751 let schedule = WorkflowSchedule::new("0 0 * * *".to_string()).with_max_concurrent_runs(5);
752
753 assert_eq!(schedule.max_concurrent_runs, Some(5));
754 }
755
756 #[test]
757 fn test_workflow_schedule_with_date_range() {
758 let now = Utc::now();
759 let future = now + Duration::days(7);
760 let schedule = WorkflowSchedule::new("0 0 * * *".to_string()).with_date_range(now, future);
761
762 assert!(schedule.start_date.is_some());
763 assert!(schedule.end_date.is_some());
764 assert_eq!(schedule.start_date.unwrap(), now);
765 assert_eq!(schedule.end_date.unwrap(), future);
766 }
767
768 #[test]
769 fn test_workflow_schedule_is_valid_now_enabled() {
770 let schedule = WorkflowSchedule::new("0 0 * * *".to_string());
771 assert!(schedule.is_valid_now());
772 }
773
774 #[test]
775 fn test_workflow_schedule_is_valid_now_disabled() {
776 let schedule = WorkflowSchedule::new("0 0 * * *".to_string()).set_enabled(false);
777 assert!(!schedule.is_valid_now());
778 }
779
780 #[test]
781 fn test_workflow_schedule_is_valid_now_before_start() {
782 let future = Utc::now() + Duration::days(1);
783 let end = future + Duration::days(7);
784 let schedule = WorkflowSchedule::new("0 0 * * *".to_string()).with_date_range(future, end);
785
786 assert!(!schedule.is_valid_now());
787 }
788
789 #[test]
790 fn test_workflow_schedule_is_valid_now_after_end() {
791 let past_start = Utc::now() - Duration::days(7);
792 let past_end = Utc::now() - Duration::days(1);
793 let schedule =
794 WorkflowSchedule::new("0 0 * * *".to_string()).with_date_range(past_start, past_end);
795
796 assert!(!schedule.is_valid_now());
797 }
798
799 #[test]
800 fn test_workflow_schedule_is_valid_now_within_range() {
801 let past = Utc::now() - Duration::days(1);
802 let future = Utc::now() + Duration::days(1);
803 let schedule = WorkflowSchedule::new("0 0 * * *".to_string()).with_date_range(past, future);
804
805 assert!(schedule.is_valid_now());
806 }
807
808 #[test]
809 fn test_version_bump_enum() {
810 assert_eq!(VersionBump::Major, VersionBump::Major);
811 assert_ne!(VersionBump::Major, VersionBump::Minor);
812 assert_ne!(VersionBump::Minor, VersionBump::Patch);
813 }
814
815 #[test]
816 fn test_workflow_get_node_mut() {
817 let mut workflow = Workflow::new("test".to_string());
818 let node = Node::new("Start".to_string(), NodeKind::Start);
819 let node_id = node.id;
820 workflow.add_node(node);
821
822 let node_mut = workflow.get_node_mut(&node_id);
824 assert!(node_mut.is_some());
825 let node_mut = node_mut.unwrap();
826 node_mut.name = "Modified".to_string();
827
828 let node = workflow.get_node(&node_id).unwrap();
830 assert_eq!(node.name, "Modified");
831 }
832
833 #[test]
834 fn test_workflow_find_nodes_by_kind() {
835 let mut workflow = Workflow::new("test".to_string());
836 workflow.add_node(Node::new("Start".to_string(), NodeKind::Start));
837
838 let llm_config = LlmConfig {
839 provider: "openai".to_string(),
840 model: "gpt-4".to_string(),
841 system_prompt: None,
842 prompt_template: "test".to_string(),
843 temperature: None,
844 max_tokens: None,
845 tools: vec![],
846 images: vec![],
847 extra_params: serde_json::json!({}),
848 };
849
850 workflow.add_node(Node::new(
851 "LLM1".to_string(),
852 NodeKind::LLM(llm_config.clone()),
853 ));
854 workflow.add_node(Node::new(
855 "LLM2".to_string(),
856 NodeKind::LLM(llm_config.clone()),
857 ));
858 workflow.add_node(Node::new("End".to_string(), NodeKind::End));
859
860 let llm_nodes = workflow.find_nodes_by_kind(&NodeKind::LLM(llm_config));
861 assert_eq!(llm_nodes.len(), 2);
862
863 let start_nodes = workflow.find_nodes_by_kind(&NodeKind::Start);
864 assert_eq!(start_nodes.len(), 1);
865 }
866
867 #[test]
868 fn test_workflow_get_start_node() {
869 let mut workflow = Workflow::new("test".to_string());
870 assert!(workflow.get_start_node().is_none());
871
872 let start = Node::new("Start".to_string(), NodeKind::Start);
873 workflow.add_node(start);
874
875 let start_node = workflow.get_start_node();
876 assert!(start_node.is_some());
877 assert!(matches!(start_node.unwrap().kind, NodeKind::Start));
878 }
879
880 #[test]
881 fn test_workflow_get_end_nodes() {
882 let mut workflow = Workflow::new("test".to_string());
883 assert_eq!(workflow.get_end_nodes().len(), 0);
884
885 workflow.add_node(Node::new("End1".to_string(), NodeKind::End));
886 workflow.add_node(Node::new("End2".to_string(), NodeKind::End));
887
888 let end_nodes = workflow.get_end_nodes();
889 assert_eq!(end_nodes.len(), 2);
890 }
891
892 #[test]
893 fn test_workflow_remove_node() {
894 let mut workflow = Workflow::new("test".to_string());
895 let node1 = Node::new("Start".to_string(), NodeKind::Start);
896 let node2 = Node::new("End".to_string(), NodeKind::End);
897 let id1 = node1.id;
898 let id2 = node2.id;
899
900 workflow.add_node(node1);
901 workflow.add_node(node2);
902 workflow.add_edge(Edge::new(id1, id2));
903
904 assert_eq!(workflow.nodes.len(), 2);
905 assert_eq!(workflow.edges.len(), 1);
906
907 let removed = workflow.remove_node(&id1);
909 assert!(removed);
910 assert_eq!(workflow.nodes.len(), 1);
911 assert_eq!(workflow.edges.len(), 0); let removed = workflow.remove_node(&id1);
915 assert!(!removed);
916 }
917
918 #[test]
919 fn test_workflow_remove_edge() {
920 let mut workflow = Workflow::new("test".to_string());
921 let node1 = Node::new("Start".to_string(), NodeKind::Start);
922 let node2 = Node::new("End".to_string(), NodeKind::End);
923 let id1 = node1.id;
924 let id2 = node2.id;
925
926 workflow.add_node(node1);
927 workflow.add_node(node2);
928 workflow.add_edge(Edge::new(id1, id2));
929
930 assert_eq!(workflow.edges.len(), 1);
931
932 let removed = workflow.remove_edge(&id1, &id2);
934 assert!(removed);
935 assert_eq!(workflow.edges.len(), 0);
936
937 let removed = workflow.remove_edge(&id1, &id2);
939 assert!(!removed);
940 }
941
942 #[test]
943 fn test_workflow_node_count() {
944 let mut workflow = Workflow::new("test".to_string());
945 assert_eq!(workflow.node_count(), 0);
946
947 workflow.add_node(Node::new("Start".to_string(), NodeKind::Start));
948 assert_eq!(workflow.node_count(), 1);
949
950 workflow.add_node(Node::new("End".to_string(), NodeKind::End));
951 assert_eq!(workflow.node_count(), 2);
952 }
953
954 #[test]
955 fn test_workflow_edge_count() {
956 let mut workflow = Workflow::new("test".to_string());
957 let node1 = Node::new("Start".to_string(), NodeKind::Start);
958 let node2 = Node::new("End".to_string(), NodeKind::End);
959 let id1 = node1.id;
960 let id2 = node2.id;
961
962 workflow.add_node(node1);
963 workflow.add_node(node2);
964
965 assert_eq!(workflow.edge_count(), 0);
966
967 workflow.add_edge(Edge::new(id1, id2));
968 assert_eq!(workflow.edge_count(), 1);
969 }
970}