rig/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    agent::CancelSignal,
4    completion::GetTokenUsage,
5    json_utils,
6    message::{AssistantContent, Reasoning, ToolResult, ToolResultContent, UserContent},
7    streaming::{StreamedAssistantContent, StreamedUserContent, StreamingCompletion},
8    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
9};
10use futures::{Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use std::{pin::Pin, sync::Arc};
13use tokio::sync::RwLock;
14use tracing::info_span;
15use tracing_futures::Instrument;
16
17use crate::{
18    agent::Agent,
19    completion::{CompletionError, CompletionModel, PromptError},
20    message::{Message, Text},
21    tool::ToolSetError,
22};
23
24#[cfg(not(target_arch = "wasm32"))]
25pub type StreamingResult<R> =
26    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
27
28#[cfg(target_arch = "wasm32")]
29pub type StreamingResult<R> =
30    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
31
32#[derive(Deserialize, Serialize, Debug, Clone)]
33#[serde(tag = "type", rename_all = "camelCase")]
34#[non_exhaustive]
35pub enum MultiTurnStreamItem<R> {
36    /// A streamed assistant content item.
37    StreamAssistantItem(StreamedAssistantContent<R>),
38    /// A streamed user content item (mostly for tool results).
39    StreamUserItem(StreamedUserContent),
40    /// The final result from the stream.
41    FinalResponse(FinalResponse),
42}
43
44#[derive(Deserialize, Serialize, Debug, Clone)]
45#[serde(rename_all = "camelCase")]
46pub struct FinalResponse {
47    response: String,
48    aggregated_usage: crate::completion::Usage,
49}
50
51impl FinalResponse {
52    pub fn empty() -> Self {
53        Self {
54            response: String::new(),
55            aggregated_usage: crate::completion::Usage::new(),
56        }
57    }
58
59    pub fn response(&self) -> &str {
60        &self.response
61    }
62
63    pub fn usage(&self) -> crate::completion::Usage {
64        self.aggregated_usage
65    }
66}
67
68impl<R> MultiTurnStreamItem<R> {
69    pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
70        Self::StreamAssistantItem(item)
71    }
72
73    pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
74        Self::FinalResponse(FinalResponse {
75            response: response.to_string(),
76            aggregated_usage,
77        })
78    }
79}
80
81#[derive(Debug, thiserror::Error)]
82pub enum StreamingError {
83    #[error("CompletionError: {0}")]
84    Completion(#[from] CompletionError),
85    #[error("PromptError: {0}")]
86    Prompt(#[from] Box<PromptError>),
87    #[error("ToolSetError: {0}")]
88    Tool(#[from] ToolSetError),
89}
90
91/// A builder for creating prompt requests with customizable options.
92/// Uses generics to track which options have been set during the build process.
93///
94/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
95/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
96/// attempting to await (which will send the prompt request) can potentially return
97/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
98/// back to back.
99pub struct StreamingPromptRequest<M, P>
100where
101    M: CompletionModel,
102    P: StreamingPromptHook<M> + 'static,
103{
104    /// The prompt message to send to the model
105    prompt: Message,
106    /// Optional chat history to include with the prompt
107    /// Note: chat history needs to outlive the agent as it might be used with other agents
108    chat_history: Option<Vec<Message>>,
109    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
110    max_depth: usize,
111    /// The agent to use for execution
112    agent: Arc<Agent<M>>,
113    /// Optional per-request hook for events
114    hook: Option<P>,
115}
116
117impl<M, P> StreamingPromptRequest<M, P>
118where
119    M: CompletionModel + 'static,
120    <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
121    P: StreamingPromptHook<M>,
122{
123    /// Create a new PromptRequest with the given prompt and model
124    pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
125        Self {
126            prompt: prompt.into(),
127            chat_history: None,
128            max_depth: 0,
129            agent,
130            hook: None,
131        }
132    }
133
134    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
135    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
136    pub fn multi_turn(mut self, depth: usize) -> Self {
137        self.max_depth = depth;
138        self
139    }
140
141    /// Add chat history to the prompt request
142    pub fn with_history(mut self, history: Vec<Message>) -> Self {
143        self.chat_history = Some(history);
144        self
145    }
146
147    /// Attach a per-request hook for tool call events
148    pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
149    where
150        P2: StreamingPromptHook<M>,
151    {
152        StreamingPromptRequest {
153            prompt: self.prompt,
154            chat_history: self.chat_history,
155            max_depth: self.max_depth,
156            agent: self.agent,
157            hook: Some(hook),
158        }
159    }
160
161    #[cfg_attr(feature = "worker", worker::send)]
162    async fn send(self) -> StreamingResult<M::StreamingResponse> {
163        let agent_span = if tracing::Span::current().is_disabled() {
164            info_span!(
165                "invoke_agent",
166                gen_ai.operation.name = "invoke_agent",
167                gen_ai.agent.name = self.agent.name(),
168                gen_ai.system_instructions = self.agent.preamble,
169                gen_ai.prompt = tracing::field::Empty,
170                gen_ai.completion = tracing::field::Empty,
171                gen_ai.usage.input_tokens = tracing::field::Empty,
172                gen_ai.usage.output_tokens = tracing::field::Empty,
173            )
174        } else {
175            tracing::Span::current()
176        };
177
178        let prompt = self.prompt;
179        if let Some(text) = prompt.rag_text() {
180            agent_span.record("gen_ai.prompt", text);
181        }
182
183        let agent = self.agent;
184
185        let chat_history = if let Some(history) = self.chat_history {
186            Arc::new(RwLock::new(history))
187        } else {
188            Arc::new(RwLock::new(vec![]))
189        };
190
191        let mut current_max_depth = 0;
192        let mut last_prompt_error = String::new();
193
194        let mut last_text_response = String::new();
195        let mut is_text_response = false;
196        let mut max_depth_reached = false;
197
198        let mut aggregated_usage = crate::completion::Usage::new();
199
200        let cancel_signal = CancelSignal::new();
201
202        Box::pin(async_stream::stream! {
203            let _guard = agent_span.enter();
204            let mut current_prompt = prompt.clone();
205            let mut did_call_tool = false;
206
207            'outer: loop {
208                if current_max_depth > self.max_depth + 1 {
209                    last_prompt_error = current_prompt.rag_text().unwrap_or_default();
210                    max_depth_reached = true;
211                    break;
212                }
213
214                current_max_depth += 1;
215
216                if self.max_depth > 1 {
217                    tracing::info!(
218                        "Current conversation depth: {}/{}",
219                        current_max_depth,
220                        self.max_depth
221                    );
222                }
223
224                if let Some(ref hook) = self.hook {
225                    let reader = chat_history.read().await;
226                    let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history");
227                    let chat_history_except_last = reader[..reader.len() - 1].to_vec();
228
229                    hook.on_completion_call(&prompt, &chat_history_except_last, cancel_signal.clone())
230                    .await;
231
232                    if cancel_signal.is_cancelled() {
233                        yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
234                    }
235                }
236
237                let chat_stream_span = info_span!(
238                    target: "rig::agent_chat",
239                    parent: tracing::Span::current(),
240                    "chat_streaming",
241                    gen_ai.operation.name = "chat",
242                    gen_ai.system_instructions = &agent.preamble,
243                    gen_ai.provider.name = tracing::field::Empty,
244                    gen_ai.request.model = tracing::field::Empty,
245                    gen_ai.response.id = tracing::field::Empty,
246                    gen_ai.response.model = tracing::field::Empty,
247                    gen_ai.usage.output_tokens = tracing::field::Empty,
248                    gen_ai.usage.input_tokens = tracing::field::Empty,
249                    gen_ai.input.messages = tracing::field::Empty,
250                    gen_ai.output.messages = tracing::field::Empty,
251                );
252
253                let mut stream = tracing::Instrument::instrument(
254                    agent
255                    .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
256                    .await?
257                    .stream(), chat_stream_span
258                )
259
260                .await?;
261
262                chat_history.write().await.push(current_prompt.clone());
263
264                let mut tool_calls = vec![];
265                let mut tool_results = vec![];
266
267                while let Some(content) = stream.next().await {
268                    match content {
269                        Ok(StreamedAssistantContent::Text(text)) => {
270                            if !is_text_response {
271                                last_text_response = String::new();
272                                is_text_response = true;
273                            }
274                            last_text_response.push_str(&text.text);
275                            if let Some(ref hook) = self.hook {
276                                hook.on_text_delta(&text.text, &last_text_response, cancel_signal.clone()).await;
277                                if cancel_signal.is_cancelled() {
278                                    yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
279                                }
280                            }
281                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
282                            did_call_tool = false;
283                        },
284                        Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
285                            let tool_span = info_span!(
286                                parent: tracing::Span::current(),
287                                "execute_tool",
288                                gen_ai.operation.name = "execute_tool",
289                                gen_ai.tool.type = "function",
290                                gen_ai.tool.name = tracing::field::Empty,
291                                gen_ai.tool.call.id = tracing::field::Empty,
292                                gen_ai.tool.call.arguments = tracing::field::Empty,
293                                gen_ai.tool.call.result = tracing::field::Empty
294                            );
295
296                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall(tool_call.clone())));
297
298                            let tc_result = async {
299                                let tool_span = tracing::Span::current();
300                                let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
301                                if let Some(ref hook) = self.hook {
302                                    hook.on_tool_call(&tool_call.function.name, &tool_args, cancel_signal.clone()).await;
303                                    if cancel_signal.is_cancelled() {
304                                        return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
305                                    }
306                                }
307
308                                tool_span.record("gen_ai.tool.name", &tool_call.function.name);
309                                tool_span.record("gen_ai.tool.call.arguments", &tool_args);
310
311                                let tool_result = match
312                                agent.tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
313                                    Ok(thing) => thing,
314                                    Err(e) => {
315                                        tracing::warn!("Error while calling tool: {e}");
316                                        e.to_string()
317                                    }
318                                };
319
320                                tool_span.record("gen_ai.tool.call.result", &tool_result);
321
322                                if let Some(ref hook) = self.hook {
323                                    hook.on_tool_result(&tool_call.function.name, &tool_args, &tool_result.to_string(), cancel_signal.clone())
324                                    .await;
325
326                                    if cancel_signal.is_cancelled() {
327                                        return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
328                                    }
329                                }
330
331                                let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
332
333                                tool_calls.push(tool_call_msg);
334                                tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
335
336                                did_call_tool = true;
337                                Ok(tool_result)
338                            }.instrument(tool_span).await;
339
340                            match tc_result {
341                                Ok(text) => {
342                                    let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: OneOrMany::one(ToolResultContent::Text(Text { text })) };
343                                    yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult(tr)));
344                                }
345                                Err(e) => {
346                                    yield Err(e);
347                                }
348                            }
349                        },
350                        Ok(StreamedAssistantContent::ToolCallDelta { id, delta }) => {
351                            if let Some(ref hook) = self.hook {
352                                hook.on_tool_call_delta(&id, &delta, cancel_signal.clone())
353                                .await;
354
355                                if cancel_signal.is_cancelled() {
356                                    yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
357                                }
358                            }
359                        }
360                        Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })) => {
361                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })));
362                            did_call_tool = false;
363                        },
364                        Ok(StreamedAssistantContent::Final(final_resp)) => {
365                            if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
366                            if is_text_response {
367                                if let Some(ref hook) = self.hook {
368                                    hook.on_stream_completion_response_finish(&prompt, &final_resp, cancel_signal.clone()).await;
369
370                                    if cancel_signal.is_cancelled() {
371                                        yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
372                                    }
373                                }
374
375                                tracing::Span::current().record("gen_ai.completion", &last_text_response);
376                                yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
377                                is_text_response = false;
378                            }
379                        }
380                        Err(e) => {
381                            yield Err(e.into());
382                            break 'outer;
383                        }
384                    }
385                }
386
387                // Add (parallel) tool calls to chat history
388                if !tool_calls.is_empty() {
389                    chat_history.write().await.push(Message::Assistant {
390                        id: None,
391                        content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
392                    });
393                }
394
395                // Add tool results to chat history
396                for (id, call_id, tool_result) in tool_results {
397                    if let Some(call_id) = call_id {
398                        chat_history.write().await.push(Message::User {
399                            content: OneOrMany::one(UserContent::tool_result_with_call_id(
400                                &id,
401                                call_id.clone(),
402                                OneOrMany::one(ToolResultContent::text(&tool_result)),
403                            )),
404                        });
405                    } else {
406                        chat_history.write().await.push(Message::User {
407                            content: OneOrMany::one(UserContent::tool_result(
408                                &id,
409                                OneOrMany::one(ToolResultContent::text(&tool_result)),
410                            )),
411                        });
412                    }
413                }
414
415                // Set the current prompt to the last message in the chat history
416                current_prompt = match chat_history.write().await.pop() {
417                    Some(prompt) => prompt,
418                    None => unreachable!("Chat history should never be empty at this point"),
419                };
420
421                if !did_call_tool {
422                    let current_span = tracing::Span::current();
423                    current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
424                    current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
425                    tracing::info!("Agent multi-turn stream finished");
426                    yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
427                    break;
428                }
429            }
430
431            if max_depth_reached {
432                yield Err(Box::new(PromptError::MaxDepthError {
433                    max_depth: self.max_depth,
434                    chat_history: Box::new((*chat_history.read().await).clone()),
435                    prompt: last_prompt_error.clone().into(),
436                }).into());
437            }
438
439        })
440    }
441}
442
443impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
444where
445    M: CompletionModel + 'static,
446    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
447    P: StreamingPromptHook<M> + 'static,
448{
449    type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
450    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
451
452    fn into_future(self) -> Self::IntoFuture {
453        // Wrap send() in a future, because send() returns a stream immediately
454        Box::pin(async move { self.send().await })
455    }
456}
457
458/// helper function to stream a completion selfuest to stdout
459pub async fn stream_to_stdout<R>(
460    stream: &mut StreamingResult<R>,
461) -> Result<FinalResponse, std::io::Error> {
462    let mut final_res = FinalResponse::empty();
463    print!("Response: ");
464    while let Some(content) = stream.next().await {
465        match content {
466            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
467                Text { text },
468            ))) => {
469                print!("{text}");
470                std::io::Write::flush(&mut std::io::stdout()).unwrap();
471            }
472            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
473                Reasoning { reasoning, .. },
474            ))) => {
475                let reasoning = reasoning.join("\n");
476                print!("{reasoning}");
477                std::io::Write::flush(&mut std::io::stdout()).unwrap();
478            }
479            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
480                final_res = res;
481            }
482            Err(err) => {
483                eprintln!("Error: {err}");
484            }
485            _ => {}
486        }
487    }
488
489    Ok(final_res)
490}
491
492// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
493/// Trait for per-request hooks to observe tool call events.
494pub trait StreamingPromptHook<M>: Clone + Send + Sync
495where
496    M: CompletionModel,
497{
498    #[allow(unused_variables)]
499    /// Called before the prompt is sent to the model
500    fn on_completion_call(
501        &self,
502        prompt: &Message,
503        history: &[Message],
504        cancel_sig: CancelSignal,
505    ) -> impl Future<Output = ()> + Send {
506        async {}
507    }
508
509    #[allow(unused_variables)]
510    /// Called when receiving a text delta
511    fn on_text_delta(
512        &self,
513        text_delta: &str,
514        aggregated_text: &str,
515        cancel_sig: CancelSignal,
516    ) -> impl Future<Output = ()> + Send {
517        async {}
518    }
519
520    #[allow(unused_variables)]
521    /// Called when receiving a tool call delta
522    fn on_tool_call_delta(
523        &self,
524        tool_call_id: &str,
525        tool_call_delta: &str,
526        cancel_sig: CancelSignal,
527    ) -> impl Future<Output = ()> + Send {
528        async {}
529    }
530
531    #[allow(unused_variables)]
532    /// Called after the model provider has finished streaming a text response from their completion API to the client.
533    fn on_stream_completion_response_finish(
534        &self,
535        prompt: &Message,
536        response: &<M as CompletionModel>::StreamingResponse,
537        cancel_sig: CancelSignal,
538    ) -> impl Future<Output = ()> + Send {
539        async {}
540    }
541
542    #[allow(unused_variables)]
543    /// Called before a tool is invoked.
544    fn on_tool_call(
545        &self,
546        tool_name: &str,
547        args: &str,
548        cancel_sig: CancelSignal,
549    ) -> impl Future<Output = ()> + Send {
550        async {}
551    }
552
553    #[allow(unused_variables)]
554    /// Called after a tool is invoked (and a result has been returned).
555    fn on_tool_result(
556        &self,
557        tool_name: &str,
558        args: &str,
559        result: &str,
560        cancel_sig: CancelSignal,
561    ) -> impl Future<Output = ()> + Send {
562        async {}
563    }
564}
565
566impl<M> StreamingPromptHook<M> for () where M: CompletionModel {}