rig/agent/prompt_request/
streaming.rs1use crate::{
2 OneOrMany,
3 completion::GetTokenUsage,
4 message::{AssistantContent, ToolResultContent, UserContent},
5 streaming::{StreamedAssistantContent, StreamingCompletion},
6};
7use futures::{Stream, StreamExt};
8use serde::{Deserialize, Serialize};
9use std::{pin::Pin, sync::Arc};
10use tokio::sync::RwLock;
11
12use crate::{
13 agent::Agent,
14 completion::{CompletionError, CompletionModel, PromptError},
15 message::{Message, Text},
16 tool::ToolSetError,
17};
18
19type StreamingResult<'a> =
20 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem, StreamingError>> + 'a>>;
21
22#[derive(Deserialize, Serialize, Debug, Clone)]
23#[serde(tag = "type", rename_all = "camelCase")]
24pub enum MultiTurnStreamItem {
25 Text(Text),
26 FinalResponse(FinalResponse),
27}
28
29#[derive(Deserialize, Serialize, Debug, Clone)]
30#[serde(rename_all = "camelCase")]
31pub struct FinalResponse {
32 response: String,
33 aggregated_usage: crate::completion::Usage,
34}
35
36impl FinalResponse {
37 pub fn empty() -> Self {
38 Self {
39 response: String::new(),
40 aggregated_usage: crate::completion::Usage::new(),
41 }
42 }
43
44 pub fn response(&self) -> &str {
45 &self.response
46 }
47
48 pub fn usage(&self) -> crate::completion::Usage {
49 self.aggregated_usage
50 }
51}
52
53impl MultiTurnStreamItem {
54 pub(crate) fn text(text: &str) -> Self {
55 Self::Text(Text {
56 text: text.to_string(),
57 })
58 }
59
60 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
61 Self::FinalResponse(FinalResponse {
62 response: response.to_string(),
63 aggregated_usage,
64 })
65 }
66}
67
68#[derive(Debug, thiserror::Error)]
69pub enum StreamingError {
70 #[error("CompletionError: {0}")]
71 Completion(#[from] CompletionError),
72 #[error("PromptError: {0}")]
73 Prompt(#[from] PromptError),
74 #[error("ToolSetError: {0}")]
75 Tool(#[from] ToolSetError),
76}
77
78pub struct StreamingPromptRequest<'a, M>
87where
88 M: CompletionModel,
89{
90 prompt: Message,
92 chat_history: Option<Vec<Message>>,
95 max_depth: usize,
97 agent: &'a Agent<M>,
99 #[cfg_attr(docsrs, doc(cfg(feature = "hooks")))]
100 #[cfg(feature = "hooks")]
101 hook: Option<&'a dyn crate::agent::PromptHook<M>>,
103}
104
105impl<'a, M> StreamingPromptRequest<'a, M>
106where
107 M: CompletionModel + 'static,
108 <M as CompletionModel>::StreamingResponse: Send + GetTokenUsage,
109{
110 pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
112 Self {
113 prompt: prompt.into(),
114 chat_history: None,
115 max_depth: 0,
116 agent,
117 #[cfg_attr(docsrs, doc(cfg(feature = "hooks")))]
118 #[cfg(feature = "hooks")]
119 hook: None,
120 }
121 }
122
123 pub fn multi_turn(mut self, depth: usize) -> Self {
126 self.max_depth = depth;
127 self
128 }
129
130 pub fn with_history(mut self, history: Vec<Message>) -> Self {
132 self.chat_history = Some(history);
133 self
134 }
135
136 #[cfg_attr(docsrs, doc(cfg(feature = "hooks")))]
138 #[cfg(feature = "hooks")]
139 pub fn with_hook(
140 mut self,
141 hook: &'a dyn crate::agent::PromptHook<M>,
142 ) -> StreamingPromptRequest<'a, M> {
143 self.hook = Some(hook);
144 self
145 }
146
147 #[cfg_attr(feature = "worker", worker::send)]
148 async fn send(self) -> StreamingResult<'a> {
149 let agent_name = self.agent.name_owned();
150
151 #[tracing::instrument(skip_all, fields(agent_name = agent_name))]
152 fn inner<'a, M>(
153 req: StreamingPromptRequest<'a, M>,
154 agent_name: String,
155 ) -> StreamingResult<'a>
156 where
157 M: CompletionModel + 'static,
158 <M as CompletionModel>::StreamingResponse: Send,
159 {
160 let prompt = req.prompt;
161 let agent = req.agent;
162
163 let chat_history = if let Some(mut history) = req.chat_history {
164 history.push(prompt.clone());
165 Arc::new(RwLock::new(history))
166 } else {
167 Arc::new(RwLock::new(vec![prompt.clone()]))
168 };
169
170 let mut current_max_depth = 0;
171 let mut last_prompt_error = String::new();
172
173 let mut last_text_response = String::new();
174 let mut is_text_response = false;
175 let mut max_depth_reached = false;
176
177 let mut aggregated_usage = crate::completion::Usage::new();
178
179 Box::pin(async_stream::stream! {
180 let mut current_prompt = prompt.clone();
181 let mut did_call_tool = false;
182
183 'outer: loop {
184 if current_max_depth > req.max_depth + 1 {
185 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
186 max_depth_reached = true;
187 break;
188 }
189
190 current_max_depth += 1;
191
192 if req.max_depth > 1 {
193 tracing::info!(
194 "Current conversation depth: {}/{}",
195 current_max_depth,
196 req.max_depth
197 );
198 }
199
200 #[cfg(feature = "hooks")]
201 if let Some(hook) = req.hook.as_ref() {
202 let reader = chat_history.read().await;
203 let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history");
204 let chat_history_except_last = reader[..reader.len() - 1].to_vec();
205
206 hook.on_completion_call(&prompt, &chat_history_except_last)
207 .await;
208 }
209
210
211 let mut stream = agent
212 .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
213 .await?
214 .stream()
215 .await?;
216
217 chat_history.write().await.push(current_prompt.clone());
218
219 let mut tool_calls = vec![];
220 let mut tool_results = vec![];
221
222 while let Some(content) = stream.next().await {
223 match content {
224 Ok(StreamedAssistantContent::Text(text)) => {
225 if !is_text_response {
226 last_text_response = String::new();
227 is_text_response = true;
228 }
229 last_text_response.push_str(&text.text);
230 yield Ok(MultiTurnStreamItem::text(&text.text));
231 did_call_tool = false;
232 },
233 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
234 #[cfg(feature = "hooks")]
235 if let Some(hook) = req.hook.as_ref() {
236 hook.on_tool_call(&tool_call.function.name, &tool_call.function.arguments.to_string()).await;
237 }
238 let tool_result =
239 agent.tools.call(&tool_call.function.name, tool_call.function.arguments.to_string()).await?;
240
241 #[cfg(feature = "hooks")]
242 if let Some(hook) = req.hook.as_ref() {
243 hook.on_tool_result(&tool_call.function.name, &tool_call.function.arguments.to_string(), &tool_result.to_string())
244 .await;
245 }
246 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
247
248 tool_calls.push(tool_call_msg);
249 tool_results.push((tool_call.id, tool_call.call_id, tool_result));
250
251 did_call_tool = true;
252 },
254 Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, .. })) => {
255 let text = reasoning.into_iter().collect::<Vec<String>>().join("");
256 yield Ok(MultiTurnStreamItem::text(&text));
257 did_call_tool = false;
258 },
259 Ok(StreamedAssistantContent::Final(final_resp)) => {
260 if is_text_response {
261 #[cfg(feature = "hooks")]
262 if let Some(hook) = req.hook.as_ref() {
263 hook.on_stream_completion_response_finish(&prompt, &final_resp).await;
264 }
265 yield Ok(MultiTurnStreamItem::text("\n"));
266 is_text_response = false;
267 }
268 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
269 }
273 Err(e) => {
274 yield Err(e.into());
275 break 'outer;
276 }
277 }
278 }
279
280 if !tool_calls.is_empty() {
282 chat_history.write().await.push(Message::Assistant {
283 id: None,
284 content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
285 });
286 }
287
288 for (id, call_id, tool_result) in tool_results {
290 if let Some(call_id) = call_id {
291 chat_history.write().await.push(Message::User {
292 content: OneOrMany::one(UserContent::tool_result_with_call_id(
293 &id,
294 call_id.clone(),
295 OneOrMany::one(ToolResultContent::text(&tool_result)),
296 )),
297 });
298 } else {
299 chat_history.write().await.push(Message::User {
300 content: OneOrMany::one(UserContent::tool_result(
301 &id,
302 OneOrMany::one(ToolResultContent::text(&tool_result)),
303 )),
304 });
305 }
306
307 }
308
309 current_prompt = match chat_history.write().await.pop() {
311 Some(prompt) => prompt,
312 None => unreachable!("Chat history should never be empty at this point"),
313 };
314
315 if !did_call_tool {
316 yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
317 break;
318 }
319 }
320
321 if max_depth_reached {
322 yield Err(PromptError::MaxDepthError {
323 max_depth: req.max_depth,
324 chat_history: (*chat_history.read().await).clone(),
325 prompt: last_prompt_error.into(),
326 }.into());
327 }
328
329 })
330 }
331
332 inner(self, agent_name)
333 }
334}
335
336impl<'a, M> IntoFuture for StreamingPromptRequest<'a, M>
337where
338 M: CompletionModel + 'static,
339 <M as CompletionModel>::StreamingResponse: Send,
340{
341 type Output = StreamingResult<'a>; type IntoFuture = Pin<Box<dyn futures::Future<Output = Self::Output> + 'a>>;
343
344 fn into_future(self) -> Self::IntoFuture {
345 Box::pin(async move { self.send().await })
347 }
348}
349
350pub async fn stream_to_stdout(
352 stream: &mut StreamingResult<'_>,
353) -> Result<FinalResponse, std::io::Error> {
354 let mut final_res = FinalResponse::empty();
355 print!("Response: ");
356 while let Some(content) = stream.next().await {
357 match content {
358 Ok(MultiTurnStreamItem::Text(Text { text })) => {
359 print!("{text}");
360 std::io::Write::flush(&mut std::io::stdout())?;
361 }
362 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
363 final_res = res;
364 }
365 Err(err) => {
366 eprintln!("Error: {err}");
367 }
368 }
369 }
370
371 Ok(final_res)
372}