rig/agent/prompt_request/
mod.rs

1pub(crate) mod streaming;
2
3use std::{
4    future::IntoFuture,
5    marker::PhantomData,
6    sync::{
7        Arc,
8        atomic::{AtomicBool, AtomicU64, Ordering},
9    },
10};
11use tracing::{Instrument, span::Id};
12
13use futures::{StreamExt, stream};
14use tracing::info_span;
15
16use crate::{
17    OneOrMany,
18    completion::{Completion, CompletionModel, Message, PromptError, Usage},
19    message::{AssistantContent, UserContent},
20    tool::ToolSetError,
21    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
22};
23
24use super::Agent;
25
26pub trait PromptType {}
27pub struct Standard;
28pub struct Extended;
29
30impl PromptType for Standard {}
31impl PromptType for Extended {}
32
33/// A builder for creating prompt requests with customizable options.
34/// Uses generics to track which options have been set during the build process.
35///
36/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
37/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
38/// attempting to await (which will send the prompt request) can potentially return
39/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
40/// back to back.
41pub struct PromptRequest<'a, S, M, P>
42where
43    S: PromptType,
44    M: CompletionModel,
45    P: PromptHook<M>,
46{
47    /// The prompt message to send to the model
48    prompt: Message,
49    /// Optional chat history to include with the prompt
50    /// Note: chat history needs to outlive the agent as it might be used with other agents
51    chat_history: Option<&'a mut Vec<Message>>,
52    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
53    max_depth: usize,
54    /// The agent to use for execution
55    agent: &'a Agent<M>,
56    /// Phantom data to track the type of the request
57    state: PhantomData<S>,
58    /// Optional per-request hook for events
59    hook: Option<P>,
60}
61
62impl<'a, M> PromptRequest<'a, Standard, M, ()>
63where
64    M: CompletionModel,
65{
66    /// Create a new PromptRequest with the given prompt and model
67    pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
68        Self {
69            prompt: prompt.into(),
70            chat_history: None,
71            max_depth: 0,
72            agent,
73            state: PhantomData,
74            hook: None,
75        }
76    }
77}
78
79impl<'a, S, M, P> PromptRequest<'a, S, M, P>
80where
81    S: PromptType,
82    M: CompletionModel,
83    P: PromptHook<M>,
84{
85    /// Enable returning extended details for responses (includes aggregated token usage)
86    ///
87    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
88    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
89    /// of conversation.
90    pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
91        PromptRequest {
92            prompt: self.prompt,
93            chat_history: self.chat_history,
94            max_depth: self.max_depth,
95            agent: self.agent,
96            state: PhantomData,
97            hook: self.hook,
98        }
99    }
100    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
101    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
102    pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M, P> {
103        PromptRequest {
104            prompt: self.prompt,
105            chat_history: self.chat_history,
106            max_depth: depth,
107            agent: self.agent,
108            state: PhantomData,
109            hook: self.hook,
110        }
111    }
112
113    /// Add chat history to the prompt request
114    pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M, P> {
115        PromptRequest {
116            prompt: self.prompt,
117            chat_history: Some(history),
118            max_depth: self.max_depth,
119            agent: self.agent,
120            state: PhantomData,
121            hook: self.hook,
122        }
123    }
124
125    /// Attach a per-request hook for tool call events
126    pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
127    where
128        P2: PromptHook<M>,
129    {
130        PromptRequest {
131            prompt: self.prompt,
132            chat_history: self.chat_history,
133            max_depth: self.max_depth,
134            agent: self.agent,
135            state: PhantomData,
136            hook: Some(hook),
137        }
138    }
139}
140
141pub struct CancelSignal(Arc<AtomicBool>);
142
143impl CancelSignal {
144    fn new() -> Self {
145        Self(Arc::new(AtomicBool::new(false)))
146    }
147
148    pub fn cancel(&self) {
149        self.0.store(true, Ordering::SeqCst);
150    }
151
152    fn is_cancelled(&self) -> bool {
153        self.0.load(Ordering::SeqCst)
154    }
155}
156
157impl Clone for CancelSignal {
158    fn clone(&self) -> Self {
159        Self(self.0.clone())
160    }
161}
162
163// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
164/// Trait for per-request hooks to observe tool call events.
165pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
166where
167    M: CompletionModel,
168{
169    #[allow(unused_variables)]
170    /// Called before the prompt is sent to the model
171    fn on_completion_call(
172        &self,
173        prompt: &Message,
174        history: &[Message],
175        cancel_sig: CancelSignal,
176    ) -> impl Future<Output = ()> + WasmCompatSend {
177        async {}
178    }
179
180    #[allow(unused_variables)]
181    /// Called after the prompt is sent to the model and a response is received.
182    fn on_completion_response(
183        &self,
184        prompt: &Message,
185        response: &crate::completion::CompletionResponse<M::Response>,
186        cancel_sig: CancelSignal,
187    ) -> impl Future<Output = ()> + WasmCompatSend {
188        async {}
189    }
190
191    #[allow(unused_variables)]
192    /// Called before a tool is invoked.
193    fn on_tool_call(
194        &self,
195        tool_name: &str,
196        args: &str,
197        cancel_sig: CancelSignal,
198    ) -> impl Future<Output = ()> + WasmCompatSend {
199        async {}
200    }
201
202    #[allow(unused_variables)]
203    /// Called after a tool is invoked (and a result has been returned).
204    fn on_tool_result(
205        &self,
206        tool_name: &str,
207        args: &str,
208        result: &str,
209        cancel_sig: CancelSignal,
210    ) -> impl Future<Output = ()> + WasmCompatSend {
211        async {}
212    }
213}
214
215impl<M> PromptHook<M> for () where M: CompletionModel {}
216
217/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
218///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
219///  directly via the associated type.
220impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
221where
222    M: CompletionModel,
223    P: PromptHook<M> + 'static,
224{
225    type Output = Result<String, PromptError>;
226    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
227
228    fn into_future(self) -> Self::IntoFuture {
229        Box::pin(self.send())
230    }
231}
232
233impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
234where
235    M: CompletionModel,
236    P: PromptHook<M> + 'static,
237{
238    type Output = Result<PromptResponse, PromptError>;
239    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
240
241    fn into_future(self) -> Self::IntoFuture {
242        Box::pin(self.send())
243    }
244}
245
246impl<M, P> PromptRequest<'_, Standard, M, P>
247where
248    M: CompletionModel,
249    P: PromptHook<M>,
250{
251    async fn send(self) -> Result<String, PromptError> {
252        self.extended_details().send().await.map(|resp| resp.output)
253    }
254}
255
256#[derive(Debug, Clone)]
257pub struct PromptResponse {
258    pub output: String,
259    pub total_usage: Usage,
260}
261
262impl PromptResponse {
263    pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
264        Self {
265            output: output.into(),
266            total_usage,
267        }
268    }
269}
270
271impl<M, P> PromptRequest<'_, Extended, M, P>
272where
273    M: CompletionModel,
274    P: PromptHook<M>,
275{
276    async fn send(self) -> Result<PromptResponse, PromptError> {
277        let agent_span = if tracing::Span::current().is_disabled() {
278            info_span!(
279                "invoke_agent",
280                gen_ai.operation.name = "invoke_agent",
281                gen_ai.agent.name = self.agent.name(),
282                gen_ai.system_instructions = self.agent.preamble,
283                gen_ai.prompt = tracing::field::Empty,
284                gen_ai.completion = tracing::field::Empty,
285                gen_ai.usage.input_tokens = tracing::field::Empty,
286                gen_ai.usage.output_tokens = tracing::field::Empty,
287            )
288        } else {
289            tracing::Span::current()
290        };
291
292        let agent = self.agent;
293        let chat_history = if let Some(history) = self.chat_history {
294            history.push(self.prompt.to_owned());
295            history
296        } else {
297            &mut vec![self.prompt.to_owned()]
298        };
299
300        if let Some(text) = self.prompt.rag_text() {
301            agent_span.record("gen_ai.prompt", text);
302        }
303
304        let cancel_sig = CancelSignal::new();
305
306        let mut current_max_depth = 0;
307        let mut usage = Usage::new();
308        let current_span_id: AtomicU64 = AtomicU64::new(0);
309
310        // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
311        let last_prompt = loop {
312            let prompt = chat_history
313                .last()
314                .cloned()
315                .expect("there should always be at least one message in the chat history");
316
317            if current_max_depth > self.max_depth + 1 {
318                break prompt;
319            }
320
321            current_max_depth += 1;
322
323            if self.max_depth > 1 {
324                tracing::info!(
325                    "Current conversation depth: {}/{}",
326                    current_max_depth,
327                    self.max_depth
328                );
329            }
330
331            if let Some(ref hook) = self.hook {
332                hook.on_completion_call(
333                    &prompt,
334                    &chat_history[..chat_history.len() - 1],
335                    cancel_sig.clone(),
336                )
337                .await;
338                if cancel_sig.is_cancelled() {
339                    return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
340                }
341            }
342            let span = tracing::Span::current();
343            let chat_span = info_span!(
344                target: "rig::agent_chat",
345                parent: &span,
346                "chat",
347                gen_ai.operation.name = "chat",
348                gen_ai.system_instructions = self.agent.preamble,
349                gen_ai.provider.name = tracing::field::Empty,
350                gen_ai.request.model = tracing::field::Empty,
351                gen_ai.response.id = tracing::field::Empty,
352                gen_ai.response.model = tracing::field::Empty,
353                gen_ai.usage.output_tokens = tracing::field::Empty,
354                gen_ai.usage.input_tokens = tracing::field::Empty,
355                gen_ai.input.messages = tracing::field::Empty,
356                gen_ai.output.messages = tracing::field::Empty,
357            );
358
359            let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
360                let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
361                chat_span.follows_from(id).to_owned()
362            } else {
363                chat_span
364            };
365
366            if let Some(id) = chat_span.id() {
367                current_span_id.store(id.into_u64(), Ordering::SeqCst);
368            };
369
370            let resp = agent
371                .completion(
372                    prompt.clone(),
373                    chat_history[..chat_history.len() - 1].to_vec(),
374                )
375                .await?
376                .send()
377                .instrument(chat_span.clone())
378                .await?;
379
380            usage += resp.usage;
381
382            if let Some(ref hook) = self.hook {
383                hook.on_completion_response(&prompt, &resp, cancel_sig.clone())
384                    .await;
385                if cancel_sig.is_cancelled() {
386                    return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
387                }
388            }
389
390            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
391                .choice
392                .iter()
393                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
394
395            chat_history.push(Message::Assistant {
396                id: None,
397                content: resp.choice.clone(),
398            });
399
400            if tool_calls.is_empty() {
401                let merged_texts = texts
402                    .into_iter()
403                    .filter_map(|content| {
404                        if let AssistantContent::Text(text) = content {
405                            Some(text.text.clone())
406                        } else {
407                            None
408                        }
409                    })
410                    .collect::<Vec<_>>()
411                    .join("\n");
412
413                if self.max_depth > 1 {
414                    tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
415                }
416
417                agent_span.record("gen_ai.completion", &merged_texts);
418                agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
419                agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
420
421                // If there are no tool calls, depth is not relevant, we can just return the merged text response.
422                return Ok(PromptResponse::new(merged_texts, usage));
423            }
424
425            let hook = self.hook.clone();
426            let tool_content = stream::iter(tool_calls)
427                .then(|choice| {
428                    let hook1 = hook.clone();
429                    let hook2 = hook.clone();
430
431                    let cancel_sig1 = cancel_sig.clone();
432                    let cancel_sig2 = cancel_sig.clone();
433
434                    let tool_span = info_span!(
435                        "execute_tool",
436                        gen_ai.operation.name = "execute_tool",
437                        gen_ai.tool.type = "function",
438                        gen_ai.tool.name = tracing::field::Empty,
439                        gen_ai.tool.call.id = tracing::field::Empty,
440                        gen_ai.tool.call.arguments = tracing::field::Empty,
441                        gen_ai.tool.call.result = tracing::field::Empty
442                    );
443
444                    let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
445                        let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
446                        tool_span.follows_from(id).to_owned()
447                    } else {
448                        tool_span
449                    };
450
451                    if let Some(id) = tool_span.id() {
452                        current_span_id.store(id.into_u64(), Ordering::SeqCst);
453                    };
454
455                    async move {
456                        if let AssistantContent::ToolCall(tool_call) = choice {
457                            let tool_name = &tool_call.function.name;
458                            let args = tool_call.function.arguments.to_string();
459                            let tool_span = tracing::Span::current();
460                            tool_span.record("gen_ai.tool.name", tool_name);
461                            tool_span.record("gen_ai.tool.call.id", &tool_call.id);
462                            tool_span.record("gen_ai.tool.call.arguments", &args);
463                            if let Some(hook) = hook1 {
464                                hook.on_tool_call(tool_name, &args, cancel_sig1.clone())
465                                    .await;
466                                if cancel_sig1.is_cancelled() {
467                                    return Err(ToolSetError::Interrupted);
468                                }
469                            }
470                            let output =
471                                match agent.tool_server_handle.call_tool(tool_name, &args).await {
472                                    Ok(res) => res,
473                                    Err(e) => {
474                                        tracing::warn!("Error while executing tool: {e}");
475                                        e.to_string()
476                                    }
477                                };
478                            if let Some(hook) = hook2 {
479                                hook.on_tool_result(
480                                    tool_name,
481                                    &args,
482                                    &output.to_string(),
483                                    cancel_sig2.clone(),
484                                )
485                                .await;
486
487                                if cancel_sig2.is_cancelled() {
488                                    return Err(ToolSetError::Interrupted);
489                                }
490                            }
491                            tool_span.record("gen_ai.tool.call.result", &output);
492                            tracing::info!(
493                                "executed tool {tool_name} with args {args}. result: {output}"
494                            );
495                            if let Some(call_id) = tool_call.call_id.clone() {
496                                Ok(UserContent::tool_result_with_call_id(
497                                    tool_call.id.clone(),
498                                    call_id,
499                                    OneOrMany::one(output.into()),
500                                ))
501                            } else {
502                                Ok(UserContent::tool_result(
503                                    tool_call.id.clone(),
504                                    OneOrMany::one(output.into()),
505                                ))
506                            }
507                        } else {
508                            unreachable!(
509                                "This should never happen as we already filtered for `ToolCall`"
510                            )
511                        }
512                    }
513                    .instrument(tool_span)
514                })
515                .collect::<Vec<Result<UserContent, ToolSetError>>>()
516                .await
517                .into_iter()
518                .collect::<Result<Vec<_>, _>>()
519                .map_err(|e| {
520                    if matches!(e, ToolSetError::Interrupted) {
521                        PromptError::prompt_cancelled(chat_history.to_vec())
522                    } else {
523                        e.into()
524                    }
525                })?;
526
527            chat_history.push(Message::User {
528                content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
529            });
530        };
531
532        // If we reach here, we never resolved the final tool call. We need to do ... something.
533        Err(PromptError::MaxDepthError {
534            max_depth: self.max_depth,
535            chat_history: Box::new(chat_history.clone()),
536            prompt: last_prompt,
537        })
538    }
539}