1use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::streaming::StreamingPromptRequest;
14use crate::client::builder::FinalCompletionResponse;
15use crate::completion::{
16 CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
17 Message, Usage,
18};
19use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction};
20use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
21use futures::stream::{AbortHandle, Abortable};
22use futures::{Stream, StreamExt};
23use serde::{Deserialize, Serialize};
24use std::future::Future;
25use std::pin::Pin;
26use std::sync::atomic::AtomicBool;
27use std::task::{Context, Poll};
28use tokio::sync::watch;
29
30pub struct PauseControl {
32 pub(crate) paused_tx: watch::Sender<bool>,
33 pub(crate) paused_rx: watch::Receiver<bool>,
34}
35
36impl PauseControl {
37 pub fn new() -> Self {
38 let (paused_tx, paused_rx) = watch::channel(false);
39 Self {
40 paused_tx,
41 paused_rx,
42 }
43 }
44
45 pub fn pause(&self) {
46 self.paused_tx.send(true).unwrap();
47 }
48
49 pub fn resume(&self) {
50 self.paused_tx.send(false).unwrap();
51 }
52
53 pub fn is_paused(&self) -> bool {
54 *self.paused_rx.borrow()
55 }
56}
57
58impl Default for PauseControl {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64#[derive(Debug, Clone)]
66pub enum RawStreamingChoice<R>
67where
68 R: Clone,
69{
70 Message(String),
72
73 ToolCall {
75 id: String,
76 call_id: Option<String>,
77 name: String,
78 arguments: serde_json::Value,
79 },
80 Reasoning {
82 id: Option<String>,
83 reasoning: String,
84 },
85
86 FinalResponse(R),
89}
90
91#[cfg(not(target_arch = "wasm32"))]
92pub type StreamingResult<R> =
93 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
94
95#[cfg(target_arch = "wasm32")]
96pub type StreamingResult<R> =
97 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
98
99pub struct StreamingCompletionResponse<R>
103where
104 R: Clone + Unpin + GetTokenUsage,
105{
106 pub(crate) inner: Abortable<StreamingResult<R>>,
107 pub(crate) abort_handle: AbortHandle,
108 pub(crate) pause_control: PauseControl,
109 text: String,
110 reasoning: String,
111 tool_calls: Vec<ToolCall>,
112 pub choice: OneOrMany<AssistantContent>,
115 pub response: Option<R>,
118 pub final_response_yielded: AtomicBool,
119}
120
121impl<R> StreamingCompletionResponse<R>
122where
123 R: Clone + Unpin + GetTokenUsage,
124{
125 pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
126 let (abort_handle, abort_registration) = AbortHandle::new_pair();
127 let abortable_stream = Abortable::new(inner, abort_registration);
128 let pause_control = PauseControl::new();
129 Self {
130 inner: abortable_stream,
131 abort_handle,
132 pause_control,
133 reasoning: String::new(),
134 text: "".to_string(),
135 tool_calls: vec![],
136 choice: OneOrMany::one(AssistantContent::text("")),
137 response: None,
138 final_response_yielded: AtomicBool::new(false),
139 }
140 }
141
142 pub fn cancel(&self) {
143 self.abort_handle.abort();
144 }
145
146 pub fn pause(&self) {
147 self.pause_control.pause();
148 }
149
150 pub fn resume(&self) {
151 self.pause_control.resume();
152 }
153
154 pub fn is_paused(&self) -> bool {
155 self.pause_control.is_paused()
156 }
157}
158
159impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
160where
161 R: Clone + Unpin + GetTokenUsage,
162{
163 fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
164 CompletionResponse {
165 choice: value.choice,
166 usage: Usage::new(), raw_response: value.response,
168 }
169 }
170}
171
172impl<R> Stream for StreamingCompletionResponse<R>
173where
174 R: Clone + Unpin + GetTokenUsage,
175{
176 type Item = Result<StreamedAssistantContent<R>, CompletionError>;
177
178 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
179 let stream = self.get_mut();
180
181 if stream.is_paused() {
182 cx.waker().wake_by_ref();
183 return Poll::Pending;
184 }
185
186 match Pin::new(&mut stream.inner).poll_next(cx) {
187 Poll::Pending => Poll::Pending,
188 Poll::Ready(None) => {
189 let mut choice = vec![];
192
193 stream.tool_calls.iter().for_each(|tc| {
194 choice.push(AssistantContent::ToolCall(tc.clone()));
195 });
196
197 if choice.is_empty() || !stream.text.is_empty() {
199 choice.insert(0, AssistantContent::text(stream.text.clone()));
200 }
201
202 stream.choice = OneOrMany::many(choice)
203 .expect("There should be at least one assistant message");
204
205 Poll::Ready(None)
206 }
207 Poll::Ready(Some(Err(err))) => {
208 if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
209 {
210 return Poll::Ready(None); }
212 Poll::Ready(Some(Err(err)))
213 }
214 Poll::Ready(Some(Ok(choice))) => match choice {
215 RawStreamingChoice::Message(text) => {
216 stream.text = format!("{}{}", stream.text, text.clone());
219 Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
220 }
221 RawStreamingChoice::Reasoning { id, reasoning } => {
222 stream.reasoning = format!("{}{}", stream.reasoning, reasoning.clone());
225 Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
226 id,
227 reasoning: vec![stream.reasoning.clone()],
228 }))))
229 }
230 RawStreamingChoice::ToolCall {
231 id,
232 name,
233 arguments,
234 call_id,
235 } => {
236 stream.tool_calls.push(ToolCall {
239 id: id.clone(),
240 call_id: call_id.clone(),
241 function: ToolFunction {
242 name: name.clone(),
243 arguments: arguments.clone(),
244 },
245 });
246 if let Some(call_id) = call_id {
247 Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call_with_call_id(
248 id, call_id, name, arguments,
249 ))))
250 } else {
251 Poll::Ready(Some(Ok(StreamedAssistantContent::tool_call(
252 id, name, arguments,
253 ))))
254 }
255 }
256 RawStreamingChoice::FinalResponse(response) => {
257 if stream
258 .final_response_yielded
259 .load(std::sync::atomic::Ordering::SeqCst)
260 {
261 stream.poll_next_unpin(cx)
262 } else {
263 stream.response = Some(response.clone());
265 stream
266 .final_response_yielded
267 .store(true, std::sync::atomic::Ordering::SeqCst);
268 let final_response = StreamedAssistantContent::final_response(response);
269 Poll::Ready(Some(Ok(final_response)))
270 }
271 }
272 },
273 }
274 }
275}
276
277pub trait StreamingPrompt<M, R>
279where
280 M: CompletionModel + 'static,
281 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
282 R: Clone + Unpin + GetTokenUsage,
283{
284 fn stream_prompt(
286 &self,
287 prompt: impl Into<Message> + WasmCompatSend,
288 ) -> StreamingPromptRequest<M, ()>;
289}
290
291pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
293where
294 M: CompletionModel + 'static,
295 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
296 R: Clone + Unpin + GetTokenUsage,
297{
298 fn stream_chat(
300 &self,
301 prompt: impl Into<Message> + WasmCompatSend,
302 chat_history: Vec<Message>,
303 ) -> StreamingPromptRequest<M, ()>;
304}
305
306pub trait StreamingCompletion<M: CompletionModel> {
308 fn stream_completion(
310 &self,
311 prompt: impl Into<Message> + WasmCompatSend,
312 chat_history: Vec<Message>,
313 ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
314}
315
316pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
317 pub(crate) inner: StreamingResult<R>,
318}
319
320impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
321 type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;
322
323 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
324 let stream = self.get_mut();
325
326 match stream.inner.as_mut().poll_next(cx) {
327 Poll::Pending => Poll::Pending,
328 Poll::Ready(None) => Poll::Ready(None),
329 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
330 Poll::Ready(Some(Ok(chunk))) => match chunk {
331 RawStreamingChoice::FinalResponse(res) => Poll::Ready(Some(Ok(
332 RawStreamingChoice::FinalResponse(FinalCompletionResponse {
333 usage: res.token_usage(),
334 }),
335 ))),
336 RawStreamingChoice::Message(m) => {
337 Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
338 }
339 RawStreamingChoice::Reasoning { id, reasoning } => {
340 Poll::Ready(Some(Ok(RawStreamingChoice::Reasoning { id, reasoning })))
341 }
342 RawStreamingChoice::ToolCall {
343 id,
344 name,
345 arguments,
346 call_id,
347 } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
348 id,
349 name,
350 arguments,
351 call_id,
352 }))),
353 },
354 }
355 }
356}
357
358pub async fn stream_to_stdout<M>(
360 agent: &'static Agent<M>,
361 stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
362) -> Result<(), std::io::Error>
363where
364 M: CompletionModel,
365{
366 let mut is_reasoning = false;
367 print!("Response: ");
368 while let Some(chunk) = stream.next().await {
369 match chunk {
370 Ok(StreamedAssistantContent::Text(text)) => {
371 if is_reasoning {
372 is_reasoning = false;
373 println!("\n---\n");
374 }
375 print!("{}", text.text);
376 std::io::Write::flush(&mut std::io::stdout())?;
377 }
378 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
379 let res = agent
380 .tool_server_handle
381 .call_tool(
382 &tool_call.function.name,
383 &tool_call.function.arguments.to_string(),
384 )
385 .await
386 .map_err(|x| std::io::Error::other(x.to_string()))?;
387 println!("\nResult: {res}");
388 }
389 Ok(StreamedAssistantContent::Final(res)) => {
390 let json_res = serde_json::to_string_pretty(&res).unwrap();
391 println!();
392 tracing::info!("Final result: {json_res}");
393 }
394 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
395 if !is_reasoning {
396 is_reasoning = true;
397 println!();
398 println!("Thinking: ");
399 }
400 let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
401
402 print!("{reasoning}");
403 std::io::Write::flush(&mut std::io::stdout())?;
404 }
405 Err(e) => {
406 if e.to_string().contains("aborted") {
407 println!("\nStream cancelled.");
408 break;
409 }
410 eprintln!("Error: {e}");
411 break;
412 }
413 }
414 }
415
416 println!(); Ok(())
419}
420
421#[cfg(test)]
423mod tests {
424 use std::time::Duration;
425
426 use super::*;
427 use async_stream::stream;
428 use tokio::time::sleep;
429
430 #[derive(Debug, Clone)]
431 pub struct MockResponse {
432 #[allow(dead_code)]
433 token_count: u32,
434 }
435
436 impl GetTokenUsage for MockResponse {
437 fn token_usage(&self) -> Option<crate::completion::Usage> {
438 let mut usage = Usage::new();
439 usage.total_tokens = 15;
440 Some(usage)
441 }
442 }
443
444 fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
445 let stream = stream! {
446 yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
447 sleep(Duration::from_millis(100)).await;
448 yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
449 sleep(Duration::from_millis(100)).await;
450 yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
451 sleep(Duration::from_millis(100)).await;
452 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
453 };
454
455 #[cfg(not(target_arch = "wasm32"))]
456 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
457 #[cfg(target_arch = "wasm32")]
458 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
459
460 StreamingCompletionResponse::stream(pinned_stream)
461 }
462
463 #[tokio::test]
464 async fn test_stream_cancellation() {
465 let mut stream = create_mock_stream();
466
467 println!("Response: ");
468 let mut chunk_count = 0;
469 while let Some(chunk) = stream.next().await {
470 match chunk {
471 Ok(StreamedAssistantContent::Text(text)) => {
472 print!("{}", text.text);
473 std::io::Write::flush(&mut std::io::stdout()).unwrap();
474 chunk_count += 1;
475 }
476 Ok(StreamedAssistantContent::ToolCall(tc)) => {
477 println!("\nTool Call: {tc:?}");
478 chunk_count += 1;
479 }
480 Ok(StreamedAssistantContent::Final(res)) => {
481 println!("\nFinal response: {res:?}");
482 }
483 Ok(StreamedAssistantContent::Reasoning(Reasoning { reasoning, .. })) => {
484 let reasoning = reasoning.into_iter().collect::<Vec<String>>().join("");
485 print!("{reasoning}");
486 std::io::Write::flush(&mut std::io::stdout()).unwrap();
487 }
488 Err(e) => {
489 eprintln!("Error: {e:?}");
490 break;
491 }
492 }
493
494 if chunk_count >= 2 {
495 println!("\nCancelling stream...");
496 stream.cancel();
497 println!("Stream cancelled.");
498 break;
499 }
500 }
501
502 let next_chunk = stream.next().await;
503 assert!(
504 next_chunk.is_none(),
505 "Expected no further chunks after cancellation, got {next_chunk:?}"
506 );
507 }
508
509 #[tokio::test]
510 async fn test_stream_pause_resume() {
511 let stream = create_mock_stream();
512
513 stream.pause();
515 assert!(stream.is_paused());
516
517 stream.resume();
519 assert!(!stream.is_paused());
520 }
521}
522
523#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
525#[serde(untagged)]
526pub enum StreamedAssistantContent<R> {
527 Text(Text),
528 ToolCall(ToolCall),
529 Reasoning(Reasoning),
530 Final(R),
531}
532
533impl<R> StreamedAssistantContent<R>
534where
535 R: Clone + Unpin,
536{
537 pub fn text(text: &str) -> Self {
538 Self::Text(Text {
539 text: text.to_string(),
540 })
541 }
542
543 pub fn tool_call(
545 id: impl Into<String>,
546 name: impl Into<String>,
547 arguments: serde_json::Value,
548 ) -> Self {
549 Self::ToolCall(ToolCall {
550 id: id.into(),
551 call_id: None,
552 function: ToolFunction {
553 name: name.into(),
554 arguments,
555 },
556 })
557 }
558
559 pub fn tool_call_with_call_id(
560 id: impl Into<String>,
561 call_id: String,
562 name: impl Into<String>,
563 arguments: serde_json::Value,
564 ) -> Self {
565 Self::ToolCall(ToolCall {
566 id: id.into(),
567 call_id: Some(call_id),
568 function: ToolFunction {
569 name: name.into(),
570 arguments,
571 },
572 })
573 }
574
575 pub fn final_response(res: R) -> Self {
576 Self::Final(res)
577 }
578}