rig/agent/prompt_request/
mod.rs

1pub mod streaming;
2
3pub use streaming::StreamingPromptHook;
4
5use std::{
6    future::IntoFuture,
7    marker::PhantomData,
8    sync::{
9        Arc,
10        atomic::{AtomicBool, AtomicU64, Ordering},
11    },
12};
13use tracing::{Instrument, span::Id};
14
15use futures::{StreamExt, stream};
16use tracing::info_span;
17
18use crate::{
19    OneOrMany,
20    completion::{Completion, CompletionModel, Message, PromptError, Usage},
21    json_utils,
22    message::{AssistantContent, UserContent},
23    tool::ToolSetError,
24    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
25};
26
27use super::Agent;
28
29pub trait PromptType {}
30pub struct Standard;
31pub struct Extended;
32
33impl PromptType for Standard {}
34impl PromptType for Extended {}
35
36/// A builder for creating prompt requests with customizable options.
37/// Uses generics to track which options have been set during the build process.
38///
39/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
40/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
41/// attempting to await (which will send the prompt request) can potentially return
42/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
43/// back to back.
44pub struct PromptRequest<'a, S, M, P>
45where
46    S: PromptType,
47    M: CompletionModel,
48    P: PromptHook<M>,
49{
50    /// The prompt message to send to the model
51    prompt: Message,
52    /// Optional chat history to include with the prompt
53    /// Note: chat history needs to outlive the agent as it might be used with other agents
54    chat_history: Option<&'a mut Vec<Message>>,
55    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
56    max_depth: usize,
57    /// The agent to use for execution
58    agent: &'a Agent<M>,
59    /// Phantom data to track the type of the request
60    state: PhantomData<S>,
61    /// Optional per-request hook for events
62    hook: Option<P>,
63}
64
65impl<'a, M> PromptRequest<'a, Standard, M, ()>
66where
67    M: CompletionModel,
68{
69    /// Create a new PromptRequest with the given prompt and model
70    pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
71        Self {
72            prompt: prompt.into(),
73            chat_history: None,
74            max_depth: 0,
75            agent,
76            state: PhantomData,
77            hook: None,
78        }
79    }
80}
81
82impl<'a, S, M, P> PromptRequest<'a, S, M, P>
83where
84    S: PromptType,
85    M: CompletionModel,
86    P: PromptHook<M>,
87{
88    /// Enable returning extended details for responses (includes aggregated token usage)
89    ///
90    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
91    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
92    /// of conversation.
93    pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
94        PromptRequest {
95            prompt: self.prompt,
96            chat_history: self.chat_history,
97            max_depth: self.max_depth,
98            agent: self.agent,
99            state: PhantomData,
100            hook: self.hook,
101        }
102    }
103    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
104    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
105    pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M, P> {
106        PromptRequest {
107            prompt: self.prompt,
108            chat_history: self.chat_history,
109            max_depth: depth,
110            agent: self.agent,
111            state: PhantomData,
112            hook: self.hook,
113        }
114    }
115
116    /// Add chat history to the prompt request
117    pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M, P> {
118        PromptRequest {
119            prompt: self.prompt,
120            chat_history: Some(history),
121            max_depth: self.max_depth,
122            agent: self.agent,
123            state: PhantomData,
124            hook: self.hook,
125        }
126    }
127
128    /// Attach a per-request hook for tool call events
129    pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
130    where
131        P2: PromptHook<M>,
132    {
133        PromptRequest {
134            prompt: self.prompt,
135            chat_history: self.chat_history,
136            max_depth: self.max_depth,
137            agent: self.agent,
138            state: PhantomData,
139            hook: Some(hook),
140        }
141    }
142}
143
144pub struct CancelSignal(Arc<AtomicBool>);
145
146impl CancelSignal {
147    fn new() -> Self {
148        Self(Arc::new(AtomicBool::new(false)))
149    }
150
151    pub fn cancel(&self) {
152        self.0.store(true, Ordering::SeqCst);
153    }
154
155    fn is_cancelled(&self) -> bool {
156        self.0.load(Ordering::SeqCst)
157    }
158}
159
160impl Clone for CancelSignal {
161    fn clone(&self) -> Self {
162        Self(self.0.clone())
163    }
164}
165
166// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
167/// Trait for per-request hooks to observe tool call events.
168pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
169where
170    M: CompletionModel,
171{
172    #[allow(unused_variables)]
173    /// Called before the prompt is sent to the model
174    fn on_completion_call(
175        &self,
176        prompt: &Message,
177        history: &[Message],
178        cancel_sig: CancelSignal,
179    ) -> impl Future<Output = ()> + WasmCompatSend {
180        async {}
181    }
182
183    #[allow(unused_variables)]
184    /// Called after the prompt is sent to the model and a response is received.
185    fn on_completion_response(
186        &self,
187        prompt: &Message,
188        response: &crate::completion::CompletionResponse<M::Response>,
189        cancel_sig: CancelSignal,
190    ) -> impl Future<Output = ()> + WasmCompatSend {
191        async {}
192    }
193
194    #[allow(unused_variables)]
195    /// Called before a tool is invoked.
196    fn on_tool_call(
197        &self,
198        tool_name: &str,
199        args: &str,
200        cancel_sig: CancelSignal,
201    ) -> impl Future<Output = ()> + WasmCompatSend {
202        async {}
203    }
204
205    #[allow(unused_variables)]
206    /// Called after a tool is invoked (and a result has been returned).
207    fn on_tool_result(
208        &self,
209        tool_name: &str,
210        args: &str,
211        result: &str,
212        cancel_sig: CancelSignal,
213    ) -> impl Future<Output = ()> + WasmCompatSend {
214        async {}
215    }
216}
217
218impl<M> PromptHook<M> for () where M: CompletionModel {}
219
220/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
221///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
222///  directly via the associated type.
223impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
224where
225    M: CompletionModel,
226    P: PromptHook<M> + 'static,
227{
228    type Output = Result<String, PromptError>;
229    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
230
231    fn into_future(self) -> Self::IntoFuture {
232        Box::pin(self.send())
233    }
234}
235
236impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
237where
238    M: CompletionModel,
239    P: PromptHook<M> + 'static,
240{
241    type Output = Result<PromptResponse, PromptError>;
242    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
243
244    fn into_future(self) -> Self::IntoFuture {
245        Box::pin(self.send())
246    }
247}
248
249impl<M, P> PromptRequest<'_, Standard, M, P>
250where
251    M: CompletionModel,
252    P: PromptHook<M>,
253{
254    async fn send(self) -> Result<String, PromptError> {
255        self.extended_details().send().await.map(|resp| resp.output)
256    }
257}
258
259#[derive(Debug, Clone)]
260pub struct PromptResponse {
261    pub output: String,
262    pub total_usage: Usage,
263}
264
265impl PromptResponse {
266    pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
267        Self {
268            output: output.into(),
269            total_usage,
270        }
271    }
272}
273
274impl<M, P> PromptRequest<'_, Extended, M, P>
275where
276    M: CompletionModel,
277    P: PromptHook<M>,
278{
279    async fn send(self) -> Result<PromptResponse, PromptError> {
280        let agent_span = if tracing::Span::current().is_disabled() {
281            info_span!(
282                "invoke_agent",
283                gen_ai.operation.name = "invoke_agent",
284                gen_ai.agent.name = self.agent.name(),
285                gen_ai.system_instructions = self.agent.preamble,
286                gen_ai.prompt = tracing::field::Empty,
287                gen_ai.completion = tracing::field::Empty,
288                gen_ai.usage.input_tokens = tracing::field::Empty,
289                gen_ai.usage.output_tokens = tracing::field::Empty,
290            )
291        } else {
292            tracing::Span::current()
293        };
294
295        let agent = self.agent;
296        let chat_history = if let Some(history) = self.chat_history {
297            history.push(self.prompt.to_owned());
298            history
299        } else {
300            &mut vec![self.prompt.to_owned()]
301        };
302
303        if let Some(text) = self.prompt.rag_text() {
304            agent_span.record("gen_ai.prompt", text);
305        }
306
307        let cancel_sig = CancelSignal::new();
308
309        let mut current_max_depth = 0;
310        let mut usage = Usage::new();
311        let current_span_id: AtomicU64 = AtomicU64::new(0);
312
313        // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
314        let last_prompt = loop {
315            let prompt = chat_history
316                .last()
317                .cloned()
318                .expect("there should always be at least one message in the chat history");
319
320            if current_max_depth > self.max_depth + 1 {
321                break prompt;
322            }
323
324            current_max_depth += 1;
325
326            if self.max_depth > 1 {
327                tracing::info!(
328                    "Current conversation depth: {}/{}",
329                    current_max_depth,
330                    self.max_depth
331                );
332            }
333
334            if let Some(ref hook) = self.hook {
335                hook.on_completion_call(
336                    &prompt,
337                    &chat_history[..chat_history.len() - 1],
338                    cancel_sig.clone(),
339                )
340                .await;
341                if cancel_sig.is_cancelled() {
342                    return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
343                }
344            }
345            let span = tracing::Span::current();
346            let chat_span = info_span!(
347                target: "rig::agent_chat",
348                parent: &span,
349                "chat",
350                gen_ai.operation.name = "chat",
351                gen_ai.system_instructions = self.agent.preamble,
352                gen_ai.provider.name = tracing::field::Empty,
353                gen_ai.request.model = tracing::field::Empty,
354                gen_ai.response.id = tracing::field::Empty,
355                gen_ai.response.model = tracing::field::Empty,
356                gen_ai.usage.output_tokens = tracing::field::Empty,
357                gen_ai.usage.input_tokens = tracing::field::Empty,
358                gen_ai.input.messages = tracing::field::Empty,
359                gen_ai.output.messages = tracing::field::Empty,
360            );
361
362            let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
363                let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
364                chat_span.follows_from(id).to_owned()
365            } else {
366                chat_span
367            };
368
369            if let Some(id) = chat_span.id() {
370                current_span_id.store(id.into_u64(), Ordering::SeqCst);
371            };
372
373            let resp = agent
374                .completion(
375                    prompt.clone(),
376                    chat_history[..chat_history.len() - 1].to_vec(),
377                )
378                .await?
379                .send()
380                .instrument(chat_span.clone())
381                .await?;
382
383            usage += resp.usage;
384
385            if let Some(ref hook) = self.hook {
386                hook.on_completion_response(&prompt, &resp, cancel_sig.clone())
387                    .await;
388                if cancel_sig.is_cancelled() {
389                    return Err(PromptError::prompt_cancelled(chat_history.to_vec()));
390                }
391            }
392
393            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
394                .choice
395                .iter()
396                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
397
398            chat_history.push(Message::Assistant {
399                id: None,
400                content: resp.choice.clone(),
401            });
402
403            if tool_calls.is_empty() {
404                let merged_texts = texts
405                    .into_iter()
406                    .filter_map(|content| {
407                        if let AssistantContent::Text(text) = content {
408                            Some(text.text.clone())
409                        } else {
410                            None
411                        }
412                    })
413                    .collect::<Vec<_>>()
414                    .join("\n");
415
416                if self.max_depth > 1 {
417                    tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
418                }
419
420                agent_span.record("gen_ai.completion", &merged_texts);
421                agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
422                agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
423
424                // If there are no tool calls, depth is not relevant, we can just return the merged text response.
425                return Ok(PromptResponse::new(merged_texts, usage));
426            }
427
428            let hook = self.hook.clone();
429            let tool_content = stream::iter(tool_calls)
430                .then(|choice| {
431                    let hook1 = hook.clone();
432                    let hook2 = hook.clone();
433
434                    let cancel_sig1 = cancel_sig.clone();
435                    let cancel_sig2 = cancel_sig.clone();
436
437                    let tool_span = info_span!(
438                        "execute_tool",
439                        gen_ai.operation.name = "execute_tool",
440                        gen_ai.tool.type = "function",
441                        gen_ai.tool.name = tracing::field::Empty,
442                        gen_ai.tool.call.id = tracing::field::Empty,
443                        gen_ai.tool.call.arguments = tracing::field::Empty,
444                        gen_ai.tool.call.result = tracing::field::Empty
445                    );
446
447                    let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
448                        let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
449                        tool_span.follows_from(id).to_owned()
450                    } else {
451                        tool_span
452                    };
453
454                    if let Some(id) = tool_span.id() {
455                        current_span_id.store(id.into_u64(), Ordering::SeqCst);
456                    };
457
458                    async move {
459                        if let AssistantContent::ToolCall(tool_call) = choice {
460                            let tool_name = &tool_call.function.name;
461                            let args =
462                                json_utils::value_to_json_string(&tool_call.function.arguments);
463                            let tool_span = tracing::Span::current();
464                            tool_span.record("gen_ai.tool.name", tool_name);
465                            tool_span.record("gen_ai.tool.call.id", &tool_call.id);
466                            tool_span.record("gen_ai.tool.call.arguments", &args);
467                            if let Some(hook) = hook1 {
468                                hook.on_tool_call(tool_name, &args, cancel_sig1.clone())
469                                    .await;
470                                if cancel_sig1.is_cancelled() {
471                                    return Err(ToolSetError::Interrupted);
472                                }
473                            }
474                            let output =
475                                match agent.tool_server_handle.call_tool(tool_name, &args).await {
476                                    Ok(res) => res,
477                                    Err(e) => {
478                                        tracing::warn!("Error while executing tool: {e}");
479                                        e.to_string()
480                                    }
481                                };
482                            if let Some(hook) = hook2 {
483                                hook.on_tool_result(
484                                    tool_name,
485                                    &args,
486                                    &output.to_string(),
487                                    cancel_sig2.clone(),
488                                )
489                                .await;
490
491                                if cancel_sig2.is_cancelled() {
492                                    return Err(ToolSetError::Interrupted);
493                                }
494                            }
495                            tool_span.record("gen_ai.tool.call.result", &output);
496                            tracing::info!(
497                                "executed tool {tool_name} with args {args}. result: {output}"
498                            );
499                            if let Some(call_id) = tool_call.call_id.clone() {
500                                Ok(UserContent::tool_result_with_call_id(
501                                    tool_call.id.clone(),
502                                    call_id,
503                                    OneOrMany::one(output.into()),
504                                ))
505                            } else {
506                                Ok(UserContent::tool_result(
507                                    tool_call.id.clone(),
508                                    OneOrMany::one(output.into()),
509                                ))
510                            }
511                        } else {
512                            unreachable!(
513                                "This should never happen as we already filtered for `ToolCall`"
514                            )
515                        }
516                    }
517                    .instrument(tool_span)
518                })
519                .collect::<Vec<Result<UserContent, ToolSetError>>>()
520                .await
521                .into_iter()
522                .collect::<Result<Vec<_>, _>>()
523                .map_err(|e| {
524                    if matches!(e, ToolSetError::Interrupted) {
525                        PromptError::prompt_cancelled(chat_history.to_vec())
526                    } else {
527                        e.into()
528                    }
529                })?;
530
531            chat_history.push(Message::User {
532                content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
533            });
534        };
535
536        // If we reach here, we never resolved the final tool call. We need to do ... something.
537        Err(PromptError::MaxDepthError {
538            max_depth: self.max_depth,
539            chat_history: Box::new(chat_history.clone()),
540            prompt: last_prompt,
541        })
542    }
543}