rig/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    agent::CancelSignal,
4    completion::GetTokenUsage,
5    message::{AssistantContent, Reasoning, ToolResultContent, UserContent},
6    streaming::{StreamedAssistantContent, StreamingCompletion},
7    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
8};
9use futures::{Stream, StreamExt};
10use serde::{Deserialize, Serialize};
11use std::{pin::Pin, sync::Arc};
12use tokio::sync::RwLock;
13use tracing::info_span;
14use tracing_futures::Instrument;
15
16use crate::{
17    agent::Agent,
18    completion::{CompletionError, CompletionModel, PromptError},
19    message::{Message, Text},
20    tool::ToolSetError,
21};
22
23#[cfg(not(target_arch = "wasm32"))]
24pub type StreamingResult<R> =
25    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
26
27#[cfg(target_arch = "wasm32")]
28pub type StreamingResult<R> =
29    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
30
31#[derive(Deserialize, Serialize, Debug, Clone)]
32#[serde(tag = "type", rename_all = "camelCase")]
33#[non_exhaustive]
34pub enum MultiTurnStreamItem<R> {
35    StreamItem(StreamedAssistantContent<R>),
36    FinalResponse(FinalResponse),
37}
38
39#[derive(Deserialize, Serialize, Debug, Clone)]
40#[serde(rename_all = "camelCase")]
41pub struct FinalResponse {
42    response: String,
43    aggregated_usage: crate::completion::Usage,
44}
45
46impl FinalResponse {
47    pub fn empty() -> Self {
48        Self {
49            response: String::new(),
50            aggregated_usage: crate::completion::Usage::new(),
51        }
52    }
53
54    pub fn response(&self) -> &str {
55        &self.response
56    }
57
58    pub fn usage(&self) -> crate::completion::Usage {
59        self.aggregated_usage
60    }
61}
62
63impl<R> MultiTurnStreamItem<R> {
64    pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
65        Self::StreamItem(item)
66    }
67
68    pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
69        Self::FinalResponse(FinalResponse {
70            response: response.to_string(),
71            aggregated_usage,
72        })
73    }
74}
75
76#[derive(Debug, thiserror::Error)]
77pub enum StreamingError {
78    #[error("CompletionError: {0}")]
79    Completion(#[from] CompletionError),
80    #[error("PromptError: {0}")]
81    Prompt(#[from] Box<PromptError>),
82    #[error("ToolSetError: {0}")]
83    Tool(#[from] ToolSetError),
84}
85
86/// A builder for creating prompt requests with customizable options.
87/// Uses generics to track which options have been set during the build process.
88///
89/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
90/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
91/// attempting to await (which will send the prompt request) can potentially return
92/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
93/// back to back.
94pub struct StreamingPromptRequest<M, P>
95where
96    M: CompletionModel,
97    P: StreamingPromptHook<M> + 'static,
98{
99    /// The prompt message to send to the model
100    prompt: Message,
101    /// Optional chat history to include with the prompt
102    /// Note: chat history needs to outlive the agent as it might be used with other agents
103    chat_history: Option<Vec<Message>>,
104    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
105    max_depth: usize,
106    /// The agent to use for execution
107    agent: Arc<Agent<M>>,
108    /// Optional per-request hook for events
109    hook: Option<P>,
110}
111
112impl<M, P> StreamingPromptRequest<M, P>
113where
114    M: CompletionModel + 'static,
115    <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
116    P: StreamingPromptHook<M>,
117{
118    /// Create a new PromptRequest with the given prompt and model
119    pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
120        Self {
121            prompt: prompt.into(),
122            chat_history: None,
123            max_depth: 0,
124            agent,
125            hook: None,
126        }
127    }
128
129    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
130    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
131    pub fn multi_turn(mut self, depth: usize) -> Self {
132        self.max_depth = depth;
133        self
134    }
135
136    /// Add chat history to the prompt request
137    pub fn with_history(mut self, history: Vec<Message>) -> Self {
138        self.chat_history = Some(history);
139        self
140    }
141
142    /// Attach a per-request hook for tool call events
143    pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
144    where
145        P2: StreamingPromptHook<M>,
146    {
147        StreamingPromptRequest {
148            prompt: self.prompt,
149            chat_history: self.chat_history,
150            max_depth: self.max_depth,
151            agent: self.agent,
152            hook: Some(hook),
153        }
154    }
155
156    #[cfg_attr(feature = "worker", worker::send)]
157    async fn send(self) -> StreamingResult<M::StreamingResponse> {
158        let agent_span = if tracing::Span::current().is_disabled() {
159            info_span!(
160                "invoke_agent",
161                gen_ai.operation.name = "invoke_agent",
162                gen_ai.agent.name = self.agent.name(),
163                gen_ai.system_instructions = self.agent.preamble,
164                gen_ai.prompt = tracing::field::Empty,
165                gen_ai.completion = tracing::field::Empty,
166                gen_ai.usage.input_tokens = tracing::field::Empty,
167                gen_ai.usage.output_tokens = tracing::field::Empty,
168            )
169        } else {
170            tracing::Span::current()
171        };
172
173        let prompt = self.prompt;
174        if let Some(text) = prompt.rag_text() {
175            agent_span.record("gen_ai.prompt", text);
176        }
177
178        let agent = self.agent;
179
180        let chat_history = if let Some(history) = self.chat_history {
181            Arc::new(RwLock::new(history))
182        } else {
183            Arc::new(RwLock::new(vec![]))
184        };
185
186        let mut current_max_depth = 0;
187        let mut last_prompt_error = String::new();
188
189        let mut last_text_response = String::new();
190        let mut is_text_response = false;
191        let mut max_depth_reached = false;
192
193        let mut aggregated_usage = crate::completion::Usage::new();
194
195        let cancel_signal = CancelSignal::new();
196
197        Box::pin(async_stream::stream! {
198            let _guard = agent_span.enter();
199            let mut current_prompt = prompt.clone();
200            let mut did_call_tool = false;
201
202            'outer: loop {
203                if current_max_depth > self.max_depth + 1 {
204                    last_prompt_error = current_prompt.rag_text().unwrap_or_default();
205                    max_depth_reached = true;
206                    break;
207                }
208
209                current_max_depth += 1;
210
211                if self.max_depth > 1 {
212                    tracing::info!(
213                        "Current conversation depth: {}/{}",
214                        current_max_depth,
215                        self.max_depth
216                    );
217                }
218
219                if let Some(ref hook) = self.hook {
220                    let reader = chat_history.read().await;
221                    let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history");
222                    let chat_history_except_last = reader[..reader.len() - 1].to_vec();
223
224                    hook.on_completion_call(&prompt, &chat_history_except_last, cancel_signal.clone())
225                    .await;
226
227                    if cancel_signal.is_cancelled() {
228                        yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
229                    }
230                }
231
232                let chat_stream_span = info_span!(
233                    target: "rig::agent_chat",
234                    parent: tracing::Span::current(),
235                    "chat_streaming",
236                    gen_ai.operation.name = "chat",
237                    gen_ai.system_instructions = &agent.preamble,
238                    gen_ai.provider.name = tracing::field::Empty,
239                    gen_ai.request.model = tracing::field::Empty,
240                    gen_ai.response.id = tracing::field::Empty,
241                    gen_ai.response.model = tracing::field::Empty,
242                    gen_ai.usage.output_tokens = tracing::field::Empty,
243                    gen_ai.usage.input_tokens = tracing::field::Empty,
244                    gen_ai.input.messages = tracing::field::Empty,
245                    gen_ai.output.messages = tracing::field::Empty,
246                );
247
248                let mut stream = tracing::Instrument::instrument(
249                    agent
250                    .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
251                    .await?
252                    .stream(), chat_stream_span
253                )
254
255                .await?;
256
257                chat_history.write().await.push(current_prompt.clone());
258
259                let mut tool_calls = vec![];
260                let mut tool_results = vec![];
261
262                while let Some(content) = stream.next().await {
263                    match content {
264                        Ok(StreamedAssistantContent::Text(text)) => {
265                            if !is_text_response {
266                                last_text_response = String::new();
267                                is_text_response = true;
268                            }
269                            last_text_response.push_str(&text.text);
270                            if let Some(ref hook) = self.hook {
271                                hook.on_text_delta(&text.text, &last_text_response, cancel_signal.clone()).await;
272                                if cancel_signal.is_cancelled() {
273                                    yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
274                                }
275                            }
276                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
277                            did_call_tool = false;
278                        },
279                        Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
280                            let tool_span = info_span!(
281                                parent: tracing::Span::current(),
282                                "execute_tool",
283                                gen_ai.operation.name = "execute_tool",
284                                gen_ai.tool.type = "function",
285                                gen_ai.tool.name = tracing::field::Empty,
286                                gen_ai.tool.call.id = tracing::field::Empty,
287                                gen_ai.tool.call.arguments = tracing::field::Empty,
288                                gen_ai.tool.call.result = tracing::field::Empty
289                            );
290
291                            let res = async {
292                                let tool_span = tracing::Span::current();
293                                if let Some(ref hook) = self.hook {
294                                    hook.on_tool_call(&tool_call.function.name, &tool_call.function.arguments.to_string(), cancel_signal.clone()).await;
295                                    if cancel_signal.is_cancelled() {
296                                        return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
297                                    }
298                                }
299
300                                tool_span.record("gen_ai.tool.name", &tool_call.function.name);
301                                tool_span.record("gen_ai.tool.call.arguments", tool_call.function.arguments.to_string());
302
303                                let tool_result = match
304                                agent.tool_server_handle.call_tool(&tool_call.function.name, &tool_call.function.arguments.to_string()).await {
305                                    Ok(thing) => thing,
306                                    Err(e) => {
307                                        tracing::warn!("Error while calling tool: {e}");
308                                        e.to_string()
309                                    }
310                                };
311
312                                tool_span.record("gen_ai.tool.call.result", &tool_result);
313
314                                if let Some(ref hook) = self.hook {
315                                    hook.on_tool_result(&tool_call.function.name, &tool_call.function.arguments.to_string(), &tool_result.to_string(), cancel_signal.clone())
316                                    .await;
317
318                                    if cancel_signal.is_cancelled() {
319                                        return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
320                                    }
321                                }
322
323                                let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
324
325                                tool_calls.push(tool_call_msg);
326                                tool_results.push((tool_call.id, tool_call.call_id, tool_result));
327
328                                did_call_tool = true;
329                                Ok(())
330                                // break;
331                            }.instrument(tool_span).await;
332
333                            if let Err(e) = res {
334                                yield Err(e);
335                            }
336                        },
337                        Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id })) => {
338                            chat_history.write().await.push(rig::message::Message::Assistant {
339                                id: None,
340                                content: OneOrMany::one(AssistantContent::Reasoning(Reasoning {
341                                    reasoning: reasoning.clone(), id: id.clone()
342                                }))
343                            });
344                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id })));
345                            did_call_tool = false;
346                        },
347                        Ok(StreamedAssistantContent::Final(final_resp)) => {
348                            if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
349                            if is_text_response {
350                                if let Some(ref hook) = self.hook {
351                                    hook.on_stream_completion_response_finish(&prompt, &final_resp, cancel_signal.clone()).await;
352
353                                    if cancel_signal.is_cancelled() {
354                                        yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
355                                    }
356                                }
357
358                                tracing::Span::current().record("gen_ai.completion", &last_text_response);
359                                yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
360                                is_text_response = false;
361                            }
362                        }
363                        Err(e) => {
364                            yield Err(e.into());
365                            break 'outer;
366                        }
367                    }
368                }
369
370                // Add (parallel) tool calls to chat history
371                if !tool_calls.is_empty() {
372                    chat_history.write().await.push(Message::Assistant {
373                        id: None,
374                        content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
375                    });
376                }
377
378                // Add tool results to chat history
379                for (id, call_id, tool_result) in tool_results {
380                    if let Some(call_id) = call_id {
381                        chat_history.write().await.push(Message::User {
382                            content: OneOrMany::one(UserContent::tool_result_with_call_id(
383                                &id,
384                                call_id.clone(),
385                                OneOrMany::one(ToolResultContent::text(&tool_result)),
386                            )),
387                        });
388                    } else {
389                        chat_history.write().await.push(Message::User {
390                            content: OneOrMany::one(UserContent::tool_result(
391                                &id,
392                                OneOrMany::one(ToolResultContent::text(&tool_result)),
393                            )),
394                        });
395                    }
396                }
397
398                // Set the current prompt to the last message in the chat history
399                current_prompt = match chat_history.write().await.pop() {
400                    Some(prompt) => prompt,
401                    None => unreachable!("Chat history should never be empty at this point"),
402                };
403
404                if !did_call_tool {
405                    let current_span = tracing::Span::current();
406                    current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
407                    current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
408                    tracing::info!("Agent multi-turn stream finished");
409                    yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
410                    break;
411                }
412            }
413
414            if max_depth_reached {
415                yield Err(Box::new(PromptError::MaxDepthError {
416                    max_depth: self.max_depth,
417                    chat_history: Box::new((*chat_history.read().await).clone()),
418                    prompt: last_prompt_error.clone().into(),
419                }).into());
420            }
421
422        })
423    }
424}
425
426impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
427where
428    M: CompletionModel + 'static,
429    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
430    P: StreamingPromptHook<M> + 'static,
431{
432    type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
433    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
434
435    fn into_future(self) -> Self::IntoFuture {
436        // Wrap send() in a future, because send() returns a stream immediately
437        Box::pin(async move { self.send().await })
438    }
439}
440
441/// helper function to stream a completion selfuest to stdout
442pub async fn stream_to_stdout<R>(
443    stream: &mut StreamingResult<R>,
444) -> Result<FinalResponse, std::io::Error> {
445    let mut final_res = FinalResponse::empty();
446    print!("Response: ");
447    while let Some(content) = stream.next().await {
448        match content {
449            Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Text(Text { text }))) => {
450                print!("{text}");
451                std::io::Write::flush(&mut std::io::stdout()).unwrap();
452            }
453            Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Reasoning(
454                Reasoning { reasoning, .. },
455            ))) => {
456                let reasoning = reasoning.join("\n");
457                print!("{reasoning}");
458                std::io::Write::flush(&mut std::io::stdout()).unwrap();
459            }
460            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
461                final_res = res;
462            }
463            Err(err) => {
464                eprintln!("Error: {err}");
465            }
466            _ => {}
467        }
468    }
469
470    Ok(final_res)
471}
472
473// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
474/// Trait for per-request hooks to observe tool call events.
475pub trait StreamingPromptHook<M>: Clone + Send + Sync
476where
477    M: CompletionModel,
478{
479    #[allow(unused_variables)]
480    /// Called before the prompt is sent to the model
481    fn on_completion_call(
482        &self,
483        prompt: &Message,
484        history: &[Message],
485        cancel_sig: CancelSignal,
486    ) -> impl Future<Output = ()> + Send {
487        async {}
488    }
489
490    #[allow(unused_variables)]
491    /// Called when receiving a text delta
492    fn on_text_delta(
493        &self,
494        text_delta: &str,
495        aggregated_text: &str,
496        cancel_sig: CancelSignal,
497    ) -> impl Future<Output = ()> + Send {
498        async {}
499    }
500
501    #[allow(unused_variables)]
502    /// Called after the model provider has finished streaming a text response from their completion API to the client.
503    fn on_stream_completion_response_finish(
504        &self,
505        prompt: &Message,
506        response: &<M as CompletionModel>::StreamingResponse,
507        cancel_sig: CancelSignal,
508    ) -> impl Future<Output = ()> + Send {
509        async {}
510    }
511
512    #[allow(unused_variables)]
513    /// Called before a tool is invoked.
514    fn on_tool_call(
515        &self,
516        tool_name: &str,
517        args: &str,
518        cancel_sig: CancelSignal,
519    ) -> impl Future<Output = ()> + Send {
520        async {}
521    }
522
523    #[allow(unused_variables)]
524    /// Called after a tool is invoked (and a result has been returned).
525    fn on_tool_result(
526        &self,
527        tool_name: &str,
528        args: &str,
529        result: &str,
530        cancel_sig: CancelSignal,
531    ) -> impl Future<Output = ()> + Send {
532        async {}
533    }
534}
535
536impl<M> StreamingPromptHook<M> for () where M: CompletionModel {}