1use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::streaming::StreamingPromptRequest;
14use crate::client::FinalCompletionResponse;
15use crate::completion::{
16 CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
17 Message, Usage,
18};
19use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction, ToolResult};
20use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
21use futures::stream::{AbortHandle, Abortable};
22use futures::{Stream, StreamExt};
23use serde::{Deserialize, Serialize};
24use std::future::Future;
25use std::pin::Pin;
26use std::sync::atomic::AtomicBool;
27use std::task::{Context, Poll};
28use tokio::sync::watch;
29
30pub struct PauseControl {
32 pub(crate) paused_tx: watch::Sender<bool>,
33 pub(crate) paused_rx: watch::Receiver<bool>,
34}
35
36impl PauseControl {
37 pub fn new() -> Self {
38 let (paused_tx, paused_rx) = watch::channel(false);
39 Self {
40 paused_tx,
41 paused_rx,
42 }
43 }
44
45 pub fn pause(&self) {
46 self.paused_tx.send(true).unwrap();
47 }
48
49 pub fn resume(&self) {
50 self.paused_tx.send(false).unwrap();
51 }
52
53 pub fn is_paused(&self) -> bool {
54 *self.paused_rx.borrow()
55 }
56}
57
58impl Default for PauseControl {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64#[derive(Debug, Clone)]
66pub enum RawStreamingChoice<R>
67where
68 R: Clone,
69{
70 Message(String),
72
73 ToolCall(RawStreamingToolCall),
75 ToolCallDelta { id: String, delta: String },
77 Reasoning {
79 id: Option<String>,
80 reasoning: String,
81 signature: Option<String>,
82 },
83 ReasoningDelta {
85 id: Option<String>,
86 reasoning: String,
87 },
88
89 FinalResponse(R),
92}
93
94#[derive(Debug, Clone)]
96pub struct RawStreamingToolCall {
97 pub id: String,
98 pub call_id: Option<String>,
99 pub name: String,
100 pub arguments: serde_json::Value,
101 pub signature: Option<String>,
102 pub additional_params: Option<serde_json::Value>,
103}
104
105impl RawStreamingToolCall {
106 pub fn empty() -> Self {
107 Self {
108 id: String::new(),
109 call_id: None,
110 name: String::new(),
111 arguments: serde_json::Value::Null,
112 signature: None,
113 additional_params: None,
114 }
115 }
116
117 pub fn new(id: String, name: String, arguments: serde_json::Value) -> Self {
118 Self {
119 id,
120 call_id: None,
121 name,
122 arguments,
123 signature: None,
124 additional_params: None,
125 }
126 }
127
128 pub fn with_call_id(mut self, call_id: String) -> Self {
129 self.call_id = Some(call_id);
130 self
131 }
132
133 pub fn with_signature(mut self, signature: Option<String>) -> Self {
134 self.signature = signature;
135 self
136 }
137
138 pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
139 self.additional_params = additional_params;
140 self
141 }
142}
143
144impl From<RawStreamingToolCall> for ToolCall {
145 fn from(tool_call: RawStreamingToolCall) -> Self {
146 ToolCall {
147 id: tool_call.id,
148 call_id: tool_call.call_id,
149 function: ToolFunction {
150 name: tool_call.name,
151 arguments: tool_call.arguments,
152 },
153 signature: tool_call.signature,
154 additional_params: tool_call.additional_params,
155 }
156 }
157}
158
159#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
160pub type StreamingResult<R> =
161 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
162
163#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
164pub type StreamingResult<R> =
165 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
166
167pub struct StreamingCompletionResponse<R>
171where
172 R: Clone + Unpin + GetTokenUsage,
173{
174 pub(crate) inner: Abortable<StreamingResult<R>>,
175 pub(crate) abort_handle: AbortHandle,
176 pub(crate) pause_control: PauseControl,
177 text: String,
178 reasoning: String,
179 tool_calls: Vec<ToolCall>,
180 pub choice: OneOrMany<AssistantContent>,
183 pub response: Option<R>,
186 pub final_response_yielded: AtomicBool,
187}
188
189impl<R> StreamingCompletionResponse<R>
190where
191 R: Clone + Unpin + GetTokenUsage,
192{
193 pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
194 let (abort_handle, abort_registration) = AbortHandle::new_pair();
195 let abortable_stream = Abortable::new(inner, abort_registration);
196 let pause_control = PauseControl::new();
197 Self {
198 inner: abortable_stream,
199 abort_handle,
200 pause_control,
201 reasoning: String::new(),
202 text: "".to_string(),
203 tool_calls: vec![],
204 choice: OneOrMany::one(AssistantContent::text("")),
205 response: None,
206 final_response_yielded: AtomicBool::new(false),
207 }
208 }
209
210 pub fn cancel(&self) {
211 self.abort_handle.abort();
212 }
213
214 pub fn pause(&self) {
215 self.pause_control.pause();
216 }
217
218 pub fn resume(&self) {
219 self.pause_control.resume();
220 }
221
222 pub fn is_paused(&self) -> bool {
223 self.pause_control.is_paused()
224 }
225}
226
227impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
228where
229 R: Clone + Unpin + GetTokenUsage,
230{
231 fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
232 CompletionResponse {
233 choice: value.choice,
234 usage: Usage::new(), raw_response: value.response,
236 }
237 }
238}
239
240impl<R> Stream for StreamingCompletionResponse<R>
241where
242 R: Clone + Unpin + GetTokenUsage,
243{
244 type Item = Result<StreamedAssistantContent<R>, CompletionError>;
245
246 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
247 let stream = self.get_mut();
248
249 if stream.is_paused() {
250 cx.waker().wake_by_ref();
251 return Poll::Pending;
252 }
253
254 match Pin::new(&mut stream.inner).poll_next(cx) {
255 Poll::Pending => Poll::Pending,
256 Poll::Ready(None) => {
257 let mut choice = vec![];
260
261 stream.tool_calls.iter().for_each(|tc| {
262 choice.push(AssistantContent::ToolCall(tc.clone()));
263 });
264
265 if choice.is_empty() || !stream.text.is_empty() {
267 choice.insert(0, AssistantContent::text(stream.text.clone()));
268 }
269
270 stream.choice = OneOrMany::many(choice)
271 .expect("There should be at least one assistant message");
272
273 Poll::Ready(None)
274 }
275 Poll::Ready(Some(Err(err))) => {
276 if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
277 {
278 return Poll::Ready(None); }
280 Poll::Ready(Some(Err(err)))
281 }
282 Poll::Ready(Some(Ok(choice))) => match choice {
283 RawStreamingChoice::Message(text) => {
284 stream.text = format!("{}{}", stream.text, text);
287 Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
288 }
289 RawStreamingChoice::ToolCallDelta { id, delta } => {
290 Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
291 id,
292 delta,
293 })))
294 }
295 RawStreamingChoice::Reasoning {
296 id,
297 reasoning,
298 signature,
299 } => Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
300 id,
301 reasoning: vec![reasoning],
302 signature,
303 })))),
304 RawStreamingChoice::ReasoningDelta { id, reasoning } => {
305 stream.reasoning = format!("{}{}", stream.reasoning, reasoning);
308 Poll::Ready(Some(Ok(StreamedAssistantContent::ReasoningDelta {
309 id,
310 reasoning,
311 })))
312 }
313 RawStreamingChoice::ToolCall(tool_call) => {
314 let tool_call: ToolCall = tool_call.into();
317 stream.tool_calls.push(tool_call.clone());
318 Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCall(tool_call))))
319 }
320 RawStreamingChoice::FinalResponse(response) => {
321 if stream
322 .final_response_yielded
323 .load(std::sync::atomic::Ordering::SeqCst)
324 {
325 stream.poll_next_unpin(cx)
326 } else {
327 stream.response = Some(response.clone());
329 stream
330 .final_response_yielded
331 .store(true, std::sync::atomic::Ordering::SeqCst);
332 let final_response = StreamedAssistantContent::final_response(response);
333 Poll::Ready(Some(Ok(final_response)))
334 }
335 }
336 },
337 }
338 }
339}
340
341pub trait StreamingPrompt<M, R>
343where
344 M: CompletionModel + 'static,
345 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
346 R: Clone + Unpin + GetTokenUsage,
347{
348 fn stream_prompt(
350 &self,
351 prompt: impl Into<Message> + WasmCompatSend,
352 ) -> StreamingPromptRequest<M, ()>;
353}
354
355pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
357where
358 M: CompletionModel + 'static,
359 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
360 R: Clone + Unpin + GetTokenUsage,
361{
362 fn stream_chat(
364 &self,
365 prompt: impl Into<Message> + WasmCompatSend,
366 chat_history: Vec<Message>,
367 ) -> StreamingPromptRequest<M, ()>;
368}
369
370pub trait StreamingCompletion<M: CompletionModel> {
372 fn stream_completion(
374 &self,
375 prompt: impl Into<Message> + WasmCompatSend,
376 chat_history: Vec<Message>,
377 ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
378}
379
380pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
381 pub(crate) inner: StreamingResult<R>,
382}
383
384impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
385 type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;
386
387 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
388 let stream = self.get_mut();
389
390 match stream.inner.as_mut().poll_next(cx) {
391 Poll::Pending => Poll::Pending,
392 Poll::Ready(None) => Poll::Ready(None),
393 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
394 Poll::Ready(Some(Ok(chunk))) => match chunk {
395 RawStreamingChoice::FinalResponse(res) => Poll::Ready(Some(Ok(
396 RawStreamingChoice::FinalResponse(FinalCompletionResponse {
397 usage: res.token_usage(),
398 }),
399 ))),
400 RawStreamingChoice::Message(m) => {
401 Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
402 }
403 RawStreamingChoice::ToolCallDelta { id, delta } => {
404 Poll::Ready(Some(Ok(RawStreamingChoice::ToolCallDelta { id, delta })))
405 }
406 RawStreamingChoice::Reasoning {
407 id,
408 reasoning,
409 signature,
410 } => Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning {
411 id,
412 reasoning,
413 signature,
414 }))),
415 RawStreamingChoice::ReasoningDelta { id, reasoning } => {
416 Poll::Ready(Some(Ok(RawStreamingChoice::ReasoningDelta {
417 id,
418 reasoning,
419 })))
420 }
421 RawStreamingChoice::ToolCall(tool_call) => {
422 Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall(tool_call))))
423 }
424 },
425 }
426 }
427}
428
429pub async fn stream_to_stdout<M>(
432 agent: &'static Agent<M>,
433 stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
434) -> Result<(), std::io::Error>
435where
436 M: CompletionModel,
437{
438 let mut is_reasoning = false;
439 print!("Response: ");
440 while let Some(chunk) = stream.next().await {
441 match chunk {
442 Ok(StreamedAssistantContent::Text(text)) => {
443 if is_reasoning {
444 is_reasoning = false;
445 println!("\n---\n");
446 }
447 print!("{}", text.text);
448 std::io::Write::flush(&mut std::io::stdout())?;
449 }
450 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
451 let res = agent
452 .tool_server_handle
453 .call_tool(
454 &tool_call.function.name,
455 &tool_call.function.arguments.to_string(),
456 )
457 .await
458 .map_err(|x| std::io::Error::other(x.to_string()))?;
459 println!("\nResult: {res}");
460 }
461 Ok(StreamedAssistantContent::Final(res)) => {
462 let json_res = serde_json::to_string_pretty(&res).unwrap();
463 println!();
464 tracing::info!("Final result: {json_res}");
465 }
466 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
467 if !is_reasoning {
468 is_reasoning = true;
469 println!();
470 println!("Thinking: ");
471 }
472 let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
473
474 print!("{reasoning}");
475 std::io::Write::flush(&mut std::io::stdout())?;
476 }
477 Err(e) => {
478 if e.to_string().contains("aborted") {
479 println!("\nStream cancelled.");
480 break;
481 }
482 eprintln!("Error: {e}");
483 break;
484 }
485 _ => {}
486 }
487 }
488
489 println!(); Ok(())
492}
493
494#[cfg(test)]
496mod tests {
497 use std::time::Duration;
498
499 use super::*;
500 use async_stream::stream;
501 use tokio::time::sleep;
502
503 #[derive(Debug, Clone)]
504 pub struct MockResponse {
505 #[allow(dead_code)]
506 token_count: u32,
507 }
508
509 impl GetTokenUsage for MockResponse {
510 fn token_usage(&self) -> Option<crate::completion::Usage> {
511 let mut usage = Usage::new();
512 usage.total_tokens = 15;
513 Some(usage)
514 }
515 }
516
517 fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
518 let stream = stream! {
519 yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
520 sleep(Duration::from_millis(100)).await;
521 yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
522 sleep(Duration::from_millis(100)).await;
523 yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
524 sleep(Duration::from_millis(100)).await;
525 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
526 };
527
528 #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
529 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
530 #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
531 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
532
533 StreamingCompletionResponse::stream(pinned_stream)
534 }
535
536 #[tokio::test]
537 async fn test_stream_cancellation() {
538 let mut stream = create_mock_stream();
539
540 println!("Response: ");
541 let mut chunk_count = 0;
542 while let Some(chunk) = stream.next().await {
543 match chunk {
544 Ok(StreamedAssistantContent::Text(text)) => {
545 print!("{}", text.text);
546 std::io::Write::flush(&mut std::io::stdout()).unwrap();
547 chunk_count += 1;
548 }
549 Ok(StreamedAssistantContent::ToolCall(tc)) => {
550 println!("\nTool Call: {tc:?}");
551 chunk_count += 1;
552 }
553 Ok(StreamedAssistantContent::ToolCallDelta { delta, .. }) => {
554 println!("\nTool Call delta: {delta:?}");
555 chunk_count += 1;
556 }
557 Ok(StreamedAssistantContent::Final(res)) => {
558 println!("\nFinal response: {res:?}");
559 }
560 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
561 let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
562 print!("{reasoning}");
563 std::io::Write::flush(&mut std::io::stdout()).unwrap();
564 }
565 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
566 println!("Reasoning delta: {reasoning}");
567 chunk_count += 1;
568 }
569 Err(e) => {
570 eprintln!("Error: {e:?}");
571 break;
572 }
573 }
574
575 if chunk_count >= 2 {
576 println!("\nCancelling stream...");
577 stream.cancel();
578 println!("Stream cancelled.");
579 break;
580 }
581 }
582
583 let next_chunk = stream.next().await;
584 assert!(
585 next_chunk.is_none(),
586 "Expected no further chunks after cancellation, got {next_chunk:?}"
587 );
588 }
589
590 #[tokio::test]
591 async fn test_stream_pause_resume() {
592 let stream = create_mock_stream();
593
594 stream.pause();
596 assert!(stream.is_paused());
597
598 stream.resume();
600 assert!(!stream.is_paused());
601 }
602}
603
604#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
606#[serde(untagged)]
607pub enum StreamedAssistantContent<R> {
608 Text(Text),
609 ToolCall(ToolCall),
610 ToolCallDelta {
611 id: String,
612 delta: String,
613 },
614 Reasoning(Reasoning),
615 ReasoningDelta {
616 id: Option<String>,
617 reasoning: String,
618 },
619 Final(R),
620}
621
622impl<R> StreamedAssistantContent<R>
623where
624 R: Clone + Unpin,
625{
626 pub fn text(text: &str) -> Self {
627 Self::Text(Text {
628 text: text.to_string(),
629 })
630 }
631
632 pub fn final_response(res: R) -> Self {
633 Self::Final(res)
634 }
635}
636
637#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
639#[serde(untagged)]
640pub enum StreamedUserContent {
641 ToolResult(ToolResult),
642}
643
644impl StreamedUserContent {
645 pub fn tool_result(tool_result: ToolResult) -> Self {
646 Self::ToolResult(tool_result)
647 }
648}