oai_rt_rs/protocol/models/
session.rs1use serde::ser::SerializeMap;
2use serde::{Deserialize, Serialize};
3
4use super::{
5 AudioConfig, AudioFormat, InputAudioTranscription, MaxTokens, Modality, Nullable,
6 OutputModalities, PromptRef, Temperature, Tool, ToolChoice, TurnDetection, Voice,
7};
8
9#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
10#[serde(rename_all = "snake_case")]
11pub enum SessionKind {
12 #[default]
13 Realtime,
14 Transcription,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
18#[serde(rename_all = "snake_case")]
19pub enum TracingAuto {
20 Auto,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct TracingConfig {
25 pub workflow_name: Option<String>,
26 pub group_id: Option<String>,
27 pub metadata: Option<super::Metadata>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32#[serde(untagged)]
33pub enum Tracing {
34 Auto(TracingAuto),
35 Config(TracingConfig),
36}
37
38#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum TruncationStrategy {
41 Auto,
42 Disabled,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
46#[serde(rename_all = "snake_case")]
47pub enum TruncationType {
48 RetentionRatio,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
52pub struct TokenLimits {
53 pub post_instructions: Option<u32>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57pub struct RetentionRatioTruncation {
58 #[serde(rename = "type")]
59 pub kind: TruncationType,
60 pub retention_ratio: f32,
61 pub token_limits: Option<TokenLimits>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65#[serde(untagged)]
66pub enum Truncation {
67 Strategy(TruncationStrategy),
68 RetentionRatio(RetentionRatioTruncation),
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct SessionConfig {
73 #[serde(rename = "type")]
74 pub kind: SessionKind,
75 pub model: String,
76 pub output_modalities: OutputModalities,
77 pub modalities: Option<Vec<Modality>>,
78 pub include: Option<Vec<String>>,
79 pub prompt: Option<PromptRef>,
80 pub truncation: Option<Truncation>,
81 pub instructions: Option<String>,
82 pub input_audio_format: Option<AudioFormat>,
83 pub output_audio_format: Option<AudioFormat>,
84 pub input_audio_transcription: Option<Nullable<InputAudioTranscription>>,
85 pub turn_detection: Option<Nullable<TurnDetection>>,
86 pub tools: Option<Vec<Tool>>,
87 pub tool_choice: Option<ToolChoice>,
88 pub temperature: Option<Temperature>,
89 pub max_output_tokens: Option<MaxTokens>,
90 pub audio: Option<AudioConfig>,
91 pub tracing: Option<Tracing>,
92 pub voice: Option<Voice>,
93}
94
95impl SessionConfig {
96 #[must_use]
97 pub fn new(
98 kind: SessionKind,
99 model: impl Into<String>,
100 output_modalities: OutputModalities,
101 ) -> Self {
102 Self {
103 kind,
104 model: model.into(),
105 output_modalities,
106 modalities: None,
107 include: None,
108 prompt: None,
109 truncation: None,
110 instructions: None,
111 input_audio_format: None,
112 output_audio_format: None,
113 input_audio_transcription: None,
114 turn_detection: None,
115 tools: None,
116 tool_choice: None,
117 temperature: None,
118 max_output_tokens: None,
119 audio: None,
120 tracing: None,
121 voice: None,
122 }
123 }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize, Default)]
127pub struct SessionUpdateConfig {
128 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
130 pub kind: Option<SessionKind>,
131 pub output_modalities: Option<OutputModalities>,
132 pub modalities: Option<Vec<Modality>>,
133 pub include: Option<Vec<String>>,
134 pub prompt: Option<PromptRef>,
135 pub truncation: Option<Truncation>,
136 pub instructions: Option<String>,
137 pub input_audio_format: Option<AudioFormat>,
138 pub output_audio_format: Option<AudioFormat>,
139 pub input_audio_transcription: Option<Nullable<InputAudioTranscription>>,
140 pub turn_detection: Option<Nullable<TurnDetection>>,
141 pub tools: Option<Vec<Tool>>,
142 pub tool_choice: Option<ToolChoice>,
143 pub temperature: Option<Temperature>,
144 pub max_output_tokens: Option<MaxTokens>,
145 pub audio: Option<AudioConfig>,
146 pub tracing: Option<Tracing>,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct Session {
151 pub id: String,
152 pub object: String,
153 pub expires_at: u64,
154 #[serde(flatten)]
156 pub config: SessionConfig,
157}
158
159#[derive(Debug, Clone, Deserialize, Default)]
160pub struct SessionUpdate {
161 #[serde(flatten)]
163 pub config: SessionUpdateConfig,
164}
165
166impl Serialize for SessionUpdate {
167 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
168 where
169 S: serde::Serializer,
170 {
171 let mut map = serializer.serialize_map(None)?;
172 let value = serde_json::to_value(&self.config).map_err(serde::ser::Error::custom)?;
173 if let serde_json::Value::Object(obj) = value {
174 if !obj.contains_key("type") {
175 map.serialize_entry("type", "realtime")?;
176 }
177 for (k, v) in obj {
178 map.serialize_entry(&k, &v)?;
179 }
180 }
181 map.end()
182 }
183}