rig/agent/prompt_request/
streaming.rs1use crate::{
2 OneOrMany,
3 agent::prompt_request::PromptHook,
4 completion::GetTokenUsage,
5 message::{AssistantContent, Reasoning, ToolResultContent, UserContent},
6 streaming::{StreamedAssistantContent, StreamingCompletion},
7};
8use futures::{Stream, StreamExt};
9use serde::{Deserialize, Serialize};
10use std::{pin::Pin, sync::Arc};
11use tokio::sync::RwLock;
12
13use crate::{
14 agent::Agent,
15 completion::{CompletionError, CompletionModel, PromptError},
16 message::{Message, Text},
17 tool::ToolSetError,
18};
19
20#[cfg(not(target_arch = "wasm32"))]
21pub type StreamingResult<R> =
22 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
23
24#[cfg(target_arch = "wasm32")]
25pub type StreamingResult<R> =
26 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
27
28#[derive(Deserialize, Serialize, Debug, Clone)]
29#[serde(tag = "type", rename_all = "camelCase")]
30#[non_exhaustive]
31pub enum MultiTurnStreamItem<R> {
32 StreamItem(StreamedAssistantContent<R>),
33 FinalResponse(FinalResponse),
34}
35
36#[derive(Deserialize, Serialize, Debug, Clone)]
37#[serde(rename_all = "camelCase")]
38pub struct FinalResponse {
39 response: String,
40 aggregated_usage: crate::completion::Usage,
41}
42
43impl FinalResponse {
44 pub fn empty() -> Self {
45 Self {
46 response: String::new(),
47 aggregated_usage: crate::completion::Usage::new(),
48 }
49 }
50
51 pub fn response(&self) -> &str {
52 &self.response
53 }
54
55 pub fn usage(&self) -> crate::completion::Usage {
56 self.aggregated_usage
57 }
58}
59
60impl<R> MultiTurnStreamItem<R> {
61 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
62 Self::StreamItem(item)
63 }
64
65 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
66 Self::FinalResponse(FinalResponse {
67 response: response.to_string(),
68 aggregated_usage,
69 })
70 }
71}
72
73#[derive(Debug, thiserror::Error)]
74pub enum StreamingError {
75 #[error("CompletionError: {0}")]
76 Completion(#[from] CompletionError),
77 #[error("PromptError: {0}")]
78 Prompt(#[from] Box<PromptError>),
79 #[error("ToolSetError: {0}")]
80 Tool(#[from] ToolSetError),
81}
82
83pub struct StreamingPromptRequest<M, P>
92where
93 M: CompletionModel,
94 P: PromptHook<M> + 'static,
95{
96 prompt: Message,
98 chat_history: Option<Vec<Message>>,
101 max_depth: usize,
103 agent: Arc<Agent<M>>,
105 hook: Option<P>,
107}
108
109impl<M, P> StreamingPromptRequest<M, P>
110where
111 M: CompletionModel + 'static,
112 <M as CompletionModel>::StreamingResponse: Send + GetTokenUsage,
113 P: PromptHook<M>,
114{
115 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
117 Self {
118 prompt: prompt.into(),
119 chat_history: None,
120 max_depth: 0,
121 agent,
122 hook: None,
123 }
124 }
125
126 pub fn multi_turn(mut self, depth: usize) -> Self {
129 self.max_depth = depth;
130 self
131 }
132
133 pub fn with_history(mut self, history: Vec<Message>) -> Self {
135 self.chat_history = Some(history);
136 self
137 }
138
139 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
141 where
142 P2: PromptHook<M>,
143 {
144 StreamingPromptRequest {
145 prompt: self.prompt,
146 chat_history: self.chat_history,
147 max_depth: self.max_depth,
148 agent: self.agent,
149 hook: Some(hook),
150 }
151 }
152
153 #[cfg_attr(feature = "worker", worker::send)]
154 async fn send(self) -> StreamingResult<M::StreamingResponse> {
155 let agent_name = self.agent.name_owned();
156
157 #[tracing::instrument(skip_all, fields(agent_name = agent_name))]
158 fn inner<M, P>(
159 req: StreamingPromptRequest<M, P>,
160 agent_name: String,
161 ) -> StreamingResult<M::StreamingResponse>
162 where
163 M: CompletionModel + 'static,
164 <M as CompletionModel>::StreamingResponse: Send,
165 P: PromptHook<M> + 'static,
166 {
167 let prompt = req.prompt;
168 let agent = req.agent;
169
170 let chat_history = if let Some(mut history) = req.chat_history {
171 history.push(prompt.clone());
172 Arc::new(RwLock::new(history))
173 } else {
174 Arc::new(RwLock::new(vec![prompt.clone()]))
175 };
176
177 let mut current_max_depth = 0;
178 let mut last_prompt_error = String::new();
179
180 let mut last_text_response = String::new();
181 let mut is_text_response = false;
182 let mut max_depth_reached = false;
183
184 let mut aggregated_usage = crate::completion::Usage::new();
185
186 Box::pin(async_stream::stream! {
187 let mut current_prompt = prompt.clone();
188 let mut did_call_tool = false;
189
190 'outer: loop {
191 if current_max_depth > req.max_depth + 1 {
192 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
193 max_depth_reached = true;
194 break;
195 }
196
197 current_max_depth += 1;
198
199 if req.max_depth > 1 {
200 tracing::info!(
201 "Current conversation depth: {}/{}",
202 current_max_depth,
203 req.max_depth
204 );
205 }
206
207 if let Some(ref hook) = req.hook {
208 let reader = chat_history.read().await;
209 let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history");
210 let chat_history_except_last = reader[..reader.len() - 1].to_vec();
211
212 hook.on_completion_call(&prompt, &chat_history_except_last)
213 .await;
214 }
215
216
217 let mut stream = agent
218 .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
219 .await?
220 .stream()
221 .await?;
222
223 chat_history.write().await.push(current_prompt.clone());
224
225 let mut tool_calls = vec![];
226 let mut tool_results = vec![];
227
228 while let Some(content) = stream.next().await {
229 match content {
230 Ok(StreamedAssistantContent::Text(text)) => {
231 if !is_text_response {
232 last_text_response = String::new();
233 is_text_response = true;
234 }
235 last_text_response.push_str(&text.text);
236 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
237 did_call_tool = false;
238 },
239 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
240 if let Some(ref hook) = req.hook {
241 hook.on_tool_call(&tool_call.function.name, &tool_call.function.arguments.to_string()).await;
242 }
243 let tool_result =
244 agent.tools.call(&tool_call.function.name, tool_call.function.arguments.to_string()).await?;
245
246 if let Some(ref hook) = req.hook {
247 hook.on_tool_result(&tool_call.function.name, &tool_call.function.arguments.to_string(), &tool_result.to_string())
248 .await;
249 }
250 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
251
252 tool_calls.push(tool_call_msg);
253 tool_results.push((tool_call.id, tool_call.call_id, tool_result));
254
255 did_call_tool = true;
256 },
258 Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id })) => {
259 chat_history.write().await.push(rig::message::Message::Assistant {
260 id: None,
261 content: OneOrMany::one(AssistantContent::Reasoning(Reasoning {
262 reasoning: reasoning.clone(), id: id.clone()
263 }))
264 });
265 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id })));
266 did_call_tool = false;
267 },
268 Ok(StreamedAssistantContent::Final(final_resp)) => {
269 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
270 if is_text_response {
271 if let Some(ref hook) = req.hook {
272 hook.on_stream_completion_response_finish(&prompt, &final_resp).await;
273 }
274 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
275 is_text_response = false;
276 }
277 }
278 Err(e) => {
279 yield Err(e.into());
280 break 'outer;
281 }
282 }
283 }
284
285 if !tool_calls.is_empty() {
287 chat_history.write().await.push(Message::Assistant {
288 id: None,
289 content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
290 });
291 }
292
293 for (id, call_id, tool_result) in tool_results {
295 if let Some(call_id) = call_id {
296 chat_history.write().await.push(Message::User {
297 content: OneOrMany::one(UserContent::tool_result_with_call_id(
298 &id,
299 call_id.clone(),
300 OneOrMany::one(ToolResultContent::text(&tool_result)),
301 )),
302 });
303 } else {
304 chat_history.write().await.push(Message::User {
305 content: OneOrMany::one(UserContent::tool_result(
306 &id,
307 OneOrMany::one(ToolResultContent::text(&tool_result)),
308 )),
309 });
310 }
311
312 }
313
314 current_prompt = match chat_history.write().await.pop() {
316 Some(prompt) => prompt,
317 None => unreachable!("Chat history should never be empty at this point"),
318 };
319
320 if !did_call_tool {
321 yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
322 break;
323 }
324 }
325
326 if max_depth_reached {
327 yield Err(Box::new(PromptError::MaxDepthError {
328 max_depth: req.max_depth,
329 chat_history: Box::new((*chat_history.read().await).clone()),
330 prompt: last_prompt_error.into(),
331 }).into());
332 }
333
334 })
335 }
336
337 inner(self, agent_name)
338 }
339}
340
341impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
342where
343 M: CompletionModel + 'static,
344 <M as CompletionModel>::StreamingResponse: Send,
345 P: PromptHook<M> + 'static,
346{
347 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = Pin<Box<dyn futures::Future<Output = Self::Output> + Send>>;
349
350 fn into_future(self) -> Self::IntoFuture {
351 Box::pin(async move { self.send().await })
353 }
354}
355
356pub async fn stream_to_stdout<R>(
358 stream: &mut StreamingResult<R>,
359) -> Result<FinalResponse, std::io::Error> {
360 let mut final_res = FinalResponse::empty();
361 print!("Response: ");
362 while let Some(content) = stream.next().await {
363 match content {
364 Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Text(Text { text }))) => {
365 print!("{text}");
366 std::io::Write::flush(&mut std::io::stdout()).unwrap();
367 }
368 Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Reasoning(
369 Reasoning { reasoning, .. },
370 ))) => {
371 let reasoning = reasoning.join("\n");
372 print!("{reasoning}");
373 std::io::Write::flush(&mut std::io::stdout()).unwrap();
374 }
375 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
376 final_res = res;
377 }
378 Err(err) => {
379 eprintln!("Error: {err}");
380 }
381 _ => {}
382 }
383 }
384
385 Ok(final_res)
386}