1use crate::agent::Agent;
12use crate::completion::{
13 CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, Message,
14};
15use crate::message::{AssistantContent, ToolCall, ToolFunction};
16use crate::OneOrMany;
17use futures::{Stream, StreamExt};
18use std::boxed::Box;
19use std::future::Future;
20use std::pin::Pin;
21use std::task::{Context, Poll};
22
23#[derive(Debug, Clone)]
25pub enum RawStreamingChoice<R: Clone> {
26 Message(String),
28
29 ToolCall {
31 id: String,
32 name: String,
33 arguments: serde_json::Value,
34 },
35
36 FinalResponse(R),
39}
40
41#[cfg(not(target_arch = "wasm32"))]
42pub type StreamingResult<R> =
43 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
44
45#[cfg(target_arch = "wasm32")]
46pub type StreamingResult<R> =
47 Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
48
49pub struct StreamingCompletionResponse<R: Clone + Unpin> {
53 pub(crate) inner: StreamingResult<R>,
54 text: String,
55 tool_calls: Vec<ToolCall>,
56 pub choice: OneOrMany<AssistantContent>,
59 pub response: Option<R>,
62}
63
64impl<R: Clone + Unpin> StreamingCompletionResponse<R> {
65 pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
66 Self {
67 inner,
68 text: "".to_string(),
69 tool_calls: vec![],
70 choice: OneOrMany::one(AssistantContent::text("")),
71 response: None,
72 }
73 }
74}
75
76impl<R: Clone + Unpin> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>> {
77 fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
78 CompletionResponse {
79 choice: value.choice,
80 raw_response: value.response,
81 }
82 }
83}
84
85impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
86 type Item = Result<AssistantContent, CompletionError>;
87
88 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
89 let stream = self.get_mut();
90
91 match stream.inner.as_mut().poll_next(cx) {
92 Poll::Pending => Poll::Pending,
93 Poll::Ready(None) => {
94 let mut choice = vec![];
97
98 stream.tool_calls.iter().for_each(|tc| {
99 choice.push(AssistantContent::ToolCall(tc.clone()));
100 });
101
102 if choice.is_empty() || !stream.text.is_empty() {
104 choice.insert(0, AssistantContent::text(stream.text.clone()));
105 }
106
107 stream.choice = OneOrMany::many(choice)
108 .expect("There should be at least one assistant message");
109
110 Poll::Ready(None)
111 }
112 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
113 Poll::Ready(Some(Ok(choice))) => match choice {
114 RawStreamingChoice::Message(text) => {
115 stream.text = format!("{}{}", stream.text, text.clone());
118 Poll::Ready(Some(Ok(AssistantContent::text(text))))
119 }
120 RawStreamingChoice::ToolCall {
121 id,
122 name,
123 arguments,
124 } => {
125 stream.tool_calls.push(ToolCall {
128 id: id.clone(),
129 function: ToolFunction {
130 name: name.clone(),
131 arguments: arguments.clone(),
132 },
133 });
134 Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, arguments))))
135 }
136 RawStreamingChoice::FinalResponse(response) => {
137 stream.response = Some(response);
139
140 stream.poll_next_unpin(cx)
141 }
142 },
143 }
144 }
145}
146
147pub trait StreamingPrompt<R: Clone + Unpin>: Send + Sync {
149 fn stream_prompt(
151 &self,
152 prompt: impl Into<Message> + Send,
153 ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
154}
155
156pub trait StreamingChat<R: Clone + Unpin>: Send + Sync {
158 fn stream_chat(
160 &self,
161 prompt: impl Into<Message> + Send,
162 chat_history: Vec<Message>,
163 ) -> impl Future<Output = Result<StreamingCompletionResponse<R>, CompletionError>>;
164}
165
166pub trait StreamingCompletion<M: CompletionModel> {
168 fn stream_completion(
170 &self,
171 prompt: impl Into<Message> + Send,
172 chat_history: Vec<Message>,
173 ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
174}
175
176pub(crate) struct StreamingResultDyn<R: Clone + Unpin> {
177 pub(crate) inner: StreamingResult<R>,
178}
179
180impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
181 type Item = Result<RawStreamingChoice<()>, CompletionError>;
182
183 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
184 let stream = self.get_mut();
185
186 match stream.inner.as_mut().poll_next(cx) {
187 Poll::Pending => Poll::Pending,
188 Poll::Ready(None) => Poll::Ready(None),
189 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
190 Poll::Ready(Some(Ok(chunk))) => match chunk {
191 RawStreamingChoice::FinalResponse(_) => {
192 Poll::Ready(Some(Ok(RawStreamingChoice::FinalResponse(()))))
193 }
194 RawStreamingChoice::Message(m) => {
195 Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
196 }
197 RawStreamingChoice::ToolCall {
198 id,
199 name,
200 arguments,
201 } => Poll::Ready(Some(Ok(RawStreamingChoice::ToolCall {
202 id,
203 name,
204 arguments,
205 }))),
206 },
207 }
208 }
209}
210
211pub async fn stream_to_stdout<M: CompletionModel>(
213 agent: &Agent<M>,
214 stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
215) -> Result<(), std::io::Error> {
216 print!("Response: ");
217 while let Some(chunk) = stream.next().await {
218 match chunk {
219 Ok(AssistantContent::Text(text)) => {
220 print!("{}", text.text);
221 std::io::Write::flush(&mut std::io::stdout())?;
222 }
223 Ok(AssistantContent::ToolCall(tool_call)) => {
224 let res = agent
225 .tools
226 .call(
227 &tool_call.function.name,
228 tool_call.function.arguments.to_string(),
229 )
230 .await
231 .map_err(|e| std::io::Error::other(e.to_string()))?;
232 println!("\nResult: {res}");
233 }
234 Err(e) => {
235 eprintln!("Error: {e}");
236 break;
237 }
238 }
239 }
240
241 println!(); Ok(())
244}