1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use std::sync::Arc;
3
4#[derive(Clone, Debug, Serialize, Deserialize)]
6pub struct ModelInfo {
7 pub id: String,
9 pub name: String,
11 pub description: Option<String>,
13 pub languages: Vec<String>,
15}
16
17#[derive(Clone, Debug)]
19pub struct AudioData {
20 pub samples: Vec<f32>,
22 pub channels: usize,
24 pub sample_rate: u32,
26}
27
28impl Serialize for AudioData {
29 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
30 where
31 S: Serializer,
32 {
33 use serde::ser::SerializeStruct;
34
35 let mut state = serializer.serialize_struct("AudioData", 3)?;
36
37 let bytes: Vec<u8> = self.samples.iter().flat_map(|f| f.to_le_bytes()).collect();
39 let base64_samples =
40 base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &bytes);
41
42 state.serialize_field("samples", &base64_samples)?;
43 state.serialize_field("channels", &self.channels)?;
44 state.serialize_field("sample_rate", &self.sample_rate)?;
45 state.end()
46 }
47}
48
49impl<'de> Deserialize<'de> for AudioData {
50 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
51 where
52 D: Deserializer<'de>,
53 {
54 #[derive(Deserialize)]
55 struct AudioDataHelper {
56 samples: String,
57 channels: usize,
58 sample_rate: u32,
59 }
60
61 let helper = AudioDataHelper::deserialize(deserializer)?;
62
63 let bytes =
65 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &helper.samples)
66 .map_err(serde::de::Error::custom)?;
67
68 let samples: Vec<f32> = bytes
69 .chunks_exact(4)
70 .map(|chunk| {
71 let arr: [u8; 4] = [chunk[0], chunk[1], chunk[2], chunk[3]];
72 f32::from_le_bytes(arr)
73 })
74 .collect();
75
76 Ok(AudioData {
77 samples,
78 channels: helper.channels,
79 sample_rate: helper.sample_rate,
80 })
81 }
82}
83
84pub type SharedAudioData = Arc<AudioData>;
86
87#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
89pub enum AudioFormat {
90 #[default]
91 Wav,
92 Mp3,
93 Flac,
94 Ogg,
95}
96
97#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
99pub struct VoiceIdentifier {
100 pub name: String,
102}
103
104impl VoiceIdentifier {
105 pub fn new(name: impl Into<String>) -> Self {
107 Self { name: name.into() }
108 }
109
110 pub fn name(&self) -> &str {
112 &self.name
113 }
114}
115
116impl From<String> for VoiceIdentifier {
117 fn from(name: String) -> Self {
118 Self::new(name)
119 }
120}
121
122impl From<&str> for VoiceIdentifier {
123 fn from(name: &str) -> Self {
124 Self::new(name)
125 }
126}
127
128#[derive(Clone, Debug)]
130pub struct SpeechRequest {
131 pub text: String,
132 pub voice: VoiceIdentifier,
133 pub format: AudioFormat,
134 pub sample_rate: Option<u32>,
135}
136
137#[derive(Clone, Debug)]
139pub struct SpeechResponse {
140 pub audio: AudioData,
141 pub text: String,
142 pub duration_ms: u64,
143}
144
145#[derive(Clone, Debug)]
147pub struct AudioChunk {
148 pub samples: Vec<f32>,
149 pub is_final: bool,
150}
151
152#[derive(Clone, Debug, Serialize, Deserialize)]
154pub struct TokenTimestamp {
155 pub text: String,
157 pub start: f32,
159 pub end: f32,
161}
162
163#[derive(Clone, Debug)]
165pub struct TranscriptionRequest {
166 pub audio: SharedAudioData,
168 pub language: Option<String>,
170 pub include_timestamps: bool,
172}
173
174#[derive(Clone, Debug, Serialize, Deserialize)]
176pub struct TranscriptionResponse {
177 pub text: String,
179 pub timestamps: Option<Vec<TokenTimestamp>>,
181 pub duration_ms: u64,
183}
184
185#[derive(Clone, Debug, Serialize, Deserialize)]
187pub struct TextChunk {
188 pub text: String,
190 pub is_final: bool,
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_audio_data_serialization() {
200 let audio = AudioData {
201 samples: vec![0.0, 0.5, -0.5, 1.0],
202 channels: 1,
203 sample_rate: 24000,
204 };
205
206 let json = serde_json::to_string(&audio).unwrap();
207 let deserialized: AudioData = serde_json::from_str(&json).unwrap();
208
209 assert_eq!(audio.samples.len(), deserialized.samples.len());
210 assert_eq!(audio.channels, deserialized.channels);
211 assert_eq!(audio.sample_rate, deserialized.sample_rate);
212
213 for (a, b) in audio.samples.iter().zip(deserialized.samples.iter()) {
214 assert!((a - b).abs() < 1e-6);
215 }
216 }
217
218 #[test]
219 fn test_voice_identifier_serialization() {
220 let voice = VoiceIdentifier::new("alba");
221 let json = serde_json::to_string(&voice).unwrap();
222 let deserialized: VoiceIdentifier = serde_json::from_str(&json).unwrap();
223
224 assert_eq!(deserialized.name, "alba");
225 }
226
227 #[test]
228 fn test_voice_identifier_from_string() {
229 let voice: VoiceIdentifier = "marius".into();
230 assert_eq!(voice.name(), "marius");
231 }
232}