use std::path::PathBuf;
use std::pin::Pin;
use async_trait::async_trait;
use bytes::Bytes;
use futures::Stream;
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct SpeechRequest {
pub input: String,
pub model: String,
pub voice: String,
pub response_format: Option<AudioFormat>,
pub speed: Option<f32>,
}
impl SpeechRequest {
pub fn new(
model: impl Into<String>,
input: impl Into<String>,
voice: impl Into<String>,
) -> Self {
Self {
input: input.into(),
model: model.into(),
voice: voice.into(),
response_format: None,
speed: None,
}
}
pub fn with_format(mut self, format: AudioFormat) -> Self {
self.response_format = Some(format);
self
}
pub fn with_speed(mut self, speed: f32) -> Self {
self.speed = Some(speed.clamp(0.25, 4.0));
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AudioFormat {
#[default]
Mp3,
Opus,
Aac,
Flac,
Wav,
Pcm,
}
impl AudioFormat {
pub fn extension(&self) -> &'static str {
match self {
AudioFormat::Mp3 => "mp3",
AudioFormat::Opus => "opus",
AudioFormat::Aac => "aac",
AudioFormat::Flac => "flac",
AudioFormat::Wav => "wav",
AudioFormat::Pcm => "pcm",
}
}
pub fn mime_type(&self) -> &'static str {
match self {
AudioFormat::Mp3 => "audio/mpeg",
AudioFormat::Opus => "audio/opus",
AudioFormat::Aac => "audio/aac",
AudioFormat::Flac => "audio/flac",
AudioFormat::Wav => "audio/wav",
AudioFormat::Pcm => "audio/L16",
}
}
}
#[derive(Debug, Clone)]
pub struct SpeechResponse {
pub audio: Vec<u8>,
pub format: AudioFormat,
pub duration_seconds: Option<f32>,
}
impl SpeechResponse {
pub fn new(audio: Vec<u8>, format: AudioFormat) -> Self {
Self {
audio,
format,
duration_seconds: None,
}
}
pub fn with_duration(mut self, duration: f32) -> Self {
self.duration_seconds = Some(duration);
self
}
pub fn save(&self, path: impl Into<PathBuf>) -> std::io::Result<()> {
std::fs::write(path.into(), &self.audio)
}
}
#[derive(Debug, Clone)]
pub struct VoiceInfo {
pub id: String,
pub name: String,
pub description: Option<String>,
pub gender: Option<String>,
pub locale: Option<String>,
}
#[async_trait]
pub trait SpeechProvider: Send + Sync {
fn name(&self) -> &str;
async fn speech(&self, request: SpeechRequest) -> Result<SpeechResponse>;
async fn speech_stream(
&self,
request: SpeechRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes>> + Send>>> {
let response = self.speech(request).await?;
let bytes = Bytes::from(response.audio);
let stream = futures::stream::once(async move { Ok(bytes) });
Ok(Box::pin(stream))
}
fn available_voices(&self) -> &[VoiceInfo] {
&[]
}
fn supported_formats(&self) -> &[AudioFormat] {
&[AudioFormat::Mp3]
}
fn default_speech_model(&self) -> Option<&str> {
None
}
}
#[derive(Debug, Clone)]
pub struct TranscriptionRequest {
pub audio: AudioInput,
pub model: String,
pub language: Option<String>,
pub prompt: Option<String>,
pub response_format: Option<TranscriptFormat>,
pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
}
impl TranscriptionRequest {
pub fn new(model: impl Into<String>, audio: AudioInput) -> Self {
Self {
audio,
model: model.into(),
language: None,
prompt: None,
response_format: None,
timestamp_granularities: None,
}
}
pub fn with_language(mut self, language: impl Into<String>) -> Self {
self.language = Some(language.into());
self
}
pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
self.prompt = Some(prompt.into());
self
}
pub fn with_format(mut self, format: TranscriptFormat) -> Self {
self.response_format = Some(format);
self
}
pub fn with_word_timestamps(mut self) -> Self {
self.timestamp_granularities = Some(vec![TimestampGranularity::Word]);
self
}
pub fn with_segment_timestamps(mut self) -> Self {
self.timestamp_granularities = Some(vec![TimestampGranularity::Segment]);
self
}
}
#[derive(Debug, Clone)]
pub enum AudioInput {
File(PathBuf),
Bytes {
data: Vec<u8>,
filename: String,
media_type: String,
},
Url(String),
}
impl AudioInput {
pub fn file(path: impl Into<PathBuf>) -> Self {
AudioInput::File(path.into())
}
pub fn bytes(
data: Vec<u8>,
filename: impl Into<String>,
media_type: impl Into<String>,
) -> Self {
AudioInput::Bytes {
data,
filename: filename.into(),
media_type: media_type.into(),
}
}
pub fn url(url: impl Into<String>) -> Self {
AudioInput::Url(url.into())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TranscriptFormat {
#[default]
Text,
Json,
VerboseJson,
Srt,
Vtt,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TimestampGranularity {
Word,
Segment,
}
#[derive(Debug, Clone)]
pub struct TranscriptionResponse {
pub text: String,
pub language: Option<String>,
pub duration: Option<f32>,
pub segments: Option<Vec<TranscriptSegment>>,
pub words: Option<Vec<TranscriptWord>>,
}
impl TranscriptionResponse {
pub fn new(text: impl Into<String>) -> Self {
Self {
text: text.into(),
language: None,
duration: None,
segments: None,
words: None,
}
}
pub fn with_language(mut self, language: impl Into<String>) -> Self {
self.language = Some(language.into());
self
}
pub fn with_duration(mut self, duration: f32) -> Self {
self.duration = Some(duration);
self
}
pub fn with_segments(mut self, segments: Vec<TranscriptSegment>) -> Self {
self.segments = Some(segments);
self
}
pub fn with_words(mut self, words: Vec<TranscriptWord>) -> Self {
self.words = Some(words);
self
}
}
#[derive(Debug, Clone)]
pub struct TranscriptSegment {
pub id: usize,
pub start: f32,
pub end: f32,
pub text: String,
}
#[derive(Debug, Clone)]
pub struct TranscriptWord {
pub word: String,
pub start: f32,
pub end: f32,
}
#[async_trait]
pub trait TranscriptionProvider: Send + Sync {
fn name(&self) -> &str;
async fn transcribe(&self, request: TranscriptionRequest) -> Result<TranscriptionResponse>;
async fn translate(&self, _request: TranscriptionRequest) -> Result<TranscriptionResponse> {
Err(Error::not_supported("Audio translation"))
}
fn supported_input_formats(&self) -> &[&str] {
&["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
}
fn max_file_size(&self) -> usize {
25 * 1024 * 1024 }
fn default_transcription_model(&self) -> Option<&str> {
None
}
}
#[derive(Debug, Clone)]
pub struct AudioModelInfo {
pub id: &'static str,
pub provider: &'static str,
pub model_type: AudioModelType,
pub languages: &'static [&'static str],
pub price_per_minute: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AudioModelType {
Tts,
Stt,
}
pub static AUDIO_MODELS: &[AudioModelInfo] = &[
AudioModelInfo {
id: "tts-1",
provider: "openai",
model_type: AudioModelType::Tts,
languages: &["en", "es", "fr", "de", "it", "pt", "ru", "zh", "ja", "ko"],
price_per_minute: 0.015,
},
AudioModelInfo {
id: "tts-1-hd",
provider: "openai",
model_type: AudioModelType::Tts,
languages: &["en", "es", "fr", "de", "it", "pt", "ru", "zh", "ja", "ko"],
price_per_minute: 0.030,
},
AudioModelInfo {
id: "whisper-1",
provider: "openai",
model_type: AudioModelType::Stt,
languages: &[
"en", "es", "fr", "de", "it", "pt", "ru", "zh", "ja", "ko", "ar", "hi",
],
price_per_minute: 0.006,
},
];
pub fn get_audio_model_info(model_id: &str) -> Option<&'static AudioModelInfo> {
AUDIO_MODELS.iter().find(|m| m.id == model_id)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_speech_request_builder() {
let request = SpeechRequest::new("tts-1", "Hello", "alloy")
.with_format(AudioFormat::Mp3)
.with_speed(1.5);
assert_eq!(request.model, "tts-1");
assert_eq!(request.input, "Hello");
assert_eq!(request.voice, "alloy");
assert_eq!(request.response_format, Some(AudioFormat::Mp3));
assert_eq!(request.speed, Some(1.5));
}
#[test]
fn test_speed_clamping() {
let request = SpeechRequest::new("tts-1", "test", "alloy").with_speed(10.0);
assert_eq!(request.speed, Some(4.0));
let request = SpeechRequest::new("tts-1", "test", "alloy").with_speed(0.1);
assert_eq!(request.speed, Some(0.25));
}
#[test]
fn test_audio_format() {
assert_eq!(AudioFormat::Mp3.extension(), "mp3");
assert_eq!(AudioFormat::Mp3.mime_type(), "audio/mpeg");
assert_eq!(AudioFormat::Opus.extension(), "opus");
}
#[test]
fn test_transcription_request_builder() {
let request = TranscriptionRequest::new("whisper-1", AudioInput::file("test.mp3"))
.with_language("en")
.with_word_timestamps();
assert_eq!(request.model, "whisper-1");
assert_eq!(request.language, Some("en".to_string()));
assert!(request.timestamp_granularities.is_some());
}
#[test]
fn test_audio_input() {
let file_input = AudioInput::file("audio.mp3");
assert!(matches!(file_input, AudioInput::File(_)));
let url_input = AudioInput::url("https://example.com/audio.mp3");
assert!(matches!(url_input, AudioInput::Url(_)));
}
#[test]
fn test_audio_model_registry() {
let model = get_audio_model_info("whisper-1");
assert!(model.is_some());
let model = model.unwrap();
assert_eq!(model.provider, "openai");
assert_eq!(model.model_type, AudioModelType::Stt);
}
}