1use crate::graph::{
61 Checkpoint, CheckpointStore, Graph, GraphBuilder, GraphResult, GraphState, GraphStatus,
62 InMemoryCheckpointStore, END,
63};
64use chrono::{DateTime, Utc};
65use cortexai_core::errors::CrewError;
66use serde::{Deserialize, Serialize};
67use std::collections::{HashMap, HashSet};
68use std::sync::Arc;
69use tokio::sync::oneshot;
70
71#[derive(Debug, Clone)]
73pub enum HumanLoopAction {
74 Continue(GraphState),
76
77 AwaitApproval {
79 gate: ApprovalGate,
81 state: GraphState,
83 node_id: String,
85 },
86
87 AwaitInput {
89 request: HumanInputRequest,
91 state: GraphState,
93 },
94
95 Breakpoint {
97 node_id: String,
99 state: GraphState,
101 },
102
103 Interrupted {
105 state: GraphState,
107 reason: String,
109 },
110
111 Complete(GraphResult),
113
114 Failed(String),
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct ApprovalGate {
121 pub description: String,
123 pub show_fields: Vec<String>,
125 pub allow_edit: bool,
127 pub timeout_secs: Option<u64>,
129 pub metadata: HashMap<String, serde_json::Value>,
131}
132
133impl ApprovalGate {
134 pub fn new(description: impl Into<String>) -> Self {
136 Self {
137 description: description.into(),
138 show_fields: Vec::new(),
139 allow_edit: false,
140 timeout_secs: None,
141 metadata: HashMap::new(),
142 }
143 }
144
145 pub fn show_fields(mut self, fields: Vec<String>) -> Self {
147 self.show_fields = fields;
148 self
149 }
150
151 pub fn allow_edit(mut self) -> Self {
153 self.allow_edit = true;
154 self
155 }
156
157 pub fn with_timeout(mut self, secs: u64) -> Self {
159 self.timeout_secs = Some(secs);
160 self
161 }
162
163 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
165 self.metadata.insert(key.into(), value);
166 self
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct HumanInputRequest {
173 pub id: String,
175 pub prompt: String,
177 pub input_type: HumanInputType,
179 pub field_name: String,
181 pub required: bool,
183 pub default: Option<serde_json::Value>,
185 pub validation: Option<serde_json::Value>,
187 pub timeout_secs: Option<u64>,
189}
190
191impl HumanInputRequest {
192 pub fn new(prompt: impl Into<String>, field_name: impl Into<String>) -> Self {
194 Self {
195 id: uuid::Uuid::new_v4().to_string(),
196 prompt: prompt.into(),
197 input_type: HumanInputType::Text,
198 field_name: field_name.into(),
199 required: true,
200 default: None,
201 validation: None,
202 timeout_secs: None,
203 }
204 }
205
206 pub fn input_type(mut self, t: HumanInputType) -> Self {
208 self.input_type = t;
209 self
210 }
211
212 pub fn optional(mut self) -> Self {
214 self.required = false;
215 self
216 }
217
218 pub fn with_default(mut self, value: serde_json::Value) -> Self {
220 self.default = Some(value);
221 self.required = false;
222 self
223 }
224
225 pub fn with_timeout(mut self, secs: u64) -> Self {
227 self.timeout_secs = Some(secs);
228 self
229 }
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
234pub enum HumanInputType {
235 Text,
237 TextArea,
239 Boolean,
241 Number,
243 Select(Vec<SelectOption>),
245 MultiSelect(Vec<SelectOption>),
247 DateTime,
249 File,
251 Json,
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct SelectOption {
258 pub label: String,
260 pub value: serde_json::Value,
262 pub description: Option<String>,
264}
265
266impl SelectOption {
267 pub fn new(label: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
269 Self {
270 label: label.into(),
271 value: value.into(),
272 description: None,
273 }
274 }
275
276 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
278 self.description = Some(desc.into());
279 self
280 }
281}
282
283#[derive(Debug, Clone, Serialize, Deserialize)]
285pub enum ApprovalResponse {
286 Approved,
288 ApprovedWithChanges(serde_json::Value),
290 Rejected(String),
292 Retry,
294}
295
296pub struct InteractiveGraph {
298 graph: Graph,
300 approval_gates: HashMap<String, ApprovalGate>,
302 breakpoints: HashSet<String>,
304 input_requests: HashMap<String, HumanInputRequest>,
306 checkpoint_store: Arc<dyn CheckpointStore>,
308}
309
310impl InteractiveGraph {
311 pub fn new(graph: Graph) -> Self {
313 Self {
314 graph,
315 approval_gates: HashMap::new(),
316 breakpoints: HashSet::new(),
317 input_requests: HashMap::new(),
318 checkpoint_store: Arc::new(InMemoryCheckpointStore::default()),
319 }
320 }
321
322 pub fn with_approval_gate(mut self, node_id: impl Into<String>, gate: ApprovalGate) -> Self {
324 self.approval_gates.insert(node_id.into(), gate);
325 self
326 }
327
328 pub fn with_breakpoint(mut self, node_id: impl Into<String>) -> Self {
330 self.breakpoints.insert(node_id.into());
331 self
332 }
333
334 pub fn with_input_request(
336 mut self,
337 node_id: impl Into<String>,
338 request: HumanInputRequest,
339 ) -> Self {
340 self.input_requests.insert(node_id.into(), request);
341 self
342 }
343
344 pub fn with_checkpoint_store(mut self, store: Arc<dyn CheckpointStore>) -> Self {
346 self.checkpoint_store = store;
347 self
348 }
349
350 pub async fn start(
352 &self,
353 initial_state: GraphState,
354 ) -> Result<InteractiveSession<'_>, CrewError> {
355 Ok(InteractiveSession::new(
356 &self.graph,
357 initial_state,
358 self.approval_gates.clone(),
359 self.breakpoints.clone(),
360 self.input_requests.clone(),
361 self.checkpoint_store.clone(),
362 ))
363 }
364
365 pub async fn resume(&self, checkpoint_id: &str) -> Result<InteractiveSession<'_>, CrewError> {
367 let checkpoint = self
368 .checkpoint_store
369 .load(checkpoint_id)
370 .await?
371 .ok_or_else(|| {
372 CrewError::TaskNotFound(format!("Checkpoint not found: {}", checkpoint_id))
373 })?;
374
375 Ok(InteractiveSession::from_checkpoint(
376 &self.graph,
377 checkpoint,
378 self.approval_gates.clone(),
379 self.breakpoints.clone(),
380 self.input_requests.clone(),
381 self.checkpoint_store.clone(),
382 ))
383 }
384}
385
386pub struct InteractiveSession<'a> {
388 graph: &'a Graph,
390 state: GraphState,
392 current_node: String,
394 status: SessionStatus,
396 approval_gates: HashMap<String, ApprovalGate>,
398 breakpoints: HashSet<String>,
400 input_requests: HashMap<String, HumanInputRequest>,
402 checkpoint_store: Arc<dyn CheckpointStore>,
404 #[allow(dead_code)]
406 pending_approval: Option<oneshot::Sender<ApprovalResponse>>,
407 #[allow(dead_code)]
409 pending_input: Option<oneshot::Sender<serde_json::Value>>,
410 session_id: String,
412 started_at: DateTime<Utc>,
414}
415
416#[derive(Debug, Clone, Copy, PartialEq, Eq)]
418pub enum SessionStatus {
419 Running,
421 AwaitingApproval,
423 AwaitingInput,
425 Paused,
427 Completed,
429 Failed,
431 Interrupted,
433}
434
435impl<'a> InteractiveSession<'a> {
436 fn new(
438 graph: &'a Graph,
439 initial_state: GraphState,
440 approval_gates: HashMap<String, ApprovalGate>,
441 breakpoints: HashSet<String>,
442 input_requests: HashMap<String, HumanInputRequest>,
443 checkpoint_store: Arc<dyn CheckpointStore>,
444 ) -> Self {
445 Self {
446 graph,
447 state: initial_state,
448 current_node: graph.entry_node.clone(),
449 status: SessionStatus::Running,
450 approval_gates,
451 breakpoints,
452 input_requests,
453 checkpoint_store,
454 pending_approval: None,
455 pending_input: None,
456 session_id: uuid::Uuid::new_v4().to_string(),
457 started_at: Utc::now(),
458 }
459 }
460
461 fn from_checkpoint(
463 graph: &'a Graph,
464 checkpoint: Checkpoint,
465 approval_gates: HashMap<String, ApprovalGate>,
466 breakpoints: HashSet<String>,
467 input_requests: HashMap<String, HumanInputRequest>,
468 checkpoint_store: Arc<dyn CheckpointStore>,
469 ) -> Self {
470 Self {
471 graph,
472 state: checkpoint.state,
473 current_node: checkpoint.next_node,
474 status: SessionStatus::Running,
475 approval_gates,
476 breakpoints,
477 input_requests,
478 checkpoint_store,
479 pending_approval: None,
480 pending_input: None,
481 session_id: uuid::Uuid::new_v4().to_string(),
482 started_at: Utc::now(),
483 }
484 }
485
486 pub fn state(&self) -> &GraphState {
488 &self.state
489 }
490
491 pub fn current_node(&self) -> &str {
493 &self.current_node
494 }
495
496 pub fn status(&self) -> SessionStatus {
498 self.status
499 }
500
501 pub fn session_id(&self) -> &str {
503 &self.session_id
504 }
505
506 pub async fn next(&mut self) -> Result<HumanLoopAction, CrewError> {
508 if self.current_node == END {
510 self.status = SessionStatus::Completed;
511 let duration = Utc::now()
512 .signed_duration_since(self.started_at)
513 .num_milliseconds() as u64;
514 self.state.metadata.execution_time_ms = duration;
515
516 return Ok(HumanLoopAction::Complete(GraphResult {
517 state: self.state.clone(),
518 status: GraphStatus::Success,
519 error: None,
520 }));
521 }
522
523 if self.state.metadata.iterations >= 100 {
525 self.status = SessionStatus::Failed;
526 return Ok(HumanLoopAction::Failed(
527 "Max iterations reached".to_string(),
528 ));
529 }
530
531 if let Some(request) = self.input_requests.get(&self.current_node).cloned() {
533 self.status = SessionStatus::AwaitingInput;
534 return Ok(HumanLoopAction::AwaitInput {
535 request,
536 state: self.state.clone(),
537 });
538 }
539
540 if self.breakpoints.contains(&self.current_node) {
542 self.status = SessionStatus::Paused;
543 self.save_checkpoint().await?;
545 return Ok(HumanLoopAction::Breakpoint {
546 node_id: self.current_node.clone(),
547 state: self.state.clone(),
548 });
549 }
550
551 let node = self
553 .graph
554 .nodes
555 .get(&self.current_node)
556 .ok_or_else(|| CrewError::TaskNotFound(self.current_node.clone()))?;
557
558 self.state
559 .metadata
560 .visited_nodes
561 .push(self.current_node.clone());
562 self.state.metadata.iterations += 1;
563
564 self.state = node.executor.call(self.state.clone()).await?;
565
566 if let Some(gate) = self.approval_gates.get(&self.current_node).cloned() {
568 self.status = SessionStatus::AwaitingApproval;
569 self.save_checkpoint().await?;
571 return Ok(HumanLoopAction::AwaitApproval {
572 gate,
573 state: self.state.clone(),
574 node_id: self.current_node.clone(),
575 });
576 }
577
578 self.current_node = self.find_next_node()?;
580 self.status = SessionStatus::Running;
581
582 Ok(HumanLoopAction::Continue(self.state.clone()))
583 }
584
585 pub async fn approve(&mut self) -> Result<(), CrewError> {
587 if self.status != SessionStatus::AwaitingApproval {
588 return Err(CrewError::ExecutionFailed(
589 "Not awaiting approval".to_string(),
590 ));
591 }
592
593 self.current_node = self.find_next_node()?;
595 self.status = SessionStatus::Running;
596 Ok(())
597 }
598
599 pub async fn approve_with_changes(
601 &mut self,
602 changes: serde_json::Value,
603 ) -> Result<(), CrewError> {
604 if self.status != SessionStatus::AwaitingApproval {
605 return Err(CrewError::ExecutionFailed(
606 "Not awaiting approval".to_string(),
607 ));
608 }
609
610 if let Some(obj) = changes.as_object() {
612 for (k, v) in obj {
613 self.state.set(k, v.clone());
614 }
615 }
616
617 self.current_node = self.find_next_node()?;
619 self.status = SessionStatus::Running;
620 Ok(())
621 }
622
623 pub async fn reject(&mut self, reason: impl Into<String>) -> Result<(), CrewError> {
625 if self.status != SessionStatus::AwaitingApproval {
626 return Err(CrewError::ExecutionFailed(
627 "Not awaiting approval".to_string(),
628 ));
629 }
630
631 self.status = SessionStatus::Interrupted;
632 self.state.set("_rejection_reason", reason.into());
633 Ok(())
634 }
635
636 pub async fn provide_input(&mut self, value: serde_json::Value) -> Result<(), CrewError> {
638 if self.status != SessionStatus::AwaitingInput {
639 return Err(CrewError::ExecutionFailed("Not awaiting input".to_string()));
640 }
641
642 if let Some(request) = self.input_requests.remove(&self.current_node) {
644 self.state.set(&request.field_name, value);
645 }
646
647 self.status = SessionStatus::Running;
648 Ok(())
649 }
650
651 pub async fn resume(&mut self) -> Result<(), CrewError> {
653 if self.status != SessionStatus::Paused {
654 return Err(CrewError::ExecutionFailed("Not paused".to_string()));
655 }
656
657 self.breakpoints.remove(&self.current_node);
659 self.status = SessionStatus::Running;
660 Ok(())
661 }
662
663 pub async fn resume_with_state(&mut self, new_state: GraphState) -> Result<(), CrewError> {
665 if self.status != SessionStatus::Paused {
666 return Err(CrewError::ExecutionFailed("Not paused".to_string()));
667 }
668
669 self.state = new_state;
670 self.breakpoints.remove(&self.current_node);
671 self.status = SessionStatus::Running;
672 Ok(())
673 }
674
675 pub async fn interrupt(&mut self, reason: impl Into<String>) -> Result<(), CrewError> {
677 self.status = SessionStatus::Interrupted;
678 self.state.set("_interrupt_reason", reason.into());
679 self.save_checkpoint().await?;
681 Ok(())
682 }
683
684 pub fn checkpoint_id(&self) -> String {
686 format!("{}_{}", self.session_id, self.state.metadata.iterations)
687 }
688
689 async fn save_checkpoint(&self) -> Result<(), CrewError> {
691 let checkpoint = Checkpoint {
692 id: self.checkpoint_id(),
693 state: self.state.clone(),
694 next_node: self.current_node.clone(),
695 created_at: Utc::now(),
696 };
697 self.checkpoint_store.save(checkpoint).await
698 }
699
700 fn find_next_node(&self) -> Result<String, CrewError> {
702 for edge in &self.graph.edges {
703 match edge {
704 crate::graph::GraphEdge::Direct { from, to } if *from == self.current_node => {
705 return Ok(to.clone());
706 }
707 crate::graph::GraphEdge::Conditional { from, router }
708 if *from == self.current_node =>
709 {
710 return Ok(router.route(&self.state));
711 }
712 _ => continue,
713 }
714 }
715 Ok(END.to_string())
716 }
717
718 pub async fn run_until_human_action(&mut self) -> Result<HumanLoopAction, CrewError> {
720 loop {
721 let action = self.next().await?;
722 match &action {
723 HumanLoopAction::Continue(_) => continue,
724 _ => return Ok(action),
725 }
726 }
727 }
728}
729
730pub struct InteractiveWorkflowBuilder {
732 graph_builder: GraphBuilder,
733 approval_gates: HashMap<String, ApprovalGate>,
734 breakpoints: HashSet<String>,
735 input_requests: HashMap<String, HumanInputRequest>,
736}
737
738impl InteractiveWorkflowBuilder {
739 pub fn new(id: impl Into<String>) -> Self {
741 Self {
742 graph_builder: GraphBuilder::new(id),
743 approval_gates: HashMap::new(),
744 breakpoints: HashSet::new(),
745 input_requests: HashMap::new(),
746 }
747 }
748
749 pub fn add_node<F, Fut>(mut self, id: impl Into<String>, func: F) -> Self
751 where
752 F: Fn(GraphState) -> Fut + Send + Sync + 'static,
753 Fut: std::future::Future<Output = Result<GraphState, CrewError>> + Send + 'static,
754 {
755 self.graph_builder = self.graph_builder.add_node(id, func);
756 self
757 }
758
759 pub fn add_node_with_approval<F, Fut>(
761 mut self,
762 id: impl Into<String>,
763 func: F,
764 gate: ApprovalGate,
765 ) -> Self
766 where
767 F: Fn(GraphState) -> Fut + Send + Sync + 'static,
768 Fut: std::future::Future<Output = Result<GraphState, CrewError>> + Send + 'static,
769 {
770 let id = id.into();
771 self.graph_builder = self.graph_builder.add_node(id.clone(), func);
772 self.approval_gates.insert(id, gate);
773 self
774 }
775
776 pub fn add_node_with_input<F, Fut>(
778 mut self,
779 id: impl Into<String>,
780 func: F,
781 request: HumanInputRequest,
782 ) -> Self
783 where
784 F: Fn(GraphState) -> Fut + Send + Sync + 'static,
785 Fut: std::future::Future<Output = Result<GraphState, CrewError>> + Send + 'static,
786 {
787 let id = id.into();
788 self.graph_builder = self.graph_builder.add_node(id.clone(), func);
789 self.input_requests.insert(id, request);
790 self
791 }
792
793 pub fn add_breakpoint_node<F, Fut>(mut self, id: impl Into<String>, func: F) -> Self
795 where
796 F: Fn(GraphState) -> Fut + Send + Sync + 'static,
797 Fut: std::future::Future<Output = Result<GraphState, CrewError>> + Send + 'static,
798 {
799 let id = id.into();
800 self.graph_builder = self.graph_builder.add_node(id.clone(), func);
801 self.breakpoints.insert(id);
802 self
803 }
804
805 pub fn add_edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
807 self.graph_builder = self.graph_builder.add_edge(from, to);
808 self
809 }
810
811 pub fn add_conditional_edge<F>(mut self, from: impl Into<String>, router: F) -> Self
813 where
814 F: Fn(&GraphState) -> String + Send + Sync + 'static,
815 {
816 self.graph_builder = self.graph_builder.add_conditional_edge(from, router);
817 self
818 }
819
820 pub fn set_entry(mut self, node_id: impl Into<String>) -> Self {
822 self.graph_builder = self.graph_builder.set_entry(node_id);
823 self
824 }
825
826 pub fn build(self) -> Result<InteractiveGraph, CrewError> {
828 let graph = self.graph_builder.build()?;
829
830 let mut interactive = InteractiveGraph::new(graph);
831 interactive.approval_gates = self.approval_gates;
832 interactive.breakpoints = self.breakpoints;
833 interactive.input_requests = self.input_requests;
834
835 Ok(interactive)
836 }
837}
838
839#[cfg(test)]
840mod tests {
841 use super::*;
842
843 #[tokio::test]
844 async fn test_simple_approval_gate() {
845 let graph = GraphBuilder::new("approval_test")
846 .add_node("generate", |mut state: GraphState| async move {
847 state.set("content", "Generated content");
848 Ok(state)
849 })
850 .add_node("publish", |mut state: GraphState| async move {
851 state.set("published", true);
852 Ok(state)
853 })
854 .add_edge("generate", "publish")
855 .add_edge("publish", END)
856 .set_entry("generate")
857 .build()
858 .unwrap();
859
860 let interactive = InteractiveGraph::new(graph)
861 .with_approval_gate("generate", ApprovalGate::new("Review content"));
862
863 let mut session = interactive.start(GraphState::new()).await.unwrap();
864
865 let action = session.next().await.unwrap();
867 match action {
868 HumanLoopAction::AwaitApproval { gate, state, .. } => {
869 assert_eq!(gate.description, "Review content");
870 assert_eq!(
871 state.get::<String>("content"),
872 Some("Generated content".to_string())
873 );
874 }
875 _ => panic!("Expected AwaitApproval"),
876 }
877
878 session.approve().await.unwrap();
880
881 let action = session.next().await.unwrap();
883 assert!(matches!(action, HumanLoopAction::Continue(_)));
884
885 let action = session.next().await.unwrap();
887 match action {
888 HumanLoopAction::Complete(result) => {
889 assert_eq!(result.status, GraphStatus::Success);
890 assert_eq!(result.state.get::<bool>("published"), Some(true));
891 }
892 _ => panic!("Expected Complete"),
893 }
894 }
895
896 #[tokio::test]
897 async fn test_breakpoint() {
898 let graph = GraphBuilder::new("breakpoint_test")
899 .add_node("step1", |mut state: GraphState| async move {
900 state.set("step", 1);
901 Ok(state)
902 })
903 .add_node("step2", |mut state: GraphState| async move {
904 state.set("step", 2);
905 Ok(state)
906 })
907 .add_edge("step1", "step2")
908 .add_edge("step2", END)
909 .set_entry("step1")
910 .build()
911 .unwrap();
912
913 let interactive = InteractiveGraph::new(graph).with_breakpoint("step2");
914
915 let mut session = interactive.start(GraphState::new()).await.unwrap();
916
917 let action = session.next().await.unwrap();
919 assert!(matches!(action, HumanLoopAction::Continue(_)));
920
921 let action = session.next().await.unwrap();
923 match action {
924 HumanLoopAction::Breakpoint { node_id, state } => {
925 assert_eq!(node_id, "step2");
926 assert_eq!(state.get::<i32>("step"), Some(1));
927 }
928 _ => panic!("Expected Breakpoint"),
929 }
930
931 session.resume().await.unwrap();
933
934 let action = session.next().await.unwrap();
936 assert!(matches!(action, HumanLoopAction::Continue(_)));
937
938 let action = session.next().await.unwrap();
940 assert!(matches!(action, HumanLoopAction::Complete(_)));
941 }
942
943 #[tokio::test]
944 async fn test_input_request() {
945 let graph = GraphBuilder::new("input_test")
946 .add_node("process", |state: GraphState| async move {
947 Ok(state)
949 })
950 .add_edge("process", END)
951 .set_entry("process")
952 .build()
953 .unwrap();
954
955 let interactive = InteractiveGraph::new(graph).with_input_request(
956 "process",
957 HumanInputRequest::new("Enter your name", "user_name"),
958 );
959
960 let mut session = interactive.start(GraphState::new()).await.unwrap();
961
962 let action = session.next().await.unwrap();
964 match action {
965 HumanLoopAction::AwaitInput { request, .. } => {
966 assert_eq!(request.prompt, "Enter your name");
967 assert_eq!(request.field_name, "user_name");
968 }
969 _ => panic!("Expected AwaitInput"),
970 }
971
972 session
974 .provide_input(serde_json::json!("Alice"))
975 .await
976 .unwrap();
977
978 let action = session.next().await.unwrap();
980 assert!(matches!(action, HumanLoopAction::Continue(_)));
981
982 let action = session.next().await.unwrap();
984 match action {
985 HumanLoopAction::Complete(result) => {
986 assert_eq!(
987 result.state.get::<String>("user_name"),
988 Some("Alice".to_string())
989 );
990 }
991 _ => panic!("Expected Complete"),
992 }
993 }
994
995 #[tokio::test]
996 async fn test_rejection() {
997 let graph = GraphBuilder::new("reject_test")
998 .add_node("generate", |mut state: GraphState| async move {
999 state.set("content", "Bad content");
1000 Ok(state)
1001 })
1002 .add_edge("generate", END)
1003 .set_entry("generate")
1004 .build()
1005 .unwrap();
1006
1007 let interactive = InteractiveGraph::new(graph)
1008 .with_approval_gate("generate", ApprovalGate::new("Review content"));
1009
1010 let mut session = interactive.start(GraphState::new()).await.unwrap();
1011
1012 let action = session.next().await.unwrap();
1014 assert!(matches!(action, HumanLoopAction::AwaitApproval { .. }));
1015
1016 session.reject("Content not good enough").await.unwrap();
1018
1019 assert_eq!(session.status(), SessionStatus::Interrupted);
1020 }
1021
1022 #[tokio::test]
1023 async fn test_run_until_human_action() {
1024 let graph = GraphBuilder::new("run_test")
1025 .add_node("auto1", |mut state: GraphState| async move {
1026 state.set("auto1", true);
1027 Ok(state)
1028 })
1029 .add_node("auto2", |mut state: GraphState| async move {
1030 state.set("auto2", true);
1031 Ok(state)
1032 })
1033 .add_node("manual", |mut state: GraphState| async move {
1034 state.set("manual", true);
1035 Ok(state)
1036 })
1037 .add_edge("auto1", "auto2")
1038 .add_edge("auto2", "manual")
1039 .add_edge("manual", END)
1040 .set_entry("auto1")
1041 .build()
1042 .unwrap();
1043
1044 let interactive =
1045 InteractiveGraph::new(graph).with_approval_gate("manual", ApprovalGate::new("Review"));
1046
1047 let mut session = interactive.start(GraphState::new()).await.unwrap();
1048
1049 let action = session.run_until_human_action().await.unwrap();
1051
1052 match action {
1053 HumanLoopAction::AwaitApproval { state, .. } => {
1054 assert_eq!(state.get::<bool>("auto1"), Some(true));
1056 assert_eq!(state.get::<bool>("auto2"), Some(true));
1057 assert_eq!(state.get::<bool>("manual"), Some(true));
1058 }
1059 _ => panic!("Expected AwaitApproval"),
1060 }
1061 }
1062
1063 #[test]
1064 fn test_approval_gate_builder() {
1065 let gate = ApprovalGate::new("Review document")
1066 .show_fields(vec!["title".to_string(), "content".to_string()])
1067 .allow_edit()
1068 .with_timeout(300)
1069 .with_metadata("priority", serde_json::json!("high"));
1070
1071 assert_eq!(gate.description, "Review document");
1072 assert_eq!(gate.show_fields.len(), 2);
1073 assert!(gate.allow_edit);
1074 assert_eq!(gate.timeout_secs, Some(300));
1075 assert_eq!(
1076 gate.metadata.get("priority"),
1077 Some(&serde_json::json!("high"))
1078 );
1079 }
1080
1081 #[test]
1082 fn test_input_request_builder() {
1083 let request = HumanInputRequest::new("Select priority", "priority")
1084 .input_type(HumanInputType::Select(vec![
1085 SelectOption::new("Low", "low"),
1086 SelectOption::new("Medium", "medium").with_description("Default"),
1087 SelectOption::new("High", "high"),
1088 ]))
1089 .with_default(serde_json::json!("medium"))
1090 .with_timeout(60);
1091
1092 assert_eq!(request.prompt, "Select priority");
1093 assert_eq!(request.field_name, "priority");
1094 assert!(!request.required); assert_eq!(request.timeout_secs, Some(60));
1096 }
1097}