rig/agent/prompt_request/
mod.rs

1pub mod streaming;
2
3pub use streaming::StreamingPromptHook;
4
5use std::{
6    future::IntoFuture,
7    marker::PhantomData,
8    sync::{
9        Arc,
10        atomic::{AtomicBool, AtomicU64, Ordering},
11    },
12};
13use tracing::{Instrument, span::Id};
14
15use futures::{StreamExt, stream};
16use tracing::info_span;
17
18use crate::{
19    OneOrMany,
20    completion::{Completion, CompletionModel, Message, PromptError, Usage},
21    json_utils,
22    message::{AssistantContent, UserContent},
23    tool::ToolSetError,
24    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
25};
26
27use super::Agent;
28
29pub trait PromptType {}
30pub struct Standard;
31pub struct Extended;
32
33impl PromptType for Standard {}
34impl PromptType for Extended {}
35
36/// A builder for creating prompt requests with customizable options.
37/// Uses generics to track which options have been set during the build process.
38///
39/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
40/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
41/// attempting to await (which will send the prompt request) can potentially return
42/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
43/// back to back.
44pub struct PromptRequest<'a, S, M, P>
45where
46    S: PromptType,
47    M: CompletionModel,
48    P: PromptHook<M>,
49{
50    /// The prompt message to send to the model
51    prompt: Message,
52    /// Optional chat history to include with the prompt
53    /// Note: chat history needs to outlive the agent as it might be used with other agents
54    chat_history: Option<&'a mut Vec<Message>>,
55    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
56    max_depth: usize,
57    /// The agent to use for execution
58    agent: &'a Agent<M>,
59    /// Phantom data to track the type of the request
60    state: PhantomData<S>,
61    /// Optional per-request hook for events
62    hook: Option<P>,
63    /// How many tools should be executed at the same time (1 by default).
64    concurrency: usize,
65}
66
67impl<'a, M> PromptRequest<'a, Standard, M, ()>
68where
69    M: CompletionModel,
70{
71    /// Create a new PromptRequest with the given prompt and model
72    pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
73        Self {
74            prompt: prompt.into(),
75            chat_history: None,
76            max_depth: 0,
77            agent,
78            state: PhantomData,
79            hook: None,
80            concurrency: 1,
81        }
82    }
83}
84
85impl<'a, S, M, P> PromptRequest<'a, S, M, P>
86where
87    S: PromptType,
88    M: CompletionModel,
89    P: PromptHook<M>,
90{
91    /// Enable returning extended details for responses (includes aggregated token usage)
92    ///
93    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
94    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
95    /// of conversation.
96    pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
97        PromptRequest {
98            prompt: self.prompt,
99            chat_history: self.chat_history,
100            max_depth: self.max_depth,
101            agent: self.agent,
102            state: PhantomData,
103            hook: self.hook,
104            concurrency: self.concurrency,
105        }
106    }
107    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
108    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
109    pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M, P> {
110        PromptRequest {
111            prompt: self.prompt,
112            chat_history: self.chat_history,
113            max_depth: depth,
114            agent: self.agent,
115            state: PhantomData,
116            hook: self.hook,
117            concurrency: self.concurrency,
118        }
119    }
120
121    /// Add concurrency to the prompt request.
122    /// This will cause the agent to execute tools concurrently.
123    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
124        self.concurrency = concurrency;
125        self
126    }
127
128    /// Add chat history to the prompt request
129    pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M, P> {
130        PromptRequest {
131            prompt: self.prompt,
132            chat_history: Some(history),
133            max_depth: self.max_depth,
134            agent: self.agent,
135            state: PhantomData,
136            hook: self.hook,
137            concurrency: self.concurrency,
138        }
139    }
140
141    /// Attach a per-request hook for tool call events
142    pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
143    where
144        P2: PromptHook<M>,
145    {
146        PromptRequest {
147            prompt: self.prompt,
148            chat_history: self.chat_history,
149            max_depth: self.max_depth,
150            agent: self.agent,
151            state: PhantomData,
152            hook: Some(hook),
153            concurrency: self.concurrency,
154        }
155    }
156}
157
158pub struct CancelSignal(Arc<AtomicBool>);
159
160impl CancelSignal {
161    fn new() -> Self {
162        Self(Arc::new(AtomicBool::new(false)))
163    }
164
165    pub fn cancel(&self) {
166        self.0.store(true, Ordering::SeqCst);
167    }
168
169    fn is_cancelled(&self) -> bool {
170        self.0.load(Ordering::SeqCst)
171    }
172}
173
174impl Clone for CancelSignal {
175    fn clone(&self) -> Self {
176        Self(self.0.clone())
177    }
178}
179
180// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
181/// Trait for per-request hooks to observe tool call events.
182pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
183where
184    M: CompletionModel,
185{
186    #[allow(unused_variables)]
187    /// Called before the prompt is sent to the model
188    fn on_completion_call(
189        &self,
190        prompt: &Message,
191        history: &[Message],
192        cancel_sig: CancelSignal,
193    ) -> impl Future<Output = ()> + WasmCompatSend {
194        async {}
195    }
196
197    #[allow(unused_variables)]
198    /// Called after the prompt is sent to the model and a response is received.
199    fn on_completion_response(
200        &self,
201        prompt: &Message,
202        response: &crate::completion::CompletionResponse<M::Response>,
203        cancel_sig: CancelSignal,
204    ) -> impl Future<Output = ()> + WasmCompatSend {
205        async {}
206    }
207
208    #[allow(unused_variables)]
209    /// Called before a tool is invoked.
210    fn on_tool_call(
211        &self,
212        tool_name: &str,
213        tool_call_id: Option<String>,
214        args: &str,
215        cancel_sig: CancelSignal,
216    ) -> impl Future<Output = ()> + WasmCompatSend {
217        async {}
218    }
219
220    #[allow(unused_variables)]
221    /// Called after a tool is invoked (and a result has been returned).
222    fn on_tool_result(
223        &self,
224        tool_name: &str,
225        tool_call_id: Option<String>,
226        args: &str,
227        result: &str,
228        cancel_sig: CancelSignal,
229    ) -> impl Future<Output = ()> + WasmCompatSend {
230        async {}
231    }
232}
233
234impl<M> PromptHook<M> for () where M: CompletionModel {}
235
236/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
237///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
238///  directly via the associated type.
239impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
240where
241    M: CompletionModel,
242    P: PromptHook<M> + 'static,
243{
244    type Output = Result<String, PromptError>;
245    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
246
247    fn into_future(self) -> Self::IntoFuture {
248        Box::pin(self.send())
249    }
250}
251
252impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
253where
254    M: CompletionModel,
255    P: PromptHook<M> + 'static,
256{
257    type Output = Result<PromptResponse, PromptError>;
258    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
259
260    fn into_future(self) -> Self::IntoFuture {
261        Box::pin(self.send())
262    }
263}
264
265impl<M, P> PromptRequest<'_, Standard, M, P>
266where
267    M: CompletionModel,
268    P: PromptHook<M>,
269{
270    async fn send(self) -> Result<String, PromptError> {
271        self.extended_details().send().await.map(|resp| resp.output)
272    }
273}
274
275#[derive(Debug, Clone)]
276pub struct PromptResponse {
277    pub output: String,
278    pub total_usage: Usage,
279}
280
281impl PromptResponse {
282    pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
283        Self {
284            output: output.into(),
285            total_usage,
286        }
287    }
288}
289
290impl<M, P> PromptRequest<'_, Extended, M, P>
291where
292    M: CompletionModel,
293    P: PromptHook<M>,
294{
295    async fn send(self) -> Result<PromptResponse, PromptError> {
296        let agent_span = if tracing::Span::current().is_disabled() {
297            info_span!(
298                "invoke_agent",
299                gen_ai.operation.name = "invoke_agent",
300                gen_ai.agent.name = self.agent.name(),
301                gen_ai.system_instructions = self.agent.preamble,
302                gen_ai.prompt = tracing::field::Empty,
303                gen_ai.completion = tracing::field::Empty,
304                gen_ai.usage.input_tokens = tracing::field::Empty,
305                gen_ai.usage.output_tokens = tracing::field::Empty,
306            )
307        } else {
308            tracing::Span::current()
309        };
310
311        let agent = self.agent;
312        let chat_history = if let Some(history) = self.chat_history {
313            history.push(self.prompt.to_owned());
314            history
315        } else {
316            &mut vec![self.prompt.to_owned()]
317        };
318
319        if let Some(text) = self.prompt.rag_text() {
320            agent_span.record("gen_ai.prompt", text);
321        }
322
323        let cancel_sig = CancelSignal::new();
324
325        let mut current_max_depth = 0;
326        let mut usage = Usage::new();
327        let current_span_id: AtomicU64 = AtomicU64::new(0);
328
329        // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
330        let last_prompt = loop {
331            let prompt = chat_history
332                .last()
333                .cloned()
334                .expect("there should always be at least one message in the chat history");
335
336            if current_max_depth > self.max_depth + 1 {
337                break prompt;
338            }
339
340            current_max_depth += 1;
341
342            if self.max_depth > 1 {
343                tracing::info!(
344                    "Current conversation depth: {}/{}",
345                    current_max_depth,
346                    self.max_depth
347                );
348            }
349
350            if let Some(ref hook) = self.hook {
351                hook.on_completion_call(
352                    &prompt,
353                    &chat_history[..chat_history.len() - 1],
354                    cancel_sig.clone(),
355                )
356                .await;
357                if cancel_sig.is_cancelled() {
358                    return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
359                }
360            }
361            let span = tracing::Span::current();
362            let chat_span = info_span!(
363                target: "rig::agent_chat",
364                parent: &span,
365                "chat",
366                gen_ai.operation.name = "chat",
367                gen_ai.system_instructions = self.agent.preamble,
368                gen_ai.provider.name = tracing::field::Empty,
369                gen_ai.request.model = tracing::field::Empty,
370                gen_ai.response.id = tracing::field::Empty,
371                gen_ai.response.model = tracing::field::Empty,
372                gen_ai.usage.output_tokens = tracing::field::Empty,
373                gen_ai.usage.input_tokens = tracing::field::Empty,
374                gen_ai.input.messages = tracing::field::Empty,
375                gen_ai.output.messages = tracing::field::Empty,
376            );
377
378            let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
379                let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
380                chat_span.follows_from(id).to_owned()
381            } else {
382                chat_span
383            };
384
385            if let Some(id) = chat_span.id() {
386                current_span_id.store(id.into_u64(), Ordering::SeqCst);
387            };
388
389            let resp = agent
390                .completion(
391                    prompt.clone(),
392                    chat_history[..chat_history.len() - 1].to_vec(),
393                )
394                .await?
395                .send()
396                .instrument(chat_span.clone())
397                .await?;
398
399            usage += resp.usage;
400
401            if let Some(ref hook) = self.hook {
402                hook.on_completion_response(&prompt, &resp, cancel_sig.clone())
403                    .await;
404                if cancel_sig.is_cancelled() {
405                    return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
406                }
407            }
408
409            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
410                .choice
411                .iter()
412                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
413
414            chat_history.push(Message::Assistant {
415                id: None,
416                content: resp.choice.clone(),
417            });
418
419            if tool_calls.is_empty() {
420                let merged_texts = texts
421                    .into_iter()
422                    .filter_map(|content| {
423                        if let AssistantContent::Text(text) = content {
424                            Some(text.text.clone())
425                        } else {
426                            None
427                        }
428                    })
429                    .collect::<Vec<_>>()
430                    .join("\n");
431
432                if self.max_depth > 1 {
433                    tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
434                }
435
436                agent_span.record("gen_ai.completion", &merged_texts);
437                agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
438                agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
439
440                // If there are no tool calls, depth is not relevant, we can just return the merged text response.
441                return Ok(PromptResponse::new(merged_texts, usage));
442            }
443
444            let hook = self.hook.clone();
445
446            let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
447            let tool_content = stream::iter(tool_calls)
448                .map(|choice| {
449                    let hook1 = hook.clone();
450                    let hook2 = hook.clone();
451
452                    let cancel_sig1 = cancel_sig.clone();
453                    let cancel_sig2 = cancel_sig.clone();
454
455                    let tool_span = info_span!(
456                        "execute_tool",
457                        gen_ai.operation.name = "execute_tool",
458                        gen_ai.tool.type = "function",
459                        gen_ai.tool.name = tracing::field::Empty,
460                        gen_ai.tool.call.id = tracing::field::Empty,
461                        gen_ai.tool.call.arguments = tracing::field::Empty,
462                        gen_ai.tool.call.result = tracing::field::Empty
463                    );
464
465                    let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
466                        let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
467                        tool_span.follows_from(id).to_owned()
468                    } else {
469                        tool_span
470                    };
471
472                    if let Some(id) = tool_span.id() {
473                        current_span_id.store(id.into_u64(), Ordering::SeqCst);
474                    };
475
476                    async move {
477                        if let AssistantContent::ToolCall(tool_call) = choice {
478                            let tool_name = &tool_call.function.name;
479                            let args =
480                                json_utils::value_to_json_string(&tool_call.function.arguments);
481                            let tool_span = tracing::Span::current();
482                            tool_span.record("gen_ai.tool.name", tool_name);
483                            tool_span.record("gen_ai.tool.call.id", &tool_call.id);
484                            tool_span.record("gen_ai.tool.call.arguments", &args);
485                            if let Some(hook) = hook1 {
486                                hook.on_tool_call(
487                                    tool_name,
488                                    tool_call.call_id.clone(),
489                                    &args,
490                                    cancel_sig1.clone(),
491                                )
492                                .await;
493                                if cancel_sig1.is_cancelled() {
494                                    return Err(ToolSetError::Interrupted);
495                                }
496                            }
497                            let output =
498                                match agent.tool_server_handle.call_tool(tool_name, &args).await {
499                                    Ok(res) => res,
500                                    Err(e) => {
501                                        tracing::warn!("Error while executing tool: {e}");
502                                        e.to_string()
503                                    }
504                                };
505                            if let Some(hook) = hook2 {
506                                hook.on_tool_result(
507                                    tool_name,
508                                    tool_call.call_id.clone(),
509                                    &args,
510                                    &output.to_string(),
511                                    cancel_sig2.clone(),
512                                )
513                                .await;
514
515                                if cancel_sig2.is_cancelled() {
516                                    return Err(ToolSetError::Interrupted);
517                                }
518                            }
519                            tool_span.record("gen_ai.tool.call.result", &output);
520                            tracing::info!(
521                                "executed tool {tool_name} with args {args}. result: {output}"
522                            );
523                            if let Some(call_id) = tool_call.call_id.clone() {
524                                Ok(UserContent::tool_result_with_call_id(
525                                    tool_call.id.clone(),
526                                    call_id,
527                                    OneOrMany::one(output.into()),
528                                ))
529                            } else {
530                                Ok(UserContent::tool_result(
531                                    tool_call.id.clone(),
532                                    OneOrMany::one(output.into()),
533                                ))
534                            }
535                        } else {
536                            unreachable!(
537                                "This should never happen as we already filtered for `ToolCall`"
538                            )
539                        }
540                    }
541                    .instrument(tool_span)
542                })
543                .buffer_unordered(self.concurrency)
544                .collect::<Vec<Result<UserContent, ToolSetError>>>()
545                .await
546                .into_iter()
547                .collect::<Result<Vec<_>, _>>()
548                .map_err(|e| {
549                    if matches!(e, ToolSetError::Interrupted) {
550                        PromptError::prompt_cancelled(chat_history.to_vec())
551                    } else {
552                        e.into()
553                    }
554                })?;
555
556            chat_history.push(Message::User {
557                content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
558            });
559        };
560
561        // If we reach here, we never resolved the final tool call. We need to do ... something.
562        Err(PromptError::MaxDepthError {
563            max_depth: self.max_depth,
564            chat_history: Box::new(chat_history.clone()),
565            prompt: Box::new(last_prompt),
566        })
567    }
568}