Skip to main content

rig/agent/prompt_request/
mod.rs

1pub mod hooks;
2pub mod streaming;
3
4use super::{
5    Agent,
6    completion::{DynamicContextStore, build_completion_request},
7};
8use crate::{
9    OneOrMany,
10    completion::{CompletionModel, Document, Message, PromptError, Usage},
11    json_utils,
12    message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
13    tool::server::ToolServerHandle,
14    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
15};
16use futures::{StreamExt, stream};
17use hooks::{HookAction, PromptHook, ToolCallHookAction};
18use serde::{Deserialize, Serialize};
19use std::{
20    future::IntoFuture,
21    marker::PhantomData,
22    sync::{
23        Arc,
24        atomic::{AtomicU64, Ordering},
25    },
26};
27use tracing::info_span;
28use tracing::{Instrument, span::Id};
29
30pub trait PromptType {}
31pub struct Standard;
32pub struct Extended;
33
34impl PromptType for Standard {}
35impl PromptType for Extended {}
36
37/// A builder for creating prompt requests with customizable options.
38/// Uses generics to track which options have been set during the build process.
39///
40/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
41/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
42/// attempting to await (which will send the prompt request) can potentially return
43/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
44/// back to back.
45pub struct PromptRequest<S, M, P>
46where
47    S: PromptType,
48    M: CompletionModel,
49    P: PromptHook<M>,
50{
51    /// The prompt message to send to the model
52    prompt: Message,
53    /// Optional chat history provided by the caller.
54    chat_history: Option<Vec<Message>>,
55    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
56    max_turns: usize,
57
58    // Agent data (cloned from agent to allow hook type transitions):
59    /// The completion model
60    model: Arc<M>,
61    /// Agent name for logging
62    agent_name: Option<String>,
63    /// System prompt
64    preamble: Option<String>,
65    /// Static context documents
66    static_context: Vec<Document>,
67    /// Temperature setting
68    temperature: Option<f64>,
69    /// Max tokens setting
70    max_tokens: Option<u64>,
71    /// Additional model parameters
72    additional_params: Option<serde_json::Value>,
73    /// Tool server handle for tool execution
74    tool_server_handle: ToolServerHandle,
75    /// Dynamic context store
76    dynamic_context: DynamicContextStore,
77    /// Tool choice setting
78    tool_choice: Option<ToolChoice>,
79
80    /// Phantom data to track the type of the request
81    state: PhantomData<S>,
82    /// Optional per-request hook for events
83    hook: Option<P>,
84    /// How many tools should be executed at the same time (1 by default).
85    concurrency: usize,
86    /// Optional JSON Schema for structured output
87    output_schema: Option<schemars::Schema>,
88}
89
90impl<M, P> PromptRequest<Standard, M, P>
91where
92    M: CompletionModel,
93    P: PromptHook<M>,
94{
95    /// Create a new PromptRequest from an agent, cloning the agent's data and default hook.
96    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
97        PromptRequest {
98            prompt: prompt.into(),
99            chat_history: None,
100            max_turns: agent.default_max_turns.unwrap_or_default(),
101            model: agent.model.clone(),
102            agent_name: agent.name.clone(),
103            preamble: agent.preamble.clone(),
104            static_context: agent.static_context.clone(),
105            temperature: agent.temperature,
106            max_tokens: agent.max_tokens,
107            additional_params: agent.additional_params.clone(),
108            tool_server_handle: agent.tool_server_handle.clone(),
109            dynamic_context: agent.dynamic_context.clone(),
110            tool_choice: agent.tool_choice.clone(),
111            state: PhantomData,
112            hook: agent.hook.clone(),
113            concurrency: 1,
114            output_schema: agent.output_schema.clone(),
115        }
116    }
117}
118
119impl<S, M, P> PromptRequest<S, M, P>
120where
121    S: PromptType,
122    M: CompletionModel,
123    P: PromptHook<M>,
124{
125    /// Enable returning extended details for responses (includes aggregated token usage
126    /// and the full message history accumulated during the agent loop).
127    ///
128    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
129    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
130    /// of conversation and inspecting the full message exchange.
131    pub fn extended_details(self) -> PromptRequest<Extended, M, P> {
132        PromptRequest {
133            prompt: self.prompt,
134            chat_history: self.chat_history,
135            max_turns: self.max_turns,
136            model: self.model,
137            agent_name: self.agent_name,
138            preamble: self.preamble,
139            static_context: self.static_context,
140            temperature: self.temperature,
141            max_tokens: self.max_tokens,
142            additional_params: self.additional_params,
143            tool_server_handle: self.tool_server_handle,
144            dynamic_context: self.dynamic_context,
145            tool_choice: self.tool_choice,
146            state: PhantomData,
147            hook: self.hook,
148            concurrency: self.concurrency,
149            output_schema: self.output_schema,
150        }
151    }
152
153    /// Set the maximum number of turns for multi-turn conversations. A given agent may require multiple turns for tool-calling before giving an answer.
154    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
155    pub fn max_turns(mut self, depth: usize) -> Self {
156        self.max_turns = depth;
157        self
158    }
159
160    /// Add concurrency to the prompt request.
161    /// This will cause the agent to execute tools concurrently.
162    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
163        self.concurrency = concurrency;
164        self
165    }
166
167    /// Add chat history to the prompt request.
168    pub fn with_history<I, T>(mut self, history: I) -> Self
169    where
170        I: IntoIterator<Item = T>,
171        T: Into<Message>,
172    {
173        self.chat_history = Some(history.into_iter().map(Into::into).collect());
174        self
175    }
176
177    /// Attach a per-request hook for tool call events.
178    /// This overrides any default hook set on the agent.
179    pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<S, M, P2>
180    where
181        P2: PromptHook<M>,
182    {
183        PromptRequest {
184            prompt: self.prompt,
185            chat_history: self.chat_history,
186            max_turns: self.max_turns,
187            model: self.model,
188            agent_name: self.agent_name,
189            preamble: self.preamble,
190            static_context: self.static_context,
191            temperature: self.temperature,
192            max_tokens: self.max_tokens,
193            additional_params: self.additional_params,
194            tool_server_handle: self.tool_server_handle,
195            dynamic_context: self.dynamic_context,
196            tool_choice: self.tool_choice,
197            state: PhantomData,
198            hook: Some(hook),
199            concurrency: self.concurrency,
200            output_schema: self.output_schema,
201        }
202    }
203}
204
205/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
206///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
207///  directly via the associated type.
208impl<M, P> IntoFuture for PromptRequest<Standard, M, P>
209where
210    M: CompletionModel + 'static,
211    P: PromptHook<M> + 'static,
212{
213    type Output = Result<String, PromptError>;
214    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
215
216    fn into_future(self) -> Self::IntoFuture {
217        Box::pin(self.send())
218    }
219}
220
221impl<M, P> IntoFuture for PromptRequest<Extended, M, P>
222where
223    M: CompletionModel + 'static,
224    P: PromptHook<M> + 'static,
225{
226    type Output = Result<PromptResponse, PromptError>;
227    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
228
229    fn into_future(self) -> Self::IntoFuture {
230        Box::pin(self.send())
231    }
232}
233
234impl<M, P> PromptRequest<Standard, M, P>
235where
236    M: CompletionModel,
237    P: PromptHook<M>,
238{
239    async fn send(self) -> Result<String, PromptError> {
240        self.extended_details().send().await.map(|resp| resp.output)
241    }
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
245#[non_exhaustive]
246pub struct PromptResponse {
247    pub output: String,
248    pub usage: Usage,
249    pub messages: Option<Vec<Message>>,
250}
251
252impl std::fmt::Display for PromptResponse {
253    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254        self.output.fmt(f)
255    }
256}
257
258impl PromptResponse {
259    pub fn new(output: impl Into<String>, usage: Usage) -> Self {
260        Self {
261            output: output.into(),
262            usage,
263            messages: None,
264        }
265    }
266
267    pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
268        self.messages = Some(messages);
269        self
270    }
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct TypedPromptResponse<T> {
275    pub output: T,
276    pub usage: Usage,
277}
278
279impl<T> TypedPromptResponse<T> {
280    pub fn new(output: T, usage: Usage) -> Self {
281        Self { output, usage }
282    }
283}
284
285const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
286
287/// Combine input history with new messages for building completion requests.
288fn build_history_for_request(
289    chat_history: Option<&[Message]>,
290    new_messages: &[Message],
291) -> Vec<Message> {
292    let input = chat_history.unwrap_or(&[]);
293    input.iter().chain(new_messages.iter()).cloned().collect()
294}
295
296/// Build the full history for error reporting (input + new messages).
297fn build_full_history(
298    chat_history: Option<&[Message]>,
299    new_messages: Vec<Message>,
300) -> Vec<Message> {
301    let input = chat_history.unwrap_or(&[]);
302    input.iter().cloned().chain(new_messages).collect()
303}
304
305fn is_empty_assistant_turn(choice: &OneOrMany<AssistantContent>) -> bool {
306    choice.len() == 1
307        && matches!(
308            choice.first(),
309            AssistantContent::Text(text) if text.text.is_empty()
310        )
311}
312
313impl<M, P> PromptRequest<Extended, M, P>
314where
315    M: CompletionModel,
316    P: PromptHook<M>,
317{
318    fn agent_name(&self) -> &str {
319        self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
320    }
321
322    async fn send(self) -> Result<PromptResponse, PromptError> {
323        let agent_span = if tracing::Span::current().is_disabled() {
324            info_span!(
325                "invoke_agent",
326                gen_ai.operation.name = "invoke_agent",
327                gen_ai.agent.name = self.agent_name(),
328                gen_ai.system_instructions = self.preamble,
329                gen_ai.prompt = tracing::field::Empty,
330                gen_ai.completion = tracing::field::Empty,
331                gen_ai.usage.input_tokens = tracing::field::Empty,
332                gen_ai.usage.output_tokens = tracing::field::Empty,
333                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
334                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
335            )
336        } else {
337            tracing::Span::current()
338        };
339
340        if let Some(text) = self.prompt.rag_text() {
341            agent_span.record("gen_ai.prompt", text);
342        }
343
344        let agent_name_for_span = self.agent_name.clone();
345        let chat_history = self.chat_history;
346        let mut new_messages: Vec<Message> = vec![self.prompt.clone()];
347
348        let mut current_max_turns = 0;
349        let mut usage = Usage::new();
350        let current_span_id: AtomicU64 = AtomicU64::new(0);
351
352        // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
353        let last_prompt = loop {
354            // Get the last message (the current prompt)
355            let Some((prompt_ref, history_for_current_turn)) = new_messages.split_last() else {
356                return Err(PromptError::prompt_cancelled(
357                    build_full_history(chat_history.as_deref(), new_messages),
358                    "prompt loop lost its pending prompt",
359                ));
360            };
361            let prompt = prompt_ref.clone();
362
363            if current_max_turns > self.max_turns + 1 {
364                break prompt;
365            }
366
367            current_max_turns += 1;
368
369            if self.max_turns > 1 {
370                tracing::info!(
371                    "Current conversation depth: {}/{}",
372                    current_max_turns,
373                    self.max_turns
374                );
375            }
376
377            // Build history for hook callback (input + new messages except last)
378            let history_for_hook =
379                build_history_for_request(chat_history.as_deref(), history_for_current_turn);
380
381            if let Some(ref hook) = self.hook
382                && let HookAction::Terminate { reason } =
383                    hook.on_completion_call(&prompt, &history_for_hook).await
384            {
385                return Err(PromptError::prompt_cancelled(
386                    build_full_history(chat_history.as_deref(), new_messages),
387                    reason,
388                ));
389            }
390
391            let span = tracing::Span::current();
392            let chat_span = info_span!(
393                target: "rig::agent_chat",
394                parent: &span,
395                "chat",
396                gen_ai.operation.name = "chat",
397                gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
398                gen_ai.system_instructions = self.preamble,
399                gen_ai.provider.name = tracing::field::Empty,
400                gen_ai.request.model = tracing::field::Empty,
401                gen_ai.response.id = tracing::field::Empty,
402                gen_ai.response.model = tracing::field::Empty,
403                gen_ai.usage.output_tokens = tracing::field::Empty,
404                gen_ai.usage.input_tokens = tracing::field::Empty,
405                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
406                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
407                gen_ai.input.messages = tracing::field::Empty,
408                gen_ai.output.messages = tracing::field::Empty,
409            );
410
411            let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
412                let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
413                chat_span.follows_from(id).to_owned()
414            } else {
415                chat_span
416            };
417
418            if let Some(id) = chat_span.id() {
419                current_span_id.store(id.into_u64(), Ordering::SeqCst);
420            };
421
422            // Build history for completion request (input + new messages except last)
423            let history_for_request =
424                build_history_for_request(chat_history.as_deref(), history_for_current_turn);
425
426            let resp = build_completion_request(
427                &self.model,
428                prompt.clone(),
429                &history_for_request,
430                self.preamble.as_deref(),
431                &self.static_context,
432                self.temperature,
433                self.max_tokens,
434                self.additional_params.as_ref(),
435                self.tool_choice.as_ref(),
436                &self.tool_server_handle,
437                &self.dynamic_context,
438                self.output_schema.as_ref(),
439            )
440            .await?
441            .send()
442            .instrument(chat_span.clone())
443            .await?;
444
445            usage += resp.usage;
446
447            if let Some(ref hook) = self.hook
448                && let HookAction::Terminate { reason } =
449                    hook.on_completion_response(&prompt, &resp).await
450            {
451                return Err(PromptError::prompt_cancelled(
452                    build_full_history(chat_history.as_deref(), new_messages),
453                    reason,
454                ));
455            }
456
457            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
458                .choice
459                .iter()
460                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
461
462            // Some providers normalize textless terminal turns into a single empty text item
463            // because the generic completion response cannot represent an empty choice. Treat
464            // that sentinel as "no assistant output" so it does not pollute returned history.
465            if !is_empty_assistant_turn(&resp.choice) {
466                new_messages.push(Message::Assistant {
467                    id: resp.message_id.clone(),
468                    content: resp.choice.clone(),
469                });
470            }
471
472            if tool_calls.is_empty() {
473                let merged_texts = texts
474                    .into_iter()
475                    .filter_map(|content| {
476                        if let AssistantContent::Text(text) = content {
477                            Some(text.text.clone())
478                        } else {
479                            None
480                        }
481                    })
482                    .collect::<Vec<_>>()
483                    .join("\n");
484
485                if self.max_turns > 1 {
486                    tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
487                }
488
489                agent_span.record("gen_ai.completion", &merged_texts);
490                agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
491                agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
492                agent_span.record(
493                    "gen_ai.usage.cache_read.input_tokens",
494                    usage.cached_input_tokens,
495                );
496                agent_span.record(
497                    "gen_ai.usage.cache_creation.input_tokens",
498                    usage.cache_creation_input_tokens,
499                );
500
501                return Ok(PromptResponse::new(merged_texts, usage).with_messages(new_messages));
502            }
503
504            let hook = self.hook.clone();
505            let tool_server_handle = self.tool_server_handle.clone();
506
507            // For error handling in concurrent tool execution, we need to build full history
508            let full_history_for_errors =
509                build_full_history(chat_history.as_deref(), new_messages.clone());
510
511            let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
512            let tool_content = stream::iter(tool_calls)
513                .map(|choice| {
514                    let hook1 = hook.clone();
515                    let hook2 = hook.clone();
516                    let tool_server_handle = tool_server_handle.clone();
517
518                    let tool_span = info_span!(
519                        "execute_tool",
520                        gen_ai.operation.name = "execute_tool",
521                        gen_ai.tool.type = "function",
522                        gen_ai.tool.name = tracing::field::Empty,
523                        gen_ai.tool.call.id = tracing::field::Empty,
524                        gen_ai.tool.call.arguments = tracing::field::Empty,
525                        gen_ai.tool.call.result = tracing::field::Empty
526                    );
527
528                    let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
529                        let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
530                        tool_span.follows_from(id).to_owned()
531                    } else {
532                        tool_span
533                    };
534
535                    if let Some(id) = tool_span.id() {
536                        current_span_id.store(id.into_u64(), Ordering::SeqCst);
537                    };
538
539                    // Clone full history for error reporting in concurrent tool execution
540                    let cloned_history_for_error = full_history_for_errors.clone();
541
542                    async move {
543                        if let AssistantContent::ToolCall(tool_call) = choice {
544                            let tool_name = &tool_call.function.name;
545                            let args =
546                                json_utils::value_to_json_string(&tool_call.function.arguments);
547                            let internal_call_id = nanoid::nanoid!();
548                            let tool_span = tracing::Span::current();
549                            tool_span.record("gen_ai.tool.name", tool_name);
550                            tool_span.record("gen_ai.tool.call.id", &tool_call.id);
551                            tool_span.record("gen_ai.tool.call.arguments", &args);
552                            if let Some(hook) = hook1 {
553                                let action = hook
554                                    .on_tool_call(
555                                        tool_name,
556                                        tool_call.call_id.clone(),
557                                        &internal_call_id,
558                                        &args,
559                                    )
560                                    .await;
561
562                                if let ToolCallHookAction::Terminate { reason } = action {
563                                    return Err(PromptError::prompt_cancelled(
564                                        cloned_history_for_error,
565                                        reason,
566                                    ));
567                                }
568
569                                if let ToolCallHookAction::Skip { reason } = action {
570                                    // Tool execution rejected, return rejection message as tool result
571                                    tracing::info!(
572                                        tool_name = tool_name,
573                                        reason = reason,
574                                        "Tool call rejected"
575                                    );
576                                    if let Some(call_id) = tool_call.call_id.clone() {
577                                        return Ok(UserContent::tool_result_with_call_id(
578                                            tool_call.id.clone(),
579                                            call_id,
580                                            OneOrMany::one(reason.into()),
581                                        ));
582                                    } else {
583                                        return Ok(UserContent::tool_result(
584                                            tool_call.id.clone(),
585                                            OneOrMany::one(reason.into()),
586                                        ));
587                                    }
588                                }
589                            }
590                            let output = match tool_server_handle.call_tool(tool_name, &args).await
591                            {
592                                Ok(res) => res,
593                                Err(e) => {
594                                    tracing::warn!("Error while executing tool: {e}");
595                                    e.to_string()
596                                }
597                            };
598                            if let Some(hook) = hook2
599                                && let HookAction::Terminate { reason } = hook
600                                    .on_tool_result(
601                                        tool_name,
602                                        tool_call.call_id.clone(),
603                                        &internal_call_id,
604                                        &args,
605                                        &output.to_string(),
606                                    )
607                                    .await
608                            {
609                                return Err(PromptError::prompt_cancelled(
610                                    cloned_history_for_error,
611                                    reason,
612                                ));
613                            }
614
615                            tool_span.record("gen_ai.tool.call.result", &output);
616                            tracing::info!(
617                                "executed tool {tool_name} with args {args}. result: {output}"
618                            );
619                            if let Some(call_id) = tool_call.call_id.clone() {
620                                Ok(UserContent::tool_result_with_call_id(
621                                    tool_call.id.clone(),
622                                    call_id,
623                                    ToolResultContent::from_tool_output(output),
624                                ))
625                            } else {
626                                Ok(UserContent::tool_result(
627                                    tool_call.id.clone(),
628                                    ToolResultContent::from_tool_output(output),
629                                ))
630                            }
631                        } else {
632                            Err(PromptError::prompt_cancelled(
633                                Vec::new(),
634                                "tool execution received non-tool assistant content",
635                            ))
636                        }
637                    }
638                    .instrument(tool_span)
639                })
640                .buffer_unordered(self.concurrency)
641                .collect::<Vec<Result<UserContent, PromptError>>>()
642                .await
643                .into_iter()
644                .collect::<Result<Vec<_>, _>>()?;
645
646            let Some(content) = OneOrMany::from_iter_optional(tool_content) else {
647                return Err(PromptError::prompt_cancelled(
648                    build_full_history(chat_history.as_deref(), new_messages),
649                    "tool execution produced no tool results",
650                ));
651            };
652
653            new_messages.push(Message::User { content });
654        };
655
656        // If we reach here, we exceeded max turns without a final response
657        Err(PromptError::MaxTurnsError {
658            max_turns: self.max_turns,
659            chat_history: build_full_history(chat_history.as_deref(), new_messages).into(),
660            prompt: last_prompt.into(),
661        })
662    }
663}
664
665// ================================================================
666// TypedPromptRequest - for structured output with automatic deserialization
667// ================================================================
668
669use crate::completion::StructuredOutputError;
670use schemars::{JsonSchema, schema_for};
671use serde::de::DeserializeOwned;
672
673/// A builder for creating typed prompt requests that return deserialized structured output.
674///
675/// This struct wraps a standard `PromptRequest` and adds:
676/// - Automatic JSON schema generation from the target type `T`
677/// - Automatic deserialization of the response into `T`
678///
679/// The type parameter `S` represents the state of the request (Standard or Extended).
680/// Use `.extended_details()` to transition to Extended state for usage tracking.
681///
682/// # Example
683/// ```rust,ignore
684/// let forecast: WeatherForecast = agent
685///     .prompt_typed("What's the weather in NYC?")
686///     .max_turns(3)
687///     .await?;
688/// ```
689pub struct TypedPromptRequest<T, S, M, P>
690where
691    T: JsonSchema + DeserializeOwned + WasmCompatSend,
692    S: PromptType,
693    M: CompletionModel,
694    P: PromptHook<M>,
695{
696    inner: PromptRequest<S, M, P>,
697    _phantom: std::marker::PhantomData<T>,
698}
699
700impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
701where
702    T: JsonSchema + DeserializeOwned + WasmCompatSend,
703    M: CompletionModel,
704    P: PromptHook<M>,
705{
706    /// Create a new TypedPromptRequest from an agent.
707    ///
708    /// This automatically sets the output schema based on the type parameter `T`.
709    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
710        let mut inner = PromptRequest::from_agent(agent, prompt);
711        // Override the output schema with the schema for T
712        inner.output_schema = Some(schema_for!(T));
713        Self {
714            inner,
715            _phantom: std::marker::PhantomData,
716        }
717    }
718}
719
720impl<T, S, M, P> TypedPromptRequest<T, S, M, P>
721where
722    T: JsonSchema + DeserializeOwned + WasmCompatSend,
723    S: PromptType,
724    M: CompletionModel,
725    P: PromptHook<M>,
726{
727    /// Enable returning extended details for responses (includes aggregated token usage).
728    ///
729    /// Note: This changes the type of the response from `.send()` to return a `TypedPromptResponse<T>` struct
730    /// instead of just `T`. This is useful for tracking token usage across multiple turns
731    /// of conversation.
732    pub fn extended_details(self) -> TypedPromptRequest<T, Extended, M, P> {
733        TypedPromptRequest {
734            inner: self.inner.extended_details(),
735            _phantom: std::marker::PhantomData,
736        }
737    }
738
739    /// Set the maximum number of turns for multi-turn conversations.
740    ///
741    /// A given agent may require multiple turns for tool-calling before giving an answer.
742    /// If the maximum turn number is exceeded, it will return a
743    /// [`StructuredOutputError::PromptError`] wrapping a `MaxTurnsError`.
744    pub fn max_turns(mut self, depth: usize) -> Self {
745        self.inner = self.inner.max_turns(depth);
746        self
747    }
748
749    /// Add concurrency to the prompt request.
750    ///
751    /// This will cause the agent to execute tools concurrently.
752    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
753        self.inner = self.inner.with_tool_concurrency(concurrency);
754        self
755    }
756
757    /// Add chat history to the prompt request.
758    pub fn with_history<I, H>(mut self, history: I) -> Self
759    where
760        I: IntoIterator<Item = H>,
761        H: Into<Message>,
762    {
763        self.inner = self.inner.with_history(history);
764        self
765    }
766
767    /// Attach a per-request hook for tool call events.
768    ///
769    /// This overrides any default hook set on the agent.
770    pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<T, S, M, P2>
771    where
772        P2: PromptHook<M>,
773    {
774        TypedPromptRequest {
775            inner: self.inner.with_hook(hook),
776            _phantom: std::marker::PhantomData,
777        }
778    }
779}
780
781impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
782where
783    T: JsonSchema + DeserializeOwned + WasmCompatSend,
784    M: CompletionModel,
785    P: PromptHook<M>,
786{
787    /// Send the typed prompt request and deserialize the response.
788    async fn send(self) -> Result<T, StructuredOutputError> {
789        let response = self.inner.send().await.map_err(Box::new)?;
790
791        if response.is_empty() {
792            return Err(StructuredOutputError::EmptyResponse);
793        }
794
795        let parsed: T = serde_json::from_str(&response)?;
796        Ok(parsed)
797    }
798}
799
800impl<T, M, P> TypedPromptRequest<T, Extended, M, P>
801where
802    T: JsonSchema + DeserializeOwned + WasmCompatSend,
803    M: CompletionModel,
804    P: PromptHook<M>,
805{
806    /// Send the typed prompt request with extended details and deserialize the response.
807    async fn send(self) -> Result<TypedPromptResponse<T>, StructuredOutputError> {
808        let response = self.inner.send().await.map_err(Box::new)?;
809
810        if response.output.is_empty() {
811            return Err(StructuredOutputError::EmptyResponse);
812        }
813
814        let parsed: T = serde_json::from_str(&response.output)?;
815        Ok(TypedPromptResponse::new(parsed, response.usage))
816    }
817}
818
819impl<T, M, P> IntoFuture for TypedPromptRequest<T, Standard, M, P>
820where
821    T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
822    M: CompletionModel + 'static,
823    P: PromptHook<M> + 'static,
824{
825    type Output = Result<T, StructuredOutputError>;
826    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
827
828    fn into_future(self) -> Self::IntoFuture {
829        Box::pin(self.send())
830    }
831}
832
833impl<T, M, P> IntoFuture for TypedPromptRequest<T, Extended, M, P>
834where
835    T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
836    M: CompletionModel + 'static,
837    P: PromptHook<M> + 'static,
838{
839    type Output = Result<TypedPromptResponse<T>, StructuredOutputError>;
840    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
841
842    fn into_future(self) -> Self::IntoFuture {
843        Box::pin(self.send())
844    }
845}
846
847#[cfg(test)]
848mod tests {
849    use super::TypedPromptResponse;
850    use crate::{
851        OneOrMany,
852        agent::AgentBuilder,
853        completion::{
854            AssistantContent, CompletionError, CompletionModel, CompletionRequest,
855            CompletionResponse, Message, Prompt, Usage,
856        },
857        message::UserContent,
858        streaming::StreamingCompletionResponse,
859    };
860    use serde::{Deserialize, Serialize};
861    use serde_json::json;
862    use std::sync::{
863        Arc,
864        atomic::{AtomicUsize, Ordering},
865    };
866
867    #[derive(Serialize)]
868    struct SerializeOnly {
869        value: &'static str,
870    }
871
872    #[derive(Deserialize)]
873    struct DeserializeOnly {
874        value: String,
875    }
876
877    #[test]
878    fn typed_prompt_response_serializes_with_serialize_only_output() {
879        let response = TypedPromptResponse::new(
880            SerializeOnly { value: "ok" },
881            Usage {
882                input_tokens: 1,
883                output_tokens: 2,
884                total_tokens: 3,
885                cached_input_tokens: 0,
886                cache_creation_input_tokens: 0,
887            },
888        );
889
890        let json = serde_json::to_string(&response).expect("serialize typed prompt response");
891        assert!(json.contains("\"value\":\"ok\""));
892    }
893
894    #[test]
895    fn typed_prompt_response_deserializes_with_deserialize_only_output() {
896        let response: TypedPromptResponse<DeserializeOnly> = serde_json::from_str(
897            r#"{"output":{"value":"ok"},"usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3,"cached_input_tokens":0,"cache_creation_input_tokens":0}}"#,
898        )
899        .expect("deserialize typed prompt response");
900
901        assert_eq!(response.output.value, "ok");
902        assert_eq!(response.usage.input_tokens, 1);
903        assert_eq!(response.usage.output_tokens, 2);
904        assert_eq!(response.usage.total_tokens, 3);
905    }
906
907    fn validate_follow_up_tool_history(request: &CompletionRequest) {
908        let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
909        assert_eq!(
910            history.len(),
911            3,
912            "follow-up request should contain the prompt, assistant tool call, and user tool result: {history:?}"
913        );
914
915        assert!(matches!(
916            history.first(),
917            Some(Message::User { content })
918                if matches!(
919                    content.first(),
920                    UserContent::Text(text) if text.text == "do tool work"
921                )
922        ));
923
924        assert!(matches!(
925            history.get(1),
926            Some(Message::Assistant { content, .. })
927                if matches!(
928                    content.first(),
929                    AssistantContent::ToolCall(tool_call)
930                        if tool_call.id == "tool_call_1"
931                            && tool_call.call_id.as_deref() == Some("call_1")
932                )
933        ));
934
935        assert!(matches!(
936            history.get(2),
937            Some(Message::User { content })
938                if matches!(
939                    content.first(),
940                    UserContent::ToolResult(tool_result)
941                        if tool_result.id == "tool_call_1"
942                            && tool_result.call_id.as_deref() == Some("call_1")
943                )
944        ));
945    }
946
947    #[derive(Clone, Default)]
948    struct EmptyFinalTurnMockModel {
949        turn_counter: Arc<AtomicUsize>,
950    }
951
952    #[allow(refining_impl_trait)]
953    impl CompletionModel for EmptyFinalTurnMockModel {
954        type Response = ();
955        type StreamingResponse = ();
956        type Client = ();
957
958        fn make(_: &Self::Client, _: impl Into<String>) -> Self {
959            Self::default()
960        }
961
962        async fn completion(
963            &self,
964            request: CompletionRequest,
965        ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
966            let turn = self.turn_counter.fetch_add(1, Ordering::SeqCst);
967
968            let choice = if turn == 0 {
969                OneOrMany::one(AssistantContent::tool_call_with_call_id(
970                    "tool_call_1",
971                    "call_1".to_string(),
972                    "missing_tool",
973                    json!({"input": "value"}),
974                ))
975            } else {
976                validate_follow_up_tool_history(&request);
977                OneOrMany::one(AssistantContent::text(""))
978            };
979
980            Ok(CompletionResponse {
981                choice,
982                usage: Usage {
983                    input_tokens: 1,
984                    output_tokens: 1,
985                    total_tokens: 2,
986                    cached_input_tokens: 0,
987                    cache_creation_input_tokens: 0,
988                },
989                raw_response: (),
990                message_id: None,
991            })
992        }
993
994        async fn stream(
995            &self,
996            _request: CompletionRequest,
997        ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
998            Err(CompletionError::ProviderError(
999                "stream is unused in this non-streaming test".to_string(),
1000            ))
1001        }
1002    }
1003
1004    #[tokio::test]
1005    async fn prompt_request_stops_cleanly_on_empty_terminal_turn() {
1006        let model = EmptyFinalTurnMockModel::default();
1007        let turn_counter = model.turn_counter.clone();
1008        let agent = AgentBuilder::new(model).build();
1009
1010        let response = agent
1011            .prompt("do tool work")
1012            .max_turns(3)
1013            .extended_details()
1014            .await
1015            .expect("empty terminal turn should not error");
1016
1017        assert!(response.output.is_empty());
1018        assert_eq!(
1019            response.usage,
1020            Usage {
1021                input_tokens: 2,
1022                output_tokens: 2,
1023                total_tokens: 4,
1024                cached_input_tokens: 0,
1025                cache_creation_input_tokens: 0,
1026            }
1027        );
1028
1029        let history = response
1030            .messages
1031            .expect("extended response should include history");
1032        assert_eq!(history.len(), 3);
1033        assert!(matches!(
1034            history.first(),
1035            Some(Message::User { content })
1036                if matches!(
1037                    content.first(),
1038                    UserContent::Text(text) if text.text == "do tool work"
1039                )
1040        ));
1041        assert!(history.iter().any(|message| matches!(
1042            message,
1043            Message::Assistant { content, .. }
1044                if matches!(
1045                    content.first(),
1046                    AssistantContent::ToolCall(tool_call)
1047                        if tool_call.id == "tool_call_1"
1048                            && tool_call.call_id.as_deref() == Some("call_1")
1049                )
1050        )));
1051        assert!(history.iter().any(|message| matches!(
1052            message,
1053            Message::User { content }
1054                if matches!(
1055                    content.first(),
1056                    UserContent::ToolResult(tool_result)
1057                        if tool_result.id == "tool_call_1"
1058                            && tool_result.call_id.as_deref() == Some("call_1")
1059                )
1060        )));
1061        assert!(!history.iter().any(|message| matches!(
1062            message,
1063            Message::Assistant { content, .. }
1064                if content.iter().any(|item| matches!(
1065                    item,
1066                    AssistantContent::Text(text) if text.text.is_empty()
1067                ))
1068        )));
1069        assert_eq!(turn_counter.load(Ordering::SeqCst), 2);
1070    }
1071}