Skip to main content

autoagents_speech/
types.rs

1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use std::sync::Arc;
3
4/// Model information
5#[derive(Clone, Debug, Serialize, Deserialize)]
6pub struct ModelInfo {
7    /// Model identifier
8    pub id: String,
9    /// Model name
10    pub name: String,
11    /// Model description
12    pub description: Option<String>,
13    /// Supported languages
14    pub languages: Vec<String>,
15}
16
17/// Audio data with normalized samples
18#[derive(Clone, Debug)]
19pub struct AudioData {
20    /// Audio samples normalized to [-1.0, 1.0]
21    pub samples: Vec<f32>, //TODO: Need to check if we can optimize it using lower precission
22    /// Number of audio channels (typically 1 for mono)
23    pub channels: usize,
24    /// Sample rate in Hz
25    pub sample_rate: u32,
26}
27
28impl Serialize for AudioData {
29    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
30    where
31        S: Serializer,
32    {
33        use serde::ser::SerializeStruct;
34
35        let mut state = serializer.serialize_struct("AudioData", 3)?;
36
37        // Serialize samples as base64
38        let bytes: Vec<u8> = self.samples.iter().flat_map(|f| f.to_le_bytes()).collect();
39        let base64_samples =
40            base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &bytes);
41
42        state.serialize_field("samples", &base64_samples)?;
43        state.serialize_field("channels", &self.channels)?;
44        state.serialize_field("sample_rate", &self.sample_rate)?;
45        state.end()
46    }
47}
48
49impl<'de> Deserialize<'de> for AudioData {
50    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
51    where
52        D: Deserializer<'de>,
53    {
54        #[derive(Deserialize)]
55        struct AudioDataHelper {
56            samples: String,
57            channels: usize,
58            sample_rate: u32,
59        }
60
61        let helper = AudioDataHelper::deserialize(deserializer)?;
62
63        // Deserialize samples from base64
64        let bytes =
65            base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &helper.samples)
66                .map_err(serde::de::Error::custom)?;
67
68        let samples: Vec<f32> = bytes
69            .chunks_exact(4)
70            .map(|chunk| {
71                let arr: [u8; 4] = [chunk[0], chunk[1], chunk[2], chunk[3]];
72                f32::from_le_bytes(arr)
73            })
74            .collect();
75
76        Ok(AudioData {
77            samples,
78            channels: helper.channels,
79            sample_rate: helper.sample_rate,
80        })
81    }
82}
83
84/// Shared reference to audio data for memory efficiency
85pub type SharedAudioData = Arc<AudioData>;
86
87/// Audio format for output
88#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
89pub enum AudioFormat {
90    #[default]
91    Wav,
92    Mp3,
93    Flac,
94    Ogg,
95}
96
97/// Voice identifier for TTS generation (predefined voices only)
98#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
99pub struct VoiceIdentifier {
100    /// Predefined voice name (e.g., "alba", "marius")
101    pub name: String,
102}
103
104impl VoiceIdentifier {
105    /// Create a voice identifier from a predefined voice name
106    pub fn new(name: impl Into<String>) -> Self {
107        Self { name: name.into() }
108    }
109
110    /// Get the voice name
111    pub fn name(&self) -> &str {
112        &self.name
113    }
114}
115
116impl From<String> for VoiceIdentifier {
117    fn from(name: String) -> Self {
118        Self::new(name)
119    }
120}
121
122impl From<&str> for VoiceIdentifier {
123    fn from(name: &str) -> Self {
124        Self::new(name)
125    }
126}
127
128/// Speech generation request
129#[derive(Clone, Debug)]
130pub struct SpeechRequest {
131    pub text: String,
132    pub voice: VoiceIdentifier,
133    pub format: AudioFormat,
134    pub sample_rate: Option<u32>,
135}
136
137/// Speech generation response
138#[derive(Clone, Debug)]
139pub struct SpeechResponse {
140    pub audio: AudioData,
141    pub text: String,
142    pub duration_ms: u64,
143}
144
145/// Audio chunk for streaming
146#[derive(Clone, Debug)]
147pub struct AudioChunk {
148    pub samples: Vec<f32>,
149    pub is_final: bool,
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_audio_data_serialization() {
158        let audio = AudioData {
159            samples: vec![0.0, 0.5, -0.5, 1.0],
160            channels: 1,
161            sample_rate: 24000,
162        };
163
164        let json = serde_json::to_string(&audio).unwrap();
165        let deserialized: AudioData = serde_json::from_str(&json).unwrap();
166
167        assert_eq!(audio.samples.len(), deserialized.samples.len());
168        assert_eq!(audio.channels, deserialized.channels);
169        assert_eq!(audio.sample_rate, deserialized.sample_rate);
170
171        for (a, b) in audio.samples.iter().zip(deserialized.samples.iter()) {
172            assert!((a - b).abs() < 1e-6);
173        }
174    }
175
176    #[test]
177    fn test_voice_identifier_serialization() {
178        let voice = VoiceIdentifier::new("alba");
179        let json = serde_json::to_string(&voice).unwrap();
180        let deserialized: VoiceIdentifier = serde_json::from_str(&json).unwrap();
181
182        assert_eq!(deserialized.name, "alba");
183    }
184
185    #[test]
186    fn test_voice_identifier_from_string() {
187        let voice: VoiceIdentifier = "marius".into();
188        assert_eq!(voice.name(), "marius");
189    }
190}