Skip to main content

autoagents_speech/
types.rs

1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use std::sync::Arc;
3
4/// Model information
5#[derive(Clone, Debug, Serialize, Deserialize)]
6pub struct ModelInfo {
7    /// Model identifier
8    pub id: String,
9    /// Model name
10    pub name: String,
11    /// Model description
12    pub description: Option<String>,
13    /// Supported languages
14    pub languages: Vec<String>,
15}
16
17/// Audio data with normalized samples
18#[derive(Clone, Debug)]
19pub struct AudioData {
20    /// Audio samples normalized to [-1.0, 1.0]
21    pub samples: Vec<f32>,
22    /// Number of audio channels (typically 1 for mono)
23    pub channels: usize,
24    /// Sample rate in Hz
25    pub sample_rate: u32,
26}
27
28impl Serialize for AudioData {
29    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
30    where
31        S: Serializer,
32    {
33        use serde::ser::SerializeStruct;
34
35        let mut state = serializer.serialize_struct("AudioData", 3)?;
36
37        // Serialize samples as base64
38        let bytes: Vec<u8> = self.samples.iter().flat_map(|f| f.to_le_bytes()).collect();
39        let base64_samples =
40            base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &bytes);
41
42        state.serialize_field("samples", &base64_samples)?;
43        state.serialize_field("channels", &self.channels)?;
44        state.serialize_field("sample_rate", &self.sample_rate)?;
45        state.end()
46    }
47}
48
49impl<'de> Deserialize<'de> for AudioData {
50    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
51    where
52        D: Deserializer<'de>,
53    {
54        #[derive(Deserialize)]
55        struct AudioDataHelper {
56            samples: String,
57            channels: usize,
58            sample_rate: u32,
59        }
60
61        let helper = AudioDataHelper::deserialize(deserializer)?;
62
63        // Deserialize samples from base64
64        let bytes =
65            base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &helper.samples)
66                .map_err(serde::de::Error::custom)?;
67
68        let samples: Vec<f32> = bytes
69            .chunks_exact(4)
70            .map(|chunk| {
71                let arr: [u8; 4] = [chunk[0], chunk[1], chunk[2], chunk[3]];
72                f32::from_le_bytes(arr)
73            })
74            .collect();
75
76        Ok(AudioData {
77            samples,
78            channels: helper.channels,
79            sample_rate: helper.sample_rate,
80        })
81    }
82}
83
84/// Shared reference to audio data for memory efficiency
85pub type SharedAudioData = Arc<AudioData>;
86
87/// Audio format for output
88#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
89pub enum AudioFormat {
90    #[default]
91    Wav,
92    Mp3,
93    Flac,
94    Ogg,
95}
96
97/// Voice identifier for TTS generation (predefined voices only)
98#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
99pub struct VoiceIdentifier {
100    /// Predefined voice name (e.g., "alba", "marius")
101    pub name: String,
102}
103
104impl VoiceIdentifier {
105    /// Create a voice identifier from a predefined voice name
106    pub fn new(name: impl Into<String>) -> Self {
107        Self { name: name.into() }
108    }
109
110    /// Get the voice name
111    pub fn name(&self) -> &str {
112        &self.name
113    }
114}
115
116impl From<String> for VoiceIdentifier {
117    fn from(name: String) -> Self {
118        Self::new(name)
119    }
120}
121
122impl From<&str> for VoiceIdentifier {
123    fn from(name: &str) -> Self {
124        Self::new(name)
125    }
126}
127
128/// Speech generation request
129#[derive(Clone, Debug)]
130pub struct SpeechRequest {
131    pub text: String,
132    pub voice: VoiceIdentifier,
133    pub format: AudioFormat,
134    pub sample_rate: Option<u32>,
135}
136
137/// Speech generation response
138#[derive(Clone, Debug)]
139pub struct SpeechResponse {
140    pub audio: AudioData,
141    pub text: String,
142    pub duration_ms: u64,
143}
144
145/// Audio chunk for streaming TTS
146#[derive(Clone, Debug)]
147pub struct AudioChunk {
148    pub samples: Vec<f32>,
149    pub is_final: bool,
150}
151
152/// Timestamp for a token in transcription
153#[derive(Clone, Debug, Serialize, Deserialize)]
154pub struct TokenTimestamp {
155    /// Token text
156    pub text: String,
157    /// Start time in seconds
158    pub start: f32,
159    /// End time in seconds
160    pub end: f32,
161}
162
163/// Transcription request for STT
164#[derive(Clone, Debug)]
165pub struct TranscriptionRequest {
166    /// Audio input to transcribe (shared to avoid copies at segment boundaries)
167    pub audio: SharedAudioData,
168    /// Optional language hint (for multilingual models)
169    pub language: Option<String>,
170    /// Whether to include timestamps
171    pub include_timestamps: bool,
172}
173
174/// Transcription response from STT
175#[derive(Clone, Debug, Serialize, Deserialize)]
176pub struct TranscriptionResponse {
177    /// Transcribed text
178    pub text: String,
179    /// Optional token-level timestamps
180    pub timestamps: Option<Vec<TokenTimestamp>>,
181    /// Processing duration in milliseconds
182    pub duration_ms: u64,
183}
184
185/// Text chunk for streaming STT
186#[derive(Clone, Debug, Serialize, Deserialize)]
187pub struct TextChunk {
188    /// Partial or final transcribed text
189    pub text: String,
190    /// Whether this is the final chunk
191    pub is_final: bool,
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_audio_data_serialization() {
200        let audio = AudioData {
201            samples: vec![0.0, 0.5, -0.5, 1.0],
202            channels: 1,
203            sample_rate: 24000,
204        };
205
206        let json = serde_json::to_string(&audio).unwrap();
207        let deserialized: AudioData = serde_json::from_str(&json).unwrap();
208
209        assert_eq!(audio.samples.len(), deserialized.samples.len());
210        assert_eq!(audio.channels, deserialized.channels);
211        assert_eq!(audio.sample_rate, deserialized.sample_rate);
212
213        for (a, b) in audio.samples.iter().zip(deserialized.samples.iter()) {
214            assert!((a - b).abs() < 1e-6);
215        }
216    }
217
218    #[test]
219    fn test_voice_identifier_serialization() {
220        let voice = VoiceIdentifier::new("alba");
221        let json = serde_json::to_string(&voice).unwrap();
222        let deserialized: VoiceIdentifier = serde_json::from_str(&json).unwrap();
223
224        assert_eq!(deserialized.name, "alba");
225    }
226
227    #[test]
228    fn test_voice_identifier_from_string() {
229        let voice: VoiceIdentifier = "marius".into();
230        assert_eq!(voice.name(), "marius");
231    }
232}