Skip to main content

rig/agent/prompt_request/
mod.rs

1pub mod hooks;
2pub mod streaming;
3
4use hooks::{HookAction, PromptHook, ToolCallHookAction};
5use std::{
6    future::IntoFuture,
7    marker::PhantomData,
8    sync::{
9        Arc,
10        atomic::{AtomicU64, Ordering},
11    },
12};
13use tracing::{Instrument, span::Id};
14
15use futures::{StreamExt, stream};
16use tracing::info_span;
17
18use crate::{
19    OneOrMany,
20    completion::{CompletionModel, Document, Message, PromptError, Usage},
21    json_utils,
22    message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
23    tool::server::ToolServerHandle,
24    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
25};
26
27use super::{
28    Agent,
29    completion::{DynamicContextStore, build_completion_request},
30};
31
32pub trait PromptType {}
33pub struct Standard;
34pub struct Extended;
35
36impl PromptType for Standard {}
37impl PromptType for Extended {}
38
39/// A builder for creating prompt requests with customizable options.
40/// Uses generics to track which options have been set during the build process.
41///
42/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
43/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
44/// attempting to await (which will send the prompt request) can potentially return
45/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
46/// back to back.
47pub struct PromptRequest<'a, S, M, P>
48where
49    S: PromptType,
50    M: CompletionModel,
51    P: PromptHook<M>,
52{
53    /// The prompt message to send to the model
54    prompt: Message,
55    /// Optional chat history to include with the prompt
56    /// Note: chat history needs to outlive the agent as it might be used with other agents
57    chat_history: Option<&'a mut Vec<Message>>,
58    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
59    max_turns: usize,
60
61    // Agent data (cloned from agent to allow hook type transitions):
62    /// The completion model
63    model: Arc<M>,
64    /// Agent name for logging
65    agent_name: Option<String>,
66    /// System prompt
67    preamble: Option<String>,
68    /// Static context documents
69    static_context: Vec<Document>,
70    /// Temperature setting
71    temperature: Option<f64>,
72    /// Max tokens setting
73    max_tokens: Option<u64>,
74    /// Additional model parameters
75    additional_params: Option<serde_json::Value>,
76    /// Tool server handle for tool execution
77    tool_server_handle: ToolServerHandle,
78    /// Dynamic context store
79    dynamic_context: DynamicContextStore,
80    /// Tool choice setting
81    tool_choice: Option<ToolChoice>,
82
83    /// Phantom data to track the type of the request
84    state: PhantomData<S>,
85    /// Optional per-request hook for events
86    hook: Option<P>,
87    /// How many tools should be executed at the same time (1 by default).
88    concurrency: usize,
89    /// Optional JSON Schema for structured output
90    output_schema: Option<schemars::Schema>,
91}
92
93impl<'a, M, P> PromptRequest<'a, Standard, M, P>
94where
95    M: CompletionModel,
96    P: PromptHook<M>,
97{
98    /// Create a new PromptRequest from an agent, cloning the agent's data and default hook.
99    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
100        PromptRequest {
101            prompt: prompt.into(),
102            chat_history: None,
103            max_turns: agent.default_max_turns.unwrap_or_default(),
104            model: agent.model.clone(),
105            agent_name: agent.name.clone(),
106            preamble: agent.preamble.clone(),
107            static_context: agent.static_context.clone(),
108            temperature: agent.temperature,
109            max_tokens: agent.max_tokens,
110            additional_params: agent.additional_params.clone(),
111            tool_server_handle: agent.tool_server_handle.clone(),
112            dynamic_context: agent.dynamic_context.clone(),
113            tool_choice: agent.tool_choice.clone(),
114            state: PhantomData,
115            hook: agent.hook.clone(),
116            concurrency: 1,
117            output_schema: agent.output_schema.clone(),
118        }
119    }
120}
121
122impl<'a, S, M, P> PromptRequest<'a, S, M, P>
123where
124    S: PromptType,
125    M: CompletionModel,
126    P: PromptHook<M>,
127{
128    /// Enable returning extended details for responses (includes aggregated token usage)
129    ///
130    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
131    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
132    /// of conversation.
133    pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
134        PromptRequest {
135            prompt: self.prompt,
136            chat_history: self.chat_history,
137            max_turns: self.max_turns,
138            model: self.model,
139            agent_name: self.agent_name,
140            preamble: self.preamble,
141            static_context: self.static_context,
142            temperature: self.temperature,
143            max_tokens: self.max_tokens,
144            additional_params: self.additional_params,
145            tool_server_handle: self.tool_server_handle,
146            dynamic_context: self.dynamic_context,
147            tool_choice: self.tool_choice,
148            state: PhantomData,
149            hook: self.hook,
150            concurrency: self.concurrency,
151            output_schema: self.output_schema,
152        }
153    }
154
155    /// Set the maximum number of turns for multi-turn conversations. A given agent may require multiple turns for tool-calling before giving an answer.
156    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
157    pub fn max_turns(mut self, depth: usize) -> Self {
158        self.max_turns = depth;
159        self
160    }
161
162    /// Add concurrency to the prompt request.
163    /// This will cause the agent to execute tools concurrently.
164    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
165        self.concurrency = concurrency;
166        self
167    }
168
169    /// Add chat history to the prompt request
170    pub fn with_history(mut self, history: &'a mut Vec<Message>) -> Self {
171        self.chat_history = Some(history);
172        self
173    }
174
175    /// Attach a per-request hook for tool call events.
176    /// This overrides any default hook set on the agent.
177    pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
178    where
179        P2: PromptHook<M>,
180    {
181        PromptRequest {
182            prompt: self.prompt,
183            chat_history: self.chat_history,
184            max_turns: self.max_turns,
185            model: self.model,
186            agent_name: self.agent_name,
187            preamble: self.preamble,
188            static_context: self.static_context,
189            temperature: self.temperature,
190            max_tokens: self.max_tokens,
191            additional_params: self.additional_params,
192            tool_server_handle: self.tool_server_handle,
193            dynamic_context: self.dynamic_context,
194            tool_choice: self.tool_choice,
195            state: PhantomData,
196            hook: Some(hook),
197            concurrency: self.concurrency,
198            output_schema: self.output_schema,
199        }
200    }
201}
202
203/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
204///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
205///  directly via the associated type.
206impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
207where
208    M: CompletionModel + 'a,
209    P: PromptHook<M> + 'static,
210{
211    type Output = Result<String, PromptError>;
212    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
213
214    fn into_future(self) -> Self::IntoFuture {
215        Box::pin(self.send())
216    }
217}
218
219impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
220where
221    M: CompletionModel + 'a,
222    P: PromptHook<M> + 'static,
223{
224    type Output = Result<PromptResponse, PromptError>;
225    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
226
227    fn into_future(self) -> Self::IntoFuture {
228        Box::pin(self.send())
229    }
230}
231
232impl<M, P> PromptRequest<'_, Standard, M, P>
233where
234    M: CompletionModel,
235    P: PromptHook<M>,
236{
237    async fn send(self) -> Result<String, PromptError> {
238        self.extended_details().send().await.map(|resp| resp.output)
239    }
240}
241
242#[derive(Debug, Clone)]
243pub struct PromptResponse {
244    pub output: String,
245    pub total_usage: Usage,
246}
247
248impl PromptResponse {
249    pub fn new(output: impl Into<String>, total_usage: Usage) -> Self {
250        Self {
251            output: output.into(),
252            total_usage,
253        }
254    }
255}
256
257const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
258
259impl<M, P> PromptRequest<'_, Extended, M, P>
260where
261    M: CompletionModel,
262    P: PromptHook<M>,
263{
264    fn agent_name(&self) -> &str {
265        self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
266    }
267
268    async fn send(mut self) -> Result<PromptResponse, PromptError> {
269        let agent_span = if tracing::Span::current().is_disabled() {
270            info_span!(
271                "invoke_agent",
272                gen_ai.operation.name = "invoke_agent",
273                gen_ai.agent.name = self.agent_name(),
274                gen_ai.system_instructions = self.preamble,
275                gen_ai.prompt = tracing::field::Empty,
276                gen_ai.completion = tracing::field::Empty,
277                gen_ai.usage.input_tokens = tracing::field::Empty,
278                gen_ai.usage.output_tokens = tracing::field::Empty,
279            )
280        } else {
281            tracing::Span::current()
282        };
283
284        if let Some(text) = self.prompt.rag_text() {
285            agent_span.record("gen_ai.prompt", text);
286        }
287
288        // Capture agent_name before borrowing chat_history
289        let agent_name_for_span = self.agent_name.clone();
290
291        let chat_history = if let Some(history) = self.chat_history.as_mut() {
292            history.push(self.prompt.to_owned());
293            history
294        } else {
295            &mut vec![self.prompt.to_owned()]
296        };
297
298        let mut current_max_turns = 0;
299        let mut usage = Usage::new();
300        let current_span_id: AtomicU64 = AtomicU64::new(0);
301
302        // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
303        let last_prompt = loop {
304            let prompt = chat_history
305                .last()
306                .cloned()
307                .expect("there should always be at least one message in the chat history");
308
309            if current_max_turns > self.max_turns + 1 {
310                break prompt;
311            }
312
313            current_max_turns += 1;
314
315            if self.max_turns > 1 {
316                tracing::info!(
317                    "Current conversation depth: {}/{}",
318                    current_max_turns,
319                    self.max_turns
320                );
321            }
322
323            if let Some(ref hook) = self.hook
324                && let HookAction::Terminate { reason } = hook
325                    .on_completion_call(&prompt, &chat_history[..chat_history.len() - 1])
326                    .await
327            {
328                return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
329            }
330
331            let span = tracing::Span::current();
332            let chat_span = info_span!(
333                target: "rig::agent_chat",
334                parent: &span,
335                "chat",
336                gen_ai.operation.name = "chat",
337                gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
338                gen_ai.system_instructions = self.preamble,
339                gen_ai.provider.name = tracing::field::Empty,
340                gen_ai.request.model = tracing::field::Empty,
341                gen_ai.response.id = tracing::field::Empty,
342                gen_ai.response.model = tracing::field::Empty,
343                gen_ai.usage.output_tokens = tracing::field::Empty,
344                gen_ai.usage.input_tokens = tracing::field::Empty,
345                gen_ai.input.messages = tracing::field::Empty,
346                gen_ai.output.messages = tracing::field::Empty,
347            );
348
349            let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
350                let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
351                chat_span.follows_from(id).to_owned()
352            } else {
353                chat_span
354            };
355
356            if let Some(id) = chat_span.id() {
357                current_span_id.store(id.into_u64(), Ordering::SeqCst);
358            };
359
360            let resp = build_completion_request(
361                &self.model,
362                prompt.clone(),
363                chat_history[..chat_history.len() - 1].to_vec(),
364                self.preamble.as_deref(),
365                &self.static_context,
366                self.temperature,
367                self.max_tokens,
368                self.additional_params.as_ref(),
369                self.tool_choice.as_ref(),
370                &self.tool_server_handle,
371                &self.dynamic_context,
372                self.output_schema.as_ref(),
373            )
374            .await?
375            .send()
376            .instrument(chat_span.clone())
377            .await?;
378
379            usage += resp.usage;
380
381            if let Some(ref hook) = self.hook
382                && let HookAction::Terminate { reason } =
383                    hook.on_completion_response(&prompt, &resp).await
384            {
385                return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
386            }
387
388            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
389                .choice
390                .iter()
391                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
392
393            chat_history.push(Message::Assistant {
394                id: resp.message_id.clone(),
395                content: resp.choice.clone(),
396            });
397
398            if tool_calls.is_empty() {
399                let merged_texts = texts
400                    .into_iter()
401                    .filter_map(|content| {
402                        if let AssistantContent::Text(text) = content {
403                            Some(text.text.clone())
404                        } else {
405                            None
406                        }
407                    })
408                    .collect::<Vec<_>>()
409                    .join("\n");
410
411                if self.max_turns > 1 {
412                    tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
413                }
414
415                agent_span.record("gen_ai.completion", &merged_texts);
416                agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
417                agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
418
419                // If there are no tool calls, depth is not relevant, we can just return the merged text response.
420                return Ok(PromptResponse::new(merged_texts, usage));
421            }
422
423            let hook = self.hook.clone();
424            let tool_server_handle = self.tool_server_handle.clone();
425
426            let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
427            let tool_content = stream::iter(tool_calls)
428                .map(|choice| {
429                    let hook1 = hook.clone();
430                    let hook2 = hook.clone();
431                    let tool_server_handle = tool_server_handle.clone();
432
433                    let tool_span = info_span!(
434                        "execute_tool",
435                        gen_ai.operation.name = "execute_tool",
436                        gen_ai.tool.type = "function",
437                        gen_ai.tool.name = tracing::field::Empty,
438                        gen_ai.tool.call.id = tracing::field::Empty,
439                        gen_ai.tool.call.arguments = tracing::field::Empty,
440                        gen_ai.tool.call.result = tracing::field::Empty
441                    );
442
443                    let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
444                        let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
445                        tool_span.follows_from(id).to_owned()
446                    } else {
447                        tool_span
448                    };
449
450                    if let Some(id) = tool_span.id() {
451                        current_span_id.store(id.into_u64(), Ordering::SeqCst);
452                    };
453
454                    let cloned_chat_history = chat_history.clone().to_vec();
455
456                    async move {
457                        if let AssistantContent::ToolCall(tool_call) = choice {
458                            let tool_name = &tool_call.function.name;
459                            let args =
460                                json_utils::value_to_json_string(&tool_call.function.arguments);
461                            let internal_call_id = nanoid::nanoid!();
462                            let tool_span = tracing::Span::current();
463                            tool_span.record("gen_ai.tool.name", tool_name);
464                            tool_span.record("gen_ai.tool.call.id", &tool_call.id);
465                            tool_span.record("gen_ai.tool.call.arguments", &args);
466                            if let Some(hook) = hook1 {
467                                let action = hook
468                                    .on_tool_call(
469                                        tool_name,
470                                        tool_call.call_id.clone(),
471                                        &internal_call_id,
472                                        &args,
473                                    )
474                                    .await;
475
476                                if let ToolCallHookAction::Terminate { reason } = action {
477                                    return Err(PromptError::prompt_cancelled(
478                                        cloned_chat_history,
479                                        reason,
480                                    ));
481                                }
482
483                                if let ToolCallHookAction::Skip { reason } = action {
484                                    // Tool execution rejected, return rejection message as tool result
485                                    tracing::info!(
486                                        tool_name = tool_name,
487                                        reason = reason,
488                                        "Tool call rejected"
489                                    );
490                                    if let Some(call_id) = tool_call.call_id.clone() {
491                                        return Ok(UserContent::tool_result_with_call_id(
492                                            tool_call.id.clone(),
493                                            call_id,
494                                            OneOrMany::one(reason.into()),
495                                        ));
496                                    } else {
497                                        return Ok(UserContent::tool_result(
498                                            tool_call.id.clone(),
499                                            OneOrMany::one(reason.into()),
500                                        ));
501                                    }
502                                }
503                            }
504                            let output = match tool_server_handle.call_tool(tool_name, &args).await
505                            {
506                                Ok(res) => res,
507                                Err(e) => {
508                                    tracing::warn!("Error while executing tool: {e}");
509                                    e.to_string()
510                                }
511                            };
512                            if let Some(hook) = hook2
513                                && let HookAction::Terminate { reason } = hook
514                                    .on_tool_result(
515                                        tool_name,
516                                        tool_call.call_id.clone(),
517                                        &internal_call_id,
518                                        &args,
519                                        &output.to_string(),
520                                    )
521                                    .await
522                            {
523                                return Err(PromptError::prompt_cancelled(
524                                    cloned_chat_history,
525                                    reason,
526                                ));
527                            }
528
529                            tool_span.record("gen_ai.tool.call.result", &output);
530                            tracing::info!(
531                                "executed tool {tool_name} with args {args}. result: {output}"
532                            );
533                            if let Some(call_id) = tool_call.call_id.clone() {
534                                Ok(UserContent::tool_result_with_call_id(
535                                    tool_call.id.clone(),
536                                    call_id,
537                                    ToolResultContent::from_tool_output(output),
538                                ))
539                            } else {
540                                Ok(UserContent::tool_result(
541                                    tool_call.id.clone(),
542                                    ToolResultContent::from_tool_output(output),
543                                ))
544                            }
545                        } else {
546                            unreachable!(
547                                "This should never happen as we already filtered for `ToolCall`"
548                            )
549                        }
550                    }
551                    .instrument(tool_span)
552                })
553                .buffer_unordered(self.concurrency)
554                .collect::<Vec<Result<UserContent, PromptError>>>()
555                .await
556                .into_iter()
557                .collect::<Result<Vec<_>, _>>()?;
558
559            chat_history.push(Message::User {
560                content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
561            });
562        };
563
564        // If we reach here, we never resolved the final tool call. We need to do ... something.
565        Err(PromptError::MaxTurnsError {
566            max_turns: self.max_turns,
567            chat_history: Box::new(chat_history.clone()),
568            prompt: Box::new(last_prompt),
569        })
570    }
571}
572
573// ================================================================
574// TypedPromptRequest - for structured output with automatic deserialization
575// ================================================================
576
577use crate::completion::StructuredOutputError;
578use schemars::{JsonSchema, schema_for};
579use serde::de::DeserializeOwned;
580
581/// A builder for creating typed prompt requests that return deserialized structured output.
582///
583/// This struct wraps a standard `PromptRequest` and adds:
584/// - Automatic JSON schema generation from the target type `T`
585/// - Automatic deserialization of the response into `T`
586///
587/// # Example
588/// ```rust,ignore
589/// let forecast: WeatherForecast = agent
590///     .prompt_typed("What's the weather in NYC?")
591///     .max_turns(3)
592///     .await?;
593/// ```
594pub struct TypedPromptRequest<'a, T, M, P>
595where
596    T: JsonSchema + DeserializeOwned + WasmCompatSend,
597    M: CompletionModel,
598    P: PromptHook<M>,
599{
600    inner: PromptRequest<'a, Standard, M, P>,
601    _phantom: std::marker::PhantomData<T>,
602}
603
604impl<'a, T, M, P> TypedPromptRequest<'a, T, M, P>
605where
606    T: JsonSchema + DeserializeOwned + WasmCompatSend,
607    M: CompletionModel,
608    P: PromptHook<M>,
609{
610    /// Create a new TypedPromptRequest from an agent.
611    ///
612    /// This automatically sets the output schema based on the type parameter `T`.
613    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
614        let mut inner = PromptRequest::from_agent(agent, prompt);
615        // Override the output schema with the schema for T
616        inner.output_schema = Some(schema_for!(T));
617        Self {
618            inner,
619            _phantom: std::marker::PhantomData,
620        }
621    }
622
623    /// Set the maximum number of turns for multi-turn conversations.
624    ///
625    /// A given agent may require multiple turns for tool-calling before giving an answer.
626    /// If the maximum turn number is exceeded, it will return a
627    /// [`StructuredOutputError::PromptError`] wrapping a `MaxTurnsError`.
628    pub fn max_turns(mut self, depth: usize) -> Self {
629        self.inner = self.inner.max_turns(depth);
630        self
631    }
632
633    /// Add concurrency to the prompt request.
634    ///
635    /// This will cause the agent to execute tools concurrently.
636    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
637        self.inner = self.inner.with_tool_concurrency(concurrency);
638        self
639    }
640
641    /// Add chat history to the prompt request.
642    pub fn with_history(mut self, history: &'a mut Vec<Message>) -> Self {
643        self.inner = self.inner.with_history(history);
644        self
645    }
646
647    /// Attach a per-request hook for tool call events.
648    ///
649    /// This overrides any default hook set on the agent.
650    pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<'a, T, M, P2>
651    where
652        P2: PromptHook<M>,
653    {
654        TypedPromptRequest {
655            inner: self.inner.with_hook(hook),
656            _phantom: std::marker::PhantomData,
657        }
658    }
659
660    /// Send the typed prompt request and deserialize the response.
661    async fn send(self) -> Result<T, StructuredOutputError> {
662        let response = self.inner.send().await?;
663
664        if response.is_empty() {
665            return Err(StructuredOutputError::EmptyResponse);
666        }
667
668        let parsed: T = serde_json::from_str(&response)?;
669        Ok(parsed)
670    }
671}
672
673impl<'a, T, M, P> IntoFuture for TypedPromptRequest<'a, T, M, P>
674where
675    T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a,
676    M: CompletionModel + 'a,
677    P: PromptHook<M> + 'static,
678{
679    type Output = Result<T, StructuredOutputError>;
680    type IntoFuture = WasmBoxedFuture<'a, Self::Output>;
681
682    fn into_future(self) -> Self::IntoFuture {
683        Box::pin(self.send())
684    }
685}