1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use std::sync::Arc;
3
4#[derive(Clone, Debug, Serialize, Deserialize)]
6pub struct ModelInfo {
7 pub id: String,
9 pub name: String,
11 pub description: Option<String>,
13 pub languages: Vec<String>,
15}
16
17#[derive(Clone, Debug)]
19pub struct AudioData {
20 pub samples: Vec<f32>,
22 pub channels: usize,
24 pub sample_rate: u32,
26}
27
28impl Serialize for AudioData {
29 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
30 where
31 S: Serializer,
32 {
33 use serde::ser::SerializeStruct;
34
35 let mut state = serializer.serialize_struct("AudioData", 3)?;
36
37 let bytes: Vec<u8> = self.samples.iter().flat_map(|f| f.to_le_bytes()).collect();
39 let base64_samples =
40 base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &bytes);
41
42 state.serialize_field("samples", &base64_samples)?;
43 state.serialize_field("channels", &self.channels)?;
44 state.serialize_field("sample_rate", &self.sample_rate)?;
45 state.end()
46 }
47}
48
49impl<'de> Deserialize<'de> for AudioData {
50 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
51 where
52 D: Deserializer<'de>,
53 {
54 #[derive(Deserialize)]
55 struct AudioDataHelper {
56 samples: String,
57 channels: usize,
58 sample_rate: u32,
59 }
60
61 let helper = AudioDataHelper::deserialize(deserializer)?;
62
63 let bytes =
65 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &helper.samples)
66 .map_err(serde::de::Error::custom)?;
67
68 let samples: Vec<f32> = bytes
69 .chunks_exact(4)
70 .map(|chunk| {
71 let arr: [u8; 4] = [chunk[0], chunk[1], chunk[2], chunk[3]];
72 f32::from_le_bytes(arr)
73 })
74 .collect();
75
76 Ok(AudioData {
77 samples,
78 channels: helper.channels,
79 sample_rate: helper.sample_rate,
80 })
81 }
82}
83
84pub type SharedAudioData = Arc<AudioData>;
86
87#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
89pub enum AudioFormat {
90 #[default]
91 Wav,
92 Mp3,
93 Flac,
94 Ogg,
95}
96
97#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
99pub struct VoiceIdentifier {
100 pub name: String,
102}
103
104impl VoiceIdentifier {
105 pub fn new(name: impl Into<String>) -> Self {
107 Self { name: name.into() }
108 }
109
110 pub fn name(&self) -> &str {
112 &self.name
113 }
114}
115
116impl From<String> for VoiceIdentifier {
117 fn from(name: String) -> Self {
118 Self::new(name)
119 }
120}
121
122impl From<&str> for VoiceIdentifier {
123 fn from(name: &str) -> Self {
124 Self::new(name)
125 }
126}
127
128#[derive(Clone, Debug)]
130pub struct SpeechRequest {
131 pub text: String,
132 pub voice: VoiceIdentifier,
133 pub format: AudioFormat,
134 pub sample_rate: Option<u32>,
135}
136
137#[derive(Clone, Debug)]
139pub struct SpeechResponse {
140 pub audio: AudioData,
141 pub text: String,
142 pub duration_ms: u64,
143}
144
145#[derive(Clone, Debug)]
147pub struct AudioChunk {
148 pub samples: Vec<f32>,
149 pub sample_rate: u32,
150 pub is_final: bool,
151}
152
153#[derive(Clone, Debug, Serialize, Deserialize)]
155pub struct TokenTimestamp {
156 pub text: String,
158 pub start: f32,
160 pub end: f32,
162}
163
164#[derive(Clone, Debug)]
166pub struct TranscriptionRequest {
167 pub audio: SharedAudioData,
169 pub language: Option<String>,
171 pub include_timestamps: bool,
173}
174
175#[derive(Clone, Debug, Serialize, Deserialize)]
177pub struct TranscriptionResponse {
178 pub text: String,
180 pub timestamps: Option<Vec<TokenTimestamp>>,
182 pub duration_ms: u64,
184}
185
186#[derive(Clone, Debug, Serialize, Deserialize)]
188pub struct TextChunk {
189 pub text: String,
191 pub is_final: bool,
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn test_audio_data_serialization() {
201 let audio = AudioData {
202 samples: vec![0.0, 0.5, -0.5, 1.0],
203 channels: 1,
204 sample_rate: 24000,
205 };
206
207 let json = serde_json::to_string(&audio).unwrap();
208 let deserialized: AudioData = serde_json::from_str(&json).unwrap();
209
210 assert_eq!(audio.samples.len(), deserialized.samples.len());
211 assert_eq!(audio.channels, deserialized.channels);
212 assert_eq!(audio.sample_rate, deserialized.sample_rate);
213
214 for (a, b) in audio.samples.iter().zip(deserialized.samples.iter()) {
215 assert!((a - b).abs() < 1e-6);
216 }
217 }
218
219 #[test]
220 fn test_voice_identifier_serialization() {
221 let voice = VoiceIdentifier::new("alba");
222 let json = serde_json::to_string(&voice).unwrap();
223 let deserialized: VoiceIdentifier = serde_json::from_str(&json).unwrap();
224
225 assert_eq!(deserialized.name, "alba");
226 }
227
228 #[test]
229 fn test_voice_identifier_from_string() {
230 let voice: VoiceIdentifier = "marius".into();
231 assert_eq!(voice.name(), "marius");
232 }
233}