kind_openai/endpoints/chat/
structured.rs1use kind_openai_schema::{GeneratedOpenAISchema, OpenAISchema};
2use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize};
3
4use crate::{endpoints::OpenAIRequestProvider, OpenAIResult, Usage};
5
6use super::{standard::ChatCompletion, FinishReason, UnifiedChatCompletionResponseMessage};
7
8#[derive(Serialize)]
14pub struct StructuredChatCompletion<'a, S> {
15 #[serde(flatten)]
16 pub(super) base_request: ChatCompletion<'a>,
17 pub(super) response_format: ChatCompletionRequestResponseFormat,
18 #[serde(skip)]
21 pub(super) _phantom: std::marker::PhantomData<S>,
22}
23
24#[derive(Serialize)]
26#[serde(tag = "type", content = "json_schema", rename_all = "snake_case")]
28pub(super) enum ChatCompletionRequestResponseFormat {
29 JsonSchema(GeneratedOpenAISchema),
30}
31
32impl<S> OpenAIRequestProvider for StructuredChatCompletion<'_, S>
33where
34 S: OpenAISchema + for<'de> Deserialize<'de>,
35{
36 type Response = StructuredChatCompletionResponse<S>;
37
38 const METHOD: reqwest::Method = reqwest::Method::POST;
39
40 fn path_with_leading_slash() -> String {
41 "/chat/completions".to_string()
42 }
43}
44
45impl<S> super::super::private::Sealed for StructuredChatCompletion<'_, S> {}
46
47#[derive(Deserialize)]
49#[serde(bound(deserialize = "S: DeserializeOwned"))]
50pub struct StructuredChatCompletionResponse<S> {
51 choices: Vec<StructuredChatCompletionResponseChoice<S>>,
52 usage: Usage,
53}
54
55impl<S> StructuredChatCompletionResponse<S> {
56 pub fn take_first_choice(self) -> Option<StructuredChatCompletionResponseChoice<S>> {
58 self.choices.into_iter().next()
59 }
60
61 pub fn usage(&self) -> Usage {
63 self.usage
64 }
65}
66
67#[derive(Deserialize)]
68#[serde(bound(deserialize = "S: DeserializeOwned"))]
69pub struct StructuredChatCompletionResponseChoice<S> {
70 finish_reason: FinishReason,
71 index: i32,
72 message: StructuredChatCompletionResponseMessage<S>,
73}
74
75impl<S> StructuredChatCompletionResponseChoice<S> {
76 pub fn message(self) -> OpenAIResult<S> {
78 Into::<UnifiedChatCompletionResponseMessage<S>>::into(self.message).into()
79 }
80
81 pub fn finish_reason(&self) -> FinishReason {
82 self.finish_reason
83 }
84
85 pub fn index(&self) -> i32 {
86 self.index
87 }
88}
89
90#[derive(Deserialize)]
92#[serde(bound(deserialize = "S: DeserializeOwned"))]
93struct StructuredChatCompletionResponseMessage<S> {
94 #[serde(deserialize_with = "de_from_str")]
95 content: S,
96 refusal: Option<String>,
97}
98
99fn de_from_str<'de, D, S>(deserializer: D) -> Result<S, D::Error>
100where
101 D: Deserializer<'de>,
102 S: DeserializeOwned,
103{
104 let s = String::deserialize(deserializer)?;
105 serde_json::from_str(&s).map_err(serde::de::Error::custom)
106}
107
108impl<S> From<StructuredChatCompletionResponseMessage<S>>
109 for UnifiedChatCompletionResponseMessage<S>
110{
111 fn from(value: StructuredChatCompletionResponseMessage<S>) -> Self {
112 UnifiedChatCompletionResponseMessage {
113 content: value.content,
114 refusal: value.refusal,
115 }
116 }
117}