rig/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    agent::CancelSignal,
4    completion::GetTokenUsage,
5    message::{AssistantContent, Reasoning, ToolResult, ToolResultContent, UserContent},
6    streaming::{StreamedAssistantContent, StreamedUserContent, StreamingCompletion},
7    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
8};
9use futures::{Stream, StreamExt};
10use serde::{Deserialize, Serialize};
11use std::{pin::Pin, sync::Arc};
12use tokio::sync::RwLock;
13use tracing::info_span;
14use tracing_futures::Instrument;
15
16use crate::{
17    agent::Agent,
18    completion::{CompletionError, CompletionModel, PromptError},
19    message::{Message, Text},
20    tool::ToolSetError,
21};
22
23#[cfg(not(target_arch = "wasm32"))]
24pub type StreamingResult<R> =
25    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
26
27#[cfg(target_arch = "wasm32")]
28pub type StreamingResult<R> =
29    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
30
31#[derive(Deserialize, Serialize, Debug, Clone)]
32#[serde(tag = "type", rename_all = "camelCase")]
33#[non_exhaustive]
34pub enum MultiTurnStreamItem<R> {
35    /// A streamed assistant content item.
36    StreamAssistantItem(StreamedAssistantContent<R>),
37    /// A streamed user content item (mostly for tool results).
38    StreamUserItem(StreamedUserContent),
39    /// The final result from the stream.
40    FinalResponse(FinalResponse),
41}
42
43#[derive(Deserialize, Serialize, Debug, Clone)]
44#[serde(rename_all = "camelCase")]
45pub struct FinalResponse {
46    response: String,
47    aggregated_usage: crate::completion::Usage,
48}
49
50impl FinalResponse {
51    pub fn empty() -> Self {
52        Self {
53            response: String::new(),
54            aggregated_usage: crate::completion::Usage::new(),
55        }
56    }
57
58    pub fn response(&self) -> &str {
59        &self.response
60    }
61
62    pub fn usage(&self) -> crate::completion::Usage {
63        self.aggregated_usage
64    }
65}
66
67impl<R> MultiTurnStreamItem<R> {
68    pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
69        Self::StreamAssistantItem(item)
70    }
71
72    pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
73        Self::FinalResponse(FinalResponse {
74            response: response.to_string(),
75            aggregated_usage,
76        })
77    }
78}
79
80#[derive(Debug, thiserror::Error)]
81pub enum StreamingError {
82    #[error("CompletionError: {0}")]
83    Completion(#[from] CompletionError),
84    #[error("PromptError: {0}")]
85    Prompt(#[from] Box<PromptError>),
86    #[error("ToolSetError: {0}")]
87    Tool(#[from] ToolSetError),
88}
89
90/// A builder for creating prompt requests with customizable options.
91/// Uses generics to track which options have been set during the build process.
92///
93/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
94/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
95/// attempting to await (which will send the prompt request) can potentially return
96/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
97/// back to back.
98pub struct StreamingPromptRequest<M, P>
99where
100    M: CompletionModel,
101    P: StreamingPromptHook<M> + 'static,
102{
103    /// The prompt message to send to the model
104    prompt: Message,
105    /// Optional chat history to include with the prompt
106    /// Note: chat history needs to outlive the agent as it might be used with other agents
107    chat_history: Option<Vec<Message>>,
108    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
109    max_depth: usize,
110    /// The agent to use for execution
111    agent: Arc<Agent<M>>,
112    /// Optional per-request hook for events
113    hook: Option<P>,
114}
115
116impl<M, P> StreamingPromptRequest<M, P>
117where
118    M: CompletionModel + 'static,
119    <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
120    P: StreamingPromptHook<M>,
121{
122    /// Create a new PromptRequest with the given prompt and model
123    pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
124        Self {
125            prompt: prompt.into(),
126            chat_history: None,
127            max_depth: 0,
128            agent,
129            hook: None,
130        }
131    }
132
133    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
134    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
135    pub fn multi_turn(mut self, depth: usize) -> Self {
136        self.max_depth = depth;
137        self
138    }
139
140    /// Add chat history to the prompt request
141    pub fn with_history(mut self, history: Vec<Message>) -> Self {
142        self.chat_history = Some(history);
143        self
144    }
145
146    /// Attach a per-request hook for tool call events
147    pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
148    where
149        P2: StreamingPromptHook<M>,
150    {
151        StreamingPromptRequest {
152            prompt: self.prompt,
153            chat_history: self.chat_history,
154            max_depth: self.max_depth,
155            agent: self.agent,
156            hook: Some(hook),
157        }
158    }
159
160    #[cfg_attr(feature = "worker", worker::send)]
161    async fn send(self) -> StreamingResult<M::StreamingResponse> {
162        let agent_span = if tracing::Span::current().is_disabled() {
163            info_span!(
164                "invoke_agent",
165                gen_ai.operation.name = "invoke_agent",
166                gen_ai.agent.name = self.agent.name(),
167                gen_ai.system_instructions = self.agent.preamble,
168                gen_ai.prompt = tracing::field::Empty,
169                gen_ai.completion = tracing::field::Empty,
170                gen_ai.usage.input_tokens = tracing::field::Empty,
171                gen_ai.usage.output_tokens = tracing::field::Empty,
172            )
173        } else {
174            tracing::Span::current()
175        };
176
177        let prompt = self.prompt;
178        if let Some(text) = prompt.rag_text() {
179            agent_span.record("gen_ai.prompt", text);
180        }
181
182        let agent = self.agent;
183
184        let chat_history = if let Some(history) = self.chat_history {
185            Arc::new(RwLock::new(history))
186        } else {
187            Arc::new(RwLock::new(vec![]))
188        };
189
190        let mut current_max_depth = 0;
191        let mut last_prompt_error = String::new();
192
193        let mut last_text_response = String::new();
194        let mut is_text_response = false;
195        let mut max_depth_reached = false;
196
197        let mut aggregated_usage = crate::completion::Usage::new();
198
199        let cancel_signal = CancelSignal::new();
200
201        Box::pin(async_stream::stream! {
202            let _guard = agent_span.enter();
203            let mut current_prompt = prompt.clone();
204            let mut did_call_tool = false;
205
206            'outer: loop {
207                if current_max_depth > self.max_depth + 1 {
208                    last_prompt_error = current_prompt.rag_text().unwrap_or_default();
209                    max_depth_reached = true;
210                    break;
211                }
212
213                current_max_depth += 1;
214
215                if self.max_depth > 1 {
216                    tracing::info!(
217                        "Current conversation depth: {}/{}",
218                        current_max_depth,
219                        self.max_depth
220                    );
221                }
222
223                if let Some(ref hook) = self.hook {
224                    let reader = chat_history.read().await;
225                    let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history");
226                    let chat_history_except_last = reader[..reader.len() - 1].to_vec();
227
228                    hook.on_completion_call(&prompt, &chat_history_except_last, cancel_signal.clone())
229                    .await;
230
231                    if cancel_signal.is_cancelled() {
232                        yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
233                    }
234                }
235
236                let chat_stream_span = info_span!(
237                    target: "rig::agent_chat",
238                    parent: tracing::Span::current(),
239                    "chat_streaming",
240                    gen_ai.operation.name = "chat",
241                    gen_ai.system_instructions = &agent.preamble,
242                    gen_ai.provider.name = tracing::field::Empty,
243                    gen_ai.request.model = tracing::field::Empty,
244                    gen_ai.response.id = tracing::field::Empty,
245                    gen_ai.response.model = tracing::field::Empty,
246                    gen_ai.usage.output_tokens = tracing::field::Empty,
247                    gen_ai.usage.input_tokens = tracing::field::Empty,
248                    gen_ai.input.messages = tracing::field::Empty,
249                    gen_ai.output.messages = tracing::field::Empty,
250                );
251
252                let mut stream = tracing::Instrument::instrument(
253                    agent
254                    .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
255                    .await?
256                    .stream(), chat_stream_span
257                )
258
259                .await?;
260
261                chat_history.write().await.push(current_prompt.clone());
262
263                let mut tool_calls = vec![];
264                let mut tool_results = vec![];
265
266                while let Some(content) = stream.next().await {
267                    match content {
268                        Ok(StreamedAssistantContent::Text(text)) => {
269                            if !is_text_response {
270                                last_text_response = String::new();
271                                is_text_response = true;
272                            }
273                            last_text_response.push_str(&text.text);
274                            if let Some(ref hook) = self.hook {
275                                hook.on_text_delta(&text.text, &last_text_response, cancel_signal.clone()).await;
276                                if cancel_signal.is_cancelled() {
277                                    yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
278                                }
279                            }
280                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
281                            did_call_tool = false;
282                        },
283                        Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
284                            let tool_span = info_span!(
285                                parent: tracing::Span::current(),
286                                "execute_tool",
287                                gen_ai.operation.name = "execute_tool",
288                                gen_ai.tool.type = "function",
289                                gen_ai.tool.name = tracing::field::Empty,
290                                gen_ai.tool.call.id = tracing::field::Empty,
291                                gen_ai.tool.call.arguments = tracing::field::Empty,
292                                gen_ai.tool.call.result = tracing::field::Empty
293                            );
294
295                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall(tool_call.clone())));
296
297                            let tc_result = async {
298                                let tool_span = tracing::Span::current();
299                                if let Some(ref hook) = self.hook {
300                                    hook.on_tool_call(&tool_call.function.name, &tool_call.function.arguments.to_string(), cancel_signal.clone()).await;
301                                    if cancel_signal.is_cancelled() {
302                                        return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
303                                    }
304                                }
305
306                                tool_span.record("gen_ai.tool.name", &tool_call.function.name);
307                                tool_span.record("gen_ai.tool.call.arguments", tool_call.function.arguments.to_string());
308
309                                let tool_result = match
310                                agent.tool_server_handle.call_tool(&tool_call.function.name, &tool_call.function.arguments.to_string()).await {
311                                    Ok(thing) => thing,
312                                    Err(e) => {
313                                        tracing::warn!("Error while calling tool: {e}");
314                                        e.to_string()
315                                    }
316                                };
317
318                                tool_span.record("gen_ai.tool.call.result", &tool_result);
319
320                                if let Some(ref hook) = self.hook {
321                                    hook.on_tool_result(&tool_call.function.name, &tool_call.function.arguments.to_string(), &tool_result.to_string(), cancel_signal.clone())
322                                    .await;
323
324                                    if cancel_signal.is_cancelled() {
325                                        return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
326                                    }
327                                }
328
329                                let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
330
331                                tool_calls.push(tool_call_msg);
332                                tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
333
334                                did_call_tool = true;
335                                Ok(tool_result)
336                            }.instrument(tool_span).await;
337
338                            match tc_result {
339                                Ok(text) => {
340                                    let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: OneOrMany::one(ToolResultContent::Text(Text { text })) };
341                                    yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult(tr)));
342                                }
343                                Err(e) => {
344                                    yield Err(e);
345                                }
346                            }
347                        },
348                        Ok(StreamedAssistantContent::ToolCallDelta { id, delta }) => {
349                            if let Some(ref hook) = self.hook {
350                                hook.on_tool_call_delta(&id, &delta, cancel_signal.clone())
351                                .await;
352
353                                if cancel_signal.is_cancelled() {
354                                    yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
355                                }
356                            }
357                        }
358                        Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })) => {
359                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })));
360                            did_call_tool = false;
361                        },
362                        Ok(StreamedAssistantContent::Final(final_resp)) => {
363                            if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
364                            if is_text_response {
365                                if let Some(ref hook) = self.hook {
366                                    hook.on_stream_completion_response_finish(&prompt, &final_resp, cancel_signal.clone()).await;
367
368                                    if cancel_signal.is_cancelled() {
369                                        yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
370                                    }
371                                }
372
373                                tracing::Span::current().record("gen_ai.completion", &last_text_response);
374                                yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
375                                is_text_response = false;
376                            }
377                        }
378                        Err(e) => {
379                            yield Err(e.into());
380                            break 'outer;
381                        }
382                    }
383                }
384
385                // Add (parallel) tool calls to chat history
386                if !tool_calls.is_empty() {
387                    chat_history.write().await.push(Message::Assistant {
388                        id: None,
389                        content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
390                    });
391                }
392
393                // Add tool results to chat history
394                for (id, call_id, tool_result) in tool_results {
395                    if let Some(call_id) = call_id {
396                        chat_history.write().await.push(Message::User {
397                            content: OneOrMany::one(UserContent::tool_result_with_call_id(
398                                &id,
399                                call_id.clone(),
400                                OneOrMany::one(ToolResultContent::text(&tool_result)),
401                            )),
402                        });
403                    } else {
404                        chat_history.write().await.push(Message::User {
405                            content: OneOrMany::one(UserContent::tool_result(
406                                &id,
407                                OneOrMany::one(ToolResultContent::text(&tool_result)),
408                            )),
409                        });
410                    }
411                }
412
413                // Set the current prompt to the last message in the chat history
414                current_prompt = match chat_history.write().await.pop() {
415                    Some(prompt) => prompt,
416                    None => unreachable!("Chat history should never be empty at this point"),
417                };
418
419                if !did_call_tool {
420                    let current_span = tracing::Span::current();
421                    current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
422                    current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
423                    tracing::info!("Agent multi-turn stream finished");
424                    yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
425                    break;
426                }
427            }
428
429            if max_depth_reached {
430                yield Err(Box::new(PromptError::MaxDepthError {
431                    max_depth: self.max_depth,
432                    chat_history: Box::new((*chat_history.read().await).clone()),
433                    prompt: last_prompt_error.clone().into(),
434                }).into());
435            }
436
437        })
438    }
439}
440
441impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
442where
443    M: CompletionModel + 'static,
444    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
445    P: StreamingPromptHook<M> + 'static,
446{
447    type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
448    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
449
450    fn into_future(self) -> Self::IntoFuture {
451        // Wrap send() in a future, because send() returns a stream immediately
452        Box::pin(async move { self.send().await })
453    }
454}
455
456/// helper function to stream a completion selfuest to stdout
457pub async fn stream_to_stdout<R>(
458    stream: &mut StreamingResult<R>,
459) -> Result<FinalResponse, std::io::Error> {
460    let mut final_res = FinalResponse::empty();
461    print!("Response: ");
462    while let Some(content) = stream.next().await {
463        match content {
464            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
465                Text { text },
466            ))) => {
467                print!("{text}");
468                std::io::Write::flush(&mut std::io::stdout()).unwrap();
469            }
470            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
471                Reasoning { reasoning, .. },
472            ))) => {
473                let reasoning = reasoning.join("\n");
474                print!("{reasoning}");
475                std::io::Write::flush(&mut std::io::stdout()).unwrap();
476            }
477            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
478                final_res = res;
479            }
480            Err(err) => {
481                eprintln!("Error: {err}");
482            }
483            _ => {}
484        }
485    }
486
487    Ok(final_res)
488}
489
490// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
491/// Trait for per-request hooks to observe tool call events.
492pub trait StreamingPromptHook<M>: Clone + Send + Sync
493where
494    M: CompletionModel,
495{
496    #[allow(unused_variables)]
497    /// Called before the prompt is sent to the model
498    fn on_completion_call(
499        &self,
500        prompt: &Message,
501        history: &[Message],
502        cancel_sig: CancelSignal,
503    ) -> impl Future<Output = ()> + Send {
504        async {}
505    }
506
507    #[allow(unused_variables)]
508    /// Called when receiving a text delta
509    fn on_text_delta(
510        &self,
511        text_delta: &str,
512        aggregated_text: &str,
513        cancel_sig: CancelSignal,
514    ) -> impl Future<Output = ()> + Send {
515        async {}
516    }
517
518    #[allow(unused_variables)]
519    /// Called when receiving a tool call delta
520    fn on_tool_call_delta(
521        &self,
522        tool_call_id: &str,
523        tool_call_delta: &str,
524        cancel_sig: CancelSignal,
525    ) -> impl Future<Output = ()> + Send {
526        async {}
527    }
528
529    #[allow(unused_variables)]
530    /// Called after the model provider has finished streaming a text response from their completion API to the client.
531    fn on_stream_completion_response_finish(
532        &self,
533        prompt: &Message,
534        response: &<M as CompletionModel>::StreamingResponse,
535        cancel_sig: CancelSignal,
536    ) -> impl Future<Output = ()> + Send {
537        async {}
538    }
539
540    #[allow(unused_variables)]
541    /// Called before a tool is invoked.
542    fn on_tool_call(
543        &self,
544        tool_name: &str,
545        args: &str,
546        cancel_sig: CancelSignal,
547    ) -> impl Future<Output = ()> + Send {
548        async {}
549    }
550
551    #[allow(unused_variables)]
552    /// Called after a tool is invoked (and a result has been returned).
553    fn on_tool_result(
554        &self,
555        tool_name: &str,
556        args: &str,
557        result: &str,
558        cancel_sig: CancelSignal,
559    ) -> impl Future<Output = ()> + Send {
560        async {}
561    }
562}
563
564impl<M> StreamingPromptHook<M> for () where M: CompletionModel {}