Skip to main content

rig/agent/prompt_request/
mod.rs

1pub mod hooks;
2pub mod streaming;
3
4use super::{
5    Agent,
6    completion::{DynamicContextStore, build_completion_request},
7};
8use crate::{
9    OneOrMany,
10    completion::{CompletionModel, Document, Message, PromptError, Usage},
11    json_utils,
12    message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
13    tool::server::ToolServerHandle,
14    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
15};
16use futures::{StreamExt, stream};
17use hooks::{HookAction, PromptHook, ToolCallHookAction};
18use std::{
19    future::IntoFuture,
20    marker::PhantomData,
21    sync::{
22        Arc,
23        atomic::{AtomicU64, Ordering},
24    },
25};
26use tracing::info_span;
27use tracing::{Instrument, span::Id};
28
29pub trait PromptType {}
30pub struct Standard;
31pub struct Extended;
32
33impl PromptType for Standard {}
34impl PromptType for Extended {}
35
36/// A builder for creating prompt requests with customizable options.
37/// Uses generics to track which options have been set during the build process.
38///
39/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
40/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
41/// attempting to await (which will send the prompt request) can potentially return
42/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
43/// back to back.
44pub struct PromptRequest<S, M, P>
45where
46    S: PromptType,
47    M: CompletionModel,
48    P: PromptHook<M>,
49{
50    /// The prompt message to send to the model
51    prompt: Message,
52    /// Optional chat history provided by the caller.
53    chat_history: Option<Vec<Message>>,
54    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
55    max_turns: usize,
56
57    // Agent data (cloned from agent to allow hook type transitions):
58    /// The completion model
59    model: Arc<M>,
60    /// Agent name for logging
61    agent_name: Option<String>,
62    /// System prompt
63    preamble: Option<String>,
64    /// Static context documents
65    static_context: Vec<Document>,
66    /// Temperature setting
67    temperature: Option<f64>,
68    /// Max tokens setting
69    max_tokens: Option<u64>,
70    /// Additional model parameters
71    additional_params: Option<serde_json::Value>,
72    /// Tool server handle for tool execution
73    tool_server_handle: ToolServerHandle,
74    /// Dynamic context store
75    dynamic_context: DynamicContextStore,
76    /// Tool choice setting
77    tool_choice: Option<ToolChoice>,
78
79    /// Phantom data to track the type of the request
80    state: PhantomData<S>,
81    /// Optional per-request hook for events
82    hook: Option<P>,
83    /// How many tools should be executed at the same time (1 by default).
84    concurrency: usize,
85    /// Optional JSON Schema for structured output
86    output_schema: Option<schemars::Schema>,
87}
88
89impl<M, P> PromptRequest<Standard, M, P>
90where
91    M: CompletionModel,
92    P: PromptHook<M>,
93{
94    /// Create a new PromptRequest from an agent, cloning the agent's data and default hook.
95    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
96        PromptRequest {
97            prompt: prompt.into(),
98            chat_history: None,
99            max_turns: agent.default_max_turns.unwrap_or_default(),
100            model: agent.model.clone(),
101            agent_name: agent.name.clone(),
102            preamble: agent.preamble.clone(),
103            static_context: agent.static_context.clone(),
104            temperature: agent.temperature,
105            max_tokens: agent.max_tokens,
106            additional_params: agent.additional_params.clone(),
107            tool_server_handle: agent.tool_server_handle.clone(),
108            dynamic_context: agent.dynamic_context.clone(),
109            tool_choice: agent.tool_choice.clone(),
110            state: PhantomData,
111            hook: agent.hook.clone(),
112            concurrency: 1,
113            output_schema: agent.output_schema.clone(),
114        }
115    }
116}
117
118impl<S, M, P> PromptRequest<S, M, P>
119where
120    S: PromptType,
121    M: CompletionModel,
122    P: PromptHook<M>,
123{
124    /// Enable returning extended details for responses (includes aggregated token usage
125    /// and the full message history accumulated during the agent loop).
126    ///
127    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
128    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
129    /// of conversation and inspecting the full message exchange.
130    pub fn extended_details(self) -> PromptRequest<Extended, M, P> {
131        PromptRequest {
132            prompt: self.prompt,
133            chat_history: self.chat_history,
134            max_turns: self.max_turns,
135            model: self.model,
136            agent_name: self.agent_name,
137            preamble: self.preamble,
138            static_context: self.static_context,
139            temperature: self.temperature,
140            max_tokens: self.max_tokens,
141            additional_params: self.additional_params,
142            tool_server_handle: self.tool_server_handle,
143            dynamic_context: self.dynamic_context,
144            tool_choice: self.tool_choice,
145            state: PhantomData,
146            hook: self.hook,
147            concurrency: self.concurrency,
148            output_schema: self.output_schema,
149        }
150    }
151
152    /// Set the maximum number of turns for multi-turn conversations. A given agent may require multiple turns for tool-calling before giving an answer.
153    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
154    pub fn max_turns(mut self, depth: usize) -> Self {
155        self.max_turns = depth;
156        self
157    }
158
159    /// Add concurrency to the prompt request.
160    /// This will cause the agent to execute tools concurrently.
161    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
162        self.concurrency = concurrency;
163        self
164    }
165
166    /// Add chat history to the prompt request.
167    pub fn with_history<I, T>(mut self, history: I) -> Self
168    where
169        I: IntoIterator<Item = T>,
170        T: Into<Message>,
171    {
172        self.chat_history = Some(history.into_iter().map(Into::into).collect());
173        self
174    }
175
176    /// Attach a per-request hook for tool call events.
177    /// This overrides any default hook set on the agent.
178    pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<S, M, P2>
179    where
180        P2: PromptHook<M>,
181    {
182        PromptRequest {
183            prompt: self.prompt,
184            chat_history: self.chat_history,
185            max_turns: self.max_turns,
186            model: self.model,
187            agent_name: self.agent_name,
188            preamble: self.preamble,
189            static_context: self.static_context,
190            temperature: self.temperature,
191            max_tokens: self.max_tokens,
192            additional_params: self.additional_params,
193            tool_server_handle: self.tool_server_handle,
194            dynamic_context: self.dynamic_context,
195            tool_choice: self.tool_choice,
196            state: PhantomData,
197            hook: Some(hook),
198            concurrency: self.concurrency,
199            output_schema: self.output_schema,
200        }
201    }
202}
203
204/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
205///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
206///  directly via the associated type.
207impl<M, P> IntoFuture for PromptRequest<Standard, M, P>
208where
209    M: CompletionModel + 'static,
210    P: PromptHook<M> + 'static,
211{
212    type Output = Result<String, PromptError>;
213    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
214
215    fn into_future(self) -> Self::IntoFuture {
216        Box::pin(self.send())
217    }
218}
219
220impl<M, P> IntoFuture for PromptRequest<Extended, M, P>
221where
222    M: CompletionModel + 'static,
223    P: PromptHook<M> + 'static,
224{
225    type Output = Result<PromptResponse, PromptError>;
226    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
227
228    fn into_future(self) -> Self::IntoFuture {
229        Box::pin(self.send())
230    }
231}
232
233impl<M, P> PromptRequest<Standard, M, P>
234where
235    M: CompletionModel,
236    P: PromptHook<M>,
237{
238    async fn send(self) -> Result<String, PromptError> {
239        self.extended_details().send().await.map(|resp| resp.output)
240    }
241}
242
243#[derive(Debug, Clone)]
244#[non_exhaustive]
245pub struct PromptResponse {
246    pub output: String,
247    pub usage: Usage,
248    pub messages: Option<Vec<Message>>,
249}
250
251impl std::fmt::Display for PromptResponse {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        self.output.fmt(f)
254    }
255}
256
257impl PromptResponse {
258    pub fn new(output: impl Into<String>, usage: Usage) -> Self {
259        Self {
260            output: output.into(),
261            usage,
262            messages: None,
263        }
264    }
265
266    pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
267        self.messages = Some(messages);
268        self
269    }
270}
271
272#[derive(Debug, Clone)]
273pub struct TypedPromptResponse<T> {
274    pub output: T,
275    pub usage: Usage,
276}
277
278impl<T> TypedPromptResponse<T> {
279    pub fn new(output: T, usage: Usage) -> Self {
280        Self { output, usage }
281    }
282}
283
284const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
285
286/// Combine input history with new messages for building completion requests.
287fn build_history_for_request(
288    chat_history: Option<&[Message]>,
289    new_messages: &[Message],
290) -> Vec<Message> {
291    let input = chat_history.unwrap_or(&[]);
292    input.iter().chain(new_messages.iter()).cloned().collect()
293}
294
295/// Build the full history for error reporting (input + new messages).
296fn build_full_history(
297    chat_history: Option<&[Message]>,
298    new_messages: Vec<Message>,
299) -> Vec<Message> {
300    let input = chat_history.unwrap_or(&[]);
301    input.iter().cloned().chain(new_messages).collect()
302}
303
304impl<M, P> PromptRequest<Extended, M, P>
305where
306    M: CompletionModel,
307    P: PromptHook<M>,
308{
309    fn agent_name(&self) -> &str {
310        self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
311    }
312
313    async fn send(self) -> Result<PromptResponse, PromptError> {
314        let agent_span = if tracing::Span::current().is_disabled() {
315            info_span!(
316                "invoke_agent",
317                gen_ai.operation.name = "invoke_agent",
318                gen_ai.agent.name = self.agent_name(),
319                gen_ai.system_instructions = self.preamble,
320                gen_ai.prompt = tracing::field::Empty,
321                gen_ai.completion = tracing::field::Empty,
322                gen_ai.usage.input_tokens = tracing::field::Empty,
323                gen_ai.usage.output_tokens = tracing::field::Empty,
324                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
325                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
326            )
327        } else {
328            tracing::Span::current()
329        };
330
331        if let Some(text) = self.prompt.rag_text() {
332            agent_span.record("gen_ai.prompt", text);
333        }
334
335        let agent_name_for_span = self.agent_name.clone();
336        let chat_history = self.chat_history;
337        let mut new_messages: Vec<Message> = vec![self.prompt.clone()];
338
339        let mut current_max_turns = 0;
340        let mut usage = Usage::new();
341        let current_span_id: AtomicU64 = AtomicU64::new(0);
342
343        // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
344        let last_prompt = loop {
345            // Get the last message (the current prompt)
346            let prompt = new_messages
347                .last()
348                .expect("there should always be at least one message")
349                .clone();
350
351            if current_max_turns > self.max_turns + 1 {
352                break prompt;
353            }
354
355            current_max_turns += 1;
356
357            if self.max_turns > 1 {
358                tracing::info!(
359                    "Current conversation depth: {}/{}",
360                    current_max_turns,
361                    self.max_turns
362                );
363            }
364
365            // Build history for hook callback (input + new messages except last)
366            let history_for_hook = build_history_for_request(
367                chat_history.as_deref(),
368                &new_messages[..new_messages.len().saturating_sub(1)],
369            );
370
371            if let Some(ref hook) = self.hook
372                && let HookAction::Terminate { reason } =
373                    hook.on_completion_call(&prompt, &history_for_hook).await
374            {
375                return Err(PromptError::prompt_cancelled(
376                    build_full_history(chat_history.as_deref(), new_messages),
377                    reason,
378                ));
379            }
380
381            let span = tracing::Span::current();
382            let chat_span = info_span!(
383                target: "rig::agent_chat",
384                parent: &span,
385                "chat",
386                gen_ai.operation.name = "chat",
387                gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
388                gen_ai.system_instructions = self.preamble,
389                gen_ai.provider.name = tracing::field::Empty,
390                gen_ai.request.model = tracing::field::Empty,
391                gen_ai.response.id = tracing::field::Empty,
392                gen_ai.response.model = tracing::field::Empty,
393                gen_ai.usage.output_tokens = tracing::field::Empty,
394                gen_ai.usage.input_tokens = tracing::field::Empty,
395                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
396                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
397                gen_ai.input.messages = tracing::field::Empty,
398                gen_ai.output.messages = tracing::field::Empty,
399            );
400
401            let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
402                let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
403                chat_span.follows_from(id).to_owned()
404            } else {
405                chat_span
406            };
407
408            if let Some(id) = chat_span.id() {
409                current_span_id.store(id.into_u64(), Ordering::SeqCst);
410            };
411
412            // Build history for completion request (input + new messages except last)
413            let history_for_request = build_history_for_request(
414                chat_history.as_deref(),
415                &new_messages[..new_messages.len().saturating_sub(1)],
416            );
417
418            let resp = build_completion_request(
419                &self.model,
420                prompt.clone(),
421                &history_for_request,
422                self.preamble.as_deref(),
423                &self.static_context,
424                self.temperature,
425                self.max_tokens,
426                self.additional_params.as_ref(),
427                self.tool_choice.as_ref(),
428                &self.tool_server_handle,
429                &self.dynamic_context,
430                self.output_schema.as_ref(),
431            )
432            .await?
433            .send()
434            .instrument(chat_span.clone())
435            .await?;
436
437            usage += resp.usage;
438
439            if let Some(ref hook) = self.hook
440                && let HookAction::Terminate { reason } =
441                    hook.on_completion_response(&prompt, &resp).await
442            {
443                return Err(PromptError::prompt_cancelled(
444                    build_full_history(chat_history.as_deref(), new_messages),
445                    reason,
446                ));
447            }
448
449            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
450                .choice
451                .iter()
452                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
453
454            new_messages.push(Message::Assistant {
455                id: resp.message_id.clone(),
456                content: resp.choice.clone(),
457            });
458
459            if tool_calls.is_empty() {
460                let merged_texts = texts
461                    .into_iter()
462                    .filter_map(|content| {
463                        if let AssistantContent::Text(text) = content {
464                            Some(text.text.clone())
465                        } else {
466                            None
467                        }
468                    })
469                    .collect::<Vec<_>>()
470                    .join("\n");
471
472                if self.max_turns > 1 {
473                    tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
474                }
475
476                agent_span.record("gen_ai.completion", &merged_texts);
477                agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
478                agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
479                agent_span.record(
480                    "gen_ai.usage.cache_read.input_tokens",
481                    usage.cached_input_tokens,
482                );
483                agent_span.record(
484                    "gen_ai.usage.cache_creation.input_tokens",
485                    usage.cache_creation_input_tokens,
486                );
487
488                return Ok(PromptResponse::new(merged_texts, usage).with_messages(new_messages));
489            }
490
491            let hook = self.hook.clone();
492            let tool_server_handle = self.tool_server_handle.clone();
493
494            // For error handling in concurrent tool execution, we need to build full history
495            let full_history_for_errors =
496                build_full_history(chat_history.as_deref(), new_messages.clone());
497
498            let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
499            let tool_content = stream::iter(tool_calls)
500                .map(|choice| {
501                    let hook1 = hook.clone();
502                    let hook2 = hook.clone();
503                    let tool_server_handle = tool_server_handle.clone();
504
505                    let tool_span = info_span!(
506                        "execute_tool",
507                        gen_ai.operation.name = "execute_tool",
508                        gen_ai.tool.type = "function",
509                        gen_ai.tool.name = tracing::field::Empty,
510                        gen_ai.tool.call.id = tracing::field::Empty,
511                        gen_ai.tool.call.arguments = tracing::field::Empty,
512                        gen_ai.tool.call.result = tracing::field::Empty
513                    );
514
515                    let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
516                        let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
517                        tool_span.follows_from(id).to_owned()
518                    } else {
519                        tool_span
520                    };
521
522                    if let Some(id) = tool_span.id() {
523                        current_span_id.store(id.into_u64(), Ordering::SeqCst);
524                    };
525
526                    // Clone full history for error reporting in concurrent tool execution
527                    let cloned_history_for_error = full_history_for_errors.clone();
528
529                    async move {
530                        if let AssistantContent::ToolCall(tool_call) = choice {
531                            let tool_name = &tool_call.function.name;
532                            let args =
533                                json_utils::value_to_json_string(&tool_call.function.arguments);
534                            let internal_call_id = nanoid::nanoid!();
535                            let tool_span = tracing::Span::current();
536                            tool_span.record("gen_ai.tool.name", tool_name);
537                            tool_span.record("gen_ai.tool.call.id", &tool_call.id);
538                            tool_span.record("gen_ai.tool.call.arguments", &args);
539                            if let Some(hook) = hook1 {
540                                let action = hook
541                                    .on_tool_call(
542                                        tool_name,
543                                        tool_call.call_id.clone(),
544                                        &internal_call_id,
545                                        &args,
546                                    )
547                                    .await;
548
549                                if let ToolCallHookAction::Terminate { reason } = action {
550                                    return Err(PromptError::prompt_cancelled(
551                                        cloned_history_for_error,
552                                        reason,
553                                    ));
554                                }
555
556                                if let ToolCallHookAction::Skip { reason } = action {
557                                    // Tool execution rejected, return rejection message as tool result
558                                    tracing::info!(
559                                        tool_name = tool_name,
560                                        reason = reason,
561                                        "Tool call rejected"
562                                    );
563                                    if let Some(call_id) = tool_call.call_id.clone() {
564                                        return Ok(UserContent::tool_result_with_call_id(
565                                            tool_call.id.clone(),
566                                            call_id,
567                                            OneOrMany::one(reason.into()),
568                                        ));
569                                    } else {
570                                        return Ok(UserContent::tool_result(
571                                            tool_call.id.clone(),
572                                            OneOrMany::one(reason.into()),
573                                        ));
574                                    }
575                                }
576                            }
577                            let output = match tool_server_handle.call_tool(tool_name, &args).await
578                            {
579                                Ok(res) => res,
580                                Err(e) => {
581                                    tracing::warn!("Error while executing tool: {e}");
582                                    e.to_string()
583                                }
584                            };
585                            if let Some(hook) = hook2
586                                && let HookAction::Terminate { reason } = hook
587                                    .on_tool_result(
588                                        tool_name,
589                                        tool_call.call_id.clone(),
590                                        &internal_call_id,
591                                        &args,
592                                        &output.to_string(),
593                                    )
594                                    .await
595                            {
596                                return Err(PromptError::prompt_cancelled(
597                                    cloned_history_for_error,
598                                    reason,
599                                ));
600                            }
601
602                            tool_span.record("gen_ai.tool.call.result", &output);
603                            tracing::info!(
604                                "executed tool {tool_name} with args {args}. result: {output}"
605                            );
606                            if let Some(call_id) = tool_call.call_id.clone() {
607                                Ok(UserContent::tool_result_with_call_id(
608                                    tool_call.id.clone(),
609                                    call_id,
610                                    ToolResultContent::from_tool_output(output),
611                                ))
612                            } else {
613                                Ok(UserContent::tool_result(
614                                    tool_call.id.clone(),
615                                    ToolResultContent::from_tool_output(output),
616                                ))
617                            }
618                        } else {
619                            unreachable!(
620                                "This should never happen as we already filtered for `ToolCall`"
621                            )
622                        }
623                    }
624                    .instrument(tool_span)
625                })
626                .buffer_unordered(self.concurrency)
627                .collect::<Vec<Result<UserContent, PromptError>>>()
628                .await
629                .into_iter()
630                .collect::<Result<Vec<_>, _>>()?;
631
632            new_messages.push(Message::User {
633                content: OneOrMany::many(tool_content).expect("There is at least one tool call"),
634            });
635        };
636
637        // If we reach here, we exceeded max turns without a final response
638        Err(PromptError::MaxTurnsError {
639            max_turns: self.max_turns,
640            chat_history: build_full_history(chat_history.as_deref(), new_messages).into(),
641            prompt: last_prompt.into(),
642        })
643    }
644}
645
646// ================================================================
647// TypedPromptRequest - for structured output with automatic deserialization
648// ================================================================
649
650use crate::completion::StructuredOutputError;
651use schemars::{JsonSchema, schema_for};
652use serde::de::DeserializeOwned;
653
654/// A builder for creating typed prompt requests that return deserialized structured output.
655///
656/// This struct wraps a standard `PromptRequest` and adds:
657/// - Automatic JSON schema generation from the target type `T`
658/// - Automatic deserialization of the response into `T`
659///
660/// The type parameter `S` represents the state of the request (Standard or Extended).
661/// Use `.extended_details()` to transition to Extended state for usage tracking.
662///
663/// # Example
664/// ```rust,ignore
665/// let forecast: WeatherForecast = agent
666///     .prompt_typed("What's the weather in NYC?")
667///     .max_turns(3)
668///     .await?;
669/// ```
670pub struct TypedPromptRequest<T, S, M, P>
671where
672    T: JsonSchema + DeserializeOwned + WasmCompatSend,
673    S: PromptType,
674    M: CompletionModel,
675    P: PromptHook<M>,
676{
677    inner: PromptRequest<S, M, P>,
678    _phantom: std::marker::PhantomData<T>,
679}
680
681impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
682where
683    T: JsonSchema + DeserializeOwned + WasmCompatSend,
684    M: CompletionModel,
685    P: PromptHook<M>,
686{
687    /// Create a new TypedPromptRequest from an agent.
688    ///
689    /// This automatically sets the output schema based on the type parameter `T`.
690    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
691        let mut inner = PromptRequest::from_agent(agent, prompt);
692        // Override the output schema with the schema for T
693        inner.output_schema = Some(schema_for!(T));
694        Self {
695            inner,
696            _phantom: std::marker::PhantomData,
697        }
698    }
699}
700
701impl<T, S, M, P> TypedPromptRequest<T, S, M, P>
702where
703    T: JsonSchema + DeserializeOwned + WasmCompatSend,
704    S: PromptType,
705    M: CompletionModel,
706    P: PromptHook<M>,
707{
708    /// Enable returning extended details for responses (includes aggregated token usage).
709    ///
710    /// Note: This changes the type of the response from `.send()` to return a `TypedPromptResponse<T>` struct
711    /// instead of just `T`. This is useful for tracking token usage across multiple turns
712    /// of conversation.
713    pub fn extended_details(self) -> TypedPromptRequest<T, Extended, M, P> {
714        TypedPromptRequest {
715            inner: self.inner.extended_details(),
716            _phantom: std::marker::PhantomData,
717        }
718    }
719
720    /// Set the maximum number of turns for multi-turn conversations.
721    ///
722    /// A given agent may require multiple turns for tool-calling before giving an answer.
723    /// If the maximum turn number is exceeded, it will return a
724    /// [`StructuredOutputError::PromptError`] wrapping a `MaxTurnsError`.
725    pub fn max_turns(mut self, depth: usize) -> Self {
726        self.inner = self.inner.max_turns(depth);
727        self
728    }
729
730    /// Add concurrency to the prompt request.
731    ///
732    /// This will cause the agent to execute tools concurrently.
733    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
734        self.inner = self.inner.with_tool_concurrency(concurrency);
735        self
736    }
737
738    /// Add chat history to the prompt request.
739    pub fn with_history<I, H>(mut self, history: I) -> Self
740    where
741        I: IntoIterator<Item = H>,
742        H: Into<Message>,
743    {
744        self.inner = self.inner.with_history(history);
745        self
746    }
747
748    /// Attach a per-request hook for tool call events.
749    ///
750    /// This overrides any default hook set on the agent.
751    pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<T, S, M, P2>
752    where
753        P2: PromptHook<M>,
754    {
755        TypedPromptRequest {
756            inner: self.inner.with_hook(hook),
757            _phantom: std::marker::PhantomData,
758        }
759    }
760}
761
762impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
763where
764    T: JsonSchema + DeserializeOwned + WasmCompatSend,
765    M: CompletionModel,
766    P: PromptHook<M>,
767{
768    /// Send the typed prompt request and deserialize the response.
769    async fn send(self) -> Result<T, StructuredOutputError> {
770        let response = self.inner.send().await.map_err(Box::new)?;
771
772        if response.is_empty() {
773            return Err(StructuredOutputError::EmptyResponse);
774        }
775
776        let parsed: T = serde_json::from_str(&response)?;
777        Ok(parsed)
778    }
779}
780
781impl<T, M, P> TypedPromptRequest<T, Extended, M, P>
782where
783    T: JsonSchema + DeserializeOwned + WasmCompatSend,
784    M: CompletionModel,
785    P: PromptHook<M>,
786{
787    /// Send the typed prompt request with extended details and deserialize the response.
788    async fn send(self) -> Result<TypedPromptResponse<T>, StructuredOutputError> {
789        let response = self.inner.send().await.map_err(Box::new)?;
790
791        if response.output.is_empty() {
792            return Err(StructuredOutputError::EmptyResponse);
793        }
794
795        let parsed: T = serde_json::from_str(&response.output)?;
796        Ok(TypedPromptResponse::new(parsed, response.usage))
797    }
798}
799
800impl<T, M, P> IntoFuture for TypedPromptRequest<T, Standard, M, P>
801where
802    T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
803    M: CompletionModel + 'static,
804    P: PromptHook<M> + 'static,
805{
806    type Output = Result<T, StructuredOutputError>;
807    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
808
809    fn into_future(self) -> Self::IntoFuture {
810        Box::pin(self.send())
811    }
812}
813
814impl<T, M, P> IntoFuture for TypedPromptRequest<T, Extended, M, P>
815where
816    T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
817    M: CompletionModel + 'static,
818    P: PromptHook<M> + 'static,
819{
820    type Output = Result<TypedPromptResponse<T>, StructuredOutputError>;
821    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
822
823    fn into_future(self) -> Self::IntoFuture {
824        Box::pin(self.send())
825    }
826}