rig/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    agent::CancelSignal,
4    completion::GetTokenUsage,
5    message::{AssistantContent, Reasoning, ToolResultContent, UserContent},
6    streaming::{StreamedAssistantContent, StreamingCompletion},
7    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
8};
9use futures::{Stream, StreamExt};
10use serde::{Deserialize, Serialize};
11use std::{pin::Pin, sync::Arc};
12use tokio::sync::RwLock;
13use tracing::info_span;
14use tracing_futures::Instrument;
15
16use crate::{
17    agent::Agent,
18    completion::{CompletionError, CompletionModel, PromptError},
19    message::{Message, Text},
20    tool::ToolSetError,
21};
22
23#[cfg(not(target_arch = "wasm32"))]
24pub type StreamingResult<R> =
25    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
26
27#[cfg(target_arch = "wasm32")]
28pub type StreamingResult<R> =
29    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
30
31#[derive(Deserialize, Serialize, Debug, Clone)]
32#[serde(tag = "type", rename_all = "camelCase")]
33#[non_exhaustive]
34pub enum MultiTurnStreamItem<R> {
35    StreamItem(StreamedAssistantContent<R>),
36    FinalResponse(FinalResponse),
37}
38
39#[derive(Deserialize, Serialize, Debug, Clone)]
40#[serde(rename_all = "camelCase")]
41pub struct FinalResponse {
42    response: String,
43    aggregated_usage: crate::completion::Usage,
44}
45
46impl FinalResponse {
47    pub fn empty() -> Self {
48        Self {
49            response: String::new(),
50            aggregated_usage: crate::completion::Usage::new(),
51        }
52    }
53
54    pub fn response(&self) -> &str {
55        &self.response
56    }
57
58    pub fn usage(&self) -> crate::completion::Usage {
59        self.aggregated_usage
60    }
61}
62
63impl<R> MultiTurnStreamItem<R> {
64    pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
65        Self::StreamItem(item)
66    }
67
68    pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
69        Self::FinalResponse(FinalResponse {
70            response: response.to_string(),
71            aggregated_usage,
72        })
73    }
74}
75
76#[derive(Debug, thiserror::Error)]
77pub enum StreamingError {
78    #[error("CompletionError: {0}")]
79    Completion(#[from] CompletionError),
80    #[error("PromptError: {0}")]
81    Prompt(#[from] Box<PromptError>),
82    #[error("ToolSetError: {0}")]
83    Tool(#[from] ToolSetError),
84}
85
86/// A builder for creating prompt requests with customizable options.
87/// Uses generics to track which options have been set during the build process.
88///
89/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
90/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
91/// attempting to await (which will send the prompt request) can potentially return
92/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
93/// back to back.
94pub struct StreamingPromptRequest<M, P>
95where
96    M: CompletionModel,
97    P: StreamingPromptHook<M> + 'static,
98{
99    /// The prompt message to send to the model
100    prompt: Message,
101    /// Optional chat history to include with the prompt
102    /// Note: chat history needs to outlive the agent as it might be used with other agents
103    chat_history: Option<Vec<Message>>,
104    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
105    max_depth: usize,
106    /// The agent to use for execution
107    agent: Arc<Agent<M>>,
108    /// Optional per-request hook for events
109    hook: Option<P>,
110}
111
112impl<M, P> StreamingPromptRequest<M, P>
113where
114    M: CompletionModel + 'static,
115    <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
116    P: StreamingPromptHook<M>,
117{
118    /// Create a new PromptRequest with the given prompt and model
119    pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
120        Self {
121            prompt: prompt.into(),
122            chat_history: None,
123            max_depth: 0,
124            agent,
125            hook: None,
126        }
127    }
128
129    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
130    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
131    pub fn multi_turn(mut self, depth: usize) -> Self {
132        self.max_depth = depth;
133        self
134    }
135
136    /// Add chat history to the prompt request
137    pub fn with_history(mut self, history: Vec<Message>) -> Self {
138        self.chat_history = Some(history);
139        self
140    }
141
142    /// Attach a per-request hook for tool call events
143    pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
144    where
145        P2: StreamingPromptHook<M>,
146    {
147        StreamingPromptRequest {
148            prompt: self.prompt,
149            chat_history: self.chat_history,
150            max_depth: self.max_depth,
151            agent: self.agent,
152            hook: Some(hook),
153        }
154    }
155
156    #[cfg_attr(feature = "worker", worker::send)]
157    async fn send(self) -> StreamingResult<M::StreamingResponse> {
158        let agent_span = if tracing::Span::current().is_disabled() {
159            info_span!(
160                "invoke_agent",
161                gen_ai.operation.name = "invoke_agent",
162                gen_ai.agent.name = self.agent.name(),
163                gen_ai.system_instructions = self.agent.preamble,
164                gen_ai.prompt = tracing::field::Empty,
165                gen_ai.completion = tracing::field::Empty,
166                gen_ai.usage.input_tokens = tracing::field::Empty,
167                gen_ai.usage.output_tokens = tracing::field::Empty,
168            )
169        } else {
170            tracing::Span::current()
171        };
172
173        let prompt = self.prompt;
174        if let Some(text) = prompt.rag_text() {
175            agent_span.record("gen_ai.prompt", text);
176        }
177
178        let agent = self.agent;
179
180        let chat_history = if let Some(history) = self.chat_history {
181            Arc::new(RwLock::new(history))
182        } else {
183            Arc::new(RwLock::new(vec![]))
184        };
185
186        let mut current_max_depth = 0;
187        let mut last_prompt_error = String::new();
188
189        let mut last_text_response = String::new();
190        let mut is_text_response = false;
191        let mut max_depth_reached = false;
192
193        let mut aggregated_usage = crate::completion::Usage::new();
194
195        let cancel_signal = CancelSignal::new();
196
197        Box::pin(async_stream::stream! {
198            let _guard = agent_span.enter();
199            let mut current_prompt = prompt.clone();
200            let mut did_call_tool = false;
201
202            'outer: loop {
203                if current_max_depth > self.max_depth + 1 {
204                    last_prompt_error = current_prompt.rag_text().unwrap_or_default();
205                    max_depth_reached = true;
206                    break;
207                }
208
209                current_max_depth += 1;
210
211                if self.max_depth > 1 {
212                    tracing::info!(
213                        "Current conversation depth: {}/{}",
214                        current_max_depth,
215                        self.max_depth
216                    );
217                }
218
219                if let Some(ref hook) = self.hook {
220                    let reader = chat_history.read().await;
221                    let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history");
222                    let chat_history_except_last = reader[..reader.len() - 1].to_vec();
223
224                    hook.on_completion_call(&prompt, &chat_history_except_last, cancel_signal.clone())
225                    .await;
226
227                    if cancel_signal.is_cancelled() {
228                        yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
229                    }
230                }
231
232                let chat_stream_span = info_span!(
233                    target: "rig::agent_chat",
234                    parent: tracing::Span::current(),
235                    "chat_streaming",
236                    gen_ai.operation.name = "chat",
237                    gen_ai.system_instructions = &agent.preamble,
238                    gen_ai.provider.name = tracing::field::Empty,
239                    gen_ai.request.model = tracing::field::Empty,
240                    gen_ai.response.id = tracing::field::Empty,
241                    gen_ai.response.model = tracing::field::Empty,
242                    gen_ai.usage.output_tokens = tracing::field::Empty,
243                    gen_ai.usage.input_tokens = tracing::field::Empty,
244                    gen_ai.input.messages = tracing::field::Empty,
245                    gen_ai.output.messages = tracing::field::Empty,
246                );
247
248                let mut stream = tracing::Instrument::instrument(
249                    agent
250                    .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
251                    .await?
252                    .stream(), chat_stream_span
253                )
254
255                .await?;
256
257                chat_history.write().await.push(current_prompt.clone());
258
259                let mut tool_calls = vec![];
260                let mut tool_results = vec![];
261
262                while let Some(content) = stream.next().await {
263                    match content {
264                        Ok(StreamedAssistantContent::Text(text)) => {
265                            if !is_text_response {
266                                last_text_response = String::new();
267                                is_text_response = true;
268                            }
269                            last_text_response.push_str(&text.text);
270                            if let Some(ref hook) = self.hook {
271                                hook.on_text_delta(&text.text, &last_text_response, cancel_signal.clone()).await;
272                                if cancel_signal.is_cancelled() {
273                                    yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
274                                }
275                            }
276                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
277                            did_call_tool = false;
278                        },
279                        Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
280                            let tool_span = info_span!(
281                                parent: tracing::Span::current(),
282                                "execute_tool",
283                                gen_ai.operation.name = "execute_tool",
284                                gen_ai.tool.type = "function",
285                                gen_ai.tool.name = tracing::field::Empty,
286                                gen_ai.tool.call.id = tracing::field::Empty,
287                                gen_ai.tool.call.arguments = tracing::field::Empty,
288                                gen_ai.tool.call.result = tracing::field::Empty
289                            );
290
291                            let res = async {
292                                let tool_span = tracing::Span::current();
293                                if let Some(ref hook) = self.hook {
294                                    hook.on_tool_call(&tool_call.function.name, &tool_call.function.arguments.to_string(), cancel_signal.clone()).await;
295                                    if cancel_signal.is_cancelled() {
296                                        return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
297                                    }
298                                }
299
300                                tool_span.record("gen_ai.tool.name", &tool_call.function.name);
301                                tool_span.record("gen_ai.tool.call.arguments", tool_call.function.arguments.to_string());
302
303                                let tool_result = match
304                                agent.tool_server_handle.call_tool(&tool_call.function.name, &tool_call.function.arguments.to_string()).await {
305                                    Ok(thing) => thing,
306                                    Err(e) => {
307                                        tracing::warn!("Error while calling tool: {e}");
308                                        e.to_string()
309                                    }
310                                };
311
312                                tool_span.record("gen_ai.tool.call.result", &tool_result);
313
314                                if let Some(ref hook) = self.hook {
315                                    hook.on_tool_result(&tool_call.function.name, &tool_call.function.arguments.to_string(), &tool_result.to_string(), cancel_signal.clone())
316                                    .await;
317
318                                    if cancel_signal.is_cancelled() {
319                                        return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
320                                    }
321                                }
322
323                                let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
324
325                                tool_calls.push(tool_call_msg);
326                                tool_results.push((tool_call.id, tool_call.call_id, tool_result));
327
328                                did_call_tool = true;
329                                Ok(())
330                                // break;
331                            }.instrument(tool_span).await;
332
333                            if let Err(e) = res {
334                                yield Err(e);
335                            }
336                        },
337                        Ok(StreamedAssistantContent::ToolCallDelta { id, delta }) => {
338                            if let Some(ref hook) = self.hook {
339                                hook.on_tool_call_delta(&id, &delta, cancel_signal.clone())
340                                .await;
341
342                                if cancel_signal.is_cancelled() {
343                                    yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
344                                }
345                            }
346                        }
347                        Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })) => {
348                            chat_history.write().await.push(rig::message::Message::Assistant {
349                                id: None,
350                                content: OneOrMany::one(AssistantContent::Reasoning(Reasoning {
351                                    reasoning: reasoning.clone(), id: id.clone(), signature: signature.clone()
352                                }))
353                            });
354                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })));
355                            did_call_tool = false;
356                        },
357                        Ok(StreamedAssistantContent::Final(final_resp)) => {
358                            if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
359                            if is_text_response {
360                                if let Some(ref hook) = self.hook {
361                                    hook.on_stream_completion_response_finish(&prompt, &final_resp, cancel_signal.clone()).await;
362
363                                    if cancel_signal.is_cancelled() {
364                                        yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
365                                    }
366                                }
367
368                                tracing::Span::current().record("gen_ai.completion", &last_text_response);
369                                yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
370                                is_text_response = false;
371                            }
372                        }
373                        Err(e) => {
374                            yield Err(e.into());
375                            break 'outer;
376                        }
377                    }
378                }
379
380                // Add (parallel) tool calls to chat history
381                if !tool_calls.is_empty() {
382                    chat_history.write().await.push(Message::Assistant {
383                        id: None,
384                        content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
385                    });
386                }
387
388                // Add tool results to chat history
389                for (id, call_id, tool_result) in tool_results {
390                    if let Some(call_id) = call_id {
391                        chat_history.write().await.push(Message::User {
392                            content: OneOrMany::one(UserContent::tool_result_with_call_id(
393                                &id,
394                                call_id.clone(),
395                                OneOrMany::one(ToolResultContent::text(&tool_result)),
396                            )),
397                        });
398                    } else {
399                        chat_history.write().await.push(Message::User {
400                            content: OneOrMany::one(UserContent::tool_result(
401                                &id,
402                                OneOrMany::one(ToolResultContent::text(&tool_result)),
403                            )),
404                        });
405                    }
406                }
407
408                // Set the current prompt to the last message in the chat history
409                current_prompt = match chat_history.write().await.pop() {
410                    Some(prompt) => prompt,
411                    None => unreachable!("Chat history should never be empty at this point"),
412                };
413
414                if !did_call_tool {
415                    let current_span = tracing::Span::current();
416                    current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
417                    current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
418                    tracing::info!("Agent multi-turn stream finished");
419                    yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
420                    break;
421                }
422            }
423
424            if max_depth_reached {
425                yield Err(Box::new(PromptError::MaxDepthError {
426                    max_depth: self.max_depth,
427                    chat_history: Box::new((*chat_history.read().await).clone()),
428                    prompt: last_prompt_error.clone().into(),
429                }).into());
430            }
431
432        })
433    }
434}
435
436impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
437where
438    M: CompletionModel + 'static,
439    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
440    P: StreamingPromptHook<M> + 'static,
441{
442    type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
443    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
444
445    fn into_future(self) -> Self::IntoFuture {
446        // Wrap send() in a future, because send() returns a stream immediately
447        Box::pin(async move { self.send().await })
448    }
449}
450
451/// helper function to stream a completion selfuest to stdout
452pub async fn stream_to_stdout<R>(
453    stream: &mut StreamingResult<R>,
454) -> Result<FinalResponse, std::io::Error> {
455    let mut final_res = FinalResponse::empty();
456    print!("Response: ");
457    while let Some(content) = stream.next().await {
458        match content {
459            Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Text(Text { text }))) => {
460                print!("{text}");
461                std::io::Write::flush(&mut std::io::stdout()).unwrap();
462            }
463            Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Reasoning(
464                Reasoning { reasoning, .. },
465            ))) => {
466                let reasoning = reasoning.join("\n");
467                print!("{reasoning}");
468                std::io::Write::flush(&mut std::io::stdout()).unwrap();
469            }
470            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
471                final_res = res;
472            }
473            Err(err) => {
474                eprintln!("Error: {err}");
475            }
476            _ => {}
477        }
478    }
479
480    Ok(final_res)
481}
482
483// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
484/// Trait for per-request hooks to observe tool call events.
485pub trait StreamingPromptHook<M>: Clone + Send + Sync
486where
487    M: CompletionModel,
488{
489    #[allow(unused_variables)]
490    /// Called before the prompt is sent to the model
491    fn on_completion_call(
492        &self,
493        prompt: &Message,
494        history: &[Message],
495        cancel_sig: CancelSignal,
496    ) -> impl Future<Output = ()> + Send {
497        async {}
498    }
499
500    #[allow(unused_variables)]
501    /// Called when receiving a text delta
502    fn on_text_delta(
503        &self,
504        text_delta: &str,
505        aggregated_text: &str,
506        cancel_sig: CancelSignal,
507    ) -> impl Future<Output = ()> + Send {
508        async {}
509    }
510
511    #[allow(unused_variables)]
512    /// Called when receiving a tool call delta
513    fn on_tool_call_delta(
514        &self,
515        tool_call_id: &str,
516        tool_call_delta: &str,
517        cancel_sig: CancelSignal,
518    ) -> impl Future<Output = ()> + Send {
519        async {}
520    }
521
522    #[allow(unused_variables)]
523    /// Called after the model provider has finished streaming a text response from their completion API to the client.
524    fn on_stream_completion_response_finish(
525        &self,
526        prompt: &Message,
527        response: &<M as CompletionModel>::StreamingResponse,
528        cancel_sig: CancelSignal,
529    ) -> impl Future<Output = ()> + Send {
530        async {}
531    }
532
533    #[allow(unused_variables)]
534    /// Called before a tool is invoked.
535    fn on_tool_call(
536        &self,
537        tool_name: &str,
538        args: &str,
539        cancel_sig: CancelSignal,
540    ) -> impl Future<Output = ()> + Send {
541        async {}
542    }
543
544    #[allow(unused_variables)]
545    /// Called after a tool is invoked (and a result has been returned).
546    fn on_tool_result(
547        &self,
548        tool_name: &str,
549        args: &str,
550        result: &str,
551        cancel_sig: CancelSignal,
552    ) -> impl Future<Output = ()> + Send {
553        async {}
554    }
555}
556
557impl<M> StreamingPromptHook<M> for () where M: CompletionModel {}