Skip to main content

adk_realtime/
audio.rs

1//! Audio format definitions and utilities.
2
3use serde::{Deserialize, Serialize};
4
5/// Audio encoding formats supported by realtime APIs.
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
7#[serde(rename_all = "lowercase")]
8pub enum AudioEncoding {
9    /// 16-bit PCM audio (most common).
10    #[serde(rename = "pcm16")]
11    #[default]
12    Pcm16,
13    /// G.711 μ-law encoding.
14    #[serde(rename = "g711_ulaw")]
15    G711Ulaw,
16    /// G.711 A-law encoding.
17    #[serde(rename = "g711_alaw")]
18    G711Alaw,
19}
20
21impl std::fmt::Display for AudioEncoding {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            Self::Pcm16 => write!(f, "pcm16"),
25            Self::G711Ulaw => write!(f, "g711_ulaw"),
26            Self::G711Alaw => write!(f, "g711_alaw"),
27        }
28    }
29}
30
31/// Complete audio format specification.
32#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
33pub struct AudioFormat {
34    /// Sample rate in Hz (e.g., 24000, 16000, 8000).
35    pub sample_rate: u32,
36    /// Number of audio channels (1 = mono, 2 = stereo).
37    pub channels: u8,
38    /// Bits per sample.
39    pub bits_per_sample: u8,
40    /// Audio encoding format.
41    pub encoding: AudioEncoding,
42}
43
44impl Default for AudioFormat {
45    fn default() -> Self {
46        Self::pcm16_24khz()
47    }
48}
49
50impl AudioFormat {
51    /// Create a new audio format specification.
52    pub fn new(
53        sample_rate: u32,
54        channels: u8,
55        bits_per_sample: u8,
56        encoding: AudioEncoding,
57    ) -> Self {
58        Self { sample_rate, channels, bits_per_sample, encoding }
59    }
60
61    /// Standard PCM16 format at 24kHz (OpenAI default).
62    pub fn pcm16_24khz() -> Self {
63        Self {
64            sample_rate: 24000,
65            channels: 1,
66            bits_per_sample: 16,
67            encoding: AudioEncoding::Pcm16,
68        }
69    }
70
71    /// PCM16 format at 16kHz (Gemini input default).
72    pub fn pcm16_16khz() -> Self {
73        Self {
74            sample_rate: 16000,
75            channels: 1,
76            bits_per_sample: 16,
77            encoding: AudioEncoding::Pcm16,
78        }
79    }
80
81    /// G.711 μ-law format at 8kHz (telephony standard).
82    pub fn g711_ulaw() -> Self {
83        Self {
84            sample_rate: 8000,
85            channels: 1,
86            bits_per_sample: 8,
87            encoding: AudioEncoding::G711Ulaw,
88        }
89    }
90
91    /// G.711 A-law format at 8kHz (telephony standard).
92    pub fn g711_alaw() -> Self {
93        Self {
94            sample_rate: 8000,
95            channels: 1,
96            bits_per_sample: 8,
97            encoding: AudioEncoding::G711Alaw,
98        }
99    }
100
101    /// Calculate bytes per second for this format.
102    pub fn bytes_per_second(&self) -> u32 {
103        self.sample_rate * self.channels as u32 * (self.bits_per_sample / 8) as u32
104    }
105
106    /// Calculate duration in milliseconds for a given number of bytes.
107    pub fn duration_ms(&self, bytes: usize) -> f64 {
108        let bytes_per_ms = self.bytes_per_second() as f64 / 1000.0;
109        bytes as f64 / bytes_per_ms
110    }
111}
112
113/// Audio chunk with format information.
114#[derive(Debug, Clone)]
115pub struct AudioChunk {
116    /// Raw audio data.
117    pub data: Vec<u8>,
118    /// Audio format of this chunk.
119    pub format: AudioFormat,
120}
121
122impl AudioChunk {
123    /// Create a new audio chunk.
124    pub fn new(data: Vec<u8>, format: AudioFormat) -> Self {
125        Self { data, format }
126    }
127
128    /// Create a PCM16 24kHz audio chunk (OpenAI format).
129    pub fn pcm16_24khz(data: Vec<u8>) -> Self {
130        Self::new(data, AudioFormat::pcm16_24khz())
131    }
132
133    /// Create a PCM16 16kHz audio chunk (Gemini input format).
134    pub fn pcm16_16khz(data: Vec<u8>) -> Self {
135        Self::new(data, AudioFormat::pcm16_16khz())
136    }
137
138    /// Get duration of this audio chunk in milliseconds.
139    pub fn duration_ms(&self) -> f64 {
140        self.format.duration_ms(self.data.len())
141    }
142
143    /// Encode audio data as base64.
144    pub fn to_base64(&self) -> String {
145        use base64::Engine;
146        base64::engine::general_purpose::STANDARD.encode(&self.data)
147    }
148
149    /// Decode audio data from base64.
150    pub fn from_base64(encoded: &str, format: AudioFormat) -> Result<Self, base64::DecodeError> {
151        use base64::Engine;
152        let data = base64::engine::general_purpose::STANDARD.decode(encoded)?;
153        Ok(Self::new(data, format))
154    }
155
156    /// Create an AudioChunk from i16 samples (converts to PCM16 little-endian bytes).
157    ///
158    /// This is useful when working with audio APIs (like LiveKit) that provide
159    /// samples as `i16` slices rather than raw byte buffers.
160    pub fn from_i16_samples(samples: &[i16], format: AudioFormat) -> Self {
161        let mut data = Vec::with_capacity(samples.len() * 2);
162        for sample in samples {
163            data.extend_from_slice(&sample.to_le_bytes());
164        }
165        Self::new(data, format)
166    }
167
168    /// Convert the audio data to a vector of i16 samples (assuming PCM16 little-endian).
169    ///
170    /// Returns an error string if the data length is not even (not valid PCM16).
171    pub fn to_i16_samples(&self) -> Result<Vec<i16>, String> {
172        if self.data.len() % 2 != 0 {
173            return Err(format!(
174                "Invalid data length for PCM16: {} (must be even)",
175                self.data.len()
176            ));
177        }
178        let mut samples = Vec::with_capacity(self.data.len() / 2);
179        for chunk in self.data.chunks_exact(2) {
180            samples.push(i16::from_le_bytes([chunk[0], chunk[1]]));
181        }
182        Ok(samples)
183    }
184}
185
186/// Buffers audio samples until a target duration is reached.
187///
188/// Smart buffering (e.g., 200ms) is essential for AI voice services to:
189/// 1. **Reduce Network Overhead**: Aggregating small frames into larger chunks
190///    drastically reduces packet rate, lowering CPU usage and bandwidth overhead.
191/// 2. **Improve Model Performance**: Provides sufficient context for Voice Activity
192///    Detection (VAD) to distinguish speech from noise.
193/// 3. **Resist Jitter**: Smooths out network jitter common in mobile networks.
194/// 4. **Latency Trade-off**: Maintains a real-time feel while gaining stability.
195#[derive(Debug, Clone)]
196pub struct SmartAudioBuffer {
197    buffer: Vec<i16>,
198    sample_rate: u32,
199    target_duration_ms: u32,
200}
201
202impl SmartAudioBuffer {
203    /// Create a new smart audio buffer.
204    pub fn new(sample_rate: u32, target_duration_ms: u32) -> Self {
205        Self { buffer: Vec::new(), sample_rate, target_duration_ms }
206    }
207
208    /// Push new samples into the buffer.
209    pub fn push(&mut self, samples: &[i16]) {
210        self.buffer.extend_from_slice(samples);
211    }
212
213    fn should_flush(&self) -> bool {
214        let duration_ms = (self.buffer.len() as f64 / self.sample_rate as f64) * 1000.0;
215
216        duration_ms >= self.target_duration_ms as f64
217    }
218
219    /// Flush the buffer if the target duration has been reached.
220    pub fn flush(&mut self) -> Option<Vec<i16>> {
221        if self.should_flush() { Some(std::mem::take(&mut self.buffer)) } else { None }
222    }
223
224    /// Flush any remaining samples in the buffer.
225    pub fn flush_remaining(&mut self) -> Option<Vec<i16>> {
226        if self.buffer.is_empty() { None } else { Some(std::mem::take(&mut self.buffer)) }
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    #[test]
235    fn test_smart_audio_buffer_flush_threshold() {
236        let sample_rate = 1000;
237        let target_ms = 100;
238        // 1000 samples/sec -> 1 sample = 1ms.
239        // target 100ms -> 100 samples.
240
241        let mut buffer = SmartAudioBuffer::new(sample_rate, target_ms);
242
243        // Push 50 samples (50ms)
244        buffer.push(&[0; 50]);
245        assert!(buffer.flush().is_none());
246
247        // Push 49 samples (total 99ms)
248        buffer.push(&[0; 49]);
249        assert!(buffer.flush().is_none());
250
251        // Push 1 sample (total 100ms)
252        buffer.push(&[0; 1]);
253        let flushed = buffer.flush();
254        assert!(flushed.is_some());
255        assert_eq!(flushed.unwrap().len(), 100);
256        assert!(buffer.buffer.is_empty());
257    }
258
259    #[test]
260    fn test_smart_audio_buffer_flush_remaining() {
261        let sample_rate = 1000;
262        let target_ms = 100;
263        let mut buffer = SmartAudioBuffer::new(sample_rate, target_ms);
264
265        buffer.push(&[0; 50]);
266        assert!(buffer.flush().is_none());
267
268        let remaining = buffer.flush_remaining();
269        assert!(remaining.is_some());
270        assert_eq!(remaining.unwrap().len(), 50);
271        assert!(buffer.buffer.is_empty());
272    }
273
274    #[test]
275    fn test_smart_audio_buffer_empty_flush() {
276        let mut buffer = SmartAudioBuffer::new(1000, 100);
277        assert!(buffer.flush().is_none());
278        assert!(buffer.flush_remaining().is_none());
279    }
280
281    #[test]
282    fn test_audio_format_bytes_per_second() {
283        let pcm16_24k = AudioFormat::pcm16_24khz();
284        assert_eq!(pcm16_24k.bytes_per_second(), 48000); // 24000 * 1 * 2
285
286        let pcm16_16k = AudioFormat::pcm16_16khz();
287        assert_eq!(pcm16_16k.bytes_per_second(), 32000); // 16000 * 1 * 2
288    }
289
290    #[test]
291    fn test_audio_format_duration() {
292        let format = AudioFormat::pcm16_24khz();
293        // 48000 bytes = 1 second
294        let duration = format.duration_ms(48000);
295        assert!((duration - 1000.0).abs() < 0.001);
296    }
297
298    #[test]
299    fn test_audio_chunk_base64() {
300        let original = AudioChunk::pcm16_24khz(vec![0, 1, 2, 3, 4, 5]);
301        let encoded = original.to_base64();
302        let decoded = AudioChunk::from_base64(&encoded, AudioFormat::pcm16_24khz()).unwrap();
303        assert_eq!(original.data, decoded.data);
304    }
305
306    #[test]
307    fn test_i16_samples_roundtrip() {
308        let samples: Vec<i16> = vec![0, 1, -1, 32767, -32768, 1000, -1000];
309        let chunk = AudioChunk::from_i16_samples(&samples, AudioFormat::pcm16_24khz());
310        let recovered = chunk.to_i16_samples().unwrap();
311        assert_eq!(samples, recovered);
312    }
313
314    #[test]
315    fn test_i16_samples_empty() {
316        let chunk = AudioChunk::from_i16_samples(&[], AudioFormat::pcm16_24khz());
317        assert!(chunk.data.is_empty());
318        assert_eq!(chunk.to_i16_samples().unwrap(), Vec::<i16>::new());
319    }
320
321    #[test]
322    fn test_i16_samples_odd_bytes_error() {
323        let chunk = AudioChunk::pcm16_24khz(vec![0, 1, 2]); // 3 bytes = invalid PCM16
324        assert!(chunk.to_i16_samples().is_err());
325    }
326}