1use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::hooks::PromptHook;
14use crate::agent::prompt_request::streaming::StreamingPromptRequest;
15use crate::client::FinalCompletionResponse;
16use crate::completion::{
17 CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
18 Message, Usage,
19};
20use crate::message::{
21 AssistantContent, Reasoning, ReasoningContent, Text, ToolCall, ToolFunction, ToolResult,
22};
23use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
24use futures::stream::{AbortHandle, Abortable};
25use futures::{Stream, StreamExt};
26use serde::{Deserialize, Serialize};
27use std::future::Future;
28use std::pin::Pin;
29use std::sync::atomic::AtomicBool;
30use std::task::{Context, Poll};
31use tokio::sync::watch;
32
33pub struct PauseControl {
35 pub(crate) paused_tx: watch::Sender<bool>,
36 pub(crate) paused_rx: watch::Receiver<bool>,
37}
38
39impl PauseControl {
40 pub fn new() -> Self {
41 let (paused_tx, paused_rx) = watch::channel(false);
42 Self {
43 paused_tx,
44 paused_rx,
45 }
46 }
47
48 pub fn pause(&self) {
49 self.paused_tx.send(true).unwrap();
50 }
51
52 pub fn resume(&self) {
53 self.paused_tx.send(false).unwrap();
54 }
55
56 pub fn is_paused(&self) -> bool {
57 *self.paused_rx.borrow()
58 }
59}
60
61impl Default for PauseControl {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
69pub enum ToolCallDeltaContent {
70 Name(String),
71 Delta(String),
72}
73
74#[derive(Debug, Clone)]
76pub enum RawStreamingChoice<R>
77where
78 R: Clone,
79{
80 Message(String),
82
83 ToolCall(RawStreamingToolCall),
85 ToolCallDelta {
87 id: String,
89 internal_call_id: String,
91 content: ToolCallDeltaContent,
92 },
93 Reasoning {
95 id: Option<String>,
96 content: ReasoningContent,
97 },
98 ReasoningDelta {
100 id: Option<String>,
101 reasoning: String,
102 },
103
104 FinalResponse(R),
107
108 MessageId(String),
111}
112
113#[derive(Debug, Clone)]
115pub struct RawStreamingToolCall {
116 pub id: String,
118 pub internal_call_id: String,
120 pub call_id: Option<String>,
121 pub name: String,
122 pub arguments: serde_json::Value,
123 pub signature: Option<String>,
124 pub additional_params: Option<serde_json::Value>,
125}
126
127impl RawStreamingToolCall {
128 pub fn empty() -> Self {
129 Self {
130 id: String::new(),
131 internal_call_id: nanoid::nanoid!(),
132 call_id: None,
133 name: String::new(),
134 arguments: serde_json::Value::Null,
135 signature: None,
136 additional_params: None,
137 }
138 }
139
140 pub fn new(id: String, name: String, arguments: serde_json::Value) -> Self {
141 Self {
142 id,
143 internal_call_id: nanoid::nanoid!(),
144 call_id: None,
145 name,
146 arguments,
147 signature: None,
148 additional_params: None,
149 }
150 }
151
152 pub fn with_internal_call_id(mut self, internal_call_id: String) -> Self {
153 self.internal_call_id = internal_call_id;
154 self
155 }
156
157 pub fn with_call_id(mut self, call_id: String) -> Self {
158 self.call_id = Some(call_id);
159 self
160 }
161
162 pub fn with_signature(mut self, signature: Option<String>) -> Self {
163 self.signature = signature;
164 self
165 }
166
167 pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
168 self.additional_params = additional_params;
169 self
170 }
171}
172
173impl From<RawStreamingToolCall> for ToolCall {
174 fn from(tool_call: RawStreamingToolCall) -> Self {
175 ToolCall {
176 id: tool_call.id,
177 call_id: tool_call.call_id,
178 function: ToolFunction {
179 name: tool_call.name,
180 arguments: tool_call.arguments,
181 },
182 signature: tool_call.signature,
183 additional_params: tool_call.additional_params,
184 }
185 }
186}
187
188#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
189pub type StreamingResult<R> =
190 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
191
192#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
193pub type StreamingResult<R> =
194 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
195
196pub struct StreamingCompletionResponse<R>
200where
201 R: Clone + Unpin + GetTokenUsage,
202{
203 pub(crate) inner: Abortable<StreamingResult<R>>,
204 pub(crate) abort_handle: AbortHandle,
205 pub(crate) pause_control: PauseControl,
206 assistant_items: Vec<AssistantContent>,
207 text_item_index: Option<usize>,
208 reasoning_item_index: Option<usize>,
209 pub choice: OneOrMany<AssistantContent>,
212 pub response: Option<R>,
215 pub final_response_yielded: AtomicBool,
216 pub message_id: Option<String>,
218}
219
220impl<R> StreamingCompletionResponse<R>
221where
222 R: Clone + Unpin + GetTokenUsage,
223{
224 pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
225 let (abort_handle, abort_registration) = AbortHandle::new_pair();
226 let abortable_stream = Abortable::new(inner, abort_registration);
227 let pause_control = PauseControl::new();
228 Self {
229 inner: abortable_stream,
230 abort_handle,
231 pause_control,
232 assistant_items: vec![],
233 text_item_index: None,
234 reasoning_item_index: None,
235 choice: OneOrMany::one(AssistantContent::text("")),
236 response: None,
237 final_response_yielded: AtomicBool::new(false),
238 message_id: None,
239 }
240 }
241
242 pub fn cancel(&self) {
243 self.abort_handle.abort();
244 }
245
246 pub fn pause(&self) {
247 self.pause_control.pause();
248 }
249
250 pub fn resume(&self) {
251 self.pause_control.resume();
252 }
253
254 pub fn is_paused(&self) -> bool {
255 self.pause_control.is_paused()
256 }
257
258 fn append_text_chunk(&mut self, text: &str) {
259 if let Some(index) = self.text_item_index
260 && let Some(AssistantContent::Text(existing_text)) = self.assistant_items.get_mut(index)
261 {
262 existing_text.text.push_str(text);
263 return;
264 }
265
266 self.assistant_items
267 .push(AssistantContent::text(text.to_owned()));
268 self.text_item_index = Some(self.assistant_items.len() - 1);
269 }
270
271 fn append_reasoning_chunk(&mut self, id: &Option<String>, text: &str) {
275 if let Some(index) = self.reasoning_item_index
276 && let Some(AssistantContent::Reasoning(existing)) = self.assistant_items.get_mut(index)
277 && let Some(ReasoningContent::Text {
278 text: existing_text,
279 ..
280 }) = existing.content.last_mut()
281 {
282 existing_text.push_str(text);
283 return;
284 }
285
286 self.assistant_items
287 .push(AssistantContent::Reasoning(Reasoning {
288 id: id.clone(),
289 content: vec![ReasoningContent::Text {
290 text: text.to_string(),
291 signature: None,
292 }],
293 }));
294 self.reasoning_item_index = Some(self.assistant_items.len() - 1);
295 }
296}
297
298impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
299where
300 R: Clone + Unpin + GetTokenUsage,
301{
302 fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
303 CompletionResponse {
304 choice: value.choice,
305 usage: Usage::new(), raw_response: value.response,
307 message_id: value.message_id,
308 }
309 }
310}
311
312impl<R> Stream for StreamingCompletionResponse<R>
313where
314 R: Clone + Unpin + GetTokenUsage,
315{
316 type Item = Result<StreamedAssistantContent<R>, CompletionError>;
317
318 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
319 let stream = self.get_mut();
320
321 if stream.is_paused() {
322 cx.waker().wake_by_ref();
323 return Poll::Pending;
324 }
325
326 match Pin::new(&mut stream.inner).poll_next(cx) {
327 Poll::Pending => Poll::Pending,
328 Poll::Ready(None) => {
329 if stream.assistant_items.is_empty() {
332 stream.assistant_items.push(AssistantContent::text(""));
333 }
334
335 stream.choice = OneOrMany::many(std::mem::take(&mut stream.assistant_items))
336 .expect("There should be at least one assistant message");
337
338 Poll::Ready(None)
339 }
340 Poll::Ready(Some(Err(err))) => {
341 if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
342 {
343 return Poll::Ready(None); }
345 Poll::Ready(Some(Err(err)))
346 }
347 Poll::Ready(Some(Ok(choice))) => match choice {
348 RawStreamingChoice::Message(text) => {
349 stream.reasoning_item_index = None;
350 stream.append_text_chunk(&text);
351 Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
352 }
353 RawStreamingChoice::ToolCallDelta {
354 id,
355 internal_call_id,
356 content,
357 } => Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
358 id,
359 internal_call_id,
360 content,
361 }))),
362 RawStreamingChoice::Reasoning { id, content } => {
363 let reasoning = Reasoning {
364 id,
365 content: vec![content],
366 };
367 stream.text_item_index = None;
368 stream.reasoning_item_index = None;
370 stream
371 .assistant_items
372 .push(AssistantContent::Reasoning(reasoning.clone()));
373 Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(reasoning))))
374 }
375 RawStreamingChoice::ReasoningDelta { id, reasoning } => {
376 stream.text_item_index = None;
377 stream.append_reasoning_chunk(&id, &reasoning);
378 Poll::Ready(Some(Ok(StreamedAssistantContent::ReasoningDelta {
379 id,
380 reasoning,
381 })))
382 }
383 RawStreamingChoice::ToolCall(raw_tool_call) => {
384 let internal_call_id = raw_tool_call.internal_call_id.clone();
385 let tool_call: ToolCall = raw_tool_call.into();
386 stream.text_item_index = None;
387 stream.reasoning_item_index = None;
388 stream
389 .assistant_items
390 .push(AssistantContent::ToolCall(tool_call.clone()));
391 Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCall {
392 tool_call,
393 internal_call_id,
394 })))
395 }
396 RawStreamingChoice::FinalResponse(response) => {
397 if stream
398 .final_response_yielded
399 .load(std::sync::atomic::Ordering::SeqCst)
400 {
401 stream.poll_next_unpin(cx)
402 } else {
403 stream.response = Some(response.clone());
405 stream
406 .final_response_yielded
407 .store(true, std::sync::atomic::Ordering::SeqCst);
408 let final_response = StreamedAssistantContent::final_response(response);
409 Poll::Ready(Some(Ok(final_response)))
410 }
411 }
412 RawStreamingChoice::MessageId(id) => {
413 stream.message_id = Some(id);
414 stream.poll_next_unpin(cx)
415 }
416 },
417 }
418 }
419}
420
421pub trait StreamingPrompt<M, R>
427where
428 M: CompletionModel + 'static,
429 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
430 R: Clone + Unpin + GetTokenUsage,
431{
432 type Hook: PromptHook<M>;
450
451 fn stream_prompt(
453 &self,
454 prompt: impl Into<Message> + WasmCompatSend,
455 ) -> StreamingPromptRequest<M, Self::Hook>;
456}
457
458pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
464where
465 M: CompletionModel + 'static,
466 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
467 R: Clone + Unpin + GetTokenUsage,
468{
469 type Hook: PromptHook<M>;
491
492 fn stream_chat<I, T>(
517 &self,
518 prompt: impl Into<Message> + WasmCompatSend,
519 chat_history: I,
520 ) -> StreamingPromptRequest<M, Self::Hook>
521 where
522 I: IntoIterator<Item = T> + WasmCompatSend,
523 T: Into<Message>;
524}
525
526pub trait StreamingCompletion<M: CompletionModel> {
528 fn stream_completion<I, T>(
530 &self,
531 prompt: impl Into<Message> + WasmCompatSend,
532 chat_history: I,
533 ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
534 where
535 I: IntoIterator<Item = T> + WasmCompatSend,
536 T: Into<Message>;
537}
538
539pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
540 pub(crate) inner: StreamingResult<R>,
541}
542
543fn map_raw_streaming_choice<R>(
544 chunk: RawStreamingChoice<R>,
545) -> RawStreamingChoice<FinalCompletionResponse>
546where
547 R: Clone + Unpin + GetTokenUsage,
548{
549 match chunk {
550 RawStreamingChoice::FinalResponse(res) => {
551 RawStreamingChoice::FinalResponse(FinalCompletionResponse {
552 usage: res.token_usage(),
553 })
554 }
555 RawStreamingChoice::Message(m) => RawStreamingChoice::Message(m),
556 RawStreamingChoice::ToolCallDelta {
557 id,
558 internal_call_id,
559 content,
560 } => RawStreamingChoice::ToolCallDelta {
561 id,
562 internal_call_id,
563 content,
564 },
565 RawStreamingChoice::Reasoning { id, content } => {
566 RawStreamingChoice::Reasoning { id, content }
567 }
568 RawStreamingChoice::ReasoningDelta { id, reasoning } => {
569 RawStreamingChoice::ReasoningDelta { id, reasoning }
570 }
571 RawStreamingChoice::ToolCall(tool_call) => RawStreamingChoice::ToolCall(tool_call),
572 RawStreamingChoice::MessageId(id) => RawStreamingChoice::MessageId(id),
573 }
574}
575
576impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
577 type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;
578
579 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
580 let stream = self.get_mut();
581
582 match stream.inner.as_mut().poll_next(cx) {
583 Poll::Pending => Poll::Pending,
584 Poll::Ready(None) => Poll::Ready(None),
585 Poll::Ready(Some(item)) => Poll::Ready(Some(item.map(map_raw_streaming_choice::<R>))),
586 }
587 }
588}
589
590pub async fn stream_to_stdout<M>(
593 agent: &'static Agent<M>,
594 stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
595) -> Result<(), std::io::Error>
596where
597 M: CompletionModel,
598{
599 let mut is_reasoning = false;
600 print!("Response: ");
601 while let Some(chunk) = stream.next().await {
602 match chunk {
603 Ok(StreamedAssistantContent::Text(text)) => {
604 if is_reasoning {
605 is_reasoning = false;
606 println!("\n---\n");
607 }
608 print!("{}", text.text);
609 std::io::Write::flush(&mut std::io::stdout())?;
610 }
611 Ok(StreamedAssistantContent::ToolCall {
612 tool_call,
613 internal_call_id: _,
614 }) => {
615 let res = agent
616 .tool_server_handle
617 .call_tool(
618 &tool_call.function.name,
619 &tool_call.function.arguments.to_string(),
620 )
621 .await
622 .map_err(|x| std::io::Error::other(x.to_string()))?;
623 println!("\nResult: {res}");
624 }
625 Ok(StreamedAssistantContent::Final(res)) => {
626 let json_res = serde_json::to_string_pretty(&res).unwrap();
627 println!();
628 tracing::info!("Final result: {json_res}");
629 }
630 Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
631 if !is_reasoning {
632 is_reasoning = true;
633 println!();
634 println!("Thinking: ");
635 }
636 let reasoning = reasoning.display_text();
637
638 print!("{reasoning}");
639 std::io::Write::flush(&mut std::io::stdout())?;
640 }
641 Err(e) => {
642 if e.to_string().contains("aborted") {
643 println!("\nStream cancelled.");
644 break;
645 }
646 eprintln!("Error: {e}");
647 break;
648 }
649 _ => {}
650 }
651 }
652
653 println!(); Ok(())
656}
657
658#[cfg(test)]
660mod tests {
661 use std::time::Duration;
662
663 use super::*;
664 use async_stream::stream;
665 use tokio::time::sleep;
666
667 #[derive(Debug, Clone)]
668 pub struct MockResponse {
669 #[allow(dead_code)]
670 token_count: u32,
671 }
672
673 impl GetTokenUsage for MockResponse {
674 fn token_usage(&self) -> Option<crate::completion::Usage> {
675 let mut usage = Usage::new();
676 usage.total_tokens = 15;
677 Some(usage)
678 }
679 }
680
681 #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
682 fn to_stream_result(
683 stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
684 + Send
685 + 'static,
686 ) -> StreamingResult<MockResponse> {
687 Box::pin(stream)
688 }
689
690 #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
691 fn to_stream_result(
692 stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
693 + 'static,
694 ) -> StreamingResult<MockResponse> {
695 Box::pin(stream)
696 }
697
698 fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
699 let stream = stream! {
700 yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
701 sleep(Duration::from_millis(100)).await;
702 yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
703 sleep(Duration::from_millis(100)).await;
704 yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
705 sleep(Duration::from_millis(100)).await;
706 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
707 };
708
709 StreamingCompletionResponse::stream(to_stream_result(stream))
710 }
711
712 fn create_reasoning_stream() -> StreamingCompletionResponse<MockResponse> {
713 let stream = stream! {
714 yield Ok(RawStreamingChoice::Reasoning {
715 id: Some("rs_1".to_string()),
716 content: ReasoningContent::Text {
717 text: "step one".to_string(),
718 signature: Some("sig_1".to_string()),
719 },
720 });
721 yield Ok(RawStreamingChoice::Message("final answer".to_string()));
722 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 5 }));
723 };
724
725 StreamingCompletionResponse::stream(to_stream_result(stream))
726 }
727
728 fn create_reasoning_only_stream() -> StreamingCompletionResponse<MockResponse> {
729 let stream = stream! {
730 yield Ok(RawStreamingChoice::Reasoning {
731 id: Some("rs_only".to_string()),
732 content: ReasoningContent::Summary("hidden summary".to_string()),
733 });
734 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 2 }));
735 };
736
737 StreamingCompletionResponse::stream(to_stream_result(stream))
738 }
739
740 fn create_interleaved_stream() -> StreamingCompletionResponse<MockResponse> {
741 let stream = stream! {
742 yield Ok(RawStreamingChoice::Reasoning {
743 id: Some("rs_interleaved".to_string()),
744 content: ReasoningContent::Text {
745 text: "chain-of-thought".to_string(),
746 signature: None,
747 },
748 });
749 yield Ok(RawStreamingChoice::Message("final-text".to_string()));
750 yield Ok(RawStreamingChoice::ToolCall(
751 RawStreamingToolCall::new(
752 "tool_1".to_string(),
753 "mock_tool".to_string(),
754 serde_json::json!({"arg": 1}),
755 ),
756 ));
757 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 3 }));
758 };
759
760 StreamingCompletionResponse::stream(to_stream_result(stream))
761 }
762
763 fn create_text_tool_text_stream() -> StreamingCompletionResponse<MockResponse> {
764 let stream = stream! {
765 yield Ok(RawStreamingChoice::Message("first".to_string()));
766 yield Ok(RawStreamingChoice::ToolCall(
767 RawStreamingToolCall::new(
768 "tool_split".to_string(),
769 "mock_tool".to_string(),
770 serde_json::json!({"arg": "x"}),
771 ),
772 ));
773 yield Ok(RawStreamingChoice::Message("second".to_string()));
774 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 3 }));
775 };
776
777 StreamingCompletionResponse::stream(to_stream_result(stream))
778 }
779
780 #[tokio::test]
781 async fn test_stream_cancellation() {
782 let mut stream = create_mock_stream();
783
784 println!("Response: ");
785 let mut chunk_count = 0;
786 while let Some(chunk) = stream.next().await {
787 match chunk {
788 Ok(StreamedAssistantContent::Text(text)) => {
789 print!("{}", text.text);
790 std::io::Write::flush(&mut std::io::stdout()).unwrap();
791 chunk_count += 1;
792 }
793 Ok(StreamedAssistantContent::ToolCall {
794 tool_call,
795 internal_call_id,
796 }) => {
797 println!("\nTool Call: {tool_call:?}, internal_call_id={internal_call_id:?}");
798 chunk_count += 1;
799 }
800 Ok(StreamedAssistantContent::ToolCallDelta {
801 id,
802 internal_call_id,
803 content,
804 }) => {
805 println!(
806 "\nTool Call delta: id={id:?}, internal_call_id={internal_call_id:?}, content={content:?}"
807 );
808 chunk_count += 1;
809 }
810 Ok(StreamedAssistantContent::Final(res)) => {
811 println!("\nFinal response: {res:?}");
812 }
813 Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
814 let reasoning = reasoning.display_text();
815 print!("{reasoning}");
816 std::io::Write::flush(&mut std::io::stdout()).unwrap();
817 }
818 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
819 println!("Reasoning delta: {reasoning}");
820 chunk_count += 1;
821 }
822 Err(e) => {
823 eprintln!("Error: {e:?}");
824 break;
825 }
826 }
827
828 if chunk_count >= 2 {
829 println!("\nCancelling stream...");
830 stream.cancel();
831 println!("Stream cancelled.");
832 break;
833 }
834 }
835
836 let next_chunk = stream.next().await;
837 assert!(
838 next_chunk.is_none(),
839 "Expected no further chunks after cancellation, got {next_chunk:?}"
840 );
841 }
842
843 #[tokio::test]
844 async fn test_stream_pause_resume() {
845 let stream = create_mock_stream();
846
847 stream.pause();
849 assert!(stream.is_paused());
850
851 stream.resume();
853 assert!(!stream.is_paused());
854 }
855
856 #[tokio::test]
857 async fn test_stream_aggregates_reasoning_content() {
858 let mut stream = create_reasoning_stream();
859 while stream.next().await.is_some() {}
860
861 let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
862
863 assert!(choice_items.iter().any(|item| matches!(
864 item,
865 AssistantContent::Reasoning(Reasoning {
866 id: Some(id),
867 content
868 }) if id == "rs_1"
869 && matches!(
870 content.first(),
871 Some(ReasoningContent::Text {
872 text,
873 signature: Some(signature)
874 }) if text == "step one" && signature == "sig_1"
875 )
876 )));
877 }
878
879 #[tokio::test]
880 async fn test_stream_reasoning_only_does_not_inject_empty_text() {
881 let mut stream = create_reasoning_only_stream();
882 while stream.next().await.is_some() {}
883
884 let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
885 assert_eq!(choice_items.len(), 1);
886 assert!(matches!(
887 choice_items.first(),
888 Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_only"
889 ));
890 }
891
892 #[tokio::test]
893 async fn test_stream_aggregates_assistant_items_in_arrival_order() {
894 let mut stream = create_interleaved_stream();
895 while stream.next().await.is_some() {}
896
897 let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
898 assert_eq!(choice_items.len(), 3);
899 assert!(matches!(
900 choice_items.first(),
901 Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_interleaved"
902 ));
903 assert!(matches!(
904 choice_items.get(1),
905 Some(AssistantContent::Text(Text { text })) if text == "final-text"
906 ));
907 assert!(matches!(
908 choice_items.get(2),
909 Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_1"
910 ));
911 }
912
913 #[tokio::test]
914 async fn test_stream_keeps_non_contiguous_text_chunks_split_by_tool_call() {
915 let mut stream = create_text_tool_text_stream();
916 while stream.next().await.is_some() {}
917
918 let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
919 assert_eq!(choice_items.len(), 3);
920 assert!(matches!(
921 choice_items.first(),
922 Some(AssistantContent::Text(Text { text })) if text == "first"
923 ));
924 assert!(matches!(
925 choice_items.get(1),
926 Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_split"
927 ));
928 assert!(matches!(
929 choice_items.get(2),
930 Some(AssistantContent::Text(Text { text })) if text == "second"
931 ));
932 }
933}
934
935#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
937#[serde(untagged)]
938pub enum StreamedAssistantContent<R> {
939 Text(Text),
940 ToolCall {
941 tool_call: ToolCall,
942 internal_call_id: String,
945 },
946 ToolCallDelta {
947 id: String,
949 internal_call_id: String,
951 content: ToolCallDeltaContent,
952 },
953 Reasoning(Reasoning),
954 ReasoningDelta {
955 id: Option<String>,
956 reasoning: String,
957 },
958 Final(R),
959}
960
961impl<R> StreamedAssistantContent<R>
962where
963 R: Clone + Unpin,
964{
965 pub fn text(text: &str) -> Self {
966 Self::Text(Text {
967 text: text.to_string(),
968 })
969 }
970
971 pub fn final_response(res: R) -> Self {
972 Self::Final(res)
973 }
974}
975
976#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
978#[serde(untagged)]
979pub enum StreamedUserContent {
980 ToolResult {
981 tool_result: ToolResult,
982 internal_call_id: String,
986 },
987}
988
989impl StreamedUserContent {
990 pub fn tool_result(tool_result: ToolResult, internal_call_id: String) -> Self {
991 Self::ToolResult {
992 tool_result,
993 internal_call_id,
994 }
995 }
996}