Skip to main content

rig/agent/prompt_request/
mod.rs

1pub mod streaming;
2
3pub use streaming::StreamingPromptHook;
4
5use std::{
6    future::IntoFuture,
7    marker::PhantomData,
8    sync::atomic::{AtomicU64, Ordering},
9};
10use tracing::{Instrument, span::Id};
11
12use futures::{StreamExt, stream};
13use tracing::info_span;
14
15use crate::{
16    OneOrMany,
17    completion::{Completion, CompletionModel, Message, PromptError, Usage},
18    json_utils,
19    message::{AssistantContent, ToolResultContent, UserContent},
20    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
21};
22
23use super::Agent;
24
25pub trait PromptType {}
26pub struct Standard;
27pub struct Extended;
28
29impl PromptType for Standard {}
30impl PromptType for Extended {}
31
32/// Control flow action for tool call hooks. This is different from the regular [`HookAction`] in that tool call executions may be skipped for one or more reasons.
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum ToolCallHookAction {
35    /// Continue tool execution as normal.
36    Continue,
37    /// Skip tool execution and return the provided reason as the tool result.
38    Skip { reason: String },
39    /// Terminate agent loop early
40    Terminate { reason: String },
41}
42
43impl ToolCallHookAction {
44    /// Continue the agentic loop as normal
45    pub fn cont() -> Self {
46        Self::Continue
47    }
48
49    /// Skip a given tool call (with a provided reason).
50    pub fn skip(reason: impl Into<String>) -> Self {
51        Self::Skip {
52            reason: reason.into(),
53        }
54    }
55
56    /// Terminates the agentic loop entirely.
57    pub fn terminate(reason: impl Into<String>) -> Self {
58        Self::Terminate {
59            reason: reason.into(),
60        }
61    }
62}
63
64/// Control flow action for hooks.
65#[derive(Debug, Clone, PartialEq, Eq)]
66pub enum HookAction {
67    /// Continue agentic loop execution as normal.
68    Continue,
69    /// Terminate agent loop early
70    Terminate { reason: String },
71}
72
73impl HookAction {
74    /// Continue the agentic loop as normal
75    pub fn cont() -> Self {
76        Self::Continue
77    }
78
79    /// Terminates the agentic loop entirely.
80    pub fn terminate(reason: impl Into<String>) -> Self {
81        Self::Terminate {
82            reason: reason.into(),
83        }
84    }
85}
86
87/// A builder for creating prompt requests with customizable options.
88/// Uses generics to track which options have been set during the build process.
89///
90/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
91/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
92/// attempting to await (which will send the prompt request) can potentially return
93/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
94/// back to back.
95pub struct PromptRequest<'a, S, M, P>
96where
97    S: PromptType,
98    M: CompletionModel,
99    P: PromptHook<M>,
100{
101    /// The prompt message to send to the model
102    prompt: Message,
103    /// Optional chat history to include with the prompt
104    /// Note: chat history needs to outlive the agent as it might be used with other agents
105    chat_history: Option<&'a mut Vec<Message>>,
106    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
107    max_turns: usize,
108    /// The agent to use for execution
109    agent: &'a Agent<M>,
110    /// Phantom data to track the type of the request
111    state: PhantomData<S>,
112    /// Optional per-request hook for events
113    hook: Option<P>,
114    /// How many tools should be executed at the same time (1 by default).
115    concurrency: usize,
116}
117
118impl<'a, M> PromptRequest<'a, Standard, M, ()>
119where
120    M: CompletionModel,
121{
122    /// Create a new PromptRequest with the given prompt and model
123    pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
124        Self {
125            prompt: prompt.into(),
126            chat_history: None,
127            max_turns: agent.default_max_turns.unwrap_or_default(),
128            agent,
129            state: PhantomData,
130            hook: None,
131            concurrency: 1,
132        }
133    }
134}
135
136impl<'a, S, M, P> PromptRequest<'a, S, M, P>
137where
138    S: PromptType,
139    M: CompletionModel,
140    P: PromptHook<M>,
141{
142    /// Enable returning extended details for responses (includes aggregated token usage)
143    ///
144    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
145    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
146    /// of conversation.
147    pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
148        PromptRequest {
149            prompt: self.prompt,
150            chat_history: self.chat_history,
151            max_turns: self.max_turns,
152            agent: self.agent,
153            state: PhantomData,
154            hook: self.hook,
155            concurrency: self.concurrency,
156        }
157    }
158    /// Set the maximum number of turns for multi-turn conversations. A given agent may require multiple turns for tool-calling before giving an answer.
159    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
160    pub fn max_turns(self, depth: usize) -> PromptRequest<'a, S, M, P> {
161        PromptRequest {
162            prompt: self.prompt,
163            chat_history: self.chat_history,
164            max_turns: depth,
165            agent: self.agent,
166            state: PhantomData,
167            hook: self.hook,
168            concurrency: self.concurrency,
169        }
170    }
171
172    /// Add concurrency to the prompt request.
173    /// This will cause the agent to execute tools concurrently.
174    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
175        self.concurrency = concurrency;
176        self
177    }
178
179    /// Add chat history to the prompt request
180    pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M, P> {
181        PromptRequest {
182            prompt: self.prompt,
183            chat_history: Some(history),
184            max_turns: self.max_turns,
185            agent: self.agent,
186            state: PhantomData,
187            hook: self.hook,
188            concurrency: self.concurrency,
189        }
190    }
191
192    /// Attach a per-request hook for tool call events
193    pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
194    where
195        P2: PromptHook<M>,
196    {
197        PromptRequest {
198            prompt: self.prompt,
199            chat_history: self.chat_history,
200            max_turns: self.max_turns,
201            agent: self.agent,
202            state: PhantomData,
203            hook: Some(hook),
204            concurrency: self.concurrency,
205        }
206    }
207}
208
209// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
210/// Trait for per-request hooks to observe tool call events.
211pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
212where
213    M: CompletionModel,
214{
215    /// Called before the prompt is sent to the model
216    fn on_completion_call(
217        &self,
218        _prompt: &Message,
219        _history: &[Message],
220    ) -> impl Future<Output = HookAction> + WasmCompatSend {
221        async { HookAction::cont() }
222    }
223
224    /// Called after the prompt is sent to the model and a response is received.
225    fn on_completion_response(
226        &self,
227        _prompt: &Message,
228        _response: &crate::completion::CompletionResponse<M::Response>,
229    ) -> impl Future<Output = HookAction> + WasmCompatSend {
230        async { HookAction::cont() }
231    }
232
233    /// Called before a tool is invoked.
234    ///
235    /// # Returns
236    /// - `ToolCallHookAction::Continue` - Allow tool execution to proceed
237    /// - `ToolCallHookAction::Skip { reason }` - Reject tool execution; `reason` will be returned to the LLM as the tool result
238    fn on_tool_call(
239        &self,
240        _tool_name: &str,
241        _tool_call_id: Option<String>,
242        _internal_call_id: &str,
243        _args: &str,
244    ) -> impl Future<Output = ToolCallHookAction> + WasmCompatSend {
245        async { ToolCallHookAction::cont() }
246    }
247
248    /// Called after a tool is invoked (and a result has been returned).
249    fn on_tool_result(
250        &self,
251        _tool_name: &str,
252        _tool_call_id: Option<String>,
253        _internal_call_id: &str,
254        _args: &str,
255        _result: &str,
256    ) -> impl Future<Output = HookAction> + WasmCompatSend {
257        async { HookAction::cont() }
258    }
259}
260
261impl<M> PromptHook<M> for () where M: CompletionModel {}
262
263/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
264///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
265///  directly via the associated type.
266impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
267where
268    M: CompletionModel,
269    P: PromptHook<M> + 'static,
270{
271    type Output = Result<String, PromptError>;
272    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
273
274    fn into_future(self) -> Self::IntoFuture {
275        Box::pin(self.send())
276    }
277}
278
279impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
280where
281    M: CompletionModel,
282    P: PromptHook<M> + 'static,
283{
284    type Output = Result<PromptResponse, PromptError>;
285    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
286
287    fn into_future(self) -> Self::IntoFuture {
288        Box::pin(self.send())
289    }
290}
291
292impl<M, P> PromptRequest<'_, Standard, M, P>
293where
294    M: CompletionModel,
295    P: PromptHook<M>,
296{
297    async fn send(self) -> Result<String, PromptError> {
298        self.extended_details().send().await.map(|resp| resp.output)
299    }
300}
301
302#[derive(Debug, Clone)]
303pub struct PromptResponse {
304    pub output: String,
305    pub total_usage: Usage,
306}
307
308impl PromptResponse {
309    pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
310        Self {
311            output: output.into(),
312            total_usage,
313        }
314    }
315}
316
317impl<M, P> PromptRequest<'_, Extended, M, P>
318where
319    M: CompletionModel,
320    P: PromptHook<M>,
321{
322    async fn send(self) -> Result<PromptResponse, PromptError> {
323        let agent_span = if tracing::Span::current().is_disabled() {
324            info_span!(
325                "invoke_agent",
326                gen_ai.operation.name = "invoke_agent",
327                gen_ai.agent.name = self.agent.name(),
328                gen_ai.system_instructions = self.agent.preamble,
329                gen_ai.prompt = tracing::field::Empty,
330                gen_ai.completion = tracing::field::Empty,
331                gen_ai.usage.input_tokens = tracing::field::Empty,
332                gen_ai.usage.output_tokens = tracing::field::Empty,
333            )
334        } else {
335            tracing::Span::current()
336        };
337
338        let agent = self.agent;
339        let chat_history = if let Some(history) = self.chat_history {
340            history.push(self.prompt.to_owned());
341            history
342        } else {
343            &mut vec![self.prompt.to_owned()]
344        };
345
346        if let Some(text) = self.prompt.rag_text() {
347            agent_span.record("gen_ai.prompt", text);
348        }
349
350        let mut current_max_turns = 0;
351        let mut usage = Usage::new();
352        let current_span_id: AtomicU64 = AtomicU64::new(0);
353
354        // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
355        let last_prompt = loop {
356            let prompt = chat_history
357                .last()
358                .cloned()
359                .expect("there should always be at least one message in the chat history");
360
361            if current_max_turns > self.max_turns + 1 {
362                break prompt;
363            }
364
365            current_max_turns += 1;
366
367            if self.max_turns > 1 {
368                tracing::info!(
369                    "Current conversation depth: {}/{}",
370                    current_max_turns,
371                    self.max_turns
372                );
373            }
374
375            if let Some(ref hook) = self.hook
376                && let HookAction::Terminate { reason } = hook
377                    .on_completion_call(&prompt, &chat_history[..chat_history.len() - 1])
378                    .await
379            {
380                return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
381            }
382
383            let span = tracing::Span::current();
384            let chat_span = info_span!(
385                target: "rig::agent_chat",
386                parent: &span,
387                "chat",
388                gen_ai.operation.name = "chat",
389                gen_ai.agent.name = self.agent.name(),
390                gen_ai.system_instructions = self.agent.preamble,
391                gen_ai.provider.name = tracing::field::Empty,
392                gen_ai.request.model = tracing::field::Empty,
393                gen_ai.response.id = tracing::field::Empty,
394                gen_ai.response.model = tracing::field::Empty,
395                gen_ai.usage.output_tokens = tracing::field::Empty,
396                gen_ai.usage.input_tokens = tracing::field::Empty,
397                gen_ai.input.messages = tracing::field::Empty,
398                gen_ai.output.messages = tracing::field::Empty,
399            );
400
401            let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
402                let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
403                chat_span.follows_from(id).to_owned()
404            } else {
405                chat_span
406            };
407
408            if let Some(id) = chat_span.id() {
409                current_span_id.store(id.into_u64(), Ordering::SeqCst);
410            };
411
412            let resp = agent
413                .completion(
414                    prompt.clone(),
415                    chat_history[..chat_history.len() - 1].to_vec(),
416                )
417                .await?
418                .send()
419                .instrument(chat_span.clone())
420                .await?;
421
422            usage += resp.usage;
423
424            if let Some(ref hook) = self.hook
425                && let HookAction::Terminate { reason } =
426                    hook.on_completion_response(&prompt, &resp).await
427            {
428                return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
429            }
430
431            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
432                .choice
433                .iter()
434                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
435
436            chat_history.push(Message::Assistant {
437                id: None,
438                content: resp.choice.clone(),
439            });
440
441            if tool_calls.is_empty() {
442                let merged_texts = texts
443                    .into_iter()
444                    .filter_map(|content| {
445                        if let AssistantContent::Text(text) = content {
446                            Some(text.text.clone())
447                        } else {
448                            None
449                        }
450                    })
451                    .collect::<Vec<_>>()
452                    .join("\n");
453
454                if self.max_turns > 1 {
455                    tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
456                }
457
458                agent_span.record("gen_ai.completion", &merged_texts);
459                agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
460                agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
461
462                // If there are no tool calls, depth is not relevant, we can just return the merged text response.
463                return Ok(PromptResponse::new(merged_texts, usage));
464            }
465
466            let hook = self.hook.clone();
467
468            let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
469            let tool_content = stream::iter(tool_calls)
470                .map(|choice| {
471                    let hook1 = hook.clone();
472                    let hook2 = hook.clone();
473
474                    let tool_span = info_span!(
475                        "execute_tool",
476                        gen_ai.operation.name = "execute_tool",
477                        gen_ai.tool.type = "function",
478                        gen_ai.tool.name = tracing::field::Empty,
479                        gen_ai.tool.call.id = tracing::field::Empty,
480                        gen_ai.tool.call.arguments = tracing::field::Empty,
481                        gen_ai.tool.call.result = tracing::field::Empty
482                    );
483
484                    let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
485                        let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
486                        tool_span.follows_from(id).to_owned()
487                    } else {
488                        tool_span
489                    };
490
491                    if let Some(id) = tool_span.id() {
492                        current_span_id.store(id.into_u64(), Ordering::SeqCst);
493                    };
494
495                    let cloned_chat_history = chat_history.clone().to_vec();
496
497                    async move {
498                        if let AssistantContent::ToolCall(tool_call) = choice {
499                            let tool_name = &tool_call.function.name;
500                            let args =
501                                json_utils::value_to_json_string(&tool_call.function.arguments);
502                            let internal_call_id = nanoid::nanoid!();
503                            let tool_span = tracing::Span::current();
504                            tool_span.record("gen_ai.tool.name", tool_name);
505                            tool_span.record("gen_ai.tool.call.id", &tool_call.id);
506                            tool_span.record("gen_ai.tool.call.arguments", &args);
507                            if let Some(hook) = hook1 {
508                                let action = hook
509                                    .on_tool_call(
510                                        tool_name,
511                                        tool_call.call_id.clone(),
512                                        &internal_call_id,
513                                        &args,
514                                    )
515                                    .await;
516
517                                if let ToolCallHookAction::Terminate { reason } = action {
518                                    return Err(PromptError::prompt_cancelled(
519                                        cloned_chat_history,
520                                        reason,
521                                    ));
522                                }
523
524                                if let ToolCallHookAction::Skip { reason } = action {
525                                    // Tool execution rejected, return rejection message as tool result
526                                    tracing::info!(
527                                        tool_name = tool_name,
528                                        reason = reason,
529                                        "Tool call rejected"
530                                    );
531                                    if let Some(call_id) = tool_call.call_id.clone() {
532                                        return Ok(UserContent::tool_result_with_call_id(
533                                            tool_call.id.clone(),
534                                            call_id,
535                                            OneOrMany::one(reason.into()),
536                                        ));
537                                    } else {
538                                        return Ok(UserContent::tool_result(
539                                            tool_call.id.clone(),
540                                            OneOrMany::one(reason.into()),
541                                        ));
542                                    }
543                                }
544                            }
545                            let output =
546                                match agent.tool_server_handle.call_tool(tool_name, &args).await {
547                                    Ok(res) => res,
548                                    Err(e) => {
549                                        tracing::warn!("Error while executing tool: {e}");
550                                        e.to_string()
551                                    }
552                                };
553                            if let Some(hook) = hook2
554                                && let HookAction::Terminate { reason } = hook
555                                    .on_tool_result(
556                                        tool_name,
557                                        tool_call.call_id.clone(),
558                                        &internal_call_id,
559                                        &args,
560                                        &output.to_string(),
561                                    )
562                                    .await
563                            {
564                                return Err(PromptError::prompt_cancelled(
565                                    cloned_chat_history,
566                                    reason,
567                                ));
568                            }
569
570                            tool_span.record("gen_ai.tool.call.result", &output);
571                            tracing::info!(
572                                "executed tool {tool_name} with args {args}. result: {output}"
573                            );
574                            if let Some(call_id) = tool_call.call_id.clone() {
575                                Ok(UserContent::tool_result_with_call_id(
576                                    tool_call.id.clone(),
577                                    call_id,
578                                    ToolResultContent::from_tool_output(output),
579                                ))
580                            } else {
581                                Ok(UserContent::tool_result(
582                                    tool_call.id.clone(),
583                                    ToolResultContent::from_tool_output(output),
584                                ))
585                            }
586                        } else {
587                            unreachable!(
588                                "This should never happen as we already filtered for `ToolCall`"
589                            )
590                        }
591                    }
592                    .instrument(tool_span)
593                })
594                .buffer_unordered(self.concurrency)
595                .collect::<Vec<Result<UserContent, PromptError>>>()
596                .await
597                .into_iter()
598                .collect::<Result<Vec<_>, _>>()?;
599
600            chat_history.push(Message::User {
601                content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
602            });
603        };
604
605        // If we reach here, we never resolved the final tool call. We need to do ... something.
606        Err(PromptError::MaxTurnsError {
607            max_turns: self.max_turns,
608            chat_history: Box::new(chat_history.clone()),
609            prompt: Box::new(last_prompt),
610        })
611    }
612}