Skip to main content

cortexai_crew/
human_loop.rs

1//! Human-in-the-Loop Support
2//!
3//! Provides pause/resume, approval gates, and interactive breakpoints
4//! for graph-based workflows.
5//!
6//! ## Features
7//!
8//! - **Approval Gates**: Pause execution for human approval before continuing
9//! - **Breakpoints**: Define nodes where execution pauses for inspection/input
10//! - **Input Collection**: Request structured input from humans during execution
11//! - **Interrupt/Resume**: Pause and resume execution at any point
12//!
13//! ## Example
14//!
15//! ```rust,ignore
16//! use cortexai_crew::human_loop::{HumanLoop, ApprovalGate, InteractiveGraph};
17//!
18//! // Create a graph with approval gates
19//! let graph = GraphBuilder::new("approval_workflow")
20//!     .add_node("generate", generate_content)
21//!     .add_node("review", review_node)
22//!     .add_node("publish", publish_content)
23//!     .add_edge("generate", "review")
24//!     .add_edge("review", "publish")
25//!     .set_entry("generate")
26//!     .build()?;
27//!
28//! // Wrap with human-in-the-loop
29//! let interactive = InteractiveGraph::new(graph)
30//!     .with_approval_gate("review", ApprovalGate::new("Review the generated content"))
31//!     .with_breakpoint("publish");
32//!
33//! // Run interactively
34//! let mut session = interactive.start(initial_state).await?;
35//!
36//! loop {
37//!     match session.next().await? {
38//!         HumanLoopAction::Continue(state) => { /* auto-continues */ }
39//!         HumanLoopAction::AwaitApproval { gate, state } => {
40//!             // Show content to user, get approval
41//!             if user_approves() {
42//!                 session.approve().await?;
43//!             } else {
44//!                 session.reject("Needs more work").await?;
45//!             }
46//!         }
47//!         HumanLoopAction::AwaitInput { request, state } => {
48//!             let input = get_user_input(&request);
49//!             session.provide_input(input).await?;
50//!         }
51//!         HumanLoopAction::Breakpoint { node_id, state } => {
52//!             // Inspect state, optionally modify
53//!             session.resume().await?;
54//!         }
55//!         HumanLoopAction::Complete(result) => break,
56//!     }
57//! }
58//! ```
59
60use crate::graph::{
61    Checkpoint, CheckpointStore, Graph, GraphBuilder, GraphResult, GraphState, GraphStatus,
62    InMemoryCheckpointStore, END,
63};
64use chrono::{DateTime, Utc};
65use cortexai_core::errors::CrewError;
66use serde::{Deserialize, Serialize};
67use std::collections::{HashMap, HashSet};
68use std::sync::Arc;
69use tokio::sync::oneshot;
70
71/// Action required from human
72#[derive(Debug, Clone)]
73pub enum HumanLoopAction {
74    /// Execution continues automatically
75    Continue(GraphState),
76
77    /// Awaiting human approval to proceed
78    AwaitApproval {
79        /// The approval gate configuration
80        gate: ApprovalGate,
81        /// Current state
82        state: GraphState,
83        /// Node that triggered the gate
84        node_id: String,
85    },
86
87    /// Awaiting human input
88    AwaitInput {
89        /// Input request details
90        request: HumanInputRequest,
91        /// Current state
92        state: GraphState,
93    },
94
95    /// Hit a breakpoint - execution paused
96    Breakpoint {
97        /// Node where breakpoint was hit
98        node_id: String,
99        /// Current state
100        state: GraphState,
101    },
102
103    /// Execution interrupted by user
104    Interrupted {
105        /// State at interruption
106        state: GraphState,
107        /// Reason for interruption
108        reason: String,
109    },
110
111    /// Execution completed
112    Complete(GraphResult),
113
114    /// Execution failed
115    Failed(String),
116}
117
118/// Approval gate configuration
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct ApprovalGate {
121    /// Human-readable description of what needs approval
122    pub description: String,
123    /// Fields from state to show for review
124    pub show_fields: Vec<String>,
125    /// Whether to allow modification of state
126    pub allow_edit: bool,
127    /// Timeout in seconds (None = wait indefinitely)
128    pub timeout_secs: Option<u64>,
129    /// Custom metadata
130    pub metadata: HashMap<String, serde_json::Value>,
131}
132
133impl ApprovalGate {
134    /// Create a new approval gate
135    pub fn new(description: impl Into<String>) -> Self {
136        Self {
137            description: description.into(),
138            show_fields: Vec::new(),
139            allow_edit: false,
140            timeout_secs: None,
141            metadata: HashMap::new(),
142        }
143    }
144
145    /// Specify fields to show for review
146    pub fn show_fields(mut self, fields: Vec<String>) -> Self {
147        self.show_fields = fields;
148        self
149    }
150
151    /// Allow editing state during approval
152    pub fn allow_edit(mut self) -> Self {
153        self.allow_edit = true;
154        self
155    }
156
157    /// Set timeout
158    pub fn with_timeout(mut self, secs: u64) -> Self {
159        self.timeout_secs = Some(secs);
160        self
161    }
162
163    /// Add custom metadata
164    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
165        self.metadata.insert(key.into(), value);
166        self
167    }
168}
169
170/// Input request for human interaction
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct HumanInputRequest {
173    /// Unique request ID
174    pub id: String,
175    /// Human-readable prompt
176    pub prompt: String,
177    /// Type of input expected
178    pub input_type: HumanInputType,
179    /// Field name to store result
180    pub field_name: String,
181    /// Whether input is required
182    pub required: bool,
183    /// Default value
184    pub default: Option<serde_json::Value>,
185    /// Validation schema (JSON Schema format)
186    pub validation: Option<serde_json::Value>,
187    /// Timeout in seconds
188    pub timeout_secs: Option<u64>,
189}
190
191impl HumanInputRequest {
192    /// Create a new input request
193    pub fn new(prompt: impl Into<String>, field_name: impl Into<String>) -> Self {
194        Self {
195            id: uuid::Uuid::new_v4().to_string(),
196            prompt: prompt.into(),
197            input_type: HumanInputType::Text,
198            field_name: field_name.into(),
199            required: true,
200            default: None,
201            validation: None,
202            timeout_secs: None,
203        }
204    }
205
206    /// Set input type
207    pub fn input_type(mut self, t: HumanInputType) -> Self {
208        self.input_type = t;
209        self
210    }
211
212    /// Set as optional
213    pub fn optional(mut self) -> Self {
214        self.required = false;
215        self
216    }
217
218    /// Set default value
219    pub fn with_default(mut self, value: serde_json::Value) -> Self {
220        self.default = Some(value);
221        self.required = false;
222        self
223    }
224
225    /// Set timeout
226    pub fn with_timeout(mut self, secs: u64) -> Self {
227        self.timeout_secs = Some(secs);
228        self
229    }
230}
231
232/// Type of input expected for human interaction
233#[derive(Debug, Clone, Serialize, Deserialize)]
234pub enum HumanInputType {
235    /// Free-form text
236    Text,
237    /// Multi-line text
238    TextArea,
239    /// Yes/No confirmation
240    Boolean,
241    /// Number input
242    Number,
243    /// Selection from options
244    Select(Vec<SelectOption>),
245    /// Multiple selection
246    MultiSelect(Vec<SelectOption>),
247    /// Date/time input
248    DateTime,
249    /// File upload (returns base64 or path)
250    File,
251    /// Structured JSON
252    Json,
253}
254
255/// Option for select inputs
256#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct SelectOption {
258    /// Display label
259    pub label: String,
260    /// Value
261    pub value: serde_json::Value,
262    /// Description
263    pub description: Option<String>,
264}
265
266impl SelectOption {
267    /// Create a new select option
268    pub fn new(label: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
269        Self {
270            label: label.into(),
271            value: value.into(),
272            description: None,
273        }
274    }
275
276    /// Add description
277    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
278        self.description = Some(desc.into());
279        self
280    }
281}
282
283/// Human response to an approval gate
284#[derive(Debug, Clone, Serialize, Deserialize)]
285pub enum ApprovalResponse {
286    /// Approved to continue
287    Approved,
288    /// Approved with modified state
289    ApprovedWithChanges(serde_json::Value),
290    /// Rejected with reason
291    Rejected(String),
292    /// Request to retry the previous node
293    Retry,
294}
295
296/// Interactive graph wrapper with human-in-the-loop support
297pub struct InteractiveGraph {
298    /// Underlying graph
299    graph: Graph,
300    /// Approval gates by node ID
301    approval_gates: HashMap<String, ApprovalGate>,
302    /// Breakpoint nodes
303    breakpoints: HashSet<String>,
304    /// Input requests by node ID
305    input_requests: HashMap<String, HumanInputRequest>,
306    /// Checkpoint store for persistence
307    checkpoint_store: Arc<dyn CheckpointStore>,
308}
309
310impl InteractiveGraph {
311    /// Create an interactive graph wrapper
312    pub fn new(graph: Graph) -> Self {
313        Self {
314            graph,
315            approval_gates: HashMap::new(),
316            breakpoints: HashSet::new(),
317            input_requests: HashMap::new(),
318            checkpoint_store: Arc::new(InMemoryCheckpointStore::default()),
319        }
320    }
321
322    /// Add an approval gate after a node
323    pub fn with_approval_gate(mut self, node_id: impl Into<String>, gate: ApprovalGate) -> Self {
324        self.approval_gates.insert(node_id.into(), gate);
325        self
326    }
327
328    /// Add a breakpoint at a node
329    pub fn with_breakpoint(mut self, node_id: impl Into<String>) -> Self {
330        self.breakpoints.insert(node_id.into());
331        self
332    }
333
334    /// Add an input request before a node
335    pub fn with_input_request(
336        mut self,
337        node_id: impl Into<String>,
338        request: HumanInputRequest,
339    ) -> Self {
340        self.input_requests.insert(node_id.into(), request);
341        self
342    }
343
344    /// Set custom checkpoint store
345    pub fn with_checkpoint_store(mut self, store: Arc<dyn CheckpointStore>) -> Self {
346        self.checkpoint_store = store;
347        self
348    }
349
350    /// Start an interactive session
351    pub async fn start(
352        &self,
353        initial_state: GraphState,
354    ) -> Result<InteractiveSession<'_>, CrewError> {
355        Ok(InteractiveSession::new(
356            &self.graph,
357            initial_state,
358            self.approval_gates.clone(),
359            self.breakpoints.clone(),
360            self.input_requests.clone(),
361            self.checkpoint_store.clone(),
362        ))
363    }
364
365    /// Resume from a checkpoint
366    pub async fn resume(&self, checkpoint_id: &str) -> Result<InteractiveSession<'_>, CrewError> {
367        let checkpoint = self
368            .checkpoint_store
369            .load(checkpoint_id)
370            .await?
371            .ok_or_else(|| {
372                CrewError::TaskNotFound(format!("Checkpoint not found: {}", checkpoint_id))
373            })?;
374
375        Ok(InteractiveSession::from_checkpoint(
376            &self.graph,
377            checkpoint,
378            self.approval_gates.clone(),
379            self.breakpoints.clone(),
380            self.input_requests.clone(),
381            self.checkpoint_store.clone(),
382        ))
383    }
384}
385
386/// Interactive execution session
387pub struct InteractiveSession<'a> {
388    /// Reference to graph
389    graph: &'a Graph,
390    /// Current state
391    state: GraphState,
392    /// Current node
393    current_node: String,
394    /// Session status
395    status: SessionStatus,
396    /// Approval gates
397    approval_gates: HashMap<String, ApprovalGate>,
398    /// Breakpoints
399    breakpoints: HashSet<String>,
400    /// Input requests
401    input_requests: HashMap<String, HumanInputRequest>,
402    /// Checkpoint store
403    checkpoint_store: Arc<dyn CheckpointStore>,
404    /// Pending approval response (reserved for future async approval handling)
405    #[allow(dead_code)]
406    pending_approval: Option<oneshot::Sender<ApprovalResponse>>,
407    /// Pending input response (reserved for future async input handling)
408    #[allow(dead_code)]
409    pending_input: Option<oneshot::Sender<serde_json::Value>>,
410    /// Session ID
411    session_id: String,
412    /// Started at
413    started_at: DateTime<Utc>,
414}
415
416/// Session status
417#[derive(Debug, Clone, Copy, PartialEq, Eq)]
418pub enum SessionStatus {
419    /// Running normally
420    Running,
421    /// Waiting for approval
422    AwaitingApproval,
423    /// Waiting for input
424    AwaitingInput,
425    /// Paused at breakpoint
426    Paused,
427    /// Completed
428    Completed,
429    /// Failed
430    Failed,
431    /// Interrupted
432    Interrupted,
433}
434
435impl<'a> InteractiveSession<'a> {
436    /// Create a new session
437    fn new(
438        graph: &'a Graph,
439        initial_state: GraphState,
440        approval_gates: HashMap<String, ApprovalGate>,
441        breakpoints: HashSet<String>,
442        input_requests: HashMap<String, HumanInputRequest>,
443        checkpoint_store: Arc<dyn CheckpointStore>,
444    ) -> Self {
445        Self {
446            graph,
447            state: initial_state,
448            current_node: graph.entry_node.clone(),
449            status: SessionStatus::Running,
450            approval_gates,
451            breakpoints,
452            input_requests,
453            checkpoint_store,
454            pending_approval: None,
455            pending_input: None,
456            session_id: uuid::Uuid::new_v4().to_string(),
457            started_at: Utc::now(),
458        }
459    }
460
461    /// Create from checkpoint
462    fn from_checkpoint(
463        graph: &'a Graph,
464        checkpoint: Checkpoint,
465        approval_gates: HashMap<String, ApprovalGate>,
466        breakpoints: HashSet<String>,
467        input_requests: HashMap<String, HumanInputRequest>,
468        checkpoint_store: Arc<dyn CheckpointStore>,
469    ) -> Self {
470        Self {
471            graph,
472            state: checkpoint.state,
473            current_node: checkpoint.next_node,
474            status: SessionStatus::Running,
475            approval_gates,
476            breakpoints,
477            input_requests,
478            checkpoint_store,
479            pending_approval: None,
480            pending_input: None,
481            session_id: uuid::Uuid::new_v4().to_string(),
482            started_at: Utc::now(),
483        }
484    }
485
486    /// Get current state
487    pub fn state(&self) -> &GraphState {
488        &self.state
489    }
490
491    /// Get current node
492    pub fn current_node(&self) -> &str {
493        &self.current_node
494    }
495
496    /// Get session status
497    pub fn status(&self) -> SessionStatus {
498        self.status
499    }
500
501    /// Get session ID
502    pub fn session_id(&self) -> &str {
503        &self.session_id
504    }
505
506    /// Execute next step
507    pub async fn next(&mut self) -> Result<HumanLoopAction, CrewError> {
508        // Check if completed
509        if self.current_node == END {
510            self.status = SessionStatus::Completed;
511            let duration = Utc::now()
512                .signed_duration_since(self.started_at)
513                .num_milliseconds() as u64;
514            self.state.metadata.execution_time_ms = duration;
515
516            return Ok(HumanLoopAction::Complete(GraphResult {
517                state: self.state.clone(),
518                status: GraphStatus::Success,
519                error: None,
520            }));
521        }
522
523        // Check for max iterations
524        if self.state.metadata.iterations >= 100 {
525            self.status = SessionStatus::Failed;
526            return Ok(HumanLoopAction::Failed(
527                "Max iterations reached".to_string(),
528            ));
529        }
530
531        // Check for input request before node
532        if let Some(request) = self.input_requests.get(&self.current_node).cloned() {
533            self.status = SessionStatus::AwaitingInput;
534            return Ok(HumanLoopAction::AwaitInput {
535                request,
536                state: self.state.clone(),
537            });
538        }
539
540        // Check for breakpoint
541        if self.breakpoints.contains(&self.current_node) {
542            self.status = SessionStatus::Paused;
543            // Save checkpoint
544            self.save_checkpoint().await?;
545            return Ok(HumanLoopAction::Breakpoint {
546                node_id: self.current_node.clone(),
547                state: self.state.clone(),
548            });
549        }
550
551        // Execute the current node
552        let node = self
553            .graph
554            .nodes
555            .get(&self.current_node)
556            .ok_or_else(|| CrewError::TaskNotFound(self.current_node.clone()))?;
557
558        self.state
559            .metadata
560            .visited_nodes
561            .push(self.current_node.clone());
562        self.state.metadata.iterations += 1;
563
564        self.state = node.executor.call(self.state.clone()).await?;
565
566        // Check for approval gate after node
567        if let Some(gate) = self.approval_gates.get(&self.current_node).cloned() {
568            self.status = SessionStatus::AwaitingApproval;
569            // Save checkpoint before waiting
570            self.save_checkpoint().await?;
571            return Ok(HumanLoopAction::AwaitApproval {
572                gate,
573                state: self.state.clone(),
574                node_id: self.current_node.clone(),
575            });
576        }
577
578        // Find next node
579        self.current_node = self.find_next_node()?;
580        self.status = SessionStatus::Running;
581
582        Ok(HumanLoopAction::Continue(self.state.clone()))
583    }
584
585    /// Approve and continue
586    pub async fn approve(&mut self) -> Result<(), CrewError> {
587        if self.status != SessionStatus::AwaitingApproval {
588            return Err(CrewError::ExecutionFailed(
589                "Not awaiting approval".to_string(),
590            ));
591        }
592
593        // Move to next node
594        self.current_node = self.find_next_node()?;
595        self.status = SessionStatus::Running;
596        Ok(())
597    }
598
599    /// Approve with modified state
600    pub async fn approve_with_changes(
601        &mut self,
602        changes: serde_json::Value,
603    ) -> Result<(), CrewError> {
604        if self.status != SessionStatus::AwaitingApproval {
605            return Err(CrewError::ExecutionFailed(
606                "Not awaiting approval".to_string(),
607            ));
608        }
609
610        // Apply changes to state
611        if let Some(obj) = changes.as_object() {
612            for (k, v) in obj {
613                self.state.set(k, v.clone());
614            }
615        }
616
617        // Move to next node
618        self.current_node = self.find_next_node()?;
619        self.status = SessionStatus::Running;
620        Ok(())
621    }
622
623    /// Reject and stop or go back
624    pub async fn reject(&mut self, reason: impl Into<String>) -> Result<(), CrewError> {
625        if self.status != SessionStatus::AwaitingApproval {
626            return Err(CrewError::ExecutionFailed(
627                "Not awaiting approval".to_string(),
628            ));
629        }
630
631        self.status = SessionStatus::Interrupted;
632        self.state.set("_rejection_reason", reason.into());
633        Ok(())
634    }
635
636    /// Provide input and continue
637    pub async fn provide_input(&mut self, value: serde_json::Value) -> Result<(), CrewError> {
638        if self.status != SessionStatus::AwaitingInput {
639            return Err(CrewError::ExecutionFailed("Not awaiting input".to_string()));
640        }
641
642        // Get the field name for this input and remove it so we don't ask again
643        if let Some(request) = self.input_requests.remove(&self.current_node) {
644            self.state.set(&request.field_name, value);
645        }
646
647        self.status = SessionStatus::Running;
648        Ok(())
649    }
650
651    /// Resume from breakpoint
652    pub async fn resume(&mut self) -> Result<(), CrewError> {
653        if self.status != SessionStatus::Paused {
654            return Err(CrewError::ExecutionFailed("Not paused".to_string()));
655        }
656
657        // Remove the breakpoint that was hit (so we don't hit it again)
658        self.breakpoints.remove(&self.current_node);
659        self.status = SessionStatus::Running;
660        Ok(())
661    }
662
663    /// Resume with modified state
664    pub async fn resume_with_state(&mut self, new_state: GraphState) -> Result<(), CrewError> {
665        if self.status != SessionStatus::Paused {
666            return Err(CrewError::ExecutionFailed("Not paused".to_string()));
667        }
668
669        self.state = new_state;
670        self.breakpoints.remove(&self.current_node);
671        self.status = SessionStatus::Running;
672        Ok(())
673    }
674
675    /// Interrupt execution
676    pub async fn interrupt(&mut self, reason: impl Into<String>) -> Result<(), CrewError> {
677        self.status = SessionStatus::Interrupted;
678        self.state.set("_interrupt_reason", reason.into());
679        // Save checkpoint for potential resume
680        self.save_checkpoint().await?;
681        Ok(())
682    }
683
684    /// Get checkpoint ID for this session
685    pub fn checkpoint_id(&self) -> String {
686        format!("{}_{}", self.session_id, self.state.metadata.iterations)
687    }
688
689    /// Save current state as checkpoint
690    async fn save_checkpoint(&self) -> Result<(), CrewError> {
691        let checkpoint = Checkpoint {
692            id: self.checkpoint_id(),
693            state: self.state.clone(),
694            next_node: self.current_node.clone(),
695            created_at: Utc::now(),
696        };
697        self.checkpoint_store.save(checkpoint).await
698    }
699
700    /// Find next node based on edges
701    fn find_next_node(&self) -> Result<String, CrewError> {
702        for edge in &self.graph.edges {
703            match edge {
704                crate::graph::GraphEdge::Direct { from, to } if *from == self.current_node => {
705                    return Ok(to.clone());
706                }
707                crate::graph::GraphEdge::Conditional { from, router }
708                    if *from == self.current_node =>
709                {
710                    return Ok(router.route(&self.state));
711                }
712                _ => continue,
713            }
714        }
715        Ok(END.to_string())
716    }
717
718    /// Run to completion or next human action
719    pub async fn run_until_human_action(&mut self) -> Result<HumanLoopAction, CrewError> {
720        loop {
721            let action = self.next().await?;
722            match &action {
723                HumanLoopAction::Continue(_) => continue,
724                _ => return Ok(action),
725            }
726        }
727    }
728}
729
730/// Builder for creating interactive workflows
731pub struct InteractiveWorkflowBuilder {
732    graph_builder: GraphBuilder,
733    approval_gates: HashMap<String, ApprovalGate>,
734    breakpoints: HashSet<String>,
735    input_requests: HashMap<String, HumanInputRequest>,
736}
737
738impl InteractiveWorkflowBuilder {
739    /// Create a new builder
740    pub fn new(id: impl Into<String>) -> Self {
741        Self {
742            graph_builder: GraphBuilder::new(id),
743            approval_gates: HashMap::new(),
744            breakpoints: HashSet::new(),
745            input_requests: HashMap::new(),
746        }
747    }
748
749    /// Add a node
750    pub fn add_node<F, Fut>(mut self, id: impl Into<String>, func: F) -> Self
751    where
752        F: Fn(GraphState) -> Fut + Send + Sync + 'static,
753        Fut: std::future::Future<Output = Result<GraphState, CrewError>> + Send + 'static,
754    {
755        self.graph_builder = self.graph_builder.add_node(id, func);
756        self
757    }
758
759    /// Add a node with approval gate
760    pub fn add_node_with_approval<F, Fut>(
761        mut self,
762        id: impl Into<String>,
763        func: F,
764        gate: ApprovalGate,
765    ) -> Self
766    where
767        F: Fn(GraphState) -> Fut + Send + Sync + 'static,
768        Fut: std::future::Future<Output = Result<GraphState, CrewError>> + Send + 'static,
769    {
770        let id = id.into();
771        self.graph_builder = self.graph_builder.add_node(id.clone(), func);
772        self.approval_gates.insert(id, gate);
773        self
774    }
775
776    /// Add a node with input request
777    pub fn add_node_with_input<F, Fut>(
778        mut self,
779        id: impl Into<String>,
780        func: F,
781        request: HumanInputRequest,
782    ) -> Self
783    where
784        F: Fn(GraphState) -> Fut + Send + Sync + 'static,
785        Fut: std::future::Future<Output = Result<GraphState, CrewError>> + Send + 'static,
786    {
787        let id = id.into();
788        self.graph_builder = self.graph_builder.add_node(id.clone(), func);
789        self.input_requests.insert(id, request);
790        self
791    }
792
793    /// Add a breakpoint node
794    pub fn add_breakpoint_node<F, Fut>(mut self, id: impl Into<String>, func: F) -> Self
795    where
796        F: Fn(GraphState) -> Fut + Send + Sync + 'static,
797        Fut: std::future::Future<Output = Result<GraphState, CrewError>> + Send + 'static,
798    {
799        let id = id.into();
800        self.graph_builder = self.graph_builder.add_node(id.clone(), func);
801        self.breakpoints.insert(id);
802        self
803    }
804
805    /// Add edge
806    pub fn add_edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
807        self.graph_builder = self.graph_builder.add_edge(from, to);
808        self
809    }
810
811    /// Add conditional edge
812    pub fn add_conditional_edge<F>(mut self, from: impl Into<String>, router: F) -> Self
813    where
814        F: Fn(&GraphState) -> String + Send + Sync + 'static,
815    {
816        self.graph_builder = self.graph_builder.add_conditional_edge(from, router);
817        self
818    }
819
820    /// Set entry node
821    pub fn set_entry(mut self, node_id: impl Into<String>) -> Self {
822        self.graph_builder = self.graph_builder.set_entry(node_id);
823        self
824    }
825
826    /// Build the interactive graph
827    pub fn build(self) -> Result<InteractiveGraph, CrewError> {
828        let graph = self.graph_builder.build()?;
829
830        let mut interactive = InteractiveGraph::new(graph);
831        interactive.approval_gates = self.approval_gates;
832        interactive.breakpoints = self.breakpoints;
833        interactive.input_requests = self.input_requests;
834
835        Ok(interactive)
836    }
837}
838
839#[cfg(test)]
840mod tests {
841    use super::*;
842
843    #[tokio::test]
844    async fn test_simple_approval_gate() {
845        let graph = GraphBuilder::new("approval_test")
846            .add_node("generate", |mut state: GraphState| async move {
847                state.set("content", "Generated content");
848                Ok(state)
849            })
850            .add_node("publish", |mut state: GraphState| async move {
851                state.set("published", true);
852                Ok(state)
853            })
854            .add_edge("generate", "publish")
855            .add_edge("publish", END)
856            .set_entry("generate")
857            .build()
858            .unwrap();
859
860        let interactive = InteractiveGraph::new(graph)
861            .with_approval_gate("generate", ApprovalGate::new("Review content"));
862
863        let mut session = interactive.start(GraphState::new()).await.unwrap();
864
865        // First step executes generate
866        let action = session.next().await.unwrap();
867        match action {
868            HumanLoopAction::AwaitApproval { gate, state, .. } => {
869                assert_eq!(gate.description, "Review content");
870                assert_eq!(
871                    state.get::<String>("content"),
872                    Some("Generated content".to_string())
873                );
874            }
875            _ => panic!("Expected AwaitApproval"),
876        }
877
878        // Approve
879        session.approve().await.unwrap();
880
881        // Continue to publish
882        let action = session.next().await.unwrap();
883        assert!(matches!(action, HumanLoopAction::Continue(_)));
884
885        // Complete
886        let action = session.next().await.unwrap();
887        match action {
888            HumanLoopAction::Complete(result) => {
889                assert_eq!(result.status, GraphStatus::Success);
890                assert_eq!(result.state.get::<bool>("published"), Some(true));
891            }
892            _ => panic!("Expected Complete"),
893        }
894    }
895
896    #[tokio::test]
897    async fn test_breakpoint() {
898        let graph = GraphBuilder::new("breakpoint_test")
899            .add_node("step1", |mut state: GraphState| async move {
900                state.set("step", 1);
901                Ok(state)
902            })
903            .add_node("step2", |mut state: GraphState| async move {
904                state.set("step", 2);
905                Ok(state)
906            })
907            .add_edge("step1", "step2")
908            .add_edge("step2", END)
909            .set_entry("step1")
910            .build()
911            .unwrap();
912
913        let interactive = InteractiveGraph::new(graph).with_breakpoint("step2");
914
915        let mut session = interactive.start(GraphState::new()).await.unwrap();
916
917        // Execute step1
918        let action = session.next().await.unwrap();
919        assert!(matches!(action, HumanLoopAction::Continue(_)));
920
921        // Hit breakpoint at step2
922        let action = session.next().await.unwrap();
923        match action {
924            HumanLoopAction::Breakpoint { node_id, state } => {
925                assert_eq!(node_id, "step2");
926                assert_eq!(state.get::<i32>("step"), Some(1));
927            }
928            _ => panic!("Expected Breakpoint"),
929        }
930
931        // Resume
932        session.resume().await.unwrap();
933
934        // Execute step2
935        let action = session.next().await.unwrap();
936        assert!(matches!(action, HumanLoopAction::Continue(_)));
937
938        // Complete
939        let action = session.next().await.unwrap();
940        assert!(matches!(action, HumanLoopAction::Complete(_)));
941    }
942
943    #[tokio::test]
944    async fn test_input_request() {
945        let graph = GraphBuilder::new("input_test")
946            .add_node("process", |state: GraphState| async move {
947                // Just pass through - input should already be set
948                Ok(state)
949            })
950            .add_edge("process", END)
951            .set_entry("process")
952            .build()
953            .unwrap();
954
955        let interactive = InteractiveGraph::new(graph).with_input_request(
956            "process",
957            HumanInputRequest::new("Enter your name", "user_name"),
958        );
959
960        let mut session = interactive.start(GraphState::new()).await.unwrap();
961
962        // Should request input before process
963        let action = session.next().await.unwrap();
964        match action {
965            HumanLoopAction::AwaitInput { request, .. } => {
966                assert_eq!(request.prompt, "Enter your name");
967                assert_eq!(request.field_name, "user_name");
968            }
969            _ => panic!("Expected AwaitInput"),
970        }
971
972        // Provide input
973        session
974            .provide_input(serde_json::json!("Alice"))
975            .await
976            .unwrap();
977
978        // Execute process
979        let action = session.next().await.unwrap();
980        assert!(matches!(action, HumanLoopAction::Continue(_)));
981
982        // Complete
983        let action = session.next().await.unwrap();
984        match action {
985            HumanLoopAction::Complete(result) => {
986                assert_eq!(
987                    result.state.get::<String>("user_name"),
988                    Some("Alice".to_string())
989                );
990            }
991            _ => panic!("Expected Complete"),
992        }
993    }
994
995    #[tokio::test]
996    async fn test_rejection() {
997        let graph = GraphBuilder::new("reject_test")
998            .add_node("generate", |mut state: GraphState| async move {
999                state.set("content", "Bad content");
1000                Ok(state)
1001            })
1002            .add_edge("generate", END)
1003            .set_entry("generate")
1004            .build()
1005            .unwrap();
1006
1007        let interactive = InteractiveGraph::new(graph)
1008            .with_approval_gate("generate", ApprovalGate::new("Review content"));
1009
1010        let mut session = interactive.start(GraphState::new()).await.unwrap();
1011
1012        // Execute generate
1013        let action = session.next().await.unwrap();
1014        assert!(matches!(action, HumanLoopAction::AwaitApproval { .. }));
1015
1016        // Reject
1017        session.reject("Content not good enough").await.unwrap();
1018
1019        assert_eq!(session.status(), SessionStatus::Interrupted);
1020    }
1021
1022    #[tokio::test]
1023    async fn test_run_until_human_action() {
1024        let graph = GraphBuilder::new("run_test")
1025            .add_node("auto1", |mut state: GraphState| async move {
1026                state.set("auto1", true);
1027                Ok(state)
1028            })
1029            .add_node("auto2", |mut state: GraphState| async move {
1030                state.set("auto2", true);
1031                Ok(state)
1032            })
1033            .add_node("manual", |mut state: GraphState| async move {
1034                state.set("manual", true);
1035                Ok(state)
1036            })
1037            .add_edge("auto1", "auto2")
1038            .add_edge("auto2", "manual")
1039            .add_edge("manual", END)
1040            .set_entry("auto1")
1041            .build()
1042            .unwrap();
1043
1044        let interactive =
1045            InteractiveGraph::new(graph).with_approval_gate("manual", ApprovalGate::new("Review"));
1046
1047        let mut session = interactive.start(GraphState::new()).await.unwrap();
1048
1049        // Run until human action (should skip auto1, auto2 and stop at manual approval)
1050        let action = session.run_until_human_action().await.unwrap();
1051
1052        match action {
1053            HumanLoopAction::AwaitApproval { state, .. } => {
1054                // Both auto nodes should have run
1055                assert_eq!(state.get::<bool>("auto1"), Some(true));
1056                assert_eq!(state.get::<bool>("auto2"), Some(true));
1057                assert_eq!(state.get::<bool>("manual"), Some(true));
1058            }
1059            _ => panic!("Expected AwaitApproval"),
1060        }
1061    }
1062
1063    #[test]
1064    fn test_approval_gate_builder() {
1065        let gate = ApprovalGate::new("Review document")
1066            .show_fields(vec!["title".to_string(), "content".to_string()])
1067            .allow_edit()
1068            .with_timeout(300)
1069            .with_metadata("priority", serde_json::json!("high"));
1070
1071        assert_eq!(gate.description, "Review document");
1072        assert_eq!(gate.show_fields.len(), 2);
1073        assert!(gate.allow_edit);
1074        assert_eq!(gate.timeout_secs, Some(300));
1075        assert_eq!(
1076            gate.metadata.get("priority"),
1077            Some(&serde_json::json!("high"))
1078        );
1079    }
1080
1081    #[test]
1082    fn test_input_request_builder() {
1083        let request = HumanInputRequest::new("Select priority", "priority")
1084            .input_type(HumanInputType::Select(vec![
1085                SelectOption::new("Low", "low"),
1086                SelectOption::new("Medium", "medium").with_description("Default"),
1087                SelectOption::new("High", "high"),
1088            ]))
1089            .with_default(serde_json::json!("medium"))
1090            .with_timeout(60);
1091
1092        assert_eq!(request.prompt, "Select priority");
1093        assert_eq!(request.field_name, "priority");
1094        assert!(!request.required); // optional because has default
1095        assert_eq!(request.timeout_secs, Some(60));
1096    }
1097}