1use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::streaming::StreamingPromptRequest;
14use crate::client::FinalCompletionResponse;
15use crate::completion::{
16 CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
17 Message, Usage,
18};
19use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction, ToolResult};
20use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
21use futures::stream::{AbortHandle, Abortable};
22use futures::{Stream, StreamExt};
23use serde::{Deserialize, Serialize};
24use std::future::Future;
25use std::pin::Pin;
26use std::sync::atomic::AtomicBool;
27use std::task::{Context, Poll};
28use tokio::sync::watch;
29
30pub struct PauseControl {
32 pub(crate) paused_tx: watch::Sender<bool>,
33 pub(crate) paused_rx: watch::Receiver<bool>,
34}
35
36impl PauseControl {
37 pub fn new() -> Self {
38 let (paused_tx, paused_rx) = watch::channel(false);
39 Self {
40 paused_tx,
41 paused_rx,
42 }
43 }
44
45 pub fn pause(&self) {
46 self.paused_tx.send(true).unwrap();
47 }
48
49 pub fn resume(&self) {
50 self.paused_tx.send(false).unwrap();
51 }
52
53 pub fn is_paused(&self) -> bool {
54 *self.paused_rx.borrow()
55 }
56}
57
58impl Default for PauseControl {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
66pub enum ToolCallDeltaContent {
67 Name(String),
68 Delta(String),
69}
70
71#[derive(Debug, Clone)]
73pub enum RawStreamingChoice<R>
74where
75 R: Clone,
76{
77 Message(String),
79
80 ToolCall(RawStreamingToolCall),
82 ToolCallDelta {
84 id: String,
86 internal_call_id: String,
88 content: ToolCallDeltaContent,
89 },
90 Reasoning {
92 id: Option<String>,
93 reasoning: String,
94 signature: Option<String>,
95 },
96 ReasoningDelta {
98 id: Option<String>,
99 reasoning: String,
100 },
101
102 FinalResponse(R),
105}
106
107#[derive(Debug, Clone)]
109pub struct RawStreamingToolCall {
110 pub id: String,
112 pub internal_call_id: String,
114 pub call_id: Option<String>,
115 pub name: String,
116 pub arguments: serde_json::Value,
117 pub signature: Option<String>,
118 pub additional_params: Option<serde_json::Value>,
119}
120
121impl RawStreamingToolCall {
122 pub fn empty() -> Self {
123 Self {
124 id: String::new(),
125 internal_call_id: nanoid::nanoid!(),
126 call_id: None,
127 name: String::new(),
128 arguments: serde_json::Value::Null,
129 signature: None,
130 additional_params: None,
131 }
132 }
133
134 pub fn new(id: String, name: String, arguments: serde_json::Value) -> Self {
135 Self {
136 id,
137 internal_call_id: nanoid::nanoid!(),
138 call_id: None,
139 name,
140 arguments,
141 signature: None,
142 additional_params: None,
143 }
144 }
145
146 pub fn with_internal_call_id(mut self, internal_call_id: String) -> Self {
147 self.internal_call_id = internal_call_id;
148 self
149 }
150
151 pub fn with_call_id(mut self, call_id: String) -> Self {
152 self.call_id = Some(call_id);
153 self
154 }
155
156 pub fn with_signature(mut self, signature: Option<String>) -> Self {
157 self.signature = signature;
158 self
159 }
160
161 pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
162 self.additional_params = additional_params;
163 self
164 }
165}
166
167impl From<RawStreamingToolCall> for ToolCall {
168 fn from(tool_call: RawStreamingToolCall) -> Self {
169 ToolCall {
170 id: tool_call.id,
171 call_id: tool_call.call_id,
172 function: ToolFunction {
173 name: tool_call.name,
174 arguments: tool_call.arguments,
175 },
176 signature: tool_call.signature,
177 additional_params: tool_call.additional_params,
178 }
179 }
180}
181
182#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
183pub type StreamingResult<R> =
184 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
185
186#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
187pub type StreamingResult<R> =
188 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
189
190pub struct StreamingCompletionResponse<R>
194where
195 R: Clone + Unpin + GetTokenUsage,
196{
197 pub(crate) inner: Abortable<StreamingResult<R>>,
198 pub(crate) abort_handle: AbortHandle,
199 pub(crate) pause_control: PauseControl,
200 text: String,
201 reasoning: String,
202 tool_calls: Vec<ToolCall>,
203 pub choice: OneOrMany<AssistantContent>,
206 pub response: Option<R>,
209 pub final_response_yielded: AtomicBool,
210}
211
212impl<R> StreamingCompletionResponse<R>
213where
214 R: Clone + Unpin + GetTokenUsage,
215{
216 pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
217 let (abort_handle, abort_registration) = AbortHandle::new_pair();
218 let abortable_stream = Abortable::new(inner, abort_registration);
219 let pause_control = PauseControl::new();
220 Self {
221 inner: abortable_stream,
222 abort_handle,
223 pause_control,
224 reasoning: String::new(),
225 text: "".to_string(),
226 tool_calls: vec![],
227 choice: OneOrMany::one(AssistantContent::text("")),
228 response: None,
229 final_response_yielded: AtomicBool::new(false),
230 }
231 }
232
233 pub fn cancel(&self) {
234 self.abort_handle.abort();
235 }
236
237 pub fn pause(&self) {
238 self.pause_control.pause();
239 }
240
241 pub fn resume(&self) {
242 self.pause_control.resume();
243 }
244
245 pub fn is_paused(&self) -> bool {
246 self.pause_control.is_paused()
247 }
248}
249
250impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
251where
252 R: Clone + Unpin + GetTokenUsage,
253{
254 fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
255 CompletionResponse {
256 choice: value.choice,
257 usage: Usage::new(), raw_response: value.response,
259 }
260 }
261}
262
263impl<R> Stream for StreamingCompletionResponse<R>
264where
265 R: Clone + Unpin + GetTokenUsage,
266{
267 type Item = Result<StreamedAssistantContent<R>, CompletionError>;
268
269 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
270 let stream = self.get_mut();
271
272 if stream.is_paused() {
273 cx.waker().wake_by_ref();
274 return Poll::Pending;
275 }
276
277 match Pin::new(&mut stream.inner).poll_next(cx) {
278 Poll::Pending => Poll::Pending,
279 Poll::Ready(None) => {
280 let mut choice = vec![];
283
284 stream.tool_calls.iter().for_each(|tc| {
285 choice.push(AssistantContent::ToolCall(tc.clone()));
286 });
287
288 if choice.is_empty() || !stream.text.is_empty() {
290 choice.insert(0, AssistantContent::text(stream.text.clone()));
291 }
292
293 stream.choice = OneOrMany::many(choice)
294 .expect("There should be at least one assistant message");
295
296 Poll::Ready(None)
297 }
298 Poll::Ready(Some(Err(err))) => {
299 if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
300 {
301 return Poll::Ready(None); }
303 Poll::Ready(Some(Err(err)))
304 }
305 Poll::Ready(Some(Ok(choice))) => match choice {
306 RawStreamingChoice::Message(text) => {
307 stream.text = format!("{}{}", stream.text, text);
310 Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
311 }
312 RawStreamingChoice::ToolCallDelta {
313 id,
314 internal_call_id,
315 content,
316 } => Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
317 id,
318 internal_call_id,
319 content,
320 }))),
321 RawStreamingChoice::Reasoning {
322 id,
323 reasoning,
324 signature,
325 } => Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
326 id,
327 reasoning: vec![reasoning],
328 signature,
329 })))),
330 RawStreamingChoice::ReasoningDelta { id, reasoning } => {
331 stream.reasoning = format!("{}{}", stream.reasoning, reasoning);
334 Poll::Ready(Some(Ok(StreamedAssistantContent::ReasoningDelta {
335 id,
336 reasoning,
337 })))
338 }
339 RawStreamingChoice::ToolCall(raw_tool_call) => {
340 let internal_call_id = raw_tool_call.internal_call_id.clone();
343 let tool_call: ToolCall = raw_tool_call.into();
344 stream.tool_calls.push(tool_call.clone());
345 Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCall {
346 tool_call,
347 internal_call_id,
348 })))
349 }
350 RawStreamingChoice::FinalResponse(response) => {
351 if stream
352 .final_response_yielded
353 .load(std::sync::atomic::Ordering::SeqCst)
354 {
355 stream.poll_next_unpin(cx)
356 } else {
357 stream.response = Some(response.clone());
359 stream
360 .final_response_yielded
361 .store(true, std::sync::atomic::Ordering::SeqCst);
362 let final_response = StreamedAssistantContent::final_response(response);
363 Poll::Ready(Some(Ok(final_response)))
364 }
365 }
366 },
367 }
368 }
369}
370
371pub trait StreamingPrompt<M, R>
373where
374 M: CompletionModel + 'static,
375 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
376 R: Clone + Unpin + GetTokenUsage,
377{
378 fn stream_prompt(
380 &self,
381 prompt: impl Into<Message> + WasmCompatSend,
382 ) -> StreamingPromptRequest<M, ()>;
383}
384
385pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
387where
388 M: CompletionModel + 'static,
389 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
390 R: Clone + Unpin + GetTokenUsage,
391{
392 fn stream_chat(
394 &self,
395 prompt: impl Into<Message> + WasmCompatSend,
396 chat_history: Vec<Message>,
397 ) -> StreamingPromptRequest<M, ()>;
398}
399
400pub trait StreamingCompletion<M: CompletionModel> {
402 fn stream_completion(
404 &self,
405 prompt: impl Into<Message> + WasmCompatSend,
406 chat_history: Vec<Message>,
407 ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
408}
409
410pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
411 pub(crate) inner: StreamingResult<R>,
412}
413
414impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
415 type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;
416
417 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
418 let stream = self.get_mut();
419
420 match stream.inner.as_mut().poll_next(cx) {
421 Poll::Pending => Poll::Pending,
422 Poll::Ready(None) => Poll::Ready(None),
423 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
424 Poll::Ready(Some(Ok(chunk))) => match chunk {
425 RawStreamingChoice::FinalResponse(res) => Poll::Ready(Some(Ok(
426 RawStreamingChoice::FinalResponse(FinalCompletionResponse {
427 usage: res.token_usage(),
428 }),
429 ))),
430 RawStreamingChoice::Message(m) => {
431 Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
432 }
433 RawStreamingChoice::ToolCallDelta {
434 id,
435 internal_call_id,
436 content,
437 } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCallDelta {
438 id,
439 internal_call_id,
440 content,
441 }))),
442 RawStreamingChoice::Reasoning {
443 id,
444 reasoning,
445 signature,
446 } => Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning {
447 id,
448 reasoning,
449 signature,
450 }))),
451 RawStreamingChoice::ReasoningDelta { id, reasoning } => {
452 Poll::Ready(Some(Ok(RawStreamingChoice::ReasoningDelta {
453 id,
454 reasoning,
455 })))
456 }
457 RawStreamingChoice::ToolCall(tool_call) => {
458 Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall(tool_call))))
459 }
460 },
461 }
462 }
463}
464
465pub async fn stream_to_stdout<M>(
468 agent: &'static Agent<M>,
469 stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
470) -> Result<(), std::io::Error>
471where
472 M: CompletionModel,
473{
474 let mut is_reasoning = false;
475 print!("Response: ");
476 while let Some(chunk) = stream.next().await {
477 match chunk {
478 Ok(StreamedAssistantContent::Text(text)) => {
479 if is_reasoning {
480 is_reasoning = false;
481 println!("\n---\n");
482 }
483 print!("{}", text.text);
484 std::io::Write::flush(&mut std::io::stdout())?;
485 }
486 Ok(StreamedAssistantContent::ToolCall {
487 tool_call,
488 internal_call_id: _,
489 }) => {
490 let res = agent
491 .tool_server_handle
492 .call_tool(
493 &tool_call.function.name,
494 &tool_call.function.arguments.to_string(),
495 )
496 .await
497 .map_err(|x| std::io::Error::other(x.to_string()))?;
498 println!("\nResult: {res}");
499 }
500 Ok(StreamedAssistantContent::Final(res)) => {
501 let json_res = serde_json::to_string_pretty(&res).unwrap();
502 println!();
503 tracing::info!("Final result: {json_res}");
504 }
505 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
506 if !is_reasoning {
507 is_reasoning = true;
508 println!();
509 println!("Thinking: ");
510 }
511 let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
512
513 print!("{reasoning}");
514 std::io::Write::flush(&mut std::io::stdout())?;
515 }
516 Err(e) => {
517 if e.to_string().contains("aborted") {
518 println!("\nStream cancelled.");
519 break;
520 }
521 eprintln!("Error: {e}");
522 break;
523 }
524 _ => {}
525 }
526 }
527
528 println!(); Ok(())
531}
532
533#[cfg(test)]
535mod tests {
536 use std::time::Duration;
537
538 use super::*;
539 use async_stream::stream;
540 use tokio::time::sleep;
541
542 #[derive(Debug, Clone)]
543 pub struct MockResponse {
544 #[allow(dead_code)]
545 token_count: u32,
546 }
547
548 impl GetTokenUsage for MockResponse {
549 fn token_usage(&self) -> Option<crate::completion::Usage> {
550 let mut usage = Usage::new();
551 usage.total_tokens = 15;
552 Some(usage)
553 }
554 }
555
556 fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
557 let stream = stream! {
558 yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
559 sleep(Duration::from_millis(100)).await;
560 yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
561 sleep(Duration::from_millis(100)).await;
562 yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
563 sleep(Duration::from_millis(100)).await;
564 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
565 };
566
567 #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
568 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
569 #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
570 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
571
572 StreamingCompletionResponse::stream(pinned_stream)
573 }
574
575 #[tokio::test]
576 async fn test_stream_cancellation() {
577 let mut stream = create_mock_stream();
578
579 println!("Response: ");
580 let mut chunk_count = 0;
581 while let Some(chunk) = stream.next().await {
582 match chunk {
583 Ok(StreamedAssistantContent::Text(text)) => {
584 print!("{}", text.text);
585 std::io::Write::flush(&mut std::io::stdout()).unwrap();
586 chunk_count += 1;
587 }
588 Ok(StreamedAssistantContent::ToolCall {
589 tool_call,
590 internal_call_id,
591 }) => {
592 println!("\nTool Call: {tool_call:?}, internal_call_id={internal_call_id:?}");
593 chunk_count += 1;
594 }
595 Ok(StreamedAssistantContent::ToolCallDelta {
596 id,
597 internal_call_id,
598 content,
599 }) => {
600 println!(
601 "\nTool Call delta: id={id:?}, internal_call_id={internal_call_id:?}, content={content:?}"
602 );
603 chunk_count += 1;
604 }
605 Ok(StreamedAssistantContent::Final(res)) => {
606 println!("\nFinal response: {res:?}");
607 }
608 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
609 let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
610 print!("{reasoning}");
611 std::io::Write::flush(&mut std::io::stdout()).unwrap();
612 }
613 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
614 println!("Reasoning delta: {reasoning}");
615 chunk_count += 1;
616 }
617 Err(e) => {
618 eprintln!("Error: {e:?}");
619 break;
620 }
621 }
622
623 if chunk_count >= 2 {
624 println!("\nCancelling stream...");
625 stream.cancel();
626 println!("Stream cancelled.");
627 break;
628 }
629 }
630
631 let next_chunk = stream.next().await;
632 assert!(
633 next_chunk.is_none(),
634 "Expected no further chunks after cancellation, got {next_chunk:?}"
635 );
636 }
637
638 #[tokio::test]
639 async fn test_stream_pause_resume() {
640 let stream = create_mock_stream();
641
642 stream.pause();
644 assert!(stream.is_paused());
645
646 stream.resume();
648 assert!(!stream.is_paused());
649 }
650}
651
652#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
654#[serde(untagged)]
655pub enum StreamedAssistantContent<R> {
656 Text(Text),
657 ToolCall {
658 tool_call: ToolCall,
659 internal_call_id: String,
662 },
663 ToolCallDelta {
664 id: String,
666 internal_call_id: String,
668 content: ToolCallDeltaContent,
669 },
670 Reasoning(Reasoning),
671 ReasoningDelta {
672 id: Option<String>,
673 reasoning: String,
674 },
675 Final(R),
676}
677
678impl<R> StreamedAssistantContent<R>
679where
680 R: Clone + Unpin,
681{
682 pub fn text(text: &str) -> Self {
683 Self::Text(Text {
684 text: text.to_string(),
685 })
686 }
687
688 pub fn final_response(res: R) -> Self {
689 Self::Final(res)
690 }
691}
692
693#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
695#[serde(untagged)]
696pub enum StreamedUserContent {
697 ToolResult {
698 tool_result: ToolResult,
699 internal_call_id: String,
703 },
704}
705
706impl StreamedUserContent {
707 pub fn tool_result(tool_result: ToolResult, internal_call_id: String) -> Self {
708 Self::ToolResult {
709 tool_result,
710 internal_call_id,
711 }
712 }
713}