1use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::completion::{
14 CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, Message,
15};
16use crate::message::{AssistantContent, ToolCall, ToolFunction};
17use futures::stream::{AbortHandle, Abortable};
18use futures::{Stream, StreamExt};
19use std::boxed::Box;
20use std::future::Future;
21use std::pin::Pin;
22use std::task::{Context, Poll};
23
24#[derive(Debug, Clone)]
26pub enum RawStreamingChoice<R: Clone> {
27 Message(String),
29
30 ToolCall {
32 id: String,
33 call_id: Option<String>,
34 name: String,
35 arguments: serde_json::Value,
36 },
37
38 FinalResponse(R),
41}
42
43#[cfg(not(target_arch = "wasm32"))]
44pub type StreamingResult<R> =
45 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
46
47#[cfg(target_arch = "wasm32")]
48pub type StreamingResult<R> =
49 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
50
51pub struct StreamingCompletionResponse<R: Clone + Unpin> {
55 pub(crate) inner: Abortable<StreamingResult<R>>,
56 pub(crate) abort_handle: AbortHandle,
57 text: String,
58 tool_calls: Vec<ToolCall>,
59 pub choice: OneOrMany<AssistantContent>,
62 pub response: Option<R>,
65}
66
67impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
68 pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
69 let (abort_handle, abort_registration) = AbortHandle::new_pair();
70 let abortable_stream = Abortable::new(inner, abort_registration);
71 Self {
72 inner: abortable_stream,
73 abort_handle,
74 text: "".to_string(),
75 tool_calls: vec![],
76 choice: OneOrMany::one(AssistantContent::text("")),
77 response: None,
78 }
79 }
80
81 pub fn cancel(&self) {
82 self.abort_handle.abort();
83 }
84}
85
86impl<R: Clone + Unpin> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>> {
87 fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
88 CompletionResponse {
89 choice: value.choice,
90 raw_response: value.response,
91 }
92 }
93}
94
95impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
96 type Item = Result<AssistantContent, CompletionError>;
97
98 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
99 let stream = self.get_mut();
100
101 match Pin::new(&mut stream.inner).poll_next(cx) {
102 Poll::Pending => Poll::Pending,
103 Poll::Ready(None) => {
104 let mut choice = vec![];
107
108 stream.tool_calls.iter().for_each(|tc| {
109 choice.push(AssistantContent::ToolCall(tc.clone()));
110 });
111
112 if choice.is_empty() || !stream.text.is_empty() {
114 choice.insert(0, AssistantContent::text(stream.text.clone()));
115 }
116
117 stream.choice = OneOrMany::many(choice)
118 .expect("There should be at least one assistant message");
119
120 Poll::Ready(None)
121 }
122 Poll::Ready(Some(Err(err))) => {
123 if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
124 {
125 return Poll::Ready(None); }
127 Poll::Ready(Some(Err(err)))
128 }
129 Poll::Ready(Some(Ok(choice))) => match choice {
130 RawStreamingChoice::Message(text) => {
131 stream.text = format!("{}{}", stream.text, text.clone());
134 Poll::Ready(Some(Ok(AssistantContent::text(text))))
135 }
136 RawStreamingChoice::ToolCall {
137 id,
138 name,
139 arguments,
140 call_id,
141 } => {
142 stream.tool_calls.push(ToolCall {
145 id: id.clone(),
146 call_id,
147 function: ToolFunction {
148 name: name.clone(),
149 arguments: arguments.clone(),
150 },
151 });
152 Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, arguments))))
153 }
154 RawStreamingChoice::FinalResponse(response) => {
155 stream.response = Some(response);
157
158 stream.poll_next_unpin(cx)
159 }
160 },
161 }
162 }
163}
164
165pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
167 fn stream_prompt(
169 &self,
170 prompt: impl Into<Message> + Send,
171 ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
172}
173
174pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
176 fn stream_chat(
178 &self,
179 prompt: impl Into<Message> + Send,
180 chat_history: Vec<Message>,
181 ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
182}
183
184pub trait StreamingCompletion<M: CompletionModel> {
186 fn stream_completion(
188 &self,
189 prompt: impl Into<Message> + Send,
190 chat_history: Vec<Message>,
191 ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
192}
193
194pub(crate) struct StreamingResultDyn<R: Clone + Unpin> {
195 pub(crate) inner: StreamingResult<R>,
196}
197
198impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
199 type Item = Result<RawStreamingChoice<()>, CompletionError>;
200
201 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
202 let stream = self.get_mut();
203
204 match stream.inner.as_mut().poll_next(cx) {
205 Poll::Pending => Poll::Pending,
206 Poll::Ready(None) => Poll::Ready(None),
207 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
208 Poll::Ready(Some(Ok(chunk))) => match chunk {
209 RawStreamingChoice::FinalResponse(_) => {
210 Poll::Ready(Some(Ok(RawStreamingChoice::FinalResponse(()))))
211 }
212 RawStreamingChoice::Message(m) => {
213 Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
214 }
215 RawStreamingChoice::ToolCall {
216 id,
217 name,
218 arguments,
219 call_id,
220 } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
221 id,
222 name,
223 arguments,
224 call_id,
225 }))),
226 },
227 }
228 }
229}
230
231pub async fn stream_to_stdout<M: CompletionModel>(
233 agent: &Agent<M>,
234 stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
235) -> Result<(), std::io::Error> {
236 print!("Response: ");
237 while let Some(chunk) = stream.next().await {
238 match chunk {
239 Ok(AssistantContent::Text(text)) => {
240 print!("{}", text.text);
241 std::io::Write::flush(&mut std::io::stdout())?;
242 }
243 Ok(AssistantContent::ToolCall(tool_call)) => {
244 let res = agent
245 .tools
246 .call(
247 &tool_call.function.name,
248 tool_call.function.arguments.to_string(),
249 )
250 .await
251 .map_err(|e| std::io::Error::other(e.to_string()))?;
252 println!("\nResult: {res}");
253 }
254 Err(e) => {
255 if e.to_string().contains("aborted") {
256 println!("\nStream cancelled.");
257 break;
258 }
259 eprintln!("Error: {e}");
260 break;
261 }
262 }
263 }
264
265 println!(); Ok(())
268}
269
270#[cfg(test)]
272mod tests {
273 use std::time::Duration;
274
275 use super::*;
276 use async_stream::stream;
277 use tokio::time::sleep;
278
279 #[derive(Debug, Clone)]
280 pub struct MockResponse {
281 #[allow(dead_code)]
282 token_count: u32,
283 }
284
285 fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
286 let stream = stream! {
287 yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
288 sleep(Duration::from_millis(100)).await;
289 yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
290 sleep(Duration::from_millis(100)).await;
291 yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
292 sleep(Duration::from_millis(100)).await;
293 yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
294 };
295
296 #[cfg(not(target_arch = "wasm32"))]
297 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
298 #[cfg(target_arch = "wasm32")]
299 let pinned_stream: StreamingResult<MockResponse> = Box::pin(stream);
300
301 StreamingCompletionResponse::stream(pinned_stream)
302 }
303
304 #[tokio::test]
305 async fn test_stream_cancellation() {
306 let mut stream = create_mock_stream();
307
308 println!("Response: ");
309 let mut chunk_count = 0;
310 while let Some(chunk) = stream.next().await {
311 match chunk {
312 Ok(AssistantContent::Text(text)) => {
313 print!("{}", text.text);
314 std::io::Write::flush(&mut std::io::stdout()).unwrap();
315 chunk_count += 1;
316 }
317 Ok(AssistantContent::ToolCall(tc)) => {
318 println!("\nTool Call: {tc:?}");
319 chunk_count += 1;
320 }
321 Err(e) => {
322 eprintln!("Error: {e:?}");
323 break;
324 }
325 }
326
327 if chunk_count >= 2 {
328 println!("\nCancelling stream...");
329 stream.cancel();
330 println!("Stream cancelled.");
331 break;
332 }
333 }
334
335 let next_chunk = stream.next().await;
336 assert!(
337 next_chunk.is_none(),
338 "Expected no further chunks after cancellation, got {next_chunk:?}"
339 );
340 }
341}