Skip to main content

rig/agent/prompt_request/
mod.rs

1pub mod hooks;
2pub mod streaming;
3
4use hooks::{HookAction, PromptHook, ToolCallHookAction};
5use std::{
6    future::IntoFuture,
7    marker::PhantomData,
8    sync::{
9        Arc,
10        atomic::{AtomicU64, Ordering},
11    },
12};
13use tracing::{Instrument, span::Id};
14
15use futures::{StreamExt, stream};
16use tracing::info_span;
17
18use crate::{
19    OneOrMany,
20    completion::{CompletionModel, Document, Message, PromptError, Usage},
21    json_utils,
22    message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
23    tool::server::ToolServerHandle,
24    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
25};
26
27use super::{
28    Agent,
29    completion::{DynamicContextStore, build_completion_request},
30};
31
32pub trait PromptType {}
33pub struct Standard;
34pub struct Extended;
35
36impl PromptType for Standard {}
37impl PromptType for Extended {}
38
39/// A builder for creating prompt requests with customizable options.
40/// Uses generics to track which options have been set during the build process.
41///
42/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
43/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
44/// attempting to await (which will send the prompt request) can potentially return
45/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
46/// back to back.
47pub struct PromptRequest<'a, S, M, P>
48where
49    S: PromptType,
50    M: CompletionModel,
51    P: PromptHook<M>,
52{
53    /// The prompt message to send to the model
54    prompt: Message,
55    /// Optional chat history to include with the prompt
56    /// Note: chat history needs to outlive the agent as it might be used with other agents
57    chat_history: Option<&'a mut Vec<Message>>,
58    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
59    max_turns: usize,
60
61    // Agent data (cloned from agent to allow hook type transitions):
62    /// The completion model
63    model: Arc<M>,
64    /// Agent name for logging
65    agent_name: Option<String>,
66    /// System prompt
67    preamble: Option<String>,
68    /// Static context documents
69    static_context: Vec<Document>,
70    /// Temperature setting
71    temperature: Option<f64>,
72    /// Max tokens setting
73    max_tokens: Option<u64>,
74    /// Additional model parameters
75    additional_params: Option<serde_json::Value>,
76    /// Tool server handle for tool execution
77    tool_server_handle: ToolServerHandle,
78    /// Dynamic context store
79    dynamic_context: DynamicContextStore,
80    /// Tool choice setting
81    tool_choice: Option<ToolChoice>,
82
83    /// Phantom data to track the type of the request
84    state: PhantomData<S>,
85    /// Optional per-request hook for events
86    hook: Option<P>,
87    /// How many tools should be executed at the same time (1 by default).
88    concurrency: usize,
89    /// Optional JSON Schema for structured output
90    output_schema: Option<schemars::Schema>,
91}
92
93impl<'a, M, P> PromptRequest<'a, Standard, M, P>
94where
95    M: CompletionModel,
96    P: PromptHook<M>,
97{
98    /// Create a new PromptRequest from an agent, cloning the agent's data and default hook.
99    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
100        PromptRequest {
101            prompt: prompt.into(),
102            chat_history: None,
103            max_turns: agent.default_max_turns.unwrap_or_default(),
104            model: agent.model.clone(),
105            agent_name: agent.name.clone(),
106            preamble: agent.preamble.clone(),
107            static_context: agent.static_context.clone(),
108            temperature: agent.temperature,
109            max_tokens: agent.max_tokens,
110            additional_params: agent.additional_params.clone(),
111            tool_server_handle: agent.tool_server_handle.clone(),
112            dynamic_context: agent.dynamic_context.clone(),
113            tool_choice: agent.tool_choice.clone(),
114            state: PhantomData,
115            hook: agent.hook.clone(),
116            concurrency: 1,
117            output_schema: agent.output_schema.clone(),
118        }
119    }
120}
121
122impl<'a, S, M, P> PromptRequest<'a, S, M, P>
123where
124    S: PromptType,
125    M: CompletionModel,
126    P: PromptHook<M>,
127{
128    /// Enable returning extended details for responses (includes aggregated token usage
129    /// and the full message history accumulated during the agent loop).
130    ///
131    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
132    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
133    /// of conversation and inspecting the full message exchange.
134    pub fn extended_details(self) -> PromptRequest<'a, Extended, M, P> {
135        PromptRequest {
136            prompt: self.prompt,
137            chat_history: self.chat_history,
138            max_turns: self.max_turns,
139            model: self.model,
140            agent_name: self.agent_name,
141            preamble: self.preamble,
142            static_context: self.static_context,
143            temperature: self.temperature,
144            max_tokens: self.max_tokens,
145            additional_params: self.additional_params,
146            tool_server_handle: self.tool_server_handle,
147            dynamic_context: self.dynamic_context,
148            tool_choice: self.tool_choice,
149            state: PhantomData,
150            hook: self.hook,
151            concurrency: self.concurrency,
152            output_schema: self.output_schema,
153        }
154    }
155
156    /// Set the maximum number of turns for multi-turn conversations. A given agent may require multiple turns for tool-calling before giving an answer.
157    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
158    pub fn max_turns(mut self, depth: usize) -> Self {
159        self.max_turns = depth;
160        self
161    }
162
163    /// Add concurrency to the prompt request.
164    /// This will cause the agent to execute tools concurrently.
165    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
166        self.concurrency = concurrency;
167        self
168    }
169
170    /// Add chat history to the prompt request
171    pub fn with_history(mut self, history: &'a mut Vec<Message>) -> Self {
172        self.chat_history = Some(history);
173        self
174    }
175
176    /// Attach a per-request hook for tool call events.
177    /// This overrides any default hook set on the agent.
178    pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<'a, S, M, P2>
179    where
180        P2: PromptHook<M>,
181    {
182        PromptRequest {
183            prompt: self.prompt,
184            chat_history: self.chat_history,
185            max_turns: self.max_turns,
186            model: self.model,
187            agent_name: self.agent_name,
188            preamble: self.preamble,
189            static_context: self.static_context,
190            temperature: self.temperature,
191            max_tokens: self.max_tokens,
192            additional_params: self.additional_params,
193            tool_server_handle: self.tool_server_handle,
194            dynamic_context: self.dynamic_context,
195            tool_choice: self.tool_choice,
196            state: PhantomData,
197            hook: Some(hook),
198            concurrency: self.concurrency,
199            output_schema: self.output_schema,
200        }
201    }
202}
203
204/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
205///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
206///  directly via the associated type.
207impl<'a, M, P> IntoFuture for PromptRequest<'a, Standard, M, P>
208where
209    M: CompletionModel + 'a,
210    P: PromptHook<M> + 'static,
211{
212    type Output = Result<String, PromptError>;
213    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
214
215    fn into_future(self) -> Self::IntoFuture {
216        Box::pin(self.send())
217    }
218}
219
220impl<'a, M, P> IntoFuture for PromptRequest<'a, Extended, M, P>
221where
222    M: CompletionModel + 'a,
223    P: PromptHook<M> + 'static,
224{
225    type Output = Result<PromptResponse, PromptError>;
226    type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent
227
228    fn into_future(self) -> Self::IntoFuture {
229        Box::pin(self.send())
230    }
231}
232
233impl<M, P> PromptRequest<'_, Standard, M, P>
234where
235    M: CompletionModel,
236    P: PromptHook<M>,
237{
238    async fn send(self) -> Result<String, PromptError> {
239        self.extended_details().send().await.map(|resp| resp.output)
240    }
241}
242
243#[derive(Debug, Clone)]
244#[non_exhaustive]
245pub struct PromptResponse {
246    pub output: String,
247    pub usage: Usage,
248    /// The message history accumulated during the agent loop.
249    /// Always populated when using `extended_details()`.
250    pub messages: Option<Vec<Message>>,
251}
252
253impl std::fmt::Display for PromptResponse {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        self.output.fmt(f)
256    }
257}
258
259impl PromptResponse {
260    pub fn new(output: impl Into<String>, usage: Usage) -> Self {
261        Self {
262            output: output.into(),
263            usage,
264            messages: None,
265        }
266    }
267
268    pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
269        self.messages = Some(messages);
270        self
271    }
272}
273
274#[derive(Debug, Clone)]
275pub struct TypedPromptResponse<T> {
276    pub output: T,
277    pub usage: Usage,
278}
279
280impl<T> TypedPromptResponse<T> {
281    pub fn new(output: T, usage: Usage) -> Self {
282        Self { output, usage }
283    }
284}
285
286const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
287
288impl<M, P> PromptRequest<'_, Extended, M, P>
289where
290    M: CompletionModel,
291    P: PromptHook<M>,
292{
293    fn agent_name(&self) -> &str {
294        self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
295    }
296
297    async fn send(mut self) -> Result<PromptResponse, PromptError> {
298        let agent_span = if tracing::Span::current().is_disabled() {
299            info_span!(
300                "invoke_agent",
301                gen_ai.operation.name = "invoke_agent",
302                gen_ai.agent.name = self.agent_name(),
303                gen_ai.system_instructions = self.preamble,
304                gen_ai.prompt = tracing::field::Empty,
305                gen_ai.completion = tracing::field::Empty,
306                gen_ai.usage.input_tokens = tracing::field::Empty,
307                gen_ai.usage.output_tokens = tracing::field::Empty,
308            )
309        } else {
310            tracing::Span::current()
311        };
312
313        if let Some(text) = self.prompt.rag_text() {
314            agent_span.record("gen_ai.prompt", text);
315        }
316
317        // Capture agent_name before borrowing chat_history
318        let agent_name_for_span = self.agent_name.clone();
319
320        let chat_history = if let Some(history) = self.chat_history.as_mut() {
321            history.push(self.prompt.to_owned());
322            history
323        } else {
324            &mut vec![self.prompt.to_owned()]
325        };
326
327        let mut current_max_turns = 0;
328        let mut usage = Usage::new();
329        let current_span_id: AtomicU64 = AtomicU64::new(0);
330
331        // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
332        let last_prompt = loop {
333            let prompt = chat_history
334                .last()
335                .cloned()
336                .expect("there should always be at least one message in the chat history");
337
338            if current_max_turns > self.max_turns + 1 {
339                break prompt;
340            }
341
342            current_max_turns += 1;
343
344            if self.max_turns > 1 {
345                tracing::info!(
346                    "Current conversation depth: {}/{}",
347                    current_max_turns,
348                    self.max_turns
349                );
350            }
351
352            if let Some(ref hook) = self.hook
353                && let HookAction::Terminate { reason } = hook
354                    .on_completion_call(&prompt, &chat_history[..chat_history.len() - 1])
355                    .await
356            {
357                return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
358            }
359
360            let span = tracing::Span::current();
361            let chat_span = info_span!(
362                target: "rig::agent_chat",
363                parent: &span,
364                "chat",
365                gen_ai.operation.name = "chat",
366                gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
367                gen_ai.system_instructions = self.preamble,
368                gen_ai.provider.name = tracing::field::Empty,
369                gen_ai.request.model = tracing::field::Empty,
370                gen_ai.response.id = tracing::field::Empty,
371                gen_ai.response.model = tracing::field::Empty,
372                gen_ai.usage.output_tokens = tracing::field::Empty,
373                gen_ai.usage.input_tokens = tracing::field::Empty,
374                gen_ai.input.messages = tracing::field::Empty,
375                gen_ai.output.messages = tracing::field::Empty,
376            );
377
378            let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
379                let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
380                chat_span.follows_from(id).to_owned()
381            } else {
382                chat_span
383            };
384
385            if let Some(id) = chat_span.id() {
386                current_span_id.store(id.into_u64(), Ordering::SeqCst);
387            };
388
389            let resp = build_completion_request(
390                &self.model,
391                prompt.clone(),
392                chat_history[..chat_history.len() - 1].to_vec(),
393                self.preamble.as_deref(),
394                &self.static_context,
395                self.temperature,
396                self.max_tokens,
397                self.additional_params.as_ref(),
398                self.tool_choice.as_ref(),
399                &self.tool_server_handle,
400                &self.dynamic_context,
401                self.output_schema.as_ref(),
402            )
403            .await?
404            .send()
405            .instrument(chat_span.clone())
406            .await?;
407
408            usage += resp.usage;
409
410            if let Some(ref hook) = self.hook
411                && let HookAction::Terminate { reason } =
412                    hook.on_completion_response(&prompt, &resp).await
413            {
414                return Err(PromptError::prompt_cancelled(chat_history.to_vec(), reason));
415            }
416
417            let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
418                .choice
419                .iter()
420                .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
421
422            chat_history.push(Message::Assistant {
423                id: resp.message_id.clone(),
424                content: resp.choice.clone(),
425            });
426
427            if tool_calls.is_empty() {
428                let merged_texts = texts
429                    .into_iter()
430                    .filter_map(|content| {
431                        if let AssistantContent::Text(text) = content {
432                            Some(text.text.clone())
433                        } else {
434                            None
435                        }
436                    })
437                    .collect::<Vec<_>>()
438                    .join("\n");
439
440                if self.max_turns > 1 {
441                    tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
442                }
443
444                agent_span.record("gen_ai.completion", &merged_texts);
445                agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
446                agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
447
448                // If there are no tool calls, depth is not relevant, we can just return the merged text response.
449                return Ok(
450                    PromptResponse::new(merged_texts, usage).with_messages(chat_history.to_vec())
451                );
452            }
453
454            let hook = self.hook.clone();
455            let tool_server_handle = self.tool_server_handle.clone();
456
457            let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
458            let tool_content = stream::iter(tool_calls)
459                .map(|choice| {
460                    let hook1 = hook.clone();
461                    let hook2 = hook.clone();
462                    let tool_server_handle = tool_server_handle.clone();
463
464                    let tool_span = info_span!(
465                        "execute_tool",
466                        gen_ai.operation.name = "execute_tool",
467                        gen_ai.tool.type = "function",
468                        gen_ai.tool.name = tracing::field::Empty,
469                        gen_ai.tool.call.id = tracing::field::Empty,
470                        gen_ai.tool.call.arguments = tracing::field::Empty,
471                        gen_ai.tool.call.result = tracing::field::Empty
472                    );
473
474                    let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
475                        let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
476                        tool_span.follows_from(id).to_owned()
477                    } else {
478                        tool_span
479                    };
480
481                    if let Some(id) = tool_span.id() {
482                        current_span_id.store(id.into_u64(), Ordering::SeqCst);
483                    };
484
485                    let cloned_chat_history = chat_history.clone().to_vec();
486
487                    async move {
488                        if let AssistantContent::ToolCall(tool_call) = choice {
489                            let tool_name = &tool_call.function.name;
490                            let args =
491                                json_utils::value_to_json_string(&tool_call.function.arguments);
492                            let internal_call_id = nanoid::nanoid!();
493                            let tool_span = tracing::Span::current();
494                            tool_span.record("gen_ai.tool.name", tool_name);
495                            tool_span.record("gen_ai.tool.call.id", &tool_call.id);
496                            tool_span.record("gen_ai.tool.call.arguments", &args);
497                            if let Some(hook) = hook1 {
498                                let action = hook
499                                    .on_tool_call(
500                                        tool_name,
501                                        tool_call.call_id.clone(),
502                                        &internal_call_id,
503                                        &args,
504                                    )
505                                    .await;
506
507                                if let ToolCallHookAction::Terminate { reason } = action {
508                                    return Err(PromptError::prompt_cancelled(
509                                        cloned_chat_history,
510                                        reason,
511                                    ));
512                                }
513
514                                if let ToolCallHookAction::Skip { reason } = action {
515                                    // Tool execution rejected, return rejection message as tool result
516                                    tracing::info!(
517                                        tool_name = tool_name,
518                                        reason = reason,
519                                        "Tool call rejected"
520                                    );
521                                    if let Some(call_id) = tool_call.call_id.clone() {
522                                        return Ok(UserContent::tool_result_with_call_id(
523                                            tool_call.id.clone(),
524                                            call_id,
525                                            OneOrMany::one(reason.into()),
526                                        ));
527                                    } else {
528                                        return Ok(UserContent::tool_result(
529                                            tool_call.id.clone(),
530                                            OneOrMany::one(reason.into()),
531                                        ));
532                                    }
533                                }
534                            }
535                            let output = match tool_server_handle.call_tool(tool_name, &args).await
536                            {
537                                Ok(res) => res,
538                                Err(e) => {
539                                    tracing::warn!("Error while executing tool: {e}");
540                                    e.to_string()
541                                }
542                            };
543                            if let Some(hook) = hook2
544                                && let HookAction::Terminate { reason } = hook
545                                    .on_tool_result(
546                                        tool_name,
547                                        tool_call.call_id.clone(),
548                                        &internal_call_id,
549                                        &args,
550                                        &output.to_string(),
551                                    )
552                                    .await
553                            {
554                                return Err(PromptError::prompt_cancelled(
555                                    cloned_chat_history,
556                                    reason,
557                                ));
558                            }
559
560                            tool_span.record("gen_ai.tool.call.result", &output);
561                            tracing::info!(
562                                "executed tool {tool_name} with args {args}. result: {output}"
563                            );
564                            if let Some(call_id) = tool_call.call_id.clone() {
565                                Ok(UserContent::tool_result_with_call_id(
566                                    tool_call.id.clone(),
567                                    call_id,
568                                    ToolResultContent::from_tool_output(output),
569                                ))
570                            } else {
571                                Ok(UserContent::tool_result(
572                                    tool_call.id.clone(),
573                                    ToolResultContent::from_tool_output(output),
574                                ))
575                            }
576                        } else {
577                            unreachable!(
578                                "This should never happen as we already filtered for `ToolCall`"
579                            )
580                        }
581                    }
582                    .instrument(tool_span)
583                })
584                .buffer_unordered(self.concurrency)
585                .collect::<Vec<Result<UserContent, PromptError>>>()
586                .await
587                .into_iter()
588                .collect::<Result<Vec<_>, _>>()?;
589
590            chat_history.push(Message::User {
591                content: OneOrMany::many(tool_content).expect("There is atleast one tool call"),
592            });
593        };
594
595        // If we reach here, we never resolved the final tool call. We need to do ... something.
596        Err(PromptError::MaxTurnsError {
597            max_turns: self.max_turns,
598            chat_history: Box::new(chat_history.clone()),
599            prompt: Box::new(last_prompt),
600        })
601    }
602}
603
604// ================================================================
605// TypedPromptRequest - for structured output with automatic deserialization
606// ================================================================
607
608use crate::completion::StructuredOutputError;
609use schemars::{JsonSchema, schema_for};
610use serde::de::DeserializeOwned;
611
612/// A builder for creating typed prompt requests that return deserialized structured output.
613///
614/// This struct wraps a standard `PromptRequest` and adds:
615/// - Automatic JSON schema generation from the target type `T`
616/// - Automatic deserialization of the response into `T`
617///
618/// The type parameter `S` represents the state of the request (Standard or Extended).
619/// Use `.extended_details()` to transition to Extended state for usage tracking.
620///
621/// # Example
622/// ```rust,ignore
623/// let forecast: WeatherForecast = agent
624///     .prompt_typed("What's the weather in NYC?")
625///     .max_turns(3)
626///     .await?;
627/// ```
628pub struct TypedPromptRequest<'a, T, S, M, P>
629where
630    T: JsonSchema + DeserializeOwned + WasmCompatSend,
631    S: PromptType,
632    M: CompletionModel,
633    P: PromptHook<M>,
634{
635    inner: PromptRequest<'a, S, M, P>,
636    _phantom: std::marker::PhantomData<T>,
637}
638
639impl<'a, T, M, P> TypedPromptRequest<'a, T, Standard, M, P>
640where
641    T: JsonSchema + DeserializeOwned + WasmCompatSend,
642    M: CompletionModel,
643    P: PromptHook<M>,
644{
645    /// Create a new TypedPromptRequest from an agent.
646    ///
647    /// This automatically sets the output schema based on the type parameter `T`.
648    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
649        let mut inner = PromptRequest::from_agent(agent, prompt);
650        // Override the output schema with the schema for T
651        inner.output_schema = Some(schema_for!(T));
652        Self {
653            inner,
654            _phantom: std::marker::PhantomData,
655        }
656    }
657}
658
659impl<'a, T, S, M, P> TypedPromptRequest<'a, T, S, M, P>
660where
661    T: JsonSchema + DeserializeOwned + WasmCompatSend,
662    S: PromptType,
663    M: CompletionModel,
664    P: PromptHook<M>,
665{
666    /// Enable returning extended details for responses (includes aggregated token usage).
667    ///
668    /// Note: This changes the type of the response from `.send()` to return a `TypedPromptResponse<T>` struct
669    /// instead of just `T`. This is useful for tracking token usage across multiple turns
670    /// of conversation.
671    pub fn extended_details(self) -> TypedPromptRequest<'a, T, Extended, M, P> {
672        TypedPromptRequest {
673            inner: self.inner.extended_details(),
674            _phantom: std::marker::PhantomData,
675        }
676    }
677
678    /// Set the maximum number of turns for multi-turn conversations.
679    ///
680    /// A given agent may require multiple turns for tool-calling before giving an answer.
681    /// If the maximum turn number is exceeded, it will return a
682    /// [`StructuredOutputError::PromptError`] wrapping a `MaxTurnsError`.
683    pub fn max_turns(mut self, depth: usize) -> Self {
684        self.inner = self.inner.max_turns(depth);
685        self
686    }
687
688    /// Add concurrency to the prompt request.
689    ///
690    /// This will cause the agent to execute tools concurrently.
691    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
692        self.inner = self.inner.with_tool_concurrency(concurrency);
693        self
694    }
695
696    /// Add chat history to the prompt request.
697    pub fn with_history(mut self, history: &'a mut Vec<Message>) -> Self {
698        self.inner = self.inner.with_history(history);
699        self
700    }
701
702    /// Attach a per-request hook for tool call events.
703    ///
704    /// This overrides any default hook set on the agent.
705    pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<'a, T, S, M, P2>
706    where
707        P2: PromptHook<M>,
708    {
709        TypedPromptRequest {
710            inner: self.inner.with_hook(hook),
711            _phantom: std::marker::PhantomData,
712        }
713    }
714}
715
716impl<'a, T, M, P> TypedPromptRequest<'a, T, Standard, M, P>
717where
718    T: JsonSchema + DeserializeOwned + WasmCompatSend,
719    M: CompletionModel,
720    P: PromptHook<M>,
721{
722    /// Send the typed prompt request and deserialize the response.
723    async fn send(self) -> Result<T, StructuredOutputError> {
724        let response = self.inner.send().await?;
725
726        if response.is_empty() {
727            return Err(StructuredOutputError::EmptyResponse);
728        }
729
730        let parsed: T = serde_json::from_str(&response)?;
731        Ok(parsed)
732    }
733}
734
735impl<'a, T, M, P> TypedPromptRequest<'a, T, Extended, M, P>
736where
737    T: JsonSchema + DeserializeOwned + WasmCompatSend,
738    M: CompletionModel,
739    P: PromptHook<M>,
740{
741    /// Send the typed prompt request with extended details and deserialize the response.
742    async fn send(self) -> Result<TypedPromptResponse<T>, StructuredOutputError> {
743        let response = self.inner.send().await?;
744
745        if response.output.is_empty() {
746            return Err(StructuredOutputError::EmptyResponse);
747        }
748
749        let parsed: T = serde_json::from_str(&response.output)?;
750        Ok(TypedPromptResponse::new(parsed, response.usage))
751    }
752}
753
754impl<'a, T, M, P> IntoFuture for TypedPromptRequest<'a, T, Standard, M, P>
755where
756    T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a,
757    M: CompletionModel + 'a,
758    P: PromptHook<M> + 'static,
759{
760    type Output = Result<T, StructuredOutputError>;
761    type IntoFuture = WasmBoxedFuture<'a, Self::Output>;
762
763    fn into_future(self) -> Self::IntoFuture {
764        Box::pin(self.send())
765    }
766}
767
768impl<'a, T, M, P> IntoFuture for TypedPromptRequest<'a, T, Extended, M, P>
769where
770    T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a,
771    M: CompletionModel + 'a,
772    P: PromptHook<M> + 'static,
773{
774    type Output = Result<TypedPromptResponse<T>, StructuredOutputError>;
775    type IntoFuture = WasmBoxedFuture<'a, Self::Output>;
776
777    fn into_future(self) -> Self::IntoFuture {
778        Box::pin(self.send())
779    }
780}