rig/agent/prompt_request/
streaming.rs1use crate::{
2 OneOrMany,
3 agent::prompt_request::PromptHook,
4 completion::GetTokenUsage,
5 message::{AssistantContent, Reasoning, ToolResultContent, UserContent},
6 streaming::{StreamedAssistantContent, StreamingCompletion},
7};
8use futures::{Stream, StreamExt};
9use serde::{Deserialize, Serialize};
10use std::{pin::Pin, sync::Arc};
11use tokio::sync::RwLock;
12
13use crate::{
14 agent::Agent,
15 completion::{CompletionError, CompletionModel, PromptError},
16 message::{Message, Text},
17 tool::ToolSetError,
18};
19
20#[cfg(not(target_arch = "wasm32"))]
21type StreamingResult =
22 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem, StreamingError>> + Send>>;
23
24#[cfg(target_arch = "wasm32")]
25type StreamingResult = Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem, StreamingError>>>>;
26
27#[derive(Deserialize, Serialize, Debug, Clone)]
28#[serde(tag = "type", rename_all = "camelCase")]
29pub enum MultiTurnStreamItem {
30 Text(Text),
31 FinalResponse(FinalResponse),
32}
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(rename_all = "camelCase")]
36pub struct FinalResponse {
37 response: String,
38 aggregated_usage: crate::completion::Usage,
39}
40
41impl FinalResponse {
42 pub fn empty() -> Self {
43 Self {
44 response: String::new(),
45 aggregated_usage: crate::completion::Usage::new(),
46 }
47 }
48
49 pub fn response(&self) -> &str {
50 &self.response
51 }
52
53 pub fn usage(&self) -> crate::completion::Usage {
54 self.aggregated_usage
55 }
56}
57
58impl MultiTurnStreamItem {
59 pub(crate) fn text(text: &str) -> Self {
60 Self::Text(Text {
61 text: text.to_string(),
62 })
63 }
64
65 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
66 Self::FinalResponse(FinalResponse {
67 response: response.to_string(),
68 aggregated_usage,
69 })
70 }
71}
72
73#[derive(Debug, thiserror::Error)]
74pub enum StreamingError {
75 #[error("CompletionError: {0}")]
76 Completion(#[from] CompletionError),
77 #[error("PromptError: {0}")]
78 Prompt(#[from] PromptError),
79 #[error("ToolSetError: {0}")]
80 Tool(#[from] ToolSetError),
81}
82
83pub struct StreamingPromptRequest<M, P>
92where
93 M: CompletionModel,
94 P: PromptHook<M> + 'static,
95{
96 prompt: Message,
98 chat_history: Option<Vec<Message>>,
101 max_depth: usize,
103 agent: Arc<Agent<M>>,
105 hook: Option<P>,
107}
108
109impl<M, P> StreamingPromptRequest<M, P>
110where
111 M: CompletionModel + 'static,
112 <M as CompletionModel>::StreamingResponse: Send + GetTokenUsage,
113 P: PromptHook<M>,
114{
115 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
117 Self {
118 prompt: prompt.into(),
119 chat_history: None,
120 max_depth: 0,
121 agent,
122 hook: None,
123 }
124 }
125
126 pub fn multi_turn(mut self, depth: usize) -> Self {
129 self.max_depth = depth;
130 self
131 }
132
133 pub fn with_history(mut self, history: Vec<Message>) -> Self {
135 self.chat_history = Some(history);
136 self
137 }
138
139 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
141 where
142 P2: PromptHook<M>,
143 {
144 StreamingPromptRequest {
145 prompt: self.prompt,
146 chat_history: self.chat_history,
147 max_depth: self.max_depth,
148 agent: self.agent,
149 hook: Some(hook),
150 }
151 }
152
153 #[cfg_attr(feature = "worker", worker::send)]
154 async fn send(self) -> StreamingResult {
155 let agent_name = self.agent.name_owned();
156
157 #[tracing::instrument(skip_all, fields(agent_name = agent_name))]
158 fn inner<M, P>(req: StreamingPromptRequest<M, P>, agent_name: String) -> StreamingResult
159 where
160 M: CompletionModel + 'static,
161 <M as CompletionModel>::StreamingResponse: Send,
162 P: PromptHook<M> + 'static,
163 {
164 let prompt = req.prompt;
165 let agent = req.agent;
166
167 let chat_history = if let Some(mut history) = req.chat_history {
168 history.push(prompt.clone());
169 Arc::new(RwLock::new(history))
170 } else {
171 Arc::new(RwLock::new(vec![prompt.clone()]))
172 };
173
174 let mut current_max_depth = 0;
175 let mut last_prompt_error = String::new();
176
177 let mut last_text_response = String::new();
178 let mut is_text_response = false;
179 let mut max_depth_reached = false;
180
181 let mut aggregated_usage = crate::completion::Usage::new();
182
183 Box::pin(async_stream::stream! {
184 let mut current_prompt = prompt.clone();
185 let mut did_call_tool = false;
186
187 'outer: loop {
188 if current_max_depth > req.max_depth + 1 {
189 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
190 max_depth_reached = true;
191 break;
192 }
193
194 current_max_depth += 1;
195
196 if req.max_depth > 1 {
197 tracing::info!(
198 "Current conversation depth: {}/{}",
199 current_max_depth,
200 req.max_depth
201 );
202 }
203
204 if let Some(ref hook) = req.hook {
205 let reader = chat_history.read().await;
206 let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history");
207 let chat_history_except_last = reader[..reader.len() - 1].to_vec();
208
209 hook.on_completion_call(&prompt, &chat_history_except_last)
210 .await;
211 }
212
213
214 let mut stream = agent
215 .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
216 .await?
217 .stream()
218 .await?;
219
220 chat_history.write().await.push(current_prompt.clone());
221
222 let mut tool_calls = vec![];
223 let mut tool_results = vec![];
224
225 while let Some(content) = stream.next().await {
226 match content {
227 Ok(StreamedAssistantContent::Text(text)) => {
228 if !is_text_response {
229 last_text_response = String::new();
230 is_text_response = true;
231 }
232 last_text_response.push_str(&text.text);
233 yield Ok(MultiTurnStreamItem::text(&text.text));
234 did_call_tool = false;
235 },
236 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
237 if let Some(ref hook) = req.hook {
238 hook.on_tool_call(&tool_call.function.name, &tool_call.function.arguments.to_string()).await;
239 }
240 let tool_result =
241 agent.tools.call(&tool_call.function.name, tool_call.function.arguments.to_string()).await?;
242
243 if let Some(ref hook) = req.hook {
244 hook.on_tool_result(&tool_call.function.name, &tool_call.function.arguments.to_string(), &tool_result.to_string())
245 .await;
246 }
247 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
248
249 tool_calls.push(tool_call_msg);
250 tool_results.push((tool_call.id, tool_call.call_id, tool_result));
251
252 did_call_tool = true;
253 },
255 Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id })) => {
256 chat_history.write().await.push(rig::message::Message::Assistant {
257 id: None,
258 content: OneOrMany::one(AssistantContent::Reasoning(Reasoning {
259 reasoning: reasoning.clone(), id
260 }))
261 });
262 let text = reasoning.into_iter().collect::<Vec<String>>().join("");
263 yield Ok(MultiTurnStreamItem::text(&text));
264 did_call_tool = false;
265 },
266 Ok(StreamedAssistantContent::Final(final_resp)) => {
267 if is_text_response {
268 if let Some(ref hook) = req.hook {
269 hook.on_stream_completion_response_finish(&prompt, &final_resp).await;
270 }
271 yield Ok(MultiTurnStreamItem::text("\n"));
272 is_text_response = false;
273 }
274 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
275 }
279 Err(e) => {
280 yield Err(e.into());
281 break 'outer;
282 }
283 }
284 }
285
286 if !tool_calls.is_empty() {
288 chat_history.write().await.push(Message::Assistant {
289 id: None,
290 content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
291 });
292 }
293
294 for (id, call_id, tool_result) in tool_results {
296 if let Some(call_id) = call_id {
297 chat_history.write().await.push(Message::User {
298 content: OneOrMany::one(UserContent::tool_result_with_call_id(
299 &id,
300 call_id.clone(),
301 OneOrMany::one(ToolResultContent::text(&tool_result)),
302 )),
303 });
304 } else {
305 chat_history.write().await.push(Message::User {
306 content: OneOrMany::one(UserContent::tool_result(
307 &id,
308 OneOrMany::one(ToolResultContent::text(&tool_result)),
309 )),
310 });
311 }
312
313 }
314
315 current_prompt = match chat_history.write().await.pop() {
317 Some(prompt) => prompt,
318 None => unreachable!("Chat history should never be empty at this point"),
319 };
320
321 if !did_call_tool {
322 yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
323 break;
324 }
325 }
326
327 if max_depth_reached {
328 yield Err(PromptError::MaxDepthError {
329 max_depth: req.max_depth,
330 chat_history: (*chat_history.read().await).clone(),
331 prompt: last_prompt_error.into(),
332 }.into());
333 }
334
335 })
336 }
337
338 inner(self, agent_name)
339 }
340}
341
342impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
343where
344 M: CompletionModel + 'static,
345 <M as CompletionModel>::StreamingResponse: Send,
346 P: PromptHook<M> + 'static,
347{
348 type Output = StreamingResult; type IntoFuture = Pin<Box<dyn futures::Future<Output = Self::Output> + Send>>;
350
351 fn into_future(self) -> Self::IntoFuture {
352 Box::pin(async move { self.send().await })
354 }
355}
356
357pub async fn stream_to_stdout(
359 stream: &mut StreamingResult,
360) -> Result<FinalResponse, std::io::Error> {
361 let mut final_res = FinalResponse::empty();
362 print!("Response: ");
363 while let Some(content) = stream.next().await {
364 match content {
365 Ok(MultiTurnStreamItem::Text(Text { text })) => {
366 print!("{text}");
367 std::io::Write::flush(&mut std::io::stdout())?;
368 }
369 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
370 final_res = res;
371 }
372 Err(err) => {
373 eprintln!("Error: {err}");
374 }
375 }
376 }
377
378 Ok(final_res)
379}