Skip to main content

oai_rt_rs/protocol/models/
common.rs

1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use serde_json::Value;
3use std::collections::HashMap;
4
5pub const DEFAULT_MODEL: &str = "gpt-realtime";
6
7/// Arbitrary JSON payloads allowed by the API (e.g. metadata values).
8pub type Metadata = HashMap<String, Value>;
9
10/// JSON Schema / tool parameter definitions are intentionally untyped.
11pub type JsonSchema = Value;
12
13/// Free-form JSON payloads where the spec is open-ended.
14pub type ArbitraryJson = Value;
15
16/// Tri-state helper for fields that can be omitted, set to null, or set to a value.
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
18#[serde(untagged)]
19pub enum Nullable<T> {
20    Value(T),
21    Null,
22}
23
24impl<T> Nullable<T> {
25    #[must_use]
26    pub const fn as_ref(&self) -> Option<&T> {
27        match self {
28            Self::Value(value) => Some(value),
29            Self::Null => None,
30        }
31    }
32}
33
34#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
35#[serde(rename_all = "snake_case")]
36pub enum Role {
37    #[default]
38    User,
39    Assistant,
40    System,
41}
42
43#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
44#[serde(rename_all = "snake_case")]
45pub enum ItemStatus {
46    #[default]
47    InProgress,
48    Completed,
49    Incomplete,
50}
51
52#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
53#[serde(rename_all = "snake_case")]
54pub enum Modality {
55    #[default]
56    Audio,
57    Text,
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum OutputModalities {
62    Audio,
63    Text,
64}
65
66impl Serialize for OutputModalities {
67    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
68    where
69        S: Serializer,
70    {
71        let values = match self {
72            Self::Audio => vec![Modality::Audio],
73            Self::Text => vec![Modality::Text],
74        };
75        values.serialize(serializer)
76    }
77}
78
79impl<'de> Deserialize<'de> for OutputModalities {
80    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
81    where
82        D: Deserializer<'de>,
83    {
84        #[derive(Deserialize)]
85        #[serde(untagged)]
86        enum Repr {
87            Single(Modality),
88            Many(Vec<Modality>),
89        }
90
91        match Repr::deserialize(deserializer)? {
92            Repr::Single(Modality::Audio) => Ok(Self::Audio),
93            Repr::Single(Modality::Text) => Ok(Self::Text),
94            Repr::Many(values) => match values.as_slice() {
95                [Modality::Audio] => Ok(Self::Audio),
96                [Modality::Text] => Ok(Self::Text),
97                _ => Err(serde::de::Error::custom(
98                    "output_modalities must contain exactly one of: audio or text",
99                )),
100            },
101        }
102    }
103}
104
105#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
106#[serde(rename_all = "snake_case")]
107pub enum Eagerness {
108    Auto,
109    Low,
110    #[default]
111    Medium,
112    High,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
116#[serde(untagged)]
117pub enum Voice {
118    Id(String),
119    Object { id: String },
120}
121
122impl<S: Into<String>> From<S> for Voice {
123    fn from(s: S) -> Self {
124        // Own the string to avoid lifetime plumbing in public APIs.
125        Self::Id(s.into())
126    }
127}
128
129impl std::fmt::Display for Voice {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        match self {
132            Self::Id(id) | Self::Object { id } => write!(f, "{id}"),
133        }
134    }
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
138#[serde(untagged)]
139pub enum MaxTokens {
140    Count(u32),
141    Infinite(Infinite),
142}
143
144#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
145#[serde(rename_all = "lowercase")]
146pub enum Infinite {
147    #[serde(rename = "inf")]
148    Inf,
149}
150
151#[derive(Debug, Clone, Copy, Serialize, PartialEq)]
152#[serde(transparent)]
153pub struct Temperature(f32);
154
155impl Temperature {
156    /// # Errors
157    /// Returns an error if `val` is outside the inclusive range [0.0, 2.0].
158    pub fn new(val: f32) -> Result<Self, TemperatureError> {
159        if (0.0..=2.0).contains(&val) {
160            Ok(Self(val))
161        } else {
162            Err(TemperatureError { value: val })
163        }
164    }
165}
166
167impl Default for Temperature {
168    fn default() -> Self {
169        Self(0.8)
170    }
171}
172
173#[derive(Debug, Clone, Copy, PartialEq)]
174pub struct TemperatureError {
175    pub value: f32,
176}
177
178impl std::fmt::Display for TemperatureError {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        write!(
181            f,
182            "temperature must be between 0.0 and 2.0, got {}",
183            self.value
184        )
185    }
186}
187
188impl std::error::Error for TemperatureError {}
189
190impl TryFrom<f32> for Temperature {
191    type Error = TemperatureError;
192
193    fn try_from(value: f32) -> Result<Self, Self::Error> {
194        Self::new(value)
195    }
196}
197
198impl<'de> Deserialize<'de> for Temperature {
199    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
200    where
201        D: Deserializer<'de>,
202    {
203        let value = f32::deserialize(deserializer)?;
204        Self::new(value).map_err(serde::de::Error::custom)
205    }
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
209#[serde(untagged)]
210pub enum PromptRef {
211    Id(String),
212    Object { id: String },
213}