Skip to main content

rig/agent/prompt_request/
mod.rs

1pub mod hooks;
2pub mod streaming;
3
4use hooks::{HookAction, PromptHook, ToolCallHookAction};
5use std::{
6    future::IntoFuture,
7    marker::PhantomData,
8    sync::{
9        Arc,
10        atomic::{AtomicU64, Ordering},
11    },
12};
13use tracing::{Instrument, span::Id};
14
15use futures::{StreamExt, stream};
16use tracing::info_span;
17
18use crate::{
19    OneOrMany,
20    completion::{CompletionModel, Document, Message, PromptError, Usage},
21    json_utils,
22    message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
23    tool::server::ToolServerHandle,
24    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
25};
26
27use super::{
28    Agent,
29    completion::{DynamicContextStore, build_completion_request},
30};
31
32pub trait PromptType {}
33pub struct Standard;
34pub struct Extended;
35
36impl PromptType for Standard {}
37impl PromptType for Extended {}
38
39/// A builder for creating prompt requests with customizable options.
40/// Uses generics to track which options have been set during the build process.
41///
42/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
43/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
44/// attempting to await (which will send the prompt request) can potentially return
45/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
46/// back to back.
47pub struct PromptRequest<'a, S, M, P>
48where
49    S: PromptType,
50    M: CompletionModel,
51    P: PromptHook<M>,
52{
53    /// The prompt message to send to the model
54    prompt: Message,
55    /// Optional chat history to include with the prompt
56    /// Note: chat history needs to outlive the agent as it might be used with other agents
57    chat_history: Option<&'a mut Vec<Message>>,
58    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
59    max_turns: usize,
60
61    // Agent data (cloned from agent to allow hook type transitions):
62    /// The completion model
63    model: Arc<M>,
64    /// Agent name for logging
65    agent_name: Option<String>,
66    /// System prompt
67    preamble: Option<String>,
68    /// Static context documents
69    static_context: Vec<Document>,
70    /// Temperature setting
71    temperature: Option<f64>,
72    /// Max tokens setting
73    max_tokens: Option<u64>,
74    /// Additional model parameters
75    additional_params: Option<serde_json::Value>,
76    /// Tool server handle for tool execution
77    tool_server_handle: ToolServerHandle,
78    /// Dynamic context store
79    dynamic_context: DynamicContextStore,
80    /// Tool choice setting
81    tool_choice: Option<ToolChoice>,
82
83    /// Phantom data to track the type of the request
84    state: PhantomData<S>,
85    /// Optional per-request hook for events
86    hook: Option<P>,
87    /// How many tools should be executed at the same time (1 by default).
88    concurrency: usize,
89    /// Optional JSON Schema for structured output
90    output_schema: Option<schemars::Schema>,
91}
92
93impl<'a, M, P> PromptRequest<'a, Standard, M, P>
94where
95    M: CompletionModel,
96    P: PromptHook<M>,
97{
98    /// Create a new PromptRequest from an agent, cloning the agent's data and default hook.
99    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
100        PromptRequest {
101            prompt: prompt.into(),
102            chat_history: None,
103            max_turns: agent.default_max_turns.unwrap_or_default(),
104            model: agent.model.clone(),
105            agent_name: agent.name.clone(),
106            preamble: agent.preamble.clone(),
107            static_context: agent.static_context.clone(),
108            temperature: agent.temperature,
109            max_tokens: agent.max_tokens,
110            additional_params: agent.additional_params.clone(),
111            tool_server_handle: agent.tool_server_handle.clone(),
112            dynamic_context: agent.dynamic_context.clone(),
113            tool_choice: agent.tool_choice.clone(),
114            state: PhantomData,
115            hook: agent.hook.clone(),
116            concurrency: 1,
117            output_schema: agent.output_schema.clone(),
118        }
119    }
120}
121
122impl<'a, S, M, P> PromptRequest<'a, S, M, P>
123where
124    S: PromptType,
125    M: CompletionModel,
126    P: PromptHook<M>,
127{
128    /// Enable returning extended details for responses (includes aggregated token usage
129    /// and the full message history accumulated during the agent loop).
130    ///
131    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
132    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
133    /// of conversation and inspecting the full message exchange.
134    pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
135        PromptRequest {
136            prompt: self.prompt,
137            chat_history: self.chat_history,
138            max_turns: self.max_turns,
139            model: self.model,
140            agent_name: self.agent_name,
141            preamble: self.preamble,
142            static_context: self.static_context,
143            temperature: self.temperature,
144            max_tokens: self.max_tokens,
145            additional_params: self.additional_params,
146            tool_server_handle: self.tool_server_handle,
147            dynamic_context: self.dynamic_context,
148            tool_choice: self.tool_choice,
149            state: PhantomData,
150            hook: self.hook,
151            concurrency: self.concurrency,
152            output_schema: self.output_schema,
153        }
154    }
155
156    /// Set the maximum number of turns for multi-turn conversations. A given agent may require multiple turns for tool-calling before giving an answer.
157    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
158    pub fn max_turns(mut self, depth: usize) -> Self {
159        self.max_turns = depth;
160        self
161    }
162
163    /// Add concurrency to the prompt request.
164    /// This will cause the agent to execute tools concurrently.
165    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
166        self.concurrency = concurrency;
167        self
168    }
169
170    /// Add chat history to the prompt request
171    pub fn with_history(mut self, history: &'a mut Vec<Message>) -> Self {
172        self.chat_history = Some(history);
173        self
174    }
175
176    /// Attach a per-request hook for tool call events.
177    /// This overrides any default hook set on the agent.
178    pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
179    where
180        P2: PromptHook<M>,
181    {
182        PromptRequest {
183            prompt: self.prompt,
184            chat_history: self.chat_history,
185            max_turns: self.max_turns,
186            model: self.model,
187            agent_name: self.agent_name,
188            preamble: self.preamble,
189            static_context: self.static_context,
190            temperature: self.temperature,
191            max_tokens: self.max_tokens,
192            additional_params: self.additional_params,
193            tool_server_handle: self.tool_server_handle,
194            dynamic_context: self.dynamic_context,
195            tool_choice: self.tool_choice,
196            state: PhantomData,
197            hook: Some(hook),
198            concurrency: self.concurrency,
199            output_schema: self.output_schema,
200        }
201    }
202}
203
204/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
205///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
206///  directly via the associated type.
207impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
208where
209    M: CompletionModel + 'a,
210    P: PromptHook<M> + 'static,
211{
212    type Output = Result<String, PromptError>;
213    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
214
215    fn into_future(self) -> Self::IntoFuture {
216        Box::pin(self.send())
217    }
218}
219
220impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
221where
222    M: CompletionModel + 'a,
223    P: PromptHook<M> + 'static,
224{
225    type Output = Result<PromptResponse, PromptError>;
226    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
227
228    fn into_future(self) -> Self::IntoFuture {
229        Box::pin(self.send())
230    }
231}
232
233impl<M, P> PromptRequest<'_, Standard, M, P>
234where
235    M: CompletionModel,
236    P: PromptHook<M>,
237{
238    async fn send(self) -> Result<String, PromptError> {
239        self.extended_details().send().await.map(|resp| resp.output)
240    }
241}
242
243#[derive(Debug, Clone)]
244#[non_exhaustive]
245pub struct PromptResponse {
246    pub output: String,
247    pub usage: Usage,
248    /// The message history accumulated during the agent loop.
249    /// Always populated when using `extended_details()`.
250    pub messages: Option<Vec<Message>>,
251}
252
253impl std::fmt::Display for PromptResponse {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        self.output.fmt(f)
256    }
257}
258
259impl PromptResponse {
260    pub fn new(output: impl Into<String>, usage: Usage) -> Self {
261        Self {
262            output: output.into(),
263            usage,
264            messages: None,
265        }
266    }
267
268    pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
269        self.messages = Some(messages);
270        self
271    }
272}
273
274#[derive(Debug, Clone)]
275pub struct TypedPromptResponse<T> {
276    pub output: T,
277    pub usage: Usage,
278}
279
280impl<T> TypedPromptResponse<T> {
281    pub fn new(output: T, usage: Usage) -> Self {
282        Self { output, usage }
283    }
284}
285
286const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
287
288impl<M, P> PromptRequest<'_, Extended, M, P>
289where
290    M: CompletionModel,
291    P: PromptHook<M>,
292{
293    fn agent_name(&self) -> &str {
294        self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
295    }
296
297    async fn send(mut self) -> Result<PromptResponse, PromptError> {
298        let agent_span = if tracing::Span::current().is_disabled() {
299            info_span!(
300                "invoke_agent",
301                gen_ai.operation.name = "invoke_agent",
302                gen_ai.agent.name = self.agent_name(),
303                gen_ai.system_instructions = self.preamble,
304                gen_ai.prompt = tracing::field::Empty,
305                gen_ai.completion = tracing::field::Empty,
306                gen_ai.usage.input_tokens = tracing::field::Empty,
307                gen_ai.usage.output_tokens = tracing::field::Empty,
308                gen_ai.usage.cached_tokens = tracing::field::Empty,
309            )
310        } else {
311            tracing::Span::current()
312        };
313
314        if let Some(text) = self.prompt.rag_text() {
315            agent_span.record("gen_ai.prompt", text);
316        }
317
318        // Capture agent_name before borrowing chat_history
319        let agent_name_for_span = self.agent_name.clone();
320
321        let chat_history = if let Some(history) = self.chat_history.as_mut() {
322            history.push(self.prompt.to_owned());
323            history
324        } else {
325            &mut vec![self.prompt.to_owned()]
326        };
327
328        let mut current_max_turns = 0;
329        let mut usage = Usage::new();
330        let current_span_id: AtomicU64 = AtomicU64::new(0);
331
332        // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
333        let last_prompt = loop {
334            let prompt = chat_history
335                .last()
336                .cloned()
337                .expect("there should always be at least one message in the chat history");
338
339            if current_max_turns > self.max_turns + 1 {
340                break prompt;
341            }
342
343            current_max_turns += 1;
344
345            if self.max_turns > 1 {
346                tracing::info!(
347                    "Current conversation depth: {}/{}",
348                    current_max_turns,
349                    self.max_turns
350                );
351            }
352
353            if let Some(ref hook) = self.hook
354                && let HookAction::Terminate { reason } = hook
355                    .on_completion_call(&prompt, &chat_history[..chat_history.len() - 1])
356                    .await
357            {
358                return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
359            }
360
361            let span = tracing::Span::current();
362            let chat_span = info_span!(
363                target: "rig::agent_chat",
364                parent: &span,
365                "chat",
366                gen_ai.operation.name = "chat",
367                gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
368                gen_ai.system_instructions = self.preamble,
369                gen_ai.provider.name = tracing::field::Empty,
370                gen_ai.request.model = tracing::field::Empty,
371                gen_ai.response.id = tracing::field::Empty,
372                gen_ai.response.model = tracing::field::Empty,
373                gen_ai.usage.output_tokens = tracing::field::Empty,
374                gen_ai.usage.input_tokens = tracing::field::Empty,
375                gen_ai.usage.cached_tokens = tracing::field::Empty,
376                gen_ai.input.messages = tracing::field::Empty,
377                gen_ai.output.messages = tracing::field::Empty,
378            );
379
380            let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
381                let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
382                chat_span.follows_from(id).to_owned()
383            } else {
384                chat_span
385            };
386
387            if let Some(id) = chat_span.id() {
388                current_span_id.store(id.into_u64(), Ordering::SeqCst);
389            };
390
391            let resp = build_completion_request(
392                &self.model,
393                prompt.clone(),
394                chat_history[..chat_history.len() - 1].to_vec(),
395                self.preamble.as_deref(),
396                &self.static_context,
397                self.temperature,
398                self.max_tokens,
399                self.additional_params.as_ref(),
400                self.tool_choice.as_ref(),
401                &self.tool_server_handle,
402                &self.dynamic_context,
403                self.output_schema.as_ref(),
404            )
405            .await?
406            .send()
407            .instrument(chat_span.clone())
408            .await?;
409
410            usage += resp.usage;
411
412            if let Some(ref hook) = self.hook
413                && let HookAction::Terminate { reason } =
414                    hook.on_completion_response(&prompt, &resp).await
415            {
416                return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
417            }
418
419            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
420                .choice
421                .iter()
422                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
423
424            chat_history.push(Message::Assistant {
425                id: resp.message_id.clone(),
426                content: resp.choice.clone(),
427            });
428
429            if tool_calls.is_empty() {
430                let merged_texts = texts
431                    .into_iter()
432                    .filter_map(|content| {
433                        if let AssistantContent::Text(text) = content {
434                            Some(text.text.clone())
435                        } else {
436                            None
437                        }
438                    })
439                    .collect::<Vec<_>>()
440                    .join("\n");
441
442                if self.max_turns > 1 {
443                    tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
444                }
445
446                agent_span.record("gen_ai.completion", &merged_texts);
447                agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
448                agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
449                agent_span.record("gen_ai.usage.cached_tokens", usage.cached_input_tokens);
450
451                // If there are no tool calls, depth is not relevant, we can just return the merged text response.
452                return Ok(
453                    PromptResponse::new(merged_texts, usage).with_messages(chat_history.to_vec())
454                );
455            }
456
457            let hook = self.hook.clone();
458            let tool_server_handle = self.tool_server_handle.clone();
459
460            let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
461            let tool_content = stream::iter(tool_calls)
462                .map(|choice| {
463                    let hook1 = hook.clone();
464                    let hook2 = hook.clone();
465                    let tool_server_handle = tool_server_handle.clone();
466
467                    let tool_span = info_span!(
468                        "execute_tool",
469                        gen_ai.operation.name = "execute_tool",
470                        gen_ai.tool.type = "function",
471                        gen_ai.tool.name = tracing::field::Empty,
472                        gen_ai.tool.call.id = tracing::field::Empty,
473                        gen_ai.tool.call.arguments = tracing::field::Empty,
474                        gen_ai.tool.call.result = tracing::field::Empty
475                    );
476
477                    let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
478                        let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
479                        tool_span.follows_from(id).to_owned()
480                    } else {
481                        tool_span
482                    };
483
484                    if let Some(id) = tool_span.id() {
485                        current_span_id.store(id.into_u64(), Ordering::SeqCst);
486                    };
487
488                    let cloned_chat_history = chat_history.clone().to_vec();
489
490                    async move {
491                        if let AssistantContent::ToolCall(tool_call) = choice {
492                            let tool_name = &tool_call.function.name;
493                            let args =
494                                json_utils::value_to_json_string(&tool_call.function.arguments);
495                            let internal_call_id = nanoid::nanoid!();
496                            let tool_span = tracing::Span::current();
497                            tool_span.record("gen_ai.tool.name", tool_name);
498                            tool_span.record("gen_ai.tool.call.id", &tool_call.id);
499                            tool_span.record("gen_ai.tool.call.arguments", &args);
500                            if let Some(hook) = hook1 {
501                                let action = hook
502                                    .on_tool_call(
503                                        tool_name,
504                                        tool_call.call_id.clone(),
505                                        &internal_call_id,
506                                        &args,
507                                    )
508                                    .await;
509
510                                if let ToolCallHookAction::Terminate { reason } = action {
511                                    return Err(PromptError::prompt_cancelled(
512                                        cloned_chat_history,
513                                        reason,
514                                    ));
515                                }
516
517                                if let ToolCallHookAction::Skip { reason } = action {
518                                    // Tool execution rejected, return rejection message as tool result
519                                    tracing::info!(
520                                        tool_name = tool_name,
521                                        reason = reason,
522                                        "Tool call rejected"
523                                    );
524                                    if let Some(call_id) = tool_call.call_id.clone() {
525                                        return Ok(UserContent::tool_result_with_call_id(
526                                            tool_call.id.clone(),
527                                            call_id,
528                                            OneOrMany::one(reason.into()),
529                                        ));
530                                    } else {
531                                        return Ok(UserContent::tool_result(
532                                            tool_call.id.clone(),
533                                            OneOrMany::one(reason.into()),
534                                        ));
535                                    }
536                                }
537                            }
538                            let output = match tool_server_handle.call_tool(tool_name, &args).await
539                            {
540                                Ok(res) => res,
541                                Err(e) => {
542                                    tracing::warn!("Error while executing tool: {e}");
543                                    e.to_string()
544                                }
545                            };
546                            if let Some(hook) = hook2
547                                && let HookAction::Terminate { reason } = hook
548                                    .on_tool_result(
549                                        tool_name,
550                                        tool_call.call_id.clone(),
551                                        &internal_call_id,
552                                        &args,
553                                        &output.to_string(),
554                                    )
555                                    .await
556                            {
557                                return Err(PromptError::prompt_cancelled(
558                                    cloned_chat_history,
559                                    reason,
560                                ));
561                            }
562
563                            tool_span.record("gen_ai.tool.call.result", &output);
564                            tracing::info!(
565                                "executed tool {tool_name} with args {args}. result: {output}"
566                            );
567                            if let Some(call_id) = tool_call.call_id.clone() {
568                                Ok(UserContent::tool_result_with_call_id(
569                                    tool_call.id.clone(),
570                                    call_id,
571                                    ToolResultContent::from_tool_output(output),
572                                ))
573                            } else {
574                                Ok(UserContent::tool_result(
575                                    tool_call.id.clone(),
576                                    ToolResultContent::from_tool_output(output),
577                                ))
578                            }
579                        } else {
580                            unreachable!(
581                                "This should never happen as we already filtered for `ToolCall`"
582                            )
583                        }
584                    }
585                    .instrument(tool_span)
586                })
587                .buffer_unordered(self.concurrency)
588                .collect::<Vec<Result<UserContent, PromptError>>>()
589                .await
590                .into_iter()
591                .collect::<Result<Vec<_>, _>>()?;
592
593            chat_history.push(Message::User {
594                content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
595            });
596        };
597
598        // If we reach here, we never resolved the final tool call. We need to do ... something.
599        Err(PromptError::MaxTurnsError {
600            max_turns: self.max_turns,
601            chat_history: Box::new(chat_history.clone()),
602            prompt: Box::new(last_prompt),
603        })
604    }
605}
606
607// ================================================================
608// TypedPromptRequest - for structured output with automatic deserialization
609// ================================================================
610
611use crate::completion::StructuredOutputError;
612use schemars::{JsonSchema, schema_for};
613use serde::de::DeserializeOwned;
614
615/// A builder for creating typed prompt requests that return deserialized structured output.
616///
617/// This struct wraps a standard `PromptRequest` and adds:
618/// - Automatic JSON schema generation from the target type `T`
619/// - Automatic deserialization of the response into `T`
620///
621/// The type parameter `S` represents the state of the request (Standard or Extended).
622/// Use `.extended_details()` to transition to Extended state for usage tracking.
623///
624/// # Example
625/// ```rust,ignore
626/// let forecast: WeatherForecast = agent
627///     .prompt_typed("What's the weather in NYC?")
628///     .max_turns(3)
629///     .await?;
630/// ```
631pub struct TypedPromptRequest<'a, T, S, M, P>
632where
633    T: JsonSchema + DeserializeOwned + WasmCompatSend,
634    S: PromptType,
635    M: CompletionModel,
636    P: PromptHook<M>,
637{
638    inner: PromptRequest<'a, S, M, P>,
639    _phantom: std::marker::PhantomData<T>,
640}
641
642impl<'a, T, M, P> TypedPromptRequest<'a, T, Standard, M, P>
643where
644    T: JsonSchema + DeserializeOwned + WasmCompatSend,
645    M: CompletionModel,
646    P: PromptHook<M>,
647{
648    /// Create a new TypedPromptRequest from an agent.
649    ///
650    /// This automatically sets the output schema based on the type parameter `T`.
651    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
652        let mut inner = PromptRequest::from_agent(agent, prompt);
653        // Override the output schema with the schema for T
654        inner.output_schema = Some(schema_for!(T));
655        Self {
656            inner,
657            _phantom: std::marker::PhantomData,
658        }
659    }
660}
661
662impl<'a, T, S, M, P> TypedPromptRequest<'a, T, S, M, P>
663where
664    T: JsonSchema + DeserializeOwned + WasmCompatSend,
665    S: PromptType,
666    M: CompletionModel,
667    P: PromptHook<M>,
668{
669    /// Enable returning extended details for responses (includes aggregated token usage).
670    ///
671    /// Note: This changes the type of the response from `.send()` to return a `TypedPromptResponse<T>` struct
672    /// instead of just `T`. This is useful for tracking token usage across multiple turns
673    /// of conversation.
674    pub fn extended_details(self) -> TypedPromptRequest<'a, T, Extended, M, P> {
675        TypedPromptRequest {
676            inner: self.inner.extended_details(),
677            _phantom: std::marker::PhantomData,
678        }
679    }
680
681    /// Set the maximum number of turns for multi-turn conversations.
682    ///
683    /// A given agent may require multiple turns for tool-calling before giving an answer.
684    /// If the maximum turn number is exceeded, it will return a
685    /// [`StructuredOutputError::PromptError`] wrapping a `MaxTurnsError`.
686    pub fn max_turns(mut self, depth: usize) -> Self {
687        self.inner = self.inner.max_turns(depth);
688        self
689    }
690
691    /// Add concurrency to the prompt request.
692    ///
693    /// This will cause the agent to execute tools concurrently.
694    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
695        self.inner = self.inner.with_tool_concurrency(concurrency);
696        self
697    }
698
699    /// Add chat history to the prompt request.
700    pub fn with_history(mut self, history: &'a mut Vec<Message>) -> Self {
701        self.inner = self.inner.with_history(history);
702        self
703    }
704
705    /// Attach a per-request hook for tool call events.
706    ///
707    /// This overrides any default hook set on the agent.
708    pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<'a, T, S, M, P2>
709    where
710        P2: PromptHook<M>,
711    {
712        TypedPromptRequest {
713            inner: self.inner.with_hook(hook),
714            _phantom: std::marker::PhantomData,
715        }
716    }
717}
718
719impl<'a, T, M, P> TypedPromptRequest<'a, T, Standard, M, P>
720where
721    T: JsonSchema + DeserializeOwned + WasmCompatSend,
722    M: CompletionModel,
723    P: PromptHook<M>,
724{
725    /// Send the typed prompt request and deserialize the response.
726    async fn send(self) -> Result<T, StructuredOutputError> {
727        let response = self.inner.send().await?;
728
729        if response.is_empty() {
730            return Err(StructuredOutputError::EmptyResponse);
731        }
732
733        let parsed: T = serde_json::from_str(&response)?;
734        Ok(parsed)
735    }
736}
737
738impl<'a, T, M, P> TypedPromptRequest<'a, T, Extended, M, P>
739where
740    T: JsonSchema + DeserializeOwned + WasmCompatSend,
741    M: CompletionModel,
742    P: PromptHook<M>,
743{
744    /// Send the typed prompt request with extended details and deserialize the response.
745    async fn send(self) -> Result<TypedPromptResponse<T>, StructuredOutputError> {
746        let response = self.inner.send().await?;
747
748        if response.output.is_empty() {
749            return Err(StructuredOutputError::EmptyResponse);
750        }
751
752        let parsed: T = serde_json::from_str(&response.output)?;
753        Ok(TypedPromptResponse::new(parsed, response.usage))
754    }
755}
756
757impl<'a, T, M, P> IntoFuture for TypedPromptRequest<'a, T, Standard, M, P>
758where
759    T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a,
760    M: CompletionModel + 'a,
761    P: PromptHook<M> + 'static,
762{
763    type Output = Result<T, StructuredOutputError>;
764    type IntoFuture = WasmBoxedFuture<'a, Self::Output>;
765
766    fn into_future(self) -> Self::IntoFuture {
767        Box::pin(self.send())
768    }
769}
770
771impl<'a, T, M, P> IntoFuture for TypedPromptRequest<'a, T, Extended, M, P>
772where
773    T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a,
774    M: CompletionModel + 'a,
775    P: PromptHook<M> + 'static,
776{
777    type Output = Result<TypedPromptResponse<T>, StructuredOutputError>;
778    type IntoFuture = WasmBoxedFuture<'a, Self::Output>;
779
780    fn into_future(self) -> Self::IntoFuture {
781        Box::pin(self.send())
782    }
783}