1use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::hooks::PromptHook;
14use crate::agent::prompt_request::streaming::StreamingPromptRequest;
15use crate::client::FinalCompletionResponse;
16use crate::completion::{
17 CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
18 Message, Usage,
19};
20use crate::message::{
21 AssistantContent, Reasoning, ReasoningContent, Text, ToolCall, ToolFunction, ToolResult,
22};
23use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
24use futures::stream::{AbortHandle, Abortable};
25use futures::{Stream, StreamExt};
26use serde::{Deserialize, Serialize};
27use std::future::Future;
28use std::pin::Pin;
29use std::sync::atomic::AtomicBool;
30use std::task::{Context, Poll};
31use tokio::sync::watch;
32
33pub struct PauseControl {
35 pub(crate) paused_tx: watch::Sender<bool>,
36 pub(crate) paused_rx: watch::Receiver<bool>,
37}
38
39impl PauseControl {
40 pub fn new() -> Self {
41 let (paused_tx, paused_rx) = watch::channel(false);
42 Self {
43 paused_tx,
44 paused_rx,
45 }
46 }
47
48 pub fn pause(&self) {
49 self.paused_tx.send(true).unwrap();
50 }
51
52 pub fn resume(&self) {
53 self.paused_tx.send(false).unwrap();
54 }
55
56 pub fn is_paused(&self) -> bool {
57 *self.paused_rx.borrow()
58 }
59}
60
61impl Default for PauseControl {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
69pub enum ToolCallDeltaContent {
70 Name(String),
71 Delta(String),
72}
73
74#[derive(Debug, Clone)]
76pub enum RawStreamingChoice<R>
77where
78 R: Clone,
79{
80 Message(String),
82
83 ToolCall(RawStreamingToolCall),
85 ToolCallDelta {
87 id: String,
89 internal_call_id: String,
91 content: ToolCallDeltaContent,
92 },
93 Reasoning {
95 id: Option<String>,
96 content: ReasoningContent,
97 },
98 ReasoningDelta {
100 id: Option<String>,
101 reasoning: String,
102 },
103
104 FinalResponse(R),
107
108 MessageId(String),
111}
112
113#[derive(Debug, Clone)]
115pub struct RawStreamingToolCall {
116 pub id: String,
118 pub internal_call_id: String,
120 pub call_id: Option<String>,
121 pub name: String,
122 pub arguments: serde_json::Value,
123 pub signature: Option<String>,
124 pub additional_params: Option<serde_json::Value>,
125}
126
127impl RawStreamingToolCall {
128 pub fn empty() -> Self {
129 Self {
130 id: String::new(),
131 internal_call_id: nanoid::nanoid!(),
132 call_id: None,
133 name: String::new(),
134 arguments: serde_json::Value::Null,
135 signature: None,
136 additional_params: None,
137 }
138 }
139
140 pub fn new(id: String, name: String, arguments: serde_json::Value) -> Self {
141 Self {
142 id,
143 internal_call_id: nanoid::nanoid!(),
144 call_id: None,
145 name,
146 arguments,
147 signature: None,
148 additional_params: None,
149 }
150 }
151
152 pub fn with_internal_call_id(mut self, internal_call_id: String) -> Self {
153 self.internal_call_id = internal_call_id;
154 self
155 }
156
157 pub fn with_call_id(mut self, call_id: String) -> Self {
158 self.call_id = Some(call_id);
159 self
160 }
161
162 pub fn with_signature(mut self, signature: Option<String>) -> Self {
163 self.signature = signature;
164 self
165 }
166
167 pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
168 self.additional_params = additional_params;
169 self
170 }
171}
172
173impl From<RawStreamingToolCall> for ToolCall {
174 fn from(tool_call: RawStreamingToolCall) -> Self {
175 ToolCall {
176 id: tool_call.id,
177 call_id: tool_call.call_id,
178 function: ToolFunction {
179 name: tool_call.name,
180 arguments: tool_call.arguments,
181 },
182 signature: tool_call.signature,
183 additional_params: tool_call.additional_params,
184 }
185 }
186}
187
188#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
189pub type StreamingResult<R> =
190 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
191
192#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
193pub type StreamingResult<R> =
194 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
195
196pub struct StreamingCompletionResponse<R>
200where
201 R: Clone + Unpin + GetTokenUsage,
202{
203 pub(crate) inner: Abortable<StreamingResult<R>>,
204 pub(crate) abort_handle: AbortHandle,
205 pub(crate) pause_control: PauseControl,
206 assistant_items: Vec<AssistantContent>,
207 text_item_index: Option<usize>,
208 reasoning_item_index: Option<usize>,
209 pub choice: OneOrMany<AssistantContent>,
212 pub response: Option<R>,
215 pub final_response_yielded: AtomicBool,
216 pub message_id: Option<String>,
218}
219
220impl<R> StreamingCompletionResponse<R>
221where
222 R: Clone + Unpin + GetTokenUsage,
223{
224 pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
225 let (abort_handle, abort_registration) = AbortHandle::new_pair();
226 let abortable_stream = Abortable::new(inner, abort_registration);
227 let pause_control = PauseControl::new();
228 Self {
229 inner: abortable_stream,
230 abort_handle,
231 pause_control,
232 assistant_items: vec![],
233 text_item_index: None,
234 reasoning_item_index: None,
235 choice: OneOrMany::one(AssistantContent::text("")),
236 response: None,
237 final_response_yielded: AtomicBool::new(false),
238 message_id: None,
239 }
240 }
241
242 pub fn cancel(&self) {
243 self.abort_handle.abort();
244 }
245
246 pub fn pause(&self) {
247 self.pause_control.pause();
248 }
249
250 pub fn resume(&self) {
251 self.pause_control.resume();
252 }
253
254 pub fn is_paused(&self) -> bool {
255 self.pause_control.is_paused()
256 }
257
258 fn append_text_chunk(&mut self, text: &str) {
259 if let Some(index) = self.text_item_index
260 && let Some(AssistantContent::Text(existing_text)) = self.assistant_items.get_mut(index)
261 {
262 existing_text.text.push_str(text);
263 return;
264 }
265
266 self.assistant_items
267 .push(AssistantContent::text(text.to_owned()));
268 self.text_item_index = Some(self.assistant_items.len() - 1);
269 }
270
271 fn append_reasoning_chunk(&mut self, id: &Option<String>, text: &str) {
275 if let Some(index) = self.reasoning_item_index
276 && let Some(AssistantContent::Reasoning(existing)) = self.assistant_items.get_mut(index)
277 && let Some(ReasoningContent::Text {
278 text: existing_text,
279 ..
280 }) = existing.content.last_mut()
281 {
282 existing_text.push_str(text);
283 return;
284 }
285
286 self.assistant_items
287 .push(AssistantContent::Reasoning(Reasoning {
288 id: id.clone(),
289 content: vec![ReasoningContent::Text {
290 text: text.to_string(),
291 signature: None,
292 }],
293 }));
294 self.reasoning_item_index = Some(self.assistant_items.len() - 1);
295 }
296}
297
298impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
299where
300 R: Clone + Unpin + GetTokenUsage,
301{
302 fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
303 CompletionResponse {
304 choice: value.choice,
305 usage: Usage::new(), raw_response: value.response,
307 message_id: value.message_id,
308 }
309 }
310}
311
312impl<R> Stream for StreamingCompletionResponse<R>
313where
314 R: Clone + Unpin + GetTokenUsage,
315{
316 type Item = Result<StreamedAssistantContent<R>, CompletionError>;
317
318 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
319 let stream = self.get_mut();
320
321 if stream.is_paused() {
322 cx.waker().wake_by_ref();
323 return Poll::Pending;
324 }
325
326 match Pin::new(&mut stream.inner).poll_next(cx) {
327 Poll::Pending => Poll::Pending,
328 Poll::Ready(None) => {
329 if stream.assistant_items.is_empty() {
332 stream.assistant_items.push(AssistantContent::text(""));
333 }
334
335 stream.choice = OneOrMany::many(std::mem::take(&mut stream.assistant_items))
336 .expect("There should be at least one assistant message");
337
338 Poll::Ready(None)
339 }
340 Poll::Ready(Some(Err(err))) => {
341 if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
342 {
343 return Poll::Ready(None); }
345 Poll::Ready(Some(Err(err)))
346 }
347 Poll::Ready(Some(Ok(choice))) => match choice {
348 RawStreamingChoice::Message(text) => {
349 stream.reasoning_item_index = None;
350 stream.append_text_chunk(&text);
351 Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
352 }
353 RawStreamingChoice::ToolCallDelta {
354 id,
355 internal_call_id,
356 content,
357 } => Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
358 id,
359 internal_call_id,
360 content,
361 }))),
362 RawStreamingChoice::Reasoning { id, content } => {
363 let reasoning = Reasoning {
364 id,
365 content: vec![content],
366 };
367 stream.text_item_index = None;
368 stream.reasoning_item_index = None;
370 stream
371 .assistant_items
372 .push(AssistantContent::Reasoning(reasoning.clone()));
373 Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(reasoning))))
374 }
375 RawStreamingChoice::ReasoningDelta { id, reasoning } => {
376 stream.text_item_index = None;
377 stream.append_reasoning_chunk(&id, &reasoning);
378 Poll::Ready(Some(Ok(StreamedAssistantContent::ReasoningDelta {
379 id,
380 reasoning,
381 })))
382 }
383 RawStreamingChoice::ToolCall(raw_tool_call) => {
384 let internal_call_id = raw_tool_call.internal_call_id.clone();
385 let tool_call: ToolCall = raw_tool_call.into();
386 stream.text_item_index = None;
387 stream.reasoning_item_index = None;
388 stream
389 .assistant_items
390 .push(AssistantContent::ToolCall(tool_call.clone()));
391 Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCall {
392 tool_call,
393 internal_call_id,
394 })))
395 }
396 RawStreamingChoice::FinalResponse(response) => {
397 if stream
398 .final_response_yielded
399 .load(std::sync::atomic::Ordering::SeqCst)
400 {
401 stream.poll_next_unpin(cx)
402 } else {
403 stream.response = Some(response.clone());
405 stream
406 .final_response_yielded
407 .store(true, std::sync::atomic::Ordering::SeqCst);
408 let final_response = StreamedAssistantContent::final_response(response);
409 Poll::Ready(Some(Ok(final_response)))
410 }
411 }
412 RawStreamingChoice::MessageId(id) => {
413 stream.message_id = Some(id);
414 stream.poll_next_unpin(cx)
415 }
416 },
417 }
418 }
419}
420
421pub trait StreamingPrompt<M, R>
427where
428 M: CompletionModel + 'static,
429 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
430 R: Clone + Unpin + GetTokenUsage,
431{
432 type Hook: PromptHook<M>;
450
451 fn stream_prompt(
453 &self,
454 prompt: impl Into<Message> + WasmCompatSend,
455 ) -> StreamingPromptRequest<M, Self::Hook>;
456}
457
458pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
464where
465 M: CompletionModel + 'static,
466 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
467 R: Clone + Unpin + GetTokenUsage,
468{
469 type Hook: PromptHook<M>;
491
492 fn stream_chat(
497 &self,
498 prompt: impl Into<Message> + WasmCompatSend,
499 chat_history: Vec<Message>,
500 ) -> StreamingPromptRequest<M, Self::Hook>;
501}
502
503pub trait StreamingCompletion<M: CompletionModel> {
505 fn stream_completion(
507 &self,
508 prompt: impl Into<Message> + WasmCompatSend,
509 chat_history: Vec<Message>,
510 ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
511}
512
513pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
514 pub(crate) inner: StreamingResult<R>,
515}
516
517fn map_raw_streaming_choice<R>(
518 chunk: RawStreamingChoice<R>,
519) -> RawStreamingChoice<FinalCompletionResponse>
520where
521 R: Clone + Unpin + GetTokenUsage,
522{
523 match chunk {
524 RawStreamingChoice::FinalResponse(res) => {
525 RawStreamingChoice::FinalResponse(FinalCompletionResponse {
526 usage: res.token_usage(),
527 })
528 }
529 RawStreamingChoice::Message(m) => RawStreamingChoice::Message(m),
530 RawStreamingChoice::ToolCallDelta {
531 id,
532 internal_call_id,
533 content,
534 } => RawStreamingChoice::ToolCallDelta {
535 id,
536 internal_call_id,
537 content,
538 },
539 RawStreamingChoice::Reasoning { id, content } => {
540 RawStreamingChoice::Reasoning { id, content }
541 }
542 RawStreamingChoice::ReasoningDelta { id, reasoning } => {
543 RawStreamingChoice::ReasoningDelta { id, reasoning }
544 }
545 RawStreamingChoice::ToolCall(tool_call) => RawStreamingChoice::ToolCall(tool_call),
546 RawStreamingChoice::MessageId(id) => RawStreamingChoice::MessageId(id),
547 }
548}
549
550impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
551 type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;
552
553 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
554 let stream = self.get_mut();
555
556 match stream.inner.as_mut().poll_next(cx) {
557 Poll::Pending => Poll::Pending,
558 Poll::Ready(None) => Poll::Ready(None),
559 Poll::Ready(Some(item)) => Poll::Ready(Some(item.map(map_raw_streaming_choice::<R>))),
560 }
561 }
562}
563
564pub async fn stream_to_stdout<M>(
567 agent: &'static Agent<M>,
568 stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
569) -> Result<(), std::io::Error>
570where
571 M: CompletionModel,
572{
573 let mut is_reasoning = false;
574 print!("Response: ");
575 while let Some(chunk) = stream.next().await {
576 match chunk {
577 Ok(StreamedAssistantContent::Text(text)) => {
578 if is_reasoning {
579 is_reasoning = false;
580 println!("\n---\n");
581 }
582 print!("{}", text.text);
583 std::io::Write::flush(&mut std::io::stdout())?;
584 }
585 Ok(StreamedAssistantContent::ToolCall {
586 tool_call,
587 internal_call_id: _,
588 }) => {
589 let res = agent
590 .tool_server_handle
591 .call_tool(
592 &tool_call.function.name,
593 &tool_call.function.arguments.to_string(),
594 )
595 .await
596 .map_err(|x| std::io::Error::other(x.to_string()))?;
597 println!("\nResult: {res}");
598 }
599 Ok(StreamedAssistantContent::Final(res)) => {
600 let json_res = serde_json::to_string_pretty(&res).unwrap();
601 println!();
602 tracing::info!("Final result: {json_res}");
603 }
604 Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
605 if !is_reasoning {
606 is_reasoning = true;
607 println!();
608 println!("Thinking: ");
609 }
610 let reasoning = reasoning.display_text();
611
612 print!("{reasoning}");
613 std::io::Write::flush(&mut std::io::stdout())?;
614 }
615 Err(e) => {
616 if e.to_string().contains("aborted") {
617 println!("\nStream cancelled.");
618 break;
619 }
620 eprintln!("Error: {e}");
621 break;
622 }
623 _ => {}
624 }
625 }
626
627 println!(); Ok(())
630}
631
632#[cfg(test)]
634mod tests {
635 use std::time::Duration;
636
637 use super::*;
638 use async_stream::stream;
639 use tokio::time::sleep;
640
641 #[derive(Debug, Clone)]
642 pub struct MockResponse {
643 #[allow(dead_code)]
644 token_count: u32,
645 }
646
647 impl GetTokenUsage for MockResponse {
648 fn token_usage(&self) -> Option<crate::completion::Usage> {
649 let mut usage = Usage::new();
650 usage.total_tokens = 15;
651 Some(usage)
652 }
653 }
654
655 #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
656 fn to_stream_result(
657 stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
658 + Send
659 + 'static,
660 ) -> StreamingResult<MockResponse> {
661 Box::pin(stream)
662 }
663
664 #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
665 fn to_stream_result(
666 stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
667 + 'static,
668 ) -> StreamingResult<MockResponse> {
669 Box::pin(stream)
670 }
671
672 fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
673 let stream = stream! {
674 yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
675 sleep(Duration::from_millis(100)).await;
676 yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
677 sleep(Duration::from_millis(100)).await;
678 yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
679 sleep(Duration::from_millis(100)).await;
680 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
681 };
682
683 StreamingCompletionResponse::stream(to_stream_result(stream))
684 }
685
686 fn create_reasoning_stream() -> StreamingCompletionResponse<MockResponse> {
687 let stream = stream! {
688 yield Ok(RawStreamingChoice::Reasoning {
689 id: Some("rs_1".to_string()),
690 content: ReasoningContent::Text {
691 text: "step one".to_string(),
692 signature: Some("sig_1".to_string()),
693 },
694 });
695 yield Ok(RawStreamingChoice::Message("final answer".to_string()));
696 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 5 }));
697 };
698
699 StreamingCompletionResponse::stream(to_stream_result(stream))
700 }
701
702 fn create_reasoning_only_stream() -> StreamingCompletionResponse<MockResponse> {
703 let stream = stream! {
704 yield Ok(RawStreamingChoice::Reasoning {
705 id: Some("rs_only".to_string()),
706 content: ReasoningContent::Summary("hidden summary".to_string()),
707 });
708 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 2 }));
709 };
710
711 StreamingCompletionResponse::stream(to_stream_result(stream))
712 }
713
714 fn create_interleaved_stream() -> StreamingCompletionResponse<MockResponse> {
715 let stream = stream! {
716 yield Ok(RawStreamingChoice::Reasoning {
717 id: Some("rs_interleaved".to_string()),
718 content: ReasoningContent::Text {
719 text: "chain-of-thought".to_string(),
720 signature: None,
721 },
722 });
723 yield Ok(RawStreamingChoice::Message("final-text".to_string()));
724 yield Ok(RawStreamingChoice::ToolCall(
725 RawStreamingToolCall::new(
726 "tool_1".to_string(),
727 "mock_tool".to_string(),
728 serde_json::json!({"arg": 1}),
729 ),
730 ));
731 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 3 }));
732 };
733
734 StreamingCompletionResponse::stream(to_stream_result(stream))
735 }
736
737 fn create_text_tool_text_stream() -> StreamingCompletionResponse<MockResponse> {
738 let stream = stream! {
739 yield Ok(RawStreamingChoice::Message("first".to_string()));
740 yield Ok(RawStreamingChoice::ToolCall(
741 RawStreamingToolCall::new(
742 "tool_split".to_string(),
743 "mock_tool".to_string(),
744 serde_json::json!({"arg": "x"}),
745 ),
746 ));
747 yield Ok(RawStreamingChoice::Message("second".to_string()));
748 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 3 }));
749 };
750
751 StreamingCompletionResponse::stream(to_stream_result(stream))
752 }
753
754 #[tokio::test]
755 async fn test_stream_cancellation() {
756 let mut stream = create_mock_stream();
757
758 println!("Response: ");
759 let mut chunk_count = 0;
760 while let Some(chunk) = stream.next().await {
761 match chunk {
762 Ok(StreamedAssistantContent::Text(text)) => {
763 print!("{}", text.text);
764 std::io::Write::flush(&mut std::io::stdout()).unwrap();
765 chunk_count += 1;
766 }
767 Ok(StreamedAssistantContent::ToolCall {
768 tool_call,
769 internal_call_id,
770 }) => {
771 println!("\nTool Call: {tool_call:?}, internal_call_id={internal_call_id:?}");
772 chunk_count += 1;
773 }
774 Ok(StreamedAssistantContent::ToolCallDelta {
775 id,
776 internal_call_id,
777 content,
778 }) => {
779 println!(
780 "\nTool Call delta: id={id:?}, internal_call_id={internal_call_id:?}, content={content:?}"
781 );
782 chunk_count += 1;
783 }
784 Ok(StreamedAssistantContent::Final(res)) => {
785 println!("\nFinal response: {res:?}");
786 }
787 Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
788 let reasoning = reasoning.display_text();
789 print!("{reasoning}");
790 std::io::Write::flush(&mut std::io::stdout()).unwrap();
791 }
792 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
793 println!("Reasoning delta: {reasoning}");
794 chunk_count += 1;
795 }
796 Err(e) => {
797 eprintln!("Error: {e:?}");
798 break;
799 }
800 }
801
802 if chunk_count >= 2 {
803 println!("\nCancelling stream...");
804 stream.cancel();
805 println!("Stream cancelled.");
806 break;
807 }
808 }
809
810 let next_chunk = stream.next().await;
811 assert!(
812 next_chunk.is_none(),
813 "Expected no further chunks after cancellation, got {next_chunk:?}"
814 );
815 }
816
817 #[tokio::test]
818 async fn test_stream_pause_resume() {
819 let stream = create_mock_stream();
820
821 stream.pause();
823 assert!(stream.is_paused());
824
825 stream.resume();
827 assert!(!stream.is_paused());
828 }
829
830 #[tokio::test]
831 async fn test_stream_aggregates_reasoning_content() {
832 let mut stream = create_reasoning_stream();
833 while stream.next().await.is_some() {}
834
835 let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
836
837 assert!(choice_items.iter().any(|item| matches!(
838 item,
839 AssistantContent::Reasoning(Reasoning {
840 id: Some(id),
841 content
842 }) if id == "rs_1"
843 && matches!(
844 content.first(),
845 Some(ReasoningContent::Text {
846 text,
847 signature: Some(signature)
848 }) if text == "step one" && signature == "sig_1"
849 )
850 )));
851 }
852
853 #[tokio::test]
854 async fn test_stream_reasoning_only_does_not_inject_empty_text() {
855 let mut stream = create_reasoning_only_stream();
856 while stream.next().await.is_some() {}
857
858 let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
859 assert_eq!(choice_items.len(), 1);
860 assert!(matches!(
861 choice_items.first(),
862 Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_only"
863 ));
864 }
865
866 #[tokio::test]
867 async fn test_stream_aggregates_assistant_items_in_arrival_order() {
868 let mut stream = create_interleaved_stream();
869 while stream.next().await.is_some() {}
870
871 let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
872 assert_eq!(choice_items.len(), 3);
873 assert!(matches!(
874 choice_items.first(),
875 Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_interleaved"
876 ));
877 assert!(matches!(
878 choice_items.get(1),
879 Some(AssistantContent::Text(Text { text })) if text == "final-text"
880 ));
881 assert!(matches!(
882 choice_items.get(2),
883 Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_1"
884 ));
885 }
886
887 #[tokio::test]
888 async fn test_stream_keeps_non_contiguous_text_chunks_split_by_tool_call() {
889 let mut stream = create_text_tool_text_stream();
890 while stream.next().await.is_some() {}
891
892 let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
893 assert_eq!(choice_items.len(), 3);
894 assert!(matches!(
895 choice_items.first(),
896 Some(AssistantContent::Text(Text { text })) if text == "first"
897 ));
898 assert!(matches!(
899 choice_items.get(1),
900 Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_split"
901 ));
902 assert!(matches!(
903 choice_items.get(2),
904 Some(AssistantContent::Text(Text { text })) if text == "second"
905 ));
906 }
907}
908
909#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
911#[serde(untagged)]
912pub enum StreamedAssistantContent<R> {
913 Text(Text),
914 ToolCall {
915 tool_call: ToolCall,
916 internal_call_id: String,
919 },
920 ToolCallDelta {
921 id: String,
923 internal_call_id: String,
925 content: ToolCallDeltaContent,
926 },
927 Reasoning(Reasoning),
928 ReasoningDelta {
929 id: Option<String>,
930 reasoning: String,
931 },
932 Final(R),
933}
934
935impl<R> StreamedAssistantContent<R>
936where
937 R: Clone + Unpin,
938{
939 pub fn text(text: &str) -> Self {
940 Self::Text(Text {
941 text: text.to_string(),
942 })
943 }
944
945 pub fn final_response(res: R) -> Self {
946 Self::Final(res)
947 }
948}
949
950#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
952#[serde(untagged)]
953pub enum StreamedUserContent {
954 ToolResult {
955 tool_result: ToolResult,
956 internal_call_id: String,
960 },
961}
962
963impl StreamedUserContent {
964 pub fn tool_result(tool_result: ToolResult, internal_call_id: String) -> Self {
965 Self::ToolResult {
966 tool_result,
967 internal_call_id,
968 }
969 }
970}