1use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::streaming::StreamingPromptRequest;
14use crate::client::builder::FinalCompletionResponse;
15use crate::completion::{
16 CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
17 Message, Usage,
18};
19use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction, ToolResult};
20use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
21use futures::stream::{AbortHandle, Abortable};
22use futures::{Stream, StreamExt};
23use serde::{Deserialize, Serialize};
24use std::future::Future;
25use std::pin::Pin;
26use std::sync::atomic::AtomicBool;
27use std::task::{Context, Poll};
28use tokio::sync::watch;
29
30pub struct PauseControl {
32 pub(crate) paused_tx: watch::Sender<bool>,
33 pub(crate) paused_rx: watch::Receiver<bool>,
34}
35
36impl PauseControl {
37 pub fn new() -> Self {
38 let (paused_tx, paused_rx) = watch::channel(false);
39 Self {
40 paused_tx,
41 paused_rx,
42 }
43 }
44
45 pub fn pause(&self) {
46 self.paused_tx.send(true).unwrap();
47 }
48
49 pub fn resume(&self) {
50 self.paused_tx.send(false).unwrap();
51 }
52
53 pub fn is_paused(&self) -> bool {
54 *self.paused_rx.borrow()
55 }
56}
57
58impl Default for PauseControl {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64#[derive(Debug, Clone)]
66pub enum RawStreamingChoice<R>
67where
68 R: Clone,
69{
70 Message(String),
72
73 ToolCall {
75 id: String,
76 call_id: Option<String>,
77 name: String,
78 arguments: serde_json::Value,
79 },
80 ToolCallDelta { id: String, delta: String },
82 Reasoning {
84 id: Option<String>,
85 reasoning: String,
86 signature: Option<String>,
87 },
88
89 FinalResponse(R),
92}
93
94#[cfg(not(target_arch = "wasm32"))]
95pub type StreamingResult<R> =
96 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
97
98#[cfg(target_arch = "wasm32")]
99pub type StreamingResult<R> =
100 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
101
102pub struct StreamingCompletionResponse<R>
106where
107 R: Clone + Unpin + GetTokenUsage,
108{
109 pub(crate) inner: Abortable<StreamingResult<R>>,
110 pub(crate) abort_handle: AbortHandle,
111 pub(crate) pause_control: PauseControl,
112 text: String,
113 reasoning: String,
114 tool_calls: Vec<ToolCall>,
115 pub choice: OneOrMany<AssistantContent>,
118 pub response: Option<R>,
121 pub final_response_yielded: AtomicBool,
122}
123
124impl<R> StreamingCompletionResponse<R>
125where
126 R: Clone + Unpin + GetTokenUsage,
127{
128 pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
129 let (abort_handle, abort_registration) = AbortHandle::new_pair();
130 let abortable_stream = Abortable::new(inner, abort_registration);
131 let pause_control = PauseControl::new();
132 Self {
133 inner: abortable_stream,
134 abort_handle,
135 pause_control,
136 reasoning: String::new(),
137 text: "".to_string(),
138 tool_calls: vec![],
139 choice: OneOrMany::one(AssistantContent::text("")),
140 response: None,
141 final_response_yielded: AtomicBool::new(false),
142 }
143 }
144
145 pub fn cancel(&self) {
146 self.abort_handle.abort();
147 }
148
149 pub fn pause(&self) {
150 self.pause_control.pause();
151 }
152
153 pub fn resume(&self) {
154 self.pause_control.resume();
155 }
156
157 pub fn is_paused(&self) -> bool {
158 self.pause_control.is_paused()
159 }
160}
161
162impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
163where
164 R: Clone + Unpin + GetTokenUsage,
165{
166 fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
167 CompletionResponse {
168 choice: value.choice,
169 usage: Usage::new(), raw_response: value.response,
171 }
172 }
173}
174
175impl<R> Stream for StreamingCompletionResponse<R>
176where
177 R: Clone + Unpin + GetTokenUsage,
178{
179 type Item = Result<StreamedAssistantContent<R>, CompletionError>;
180
181 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
182 let stream = self.get_mut();
183
184 if stream.is_paused() {
185 cx.waker().wake_by_ref();
186 return Poll::Pending;
187 }
188
189 match Pin::new(&mut stream.inner).poll_next(cx) {
190 Poll::Pending => Poll::Pending,
191 Poll::Ready(None) => {
192 let mut choice = vec![];
195
196 stream.tool_calls.iter().for_each(|tc| {
197 choice.push(AssistantContent::ToolCall(tc.clone()));
198 });
199
200 if choice.is_empty() || !stream.text.is_empty() {
202 choice.insert(0, AssistantContent::text(stream.text.clone()));
203 }
204
205 stream.choice = OneOrMany::many(choice)
206 .expect("There should be at least one assistant message");
207
208 Poll::Ready(None)
209 }
210 Poll::Ready(Some(Err(err))) => {
211 if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
212 {
213 return Poll::Ready(None); }
215 Poll::Ready(Some(Err(err)))
216 }
217 Poll::Ready(Some(Ok(choice))) => match choice {
218 RawStreamingChoice::Message(text) => {
219 stream.text = format!("{}{}", stream.text, text);
222 Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
223 }
224 RawStreamingChoice::ToolCallDelta { id, delta } => {
225 Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
226 id,
227 delta,
228 })))
229 }
230 RawStreamingChoice::Reasoning {
231 id,
232 reasoning,
233 signature,
234 } => {
235 stream.reasoning = format!("{}{}", stream.reasoning, reasoning);
238 Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
239 id,
240 reasoning: vec![reasoning],
241 signature,
242 }))))
243 }
244 RawStreamingChoice::ToolCall {
245 id,
246 name,
247 arguments,
248 call_id,
249 } => {
250 stream.tool_calls.push(ToolCall {
253 id: id.clone(),
254 call_id: call_id.clone(),
255 function: ToolFunction {
256 name: name.clone(),
257 arguments: arguments.clone(),
258 },
259 });
260 if let Some(call_id) = call_id {
261 Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call_with_call_id(
262 id, call_id, name, arguments,
263 ))))
264 } else {
265 Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call(
266 id, name, arguments,
267 ))))
268 }
269 }
270 RawStreamingChoice::FinalResponse(response) => {
271 if stream
272 .final_response_yielded
273 .load(std::sync::atomic::Ordering::SeqCst)
274 {
275 stream.poll_next_unpin(cx)
276 } else {
277 stream.response = Some(response.clone());
279 stream
280 .final_response_yielded
281 .store(true, std::sync::atomic::Ordering::SeqCst);
282 let final_response = StreamedAssistantContent::final_response(response);
283 Poll::Ready(Some(Ok(final_response)))
284 }
285 }
286 },
287 }
288 }
289}
290
291pub trait StreamingPrompt<M, R>
293where
294 M: CompletionModel + 'static,
295 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
296 R: Clone + Unpin + GetTokenUsage,
297{
298 fn stream_prompt(
300 &self,
301 prompt: impl Into<Message> + WasmCompatSend,
302 ) -> StreamingPromptRequest<M, ()>;
303}
304
305pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
307where
308 M: CompletionModel + 'static,
309 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
310 R: Clone + Unpin + GetTokenUsage,
311{
312 fn stream_chat(
314 &self,
315 prompt: impl Into<Message> + WasmCompatSend,
316 chat_history: Vec<Message>,
317 ) -> StreamingPromptRequest<M, ()>;
318}
319
320pub trait StreamingCompletion<M: CompletionModel> {
322 fn stream_completion(
324 &self,
325 prompt: impl Into<Message> + WasmCompatSend,
326 chat_history: Vec<Message>,
327 ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
328}
329
330pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
331 pub(crate) inner: StreamingResult<R>,
332}
333
334impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
335 type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;
336
337 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
338 let stream = self.get_mut();
339
340 match stream.inner.as_mut().poll_next(cx) {
341 Poll::Pending => Poll::Pending,
342 Poll::Ready(None) => Poll::Ready(None),
343 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
344 Poll::Ready(Some(Ok(chunk))) => match chunk {
345 RawStreamingChoice::FinalResponse(res) => Poll::Ready(Some(Ok(
346 RawStreamingChoice::FinalResponse(FinalCompletionResponse {
347 usage: res.token_usage(),
348 }),
349 ))),
350 RawStreamingChoice::Message(m) => {
351 Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
352 }
353 RawStreamingChoice::ToolCallDelta { id, delta } => {
354 Poll::Ready(Some(Ok(RawStreamingChoice::ToolCallDelta { id, delta })))
355 }
356 RawStreamingChoice::Reasoning {
357 id,
358 reasoning,
359 signature,
360 } => Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning {
361 id,
362 reasoning,
363 signature,
364 }))),
365 RawStreamingChoice::ToolCall {
366 id,
367 name,
368 arguments,
369 call_id,
370 } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
371 id,
372 name,
373 arguments,
374 call_id,
375 }))),
376 },
377 }
378 }
379}
380
381pub async fn stream_to_stdout<M>(
384 agent: &'static Agent<M>,
385 stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
386) -> Result<(), std::io::Error>
387where
388 M: CompletionModel,
389{
390 let mut is_reasoning = false;
391 print!("Response: ");
392 while let Some(chunk) = stream.next().await {
393 match chunk {
394 Ok(StreamedAssistantContent::Text(text)) => {
395 if is_reasoning {
396 is_reasoning = false;
397 println!("\n---\n");
398 }
399 print!("{}", text.text);
400 std::io::Write::flush(&mut std::io::stdout())?;
401 }
402 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
403 let res = agent
404 .tool_server_handle
405 .call_tool(
406 &tool_call.function.name,
407 &tool_call.function.arguments.to_string(),
408 )
409 .await
410 .map_err(|x| std::io::Error::other(x.to_string()))?;
411 println!("\nResult: {res}");
412 }
413 Ok(StreamedAssistantContent::Final(res)) => {
414 let json_res = serde_json::to_string_pretty(&res).unwrap();
415 println!();
416 tracing::info!("Final result: {json_res}");
417 }
418 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
419 if !is_reasoning {
420 is_reasoning = true;
421 println!();
422 println!("Thinking: ");
423 }
424 let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
425
426 print!("{reasoning}");
427 std::io::Write::flush(&mut std::io::stdout())?;
428 }
429 Err(e) => {
430 if e.to_string().contains("aborted") {
431 println!("\nStream cancelled.");
432 break;
433 }
434 eprintln!("Error: {e}");
435 break;
436 }
437 _ => {}
438 }
439 }
440
441 println!(); Ok(())
444}
445
446#[cfg(test)]
448mod tests {
449 use std::time::Duration;
450
451 use super::*;
452 use async_stream::stream;
453 use tokio::time::sleep;
454
455 #[derive(Debug, Clone)]
456 pub struct MockResponse {
457 #[allow(dead_code)]
458 token_count: u32,
459 }
460
461 impl GetTokenUsage for MockResponse {
462 fn token_usage(&self) -> Option<crate::completion::Usage> {
463 let mut usage = Usage::new();
464 usage.total_tokens = 15;
465 Some(usage)
466 }
467 }
468
469 fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
470 let stream = stream! {
471 yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
472 sleep(Duration::from_millis(100)).await;
473 yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
474 sleep(Duration::from_millis(100)).await;
475 yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
476 sleep(Duration::from_millis(100)).await;
477 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
478 };
479
480 #[cfg(not(target_arch = "wasm32"))]
481 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
482 #[cfg(target_arch = "wasm32")]
483 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
484
485 StreamingCompletionResponse::stream(pinned_stream)
486 }
487
488 #[tokio::test]
489 async fn test_stream_cancellation() {
490 let mut stream = create_mock_stream();
491
492 println!("Response: ");
493 let mut chunk_count = 0;
494 while let Some(chunk) = stream.next().await {
495 match chunk {
496 Ok(StreamedAssistantContent::Text(text)) => {
497 print!("{}", text.text);
498 std::io::Write::flush(&mut std::io::stdout()).unwrap();
499 chunk_count += 1;
500 }
501 Ok(StreamedAssistantContent::ToolCall(tc)) => {
502 println!("\nTool Call: {tc:?}");
503 chunk_count += 1;
504 }
505 Ok(StreamedAssistantContent::ToolCallDelta { delta, .. }) => {
506 println!("\nTool Call delta: {delta:?}");
507 chunk_count += 1;
508 }
509 Ok(StreamedAssistantContent::Final(res)) => {
510 println!("\nFinal response: {res:?}");
511 }
512 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
513 let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
514 print!("{reasoning}");
515 std::io::Write::flush(&mut std::io::stdout()).unwrap();
516 }
517 Err(e) => {
518 eprintln!("Error: {e:?}");
519 break;
520 }
521 }
522
523 if chunk_count >= 2 {
524 println!("\nCancelling stream...");
525 stream.cancel();
526 println!("Stream cancelled.");
527 break;
528 }
529 }
530
531 let next_chunk = stream.next().await;
532 assert!(
533 next_chunk.is_none(),
534 "Expected no further chunks after cancellation, got {next_chunk:?}"
535 );
536 }
537
538 #[tokio::test]
539 async fn test_stream_pause_resume() {
540 let stream = create_mock_stream();
541
542 stream.pause();
544 assert!(stream.is_paused());
545
546 stream.resume();
548 assert!(!stream.is_paused());
549 }
550}
551
552#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
554#[serde(untagged)]
555pub enum StreamedAssistantContent<R> {
556 Text(Text),
557 ToolCall(ToolCall),
558 ToolCallDelta { id: String, delta: String },
559 Reasoning(Reasoning),
560 Final(R),
561}
562
563impl<R> StreamedAssistantContent<R>
564where
565 R: Clone + Unpin,
566{
567 pub fn text(text: &str) -> Self {
568 Self::Text(Text {
569 text: text.to_string(),
570 })
571 }
572
573 pub fn tool_call(
575 id: impl Into<String>,
576 name: impl Into<String>,
577 arguments: serde_json::Value,
578 ) -> Self {
579 Self::ToolCall(ToolCall {
580 id: id.into(),
581 call_id: None,
582 function: ToolFunction {
583 name: name.into(),
584 arguments,
585 },
586 })
587 }
588
589 pub fn tool_call_with_call_id(
590 id: impl Into<String>,
591 call_id: String,
592 name: impl Into<String>,
593 arguments: serde_json::Value,
594 ) -> Self {
595 Self::ToolCall(ToolCall {
596 id: id.into(),
597 call_id: Some(call_id),
598 function: ToolFunction {
599 name: name.into(),
600 arguments,
601 },
602 })
603 }
604
605 pub fn final_response(res: R) -> Self {
606 Self::Final(res)
607 }
608}
609
610#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
612#[serde(untagged)]
613pub enum StreamedUserContent {
614 ToolResult(ToolResult),
615}
616
617impl StreamedUserContent {
618 pub fn tool_result(tool_result: ToolResult) -> Self {
619 Self::ToolResult(tool_result)
620 }
621}