1use std::path::PathBuf;
37use std::pin::Pin;
38
39use async_trait::async_trait;
40use bytes::Bytes;
41use futures::Stream;
42use serde::{Deserialize, Serialize};
43
44use crate::error::{Error, Result};
45
46#[derive(Debug, Clone)]
52pub struct SpeechRequest {
53 pub input: String,
55 pub model: String,
57 pub voice: String,
59 pub response_format: Option<AudioFormat>,
61 pub speed: Option<f32>,
63}
64
65impl SpeechRequest {
66 pub fn new(
68 model: impl Into<String>,
69 input: impl Into<String>,
70 voice: impl Into<String>,
71 ) -> Self {
72 Self {
73 input: input.into(),
74 model: model.into(),
75 voice: voice.into(),
76 response_format: None,
77 speed: None,
78 }
79 }
80
81 pub fn with_format(mut self, format: AudioFormat) -> Self {
83 self.response_format = Some(format);
84 self
85 }
86
87 pub fn with_speed(mut self, speed: f32) -> Self {
89 self.speed = Some(speed.clamp(0.25, 4.0));
90 self
91 }
92}
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
96#[serde(rename_all = "lowercase")]
97pub enum AudioFormat {
98 #[default]
100 Mp3,
101 Opus,
103 Aac,
105 Flac,
107 Wav,
109 Pcm,
111}
112
113impl AudioFormat {
114 pub fn extension(&self) -> &'static str {
116 match self {
117 AudioFormat::Mp3 => "mp3",
118 AudioFormat::Opus => "opus",
119 AudioFormat::Aac => "aac",
120 AudioFormat::Flac => "flac",
121 AudioFormat::Wav => "wav",
122 AudioFormat::Pcm => "pcm",
123 }
124 }
125
126 pub fn mime_type(&self) -> &'static str {
128 match self {
129 AudioFormat::Mp3 => "audio/mpeg",
130 AudioFormat::Opus => "audio/opus",
131 AudioFormat::Aac => "audio/aac",
132 AudioFormat::Flac => "audio/flac",
133 AudioFormat::Wav => "audio/wav",
134 AudioFormat::Pcm => "audio/L16",
135 }
136 }
137}
138
139#[derive(Debug, Clone)]
141pub struct SpeechResponse {
142 pub audio: Vec<u8>,
144 pub format: AudioFormat,
146 pub duration_seconds: Option<f32>,
148}
149
150impl SpeechResponse {
151 pub fn new(audio: Vec<u8>, format: AudioFormat) -> Self {
153 Self {
154 audio,
155 format,
156 duration_seconds: None,
157 }
158 }
159
160 pub fn with_duration(mut self, duration: f32) -> Self {
162 self.duration_seconds = Some(duration);
163 self
164 }
165
166 pub fn save(&self, path: impl Into<PathBuf>) -> std::io::Result<()> {
168 std::fs::write(path.into(), &self.audio)
169 }
170}
171
172#[derive(Debug, Clone)]
174pub struct VoiceInfo {
175 pub id: String,
177 pub name: String,
179 pub description: Option<String>,
181 pub gender: Option<String>,
183 pub locale: Option<String>,
185}
186
187#[async_trait]
189pub trait SpeechProvider: Send + Sync {
190 fn name(&self) -> &str;
192
193 async fn speech(&self, request: SpeechRequest) -> Result<SpeechResponse>;
195
196 async fn speech_stream(
198 &self,
199 request: SpeechRequest,
200 ) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes>> + Send>>> {
201 let response = self.speech(request).await?;
203 let bytes = Bytes::from(response.audio);
204 let stream = futures::stream::once(async move { Ok(bytes) });
205 Ok(Box::pin(stream))
206 }
207
208 fn available_voices(&self) -> &[VoiceInfo] {
210 &[]
211 }
212
213 fn supported_formats(&self) -> &[AudioFormat] {
215 &[AudioFormat::Mp3]
216 }
217
218 fn default_speech_model(&self) -> Option<&str> {
220 None
221 }
222}
223
224#[derive(Debug, Clone)]
230pub struct TranscriptionRequest {
231 pub audio: AudioInput,
233 pub model: String,
235 pub language: Option<String>,
237 pub prompt: Option<String>,
239 pub response_format: Option<TranscriptFormat>,
241 pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
243}
244
245impl TranscriptionRequest {
246 pub fn new(model: impl Into<String>, audio: AudioInput) -> Self {
248 Self {
249 audio,
250 model: model.into(),
251 language: None,
252 prompt: None,
253 response_format: None,
254 timestamp_granularities: None,
255 }
256 }
257
258 pub fn with_language(mut self, language: impl Into<String>) -> Self {
260 self.language = Some(language.into());
261 self
262 }
263
264 pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
266 self.prompt = Some(prompt.into());
267 self
268 }
269
270 pub fn with_format(mut self, format: TranscriptFormat) -> Self {
272 self.response_format = Some(format);
273 self
274 }
275
276 pub fn with_word_timestamps(mut self) -> Self {
278 self.timestamp_granularities = Some(vec![TimestampGranularity::Word]);
279 self
280 }
281
282 pub fn with_segment_timestamps(mut self) -> Self {
284 self.timestamp_granularities = Some(vec![TimestampGranularity::Segment]);
285 self
286 }
287}
288
289#[derive(Debug, Clone)]
291pub enum AudioInput {
292 File(PathBuf),
294 Bytes {
296 data: Vec<u8>,
297 filename: String,
298 media_type: String,
299 },
300 Url(String),
302}
303
304impl AudioInput {
305 pub fn file(path: impl Into<PathBuf>) -> Self {
307 AudioInput::File(path.into())
308 }
309
310 pub fn bytes(
312 data: Vec<u8>,
313 filename: impl Into<String>,
314 media_type: impl Into<String>,
315 ) -> Self {
316 AudioInput::Bytes {
317 data,
318 filename: filename.into(),
319 media_type: media_type.into(),
320 }
321 }
322
323 pub fn url(url: impl Into<String>) -> Self {
325 AudioInput::Url(url.into())
326 }
327}
328
329#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
331#[serde(rename_all = "snake_case")]
332pub enum TranscriptFormat {
333 #[default]
335 Text,
336 Json,
338 VerboseJson,
340 Srt,
342 Vtt,
344}
345
346#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
348#[serde(rename_all = "lowercase")]
349pub enum TimestampGranularity {
350 Word,
352 Segment,
354}
355
356#[derive(Debug, Clone)]
358pub struct TranscriptionResponse {
359 pub text: String,
361 pub language: Option<String>,
363 pub duration: Option<f32>,
365 pub segments: Option<Vec<TranscriptSegment>>,
367 pub words: Option<Vec<TranscriptWord>>,
369}
370
371impl TranscriptionResponse {
372 pub fn new(text: impl Into<String>) -> Self {
374 Self {
375 text: text.into(),
376 language: None,
377 duration: None,
378 segments: None,
379 words: None,
380 }
381 }
382
383 pub fn with_language(mut self, language: impl Into<String>) -> Self {
385 self.language = Some(language.into());
386 self
387 }
388
389 pub fn with_duration(mut self, duration: f32) -> Self {
391 self.duration = Some(duration);
392 self
393 }
394
395 pub fn with_segments(mut self, segments: Vec<TranscriptSegment>) -> Self {
397 self.segments = Some(segments);
398 self
399 }
400
401 pub fn with_words(mut self, words: Vec<TranscriptWord>) -> Self {
403 self.words = Some(words);
404 self
405 }
406}
407
408#[derive(Debug, Clone)]
410pub struct TranscriptSegment {
411 pub id: usize,
413 pub start: f32,
415 pub end: f32,
417 pub text: String,
419}
420
421#[derive(Debug, Clone)]
423pub struct TranscriptWord {
424 pub word: String,
426 pub start: f32,
428 pub end: f32,
430}
431
432#[async_trait]
434pub trait TranscriptionProvider: Send + Sync {
435 fn name(&self) -> &str;
437
438 async fn transcribe(&self, request: TranscriptionRequest) -> Result<TranscriptionResponse>;
440
441 async fn translate(&self, _request: TranscriptionRequest) -> Result<TranscriptionResponse> {
443 Err(Error::not_supported("Audio translation"))
444 }
445
446 fn supported_input_formats(&self) -> &[&str] {
448 &["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
449 }
450
451 fn max_file_size(&self) -> usize {
453 25 * 1024 * 1024 }
455
456 fn default_transcription_model(&self) -> Option<&str> {
458 None
459 }
460}
461
462#[derive(Debug, Clone)]
464pub struct AudioModelInfo {
465 pub id: &'static str,
467 pub provider: &'static str,
469 pub model_type: AudioModelType,
471 pub languages: &'static [&'static str],
473 pub price_per_minute: f64,
475}
476
477#[derive(Debug, Clone, Copy, PartialEq, Eq)]
479pub enum AudioModelType {
480 Tts,
482 Stt,
484}
485
486pub static AUDIO_MODELS: &[AudioModelInfo] = &[
488 AudioModelInfo {
490 id: "tts-1",
491 provider: "openai",
492 model_type: AudioModelType::Tts,
493 languages: &["en", "es", "fr", "de", "it", "pt", "ru", "zh", "ja", "ko"],
494 price_per_minute: 0.015,
495 },
496 AudioModelInfo {
497 id: "tts-1-hd",
498 provider: "openai",
499 model_type: AudioModelType::Tts,
500 languages: &["en", "es", "fr", "de", "it", "pt", "ru", "zh", "ja", "ko"],
501 price_per_minute: 0.030,
502 },
503 AudioModelInfo {
505 id: "whisper-1",
506 provider: "openai",
507 model_type: AudioModelType::Stt,
508 languages: &[
509 "en", "es", "fr", "de", "it", "pt", "ru", "zh", "ja", "ko", "ar", "hi",
510 ],
511 price_per_minute: 0.006,
512 },
513];
514
515pub fn get_audio_model_info(model_id: &str) -> Option<&'static AudioModelInfo> {
517 AUDIO_MODELS.iter().find(|m| m.id == model_id)
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523
524 #[test]
525 fn test_speech_request_builder() {
526 let request = SpeechRequest::new("tts-1", "Hello", "alloy")
527 .with_format(AudioFormat::Mp3)
528 .with_speed(1.5);
529
530 assert_eq!(request.model, "tts-1");
531 assert_eq!(request.input, "Hello");
532 assert_eq!(request.voice, "alloy");
533 assert_eq!(request.response_format, Some(AudioFormat::Mp3));
534 assert_eq!(request.speed, Some(1.5));
535 }
536
537 #[test]
538 fn test_speed_clamping() {
539 let request = SpeechRequest::new("tts-1", "test", "alloy").with_speed(10.0);
540 assert_eq!(request.speed, Some(4.0));
541
542 let request = SpeechRequest::new("tts-1", "test", "alloy").with_speed(0.1);
543 assert_eq!(request.speed, Some(0.25));
544 }
545
546 #[test]
547 fn test_audio_format() {
548 assert_eq!(AudioFormat::Mp3.extension(), "mp3");
549 assert_eq!(AudioFormat::Mp3.mime_type(), "audio/mpeg");
550 assert_eq!(AudioFormat::Opus.extension(), "opus");
551 }
552
553 #[test]
554 fn test_transcription_request_builder() {
555 let request = TranscriptionRequest::new("whisper-1", AudioInput::file("test.mp3"))
556 .with_language("en")
557 .with_word_timestamps();
558
559 assert_eq!(request.model, "whisper-1");
560 assert_eq!(request.language, Some("en".to_string()));
561 assert!(request.timestamp_granularities.is_some());
562 }
563
564 #[test]
565 fn test_audio_input() {
566 let file_input = AudioInput::file("audio.mp3");
567 assert!(matches!(file_input, AudioInput::File(_)));
568
569 let url_input = AudioInput::url("https://example.com/audio.mp3");
570 assert!(matches!(url_input, AudioInput::Url(_)));
571 }
572
573 #[test]
574 fn test_audio_model_registry() {
575 let model = get_audio_model_info("whisper-1");
576 assert!(model.is_some());
577 let model = model.unwrap();
578 assert_eq!(model.provider, "openai");
579 assert_eq!(model.model_type, AudioModelType::Stt);
580 }
581}