1use crate::state::State;
2use futures::StreamExt;
3use std::collections::HashMap;
4use std::sync::Mutex;
5
6#[derive(Clone, Debug, Default, PartialEq, Eq)]
8pub enum StreamMode {
9 #[default]
11 Values,
12
13 Updates,
15
16 Messages,
18
19 Custom,
21
22 Debug,
24
25 Tools,
27
28 Checkpoints,
30
31 Tasks,
33
34 Multi(Vec<StreamMode>),
36}
37
38#[derive(Clone, Debug)]
40pub enum StreamEvent<S: State> {
41 Values { state: S, step: usize },
43
44 FilteredValues {
50 data: serde_json::Value,
51 step: usize,
52 },
53
54 Updates {
56 node: String,
57 update: S::Update,
58 step: usize,
59 },
60
61 FilteredUpdates {
66 node: String,
67 data: serde_json::Value,
68 step: usize,
69 },
70
71 Messages {
73 chunk: MessageChunk,
74 metadata: MessageStreamMetadata,
75 },
76
77 Custom {
79 node: String,
80 data: serde_json::Value,
81 ns: Vec<String>,
82 },
83
84 TaskStart {
86 node: String,
87 task_id: String,
88 step: usize,
89 },
90
91 TaskEnd {
93 node: String,
94 task_id: String,
95 step: usize,
96 duration_ms: u64,
97 },
98
99 Interrupt {
101 node: String,
102 payload: serde_json::Value,
103 resumable: bool,
104 ns: Vec<String>,
105 },
106
107 BudgetExceeded {
109 reason: crate::pregel::BudgetExceededReason,
110 usage: BudgetUsage,
111 },
112
113 End { output: S },
115
116 Cancelled { step: usize },
123
124 Debug(DebugEvent),
126
127 Tools(ToolsEvent),
129
130 CheckpointSaved {
132 checkpoint_id: String,
133 metadata: crate::checkpoint::CheckpointMetadata,
134 step: usize,
135 },
136
137 TaskDetail {
139 task_id: String,
140 node: String,
141 step: usize,
142 attempt: usize,
143 event: TaskEventType,
144 },
145}
146
147impl<S: State> StreamEvent<S> {
148 #[must_use]
155 #[allow(
156 clippy::match_same_arms,
157 reason = "each arm is explicit for clarity even when some return the same value"
158 )]
159 pub fn namespace(&self) -> &[String] {
160 match self {
161 Self::Custom { ns, .. } => ns,
162 Self::Messages { metadata, .. } => &metadata.ns,
163 Self::Interrupt { ns, .. } => ns,
164 Self::Values { .. }
165 | Self::FilteredValues { .. }
166 | Self::Updates { .. }
167 | Self::FilteredUpdates { .. }
168 | Self::TaskStart { .. }
169 | Self::TaskEnd { .. }
170 | Self::BudgetExceeded { .. }
171 | Self::End { .. }
172 | Self::Cancelled { .. }
173 | Self::Debug(_)
174 | Self::Tools(_)
175 | Self::CheckpointSaved { .. }
176 | Self::TaskDetail { .. } => &[],
177 }
178 }
179}
180
181#[derive(Clone)]
211pub struct StreamPart<S: State> {
212 pub ns: Vec<String>,
218
219 pub event: &'static str,
224
225 pub data: StreamEvent<S>,
230
231 pub metadata: Option<HashMap<String, serde_json::Value>>,
236}
237
238impl<S: State> std::fmt::Debug for StreamPart<S> {
239 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240 f.debug_struct("StreamPart")
241 .field("ns", &self.ns)
242 .field("event", &self.event)
243 .field("data", &"<StreamEvent>")
244 .field("metadata", &self.metadata)
245 .finish()
246 }
247}
248
249#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
251pub struct MessageChunk {
252 pub content: String,
253 pub tool_call_chunks: Vec<ToolCallChunk>,
254 pub usage_delta: Option<crate::state::TokenUsage>,
255}
256
257#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
259pub struct ToolCallChunk {
260 pub id: Option<String>,
261 pub name: Option<String>,
262 pub args_delta: String,
263 pub index: usize,
264}
265
266#[derive(Clone, Debug)]
268pub struct MessageStreamMetadata {
269 pub node: String,
270 pub model: String,
271 pub tags: Vec<String>,
272 pub ns: Vec<String>,
273}
274
275#[derive(Clone, Debug, serde::Serialize)]
277pub enum DebugEvent {
278 GraphStart {
280 thread_id: String,
281 input: serde_json::Value,
282 },
283 SuperstepStart {
285 step: usize,
286 pending_nodes: Vec<String>,
287 },
288 SuperstepEnd { step: usize, duration_ms: u64 },
290 NodeStart { node: String, step: usize },
292 NodeEnd {
294 node: String,
295 step: usize,
296 duration_ms: u64,
297 output_type: String,
298 },
299 NodeError {
301 node: String,
302 step: usize,
303 error: String,
304 },
305 ChannelWrite {
307 channel: String,
308 node: String,
309 value_summary: String,
310 },
311 ChannelUpdate { channel: String, new_version: u64 },
313 Merge {
315 step: usize,
316 channels_updated: Vec<String>,
317 },
318 EdgeTraversed {
320 from: String,
321 to: String,
322 edge_type: String,
323 },
324 CheckpointSaved {
326 checkpoint_id: String,
327 step: usize,
328 source: String,
329 },
330 BudgetCheck {
332 tokens_used: u64,
333 cost_usd: f64,
334 budget_remaining_pct: f32,
335 },
336 GraphEnd {
338 total_steps: usize,
339 total_duration_ms: u64,
340 },
341}
342
343#[derive(Clone, Debug)]
345pub enum ToolsEvent {
346 ToolStarted {
347 tool_name: String,
348 tool_call_id: String,
349 node: String,
350 input: serde_json::Value,
351 timestamp: chrono::DateTime<chrono::Utc>,
352 },
353 ToolOutputDelta {
354 tool_call_id: String,
355 delta: String,
356 },
357 ToolFinished {
358 tool_call_id: String,
359 output: serde_json::Value,
360 duration_ms: u64,
361 success: bool,
362 },
363 ToolError {
364 tool_call_id: String,
365 error: String,
366 },
367}
368
369#[derive(Clone, Debug)]
371pub struct BudgetUsage {
372 pub tokens_used: u64,
373 pub cost_usd: f64,
374 pub duration_ms: u64,
375 pub steps_completed: usize,
376}
377
378#[derive(Clone, Debug)]
380pub enum TaskEventType {
381 Started,
382 Completed { duration_ms: u64 },
383 Failed { error: String },
384 Retrying { attempt: usize },
385}
386
387pub trait StreamTransformer: Send + Sync + 'static {
389 #[must_use]
395 fn transform(&self, data: serde_json::Value) -> Option<serde_json::Value>;
396}
397
398#[derive(Clone)]
400pub struct EventEmitter<S: State> {
401 pub tx: tokio::sync::mpsc::Sender<StreamEvent<S>>,
402 pub mode: StreamMode,
403 ns: Vec<String>,
404 _phantom: std::marker::PhantomData<S>,
405}
406
407impl<S: State> std::fmt::Debug for EventEmitter<S> {
408 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
409 f.debug_struct("EventEmitter")
410 .field("tx", &"<mpsc::Sender>")
411 .field("mode", &self.mode)
412 .field("ns", &self.ns)
413 .finish()
414 }
415}
416
417impl<S: State> EventEmitter<S> {
418 #[must_use]
419 pub const fn new(tx: tokio::sync::mpsc::Sender<StreamEvent<S>>, mode: StreamMode) -> Self {
420 Self {
421 tx,
422 mode,
423 ns: Vec::new(),
424 _phantom: std::marker::PhantomData,
425 }
426 }
427
428 #[must_use]
433 pub fn with_subgraph_ns(&self, ns_segment: String) -> Self {
434 let mut new_ns = self.ns.clone();
435 new_ns.push(ns_segment);
436 Self {
437 tx: self.tx.clone(),
438 mode: self.mode.clone(),
439 ns: new_ns,
440 _phantom: std::marker::PhantomData,
441 }
442 }
443
444 #[must_use]
446 pub fn ns(&self) -> &[String] {
447 &self.ns
448 }
449
450 #[must_use]
452 pub const fn mode(&self) -> &StreamMode {
453 &self.mode
454 }
455
456 pub async fn emit(&self, event: StreamEvent<S>) {
462 let _ = self.tx.send(event).await;
463 }
464
465 #[must_use]
466 pub fn stream_writer(&self, node: String) -> StreamWriter<S> {
467 StreamWriter::new(self.tx.clone(), node, self.mode.clone())
468 }
469
470 #[must_use]
471 #[allow(clippy::match_same_arms, reason = "each arm is explicit for clarity")]
472 pub fn should_emit(&self, event: &StreamEvent<S>) -> bool {
473 match (&self.mode, event) {
474 (
475 StreamMode::Values,
476 StreamEvent::Values { .. }
477 | StreamEvent::FilteredValues { .. }
478 | StreamEvent::End { .. },
479 ) => true,
480 (
481 StreamMode::Updates,
482 StreamEvent::Updates { .. }
483 | StreamEvent::FilteredUpdates { .. }
484 | StreamEvent::End { .. },
485 ) => true,
486 (StreamMode::Messages, StreamEvent::Messages { .. } | StreamEvent::End { .. }) => {
487 if let StreamEvent::Messages { metadata, .. } = event {
489 !Self::has_nostream_tag_in_metadata(metadata)
490 } else {
491 true
492 }
493 }
494 (StreamMode::Custom, StreamEvent::Custom { .. } | StreamEvent::End { .. }) => true,
495 (StreamMode::Debug, _) => true, (StreamMode::Tools, StreamEvent::Tools(_) | StreamEvent::End { .. }) => true,
497 (
498 StreamMode::Checkpoints,
499 StreamEvent::CheckpointSaved { .. } | StreamEvent::End { .. },
500 ) => true,
501 (StreamMode::Tasks, StreamEvent::TaskDetail { .. } | StreamEvent::End { .. }) => true,
502 (StreamMode::Multi(modes), _) => {
503 Self::mode_matches_multi(modes, event)
505 }
506 _ => false,
507 }
508 }
509
510 #[must_use]
515 fn has_nostream_tag_in_metadata(metadata: &MessageStreamMetadata) -> bool {
516 metadata.tags.iter().any(|tag| tag == "nostream")
517 }
518
519 #[must_use]
551 pub fn has_nostream_tag(&self, options: Option<&crate::llm::CallOptions>) -> bool {
552 options.is_some_and(|opts| opts.tags.iter().any(|tag| tag == "nostream"))
553 }
554
555 #[must_use]
557 fn mode_matches_multi(modes: &[StreamMode], event: &StreamEvent<S>) -> bool {
558 modes.iter().any(|m| Self::mode_matches_single(m, event))
559 }
560
561 #[must_use]
563 #[allow(
564 clippy::match_same_arms,
565 clippy::missing_const_for_fn,
566 reason = "each arm is explicit for clarity; non-const for multi-mode filtering"
567 )]
568 fn mode_matches_single(mode: &StreamMode, event: &StreamEvent<S>) -> bool {
569 match (mode, event) {
570 (
571 StreamMode::Values,
572 StreamEvent::Values { .. }
573 | StreamEvent::FilteredValues { .. }
574 | StreamEvent::End { .. },
575 ) => true,
576 (
577 StreamMode::Updates,
578 StreamEvent::Updates { .. }
579 | StreamEvent::FilteredUpdates { .. }
580 | StreamEvent::End { .. },
581 ) => true,
582 (StreamMode::Messages, StreamEvent::Messages { .. } | StreamEvent::End { .. }) => true,
583 (StreamMode::Custom, StreamEvent::Custom { .. } | StreamEvent::End { .. }) => true,
584 (StreamMode::Debug, _) => true,
585 (StreamMode::Tools, StreamEvent::Tools(_) | StreamEvent::End { .. }) => true,
586 (
587 StreamMode::Checkpoints,
588 StreamEvent::CheckpointSaved { .. } | StreamEvent::End { .. },
589 ) => true,
590 (StreamMode::Tasks, StreamEvent::TaskDetail { .. } | StreamEvent::End { .. }) => true,
591 (StreamMode::Multi(_), _) => false,
592 _ => false,
593 }
594 }
595}
596
597#[derive(Clone)]
603pub struct StreamWriter<S: State> {
604 tx: Option<tokio::sync::mpsc::Sender<StreamEvent<S>>>,
605 node: String,
606 mode: StreamMode,
607 ns: Vec<String>,
608}
609
610impl<S: State> StreamWriter<S> {
611 #[must_use]
613 pub const fn new(
614 tx: tokio::sync::mpsc::Sender<StreamEvent<S>>,
615 node: String,
616 mode: StreamMode,
617 ) -> Self {
618 Self {
619 tx: Some(tx),
620 node,
621 mode,
622 ns: Vec::new(),
623 }
624 }
625
626 #[must_use]
630 pub const fn disconnected(node: String, mode: StreamMode) -> Self {
631 Self {
632 tx: None,
633 node,
634 mode,
635 ns: Vec::new(),
636 }
637 }
638
639 #[must_use]
641 pub fn with_ns(&self, ns_segment: String) -> Self {
642 let mut new_ns = self.ns.clone();
643 new_ns.push(ns_segment);
644 Self {
645 tx: self.tx.clone(),
646 node: self.node.clone(),
647 mode: self.mode.clone(),
648 ns: new_ns,
649 }
650 }
651
652 pub async fn send(&self, data: serde_json::Value) {
657 let Some(ref tx) = self.tx else {
658 return;
659 };
660
661 let event = StreamEvent::Custom {
662 node: self.node.clone(),
663 data,
664 ns: self.ns.clone(),
665 };
666
667 let emitter = EventEmitter::new(tx.clone(), self.mode.clone());
668 if emitter.should_emit(&event) {
669 let _ = tx.send(event).await;
670 }
671 }
672}
673
674impl<S: State> std::fmt::Debug for StreamWriter<S> {
675 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
676 f.debug_struct("StreamWriter")
677 .field("tx", &self.tx.is_some())
678 .field("node", &self.node)
679 .field("mode", &self.mode)
680 .field("ns", &self.ns)
681 .finish()
682 }
683}
684
685pub async fn call_llm_streaming<S: State, M: crate::llm::ChatModel>(
697 model: &M,
698 messages: &[crate::state::Message],
699 options: Option<&crate::llm::CallOptions>,
700 emitter: &EventEmitter<S>,
701 node_name: &str,
702) -> Result<crate::state::Message, crate::llm::LlmError> {
703 let mut stream = model.stream(messages, options).await?;
704 let mut full_content = String::new();
705 let mut tool_calls: Vec<crate::state::ToolCall> = Vec::new();
706 let mut total_usage = crate::state::TokenUsage::default();
707
708 #[allow(clippy::option_if_let_else, reason = "explicit match is clearer")]
710 let tags: Vec<String> = match options {
711 Some(opts) => opts.tags.clone(),
712 None => Vec::new(),
713 };
714
715 while let Some(chunk_result) = stream.next().await {
716 let chunk = chunk_result?;
717
718 full_content.push_str(&chunk.content);
719
720 for tc_chunk in &chunk.tool_call_chunks {
721 while tool_calls.len() <= tc_chunk.index {
722 tool_calls.push(crate::state::ToolCall {
723 id: String::new(),
724 name: String::new(),
725 arguments: serde_json::Value::Null,
726 });
727 }
728 let tc = &mut tool_calls[tc_chunk.index];
729 if let Some(ref id) = tc_chunk.id {
730 id.clone_into(&mut tc.id);
731 }
732 if let Some(ref name) = tc_chunk.name {
733 name.clone_into(&mut tc.name);
734 }
735 if !tc_chunk.args_delta.is_empty() {
736 match &mut tc.arguments {
737 serde_json::Value::String(s) => s.push_str(&tc_chunk.args_delta),
738 serde_json::Value::Null => {
739 tc.arguments = serde_json::Value::String(tc_chunk.args_delta.clone());
740 }
741 other => {
742 let mut s = match std::mem::replace(other, serde_json::Value::Null) {
743 serde_json::Value::String(existing) => existing,
744 _ => String::new(),
745 };
746 s.push_str(&tc_chunk.args_delta);
747 *other = serde_json::Value::String(s);
748 }
749 }
750 }
751 }
752
753 if let Some(ref usage) = chunk.usage {
754 total_usage.input_tokens += usage.input_tokens;
755 total_usage.output_tokens += usage.output_tokens;
756 total_usage.total_tokens += usage.total_tokens;
757 }
758
759 let stream_chunk = MessageChunk {
760 content: chunk.content,
761 tool_call_chunks: chunk.tool_call_chunks,
762 usage_delta: chunk.usage,
763 };
764
765 let event = StreamEvent::Messages {
766 chunk: stream_chunk,
767 metadata: MessageStreamMetadata {
768 node: node_name.to_string(),
769 model: model.model_name().to_string(),
770 tags: tags.clone(),
771 ns: emitter.ns().to_vec(),
772 },
773 };
774
775 if emitter.should_emit(&event) {
776 emitter.emit(event).await;
777 }
778 }
779
780 for tc in &mut tool_calls {
782 if let serde_json::Value::String(s) = &tc.arguments {
783 tc.arguments = serde_json::from_str(s).unwrap_or_else(|_| {
784 serde_json::Value::String(std::mem::take(&mut tc.arguments).to_string())
785 });
786 }
787 }
788
789 total_usage.total_tokens = total_usage.input_tokens + total_usage.output_tokens;
790
791 Ok(crate::state::Message {
792 id: uuid::Uuid::new_v4().to_string(),
793 role: crate::state::Role::Ai,
794 content: crate::state::Content::Text(full_content),
795 tool_calls,
796 tool_call_id: None,
797 name: None,
798 usage: Some(total_usage),
799 })
800}
801
802#[derive(Clone, Debug)]
808pub struct MessageBatchConfig {
809 pub max_chunks: usize,
814
815 pub flush_interval_ms: Option<u64>,
821}
822
823impl Default for MessageBatchConfig {
824 fn default() -> Self {
825 Self {
826 max_chunks: 10,
827 flush_interval_ms: Some(100),
828 }
829 }
830}
831
832impl MessageBatchConfig {
833 #[must_use]
835 pub const fn new(max_chunks: usize, flush_interval_ms: Option<u64>) -> Self {
836 Self {
837 max_chunks,
838 flush_interval_ms,
839 }
840 }
841
842 #[must_use]
844 pub const fn no_batching() -> Self {
845 Self {
846 max_chunks: 1,
847 flush_interval_ms: None,
848 }
849 }
850}
851
852pub(crate) fn filter_json_by_keys(value: serde_json::Value, keys: &[String]) -> serde_json::Value {
857 if keys.is_empty() {
858 return value;
859 }
860 match value {
861 serde_json::Value::Object(mut map) => {
862 let keep: std::collections::HashSet<&String> = keys.iter().collect();
863 map.retain(|k, _| keep.contains(k));
864 serde_json::Value::Object(map)
865 }
866 other => other,
867 }
868}
869
870#[derive(Clone, Debug, Default)]
872pub struct StreamConfig {
873 pub mode: StreamMode,
874 pub include_subgraphs: bool,
875 pub subgraph_filter: Option<Vec<String>>,
876 pub output_keys: Option<Vec<String>>,
878 pub message_batch_config: MessageBatchConfig,
880 pub resumption: Option<StreamResumption>,
885}
886
887impl StreamConfig {
888 #[must_use]
889 pub const fn new(mode: StreamMode) -> Self {
890 Self {
891 mode,
892 include_subgraphs: false,
893 subgraph_filter: None,
894 output_keys: None,
895 message_batch_config: MessageBatchConfig {
896 max_chunks: 10,
897 flush_interval_ms: Some(100),
898 },
899 resumption: None,
900 }
901 }
902
903 #[must_use]
904 pub const fn with_subgraphs(mut self, include: bool) -> Self {
905 self.include_subgraphs = include;
906 self
907 }
908
909 #[must_use]
910 pub fn with_subgraph_filter(mut self, filter: Vec<String>) -> Self {
911 self.subgraph_filter = Some(filter);
912 self
913 }
914
915 #[must_use]
917 pub fn with_output_keys(mut self, keys: Vec<String>) -> Self {
918 self.output_keys = Some(keys);
919 self
920 }
921
922 #[must_use]
924 pub const fn with_message_batch_config(mut self, config: MessageBatchConfig) -> Self {
925 self.message_batch_config = config;
926 self
927 }
928
929 #[must_use]
935 pub fn with_resumption(mut self, resumption: StreamResumption) -> Self {
936 self.resumption = Some(resumption);
937 self
938 }
939}
940
941#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
943pub struct StreamResumption {
944 pub run_id: String,
945 pub last_checkpoint_id: Option<String>,
946 pub last_step: Option<usize>,
947}
948
949impl StreamResumption {
950 #[must_use]
951 pub const fn new(
952 run_id: String,
953 last_checkpoint_id: Option<String>,
954 last_step: Option<usize>,
955 ) -> Self {
956 Self {
957 run_id,
958 last_checkpoint_id,
959 last_step,
960 }
961 }
962
963 #[must_use]
964 pub const fn should_skip(&self, current_step: usize) -> bool {
965 match self.last_step {
966 Some(last_step) => current_step <= last_step,
967 None => false,
968 }
969 }
970}
971
972#[derive(Clone, Debug, Default)]
974pub struct JsonParseTransformer;
975
976impl JsonParseTransformer {
977 #[must_use]
978 pub const fn new() -> Self {
979 Self
980 }
981}
982
983impl StreamTransformer for JsonParseTransformer {
984 #[allow(
985 clippy::option_if_let_else,
986 reason = "project rules prohibit map_or with unwrap; match is explicit and readable"
987 )]
988 fn transform(&self, data: serde_json::Value) -> Option<serde_json::Value> {
989 match data {
990 serde_json::Value::String(s) => match serde_json::from_str(&s) {
991 Ok(v) => Some(v),
992 Err(_) => Some(serde_json::Value::Null),
993 },
994 _ => Some(data),
995 }
996 }
997}
998
999#[derive(Clone, Debug)]
1001pub struct FilterFieldsTransformer {
1002 pub fields: Vec<String>,
1003}
1004
1005impl FilterFieldsTransformer {
1006 #[must_use]
1007 pub const fn new(fields: Vec<String>) -> Self {
1008 Self { fields }
1009 }
1010}
1011
1012impl StreamTransformer for FilterFieldsTransformer {
1013 fn transform(&self, data: serde_json::Value) -> Option<serde_json::Value> {
1014 match data {
1015 serde_json::Value::Object(mut map) => {
1016 let keys_to_keep: std::collections::HashSet<_> = self.fields.iter().collect();
1017 map.retain(|k, _| keys_to_keep.contains(k));
1018 Some(serde_json::Value::Object(map))
1019 }
1020 _ => Some(data),
1021 }
1022 }
1023}
1024
1025#[derive(Debug)]
1035pub struct BatchTransformer {
1036 pub size: usize,
1038
1039 buffer: Mutex<Vec<serde_json::Value>>,
1041}
1042
1043impl BatchTransformer {
1044 #[must_use]
1048 pub fn new(size: usize) -> Self {
1049 Self {
1050 size: size.max(1),
1051 buffer: Mutex::new(Vec::new()),
1052 }
1053 }
1054
1055 #[must_use]
1064 pub fn flush(&self) -> Option<serde_json::Value> {
1065 let mut buffer = self.buffer.lock().expect("BatchTransformer buffer lock");
1066 if buffer.is_empty() {
1067 return None;
1068 }
1069 let items = std::mem::take(&mut *buffer);
1070 drop(buffer);
1071 Some(serde_json::Value::Array(items))
1072 }
1073}
1074
1075impl Clone for BatchTransformer {
1076 fn clone(&self) -> Self {
1077 Self {
1078 size: self.size,
1079 buffer: Mutex::new(Vec::new()),
1081 }
1082 }
1083}
1084
1085impl StreamTransformer for BatchTransformer {
1086 fn transform(&self, data: serde_json::Value) -> Option<serde_json::Value> {
1087 let mut buffer = self.buffer.lock().expect("BatchTransformer buffer lock");
1088 buffer.push(data);
1089 let items = (buffer.len() >= self.size).then(|| std::mem::take(&mut *buffer));
1090 drop(buffer);
1091 items.map(serde_json::Value::Array)
1092 }
1093}
1094
1095#[cfg(test)]
1098mod tests {
1099 use super::{
1100 BatchTransformer, EventEmitter, MessageBatchConfig, MessageChunk, MessageStreamMetadata,
1101 StreamConfig, StreamEvent, StreamMode, StreamResumption, StreamTransformer, ToolsEvent,
1102 };
1103 use crate::state::{FieldVersions, FieldsChanged, State};
1104
1105 #[derive(Clone, Debug, Default)]
1107 struct TestState;
1108
1109 impl State for TestState {
1110 type Update = TestStateUpdate;
1111 type FieldVersions = FieldVersions;
1112
1113 fn apply(&mut self, _update: Self::Update) -> FieldsChanged {
1114 FieldsChanged(0)
1115 }
1116
1117 fn reset_ephemeral(&mut self) {}
1118 }
1119
1120 #[derive(Clone, Debug, Default)]
1121 struct TestStateUpdate;
1122
1123 #[test]
1124 fn message_batch_config_default() {
1125 let config = MessageBatchConfig::default();
1126 assert_eq!(config.max_chunks, 10);
1127 assert_eq!(config.flush_interval_ms, Some(100));
1128 }
1129
1130 #[test]
1131 fn message_batch_config_no_batching() {
1132 let config = MessageBatchConfig::no_batching();
1133 assert_eq!(config.max_chunks, 1);
1134 assert_eq!(config.flush_interval_ms, None);
1135 }
1136
1137 #[test]
1138 fn message_batch_config_new_custom() {
1139 let config = MessageBatchConfig::new(50, Some(200));
1140 assert_eq!(config.max_chunks, 50);
1141 assert_eq!(config.flush_interval_ms, Some(200));
1142 }
1143
1144 #[test]
1147 fn resumption_should_skip_returns_true_when_step_at_last_step() {
1148 let r = StreamResumption::new("run1".to_string(), None, Some(3));
1149 assert!(r.should_skip(3));
1150 }
1151
1152 #[test]
1153 fn resumption_should_skip_returns_true_when_step_before_last_step() {
1154 let r = StreamResumption::new("run1".to_string(), None, Some(3));
1155 assert!(r.should_skip(2));
1156 assert!(r.should_skip(0));
1157 }
1158
1159 #[test]
1160 fn resumption_should_skip_returns_false_when_step_after_last_step() {
1161 let r = StreamResumption::new("run1".to_string(), None, Some(3));
1162 assert!(!r.should_skip(4));
1163 assert!(!r.should_skip(100));
1164 }
1165
1166 #[test]
1167 fn resumption_should_skip_returns_false_when_last_step_is_none() {
1168 let r = StreamResumption::new("run1".to_string(), None, None);
1169 assert!(!r.should_skip(0));
1170 assert!(!r.should_skip(100));
1171 }
1172
1173 #[test]
1176 fn stream_config_default_has_no_resumption() {
1177 let config = StreamConfig::default();
1178 assert!(config.resumption.is_none());
1179 }
1180
1181 #[test]
1182 fn stream_config_new_has_no_resumption() {
1183 let config = StreamConfig::new(StreamMode::Values);
1184 assert!(config.resumption.is_none());
1185 }
1186
1187 #[test]
1188 fn stream_config_with_resumption_sets_field() {
1189 let r = StreamResumption::new("run1".to_string(), Some("cp-5".to_string()), Some(5));
1190 let config = StreamConfig::new(StreamMode::Values).with_resumption(r);
1191 assert!(config.resumption.is_some());
1192 let resumption = config.resumption.expect("resumption should be set");
1193 assert_eq!(resumption.run_id, "run1");
1194 assert_eq!(resumption.last_checkpoint_id, Some("cp-5".to_string()));
1195 assert_eq!(resumption.last_step, Some(5));
1196 }
1197
1198 #[test]
1201 fn should_emit_messages_event_without_nostream() {
1202 let (tx, _rx) = tokio::sync::mpsc::channel(16);
1203 let emitter = EventEmitter::<TestState>::new(tx, StreamMode::Messages);
1204 let event = StreamEvent::Messages {
1205 chunk: MessageChunk {
1206 content: "hello".to_string(),
1207 tool_call_chunks: Vec::new(),
1208 usage_delta: None,
1209 },
1210 metadata: MessageStreamMetadata {
1211 node: "agent".to_string(),
1212 model: "test".to_string(),
1213 tags: vec![],
1214 ns: Vec::new(),
1215 },
1216 };
1217 assert!(emitter.should_emit(&event));
1218 }
1219
1220 #[test]
1221 fn should_emit_messages_event_with_nostream_suppressed() {
1222 let (tx, _rx) = tokio::sync::mpsc::channel(16);
1223 let emitter = EventEmitter::<TestState>::new(tx, StreamMode::Messages);
1224 let event = StreamEvent::Messages {
1225 chunk: MessageChunk {
1226 content: "hello".to_string(),
1227 tool_call_chunks: Vec::new(),
1228 usage_delta: None,
1229 },
1230 metadata: MessageStreamMetadata {
1231 node: "agent".to_string(),
1232 model: "test".to_string(),
1233 tags: vec!["nostream".to_string()],
1234 ns: Vec::new(),
1235 },
1236 };
1237 assert!(!emitter.should_emit(&event));
1238 }
1239
1240 #[test]
1241 fn should_emit_messages_event_with_other_tags_not_suppressed() {
1242 let (tx, _rx) = tokio::sync::mpsc::channel(16);
1243 let emitter = EventEmitter::<TestState>::new(tx, StreamMode::Messages);
1244 let event = StreamEvent::Messages {
1245 chunk: MessageChunk {
1246 content: "hello".to_string(),
1247 tool_call_chunks: Vec::new(),
1248 usage_delta: None,
1249 },
1250 metadata: MessageStreamMetadata {
1251 node: "agent".to_string(),
1252 model: "test".to_string(),
1253 tags: vec!["fast".to_string(), "stream".to_string()],
1254 ns: Vec::new(),
1255 },
1256 };
1257 assert!(emitter.should_emit(&event));
1258 }
1259
1260 #[test]
1261 fn should_emit_end_event_always_in_messages_mode() {
1262 let (tx, _rx) = tokio::sync::mpsc::channel(16);
1263 let emitter = EventEmitter::<TestState>::new(tx, StreamMode::Messages);
1264 let event = StreamEvent::End { output: TestState };
1265 assert!(emitter.should_emit(&event));
1266 }
1267
1268 #[test]
1269 fn should_emit_tools_event_in_tools_mode() {
1270 let (tx, _rx) = tokio::sync::mpsc::channel(16);
1271 let emitter = EventEmitter::<TestState>::new(tx, StreamMode::Tools);
1272 let event = StreamEvent::Tools(ToolsEvent::ToolStarted {
1273 tool_name: "search".to_string(),
1274 tool_call_id: "call_1".to_string(),
1275 node: "tools".to_string(),
1276 input: serde_json::json!({}),
1277 timestamp: chrono::Utc::now(),
1278 });
1279 assert!(emitter.should_emit(&event));
1280 }
1281
1282 #[test]
1283 fn should_emit_tool_output_delta_in_tools_mode() {
1284 let (tx, _rx) = tokio::sync::mpsc::channel(16);
1285 let emitter = EventEmitter::<TestState>::new(tx, StreamMode::Tools);
1286 let event = StreamEvent::Tools(ToolsEvent::ToolOutputDelta {
1287 tool_call_id: "call_1".to_string(),
1288 delta: "partial".to_string(),
1289 });
1290 assert!(emitter.should_emit(&event));
1291 }
1292
1293 #[test]
1294 fn should_emit_tool_finished_in_tools_mode() {
1295 let (tx, _rx) = tokio::sync::mpsc::channel(16);
1296 let emitter = EventEmitter::<TestState>::new(tx, StreamMode::Tools);
1297 let event = StreamEvent::Tools(ToolsEvent::ToolFinished {
1298 tool_call_id: "call_1".to_string(),
1299 output: serde_json::json!({"result": "ok"}),
1300 duration_ms: 100,
1301 success: true,
1302 });
1303 assert!(emitter.should_emit(&event));
1304 }
1305
1306 #[test]
1309 fn batch_transformer_emits_batch_when_max_size_reached() {
1310 let transformer = BatchTransformer::new(3);
1311 let item = serde_json::json!({"token": "hello"});
1312
1313 assert!(transformer.transform(item.clone()).is_none());
1315 assert!(transformer.transform(item.clone()).is_none());
1316
1317 let result = transformer.transform(item);
1319 assert!(result.is_some());
1320 let batch = result.expect("batch should be emitted");
1321 assert!(batch.is_array());
1322 assert_eq!(batch.as_array().expect("batch should be an array").len(), 3);
1323 }
1324
1325 #[test]
1326 fn batch_transformer_returns_none_below_threshold() {
1327 let transformer = BatchTransformer::new(5);
1328 let item = serde_json::json!("test");
1329
1330 assert!(transformer.transform(item).is_none());
1332 }
1333
1334 #[test]
1335 fn batch_transformer_flush_returns_remaining() {
1336 let transformer = BatchTransformer::new(10);
1337 let item = serde_json::json!("data");
1338
1339 let _ = transformer.transform(item.clone());
1340 let _ = transformer.transform(item.clone());
1341 let _ = transformer.transform(item);
1342
1343 let flushed = transformer.flush();
1344 assert!(flushed.is_some());
1345 let batch = flushed.expect("flush should return items");
1346 assert_eq!(
1347 batch.as_array().expect("flush should return array").len(),
1348 3
1349 );
1350 }
1351
1352 #[test]
1353 fn batch_transformer_flush_empty_returns_none() {
1354 let transformer = BatchTransformer::new(10);
1355 assert!(transformer.flush().is_none());
1356 }
1357
1358 #[test]
1359 fn batch_transformer_size_one_emits_immediately() {
1360 let transformer = BatchTransformer::new(1);
1361 let result = transformer.transform(serde_json::json!("single"));
1362 assert!(result.is_some());
1363 let batch = result.expect("batch should be emitted");
1364 assert_eq!(batch.as_array().expect("batch should be array").len(), 1);
1365 }
1366
1367 #[test]
1368 fn batch_transformer_size_zero_clamped_to_one() {
1369 let transformer = BatchTransformer::new(0);
1370 let result = transformer.transform(serde_json::json!("clamped"));
1371 assert!(result.is_some());
1372 let batch = result.expect("batch should be emitted immediately");
1373 assert_eq!(batch.as_array().expect("batch should be array").len(), 1);
1374 }
1375
1376 #[test]
1377 fn batch_transformer_multiple_batches() {
1378 let transformer = BatchTransformer::new(2);
1379 let item = serde_json::json!("x");
1380
1381 assert!(transformer.transform(item.clone()).is_none());
1383 let batch1 = transformer.transform(item.clone());
1384 assert!(batch1.is_some());
1385 assert_eq!(
1386 batch1
1387 .expect("batch1")
1388 .as_array()
1389 .expect("batch1 array")
1390 .len(),
1391 2
1392 );
1393
1394 assert!(transformer.transform(item.clone()).is_none());
1396 let batch2 = transformer.transform(item);
1397 assert!(batch2.is_some());
1398 assert_eq!(
1399 batch2
1400 .expect("batch2")
1401 .as_array()
1402 .expect("batch2 array")
1403 .len(),
1404 2
1405 );
1406 }
1407
1408 #[test]
1409 fn batch_transformer_clone_maintains_independent_buffer() {
1410 let transformer = BatchTransformer::new(3);
1411 let item = serde_json::json!("x");
1412
1413 let _ = transformer.transform(item);
1414 let cloned = transformer.clone();
1415
1416 let flushed_original = transformer.flush();
1418 assert!(flushed_original.is_some());
1419 assert_eq!(
1420 flushed_original
1421 .expect("original flush")
1422 .as_array()
1423 .expect("original flush array")
1424 .len(),
1425 1
1426 );
1427
1428 assert!(cloned.flush().is_none());
1429 }
1430}