Skip to main content

oai_rt_rs/protocol/models/
session.rs

1use serde::ser::SerializeMap;
2use serde::{Deserialize, Serialize};
3
4use super::{
5    AudioConfig, AudioFormat, InputAudioTranscription, MaxTokens, Modality, Nullable,
6    OutputModalities, PromptRef, Temperature, Tool, ToolChoice, TurnDetection, Voice,
7};
8
9#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
10#[serde(rename_all = "snake_case")]
11pub enum SessionKind {
12    #[default]
13    Realtime,
14    Transcription,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
18#[serde(rename_all = "snake_case")]
19pub enum TracingAuto {
20    Auto,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct TracingConfig {
25    pub workflow_name: Option<String>,
26    pub group_id: Option<String>,
27    /// Arbitrary tracing metadata (spec allows free-form JSON values).
28    pub metadata: Option<super::Metadata>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32#[serde(untagged)]
33pub enum Tracing {
34    Auto(TracingAuto),
35    Config(TracingConfig),
36}
37
38#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum TruncationStrategy {
41    Auto,
42    Disabled,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
46#[serde(rename_all = "snake_case")]
47pub enum TruncationType {
48    RetentionRatio,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
52pub struct TokenLimits {
53    pub post_instructions: Option<u32>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57pub struct RetentionRatioTruncation {
58    #[serde(rename = "type")]
59    pub kind: TruncationType,
60    pub retention_ratio: f32,
61    pub token_limits: Option<TokenLimits>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65#[serde(untagged)]
66pub enum Truncation {
67    Strategy(TruncationStrategy),
68    RetentionRatio(RetentionRatioTruncation),
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct SessionConfig {
73    #[serde(rename = "type")]
74    pub kind: SessionKind,
75    pub model: String,
76    pub output_modalities: OutputModalities,
77    pub modalities: Option<Vec<Modality>>,
78    pub include: Option<Vec<String>>,
79    pub prompt: Option<PromptRef>,
80    pub truncation: Option<Truncation>,
81    pub instructions: Option<String>,
82    pub input_audio_format: Option<AudioFormat>,
83    pub output_audio_format: Option<AudioFormat>,
84    pub input_audio_transcription: Option<Nullable<InputAudioTranscription>>,
85    pub turn_detection: Option<Nullable<TurnDetection>>,
86    pub tools: Option<Vec<Tool>>,
87    pub tool_choice: Option<ToolChoice>,
88    pub temperature: Option<Temperature>,
89    pub max_output_tokens: Option<MaxTokens>,
90    pub audio: Option<AudioConfig>,
91    pub tracing: Option<Tracing>,
92    pub voice: Option<Voice>,
93}
94
95impl SessionConfig {
96    #[must_use]
97    pub fn new(
98        kind: SessionKind,
99        model: impl Into<String>,
100        output_modalities: OutputModalities,
101    ) -> Self {
102        Self {
103            kind,
104            model: model.into(),
105            output_modalities,
106            modalities: None,
107            include: None,
108            prompt: None,
109            truncation: None,
110            instructions: None,
111            input_audio_format: None,
112            output_audio_format: None,
113            input_audio_transcription: None,
114            turn_detection: None,
115            tools: None,
116            tool_choice: None,
117            temperature: None,
118            max_output_tokens: None,
119            audio: None,
120            tracing: None,
121            voice: None,
122        }
123    }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize, Default)]
127pub struct SessionUpdateConfig {
128    /// Partial updates only; GA forbids changing `model` or session `type`.
129    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
130    pub kind: Option<SessionKind>,
131    pub output_modalities: Option<OutputModalities>,
132    pub modalities: Option<Vec<Modality>>,
133    pub include: Option<Vec<String>>,
134    pub prompt: Option<PromptRef>,
135    pub truncation: Option<Truncation>,
136    pub instructions: Option<String>,
137    pub input_audio_format: Option<AudioFormat>,
138    pub output_audio_format: Option<AudioFormat>,
139    pub input_audio_transcription: Option<Nullable<InputAudioTranscription>>,
140    pub turn_detection: Option<Nullable<TurnDetection>>,
141    pub tools: Option<Vec<Tool>>,
142    pub tool_choice: Option<ToolChoice>,
143    pub temperature: Option<Temperature>,
144    pub max_output_tokens: Option<MaxTokens>,
145    pub audio: Option<AudioConfig>,
146    pub tracing: Option<Tracing>,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct Session {
151    pub id: String,
152    pub object: String,
153    pub expires_at: u64,
154    /// Flattened to match the API's session JSON shape.
155    #[serde(flatten)]
156    pub config: SessionConfig,
157}
158
159#[derive(Debug, Clone, Deserialize, Default)]
160pub struct SessionUpdate {
161    /// Flattened to match the API's session.update JSON shape.
162    #[serde(flatten)]
163    pub config: SessionUpdateConfig,
164}
165
166impl Serialize for SessionUpdate {
167    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
168    where
169        S: serde::Serializer,
170    {
171        let mut map = serializer.serialize_map(None)?;
172        let value = serde_json::to_value(&self.config).map_err(serde::ser::Error::custom)?;
173        if let serde_json::Value::Object(obj) = value {
174            if !obj.contains_key("type") {
175                map.serialize_entry("type", "realtime")?;
176            }
177            for (k, v) in obj {
178                map.serialize_entry(&k, &v)?;
179            }
180        }
181        map.end()
182    }
183}