rig/agent/prompt_request/
mod.rs

1pub mod streaming;
2
3pub use streaming::StreamingPromptHook;
4
5use std::{
6    future::IntoFuture,
7    marker::PhantomData,
8    sync::{
9        Arc, OnceLock,
10        atomic::{AtomicBool, AtomicU64, Ordering},
11    },
12};
13use tracing::{Instrument, span::Id};
14
15use futures::{StreamExt, stream};
16use tracing::info_span;
17
18use crate::{
19    OneOrMany,
20    completion::{Completion, CompletionModel, Message, PromptError, Usage},
21    json_utils,
22    message::{AssistantContent, UserContent},
23    tool::ToolSetError,
24    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
25};
26
27use super::Agent;
28
29pub trait PromptType {}
30pub struct Standard;
31pub struct Extended;
32
33impl PromptType for Standard {}
34impl PromptType for Extended {}
35
36/// A builder for creating prompt requests with customizable options.
37/// Uses generics to track which options have been set during the build process.
38///
39/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
40/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
41/// attempting to await (which will send the prompt request) can potentially return
42/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
43/// back to back.
44pub struct PromptRequest<'a, S, M, P>
45where
46    S: PromptType,
47    M: CompletionModel,
48    P: PromptHook<M>,
49{
50    /// The prompt message to send to the model
51    prompt: Message,
52    /// Optional chat history to include with the prompt
53    /// Note: chat history needs to outlive the agent as it might be used with other agents
54    chat_history: Option<&'a mut Vec<Message>>,
55    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
56    max_depth: usize,
57    /// The agent to use for execution
58    agent: &'a Agent<M>,
59    /// Phantom data to track the type of the request
60    state: PhantomData<S>,
61    /// Optional per-request hook for events
62    hook: Option<P>,
63    /// How many tools should be executed at the same time (1 by default).
64    concurrency: usize,
65}
66
67impl<'a, M> PromptRequest<'a, Standard, M, ()>
68where
69    M: CompletionModel,
70{
71    /// Create a new PromptRequest with the given prompt and model
72    pub fn new(agent: &'a Agent<M>, prompt: impl Into<Message>) -> Self {
73        Self {
74            prompt: prompt.into(),
75            chat_history: None,
76            max_depth: agent.default_max_depth.unwrap_or_default(),
77            agent,
78            state: PhantomData,
79            hook: None,
80            concurrency: 1,
81        }
82    }
83}
84
85impl<'a, S, M, P> PromptRequest<'a, S, M, P>
86where
87    S: PromptType,
88    M: CompletionModel,
89    P: PromptHook<M>,
90{
91    /// Enable returning extended details for responses (includes aggregated token usage)
92    ///
93    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
94    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
95    /// of conversation.
96    pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
97        PromptRequest {
98            prompt: self.prompt,
99            chat_history: self.chat_history,
100            max_depth: self.max_depth,
101            agent: self.agent,
102            state: PhantomData,
103            hook: self.hook,
104            concurrency: self.concurrency,
105        }
106    }
107    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
108    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
109    pub fn multi_turn(self, depth: usize) -> PromptRequest<'a, S, M, P> {
110        PromptRequest {
111            prompt: self.prompt,
112            chat_history: self.chat_history,
113            max_depth: depth,
114            agent: self.agent,
115            state: PhantomData,
116            hook: self.hook,
117            concurrency: self.concurrency,
118        }
119    }
120
121    /// Add concurrency to the prompt request.
122    /// This will cause the agent to execute tools concurrently.
123    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
124        self.concurrency = concurrency;
125        self
126    }
127
128    /// Add chat history to the prompt request
129    pub fn with_history(self, history: &'a mut Vec<Message>) -> PromptRequest<'a, S, M, P> {
130        PromptRequest {
131            prompt: self.prompt,
132            chat_history: Some(history),
133            max_depth: self.max_depth,
134            agent: self.agent,
135            state: PhantomData,
136            hook: self.hook,
137            concurrency: self.concurrency,
138        }
139    }
140
141    /// Attach a per-request hook for tool call events
142    pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
143    where
144        P2: PromptHook<M>,
145    {
146        PromptRequest {
147            prompt: self.prompt,
148            chat_history: self.chat_history,
149            max_depth: self.max_depth,
150            agent: self.agent,
151            state: PhantomData,
152            hook: Some(hook),
153            concurrency: self.concurrency,
154        }
155    }
156}
157
158/// Handles cancellations from a [`PromptHook`] in an agentic loop.
159/// Upon using `CancelSignal::cancel()`, the agent loop will terminate early, providing the messages generated so far.
160/// You can additionally add a reason for early termination with `CancelSignal::cancel_with_reason()`.
161pub struct CancelSignal {
162    sig: Arc<AtomicBool>,
163    reason: Arc<OnceLock<String>>,
164}
165
166impl CancelSignal {
167    fn new() -> Self {
168        Self {
169            sig: Arc::new(AtomicBool::new(false)),
170            reason: Arc::new(OnceLock::new()),
171        }
172    }
173
174    pub fn cancel(&self) {
175        self.sig.store(true, Ordering::SeqCst);
176    }
177
178    pub fn cancel_with_reason(&self, reason: &str) {
179        // SAFETY: This can only be set once. We immediately return once the prompt hook is finished if the internal AtomicBool is set to true
180        // It is technically on the user to return early when using this in a prompt hook, but this is relatively obvious
181        let _ = self.reason.set(reason.to_string());
182        self.cancel();
183    }
184
185    fn is_cancelled(&self) -> bool {
186        self.sig.load(Ordering::SeqCst)
187    }
188
189    fn cancel_reason(&self) -> Option<&str> {
190        self.reason.get().map(|x| x.as_str())
191    }
192}
193
194impl Clone for CancelSignal {
195    fn clone(&self) -> Self {
196        Self {
197            sig: self.sig.clone(),
198            reason: self.reason.clone(),
199        }
200    }
201}
202
203// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
204/// Trait for per-request hooks to observe tool call events.
205pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
206where
207    M: CompletionModel,
208{
209    #[allow(unused_variables)]
210    /// Called before the prompt is sent to the model
211    fn on_completion_call(
212        &self,
213        prompt: &Message,
214        history: &[Message],
215        cancel_sig: CancelSignal,
216    ) -> impl Future<Output = ()> + WasmCompatSend {
217        async {}
218    }
219
220    #[allow(unused_variables)]
221    /// Called after the prompt is sent to the model and a response is received.
222    fn on_completion_response(
223        &self,
224        prompt: &Message,
225        response: &crate::completion::CompletionResponse<M::Response>,
226        cancel_sig: CancelSignal,
227    ) -> impl Future<Output = ()> + WasmCompatSend {
228        async {}
229    }
230
231    #[allow(unused_variables)]
232    /// Called before a tool is invoked.
233    fn on_tool_call(
234        &self,
235        tool_name: &str,
236        tool_call_id: Option<String>,
237        args: &str,
238        cancel_sig: CancelSignal,
239    ) -> impl Future<Output = ()> + WasmCompatSend {
240        async {}
241    }
242
243    #[allow(unused_variables)]
244    /// Called after a tool is invoked (and a result has been returned).
245    fn on_tool_result(
246        &self,
247        tool_name: &str,
248        tool_call_id: Option<String>,
249        args: &str,
250        result: &str,
251        cancel_sig: CancelSignal,
252    ) -> impl Future<Output = ()> + WasmCompatSend {
253        async {}
254    }
255}
256
257impl<M> PromptHook<M> for () where M: CompletionModel {}
258
259/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
260///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
261///  directly via the associated type.
262impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
263where
264    M: CompletionModel,
265    P: PromptHook<M> + 'static,
266{
267    type Output = Result<String, PromptError>;
268    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
269
270    fn into_future(self) -> Self::IntoFuture {
271        Box::pin(self.send())
272    }
273}
274
275impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
276where
277    M: CompletionModel,
278    P: PromptHook<M> + 'static,
279{
280    type Output = Result<PromptResponse, PromptError>;
281    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
282
283    fn into_future(self) -> Self::IntoFuture {
284        Box::pin(self.send())
285    }
286}
287
288impl<M, P> PromptRequest<'_, Standard, M, P>
289where
290    M: CompletionModel,
291    P: PromptHook<M>,
292{
293    async fn send(self) -> Result<String, PromptError> {
294        self.extended_details().send().await.map(|resp| resp.output)
295    }
296}
297
298#[derive(Debug, Clone)]
299pub struct PromptResponse {
300    pub output: String,
301    pub total_usage: Usage,
302}
303
304impl PromptResponse {
305    pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
306        Self {
307            output: output.into(),
308            total_usage,
309        }
310    }
311}
312
313impl<M, P> PromptRequest<'_, Extended, M, P>
314where
315    M: CompletionModel,
316    P: PromptHook<M>,
317{
318    async fn send(self) -> Result<PromptResponse, PromptError> {
319        let agent_span = if tracing::Span::current().is_disabled() {
320            info_span!(
321                "invoke_agent",
322                gen_ai.operation.name = "invoke_agent",
323                gen_ai.agent.name = self.agent.name(),
324                gen_ai.system_instructions = self.agent.preamble,
325                gen_ai.prompt = tracing::field::Empty,
326                gen_ai.completion = tracing::field::Empty,
327                gen_ai.usage.input_tokens = tracing::field::Empty,
328                gen_ai.usage.output_tokens = tracing::field::Empty,
329            )
330        } else {
331            tracing::Span::current()
332        };
333
334        let agent = self.agent;
335        let chat_history = if let Some(history) = self.chat_history {
336            history.push(self.prompt.to_owned());
337            history
338        } else {
339            &mut vec![self.prompt.to_owned()]
340        };
341
342        if let Some(text) = self.prompt.rag_text() {
343            agent_span.record("gen_ai.prompt", text);
344        }
345
346        let cancel_sig = CancelSignal::new();
347
348        let mut current_max_depth = 0;
349        let mut usage = Usage::new();
350        let current_span_id: AtomicU64 = AtomicU64::new(0);
351
352        // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
353        let last_prompt = loop {
354            let prompt = chat_history
355                .last()
356                .cloned()
357                .expect("there should always be at least one message in the chat history");
358
359            if current_max_depth > self.max_depth + 1 {
360                break prompt;
361            }
362
363            current_max_depth += 1;
364
365            if self.max_depth > 1 {
366                tracing::info!(
367                    "Current conversation depth: {}/{}",
368                    current_max_depth,
369                    self.max_depth
370                );
371            }
372
373            if let Some(ref hook) = self.hook {
374                hook.on_completion_call(
375                    &prompt,
376                    &chat_history[..chat_history.len() - 1],
377                    cancel_sig.clone(),
378                )
379                .await;
380                if cancel_sig.is_cancelled() {
381                    return Err(PromptError::prompt_cancelled(
382                        chat_history.to_vec(),
383                        cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
384                    ));
385                }
386            }
387            let span = tracing::Span::current();
388            let chat_span = info_span!(
389                target: "rig::agent_chat",
390                parent: &span,
391                "chat",
392                gen_ai.operation.name = "chat",
393                gen_ai.agent.name = self.agent.name(),
394                gen_ai.system_instructions = self.agent.preamble,
395                gen_ai.provider.name = tracing::field::Empty,
396                gen_ai.request.model = tracing::field::Empty,
397                gen_ai.response.id = tracing::field::Empty,
398                gen_ai.response.model = tracing::field::Empty,
399                gen_ai.usage.output_tokens = tracing::field::Empty,
400                gen_ai.usage.input_tokens = tracing::field::Empty,
401                gen_ai.input.messages = tracing::field::Empty,
402                gen_ai.output.messages = tracing::field::Empty,
403            );
404
405            let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
406                let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
407                chat_span.follows_from(id).to_owned()
408            } else {
409                chat_span
410            };
411
412            if let Some(id) = chat_span.id() {
413                current_span_id.store(id.into_u64(), Ordering::SeqCst);
414            };
415
416            let resp = agent
417                .completion(
418                    prompt.clone(),
419                    chat_history[..chat_history.len() - 1].to_vec(),
420                )
421                .await?
422                .send()
423                .instrument(chat_span.clone())
424                .await?;
425
426            usage += resp.usage;
427
428            if let Some(ref hook) = self.hook {
429                hook.on_completion_response(&prompt, &resp, cancel_sig.clone())
430                    .await;
431                if cancel_sig.is_cancelled() {
432                    return Err(PromptError::prompt_cancelled(
433                        chat_history.to_vec(),
434                        cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
435                    ));
436                }
437            }
438
439            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
440                .choice
441                .iter()
442                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
443
444            chat_history.push(Message::Assistant {
445                id: None,
446                content: resp.choice.clone(),
447            });
448
449            if tool_calls.is_empty() {
450                let merged_texts = texts
451                    .into_iter()
452                    .filter_map(|content| {
453                        if let AssistantContent::Text(text) = content {
454                            Some(text.text.clone())
455                        } else {
456                            None
457                        }
458                    })
459                    .collect::<Vec<_>>()
460                    .join("\n");
461
462                if self.max_depth > 1 {
463                    tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth);
464                }
465
466                agent_span.record("gen_ai.completion", &merged_texts);
467                agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
468                agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
469
470                // If there are no tool calls, depth is not relevant, we can just return the merged text response.
471                return Ok(PromptResponse::new(merged_texts, usage));
472            }
473
474            let hook = self.hook.clone();
475
476            let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
477            let tool_content = stream::iter(tool_calls)
478                .map(|choice| {
479                    let hook1 = hook.clone();
480                    let hook2 = hook.clone();
481
482                    let cancel_sig1 = cancel_sig.clone();
483                    let cancel_sig2 = cancel_sig.clone();
484
485                    let tool_span = info_span!(
486                        "execute_tool",
487                        gen_ai.operation.name = "execute_tool",
488                        gen_ai.tool.type = "function",
489                        gen_ai.tool.name = tracing::field::Empty,
490                        gen_ai.tool.call.id = tracing::field::Empty,
491                        gen_ai.tool.call.arguments = tracing::field::Empty,
492                        gen_ai.tool.call.result = tracing::field::Empty
493                    );
494
495                    let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
496                        let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
497                        tool_span.follows_from(id).to_owned()
498                    } else {
499                        tool_span
500                    };
501
502                    if let Some(id) = tool_span.id() {
503                        current_span_id.store(id.into_u64(), Ordering::SeqCst);
504                    };
505
506                    async move {
507                        if let AssistantContent::ToolCall(tool_call) = choice {
508                            let tool_name = &tool_call.function.name;
509                            let args =
510                                json_utils::value_to_json_string(&tool_call.function.arguments);
511                            let tool_span = tracing::Span::current();
512                            tool_span.record("gen_ai.tool.name", tool_name);
513                            tool_span.record("gen_ai.tool.call.id", &tool_call.id);
514                            tool_span.record("gen_ai.tool.call.arguments", &args);
515                            if let Some(hook) = hook1 {
516                                hook.on_tool_call(
517                                    tool_name,
518                                    tool_call.call_id.clone(),
519                                    &args,
520                                    cancel_sig1.clone(),
521                                )
522                                .await;
523                                if cancel_sig1.is_cancelled() {
524                                    return Err(ToolSetError::Interrupted);
525                                }
526                            }
527                            let output =
528                                match agent.tool_server_handle.call_tool(tool_name, &args).await {
529                                    Ok(res) => res,
530                                    Err(e) => {
531                                        tracing::warn!("Error while executing tool: {e}");
532                                        e.to_string()
533                                    }
534                                };
535                            if let Some(hook) = hook2 {
536                                hook.on_tool_result(
537                                    tool_name,
538                                    tool_call.call_id.clone(),
539                                    &args,
540                                    &output.to_string(),
541                                    cancel_sig2.clone(),
542                                )
543                                .await;
544
545                                if cancel_sig2.is_cancelled() {
546                                    return Err(ToolSetError::Interrupted);
547                                }
548                            }
549                            tool_span.record("gen_ai.tool.call.result", &output);
550                            tracing::info!(
551                                "executed tool {tool_name} with args {args}. result: {output}"
552                            );
553                            if let Some(call_id) = tool_call.call_id.clone() {
554                                Ok(UserContent::tool_result_with_call_id(
555                                    tool_call.id.clone(),
556                                    call_id,
557                                    OneOrMany::one(output.into()),
558                                ))
559                            } else {
560                                Ok(UserContent::tool_result(
561                                    tool_call.id.clone(),
562                                    OneOrMany::one(output.into()),
563                                ))
564                            }
565                        } else {
566                            unreachable!(
567                                "This should never happen as we already filtered for `ToolCall`"
568                            )
569                        }
570                    }
571                    .instrument(tool_span)
572                })
573                .buffer_unordered(self.concurrency)
574                .collect::<Vec<Result<UserContent, ToolSetError>>>()
575                .await
576                .into_iter()
577                .collect::<Result<Vec<_>, _>>()
578                .map_err(|e| {
579                    if matches!(e, ToolSetError::Interrupted) {
580                        PromptError::prompt_cancelled(
581                            chat_history.to_vec(),
582                            cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
583                        )
584                    } else {
585                        e.into()
586                    }
587                })?;
588
589            chat_history.push(Message::User {
590                content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
591            });
592        };
593
594        // If we reach here, we never resolved the final tool call. We need to do ... something.
595        Err(PromptError::MaxDepthError {
596            max_depth: self.max_depth,
597            chat_history: Box::new(chat_history.clone()),
598            prompt: Box::new(last_prompt),
599        })
600    }
601}