Skip to main content

cortexai_crew/
streaming.rs

1//! Unified Streaming for Workflows and Agents
2//!
3//! Provides token-level and event-level streaming across all workflow types:
4//! - Graph node execution with LLM token streaming
5//! - Handoff conversation streaming
6//! - Workflow step streaming
7//! - Real-time event callbacks
8//!
9//! ## Example
10//!
11//! ```rust,ignore
12//! use cortexai_crew::streaming::{StreamingGraph, WorkflowEvent};
13//!
14//! // Create a streaming graph
15//! let streaming = StreamingGraph::new(graph, backend);
16//!
17//! // Execute with event stream
18//! let mut stream = streaming.execute_streaming(initial_state).await?;
19//!
20//! while let Some(event) = stream.next().await {
21//!     match event? {
22//!         WorkflowEvent::NodeStarted { node_id } => println!("Starting: {}", node_id),
23//!         WorkflowEvent::TokenDelta { token } => print!("{}", token),
24//!         WorkflowEvent::NodeCompleted { node_id, .. } => println!("\nDone: {}", node_id),
25//!         WorkflowEvent::WorkflowCompleted { .. } => break,
26//!         _ => {}
27//!     }
28//! }
29//! ```
30
31use crate::graph::{Graph, GraphState, GraphStatus, END};
32use crate::handoff::{HandoffContext, HandoffResult, HandoffRouter, HandoffStatus};
33use chrono::{DateTime, Utc};
34use futures::Stream;
35use cortexai_core::errors::CrewError;
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38use std::pin::Pin;
39use std::sync::Arc;
40use tokio::sync::mpsc;
41
42/// Unified workflow event for streaming
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub enum WorkflowEvent {
45    // === Workflow-level events ===
46    /// Workflow execution started
47    WorkflowStarted {
48        workflow_id: String,
49        workflow_type: WorkflowType,
50        timestamp: DateTime<Utc>,
51    },
52
53    /// Workflow execution completed
54    WorkflowCompleted {
55        workflow_id: String,
56        status: StreamingStatus,
57        duration_ms: u64,
58        timestamp: DateTime<Utc>,
59    },
60
61    /// Workflow execution failed
62    WorkflowFailed {
63        workflow_id: String,
64        error: String,
65        timestamp: DateTime<Utc>,
66    },
67
68    // === Node-level events ===
69    /// Node execution started
70    NodeStarted {
71        node_id: String,
72        node_type: StreamingNodeType,
73        timestamp: DateTime<Utc>,
74    },
75
76    /// Node execution completed
77    NodeCompleted {
78        node_id: String,
79        duration_ms: u64,
80        state_changes: Vec<String>,
81        timestamp: DateTime<Utc>,
82    },
83
84    /// Node execution failed
85    NodeFailed {
86        node_id: String,
87        error: String,
88        timestamp: DateTime<Utc>,
89    },
90
91    // === Token-level events (LLM streaming) ===
92    /// Token generated by LLM
93    TokenDelta {
94        token: String,
95        node_id: Option<String>,
96        agent_id: Option<String>,
97    },
98
99    /// Reasoning/thinking token (for models that expose this)
100    ReasoningDelta {
101        token: String,
102        node_id: Option<String>,
103    },
104
105    /// Tool call started
106    ToolCallStarted {
107        tool_id: String,
108        tool_name: String,
109        node_id: Option<String>,
110    },
111
112    /// Tool call arguments streaming
113    ToolCallDelta {
114        tool_id: String,
115        arguments_delta: String,
116    },
117
118    /// Tool call completed
119    ToolCallCompleted {
120        tool_id: String,
121        result: Option<String>,
122    },
123
124    // === Handoff events ===
125    /// Agent handoff occurred
126    AgentHandoff {
127        from_agent: String,
128        to_agent: String,
129        reason: String,
130        timestamp: DateTime<Utc>,
131    },
132
133    /// Agent returned to caller
134    AgentReturn {
135        from_agent: String,
136        to_agent: String,
137        timestamp: DateTime<Utc>,
138    },
139
140    /// Message added to conversation
141    MessageAdded {
142        role: String,
143        agent_id: Option<String>,
144        content_preview: String,
145        timestamp: DateTime<Utc>,
146    },
147
148    // === State events ===
149    /// State checkpoint created
150    CheckpointCreated {
151        checkpoint_id: String,
152        node_id: String,
153        timestamp: DateTime<Utc>,
154    },
155
156    /// State was modified
157    StateUpdated {
158        keys_changed: Vec<String>,
159        node_id: Option<String>,
160    },
161
162    // === Progress events ===
163    /// Progress update
164    Progress {
165        current_step: usize,
166        total_steps: Option<usize>,
167        message: String,
168    },
169
170    /// Custom event for extensibility
171    Custom {
172        event_type: String,
173        data: serde_json::Value,
174    },
175}
176
177/// Type of workflow
178#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
179pub enum WorkflowType {
180    Graph,
181    Handoff,
182    Workflow,
183    Subgraph,
184}
185
186/// Status of streaming workflow completion
187#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
188pub enum StreamingStatus {
189    Success,
190    Failed,
191    MaxIterations,
192    MaxHandoffs,
193    Interrupted,
194}
195
196impl From<GraphStatus> for StreamingStatus {
197    fn from(status: GraphStatus) -> Self {
198        match status {
199            GraphStatus::Success => StreamingStatus::Success,
200            GraphStatus::Failed => StreamingStatus::Failed,
201            GraphStatus::MaxIterations => StreamingStatus::MaxIterations,
202            GraphStatus::Interrupted => StreamingStatus::Interrupted,
203            GraphStatus::Paused => StreamingStatus::Interrupted,
204        }
205    }
206}
207
208impl From<HandoffStatus> for StreamingStatus {
209    fn from(status: HandoffStatus) -> Self {
210        match status {
211            HandoffStatus::Completed => StreamingStatus::Success,
212            HandoffStatus::MaxTurnsReached => StreamingStatus::MaxIterations,
213            HandoffStatus::MaxHandoffsReached => StreamingStatus::MaxHandoffs,
214            HandoffStatus::Interrupted => StreamingStatus::Interrupted,
215        }
216    }
217}
218
219/// Type of streaming node
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub enum StreamingNodeType {
222    Regular,
223    Subgraph,
224    LLMCall,
225    ToolExecution,
226    Conditional,
227    Human,
228}
229
230/// Stream of workflow events
231pub type EventStream = Pin<Box<dyn Stream<Item = Result<WorkflowEvent, CrewError>> + Send>>;
232
233/// Event callback function type
234pub type EventCallback = Arc<dyn Fn(WorkflowEvent) + Send + Sync>;
235
236/// Builder for event callbacks
237pub struct EventCallbackBuilder {
238    on_workflow_started: Option<EventCallback>,
239    on_workflow_completed: Option<EventCallback>,
240    on_node_started: Option<EventCallback>,
241    on_node_completed: Option<EventCallback>,
242    on_token: Option<EventCallback>,
243    on_handoff: Option<EventCallback>,
244    on_any: Option<EventCallback>,
245}
246
247impl Default for EventCallbackBuilder {
248    fn default() -> Self {
249        Self::new()
250    }
251}
252
253impl EventCallbackBuilder {
254    /// Create a new callback builder
255    pub fn new() -> Self {
256        Self {
257            on_workflow_started: None,
258            on_workflow_completed: None,
259            on_node_started: None,
260            on_node_completed: None,
261            on_token: None,
262            on_handoff: None,
263            on_any: None,
264        }
265    }
266
267    /// Set callback for workflow started
268    pub fn on_workflow_started<F>(mut self, f: F) -> Self
269    where
270        F: Fn(WorkflowEvent) + Send + Sync + 'static,
271    {
272        self.on_workflow_started = Some(Arc::new(f));
273        self
274    }
275
276    /// Set callback for workflow completed
277    pub fn on_workflow_completed<F>(mut self, f: F) -> Self
278    where
279        F: Fn(WorkflowEvent) + Send + Sync + 'static,
280    {
281        self.on_workflow_completed = Some(Arc::new(f));
282        self
283    }
284
285    /// Set callback for node started
286    pub fn on_node_started<F>(mut self, f: F) -> Self
287    where
288        F: Fn(WorkflowEvent) + Send + Sync + 'static,
289    {
290        self.on_node_started = Some(Arc::new(f));
291        self
292    }
293
294    /// Set callback for node completed
295    pub fn on_node_completed<F>(mut self, f: F) -> Self
296    where
297        F: Fn(WorkflowEvent) + Send + Sync + 'static,
298    {
299        self.on_node_completed = Some(Arc::new(f));
300        self
301    }
302
303    /// Set callback for tokens
304    pub fn on_token<F>(mut self, f: F) -> Self
305    where
306        F: Fn(WorkflowEvent) + Send + Sync + 'static,
307    {
308        self.on_token = Some(Arc::new(f));
309        self
310    }
311
312    /// Set callback for handoffs
313    pub fn on_handoff<F>(mut self, f: F) -> Self
314    where
315        F: Fn(WorkflowEvent) + Send + Sync + 'static,
316    {
317        self.on_handoff = Some(Arc::new(f));
318        self
319    }
320
321    /// Set callback for any event
322    pub fn on_any<F>(mut self, f: F) -> Self
323    where
324        F: Fn(WorkflowEvent) + Send + Sync + 'static,
325    {
326        self.on_any = Some(Arc::new(f));
327        self
328    }
329
330    /// Build the event handler
331    pub fn build(self) -> EventHandler {
332        EventHandler {
333            on_workflow_started: self.on_workflow_started,
334            on_workflow_completed: self.on_workflow_completed,
335            on_node_started: self.on_node_started,
336            on_node_completed: self.on_node_completed,
337            on_token: self.on_token,
338            on_handoff: self.on_handoff,
339            on_any: self.on_any,
340        }
341    }
342}
343
344/// Handler for workflow events
345pub struct EventHandler {
346    on_workflow_started: Option<EventCallback>,
347    on_workflow_completed: Option<EventCallback>,
348    on_node_started: Option<EventCallback>,
349    on_node_completed: Option<EventCallback>,
350    on_token: Option<EventCallback>,
351    on_handoff: Option<EventCallback>,
352    on_any: Option<EventCallback>,
353}
354
355impl EventHandler {
356    /// Handle an event
357    pub fn handle(&self, event: &WorkflowEvent) {
358        // Call specific handler
359        match event {
360            WorkflowEvent::WorkflowStarted { .. } => {
361                if let Some(cb) = &self.on_workflow_started {
362                    cb(event.clone());
363                }
364            }
365            WorkflowEvent::WorkflowCompleted { .. } | WorkflowEvent::WorkflowFailed { .. } => {
366                if let Some(cb) = &self.on_workflow_completed {
367                    cb(event.clone());
368                }
369            }
370            WorkflowEvent::NodeStarted { .. } => {
371                if let Some(cb) = &self.on_node_started {
372                    cb(event.clone());
373                }
374            }
375            WorkflowEvent::NodeCompleted { .. } | WorkflowEvent::NodeFailed { .. } => {
376                if let Some(cb) = &self.on_node_completed {
377                    cb(event.clone());
378                }
379            }
380            WorkflowEvent::TokenDelta { .. } | WorkflowEvent::ReasoningDelta { .. } => {
381                if let Some(cb) = &self.on_token {
382                    cb(event.clone());
383                }
384            }
385            WorkflowEvent::AgentHandoff { .. } | WorkflowEvent::AgentReturn { .. } => {
386                if let Some(cb) = &self.on_handoff {
387                    cb(event.clone());
388                }
389            }
390            _ => {}
391        }
392
393        // Call any handler
394        if let Some(cb) = &self.on_any {
395            cb(event.clone());
396        }
397    }
398}
399
400/// Event emitter for sending events to multiple receivers
401pub struct EventEmitter {
402    senders: Vec<mpsc::UnboundedSender<WorkflowEvent>>,
403    handler: Option<EventHandler>,
404}
405
406impl Default for EventEmitter {
407    fn default() -> Self {
408        Self::new()
409    }
410}
411
412impl EventEmitter {
413    /// Create a new emitter
414    pub fn new() -> Self {
415        Self {
416            senders: Vec::new(),
417            handler: None,
418        }
419    }
420
421    /// Create with an event handler
422    pub fn with_handler(handler: EventHandler) -> Self {
423        Self {
424            senders: Vec::new(),
425            handler: Some(handler),
426        }
427    }
428
429    /// Subscribe to events
430    pub fn subscribe(&mut self) -> mpsc::UnboundedReceiver<WorkflowEvent> {
431        let (tx, rx) = mpsc::unbounded_channel();
432        self.senders.push(tx);
433        rx
434    }
435
436    /// Emit an event
437    pub fn emit(&self, event: WorkflowEvent) {
438        // Send to all subscribers
439        for sender in &self.senders {
440            let _ = sender.send(event.clone());
441        }
442
443        // Call handler
444        if let Some(handler) = &self.handler {
445            handler.handle(&event);
446        }
447    }
448
449    /// Emit workflow started event
450    pub fn workflow_started(&self, workflow_id: &str, workflow_type: WorkflowType) {
451        self.emit(WorkflowEvent::WorkflowStarted {
452            workflow_id: workflow_id.to_string(),
453            workflow_type,
454            timestamp: Utc::now(),
455        });
456    }
457
458    /// Emit workflow completed event
459    pub fn workflow_completed(&self, workflow_id: &str, status: StreamingStatus, duration_ms: u64) {
460        self.emit(WorkflowEvent::WorkflowCompleted {
461            workflow_id: workflow_id.to_string(),
462            status,
463            duration_ms,
464            timestamp: Utc::now(),
465        });
466    }
467
468    /// Emit workflow failed event
469    pub fn workflow_failed(&self, workflow_id: &str, error: &str) {
470        self.emit(WorkflowEvent::WorkflowFailed {
471            workflow_id: workflow_id.to_string(),
472            error: error.to_string(),
473            timestamp: Utc::now(),
474        });
475    }
476
477    /// Emit node started event
478    pub fn node_started(&self, node_id: &str, node_type: StreamingNodeType) {
479        self.emit(WorkflowEvent::NodeStarted {
480            node_id: node_id.to_string(),
481            node_type,
482            timestamp: Utc::now(),
483        });
484    }
485
486    /// Emit node completed event
487    pub fn node_completed(&self, node_id: &str, duration_ms: u64, state_changes: Vec<String>) {
488        self.emit(WorkflowEvent::NodeCompleted {
489            node_id: node_id.to_string(),
490            duration_ms,
491            state_changes,
492            timestamp: Utc::now(),
493        });
494    }
495
496    /// Emit token delta event
497    pub fn token_delta(&self, token: &str, node_id: Option<&str>, agent_id: Option<&str>) {
498        self.emit(WorkflowEvent::TokenDelta {
499            token: token.to_string(),
500            node_id: node_id.map(|s| s.to_string()),
501            agent_id: agent_id.map(|s| s.to_string()),
502        });
503    }
504
505    /// Emit handoff event
506    pub fn agent_handoff(&self, from: &str, to: &str, reason: &str) {
507        self.emit(WorkflowEvent::AgentHandoff {
508            from_agent: from.to_string(),
509            to_agent: to.to_string(),
510            reason: reason.to_string(),
511            timestamp: Utc::now(),
512        });
513    }
514}
515
516/// Streaming wrapper for Graph execution
517pub struct StreamingGraph {
518    graph: Arc<Graph>,
519    emitter: Arc<EventEmitter>,
520}
521
522impl StreamingGraph {
523    /// Create a new streaming graph
524    pub fn new(graph: Graph) -> Self {
525        Self {
526            graph: Arc::new(graph),
527            emitter: Arc::new(EventEmitter::new()),
528        }
529    }
530
531    /// Create with event handler
532    pub fn with_handler(graph: Graph, handler: EventHandler) -> Self {
533        Self {
534            graph: Arc::new(graph),
535            emitter: Arc::new(EventEmitter::with_handler(handler)),
536        }
537    }
538
539    /// Get the emitter for subscribing to events
540    pub fn emitter(&self) -> &EventEmitter {
541        &self.emitter
542    }
543
544    /// Execute with streaming events
545    pub async fn execute_streaming(
546        &self,
547        initial_state: GraphState,
548    ) -> Result<(GraphState, GraphStatus), CrewError> {
549        let workflow_id = uuid::Uuid::new_v4().to_string();
550        let start_time = std::time::Instant::now();
551
552        self.emitter
553            .workflow_started(&workflow_id, WorkflowType::Graph);
554
555        let mut state = initial_state;
556        state.metadata.started_at = Some(Utc::now());
557        state.metadata.iterations = 0;
558
559        let mut current_node = self.graph.entry_node.clone();
560
561        let result = loop {
562            // Check max iterations
563            if state.metadata.iterations >= self.graph.config.max_iterations {
564                break (state, GraphStatus::MaxIterations);
565            }
566
567            // Check for END node
568            if current_node == END {
569                break (state, GraphStatus::Success);
570            }
571
572            // Get node
573            let node = match self.graph.nodes.get(&current_node) {
574                Some(n) => n,
575                None => {
576                    self.emitter.workflow_failed(
577                        &workflow_id,
578                        &format!("Node not found: {}", current_node),
579                    );
580                    return Err(CrewError::TaskNotFound(format!(
581                        "Node not found: {}",
582                        current_node
583                    )));
584                }
585            };
586
587            // Emit node started
588            self.emitter
589                .node_started(&current_node, StreamingNodeType::Regular);
590            let node_start = std::time::Instant::now();
591
592            // Track state keys before execution
593            let keys_before: Vec<String> = state
594                .data
595                .as_object()
596                .map(|o| o.keys().cloned().collect())
597                .unwrap_or_default();
598
599            // Execute node
600            state.metadata.visited_nodes.push(current_node.clone());
601            state.metadata.iterations += 1;
602
603            state = match node.executor.call(state).await {
604                Ok(s) => s,
605                Err(e) => {
606                    self.emitter.emit(WorkflowEvent::NodeFailed {
607                        node_id: current_node.clone(),
608                        error: e.to_string(),
609                        timestamp: Utc::now(),
610                    });
611                    self.emitter.workflow_failed(&workflow_id, &e.to_string());
612                    return Err(e);
613                }
614            };
615
616            // Detect state changes
617            let keys_after: Vec<String> = state
618                .data
619                .as_object()
620                .map(|o| o.keys().cloned().collect())
621                .unwrap_or_default();
622
623            let state_changes: Vec<String> = keys_after
624                .iter()
625                .filter(|k| !keys_before.contains(k))
626                .cloned()
627                .collect();
628
629            // Emit node completed
630            self.emitter.node_completed(
631                &current_node,
632                node_start.elapsed().as_millis() as u64,
633                state_changes,
634            );
635
636            // Find next node
637            current_node = self.graph.find_next_node(&current_node, &state)?;
638        };
639
640        let duration_ms = start_time.elapsed().as_millis() as u64;
641        self.emitter
642            .workflow_completed(&workflow_id, result.1.into(), duration_ms);
643
644        Ok(result)
645    }
646}
647
648/// Streaming wrapper for Handoff execution
649pub struct StreamingHandoff {
650    router: Arc<HandoffRouter>,
651    emitter: Arc<EventEmitter>,
652}
653
654impl StreamingHandoff {
655    /// Create a new streaming handoff router
656    pub fn new(router: HandoffRouter) -> Self {
657        Self {
658            router: Arc::new(router),
659            emitter: Arc::new(EventEmitter::new()),
660        }
661    }
662
663    /// Create with event handler
664    pub fn with_handler(router: HandoffRouter, handler: EventHandler) -> Self {
665        Self {
666            router: Arc::new(router),
667            emitter: Arc::new(EventEmitter::with_handler(handler)),
668        }
669    }
670
671    /// Get the emitter
672    pub fn emitter(&self) -> &EventEmitter {
673        &self.emitter
674    }
675
676    /// Execute with streaming events
677    pub async fn execute_streaming(
678        &self,
679        context: HandoffContext,
680    ) -> Result<HandoffResult, CrewError> {
681        let workflow_id = context.conversation_id.clone();
682        let start_time = std::time::Instant::now();
683
684        self.emitter
685            .workflow_started(&workflow_id, WorkflowType::Handoff);
686
687        // Execute the handoff router
688        let result = self.router.run(context).await?;
689
690        let duration_ms = start_time.elapsed().as_millis() as u64;
691
692        // Emit events for the handoff history
693        for record in &result.context.handoff_history {
694            if record.is_return {
695                self.emitter.emit(WorkflowEvent::AgentReturn {
696                    from_agent: record.from_agent.clone(),
697                    to_agent: record.to_agent.clone(),
698                    timestamp: record.timestamp,
699                });
700            } else {
701                self.emitter.emit(WorkflowEvent::AgentHandoff {
702                    from_agent: record.from_agent.clone(),
703                    to_agent: record.to_agent.clone(),
704                    reason: record.reason.clone(),
705                    timestamp: record.timestamp,
706                });
707            }
708        }
709
710        self.emitter
711            .workflow_completed(&workflow_id, result.status.into(), duration_ms);
712
713        Ok(result)
714    }
715}
716
717/// Collector for aggregating streamed tokens
718pub struct TokenCollector {
719    tokens: Vec<String>,
720    node_tokens: HashMap<String, Vec<String>>,
721    current_node: Option<String>,
722}
723
724impl Default for TokenCollector {
725    fn default() -> Self {
726        Self::new()
727    }
728}
729
730impl TokenCollector {
731    /// Create a new collector
732    pub fn new() -> Self {
733        Self {
734            tokens: Vec::new(),
735            node_tokens: HashMap::new(),
736            current_node: None,
737        }
738    }
739
740    /// Set current node context
741    pub fn set_node(&mut self, node_id: Option<String>) {
742        self.current_node = node_id;
743    }
744
745    /// Add a token
746    pub fn add_token(&mut self, token: &str) {
747        self.tokens.push(token.to_string());
748
749        if let Some(node_id) = &self.current_node {
750            self.node_tokens
751                .entry(node_id.clone())
752                .or_default()
753                .push(token.to_string());
754        }
755    }
756
757    /// Get all collected tokens as a string
758    pub fn collected(&self) -> String {
759        self.tokens.join("")
760    }
761
762    /// Get tokens for a specific node
763    pub fn tokens_for_node(&self, node_id: &str) -> String {
764        self.node_tokens
765            .get(node_id)
766            .map(|t| t.join(""))
767            .unwrap_or_default()
768    }
769
770    /// Get all tokens
771    pub fn all_tokens(&self) -> &[String] {
772        &self.tokens
773    }
774
775    /// Clear the collector
776    pub fn clear(&mut self) {
777        self.tokens.clear();
778        self.node_tokens.clear();
779        self.current_node = None;
780    }
781}
782
783/// Statistics about a streaming execution
784#[derive(Debug, Clone, Serialize, Deserialize)]
785pub struct StreamingStats {
786    /// Total tokens streamed
787    pub total_tokens: usize,
788    /// Tokens per node
789    pub tokens_per_node: HashMap<String, usize>,
790    /// Total duration in milliseconds
791    pub duration_ms: u64,
792    /// Number of nodes executed
793    pub nodes_executed: usize,
794    /// Number of handoffs
795    pub handoffs: usize,
796    /// Events emitted
797    pub events_emitted: usize,
798}
799
800#[cfg(test)]
801mod tests {
802    use super::*;
803    use crate::graph::GraphBuilder;
804    use std::sync::atomic::{AtomicUsize, Ordering};
805
806    #[tokio::test]
807    async fn test_event_emitter() {
808        let mut emitter = EventEmitter::new();
809        let mut rx = emitter.subscribe();
810
811        emitter.workflow_started("test-1", WorkflowType::Graph);
812
813        let event = rx.recv().await.unwrap();
814        match event {
815            WorkflowEvent::WorkflowStarted { workflow_id, .. } => {
816                assert_eq!(workflow_id, "test-1");
817            }
818            _ => panic!("Wrong event type"),
819        }
820    }
821
822    #[tokio::test]
823    async fn test_event_callback_builder() {
824        let started_count = Arc::new(AtomicUsize::new(0));
825        let completed_count = Arc::new(AtomicUsize::new(0));
826
827        let started_clone = started_count.clone();
828        let completed_clone = completed_count.clone();
829
830        let handler = EventCallbackBuilder::new()
831            .on_workflow_started(move |_| {
832                started_clone.fetch_add(1, Ordering::SeqCst);
833            })
834            .on_workflow_completed(move |_| {
835                completed_clone.fetch_add(1, Ordering::SeqCst);
836            })
837            .build();
838
839        let emitter = EventEmitter::with_handler(handler);
840
841        emitter.workflow_started("test", WorkflowType::Graph);
842        emitter.workflow_completed("test", StreamingStatus::Success, 100);
843
844        assert_eq!(started_count.load(Ordering::SeqCst), 1);
845        assert_eq!(completed_count.load(Ordering::SeqCst), 1);
846    }
847
848    #[tokio::test]
849    async fn test_streaming_graph() {
850        let graph = GraphBuilder::new("test")
851            .add_node("step1", |mut state: GraphState| async move {
852                state.set("value", 1);
853                Ok(state)
854            })
855            .add_node("step2", |mut state: GraphState| async move {
856                let v: i32 = state.get("value").unwrap_or(0);
857                state.set("value", v + 1);
858                Ok(state)
859            })
860            .add_edge("step1", "step2")
861            .add_edge("step2", END)
862            .set_entry("step1")
863            .build()
864            .unwrap();
865
866        let events = Arc::new(std::sync::Mutex::new(Vec::new()));
867        let events_clone = events.clone();
868
869        let handler = EventCallbackBuilder::new()
870            .on_any(move |event| {
871                events_clone.lock().unwrap().push(event);
872            })
873            .build();
874
875        let streaming = StreamingGraph::with_handler(graph, handler);
876        let (state, status) = streaming
877            .execute_streaming(GraphState::new())
878            .await
879            .unwrap();
880
881        assert_eq!(status, GraphStatus::Success);
882        assert_eq!(state.get::<i32>("value"), Some(2));
883
884        let captured = events.lock().unwrap();
885        assert!(captured.len() >= 5); // started + 2*(node_started + node_completed) + completed
886    }
887
888    #[tokio::test]
889    async fn test_token_collector() {
890        let mut collector = TokenCollector::new();
891
892        collector.set_node(Some("node1".to_string()));
893        collector.add_token("Hello");
894        collector.add_token(" ");
895        collector.add_token("World");
896
897        collector.set_node(Some("node2".to_string()));
898        collector.add_token("!");
899
900        assert_eq!(collector.collected(), "Hello World!");
901        assert_eq!(collector.tokens_for_node("node1"), "Hello World");
902        assert_eq!(collector.tokens_for_node("node2"), "!");
903    }
904
905    #[tokio::test]
906    async fn test_workflow_event_serialization() {
907        let event = WorkflowEvent::TokenDelta {
908            token: "test".to_string(),
909            node_id: Some("node1".to_string()),
910            agent_id: None,
911        };
912
913        let json = serde_json::to_string(&event).unwrap();
914        assert!(json.contains("TokenDelta"));
915        assert!(json.contains("test"));
916
917        let parsed: WorkflowEvent = serde_json::from_str(&json).unwrap();
918        match parsed {
919            WorkflowEvent::TokenDelta { token, .. } => assert_eq!(token, "test"),
920            _ => panic!("Wrong type"),
921        }
922    }
923
924    #[tokio::test]
925    async fn test_multiple_subscribers() {
926        let mut emitter = EventEmitter::new();
927        let mut rx1 = emitter.subscribe();
928        let mut rx2 = emitter.subscribe();
929
930        emitter.node_started("node1", StreamingNodeType::Regular);
931
932        // Both should receive
933        let e1 = rx1.recv().await.unwrap();
934        let e2 = rx2.recv().await.unwrap();
935
936        match (e1, e2) {
937            (
938                WorkflowEvent::NodeStarted { node_id: n1, .. },
939                WorkflowEvent::NodeStarted { node_id: n2, .. },
940            ) => {
941                assert_eq!(n1, "node1");
942                assert_eq!(n2, "node1");
943            }
944            _ => panic!("Wrong events"),
945        }
946    }
947
948    #[test]
949    fn test_streaming_status_conversions() {
950        assert_eq!(
951            StreamingStatus::from(GraphStatus::Success),
952            StreamingStatus::Success
953        );
954        assert_eq!(
955            StreamingStatus::from(GraphStatus::Failed),
956            StreamingStatus::Failed
957        );
958        assert_eq!(
959            StreamingStatus::from(HandoffStatus::Completed),
960            StreamingStatus::Success
961        );
962        assert_eq!(
963            StreamingStatus::from(HandoffStatus::MaxHandoffsReached),
964            StreamingStatus::MaxHandoffs
965        );
966    }
967}