1use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::streaming::StreamingPromptRequest;
14use crate::client::FinalCompletionResponse;
15use crate::completion::{
16 CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
17 Message, Usage,
18};
19use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction, ToolResult};
20use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
21use futures::stream::{AbortHandle, Abortable};
22use futures::{Stream, StreamExt};
23use serde::{Deserialize, Serialize};
24use std::future::Future;
25use std::pin::Pin;
26use std::sync::atomic::AtomicBool;
27use std::task::{Context, Poll};
28use tokio::sync::watch;
29
30pub struct PauseControl {
32 pub(crate) paused_tx: watch::Sender<bool>,
33 pub(crate) paused_rx: watch::Receiver<bool>,
34}
35
36impl PauseControl {
37 pub fn new() -> Self {
38 let (paused_tx, paused_rx) = watch::channel(false);
39 Self {
40 paused_tx,
41 paused_rx,
42 }
43 }
44
45 pub fn pause(&self) {
46 self.paused_tx.send(true).unwrap();
47 }
48
49 pub fn resume(&self) {
50 self.paused_tx.send(false).unwrap();
51 }
52
53 pub fn is_paused(&self) -> bool {
54 *self.paused_rx.borrow()
55 }
56}
57
58impl Default for PauseControl {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
66pub enum ToolCallDeltaContent {
67 Name(String),
68 Delta(String),
69}
70
71#[derive(Debug, Clone)]
73pub enum RawStreamingChoice<R>
74where
75 R: Clone,
76{
77 Message(String),
79
80 ToolCall(RawStreamingToolCall),
82 ToolCallDelta {
84 id: String,
85 content: ToolCallDeltaContent,
86 },
87 Reasoning {
89 id: Option<String>,
90 reasoning: String,
91 signature: Option<String>,
92 },
93 ReasoningDelta {
95 id: Option<String>,
96 reasoning: String,
97 },
98
99 FinalResponse(R),
102}
103
104#[derive(Debug, Clone)]
106pub struct RawStreamingToolCall {
107 pub id: String,
108 pub call_id: Option<String>,
109 pub name: String,
110 pub arguments: serde_json::Value,
111 pub signature: Option<String>,
112 pub additional_params: Option<serde_json::Value>,
113}
114
115impl RawStreamingToolCall {
116 pub fn empty() -> Self {
117 Self {
118 id: String::new(),
119 call_id: None,
120 name: String::new(),
121 arguments: serde_json::Value::Null,
122 signature: None,
123 additional_params: None,
124 }
125 }
126
127 pub fn new(id: String, name: String, arguments: serde_json::Value) -> Self {
128 Self {
129 id,
130 call_id: None,
131 name,
132 arguments,
133 signature: None,
134 additional_params: None,
135 }
136 }
137
138 pub fn with_call_id(mut self, call_id: String) -> Self {
139 self.call_id = Some(call_id);
140 self
141 }
142
143 pub fn with_signature(mut self, signature: Option<String>) -> Self {
144 self.signature = signature;
145 self
146 }
147
148 pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
149 self.additional_params = additional_params;
150 self
151 }
152}
153
154impl From<RawStreamingToolCall> for ToolCall {
155 fn from(tool_call: RawStreamingToolCall) -> Self {
156 ToolCall {
157 id: tool_call.id,
158 call_id: tool_call.call_id,
159 function: ToolFunction {
160 name: tool_call.name,
161 arguments: tool_call.arguments,
162 },
163 signature: tool_call.signature,
164 additional_params: tool_call.additional_params,
165 }
166 }
167}
168
169#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
170pub type StreamingResult<R> =
171 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
172
173#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
174pub type StreamingResult<R> =
175 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
176
177pub struct StreamingCompletionResponse<R>
181where
182 R: Clone + Unpin + GetTokenUsage,
183{
184 pub(crate) inner: Abortable<StreamingResult<R>>,
185 pub(crate) abort_handle: AbortHandle,
186 pub(crate) pause_control: PauseControl,
187 text: String,
188 reasoning: String,
189 tool_calls: Vec<ToolCall>,
190 pub choice: OneOrMany<AssistantContent>,
193 pub response: Option<R>,
196 pub final_response_yielded: AtomicBool,
197}
198
199impl<R> StreamingCompletionResponse<R>
200where
201 R: Clone + Unpin + GetTokenUsage,
202{
203 pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
204 let (abort_handle, abort_registration) = AbortHandle::new_pair();
205 let abortable_stream = Abortable::new(inner, abort_registration);
206 let pause_control = PauseControl::new();
207 Self {
208 inner: abortable_stream,
209 abort_handle,
210 pause_control,
211 reasoning: String::new(),
212 text: "".to_string(),
213 tool_calls: vec![],
214 choice: OneOrMany::one(AssistantContent::text("")),
215 response: None,
216 final_response_yielded: AtomicBool::new(false),
217 }
218 }
219
220 pub fn cancel(&self) {
221 self.abort_handle.abort();
222 }
223
224 pub fn pause(&self) {
225 self.pause_control.pause();
226 }
227
228 pub fn resume(&self) {
229 self.pause_control.resume();
230 }
231
232 pub fn is_paused(&self) -> bool {
233 self.pause_control.is_paused()
234 }
235}
236
237impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
238where
239 R: Clone + Unpin + GetTokenUsage,
240{
241 fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
242 CompletionResponse {
243 choice: value.choice,
244 usage: Usage::new(), raw_response: value.response,
246 }
247 }
248}
249
250impl<R> Stream for StreamingCompletionResponse<R>
251where
252 R: Clone + Unpin + GetTokenUsage,
253{
254 type Item = Result<StreamedAssistantContent<R>, CompletionError>;
255
256 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
257 let stream = self.get_mut();
258
259 if stream.is_paused() {
260 cx.waker().wake_by_ref();
261 return Poll::Pending;
262 }
263
264 match Pin::new(&mut stream.inner).poll_next(cx) {
265 Poll::Pending => Poll::Pending,
266 Poll::Ready(None) => {
267 let mut choice = vec![];
270
271 stream.tool_calls.iter().for_each(|tc| {
272 choice.push(AssistantContent::ToolCall(tc.clone()));
273 });
274
275 if choice.is_empty() || !stream.text.is_empty() {
277 choice.insert(0, AssistantContent::text(stream.text.clone()));
278 }
279
280 stream.choice = OneOrMany::many(choice)
281 .expect("There should be at least one assistant message");
282
283 Poll::Ready(None)
284 }
285 Poll::Ready(Some(Err(err))) => {
286 if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
287 {
288 return Poll::Ready(None); }
290 Poll::Ready(Some(Err(err)))
291 }
292 Poll::Ready(Some(Ok(choice))) => match choice {
293 RawStreamingChoice::Message(text) => {
294 stream.text = format!("{}{}", stream.text, text);
297 Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
298 }
299 RawStreamingChoice::ToolCallDelta { id, content } => {
300 Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
301 id,
302 content,
303 })))
304 }
305 RawStreamingChoice::Reasoning {
306 id,
307 reasoning,
308 signature,
309 } => Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
310 id,
311 reasoning: vec![reasoning],
312 signature,
313 })))),
314 RawStreamingChoice::ReasoningDelta { id, reasoning } => {
315 stream.reasoning = format!("{}{}", stream.reasoning, reasoning);
318 Poll::Ready(Some(Ok(StreamedAssistantContent::ReasoningDelta {
319 id,
320 reasoning,
321 })))
322 }
323 RawStreamingChoice::ToolCall(tool_call) => {
324 let tool_call: ToolCall = tool_call.into();
327 stream.tool_calls.push(tool_call.clone());
328 Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCall(tool_call))))
329 }
330 RawStreamingChoice::FinalResponse(response) => {
331 if stream
332 .final_response_yielded
333 .load(std::sync::atomic::Ordering::SeqCst)
334 {
335 stream.poll_next_unpin(cx)
336 } else {
337 stream.response = Some(response.clone());
339 stream
340 .final_response_yielded
341 .store(true, std::sync::atomic::Ordering::SeqCst);
342 let final_response = StreamedAssistantContent::final_response(response);
343 Poll::Ready(Some(Ok(final_response)))
344 }
345 }
346 },
347 }
348 }
349}
350
351pub trait StreamingPrompt<M, R>
353where
354 M: CompletionModel + 'static,
355 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
356 R: Clone + Unpin + GetTokenUsage,
357{
358 fn stream_prompt(
360 &self,
361 prompt: impl Into<Message> + WasmCompatSend,
362 ) -> StreamingPromptRequest<M, ()>;
363}
364
365pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
367where
368 M: CompletionModel + 'static,
369 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
370 R: Clone + Unpin + GetTokenUsage,
371{
372 fn stream_chat(
374 &self,
375 prompt: impl Into<Message> + WasmCompatSend,
376 chat_history: Vec<Message>,
377 ) -> StreamingPromptRequest<M, ()>;
378}
379
380pub trait StreamingCompletion<M: CompletionModel> {
382 fn stream_completion(
384 &self,
385 prompt: impl Into<Message> + WasmCompatSend,
386 chat_history: Vec<Message>,
387 ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
388}
389
390pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
391 pub(crate) inner: StreamingResult<R>,
392}
393
394impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
395 type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;
396
397 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398 let stream = self.get_mut();
399
400 match stream.inner.as_mut().poll_next(cx) {
401 Poll::Pending => Poll::Pending,
402 Poll::Ready(None) => Poll::Ready(None),
403 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
404 Poll::Ready(Some(Ok(chunk))) => match chunk {
405 RawStreamingChoice::FinalResponse(res) => Poll::Ready(Some(Ok(
406 RawStreamingChoice::FinalResponse(FinalCompletionResponse {
407 usage: res.token_usage(),
408 }),
409 ))),
410 RawStreamingChoice::Message(m) => {
411 Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
412 }
413 RawStreamingChoice::ToolCallDelta { id, content } => {
414 Poll::Ready(Some(Ok(RawStreamingChoice::ToolCallDelta { id, content })))
415 }
416 RawStreamingChoice::Reasoning {
417 id,
418 reasoning,
419 signature,
420 } => Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning {
421 id,
422 reasoning,
423 signature,
424 }))),
425 RawStreamingChoice::ReasoningDelta { id, reasoning } => {
426 Poll::Ready(Some(Ok(RawStreamingChoice::ReasoningDelta {
427 id,
428 reasoning,
429 })))
430 }
431 RawStreamingChoice::ToolCall(tool_call) => {
432 Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall(tool_call))))
433 }
434 },
435 }
436 }
437}
438
439pub async fn stream_to_stdout<M>(
442 agent: &'static Agent<M>,
443 stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
444) -> Result<(), std::io::Error>
445where
446 M: CompletionModel,
447{
448 let mut is_reasoning = false;
449 print!("Response: ");
450 while let Some(chunk) = stream.next().await {
451 match chunk {
452 Ok(StreamedAssistantContent::Text(text)) => {
453 if is_reasoning {
454 is_reasoning = false;
455 println!("\n---\n");
456 }
457 print!("{}", text.text);
458 std::io::Write::flush(&mut std::io::stdout())?;
459 }
460 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
461 let res = agent
462 .tool_server_handle
463 .call_tool(
464 &tool_call.function.name,
465 &tool_call.function.arguments.to_string(),
466 )
467 .await
468 .map_err(|x| std::io::Error::other(x.to_string()))?;
469 println!("\nResult: {res}");
470 }
471 Ok(StreamedAssistantContent::Final(res)) => {
472 let json_res = serde_json::to_string_pretty(&res).unwrap();
473 println!();
474 tracing::info!("Final result: {json_res}");
475 }
476 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
477 if !is_reasoning {
478 is_reasoning = true;
479 println!();
480 println!("Thinking: ");
481 }
482 let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
483
484 print!("{reasoning}");
485 std::io::Write::flush(&mut std::io::stdout())?;
486 }
487 Err(e) => {
488 if e.to_string().contains("aborted") {
489 println!("\nStream cancelled.");
490 break;
491 }
492 eprintln!("Error: {e}");
493 break;
494 }
495 _ => {}
496 }
497 }
498
499 println!(); Ok(())
502}
503
504#[cfg(test)]
506mod tests {
507 use std::time::Duration;
508
509 use super::*;
510 use async_stream::stream;
511 use tokio::time::sleep;
512
513 #[derive(Debug, Clone)]
514 pub struct MockResponse {
515 #[allow(dead_code)]
516 token_count: u32,
517 }
518
519 impl GetTokenUsage for MockResponse {
520 fn token_usage(&self) -> Option<crate::completion::Usage> {
521 let mut usage = Usage::new();
522 usage.total_tokens = 15;
523 Some(usage)
524 }
525 }
526
527 fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
528 let stream = stream! {
529 yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
530 sleep(Duration::from_millis(100)).await;
531 yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
532 sleep(Duration::from_millis(100)).await;
533 yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
534 sleep(Duration::from_millis(100)).await;
535 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
536 };
537
538 #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
539 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
540 #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
541 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
542
543 StreamingCompletionResponse::stream(pinned_stream)
544 }
545
546 #[tokio::test]
547 async fn test_stream_cancellation() {
548 let mut stream = create_mock_stream();
549
550 println!("Response: ");
551 let mut chunk_count = 0;
552 while let Some(chunk) = stream.next().await {
553 match chunk {
554 Ok(StreamedAssistantContent::Text(text)) => {
555 print!("{}", text.text);
556 std::io::Write::flush(&mut std::io::stdout()).unwrap();
557 chunk_count += 1;
558 }
559 Ok(StreamedAssistantContent::ToolCall(tc)) => {
560 println!("\nTool Call: {tc:?}");
561 chunk_count += 1;
562 }
563 Ok(StreamedAssistantContent::ToolCallDelta { id, content }) => {
564 println!("\nTool Call delta: id={id:?}, content={content:?}");
565 chunk_count += 1;
566 }
567 Ok(StreamedAssistantContent::Final(res)) => {
568 println!("\nFinal response: {res:?}");
569 }
570 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
571 let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
572 print!("{reasoning}");
573 std::io::Write::flush(&mut std::io::stdout()).unwrap();
574 }
575 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
576 println!("Reasoning delta: {reasoning}");
577 chunk_count += 1;
578 }
579 Err(e) => {
580 eprintln!("Error: {e:?}");
581 break;
582 }
583 }
584
585 if chunk_count >= 2 {
586 println!("\nCancelling stream...");
587 stream.cancel();
588 println!("Stream cancelled.");
589 break;
590 }
591 }
592
593 let next_chunk = stream.next().await;
594 assert!(
595 next_chunk.is_none(),
596 "Expected no further chunks after cancellation, got {next_chunk:?}"
597 );
598 }
599
600 #[tokio::test]
601 async fn test_stream_pause_resume() {
602 let stream = create_mock_stream();
603
604 stream.pause();
606 assert!(stream.is_paused());
607
608 stream.resume();
610 assert!(!stream.is_paused());
611 }
612}
613
614#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
616#[serde(untagged)]
617pub enum StreamedAssistantContent<R> {
618 Text(Text),
619 ToolCall(ToolCall),
620 ToolCallDelta {
621 id: String,
622 content: ToolCallDeltaContent,
623 },
624 Reasoning(Reasoning),
625 ReasoningDelta {
626 id: Option<String>,
627 reasoning: String,
628 },
629 Final(R),
630}
631
632impl<R> StreamedAssistantContent<R>
633where
634 R: Clone + Unpin,
635{
636 pub fn text(text: &str) -> Self {
637 Self::Text(Text {
638 text: text.to_string(),
639 })
640 }
641
642 pub fn final_response(res: R) -> Self {
643 Self::Final(res)
644 }
645}
646
647#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
649#[serde(untagged)]
650pub enum StreamedUserContent {
651 ToolResult(ToolResult),
652}
653
654impl StreamedUserContent {
655 pub fn tool_result(tool_result: ToolResult) -> Self {
656 Self::ToolResult(tool_result)
657 }
658}