1use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::hooks::PromptHook;
14use crate::agent::prompt_request::streaming::StreamingPromptRequest;
15use crate::completion::{
16 CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
17 Message, Usage,
18};
19use crate::message::{
20 AssistantContent, Reasoning, ReasoningContent, Text, ToolCall, ToolFunction, ToolResult,
21};
22use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
23use futures::stream::{AbortHandle, Abortable};
24use futures::{Stream, StreamExt};
25use serde::{Deserialize, Serialize};
26use std::future::Future;
27use std::pin::Pin;
28use std::sync::atomic::AtomicBool;
29use std::task::{Context, Poll};
30use tokio::sync::watch;
31
32pub struct PauseControl {
34 pub(crate) paused_tx: watch::Sender<bool>,
35 pub(crate) paused_rx: watch::Receiver<bool>,
36}
37
38impl PauseControl {
39 pub fn new() -> Self {
41 let (paused_tx, paused_rx) = watch::channel(false);
42 Self {
43 paused_tx,
44 paused_rx,
45 }
46 }
47
48 pub fn pause(&self) {
50 let _ = self.paused_tx.send(true);
51 }
52
53 pub fn resume(&self) {
55 let _ = self.paused_tx.send(false);
56 }
57
58 pub fn is_paused(&self) -> bool {
60 *self.paused_rx.borrow()
61 }
62}
63
64impl Default for PauseControl {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
72pub enum ToolCallDeltaContent {
73 Name(String),
75 Delta(String),
77}
78
79#[derive(Debug, Clone)]
81pub enum RawStreamingChoice<R>
82where
83 R: Clone,
84{
85 Message(String),
87
88 TextStart {
94 additional_params: Option<serde_json::Value>,
96 },
97
98 TextAdditionalParams(serde_json::Value),
103
104 ToolCall(RawStreamingToolCall),
106 ToolCallDelta {
108 id: String,
110 internal_call_id: String,
112 content: ToolCallDeltaContent,
113 },
114 Reasoning {
116 id: Option<String>,
118 content: ReasoningContent,
120 },
121 ReasoningDelta {
123 id: Option<String>,
125 reasoning: String,
127 },
128
129 FinalResponse(R),
132
133 MessageId(String),
136}
137
138#[derive(Debug, Clone)]
140pub struct RawStreamingToolCall {
141 pub id: String,
143 pub internal_call_id: String,
145 pub call_id: Option<String>,
147 pub name: String,
149 pub arguments: serde_json::Value,
151 pub signature: Option<String>,
153 pub additional_params: Option<serde_json::Value>,
155}
156
157impl RawStreamingToolCall {
158 pub fn empty() -> Self {
160 Self {
161 id: String::new(),
162 internal_call_id: nanoid::nanoid!(),
163 call_id: None,
164 name: String::new(),
165 arguments: serde_json::Value::Null,
166 signature: None,
167 additional_params: None,
168 }
169 }
170
171 pub fn new(id: String, name: String, arguments: serde_json::Value) -> Self {
173 Self {
174 id,
175 internal_call_id: nanoid::nanoid!(),
176 call_id: None,
177 name,
178 arguments,
179 signature: None,
180 additional_params: None,
181 }
182 }
183
184 pub fn with_internal_call_id(mut self, internal_call_id: String) -> Self {
186 self.internal_call_id = internal_call_id;
187 self
188 }
189
190 pub fn with_call_id(mut self, call_id: String) -> Self {
192 self.call_id = Some(call_id);
193 self
194 }
195
196 pub fn with_signature(mut self, signature: Option<String>) -> Self {
198 self.signature = signature;
199 self
200 }
201
202 pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
204 self.additional_params = additional_params;
205 self
206 }
207}
208
209impl From<RawStreamingToolCall> for ToolCall {
210 fn from(tool_call: RawStreamingToolCall) -> Self {
211 ToolCall {
212 id: tool_call.id,
213 call_id: tool_call.call_id,
214 function: ToolFunction {
215 name: tool_call.name,
216 arguments: tool_call.arguments,
217 },
218 signature: tool_call.signature,
219 additional_params: tool_call.additional_params,
220 }
221 }
222}
223
224#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
225pub type StreamingResult<R> =
227 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
228
229#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
230pub type StreamingResult<R> =
232 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
233
234pub struct StreamingCompletionResponse<R>
238where
239 R: Clone + Unpin + GetTokenUsage,
240{
241 pub(crate) inner: Abortable<StreamingResult<R>>,
242 pub(crate) abort_handle: AbortHandle,
243 pub(crate) pause_control: PauseControl,
244 assistant_items: Vec<AssistantContent>,
245 text_item_index: Option<usize>,
246 reasoning_item_index: Option<usize>,
247 pub choice: OneOrMany<AssistantContent>,
250 pub response: Option<R>,
253 pub final_response_yielded: AtomicBool,
254 pub message_id: Option<String>,
256}
257
258impl<R> StreamingCompletionResponse<R>
259where
260 R: Clone + Unpin + GetTokenUsage,
261{
262 pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
264 let (abort_handle, abort_registration) = AbortHandle::new_pair();
265 let abortable_stream = Abortable::new(inner, abort_registration);
266 let pause_control = PauseControl::new();
267 Self {
268 inner: abortable_stream,
269 abort_handle,
270 pause_control,
271 assistant_items: vec![],
272 text_item_index: None,
273 reasoning_item_index: None,
274 choice: OneOrMany::one(AssistantContent::text("")),
275 response: None,
276 final_response_yielded: AtomicBool::new(false),
277 message_id: None,
278 }
279 }
280
281 pub fn cancel(&self) {
283 self.abort_handle.abort();
284 }
285
286 pub fn pause(&self) {
288 self.pause_control.pause();
289 }
290
291 pub fn resume(&self) {
293 self.pause_control.resume();
294 }
295
296 pub fn is_paused(&self) -> bool {
298 self.pause_control.is_paused()
299 }
300
301 fn append_text_chunk(&mut self, text: &str) {
302 if let Some(index) = self.text_item_index
303 && let Some(AssistantContent::Text(existing_text)) = self.assistant_items.get_mut(index)
304 {
305 existing_text.text.push_str(text);
306 return;
307 }
308
309 self.assistant_items
310 .push(AssistantContent::text(text.to_owned()));
311 self.text_item_index = Some(self.assistant_items.len() - 1);
312 }
313
314 fn append_text_additional_params(&mut self, additional_params: serde_json::Value) {
315 if additional_params.is_null() {
316 return;
317 }
318
319 let index = if let Some(index) = self.text_item_index
320 && matches!(
321 self.assistant_items.get(index),
322 Some(AssistantContent::Text(_))
323 ) {
324 index
325 } else {
326 self.assistant_items.push(AssistantContent::text(""));
327 let index = self.assistant_items.len() - 1;
328 self.text_item_index = Some(index);
329 index
330 };
331
332 let Some(AssistantContent::Text(text)) = self.assistant_items.get_mut(index) else {
333 return;
334 };
335
336 match text.additional_params.as_mut() {
337 Some(existing) => merge_text_additional_params(existing, additional_params),
338 None => text.additional_params = Some(additional_params),
339 }
340 }
341
342 fn append_reasoning_chunk(&mut self, id: &Option<String>, text: &str) {
346 if let Some(index) = self.reasoning_item_index
347 && let Some(AssistantContent::Reasoning(existing)) = self.assistant_items.get_mut(index)
348 && let Some(ReasoningContent::Text {
349 text: existing_text,
350 ..
351 }) = existing.content.last_mut()
352 {
353 existing_text.push_str(text);
354 return;
355 }
356
357 self.assistant_items
358 .push(AssistantContent::Reasoning(Reasoning {
359 id: id.clone(),
360 content: vec![ReasoningContent::Text {
361 text: text.to_string(),
362 signature: None,
363 }],
364 }));
365 self.reasoning_item_index = Some(self.assistant_items.len() - 1);
366 }
367}
368
369fn merge_text_additional_params(existing: &mut serde_json::Value, incoming: serde_json::Value) {
370 match (existing, incoming) {
371 (serde_json::Value::Object(existing_map), serde_json::Value::Object(incoming_map)) => {
372 for (key, incoming_value) in incoming_map {
373 match existing_map.get_mut(&key) {
374 Some(existing_value) => match (existing_value, incoming_value) {
375 (
376 serde_json::Value::Array(existing_array),
377 serde_json::Value::Array(mut incoming_array),
378 ) => existing_array.append(&mut incoming_array),
379 (existing_value, incoming_value) => {
380 merge_text_additional_params(existing_value, incoming_value);
381 }
382 },
383 None => {
384 existing_map.insert(key, incoming_value);
385 }
386 }
387 }
388 }
389 (existing, incoming) => {
390 *existing = incoming;
391 }
392 }
393}
394
395impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
396where
397 R: Clone + Unpin + GetTokenUsage,
398{
399 fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
400 CompletionResponse {
401 choice: value.choice,
402 usage: Usage::new(), raw_response: value.response,
404 message_id: value.message_id,
405 }
406 }
407}
408
409impl<R> Stream for StreamingCompletionResponse<R>
410where
411 R: Clone + Unpin + GetTokenUsage,
412{
413 type Item = Result<StreamedAssistantContent<R>, CompletionError>;
414
415 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
416 let stream = self.get_mut();
417
418 if stream.is_paused() {
419 cx.waker().wake_by_ref();
420 return Poll::Pending;
421 }
422
423 match Pin::new(&mut stream.inner).poll_next(cx) {
424 Poll::Pending => Poll::Pending,
425 Poll::Ready(None) => {
426 if stream.assistant_items.is_empty() {
429 stream.assistant_items.push(AssistantContent::text(""));
430 }
431
432 if let Some(choice) =
433 OneOrMany::from_iter_optional(std::mem::take(&mut stream.assistant_items))
434 {
435 stream.choice = choice;
436 }
437
438 Poll::Ready(None)
439 }
440 Poll::Ready(Some(Err(err))) => {
441 if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
442 {
443 return Poll::Ready(None); }
445 Poll::Ready(Some(Err(err)))
446 }
447 Poll::Ready(Some(Ok(choice))) => match choice {
448 RawStreamingChoice::Message(text) => {
449 stream.reasoning_item_index = None;
450 stream.append_text_chunk(&text);
451 Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
452 }
453 RawStreamingChoice::TextStart { additional_params } => {
454 stream.reasoning_item_index = None;
455 stream.text_item_index = None;
456 if let Some(additional_params) = additional_params {
457 stream.append_text_additional_params(additional_params);
458 }
459 stream.poll_next_unpin(cx)
460 }
461 RawStreamingChoice::TextAdditionalParams(additional_params) => {
462 stream.append_text_additional_params(additional_params);
463 stream.poll_next_unpin(cx)
464 }
465 RawStreamingChoice::ToolCallDelta {
466 id,
467 internal_call_id,
468 content,
469 } => Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
470 id,
471 internal_call_id,
472 content,
473 }))),
474 RawStreamingChoice::Reasoning { id, content } => {
475 let reasoning = Reasoning {
476 id,
477 content: vec![content],
478 };
479 stream.text_item_index = None;
480 stream.reasoning_item_index = None;
482 stream
483 .assistant_items
484 .push(AssistantContent::Reasoning(reasoning.clone()));
485 Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(reasoning))))
486 }
487 RawStreamingChoice::ReasoningDelta { id, reasoning } => {
488 stream.text_item_index = None;
489 stream.append_reasoning_chunk(&id, &reasoning);
490 Poll::Ready(Some(Ok(StreamedAssistantContent::ReasoningDelta {
491 id,
492 reasoning,
493 })))
494 }
495 RawStreamingChoice::ToolCall(raw_tool_call) => {
496 let internal_call_id = raw_tool_call.internal_call_id.clone();
497 let tool_call: ToolCall = raw_tool_call.into();
498 stream.text_item_index = None;
499 stream.reasoning_item_index = None;
500 stream
501 .assistant_items
502 .push(AssistantContent::ToolCall(tool_call.clone()));
503 Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCall {
504 tool_call,
505 internal_call_id,
506 })))
507 }
508 RawStreamingChoice::FinalResponse(response) => {
509 if stream
510 .final_response_yielded
511 .load(std::sync::atomic::Ordering::SeqCst)
512 {
513 stream.poll_next_unpin(cx)
514 } else {
515 stream.response = Some(response.clone());
517 stream
518 .final_response_yielded
519 .store(true, std::sync::atomic::Ordering::SeqCst);
520 let final_response = StreamedAssistantContent::final_response(response);
521 Poll::Ready(Some(Ok(final_response)))
522 }
523 }
524 RawStreamingChoice::MessageId(id) => {
525 stream.message_id = Some(id);
526 stream.poll_next_unpin(cx)
527 }
528 },
529 }
530 }
531}
532
533pub trait StreamingPrompt<M, R>
539where
540 M: CompletionModel + 'static,
541 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
542 R: Clone + Unpin + GetTokenUsage,
543{
544 type Hook: PromptHook<M>;
562
563 fn stream_prompt(
565 &self,
566 prompt: impl Into<Message> + WasmCompatSend,
567 ) -> StreamingPromptRequest<M, Self::Hook>;
568}
569
570pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
576where
577 M: CompletionModel + 'static,
578 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
579 R: Clone + Unpin + GetTokenUsage,
580{
581 type Hook: PromptHook<M>;
603
604 fn stream_chat<I, T>(
629 &self,
630 prompt: impl Into<Message> + WasmCompatSend,
631 chat_history: I,
632 ) -> StreamingPromptRequest<M, Self::Hook>
633 where
634 I: IntoIterator<Item = T> + WasmCompatSend,
635 T: Into<Message>;
636}
637
638pub trait StreamingCompletion<M: CompletionModel> {
640 fn stream_completion<I, T>(
642 &self,
643 prompt: impl Into<Message> + WasmCompatSend,
644 chat_history: I,
645 ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
646 where
647 I: IntoIterator<Item = T> + WasmCompatSend,
648 T: Into<Message>;
649}
650
651pub async fn stream_to_stdout<M>(
654 agent: &'static Agent<M>,
655 stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
656) -> Result<(), std::io::Error>
657where
658 M: CompletionModel,
659{
660 let mut is_reasoning = false;
661 print!("Response: ");
662 while let Some(chunk) = stream.next().await {
663 match chunk {
664 Ok(StreamedAssistantContent::Text(text)) => {
665 if is_reasoning {
666 is_reasoning = false;
667 println!("\n---\n");
668 }
669 print!("{}", text.text);
670 std::io::Write::flush(&mut std::io::stdout())?;
671 }
672 Ok(StreamedAssistantContent::ToolCall {
673 tool_call,
674 internal_call_id: _,
675 }) => {
676 let res = agent
677 .tool_server_handle
678 .call_tool(
679 &tool_call.function.name,
680 &tool_call.function.arguments.to_string(),
681 )
682 .await
683 .map_err(|x| std::io::Error::other(x.to_string()))?;
684 println!("\nResult: {res}");
685 }
686 Ok(StreamedAssistantContent::Final(res)) => {
687 if let Ok(json_res) = serde_json::to_string_pretty(&res) {
688 println!();
689 tracing::info!("Final result: {json_res}");
690 }
691 }
692 Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
693 if !is_reasoning {
694 is_reasoning = true;
695 println!();
696 println!("Thinking: ");
697 }
698 let reasoning = reasoning.display_text();
699
700 print!("{reasoning}");
701 std::io::Write::flush(&mut std::io::stdout())?;
702 }
703 Err(e) => {
704 if e.to_string().contains("aborted") {
705 println!("\nStream cancelled.");
706 break;
707 }
708 eprintln!("Error: {e}");
709 break;
710 }
711 _ => {}
712 }
713 }
714
715 println!(); Ok(())
718}
719
720#[cfg(test)]
722mod tests {
723 use std::time::Duration;
724
725 use super::*;
726 use crate::test_utils::MockResponse;
727 use async_stream::stream;
728 use tokio::time::sleep;
729
730 #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
731 fn to_stream_result(
732 stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
733 + Send
734 + 'static,
735 ) -> StreamingResult<MockResponse> {
736 Box::pin(stream)
737 }
738
739 #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
740 fn to_stream_result(
741 stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
742 + 'static,
743 ) -> StreamingResult<MockResponse> {
744 Box::pin(stream)
745 }
746
747 fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
748 let stream = stream! {
749 yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
750 sleep(Duration::from_millis(100)).await;
751 yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
752 sleep(Duration::from_millis(100)).await;
753 yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
754 sleep(Duration::from_millis(100)).await;
755 yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(15)));
756 };
757
758 StreamingCompletionResponse::stream(to_stream_result(stream))
759 }
760
761 fn create_reasoning_stream() -> StreamingCompletionResponse<MockResponse> {
762 let stream = stream! {
763 yield Ok(RawStreamingChoice::Reasoning {
764 id: Some("rs_1".to_string()),
765 content: ReasoningContent::Text {
766 text: "step one".to_string(),
767 signature: Some("sig_1".to_string()),
768 },
769 });
770 yield Ok(RawStreamingChoice::Message("final answer".to_string()));
771 yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(5)));
772 };
773
774 StreamingCompletionResponse::stream(to_stream_result(stream))
775 }
776
777 fn create_reasoning_only_stream() -> StreamingCompletionResponse<MockResponse> {
778 let stream = stream! {
779 yield Ok(RawStreamingChoice::Reasoning {
780 id: Some("rs_only".to_string()),
781 content: ReasoningContent::Summary("hidden summary".to_string()),
782 });
783 yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(2)));
784 };
785
786 StreamingCompletionResponse::stream(to_stream_result(stream))
787 }
788
789 fn create_interleaved_stream() -> StreamingCompletionResponse<MockResponse> {
790 let stream = stream! {
791 yield Ok(RawStreamingChoice::Reasoning {
792 id: Some("rs_interleaved".to_string()),
793 content: ReasoningContent::Text {
794 text: "chain-of-thought".to_string(),
795 signature: None,
796 },
797 });
798 yield Ok(RawStreamingChoice::Message("final-text".to_string()));
799 yield Ok(RawStreamingChoice::ToolCall(
800 RawStreamingToolCall::new(
801 "tool_1".to_string(),
802 "mock_tool".to_string(),
803 serde_json::json!({"arg": 1}),
804 ),
805 ));
806 yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(3)));
807 };
808
809 StreamingCompletionResponse::stream(to_stream_result(stream))
810 }
811
812 fn create_text_tool_text_stream() -> StreamingCompletionResponse<MockResponse> {
813 let stream = stream! {
814 yield Ok(RawStreamingChoice::Message("first".to_string()));
815 yield Ok(RawStreamingChoice::ToolCall(
816 RawStreamingToolCall::new(
817 "tool_split".to_string(),
818 "mock_tool".to_string(),
819 serde_json::json!({"arg": "x"}),
820 ),
821 ));
822 yield Ok(RawStreamingChoice::Message("second".to_string()));
823 yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(3)));
824 };
825
826 StreamingCompletionResponse::stream(to_stream_result(stream))
827 }
828
829 fn create_text_metadata_stream() -> StreamingCompletionResponse<MockResponse> {
830 let stream = stream! {
831 yield Ok(RawStreamingChoice::TextStart {
832 additional_params: None,
833 });
834 yield Ok(RawStreamingChoice::Message("first".to_string()));
835 yield Ok(RawStreamingChoice::TextAdditionalParams(serde_json::json!({
836 "citations": [{
837 "type": "char_location",
838 "cited_text": "First citation.",
839 "document_index": 0,
840 "start_char_index": 0,
841 "end_char_index": 15
842 }]
843 })));
844 yield Ok(RawStreamingChoice::TextAdditionalParams(serde_json::json!({
845 "citations": [{
846 "type": "char_location",
847 "cited_text": "Second citation.",
848 "document_index": 0,
849 "start_char_index": 16,
850 "end_char_index": 32
851 }]
852 })));
853 yield Ok(RawStreamingChoice::TextStart {
854 additional_params: Some(serde_json::json!({
855 "block": 2
856 })),
857 });
858 yield Ok(RawStreamingChoice::Message("second".to_string()));
859 yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(3)));
860 };
861
862 StreamingCompletionResponse::stream(to_stream_result(stream))
863 }
864
865 #[tokio::test]
866 async fn test_stream_cancellation() {
867 let mut stream = create_mock_stream();
868
869 println!("Response: ");
870 let mut chunk_count = 0;
871 while let Some(chunk) = stream.next().await {
872 match chunk {
873 Ok(StreamedAssistantContent::Text(text)) => {
874 print!("{}", text.text);
875 std::io::Write::flush(&mut std::io::stdout()).unwrap();
876 chunk_count += 1;
877 }
878 Ok(StreamedAssistantContent::ToolCall {
879 tool_call,
880 internal_call_id,
881 }) => {
882 println!("\nTool Call: {tool_call:?}, internal_call_id={internal_call_id:?}");
883 chunk_count += 1;
884 }
885 Ok(StreamedAssistantContent::ToolCallDelta {
886 id,
887 internal_call_id,
888 content,
889 }) => {
890 println!(
891 "\nTool Call delta: id={id:?}, internal_call_id={internal_call_id:?}, content={content:?}"
892 );
893 chunk_count += 1;
894 }
895 Ok(StreamedAssistantContent::Final(res)) => {
896 println!("\nFinal response: {res:?}");
897 }
898 Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
899 let reasoning = reasoning.display_text();
900 print!("{reasoning}");
901 std::io::Write::flush(&mut std::io::stdout()).unwrap();
902 }
903 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
904 println!("Reasoning delta: {reasoning}");
905 chunk_count += 1;
906 }
907 Err(e) => {
908 eprintln!("Error: {e:?}");
909 break;
910 }
911 }
912
913 if chunk_count >= 2 {
914 println!("\nCancelling stream...");
915 stream.cancel();
916 println!("Stream cancelled.");
917 break;
918 }
919 }
920
921 let next_chunk = stream.next().await;
922 assert!(
923 next_chunk.is_none(),
924 "Expected no further chunks after cancellation, got {next_chunk:?}"
925 );
926 }
927
928 #[tokio::test]
929 async fn test_stream_pause_resume() {
930 let stream = create_mock_stream();
931
932 stream.pause();
934 assert!(stream.is_paused());
935
936 stream.resume();
938 assert!(!stream.is_paused());
939 }
940
941 #[tokio::test]
942 async fn test_stream_aggregates_reasoning_content() {
943 let mut stream = create_reasoning_stream();
944 while stream.next().await.is_some() {}
945
946 let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
947
948 assert!(choice_items.iter().any(|item| matches!(
949 item,
950 AssistantContent::Reasoning(Reasoning {
951 id: Some(id),
952 content
953 }) if id == "rs_1"
954 && matches!(
955 content.first(),
956 Some(ReasoningContent::Text {
957 text,
958 signature: Some(signature)
959 }) if text == "step one" && signature == "sig_1"
960 )
961 )));
962 }
963
964 #[tokio::test]
965 async fn test_stream_reasoning_only_does_not_inject_empty_text() {
966 let mut stream = create_reasoning_only_stream();
967 while stream.next().await.is_some() {}
968
969 let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
970 assert_eq!(choice_items.len(), 1);
971 assert!(matches!(
972 choice_items.first(),
973 Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_only"
974 ));
975 }
976
977 #[tokio::test]
978 async fn test_stream_aggregates_assistant_items_in_arrival_order() {
979 let mut stream = create_interleaved_stream();
980 while stream.next().await.is_some() {}
981
982 let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
983 assert_eq!(choice_items.len(), 3);
984 assert!(matches!(
985 choice_items.first(),
986 Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_interleaved"
987 ));
988 assert!(matches!(
989 choice_items.get(1),
990 Some(AssistantContent::Text(Text { text, .. })) if text == "final-text"
991 ));
992 assert!(matches!(
993 choice_items.get(2),
994 Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_1"
995 ));
996 }
997
998 #[tokio::test]
999 async fn test_stream_keeps_non_contiguous_text_chunks_split_by_tool_call() {
1000 let mut stream = create_text_tool_text_stream();
1001 while stream.next().await.is_some() {}
1002
1003 let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
1004 assert_eq!(choice_items.len(), 3);
1005 assert!(matches!(
1006 choice_items.first(),
1007 Some(AssistantContent::Text(Text { text, .. })) if text == "first"
1008 ));
1009 assert!(matches!(
1010 choice_items.get(1),
1011 Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_split"
1012 ));
1013 assert!(matches!(
1014 choice_items.get(2),
1015 Some(AssistantContent::Text(Text { text, .. })) if text == "second"
1016 ));
1017 }
1018
1019 #[tokio::test]
1020 async fn test_stream_preserves_text_additional_params() {
1021 let mut stream = create_text_metadata_stream();
1022 while stream.next().await.is_some() {}
1023
1024 let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
1025 assert_eq!(choice_items.len(), 2);
1026
1027 let Some(AssistantContent::Text(Text {
1028 text,
1029 additional_params: Some(additional_params),
1030 })) = choice_items.first()
1031 else {
1032 panic!("expected first text item with metadata");
1033 };
1034 assert_eq!(text, "first");
1035 assert_eq!(
1036 additional_params["citations"]
1037 .as_array()
1038 .expect("citations should be an array")
1039 .len(),
1040 2
1041 );
1042
1043 let Some(AssistantContent::Text(Text {
1044 text,
1045 additional_params: Some(additional_params),
1046 })) = choice_items.get(1)
1047 else {
1048 panic!("expected second text item with metadata");
1049 };
1050 assert_eq!(text, "second");
1051 assert_eq!(additional_params["block"], 2);
1052 }
1053}
1054
1055#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
1057#[serde(untagged)]
1058pub enum StreamedAssistantContent<R> {
1059 Text(Text),
1061 ToolCall {
1063 tool_call: ToolCall,
1064 internal_call_id: String,
1067 },
1068 ToolCallDelta {
1070 id: String,
1072 internal_call_id: String,
1074 content: ToolCallDeltaContent,
1075 },
1076 Reasoning(Reasoning),
1078 ReasoningDelta {
1080 id: Option<String>,
1082 reasoning: String,
1084 },
1085 Final(R),
1087}
1088
1089impl<R> StreamedAssistantContent<R>
1090where
1091 R: Clone + Unpin,
1092{
1093 pub fn text(text: &str) -> Self {
1095 Self::Text(Text::new(text.to_string()))
1096 }
1097
1098 pub fn final_response(res: R) -> Self {
1100 Self::Final(res)
1101 }
1102}
1103
1104#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
1106#[serde(untagged)]
1107pub enum StreamedUserContent {
1108 ToolResult {
1110 tool_result: ToolResult,
1111 internal_call_id: String,
1115 },
1116}
1117
1118impl StreamedUserContent {
1119 pub fn tool_result(tool_result: ToolResult, internal_call_id: String) -> Self {
1121 Self::ToolResult {
1122 tool_result,
1123 internal_call_id,
1124 }
1125 }
1126}