kind_openai/endpoints/chat/
structured.rs

1use kind_openai_schema::{GeneratedOpenAISchema, OpenAISchema};
2use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize};
3
4use crate::{endpoints::OpenAIRequestProvider, OpenAIResult, Usage};
5
6use super::{standard::ChatCompletion, FinishReason, UnifiedChatCompletionResponseMessage};
7
8/// A chat completion request who's response conforms to a particular JSON schema.
9///
10/// All types which are structured must derive `kind_openai::OpenAISchema`, as well as
11/// `serde::Deserialize`. Take a look at the docs of that trait for a better idea of how
12/// to use it.
13#[derive(Serialize)]
14pub struct StructuredChatCompletion<'a, S> {
15    #[serde(flatten)]
16    pub(super) base_request: ChatCompletion<'a>,
17    pub(super) response_format: ChatCompletionRequestResponseFormat,
18    // tether the schema type to the request so that drifting between the request and response
19    // type when deserialization time comes is impossible
20    #[serde(skip)]
21    pub(super) _phantom: std::marker::PhantomData<S>,
22}
23
24/// Enum that serializes itself into the part of the request body where OpenAI expects the schema.
25#[derive(Serialize)]
26// TODO: fix this so that `content = "json_schema"` is not necessary
27#[serde(tag = "type", content = "json_schema", rename_all = "snake_case")]
28pub(super) enum ChatCompletionRequestResponseFormat {
29    JsonSchema(GeneratedOpenAISchema),
30}
31
32impl<S> OpenAIRequestProvider for StructuredChatCompletion<'_, S>
33where
34    S: OpenAISchema + for<'de> Deserialize<'de>,
35{
36    type Response = StructuredChatCompletionResponse<S>;
37
38    const METHOD: reqwest::Method = reqwest::Method::POST;
39
40    fn path_with_leading_slash() -> String {
41        "/chat/completions".to_string()
42    }
43}
44
45impl<S> super::super::private::Sealed for StructuredChatCompletion<'_, S> {}
46
47/// A response from a structured chat completion request.
48#[derive(Deserialize)]
49#[serde(bound(deserialize = "S: DeserializeOwned"))]
50pub struct StructuredChatCompletionResponse<S> {
51    choices: Vec<StructuredChatCompletionResponseChoice<S>>,
52    usage: Usage,
53}
54
55impl<S> StructuredChatCompletionResponse<S> {
56    /// Takes the first message in the response consumes the response.
57    pub fn take_first_choice(self) -> Option<StructuredChatCompletionResponseChoice<S>> {
58        self.choices.into_iter().next()
59    }
60
61    /// Gives the usage tokens of the response.
62    pub fn usage(&self) -> Usage {
63        self.usage
64    }
65}
66
67#[derive(Deserialize)]
68#[serde(bound(deserialize = "S: DeserializeOwned"))]
69pub struct StructuredChatCompletionResponseChoice<S> {
70    finish_reason: FinishReason,
71    index: i32,
72    message: StructuredChatCompletionResponseMessage<S>,
73}
74
75impl<S> StructuredChatCompletionResponseChoice<S> {
76    /// Returns your desired type that was produced from OpenAI.
77    pub fn message(self) -> OpenAIResult<S> {
78        Into::<UnifiedChatCompletionResponseMessage<S>>::into(self.message).into()
79    }
80
81    pub fn finish_reason(&self) -> FinishReason {
82        self.finish_reason
83    }
84
85    pub fn index(&self) -> i32 {
86        self.index
87    }
88}
89
90// leave private, messages should only be interacted with through the unified message type.
91#[derive(Deserialize)]
92#[serde(bound(deserialize = "S: DeserializeOwned"))]
93struct StructuredChatCompletionResponseMessage<S> {
94    #[serde(deserialize_with = "de_from_str")]
95    content: S,
96    refusal: Option<String>,
97}
98
99fn de_from_str<'de, D, S>(deserializer: D) -> Result<S, D::Error>
100where
101    D: Deserializer<'de>,
102    S: DeserializeOwned,
103{
104    let s = String::deserialize(deserializer)?;
105    serde_json::from_str(&s).map_err(serde::de::Error::custom)
106}
107
108impl<S> From<StructuredChatCompletionResponseMessage<S>>
109    for UnifiedChatCompletionResponseMessage<S>
110{
111    fn from(value: StructuredChatCompletionResponseMessage<S>) -> Self {
112        UnifiedChatCompletionResponseMessage {
113            content: value.content,
114            refusal: value.refusal,
115        }
116    }
117}