Skip to main content

rig_core/agent/prompt_request/
mod.rs

1pub mod hooks;
2pub mod streaming;
3
4use super::{
5    Agent,
6    completion::{DynamicContextStore, build_completion_request},
7};
8use crate::{
9    OneOrMany,
10    completion::{CompletionModel, Document, Message, PromptError, Usage},
11    json_utils,
12    memory::ConversationMemory,
13    message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
14    tool::server::ToolServerHandle,
15    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
16};
17use futures::{StreamExt, stream};
18use hooks::{HookAction, PromptHook, ToolCallHookAction};
19use serde::{Deserialize, Serialize};
20use std::{
21    future::IntoFuture,
22    marker::PhantomData,
23    sync::{
24        Arc,
25        atomic::{AtomicU64, Ordering},
26    },
27};
28use tracing::info_span;
29use tracing::{Instrument, span::Id};
30
31pub trait PromptType {}
32pub struct Standard;
33pub struct Extended;
34
35impl PromptType for Standard {}
36impl PromptType for Extended {}
37
38/// A builder for creating prompt requests with customizable options.
39/// Uses generics to track which options have been set during the build process.
40///
41/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
42/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
43/// attempting to await (which will send the prompt request) can potentially return
44/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
45/// back to back.
46pub struct PromptRequest<S, M, P>
47where
48    S: PromptType,
49    M: CompletionModel,
50    P: PromptHook<M>,
51{
52    /// The prompt message to send to the model
53    prompt: Message,
54    /// Optional chat history provided by the caller.
55    chat_history: Option<Vec<Message>>,
56    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
57    max_turns: usize,
58
59    // Agent data (cloned from agent to allow hook type transitions):
60    /// The completion model
61    model: Arc<M>,
62    /// Agent name for logging
63    agent_name: Option<String>,
64    /// System prompt
65    preamble: Option<String>,
66    /// Static context documents
67    static_context: Vec<Document>,
68    /// Temperature setting
69    temperature: Option<f64>,
70    /// Max tokens setting
71    max_tokens: Option<u64>,
72    /// Additional model parameters
73    additional_params: Option<serde_json::Value>,
74    /// Tool server handle for tool execution
75    tool_server_handle: ToolServerHandle,
76    /// Dynamic context store
77    dynamic_context: DynamicContextStore,
78    /// Tool choice setting
79    tool_choice: Option<ToolChoice>,
80
81    /// Phantom data to track the type of the request
82    state: PhantomData<S>,
83    /// Optional per-request hook for events
84    hook: Option<P>,
85    /// How many tools should be executed at the same time (1 by default).
86    concurrency: usize,
87    /// Optional JSON Schema for structured output
88    output_schema: Option<schemars::Schema>,
89    /// Optional conversation memory backend cloned from the agent.
90    memory: Option<Arc<dyn ConversationMemory>>,
91    /// Optional conversation id used for loading and saving memory.
92    conversation_id: Option<String>,
93}
94
95impl<M, P> PromptRequest<Standard, M, P>
96where
97    M: CompletionModel,
98    P: PromptHook<M>,
99{
100    /// Create a new PromptRequest from an agent, cloning the agent's data and default hook.
101    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
102        PromptRequest {
103            prompt: prompt.into(),
104            chat_history: None,
105            max_turns: agent.default_max_turns.unwrap_or_default(),
106            model: agent.model.clone(),
107            agent_name: agent.name.clone(),
108            preamble: agent.preamble.clone(),
109            static_context: agent.static_context.clone(),
110            temperature: agent.temperature,
111            max_tokens: agent.max_tokens,
112            additional_params: agent.additional_params.clone(),
113            tool_server_handle: agent.tool_server_handle.clone(),
114            dynamic_context: agent.dynamic_context.clone(),
115            tool_choice: agent.tool_choice.clone(),
116            state: PhantomData,
117            hook: agent.hook.clone(),
118            concurrency: 1,
119            output_schema: agent.output_schema.clone(),
120            memory: agent.memory.clone(),
121            conversation_id: agent.default_conversation_id.clone(),
122        }
123    }
124}
125
126impl<S, M, P> PromptRequest<S, M, P>
127where
128    S: PromptType,
129    M: CompletionModel,
130    P: PromptHook<M>,
131{
132    /// Enable returning extended details for responses (includes aggregated token usage
133    /// and the full message history accumulated during the agent loop).
134    ///
135    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
136    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
137    /// of conversation and inspecting the full message exchange.
138    pub fn extended_details(self) -> PromptRequest<Extended, M, P> {
139        PromptRequest {
140            prompt: self.prompt,
141            chat_history: self.chat_history,
142            max_turns: self.max_turns,
143            model: self.model,
144            agent_name: self.agent_name,
145            preamble: self.preamble,
146            static_context: self.static_context,
147            temperature: self.temperature,
148            max_tokens: self.max_tokens,
149            additional_params: self.additional_params,
150            tool_server_handle: self.tool_server_handle,
151            dynamic_context: self.dynamic_context,
152            tool_choice: self.tool_choice,
153            state: PhantomData,
154            hook: self.hook,
155            concurrency: self.concurrency,
156            output_schema: self.output_schema,
157            memory: self.memory,
158            conversation_id: self.conversation_id,
159        }
160    }
161
162    /// Set the maximum number of turns for multi-turn conversations. A given agent may require multiple turns for tool-calling before giving an answer.
163    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
164    pub fn max_turns(mut self, depth: usize) -> Self {
165        self.max_turns = depth;
166        self
167    }
168
169    /// Add concurrency to the prompt request.
170    /// This will cause the agent to execute tools concurrently.
171    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
172        self.concurrency = concurrency;
173        self
174    }
175
176    /// Add chat history to the prompt request.
177    pub fn with_history<I, T>(mut self, history: I) -> Self
178    where
179        I: IntoIterator<Item = T>,
180        T: Into<Message>,
181    {
182        self.chat_history = Some(history.into_iter().map(Into::into).collect());
183        self
184    }
185
186    /// Set the conversation id used to load and persist memory for this request.
187    ///
188    /// Overrides any default conversation id set on the agent. If memory is not
189    /// configured on the agent, this has no effect.
190    pub fn conversation(mut self, id: impl Into<String>) -> Self {
191        self.conversation_id = Some(id.into());
192        self
193    }
194
195    /// Disable conversation memory for this request.
196    ///
197    /// History will neither be loaded from nor saved to the agent's memory backend.
198    pub fn without_memory(mut self) -> Self {
199        self.memory = None;
200        self.conversation_id = None;
201        self
202    }
203
204    /// Attach a per-request hook for tool call events.
205    /// This overrides any default hook set on the agent.
206    pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<S, M, P2>
207    where
208        P2: PromptHook<M>,
209    {
210        PromptRequest {
211            prompt: self.prompt,
212            chat_history: self.chat_history,
213            max_turns: self.max_turns,
214            model: self.model,
215            agent_name: self.agent_name,
216            preamble: self.preamble,
217            static_context: self.static_context,
218            temperature: self.temperature,
219            max_tokens: self.max_tokens,
220            additional_params: self.additional_params,
221            tool_server_handle: self.tool_server_handle,
222            dynamic_context: self.dynamic_context,
223            tool_choice: self.tool_choice,
224            state: PhantomData,
225            hook: Some(hook),
226            concurrency: self.concurrency,
227            output_schema: self.output_schema,
228            memory: self.memory,
229            conversation_id: self.conversation_id,
230        }
231    }
232}
233
234/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
235///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
236///  directly via the associated type.
237impl<M, P> IntoFuture for PromptRequest<Standard, M, P>
238where
239    M: CompletionModel + 'static,
240    P: PromptHook<M> + 'static,
241{
242    type Output = Result<String, PromptError>;
243    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
244
245    fn into_future(self) -> Self::IntoFuture {
246        Box::pin(self.send())
247    }
248}
249
250impl<M, P> IntoFuture for PromptRequest<Extended, M, P>
251where
252    M: CompletionModel + 'static,
253    P: PromptHook<M> + 'static,
254{
255    type Output = Result<PromptResponse, PromptError>;
256    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
257
258    fn into_future(self) -> Self::IntoFuture {
259        Box::pin(self.send())
260    }
261}
262
263impl<M, P> PromptRequest<Standard, M, P>
264where
265    M: CompletionModel,
266    P: PromptHook<M>,
267{
268    async fn send(self) -> Result<String, PromptError> {
269        self.extended_details().send().await.map(|resp| resp.output)
270    }
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize)]
274#[non_exhaustive]
275pub struct PromptResponse {
276    pub output: String,
277    pub usage: Usage,
278    pub messages: Option<Vec<Message>>,
279}
280
281impl std::fmt::Display for PromptResponse {
282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283        self.output.fmt(f)
284    }
285}
286
287impl PromptResponse {
288    pub fn new(output: impl Into<String>, usage: Usage) -> Self {
289        Self {
290            output: output.into(),
291            usage,
292            messages: None,
293        }
294    }
295
296    pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
297        self.messages = Some(messages);
298        self
299    }
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct TypedPromptResponse<T> {
304    pub output: T,
305    pub usage: Usage,
306}
307
308impl<T> TypedPromptResponse<T> {
309    pub fn new(output: T, usage: Usage) -> Self {
310        Self { output, usage }
311    }
312}
313
314const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
315
316/// Combine input history with new messages for building completion requests.
317fn build_history_for_request(
318    chat_history: Option<&[Message]>,
319    new_messages: &[Message],
320) -> Vec<Message> {
321    let input = chat_history.unwrap_or(&[]);
322    input.iter().chain(new_messages.iter()).cloned().collect()
323}
324
325/// Build the full history for error reporting (input + new messages).
326fn build_full_history(
327    chat_history: Option<&[Message]>,
328    new_messages: Vec<Message>,
329) -> Vec<Message> {
330    let input = chat_history.unwrap_or(&[]);
331    input.iter().cloned().chain(new_messages).collect()
332}
333
334fn is_empty_assistant_turn(choice: &OneOrMany<AssistantContent>) -> bool {
335    choice.len() == 1
336        && matches!(
337            choice.first(),
338            AssistantContent::Text(text) if text.text.is_empty()
339        )
340}
341
342impl<M, P> PromptRequest<Extended, M, P>
343where
344    M: CompletionModel,
345    P: PromptHook<M>,
346{
347    fn agent_name(&self) -> &str {
348        self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
349    }
350
351    async fn send(self) -> Result<PromptResponse, PromptError> {
352        let agent_span = if tracing::Span::current().is_disabled() {
353            info_span!(
354                "invoke_agent",
355                gen_ai.operation.name = "invoke_agent",
356                gen_ai.agent.name = self.agent_name(),
357                gen_ai.system_instructions = self.preamble,
358                gen_ai.prompt = tracing::field::Empty,
359                gen_ai.completion = tracing::field::Empty,
360                gen_ai.usage.input_tokens = tracing::field::Empty,
361                gen_ai.usage.output_tokens = tracing::field::Empty,
362                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
363                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
364                gen_ai.usage.reasoning_tokens = tracing::field::Empty,
365            )
366        } else {
367            tracing::Span::current()
368        };
369
370        if let Some(text) = self.prompt.rag_text() {
371            agent_span.record("gen_ai.prompt", text);
372        }
373
374        let agent_name_for_span = self.agent_name.clone();
375        // When the caller passes explicit history, memory is fully bypassed for this
376        // request (no load AND no save). Otherwise, if a memory backend and
377        // conversation id are both configured, load prior history; if either is
378        // missing, behave as if no memory is configured.
379        let (chat_history, memory_handle) = match self.chat_history {
380            Some(history) => (Some(history), None),
381            None => match (self.memory, self.conversation_id) {
382                (Some(memory), Some(id)) => {
383                    let loaded = memory.load(&id).await?;
384                    (Some(loaded), Some((memory, id)))
385                }
386                _ => (None, None),
387            },
388        };
389        let mut new_messages: Vec<Message> = vec![self.prompt.clone()];
390
391        let mut current_max_turns = 0;
392        let mut usage = Usage::new();
393        let current_span_id: AtomicU64 = AtomicU64::new(0);
394
395        // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
396        let last_prompt = loop {
397            // Get the last message (the current prompt)
398            let Some((prompt_ref, history_for_current_turn)) = new_messages.split_last() else {
399                return Err(PromptError::prompt_cancelled(
400                    build_full_history(chat_history.as_deref(), new_messages),
401                    "prompt loop lost its pending prompt",
402                ));
403            };
404            let prompt = prompt_ref.clone();
405
406            if current_max_turns > self.max_turns + 1 {
407                break prompt;
408            }
409
410            current_max_turns += 1;
411
412            if self.max_turns > 1 {
413                tracing::info!(
414                    "Current conversation depth: {}/{}",
415                    current_max_turns,
416                    self.max_turns
417                );
418            }
419
420            // Build history for hook callback (input + new messages except last)
421            let history_for_hook =
422                build_history_for_request(chat_history.as_deref(), history_for_current_turn);
423
424            if let Some(ref hook) = self.hook
425                && let HookAction::Terminate { reason } =
426                    hook.on_completion_call(&prompt, &history_for_hook).await
427            {
428                return Err(PromptError::prompt_cancelled(
429                    build_full_history(chat_history.as_deref(), new_messages),
430                    reason,
431                ));
432            }
433
434            let span = tracing::Span::current();
435            let chat_span = info_span!(
436                target: "rig::agent_chat",
437                parent: &span,
438                "chat",
439                gen_ai.operation.name = "chat",
440                gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
441                gen_ai.system_instructions = self.preamble,
442                gen_ai.provider.name = tracing::field::Empty,
443                gen_ai.request.model = tracing::field::Empty,
444                gen_ai.response.id = tracing::field::Empty,
445                gen_ai.response.model = tracing::field::Empty,
446                gen_ai.usage.output_tokens = tracing::field::Empty,
447                gen_ai.usage.input_tokens = tracing::field::Empty,
448                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
449                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
450                gen_ai.usage.reasoning_tokens = tracing::field::Empty,
451                gen_ai.input.messages = tracing::field::Empty,
452                gen_ai.output.messages = tracing::field::Empty,
453            );
454
455            let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
456                let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
457                chat_span.follows_from(id).to_owned()
458            } else {
459                chat_span
460            };
461
462            if let Some(id) = chat_span.id() {
463                current_span_id.store(id.into_u64(), Ordering::SeqCst);
464            };
465
466            // Build history for completion request (input + new messages except last)
467            let history_for_request =
468                build_history_for_request(chat_history.as_deref(), history_for_current_turn);
469
470            let resp = build_completion_request(
471                &self.model,
472                prompt.clone(),
473                &history_for_request,
474                self.preamble.as_deref(),
475                &self.static_context,
476                self.temperature,
477                self.max_tokens,
478                self.additional_params.as_ref(),
479                self.tool_choice.as_ref(),
480                &self.tool_server_handle,
481                &self.dynamic_context,
482                self.output_schema.as_ref(),
483            )
484            .await?
485            .send()
486            .instrument(chat_span.clone())
487            .await?;
488
489            usage += resp.usage;
490
491            if let Some(ref hook) = self.hook
492                && let HookAction::Terminate { reason } =
493                    hook.on_completion_response(&prompt, &resp).await
494            {
495                return Err(PromptError::prompt_cancelled(
496                    build_full_history(chat_history.as_deref(), new_messages),
497                    reason,
498                ));
499            }
500
501            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
502                .choice
503                .iter()
504                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
505
506            // Some providers normalize textless terminal turns into a single empty text item
507            // because the generic completion response cannot represent an empty choice. Treat
508            // that sentinel as "no assistant output" so it does not pollute returned history.
509            if !is_empty_assistant_turn(&resp.choice) {
510                new_messages.push(Message::Assistant {
511                    id: resp.message_id.clone(),
512                    content: resp.choice.clone(),
513                });
514            }
515
516            if tool_calls.is_empty() {
517                let merged_texts = texts
518                    .into_iter()
519                    .filter_map(|content| {
520                        if let AssistantContent::Text(text) = content {
521                            Some(text.text.clone())
522                        } else {
523                            None
524                        }
525                    })
526                    .collect::<Vec<_>>()
527                    .join("\n");
528
529                if self.max_turns > 1 {
530                    tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
531                }
532
533                agent_span.record("gen_ai.completion", &merged_texts);
534                agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
535                agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
536                agent_span.record(
537                    "gen_ai.usage.cache_read.input_tokens",
538                    usage.cached_input_tokens,
539                );
540                agent_span.record(
541                    "gen_ai.usage.cache_creation.input_tokens",
542                    usage.cache_creation_input_tokens,
543                );
544                agent_span.record("gen_ai.usage.reasoning_tokens", usage.reasoning_tokens);
545
546                if let Some((memory, id)) = memory_handle.as_ref()
547                    && let Err(err) = memory.append(id, new_messages.clone()).await
548                {
549                    tracing::warn!(
550                        error = %err,
551                        conversation_id = %id,
552                        "conversation memory append failed; returning model response anyway"
553                    );
554                }
555
556                return Ok(PromptResponse::new(merged_texts, usage).with_messages(new_messages));
557            }
558
559            let hook = self.hook.clone();
560            let tool_server_handle = self.tool_server_handle.clone();
561
562            // For error handling in concurrent tool execution, we need to build full history
563            let full_history_for_errors =
564                build_full_history(chat_history.as_deref(), new_messages.clone());
565
566            let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
567            let tool_content = stream::iter(tool_calls)
568                .map(|choice| {
569                    let hook1 = hook.clone();
570                    let hook2 = hook.clone();
571                    let tool_server_handle = tool_server_handle.clone();
572
573                    let tool_span = info_span!(
574                        "execute_tool",
575                        gen_ai.operation.name = "execute_tool",
576                        gen_ai.tool.type = "function",
577                        gen_ai.tool.name = tracing::field::Empty,
578                        gen_ai.tool.call.id = tracing::field::Empty,
579                        gen_ai.tool.call.arguments = tracing::field::Empty,
580                        gen_ai.tool.call.result = tracing::field::Empty
581                    );
582
583                    let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
584                        let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
585                        tool_span.follows_from(id).to_owned()
586                    } else {
587                        tool_span
588                    };
589
590                    if let Some(id) = tool_span.id() {
591                        current_span_id.store(id.into_u64(), Ordering::SeqCst);
592                    };
593
594                    // Clone full history for error reporting in concurrent tool execution
595                    let cloned_history_for_error = full_history_for_errors.clone();
596
597                    async move {
598                        if let AssistantContent::ToolCall(tool_call) = choice {
599                            let tool_name = &tool_call.function.name;
600                            let args =
601                                json_utils::value_to_json_string(&tool_call.function.arguments);
602                            let internal_call_id = nanoid::nanoid!();
603                            let tool_span = tracing::Span::current();
604                            tool_span.record("gen_ai.tool.name", tool_name);
605                            tool_span.record("gen_ai.tool.call.id", &tool_call.id);
606                            tool_span.record("gen_ai.tool.call.arguments", &args);
607                            if let Some(hook) = hook1 {
608                                let action = hook
609                                    .on_tool_call(
610                                        tool_name,
611                                        tool_call.call_id.clone(),
612                                        &internal_call_id,
613                                        &args,
614                                    )
615                                    .await;
616
617                                if let ToolCallHookAction::Terminate { reason } = action {
618                                    return Err(PromptError::prompt_cancelled(
619                                        cloned_history_for_error,
620                                        reason,
621                                    ));
622                                }
623
624                                if let ToolCallHookAction::Skip { reason } = action {
625                                    // Tool execution rejected, return rejection message as tool result
626                                    tracing::info!(
627                                        tool_name = tool_name,
628                                        reason = reason,
629                                        "Tool call rejected"
630                                    );
631                                    if let Some(call_id) = tool_call.call_id.clone() {
632                                        return Ok(UserContent::tool_result_with_call_id(
633                                            tool_call.id.clone(),
634                                            call_id,
635                                            OneOrMany::one(reason.into()),
636                                        ));
637                                    } else {
638                                        return Ok(UserContent::tool_result(
639                                            tool_call.id.clone(),
640                                            OneOrMany::one(reason.into()),
641                                        ));
642                                    }
643                                }
644                            }
645                            let output = match tool_server_handle.call_tool(tool_name, &args).await
646                            {
647                                Ok(res) => res,
648                                Err(e) => {
649                                    tracing::warn!("Error while executing tool: {e}");
650                                    e.to_string()
651                                }
652                            };
653                            if let Some(hook) = hook2
654                                && let HookAction::Terminate { reason } = hook
655                                    .on_tool_result(
656                                        tool_name,
657                                        tool_call.call_id.clone(),
658                                        &internal_call_id,
659                                        &args,
660                                        &output.to_string(),
661                                    )
662                                    .await
663                            {
664                                return Err(PromptError::prompt_cancelled(
665                                    cloned_history_for_error,
666                                    reason,
667                                ));
668                            }
669
670                            tool_span.record("gen_ai.tool.call.result", &output);
671                            tracing::info!(
672                                "executed tool {tool_name} with args {args}. result: {output}"
673                            );
674                            if let Some(call_id) = tool_call.call_id.clone() {
675                                Ok(UserContent::tool_result_with_call_id(
676                                    tool_call.id.clone(),
677                                    call_id,
678                                    ToolResultContent::from_tool_output(output),
679                                ))
680                            } else {
681                                Ok(UserContent::tool_result(
682                                    tool_call.id.clone(),
683                                    ToolResultContent::from_tool_output(output),
684                                ))
685                            }
686                        } else {
687                            Err(PromptError::prompt_cancelled(
688                                Vec::new(),
689                                "tool execution received non-tool assistant content",
690                            ))
691                        }
692                    }
693                    .instrument(tool_span)
694                })
695                .buffer_unordered(self.concurrency)
696                .collect::<Vec<Result<UserContent, PromptError>>>()
697                .await
698                .into_iter()
699                .collect::<Result<Vec<_>, _>>()?;
700
701            let Some(content) = OneOrMany::from_iter_optional(tool_content) else {
702                return Err(PromptError::prompt_cancelled(
703                    build_full_history(chat_history.as_deref(), new_messages),
704                    "tool execution produced no tool results",
705                ));
706            };
707
708            new_messages.push(Message::User { content });
709        };
710
711        // If we reach here, we exceeded max turns without a final response
712        Err(PromptError::MaxTurnsError {
713            max_turns: self.max_turns,
714            chat_history: build_full_history(chat_history.as_deref(), new_messages).into(),
715            prompt: last_prompt.into(),
716        })
717    }
718}
719
720// ================================================================
721// TypedPromptRequest - for structured output with automatic deserialization
722// ================================================================
723
724use crate::completion::StructuredOutputError;
725use schemars::{JsonSchema, schema_for};
726use serde::de::DeserializeOwned;
727
728/// A builder for creating typed prompt requests that return deserialized structured output.
729///
730/// This struct wraps a standard `PromptRequest` and adds:
731/// - Automatic JSON schema generation from the target type `T`
732/// - Automatic deserialization of the response into `T`
733///
734/// The type parameter `S` represents the state of the request (Standard or Extended).
735/// Use `.extended_details()` to transition to Extended state for usage tracking.
736///
737/// # Example
738/// ```rust,ignore
739/// let forecast: WeatherForecast = agent
740///     .prompt_typed("What's the weather in NYC?")
741///     .max_turns(3)
742///     .await?;
743/// ```
744pub struct TypedPromptRequest<T, S, M, P>
745where
746    T: JsonSchema + DeserializeOwned + WasmCompatSend,
747    S: PromptType,
748    M: CompletionModel,
749    P: PromptHook<M>,
750{
751    inner: PromptRequest<S, M, P>,
752    _phantom: std::marker::PhantomData<T>,
753}
754
755impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
756where
757    T: JsonSchema + DeserializeOwned + WasmCompatSend,
758    M: CompletionModel,
759    P: PromptHook<M>,
760{
761    /// Create a new TypedPromptRequest from an agent.
762    ///
763    /// This automatically sets the output schema based on the type parameter `T`.
764    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
765        let mut inner = PromptRequest::from_agent(agent, prompt);
766        // Override the output schema with the schema for T
767        inner.output_schema = Some(schema_for!(T));
768        Self {
769            inner,
770            _phantom: std::marker::PhantomData,
771        }
772    }
773}
774
775impl<T, S, M, P> TypedPromptRequest<T, S, M, P>
776where
777    T: JsonSchema + DeserializeOwned + WasmCompatSend,
778    S: PromptType,
779    M: CompletionModel,
780    P: PromptHook<M>,
781{
782    /// Enable returning extended details for responses (includes aggregated token usage).
783    ///
784    /// Note: This changes the type of the response from `.send()` to return a `TypedPromptResponse<T>` struct
785    /// instead of just `T`. This is useful for tracking token usage across multiple turns
786    /// of conversation.
787    pub fn extended_details(self) -> TypedPromptRequest<T, Extended, M, P> {
788        TypedPromptRequest {
789            inner: self.inner.extended_details(),
790            _phantom: std::marker::PhantomData,
791        }
792    }
793
794    /// Set the maximum number of turns for multi-turn conversations.
795    ///
796    /// A given agent may require multiple turns for tool-calling before giving an answer.
797    /// If the maximum turn number is exceeded, it will return a
798    /// [`StructuredOutputError::PromptError`] wrapping a `MaxTurnsError`.
799    pub fn max_turns(mut self, depth: usize) -> Self {
800        self.inner = self.inner.max_turns(depth);
801        self
802    }
803
804    /// Add concurrency to the prompt request.
805    ///
806    /// This will cause the agent to execute tools concurrently.
807    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
808        self.inner = self.inner.with_tool_concurrency(concurrency);
809        self
810    }
811
812    /// Add chat history to the prompt request.
813    pub fn with_history<I, H>(mut self, history: I) -> Self
814    where
815        I: IntoIterator<Item = H>,
816        H: Into<Message>,
817    {
818        self.inner = self.inner.with_history(history);
819        self
820    }
821
822    /// Set the conversation id used to load and persist memory for this request.
823    ///
824    /// Overrides any default conversation id set on the agent. If memory is not
825    /// configured on the agent, this has no effect.
826    pub fn conversation(mut self, id: impl Into<String>) -> Self {
827        self.inner = self.inner.conversation(id);
828        self
829    }
830
831    /// Disable conversation memory for this request.
832    ///
833    /// History will neither be loaded from nor saved to the agent's memory backend.
834    pub fn without_memory(mut self) -> Self {
835        self.inner = self.inner.without_memory();
836        self
837    }
838
839    /// Attach a per-request hook for tool call events.
840    ///
841    /// This overrides any default hook set on the agent.
842    pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<T, S, M, P2>
843    where
844        P2: PromptHook<M>,
845    {
846        TypedPromptRequest {
847            inner: self.inner.with_hook(hook),
848            _phantom: std::marker::PhantomData,
849        }
850    }
851}
852
853impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
854where
855    T: JsonSchema + DeserializeOwned + WasmCompatSend,
856    M: CompletionModel,
857    P: PromptHook<M>,
858{
859    /// Send the typed prompt request and deserialize the response.
860    async fn send(self) -> Result<T, StructuredOutputError> {
861        let response = self.inner.send().await.map_err(Box::new)?;
862
863        if response.is_empty() {
864            return Err(StructuredOutputError::EmptyResponse);
865        }
866
867        let parsed: T = serde_json::from_str(&response)?;
868        Ok(parsed)
869    }
870}
871
872impl<T, M, P> TypedPromptRequest<T, Extended, M, P>
873where
874    T: JsonSchema + DeserializeOwned + WasmCompatSend,
875    M: CompletionModel,
876    P: PromptHook<M>,
877{
878    /// Send the typed prompt request with extended details and deserialize the response.
879    async fn send(self) -> Result<TypedPromptResponse<T>, StructuredOutputError> {
880        let response = self.inner.send().await.map_err(Box::new)?;
881
882        if response.output.is_empty() {
883            return Err(StructuredOutputError::EmptyResponse);
884        }
885
886        let parsed: T = serde_json::from_str(&response.output)?;
887        Ok(TypedPromptResponse::new(parsed, response.usage))
888    }
889}
890
891impl<T, M, P> IntoFuture for TypedPromptRequest<T, Standard, M, P>
892where
893    T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
894    M: CompletionModel + 'static,
895    P: PromptHook<M> + 'static,
896{
897    type Output = Result<T, StructuredOutputError>;
898    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
899
900    fn into_future(self) -> Self::IntoFuture {
901        Box::pin(self.send())
902    }
903}
904
905impl<T, M, P> IntoFuture for TypedPromptRequest<T, Extended, M, P>
906where
907    T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
908    M: CompletionModel + 'static,
909    P: PromptHook<M> + 'static,
910{
911    type Output = Result<TypedPromptResponse<T>, StructuredOutputError>;
912    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
913
914    fn into_future(self) -> Self::IntoFuture {
915        Box::pin(self.send())
916    }
917}
918
919#[cfg(test)]
920mod tests {
921    use super::TypedPromptResponse;
922    use crate::{
923        agent::AgentBuilder,
924        completion::{
925            AssistantContent, CompletionError, CompletionRequest, Message, Prompt, PromptError,
926            Usage,
927        },
928        message::UserContent,
929        test_utils::{
930            AppendFailingMemory, CountingMemory, FailingMemory, MockCompletionModel, MockTurn,
931        },
932    };
933    use serde::{Deserialize, Serialize};
934    use serde_json::json;
935
936    #[derive(Serialize)]
937    struct SerializeOnly {
938        value: &'static str,
939    }
940
941    #[derive(Deserialize)]
942    struct DeserializeOnly {
943        value: String,
944    }
945
946    #[test]
947    fn typed_prompt_response_serializes_with_serialize_only_output() {
948        let response = TypedPromptResponse::new(
949            SerializeOnly { value: "ok" },
950            Usage {
951                input_tokens: 1,
952                output_tokens: 2,
953                total_tokens: 3,
954                cached_input_tokens: 0,
955                cache_creation_input_tokens: 0,
956                reasoning_tokens: 0,
957            },
958        );
959
960        let json = serde_json::to_string(&response).expect("serialize typed prompt response");
961        assert!(json.contains("\"value\":\"ok\""));
962    }
963
964    #[test]
965    fn typed_prompt_response_deserializes_with_deserialize_only_output() {
966        let response: TypedPromptResponse<DeserializeOnly> = serde_json::from_str(
967            r#"{"output":{"value":"ok"},"usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3,"cached_input_tokens":0,"cache_creation_input_tokens":0,"reasoning_tokens":0}}"#,
968        )
969        .expect("deserialize typed prompt response");
970
971        assert_eq!(response.output.value, "ok");
972        assert_eq!(response.usage.input_tokens, 1);
973        assert_eq!(response.usage.output_tokens, 2);
974        assert_eq!(response.usage.total_tokens, 3);
975    }
976
977    fn validate_follow_up_tool_history(request: &CompletionRequest) {
978        let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
979        assert_eq!(
980            history.len(),
981            3,
982            "follow-up request should contain the prompt, assistant tool call, and user tool result: {history:?}"
983        );
984
985        assert!(matches!(
986            history.first(),
987            Some(Message::User { content })
988                if matches!(
989                    content.first(),
990                    UserContent::Text(text) if text.text == "do tool work"
991                )
992        ));
993
994        assert!(matches!(
995            history.get(1),
996            Some(Message::Assistant { content, .. })
997                if matches!(
998                    content.first(),
999                    AssistantContent::ToolCall(tool_call)
1000                        if tool_call.id == "tool_call_1"
1001                            && tool_call.call_id.as_deref() == Some("call_1")
1002                )
1003        ));
1004
1005        assert!(matches!(
1006            history.get(2),
1007            Some(Message::User { content })
1008                if matches!(
1009                    content.first(),
1010                    UserContent::ToolResult(tool_result)
1011                        if tool_result.id == "tool_call_1"
1012                            && tool_result.call_id.as_deref() == Some("call_1")
1013                )
1014        ));
1015    }
1016
1017    #[tokio::test]
1018    async fn prompt_request_stops_cleanly_on_empty_terminal_turn() {
1019        let model = MockCompletionModel::new([
1020            MockTurn::tool_call("tool_call_1", "missing_tool", json!({"input": "value"}))
1021                .with_call_id("call_1")
1022                .with_usage(Usage {
1023                    input_tokens: 1,
1024                    output_tokens: 1,
1025                    total_tokens: 2,
1026                    cached_input_tokens: 0,
1027                    cache_creation_input_tokens: 0,
1028                    reasoning_tokens: 0,
1029                }),
1030            MockTurn::text("").with_usage(Usage {
1031                input_tokens: 1,
1032                output_tokens: 1,
1033                total_tokens: 2,
1034                cached_input_tokens: 0,
1035                cache_creation_input_tokens: 0,
1036                reasoning_tokens: 0,
1037            }),
1038        ]);
1039        let agent = AgentBuilder::new(model).build();
1040
1041        let response = agent
1042            .prompt("do tool work")
1043            .max_turns(3)
1044            .extended_details()
1045            .await
1046            .expect("empty terminal turn should not error");
1047
1048        assert!(response.output.is_empty());
1049        assert_eq!(
1050            response.usage,
1051            Usage {
1052                input_tokens: 2,
1053                output_tokens: 2,
1054                total_tokens: 4,
1055                cached_input_tokens: 0,
1056                cache_creation_input_tokens: 0,
1057                reasoning_tokens: 0,
1058            }
1059        );
1060
1061        let history = response
1062            .messages
1063            .expect("extended response should include history");
1064        assert_eq!(history.len(), 3);
1065        assert!(matches!(
1066            history.first(),
1067            Some(Message::User { content })
1068                if matches!(
1069                    content.first(),
1070                    UserContent::Text(text) if text.text == "do tool work"
1071                )
1072        ));
1073        assert!(history.iter().any(|message| matches!(
1074            message,
1075            Message::Assistant { content, .. }
1076                if matches!(
1077                    content.first(),
1078                    AssistantContent::ToolCall(tool_call)
1079                        if tool_call.id == "tool_call_1"
1080                            && tool_call.call_id.as_deref() == Some("call_1")
1081                )
1082        )));
1083        assert!(history.iter().any(|message| matches!(
1084            message,
1085            Message::User { content }
1086                if matches!(
1087                    content.first(),
1088                    UserContent::ToolResult(tool_result)
1089                        if tool_result.id == "tool_call_1"
1090                            && tool_result.call_id.as_deref() == Some("call_1")
1091                )
1092        )));
1093        assert!(!history.iter().any(|message| matches!(
1094            message,
1095            Message::Assistant { content, .. }
1096                if content.iter().any(|item| matches!(
1097                    item,
1098                    AssistantContent::Text(text) if text.text.is_empty()
1099                ))
1100        )));
1101        let requests = agent.model.requests();
1102        assert_eq!(requests.len(), 2);
1103        validate_follow_up_tool_history(&requests[1]);
1104    }
1105
1106    // ----- Conversation memory integration tests -----
1107
1108    use crate::memory::{ConversationMemory, InMemoryConversationMemory};
1109
1110    #[tokio::test]
1111    async fn memory_loads_into_request_history() {
1112        let memory = InMemoryConversationMemory::new();
1113        memory
1114            .append(
1115                "thread-1",
1116                vec![Message::user("hello"), Message::assistant("hi there")],
1117            )
1118            .await
1119            .unwrap();
1120
1121        let model = MockCompletionModel::text("ack");
1122        let recorded = model.clone();
1123
1124        let agent = AgentBuilder::new(model).memory(memory).build();
1125        let _ = agent
1126            .prompt("ping")
1127            .conversation("thread-1")
1128            .await
1129            .expect("prompt should succeed");
1130
1131        let received = recorded.requests()[0]
1132            .chat_history
1133            .iter()
1134            .cloned()
1135            .collect::<Vec<_>>();
1136        assert_eq!(
1137            received.len(),
1138            3,
1139            "loaded memory (2) + current prompt should appear in request: {received:?}"
1140        );
1141    }
1142
1143    #[tokio::test]
1144    async fn memory_appends_full_turn_after_success() {
1145        let memory = InMemoryConversationMemory::new();
1146        let model = MockCompletionModel::text("ack");
1147        let agent = AgentBuilder::new(model).memory(memory.clone()).build();
1148
1149        let _ = agent
1150            .prompt("hello")
1151            .conversation("t1")
1152            .await
1153            .expect("prompt should succeed");
1154
1155        let stored = memory.load("t1").await.unwrap();
1156        assert_eq!(stored.len(), 2, "user prompt + assistant response saved");
1157    }
1158
1159    #[tokio::test]
1160    async fn explicit_with_history_overrides_memory() {
1161        let memory = CountingMemory::default();
1162        memory
1163            .inner()
1164            .append("t1", vec![Message::user("from-memory")])
1165            .await
1166            .unwrap();
1167
1168        let model = MockCompletionModel::text("ack");
1169        let recorded = model.clone();
1170
1171        let agent = AgentBuilder::new(model).memory(memory.clone()).build();
1172        let _ = agent
1173            .prompt("hello")
1174            .conversation("t1")
1175            .with_history(vec![Message::user("from-caller")])
1176            .await
1177            .expect("prompt should succeed");
1178
1179        assert_eq!(memory.load_count(), 0, "load skipped");
1180        let appends = memory.append_count();
1181        assert_eq!(appends, 0, "append skipped");
1182
1183        let received = recorded.requests()[0]
1184            .chat_history
1185            .iter()
1186            .cloned()
1187            .collect::<Vec<_>>();
1188        assert_eq!(received.len(), 2, "caller history (1) + current prompt");
1189        assert!(matches!(
1190            received.first(),
1191            Some(Message::User { content })
1192                if matches!(content.first(), UserContent::Text(t) if t.text == "from-caller")
1193        ));
1194    }
1195
1196    #[tokio::test]
1197    async fn memory_unchanged_on_provider_error() {
1198        let memory = InMemoryConversationMemory::new();
1199        let model = MockCompletionModel::new([MockTurn::error("boom")]);
1200
1201        let agent = AgentBuilder::new(model).memory(memory.clone()).build();
1202        let result = agent.prompt("hello").conversation("t1").await;
1203        assert!(result.is_err());
1204
1205        let stored = memory.load("t1").await.unwrap();
1206        assert!(stored.is_empty(), "no append on error");
1207    }
1208
1209    #[tokio::test]
1210    async fn missing_conversation_id_behaves_as_no_memory() {
1211        let memory = CountingMemory::default();
1212        let model = MockCompletionModel::text("ack");
1213        let agent = AgentBuilder::new(model).memory(memory.clone()).build();
1214
1215        let _ = agent.prompt("hello").await.expect("prompt should succeed");
1216
1217        assert_eq!(memory.load_count(), 0);
1218        assert_eq!(memory.append_count(), 0);
1219    }
1220
1221    #[tokio::test]
1222    async fn default_conversation_id_is_used_when_none_per_request() {
1223        let memory = InMemoryConversationMemory::new();
1224        let model = MockCompletionModel::text("ack");
1225        let agent = AgentBuilder::new(model)
1226            .memory(memory.clone())
1227            .conversation_id("default-thread")
1228            .build();
1229
1230        let _ = agent.prompt("hello").await.expect("prompt should succeed");
1231        let stored = memory.load("default-thread").await.unwrap();
1232        assert_eq!(stored.len(), 2);
1233    }
1234
1235    #[tokio::test]
1236    async fn with_filter_truncates_loaded_history() {
1237        let memory = InMemoryConversationMemory::new()
1238            .with_filter(|msgs: Vec<Message>| msgs.into_iter().rev().take(2).rev().collect());
1239        memory
1240            .append(
1241                "t1",
1242                vec![
1243                    Message::user("1"),
1244                    Message::assistant("2"),
1245                    Message::user("3"),
1246                    Message::assistant("4"),
1247                ],
1248            )
1249            .await
1250            .unwrap();
1251
1252        let model = MockCompletionModel::text("ack");
1253        let recorded = model.clone();
1254        let agent = AgentBuilder::new(model).memory(memory).build();
1255
1256        let _ = agent
1257            .prompt("ping")
1258            .conversation("t1")
1259            .await
1260            .expect("prompt should succeed");
1261
1262        let received = recorded.requests()[0]
1263            .chat_history
1264            .iter()
1265            .cloned()
1266            .collect::<Vec<_>>();
1267        assert_eq!(
1268            received.len(),
1269            3,
1270            "window-truncated history (2) + current prompt"
1271        );
1272    }
1273
1274    #[tokio::test]
1275    async fn without_memory_disables_for_request() {
1276        let memory = CountingMemory::default();
1277        let model = MockCompletionModel::text("ack");
1278        let agent = AgentBuilder::new(model)
1279            .memory(memory.clone())
1280            .conversation_id("t1")
1281            .build();
1282
1283        let _ = agent
1284            .prompt("hello")
1285            .without_memory()
1286            .await
1287            .expect("prompt should succeed");
1288
1289        assert_eq!(memory.load_count(), 0);
1290        assert_eq!(memory.append_count(), 0);
1291    }
1292
1293    #[tokio::test]
1294    async fn memory_load_error_surfaces_as_prompt_error() {
1295        let model = MockCompletionModel::text("ack");
1296        let agent = AgentBuilder::new(model)
1297            .memory(FailingMemory::default())
1298            .build();
1299        let result = agent.prompt("hello").conversation("t1").await;
1300
1301        match result {
1302            Err(PromptError::CompletionError(CompletionError::RequestError(err))) => {
1303                let msg = format!("{err}");
1304                assert!(msg.contains("load boom"), "got: {msg}");
1305            }
1306            other => panic!("expected PromptError::CompletionError(RequestError), got {other:?}"),
1307        }
1308    }
1309
1310    #[tokio::test]
1311    async fn memory_append_error_does_not_drop_response() {
1312        let model = MockCompletionModel::text("ack");
1313        let agent = AgentBuilder::new(model)
1314            .memory(AppendFailingMemory::default())
1315            .build();
1316        let response: String = agent
1317            .prompt("hello")
1318            .conversation("t1")
1319            .await
1320            .expect("append failure must not block successful completion");
1321
1322        assert!(!response.is_empty());
1323    }
1324}