rig/agent/prompt_request/
mod.rs

1pub mod streaming;
2
3pub use streaming::StreamingPromptHook;
4
5use std::{
6    future::IntoFuture,
7    marker::PhantomData,
8    sync::{
9        Arc,
10        atomic::{AtomicBool, AtomicU64, Ordering},
11    },
12};
13use tracing::{Instrument, span::Id};
14
15use futures::{StreamExt, stream};
16use tracing::info_span;
17
18use crate::{
19    OneOrMany,
20    completion::{Completion, CompletionModel, Message, PromptError, Usage},
21    message::{AssistantContent, UserContent},
22    tool::ToolSetError,
23    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
24};
25
26use super::Agent;
27
28pub trait PromptType {}
29pub struct Standard;
30pub struct Extended;
31
32impl PromptType for Standard {}
33impl PromptType for Extended {}
34
35/// A builder for creating prompt requests with customizable options.
36/// Uses generics to track which options have been set during the build process.
37///
38/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
39/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
40/// attempting to await (which will send the prompt request) can potentially return
41/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
42/// back to back.
43pub struct PromptRequest<'a, S, M, P>
44where
45    S: PromptType,
46    M: CompletionModel,
47    P: PromptHook<M>,
48{
49    /// The prompt message to send to the model
50    prompt: Message,
51    /// Optional chat history to include with the prompt
52    /// Note: chat history needs to outlive the agent as it might be used with other agents
53    chat_history: Option<&'a mut Vec<Message>>,
54    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
55    max_depth: usize,
56    /// The agent to use for execution
57    agent: &'a Agent<M>,
58    /// Phantom data to track the type of the request
59    state: PhantomData<S>,
60    /// Optional per-request hook for events
61    hook: Option<P>,
62}
63
64impl<'a, M> PromptRequest<'a, Standard, M, ()>
65where
66    M: CompletionModel,
67{
68    /// Create a new PromptRequest with the given prompt and model
69    pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
70        Self {
71            prompt: prompt.into(),
72            chat_history: None,
73            max_depth: 0,
74            agent,
75            state: PhantomData,
76            hook: None,
77        }
78    }
79}
80
81impl<'a, S, M, P> PromptRequest<'a, S, M, P>
82where
83    S: PromptType,
84    M: CompletionModel,
85    P: PromptHook<M>,
86{
87    /// Enable returning extended details for responses (includes aggregated token usage)
88    ///
89    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
90    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
91    /// of conversation.
92    pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
93        PromptRequest {
94            prompt: self.prompt,
95            chat_history: self.chat_history,
96            max_depth: self.max_depth,
97            agent: self.agent,
98            state: PhantomData,
99            hook: self.hook,
100        }
101    }
102    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
103    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
104    pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M, P> {
105        PromptRequest {
106            prompt: self.prompt,
107            chat_history: self.chat_history,
108            max_depth: depth,
109            agent: self.agent,
110            state: PhantomData,
111            hook: self.hook,
112        }
113    }
114
115    /// Add chat history to the prompt request
116    pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M, P> {
117        PromptRequest {
118            prompt: self.prompt,
119            chat_history: Some(history),
120            max_depth: self.max_depth,
121            agent: self.agent,
122            state: PhantomData,
123            hook: self.hook,
124        }
125    }
126
127    /// Attach a per-request hook for tool call events
128    pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
129    where
130        P2: PromptHook<M>,
131    {
132        PromptRequest {
133            prompt: self.prompt,
134            chat_history: self.chat_history,
135            max_depth: self.max_depth,
136            agent: self.agent,
137            state: PhantomData,
138            hook: Some(hook),
139        }
140    }
141}
142
143pub struct CancelSignal(Arc<AtomicBool>);
144
145impl CancelSignal {
146    fn new() -> Self {
147        Self(Arc::new(AtomicBool::new(false)))
148    }
149
150    pub fn cancel(&self) {
151        self.0.store(true, Ordering::SeqCst);
152    }
153
154    fn is_cancelled(&self) -> bool {
155        self.0.load(Ordering::SeqCst)
156    }
157}
158
159impl Clone for CancelSignal {
160    fn clone(&self) -> Self {
161        Self(self.0.clone())
162    }
163}
164
165// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
166/// Trait for per-request hooks to observe tool call events.
167pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
168where
169    M: CompletionModel,
170{
171    #[allow(unused_variables)]
172    /// Called before the prompt is sent to the model
173    fn on_completion_call(
174        &self,
175        prompt: &Message,
176        history: &[Message],
177        cancel_sig: CancelSignal,
178    ) -> impl Future<Output = ()> + WasmCompatSend {
179        async {}
180    }
181
182    #[allow(unused_variables)]
183    /// Called after the prompt is sent to the model and a response is received.
184    fn on_completion_response(
185        &self,
186        prompt: &Message,
187        response: &crate::completion::CompletionResponse<M::Response>,
188        cancel_sig: CancelSignal,
189    ) -> impl Future<Output = ()> + WasmCompatSend {
190        async {}
191    }
192
193    #[allow(unused_variables)]
194    /// Called before a tool is invoked.
195    fn on_tool_call(
196        &self,
197        tool_name: &str,
198        args: &str,
199        cancel_sig: CancelSignal,
200    ) -> impl Future<Output = ()> + WasmCompatSend {
201        async {}
202    }
203
204    #[allow(unused_variables)]
205    /// Called after a tool is invoked (and a result has been returned).
206    fn on_tool_result(
207        &self,
208        tool_name: &str,
209        args: &str,
210        result: &str,
211        cancel_sig: CancelSignal,
212    ) -> impl Future<Output = ()> + WasmCompatSend {
213        async {}
214    }
215}
216
217impl<M> PromptHook<M> for () where M: CompletionModel {}
218
219/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
220///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
221///  directly via the associated type.
222impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
223where
224    M: CompletionModel,
225    P: PromptHook<M> + 'static,
226{
227    type Output = Result<String, PromptError>;
228    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
229
230    fn into_future(self) -> Self::IntoFuture {
231        Box::pin(self.send())
232    }
233}
234
235impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
236where
237    M: CompletionModel,
238    P: PromptHook<M> + 'static,
239{
240    type Output = Result<PromptResponse, PromptError>;
241    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
242
243    fn into_future(self) -> Self::IntoFuture {
244        Box::pin(self.send())
245    }
246}
247
248impl<M, P> PromptRequest<'_, Standard, M, P>
249where
250    M: CompletionModel,
251    P: PromptHook<M>,
252{
253    async fn send(self) -> Result<String, PromptError> {
254        self.extended_details().send().await.map(|resp| resp.output)
255    }
256}
257
258#[derive(Debug, Clone)]
259pub struct PromptResponse {
260    pub output: String,
261    pub total_usage: Usage,
262}
263
264impl PromptResponse {
265    pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
266        Self {
267            output: output.into(),
268            total_usage,
269        }
270    }
271}
272
273impl<M, P> PromptRequest<'_, Extended, M, P>
274where
275    M: CompletionModel,
276    P: PromptHook<M>,
277{
278    async fn send(self) -> Result<PromptResponse, PromptError> {
279        let agent_span = if tracing::Span::current().is_disabled() {
280            info_span!(
281                "invoke_agent",
282                gen_ai.operation.name = "invoke_agent",
283                gen_ai.agent.name = self.agent.name(),
284                gen_ai.system_instructions = self.agent.preamble,
285                gen_ai.prompt = tracing::field::Empty,
286                gen_ai.completion = tracing::field::Empty,
287                gen_ai.usage.input_tokens = tracing::field::Empty,
288                gen_ai.usage.output_tokens = tracing::field::Empty,
289            )
290        } else {
291            tracing::Span::current()
292        };
293
294        let agent = self.agent;
295        let chat_history = if let Some(history) = self.chat_history {
296            history.push(self.prompt.to_owned());
297            history
298        } else {
299            &mut vec![self.prompt.to_owned()]
300        };
301
302        if let Some(text) = self.prompt.rag_text() {
303            agent_span.record("gen_ai.prompt", text);
304        }
305
306        let cancel_sig = CancelSignal::new();
307
308        let mut current_max_depth = 0;
309        let mut usage = Usage::new();
310        let current_span_id: AtomicU64 = AtomicU64::new(0);
311
312        // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
313        let last_prompt = loop {
314            let prompt = chat_history
315                .last()
316                .cloned()
317                .expect("there should always be at least one message in the chat history");
318
319            if current_max_depth > self.max_depth + 1 {
320                break prompt;
321            }
322
323            current_max_depth += 1;
324
325            if self.max_depth > 1 {
326                tracing::info!(
327                    "Current conversation depth: {}/{}",
328                    current_max_depth,
329                    self.max_depth
330                );
331            }
332
333            if let Some(ref hook) = self.hook {
334                hook.on_completion_call(
335                    &prompt,
336                    &chat_history[..chat_history.len() - 1],
337                    cancel_sig.clone(),
338                )
339                .await;
340                if cancel_sig.is_cancelled() {
341                    return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
342                }
343            }
344            let span = tracing::Span::current();
345            let chat_span = info_span!(
346                target: "rig::agent_chat",
347                parent: &span,
348                "chat",
349                gen_ai.operation.name = "chat",
350                gen_ai.system_instructions = self.agent.preamble,
351                gen_ai.provider.name = tracing::field::Empty,
352                gen_ai.request.model = tracing::field::Empty,
353                gen_ai.response.id = tracing::field::Empty,
354                gen_ai.response.model = tracing::field::Empty,
355                gen_ai.usage.output_tokens = tracing::field::Empty,
356                gen_ai.usage.input_tokens = tracing::field::Empty,
357                gen_ai.input.messages = tracing::field::Empty,
358                gen_ai.output.messages = tracing::field::Empty,
359            );
360
361            let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
362                let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
363                chat_span.follows_from(id).to_owned()
364            } else {
365                chat_span
366            };
367
368            if let Some(id) = chat_span.id() {
369                current_span_id.store(id.into_u64(), Ordering::SeqCst);
370            };
371
372            let resp = agent
373                .completion(
374                    prompt.clone(),
375                    chat_history[..chat_history.len() - 1].to_vec(),
376                )
377                .await?
378                .send()
379                .instrument(chat_span.clone())
380                .await?;
381
382            usage += resp.usage;
383
384            if let Some(ref hook) = self.hook {
385                hook.on_completion_response(&prompt, &resp, cancel_sig.clone())
386                    .await;
387                if cancel_sig.is_cancelled() {
388                    return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
389                }
390            }
391
392            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
393                .choice
394                .iter()
395                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
396
397            chat_history.push(Message::Assistant {
398                id: None,
399                content: resp.choice.clone(),
400            });
401
402            if tool_calls.is_empty() {
403                let merged_texts = texts
404                    .into_iter()
405                    .filter_map(|content| {
406                        if let AssistantContent::Text(text) = content {
407                            Some(text.text.clone())
408                        } else {
409                            None
410                        }
411                    })
412                    .collect::<Vec<_>>()
413                    .join("\n");
414
415                if self.max_depth > 1 {
416                    tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
417                }
418
419                agent_span.record("gen_ai.completion", &merged_texts);
420                agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
421                agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
422
423                // If there are no tool calls, depth is not relevant, we can just return the merged text response.
424                return Ok(PromptResponse::new(merged_texts, usage));
425            }
426
427            let hook = self.hook.clone();
428            let tool_content = stream::iter(tool_calls)
429                .then(|choice| {
430                    let hook1 = hook.clone();
431                    let hook2 = hook.clone();
432
433                    let cancel_sig1 = cancel_sig.clone();
434                    let cancel_sig2 = cancel_sig.clone();
435
436                    let tool_span = info_span!(
437                        "execute_tool",
438                        gen_ai.operation.name = "execute_tool",
439                        gen_ai.tool.type = "function",
440                        gen_ai.tool.name = tracing::field::Empty,
441                        gen_ai.tool.call.id = tracing::field::Empty,
442                        gen_ai.tool.call.arguments = tracing::field::Empty,
443                        gen_ai.tool.call.result = tracing::field::Empty
444                    );
445
446                    let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
447                        let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
448                        tool_span.follows_from(id).to_owned()
449                    } else {
450                        tool_span
451                    };
452
453                    if let Some(id) = tool_span.id() {
454                        current_span_id.store(id.into_u64(), Ordering::SeqCst);
455                    };
456
457                    async move {
458                        if let AssistantContent::ToolCall(tool_call) = choice {
459                            let tool_name = &tool_call.function.name;
460                            let args = tool_call.function.arguments.to_string();
461                            let tool_span = tracing::Span::current();
462                            tool_span.record("gen_ai.tool.name", tool_name);
463                            tool_span.record("gen_ai.tool.call.id", &tool_call.id);
464                            tool_span.record("gen_ai.tool.call.arguments", &args);
465                            if let Some(hook) = hook1 {
466                                hook.on_tool_call(tool_name, &args, cancel_sig1.clone())
467                                    .await;
468                                if cancel_sig1.is_cancelled() {
469                                    return Err(ToolSetError::Interrupted);
470                                }
471                            }
472                            let output =
473                                match agent.tool_server_handle.call_tool(tool_name, &args).await {
474                                    Ok(res) => res,
475                                    Err(e) => {
476                                        tracing::warn!("Error while executing tool: {e}");
477                                        e.to_string()
478                                    }
479                                };
480                            if let Some(hook) = hook2 {
481                                hook.on_tool_result(
482                                    tool_name,
483                                    &args,
484                                    &output.to_string(),
485                                    cancel_sig2.clone(),
486                                )
487                                .await;
488
489                                if cancel_sig2.is_cancelled() {
490                                    return Err(ToolSetError::Interrupted);
491                                }
492                            }
493                            tool_span.record("gen_ai.tool.call.result", &output);
494                            tracing::info!(
495                                "executed tool {tool_name} with args {args}. result: {output}"
496                            );
497                            if let Some(call_id) = tool_call.call_id.clone() {
498                                Ok(UserContent::tool_result_with_call_id(
499                                    tool_call.id.clone(),
500                                    call_id,
501                                    OneOrMany::one(output.into()),
502                                ))
503                            } else {
504                                Ok(UserContent::tool_result(
505                                    tool_call.id.clone(),
506                                    OneOrMany::one(output.into()),
507                                ))
508                            }
509                        } else {
510                            unreachable!(
511                                "This should never happen as we already filtered for `ToolCall`"
512                            )
513                        }
514                    }
515                    .instrument(tool_span)
516                })
517                .collect::<Vec<Result<UserContent, ToolSetError>>>()
518                .await
519                .into_iter()
520                .collect::<Result<Vec<_>, _>>()
521                .map_err(|e| {
522                    if matches!(e, ToolSetError::Interrupted) {
523                        PromptError::prompt_cancelled(chat_history.to_vec())
524                    } else {
525                        e.into()
526                    }
527                })?;
528
529            chat_history.push(Message::User {
530                content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
531            });
532        };
533
534        // If we reach here, we never resolved the final tool call. We need to do ... something.
535        Err(PromptError::MaxDepthError {
536            max_depth: self.max_depth,
537            chat_history: Box::new(chat_history.clone()),
538            prompt: last_prompt,
539        })
540    }
541}