llmkit/
audio.rs

1//! Audio APIs for text-to-speech (TTS) and speech-to-text (STT).
2//!
3//! This module provides unified interfaces for audio synthesis and transcription
4//! across various providers including OpenAI, ElevenLabs, Deepgram, and others.
5//!
6//! # Text-to-Speech Example
7//!
8//! ```ignore
9//! use llmkit::{SpeechProvider, SpeechRequest, AudioFormat};
10//!
11//! // Create provider
12//! let provider = OpenAIProvider::from_env()?;
13//!
14//! // Generate speech
15//! let request = SpeechRequest::new("tts-1", "Hello, world!", "alloy");
16//!
17//! let response = provider.speech(request).await?;
18//! std::fs::write("output.mp3", &response.audio)?;
19//! ```
20//!
21//! # Speech-to-Text Example
22//!
23//! ```ignore
24//! use llmkit::{TranscriptionProvider, TranscriptionRequest, AudioInput};
25//!
26//! // Create provider
27//! let provider = OpenAIProvider::from_env()?;
28//!
29//! // Transcribe audio
30//! let request = TranscriptionRequest::new("whisper-1", AudioInput::file("audio.mp3"));
31//!
32//! let response = provider.transcribe(request).await?;
33//! println!("Transcription: {}", response.text);
34//! ```
35
36use std::path::PathBuf;
37use std::pin::Pin;
38
39use async_trait::async_trait;
40use bytes::Bytes;
41use futures::Stream;
42use serde::{Deserialize, Serialize};
43
44use crate::error::{Error, Result};
45
46// ============================================================================
47// Text-to-Speech (TTS)
48// ============================================================================
49
50/// Request for generating speech from text.
51#[derive(Debug, Clone)]
52pub struct SpeechRequest {
53    /// The text to convert to speech.
54    pub input: String,
55    /// The model to use (e.g., "tts-1", "tts-1-hd").
56    pub model: String,
57    /// The voice to use (e.g., "alloy", "echo", "fable").
58    pub voice: String,
59    /// Audio format for the output.
60    pub response_format: Option<AudioFormat>,
61    /// Speed of speech (0.25 to 4.0, default 1.0).
62    pub speed: Option<f32>,
63}
64
65impl SpeechRequest {
66    /// Create a new speech request.
67    pub fn new(
68        model: impl Into<String>,
69        input: impl Into<String>,
70        voice: impl Into<String>,
71    ) -> Self {
72        Self {
73            input: input.into(),
74            model: model.into(),
75            voice: voice.into(),
76            response_format: None,
77            speed: None,
78        }
79    }
80
81    /// Set the audio format.
82    pub fn with_format(mut self, format: AudioFormat) -> Self {
83        self.response_format = Some(format);
84        self
85    }
86
87    /// Set the speech speed (0.25 to 4.0).
88    pub fn with_speed(mut self, speed: f32) -> Self {
89        self.speed = Some(speed.clamp(0.25, 4.0));
90        self
91    }
92}
93
94/// Audio output format.
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
96#[serde(rename_all = "lowercase")]
97pub enum AudioFormat {
98    /// MP3 audio (default).
99    #[default]
100    Mp3,
101    /// Opus audio (for WebRTC).
102    Opus,
103    /// AAC audio.
104    Aac,
105    /// FLAC lossless audio.
106    Flac,
107    /// WAV audio.
108    Wav,
109    /// Raw PCM audio.
110    Pcm,
111}
112
113impl AudioFormat {
114    /// Get the file extension for this format.
115    pub fn extension(&self) -> &'static str {
116        match self {
117            AudioFormat::Mp3 => "mp3",
118            AudioFormat::Opus => "opus",
119            AudioFormat::Aac => "aac",
120            AudioFormat::Flac => "flac",
121            AudioFormat::Wav => "wav",
122            AudioFormat::Pcm => "pcm",
123        }
124    }
125
126    /// Get the MIME type for this format.
127    pub fn mime_type(&self) -> &'static str {
128        match self {
129            AudioFormat::Mp3 => "audio/mpeg",
130            AudioFormat::Opus => "audio/opus",
131            AudioFormat::Aac => "audio/aac",
132            AudioFormat::Flac => "audio/flac",
133            AudioFormat::Wav => "audio/wav",
134            AudioFormat::Pcm => "audio/L16",
135        }
136    }
137}
138
139/// Response from a speech generation request.
140#[derive(Debug, Clone)]
141pub struct SpeechResponse {
142    /// The generated audio data.
143    pub audio: Vec<u8>,
144    /// The format of the audio.
145    pub format: AudioFormat,
146    /// Duration of the audio in seconds (if known).
147    pub duration_seconds: Option<f32>,
148}
149
150impl SpeechResponse {
151    /// Create a new speech response.
152    pub fn new(audio: Vec<u8>, format: AudioFormat) -> Self {
153        Self {
154            audio,
155            format,
156            duration_seconds: None,
157        }
158    }
159
160    /// Set the duration.
161    pub fn with_duration(mut self, duration: f32) -> Self {
162        self.duration_seconds = Some(duration);
163        self
164    }
165
166    /// Save the audio to a file.
167    pub fn save(&self, path: impl Into<PathBuf>) -> std::io::Result<()> {
168        std::fs::write(path.into(), &self.audio)
169    }
170}
171
172/// Information about a voice.
173#[derive(Debug, Clone)]
174pub struct VoiceInfo {
175    /// Voice ID.
176    pub id: String,
177    /// Voice name.
178    pub name: String,
179    /// Voice description.
180    pub description: Option<String>,
181    /// Voice gender (if applicable).
182    pub gender: Option<String>,
183    /// Language/locale.
184    pub locale: Option<String>,
185}
186
187/// Trait for providers that support text-to-speech.
188#[async_trait]
189pub trait SpeechProvider: Send + Sync {
190    /// Get the provider name.
191    fn name(&self) -> &str;
192
193    /// Generate speech from text.
194    async fn speech(&self, request: SpeechRequest) -> Result<SpeechResponse>;
195
196    /// Generate speech as a stream (for real-time playback).
197    async fn speech_stream(
198        &self,
199        request: SpeechRequest,
200    ) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes>> + Send>>> {
201        // Default implementation: generate full audio and yield as single chunk
202        let response = self.speech(request).await?;
203        let bytes = Bytes::from(response.audio);
204        let stream = futures::stream::once(async move { Ok(bytes) });
205        Ok(Box::pin(stream))
206    }
207
208    /// Get available voices for this provider.
209    fn available_voices(&self) -> &[VoiceInfo] {
210        &[]
211    }
212
213    /// Get supported audio formats.
214    fn supported_formats(&self) -> &[AudioFormat] {
215        &[AudioFormat::Mp3]
216    }
217
218    /// Get the default model for this provider.
219    fn default_speech_model(&self) -> Option<&str> {
220        None
221    }
222}
223
224// ============================================================================
225// Speech-to-Text (STT) / Transcription
226// ============================================================================
227
228/// Request for transcribing audio to text.
229#[derive(Debug, Clone)]
230pub struct TranscriptionRequest {
231    /// The audio to transcribe.
232    pub audio: AudioInput,
233    /// The model to use (e.g., "whisper-1").
234    pub model: String,
235    /// Language of the audio (ISO-639-1 code).
236    pub language: Option<String>,
237    /// Prompt to guide transcription style.
238    pub prompt: Option<String>,
239    /// Response format.
240    pub response_format: Option<TranscriptFormat>,
241    /// Timestamp granularities to include.
242    pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
243}
244
245impl TranscriptionRequest {
246    /// Create a new transcription request.
247    pub fn new(model: impl Into<String>, audio: AudioInput) -> Self {
248        Self {
249            audio,
250            model: model.into(),
251            language: None,
252            prompt: None,
253            response_format: None,
254            timestamp_granularities: None,
255        }
256    }
257
258    /// Set the language.
259    pub fn with_language(mut self, language: impl Into<String>) -> Self {
260        self.language = Some(language.into());
261        self
262    }
263
264    /// Set a prompt to guide transcription.
265    pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
266        self.prompt = Some(prompt.into());
267        self
268    }
269
270    /// Set the response format.
271    pub fn with_format(mut self, format: TranscriptFormat) -> Self {
272        self.response_format = Some(format);
273        self
274    }
275
276    /// Enable word-level timestamps.
277    pub fn with_word_timestamps(mut self) -> Self {
278        self.timestamp_granularities = Some(vec![TimestampGranularity::Word]);
279        self
280    }
281
282    /// Enable segment-level timestamps.
283    pub fn with_segment_timestamps(mut self) -> Self {
284        self.timestamp_granularities = Some(vec![TimestampGranularity::Segment]);
285        self
286    }
287}
288
289/// Input audio source.
290#[derive(Debug, Clone)]
291pub enum AudioInput {
292    /// Path to a local audio file.
293    File(PathBuf),
294    /// Audio data in memory.
295    Bytes {
296        data: Vec<u8>,
297        filename: String,
298        media_type: String,
299    },
300    /// URL to an audio file.
301    Url(String),
302}
303
304impl AudioInput {
305    /// Create an input from a file path.
306    pub fn file(path: impl Into<PathBuf>) -> Self {
307        AudioInput::File(path.into())
308    }
309
310    /// Create an input from bytes.
311    pub fn bytes(
312        data: Vec<u8>,
313        filename: impl Into<String>,
314        media_type: impl Into<String>,
315    ) -> Self {
316        AudioInput::Bytes {
317            data,
318            filename: filename.into(),
319            media_type: media_type.into(),
320        }
321    }
322
323    /// Create an input from a URL.
324    pub fn url(url: impl Into<String>) -> Self {
325        AudioInput::Url(url.into())
326    }
327}
328
329/// Response format for transcription.
330#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
331#[serde(rename_all = "snake_case")]
332pub enum TranscriptFormat {
333    /// Plain text.
334    #[default]
335    Text,
336    /// JSON with metadata.
337    Json,
338    /// Verbose JSON with timing info.
339    VerboseJson,
340    /// SRT subtitles.
341    Srt,
342    /// VTT subtitles.
343    Vtt,
344}
345
346/// Timestamp granularity options.
347#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
348#[serde(rename_all = "lowercase")]
349pub enum TimestampGranularity {
350    /// Word-level timestamps.
351    Word,
352    /// Segment-level timestamps.
353    Segment,
354}
355
356/// Response from a transcription request.
357#[derive(Debug, Clone)]
358pub struct TranscriptionResponse {
359    /// The transcribed text.
360    pub text: String,
361    /// Detected language.
362    pub language: Option<String>,
363    /// Duration of the audio in seconds.
364    pub duration: Option<f32>,
365    /// Transcript segments with timing.
366    pub segments: Option<Vec<TranscriptSegment>>,
367    /// Word-level timing information.
368    pub words: Option<Vec<TranscriptWord>>,
369}
370
371impl TranscriptionResponse {
372    /// Create a new transcription response.
373    pub fn new(text: impl Into<String>) -> Self {
374        Self {
375            text: text.into(),
376            language: None,
377            duration: None,
378            segments: None,
379            words: None,
380        }
381    }
382
383    /// Set the language.
384    pub fn with_language(mut self, language: impl Into<String>) -> Self {
385        self.language = Some(language.into());
386        self
387    }
388
389    /// Set the duration.
390    pub fn with_duration(mut self, duration: f32) -> Self {
391        self.duration = Some(duration);
392        self
393    }
394
395    /// Set segments.
396    pub fn with_segments(mut self, segments: Vec<TranscriptSegment>) -> Self {
397        self.segments = Some(segments);
398        self
399    }
400
401    /// Set words.
402    pub fn with_words(mut self, words: Vec<TranscriptWord>) -> Self {
403        self.words = Some(words);
404        self
405    }
406}
407
408/// A segment of the transcript with timing.
409#[derive(Debug, Clone)]
410pub struct TranscriptSegment {
411    /// Segment index.
412    pub id: usize,
413    /// Start time in seconds.
414    pub start: f32,
415    /// End time in seconds.
416    pub end: f32,
417    /// Segment text.
418    pub text: String,
419}
420
421/// A word with timing information.
422#[derive(Debug, Clone)]
423pub struct TranscriptWord {
424    /// The word.
425    pub word: String,
426    /// Start time in seconds.
427    pub start: f32,
428    /// End time in seconds.
429    pub end: f32,
430}
431
432/// Trait for providers that support speech-to-text transcription.
433#[async_trait]
434pub trait TranscriptionProvider: Send + Sync {
435    /// Get the provider name.
436    fn name(&self) -> &str;
437
438    /// Transcribe audio to text.
439    async fn transcribe(&self, request: TranscriptionRequest) -> Result<TranscriptionResponse>;
440
441    /// Translate audio to English text (for non-English audio).
442    async fn translate(&self, _request: TranscriptionRequest) -> Result<TranscriptionResponse> {
443        Err(Error::not_supported("Audio translation"))
444    }
445
446    /// Get supported audio formats for input.
447    fn supported_input_formats(&self) -> &[&str] {
448        &["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
449    }
450
451    /// Get maximum file size in bytes.
452    fn max_file_size(&self) -> usize {
453        25 * 1024 * 1024 // 25 MB default
454    }
455
456    /// Get the default model for this provider.
457    fn default_transcription_model(&self) -> Option<&str> {
458        None
459    }
460}
461
462/// Information about an audio model.
463#[derive(Debug, Clone)]
464pub struct AudioModelInfo {
465    /// Model ID/name.
466    pub id: &'static str,
467    /// Provider that offers this model.
468    pub provider: &'static str,
469    /// Model type (TTS or STT).
470    pub model_type: AudioModelType,
471    /// Supported languages.
472    pub languages: &'static [&'static str],
473    /// Price per minute (USD).
474    pub price_per_minute: f64,
475}
476
477/// Type of audio model.
478#[derive(Debug, Clone, Copy, PartialEq, Eq)]
479pub enum AudioModelType {
480    /// Text-to-speech.
481    Tts,
482    /// Speech-to-text.
483    Stt,
484}
485
486/// Registry of known audio models.
487pub static AUDIO_MODELS: &[AudioModelInfo] = &[
488    // OpenAI TTS
489    AudioModelInfo {
490        id: "tts-1",
491        provider: "openai",
492        model_type: AudioModelType::Tts,
493        languages: &["en", "es", "fr", "de", "it", "pt", "ru", "zh", "ja", "ko"],
494        price_per_minute: 0.015,
495    },
496    AudioModelInfo {
497        id: "tts-1-hd",
498        provider: "openai",
499        model_type: AudioModelType::Tts,
500        languages: &["en", "es", "fr", "de", "it", "pt", "ru", "zh", "ja", "ko"],
501        price_per_minute: 0.030,
502    },
503    // OpenAI STT
504    AudioModelInfo {
505        id: "whisper-1",
506        provider: "openai",
507        model_type: AudioModelType::Stt,
508        languages: &[
509            "en", "es", "fr", "de", "it", "pt", "ru", "zh", "ja", "ko", "ar", "hi",
510        ],
511        price_per_minute: 0.006,
512    },
513];
514
515/// Get audio model info by ID.
516pub fn get_audio_model_info(model_id: &str) -> Option<&'static AudioModelInfo> {
517    AUDIO_MODELS.iter().find(|m| m.id == model_id)
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523
524    #[test]
525    fn test_speech_request_builder() {
526        let request = SpeechRequest::new("tts-1", "Hello", "alloy")
527            .with_format(AudioFormat::Mp3)
528            .with_speed(1.5);
529
530        assert_eq!(request.model, "tts-1");
531        assert_eq!(request.input, "Hello");
532        assert_eq!(request.voice, "alloy");
533        assert_eq!(request.response_format, Some(AudioFormat::Mp3));
534        assert_eq!(request.speed, Some(1.5));
535    }
536
537    #[test]
538    fn test_speed_clamping() {
539        let request = SpeechRequest::new("tts-1", "test", "alloy").with_speed(10.0);
540        assert_eq!(request.speed, Some(4.0));
541
542        let request = SpeechRequest::new("tts-1", "test", "alloy").with_speed(0.1);
543        assert_eq!(request.speed, Some(0.25));
544    }
545
546    #[test]
547    fn test_audio_format() {
548        assert_eq!(AudioFormat::Mp3.extension(), "mp3");
549        assert_eq!(AudioFormat::Mp3.mime_type(), "audio/mpeg");
550        assert_eq!(AudioFormat::Opus.extension(), "opus");
551    }
552
553    #[test]
554    fn test_transcription_request_builder() {
555        let request = TranscriptionRequest::new("whisper-1", AudioInput::file("test.mp3"))
556            .with_language("en")
557            .with_word_timestamps();
558
559        assert_eq!(request.model, "whisper-1");
560        assert_eq!(request.language, Some("en".to_string()));
561        assert!(request.timestamp_granularities.is_some());
562    }
563
564    #[test]
565    fn test_audio_input() {
566        let file_input = AudioInput::file("audio.mp3");
567        assert!(matches!(file_input, AudioInput::File(_)));
568
569        let url_input = AudioInput::url("https://example.com/audio.mp3");
570        assert!(matches!(url_input, AudioInput::Url(_)));
571    }
572
573    #[test]
574    fn test_audio_model_registry() {
575        let model = get_audio_model_info("whisper-1");
576        assert!(model.is_some());
577        let model = model.unwrap();
578        assert_eq!(model.provider, "openai");
579        assert_eq!(model.model_type, AudioModelType::Stt);
580    }
581}