use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::LlmError;
use crate::traits::AudioCapability;
use crate::types::{
AudioFeature, AudioTranslationRequest, LanguageInfo, SttRequest, SttResponse, TtsRequest,
TtsResponse, VoiceInfo, WordTimestamp,
};
use super::config::OpenAiConfig;
#[derive(Debug, Clone, Serialize)]
struct OpenAiTtsRequest {
model: String,
input: String,
voice: String,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
speed: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
instructions: Option<String>,
}
#[derive(Debug, Clone)]
struct OpenAiSttRequest {
file_data: Vec<u8>,
filename: String,
model: String,
language: Option<String>,
prompt: Option<String>,
response_format: Option<String>,
temperature: Option<f32>,
timestamp_granularities: Option<Vec<String>>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiSttResponse {
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
language: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
duration: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
words: Option<Vec<OpenAiWordTimestamp>>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiWordTimestamp {
word: String,
start: f32,
end: f32,
}
#[derive(Debug, Clone)]
pub struct OpenAiAudio {
config: OpenAiConfig,
http_client: reqwest::Client,
features: Vec<AudioFeature>,
}
impl OpenAiAudio {
pub fn new(config: OpenAiConfig, http_client: reqwest::Client) -> Self {
let features = vec![
AudioFeature::TextToSpeech,
AudioFeature::SpeechToText,
AudioFeature::AudioTranslation,
AudioFeature::CharacterTiming,
];
Self {
config,
http_client,
features,
}
}
fn get_tts_voices(&self) -> Vec<VoiceInfo> {
vec![
VoiceInfo {
id: "alloy".to_string(),
name: "Alloy".to_string(),
description: Some("Neutral, balanced voice".to_string()),
language: Some("en".to_string()),
gender: Some("neutral".to_string()),
category: Some("standard".to_string()),
},
VoiceInfo {
id: "ash".to_string(),
name: "Ash".to_string(),
description: Some("Warm, expressive voice".to_string()),
language: Some("en".to_string()),
gender: Some("neutral".to_string()),
category: Some("standard".to_string()),
},
VoiceInfo {
id: "ballad".to_string(),
name: "Ballad".to_string(),
description: Some("Melodic, storytelling voice".to_string()),
language: Some("en".to_string()),
gender: Some("neutral".to_string()),
category: Some("standard".to_string()),
},
VoiceInfo {
id: "coral".to_string(),
name: "Coral".to_string(),
description: Some("Bright, cheerful voice".to_string()),
language: Some("en".to_string()),
gender: Some("female".to_string()),
category: Some("standard".to_string()),
},
VoiceInfo {
id: "echo".to_string(),
name: "Echo".to_string(),
description: Some("Male voice".to_string()),
language: Some("en".to_string()),
gender: Some("male".to_string()),
category: Some("standard".to_string()),
},
VoiceInfo {
id: "fable".to_string(),
name: "Fable".to_string(),
description: Some("British accent".to_string()),
language: Some("en".to_string()),
gender: Some("male".to_string()),
category: Some("standard".to_string()),
},
VoiceInfo {
id: "nova".to_string(),
name: "Nova".to_string(),
description: Some("Female voice".to_string()),
language: Some("en".to_string()),
gender: Some("female".to_string()),
category: Some("standard".to_string()),
},
VoiceInfo {
id: "onyx".to_string(),
name: "Onyx".to_string(),
description: Some("Deep male voice".to_string()),
language: Some("en".to_string()),
gender: Some("male".to_string()),
category: Some("standard".to_string()),
},
VoiceInfo {
id: "sage".to_string(),
name: "Sage".to_string(),
description: Some("Wise, thoughtful voice".to_string()),
language: Some("en".to_string()),
gender: Some("neutral".to_string()),
category: Some("standard".to_string()),
},
VoiceInfo {
id: "shimmer".to_string(),
name: "Shimmer".to_string(),
description: Some("Soft female voice".to_string()),
language: Some("en".to_string()),
gender: Some("female".to_string()),
category: Some("standard".to_string()),
},
VoiceInfo {
id: "verse".to_string(),
name: "Verse".to_string(),
description: Some("Poetic, rhythmic voice".to_string()),
language: Some("en".to_string()),
gender: Some("neutral".to_string()),
category: Some("standard".to_string()),
},
]
}
async fn make_tts_request(&self, request: OpenAiTtsRequest) -> Result<Vec<u8>, LlmError> {
let url = format!("{}/audio/speech", self.config.base_url);
let mut headers = reqwest::header::HeaderMap::new();
for (key, value) in self.config.get_headers() {
let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())
.map_err(|e| LlmError::HttpError(format!("Invalid header name: {e}")))?;
let header_value = reqwest::header::HeaderValue::from_str(&value)
.map_err(|e| LlmError::HttpError(format!("Invalid header value: {e}")))?;
headers.insert(header_name, header_value);
}
let response = self
.http_client
.post(&url)
.headers(headers)
.json(&request)
.send()
.await
.map_err(|e| LlmError::HttpError(format!("TTS request failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::ApiError {
code: status.as_u16(),
message: format!("OpenAI TTS API error {status}: {error_text}"),
details: None,
});
}
let audio_data = response
.bytes()
.await
.map_err(|e| LlmError::HttpError(format!("Failed to read audio data: {e}")))?;
Ok(audio_data.to_vec())
}
async fn make_stt_request(
&self,
request: OpenAiSttRequest,
) -> Result<OpenAiSttResponse, LlmError> {
let url = format!("{}/audio/transcriptions", self.config.base_url);
let mut form = reqwest::multipart::Form::new();
let file_part = reqwest::multipart::Part::bytes(request.file_data)
.file_name(request.filename)
.mime_str("audio/mpeg")
.map_err(|e| LlmError::HttpError(format!("Failed to create file part: {e}")))?;
form = form.part("file", file_part);
form = form.text("model", request.model);
if let Some(language) = request.language {
form = form.text("language", language);
}
if let Some(prompt) = request.prompt {
form = form.text("prompt", prompt);
}
if let Some(format) = request.response_format {
form = form.text("response_format", format);
}
if let Some(temp) = request.temperature {
form = form.text("temperature", temp.to_string());
}
if let Some(granularities) = request.timestamp_granularities {
for granularity in granularities {
form = form.text("timestamp_granularities[]", granularity);
}
}
let mut headers = reqwest::header::HeaderMap::new();
for (key, value) in self.config.get_headers() {
if key == "Content-Type" {
continue;
}
let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())
.map_err(|e| LlmError::HttpError(format!("Invalid header name: {e}")))?;
let header_value = reqwest::header::HeaderValue::from_str(&value)
.map_err(|e| LlmError::HttpError(format!("Invalid header value: {e}")))?;
headers.insert(header_name, header_value);
}
let response = self
.http_client
.post(&url)
.headers(headers)
.multipart(form)
.send()
.await
.map_err(|e| LlmError::HttpError(format!("STT request failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::ApiError {
code: status.as_u16(),
message: format!("OpenAI STT API error {status}: {error_text}"),
details: None,
});
}
let openai_response: OpenAiSttResponse = response
.json()
.await
.map_err(|e| LlmError::ParseError(format!("Failed to parse STT response: {e}")))?;
Ok(openai_response)
}
fn convert_stt_response(&self, openai_response: OpenAiSttResponse) -> SttResponse {
let words = openai_response.words.map(|words| {
words
.into_iter()
.map(|w| WordTimestamp {
word: w.word,
start: w.start,
end: w.end,
confidence: None, })
.collect()
});
SttResponse {
text: openai_response.text,
language: openai_response.language,
confidence: None, words,
duration: openai_response.duration,
metadata: HashMap::new(),
}
}
fn get_supported_tts_models(&self) -> Vec<String> {
vec![
"tts-1".to_string(),
"tts-1-hd".to_string(),
"gpt-4o-mini-tts".to_string(), ]
}
fn validate_tts_request(
&self,
model: &str,
instructions: &Option<String>,
) -> Result<(), LlmError> {
if !self.get_supported_tts_models().contains(&model.to_string()) {
return Err(LlmError::InvalidInput(format!(
"Unsupported TTS model: {}. Supported models: {}",
model,
self.get_supported_tts_models().join(", ")
)));
}
if let Some(instructions) = instructions {
if model == "tts-1" || model == "tts-1-hd" {
return Err(LlmError::InvalidInput(
"Instructions parameter is not supported for tts-1 and tts-1-hd models"
.to_string(),
));
}
if instructions.len() > 4096 {
return Err(LlmError::InvalidInput(
"Instructions cannot exceed 4096 characters".to_string(),
));
}
}
Ok(())
}
fn is_voice_supported(&self, voice: &str) -> bool {
self.get_tts_voices().iter().any(|v| v.id == voice)
}
}
#[async_trait]
impl AudioCapability for OpenAiAudio {
fn supported_features(&self) -> &[AudioFeature] {
&self.features
}
async fn text_to_speech(&self, request: TtsRequest) -> Result<TtsResponse, LlmError> {
let voice = request.voice.unwrap_or_else(|| "alloy".to_string());
let format = request.format.unwrap_or_else(|| "mp3".to_string());
let model = request.model.unwrap_or_else(|| "tts-1".to_string());
let instructions = request
.extra_params
.get("instructions")
.and_then(|v| v.as_str())
.map(std::string::ToString::to_string);
if !self.is_voice_supported(&voice) {
return Err(LlmError::InvalidInput(format!(
"Unsupported voice: {}. Supported voices: {}",
voice,
self.get_tts_voices()
.iter()
.map(|v| v.id.as_str())
.collect::<Vec<_>>()
.join(", ")
)));
}
self.validate_tts_request(&model, &instructions)?;
let openai_request = OpenAiTtsRequest {
model,
input: request.text,
voice,
response_format: Some(format.clone()),
speed: request.speed,
instructions,
};
let audio_data = self.make_tts_request(openai_request).await?;
Ok(TtsResponse {
audio_data,
format,
duration: None, sample_rate: None,
metadata: HashMap::new(),
})
}
async fn speech_to_text(&self, request: SttRequest) -> Result<SttResponse, LlmError> {
let (file_data, filename) = if let Some(data) = request.audio_data {
(data, "audio.mp3".to_string())
} else if let Some(path) = request.file_path {
let data = std::fs::read(&path)
.map_err(|e| LlmError::IoError(format!("Failed to read audio file: {e}")))?;
let filename = std::path::Path::new(&path)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("audio.mp3")
.to_string();
(data, filename)
} else {
return Err(LlmError::InvalidInput(
"Either audio_data or file_path must be provided".to_string(),
));
};
let model = request.model.unwrap_or_else(|| "whisper-1".to_string());
let openai_request = OpenAiSttRequest {
file_data,
filename,
model,
language: request.language,
prompt: None, response_format: Some("verbose_json".to_string()),
temperature: None,
timestamp_granularities: request.timestamp_granularities,
};
let openai_response = self.make_stt_request(openai_request).await?;
Ok(self.convert_stt_response(openai_response))
}
async fn translate_audio(
&self,
request: AudioTranslationRequest,
) -> Result<SttResponse, LlmError> {
let url = format!("{}/audio/translations", self.config.base_url);
let (file_data, filename) = if let Some(data) = request.audio_data {
(data, "audio.mp3".to_string())
} else if let Some(path) = request.file_path {
let data = std::fs::read(&path)
.map_err(|e| LlmError::IoError(format!("Failed to read audio file: {e}")))?;
let filename = std::path::Path::new(&path)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("audio.mp3")
.to_string();
(data, filename)
} else {
return Err(LlmError::InvalidInput(
"Either audio_data or file_path must be provided".to_string(),
));
};
let mut form = reqwest::multipart::Form::new();
let file_part = reqwest::multipart::Part::bytes(file_data)
.file_name(filename)
.mime_str("audio/mpeg")
.map_err(|e| LlmError::HttpError(format!("Failed to create file part: {e}")))?;
form = form.part("file", file_part);
let model = request.model.unwrap_or_else(|| "whisper-1".to_string());
form = form.text("model", model);
form = form.text("response_format", "json");
let mut headers = reqwest::header::HeaderMap::new();
for (key, value) in self.config.get_headers() {
if key == "Content-Type" {
continue;
}
let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())
.map_err(|e| LlmError::HttpError(format!("Invalid header name: {e}")))?;
let header_value = reqwest::header::HeaderValue::from_str(&value)
.map_err(|e| LlmError::HttpError(format!("Invalid header value: {e}")))?;
headers.insert(header_name, header_value);
}
let response = self
.http_client
.post(&url)
.headers(headers)
.multipart(form)
.send()
.await
.map_err(|e| LlmError::HttpError(format!("Translation request failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::ApiError {
code: status.as_u16(),
message: format!("OpenAI Translation API error {status}: {error_text}"),
details: None,
});
}
let openai_response: OpenAiSttResponse = response.json().await.map_err(|e| {
LlmError::ParseError(format!("Failed to parse translation response: {e}"))
})?;
Ok(self.convert_stt_response(openai_response))
}
async fn get_voices(&self) -> Result<Vec<VoiceInfo>, LlmError> {
Ok(self.get_tts_voices())
}
async fn get_supported_languages(&self) -> Result<Vec<LanguageInfo>, LlmError> {
let languages = vec![
LanguageInfo {
code: "en".to_string(),
name: "English".to_string(),
supports_transcription: true,
supports_translation: true,
},
LanguageInfo {
code: "zh".to_string(),
name: "Chinese".to_string(),
supports_transcription: true,
supports_translation: true,
},
LanguageInfo {
code: "es".to_string(),
name: "Spanish".to_string(),
supports_transcription: true,
supports_translation: true,
},
LanguageInfo {
code: "fr".to_string(),
name: "French".to_string(),
supports_transcription: true,
supports_translation: true,
},
LanguageInfo {
code: "de".to_string(),
name: "German".to_string(),
supports_transcription: true,
supports_translation: true,
},
LanguageInfo {
code: "ja".to_string(),
name: "Japanese".to_string(),
supports_transcription: true,
supports_translation: true,
},
LanguageInfo {
code: "ko".to_string(),
name: "Korean".to_string(),
supports_transcription: true,
supports_translation: true,
},
];
Ok(languages)
}
fn get_supported_audio_formats(&self) -> Vec<String> {
vec![
"mp3".to_string(),
"mp4".to_string(),
"mpeg".to_string(),
"mpga".to_string(),
"m4a".to_string(),
"wav".to_string(),
"webm".to_string(),
]
}
}