use crate::audio::response::TranscriptionResponse;
use crate::common::auth::AuthProvider;
use crate::common::client::create_http_client;
use crate::common::errors::{ErrorResponse, OpenAIToolError, Result};
use request::multipart::{Form, Part};
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::time::Duration;
const AUDIO_PATH: &str = "audio";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum TtsModel {
#[serde(rename = "tts-1")]
#[default]
Tts1,
#[serde(rename = "tts-1-hd")]
Tts1Hd,
#[serde(rename = "gpt-4o-mini-tts")]
Gpt4oMiniTts,
}
impl TtsModel {
pub fn as_str(&self) -> &'static str {
match self {
Self::Tts1 => "tts-1",
Self::Tts1Hd => "tts-1-hd",
Self::Gpt4oMiniTts => "gpt-4o-mini-tts",
}
}
pub fn supports_instructions(&self) -> bool {
matches!(self, Self::Gpt4oMiniTts)
}
}
impl std::fmt::Display for TtsModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum Voice {
#[default]
Alloy,
Ash,
Ballad,
Cedar,
Coral,
Echo,
Fable,
Marin,
Nova,
Onyx,
Sage,
Shimmer,
Verse,
}
impl Voice {
pub fn as_str(&self) -> &'static str {
match self {
Self::Alloy => "alloy",
Self::Ash => "ash",
Self::Ballad => "ballad",
Self::Cedar => "cedar",
Self::Coral => "coral",
Self::Echo => "echo",
Self::Fable => "fable",
Self::Marin => "marin",
Self::Nova => "nova",
Self::Onyx => "onyx",
Self::Sage => "sage",
Self::Shimmer => "shimmer",
Self::Verse => "verse",
}
}
}
impl std::fmt::Display for Voice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum AudioFormat {
#[default]
Mp3,
Opus,
Aac,
Flac,
Wav,
Pcm,
}
impl AudioFormat {
pub fn as_str(&self) -> &'static str {
match self {
Self::Mp3 => "mp3",
Self::Opus => "opus",
Self::Aac => "aac",
Self::Flac => "flac",
Self::Wav => "wav",
Self::Pcm => "pcm",
}
}
pub fn file_extension(&self) -> &'static str {
self.as_str()
}
}
impl std::fmt::Display for AudioFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum SttModel {
#[serde(rename = "whisper-1")]
#[default]
Whisper1,
#[serde(rename = "gpt-4o-transcribe")]
Gpt4oTranscribe,
}
impl SttModel {
pub fn as_str(&self) -> &'static str {
match self {
Self::Whisper1 => "whisper-1",
Self::Gpt4oTranscribe => "gpt-4o-transcribe",
}
}
}
impl std::fmt::Display for SttModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum TranscriptionFormat {
#[default]
Json,
Text,
Srt,
VerboseJson,
Vtt,
}
impl TranscriptionFormat {
pub fn as_str(&self) -> &'static str {
match self {
Self::Json => "json",
Self::Text => "text",
Self::Srt => "srt",
Self::VerboseJson => "verbose_json",
Self::Vtt => "vtt",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TimestampGranularity {
Word,
Segment,
}
impl TimestampGranularity {
pub fn as_str(&self) -> &'static str {
match self {
Self::Word => "word",
Self::Segment => "segment",
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TtsOptions {
pub model: TtsModel,
pub voice: Voice,
pub response_format: AudioFormat,
pub speed: Option<f32>,
pub instructions: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct TranscribeOptions {
pub model: Option<SttModel>,
pub language: Option<String>,
pub prompt: Option<String>,
pub response_format: Option<TranscriptionFormat>,
pub temperature: Option<f32>,
pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
}
#[derive(Debug, Clone, Default)]
pub struct TranslateOptions {
pub model: Option<SttModel>,
pub prompt: Option<String>,
pub response_format: Option<TranscriptionFormat>,
pub temperature: Option<f32>,
}
#[derive(Debug, Clone, Serialize)]
struct TtsRequest {
model: String,
input: String,
voice: String,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
speed: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
instructions: Option<String>,
}
pub struct Audio {
auth: AuthProvider,
timeout: Option<Duration>,
}
impl Audio {
pub fn new() -> Result<Self> {
let auth = AuthProvider::openai_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_auth(auth: AuthProvider) -> Self {
Self { auth, timeout: None }
}
pub fn azure() -> Result<Self> {
let auth = AuthProvider::azure_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn detect_provider() -> Result<Self> {
let auth = AuthProvider::from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
let auth = AuthProvider::from_url_with_key(base_url, api_key);
Self { auth, timeout: None }
}
pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
let auth = AuthProvider::from_url(url)?;
Ok(Self { auth, timeout: None })
}
pub fn auth(&self) -> &AuthProvider {
&self.auth
}
pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
self.timeout = Some(timeout);
self
}
fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
let client = create_http_client(self.timeout)?;
let mut headers = request::header::HeaderMap::new();
self.auth.apply_headers(&mut headers)?;
headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
Ok((client, headers))
}
pub async fn text_to_speech(&self, text: &str, options: TtsOptions) -> Result<Vec<u8>> {
let (client, mut headers) = self.create_client()?;
headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
let instructions = if options.instructions.is_some() {
if options.model.supports_instructions() {
options.instructions
} else {
tracing::warn!("Model '{}' does not support instructions parameter. Ignoring instructions.", options.model);
None
}
} else {
None
};
let request_body = TtsRequest {
model: options.model.as_str().to_string(),
input: text.to_string(),
voice: options.voice.as_str().to_string(),
response_format: Some(options.response_format.as_str().to_string()),
speed: options.speed,
instructions,
};
let body = serde_json::to_string(&request_body).map_err(OpenAIToolError::SerdeJsonError)?;
let url = format!("{}/speech", self.auth.endpoint(AUDIO_PATH));
let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
let bytes = response.bytes().await.map_err(OpenAIToolError::RequestError)?;
Ok(bytes.to_vec())
}
pub async fn transcribe(&self, audio_path: &str, options: TranscribeOptions) -> Result<TranscriptionResponse> {
let audio_content = tokio::fs::read(audio_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read audio file: {}", e)))?;
let filename = Path::new(audio_path).file_name().and_then(|n| n.to_str()).unwrap_or("audio.mp3").to_string();
self.transcribe_bytes(&audio_content, &filename, options).await
}
pub async fn transcribe_bytes(&self, audio_data: &[u8], filename: &str, options: TranscribeOptions) -> Result<TranscriptionResponse> {
let (client, headers) = self.create_client()?;
let audio_part = Part::bytes(audio_data.to_vec())
.file_name(filename.to_string())
.mime_str("audio/mpeg")
.map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
let mut form = Form::new().part("file", audio_part);
let model = options.model.unwrap_or_default();
form = form.text("model", model.as_str().to_string());
if let Some(language) = options.language {
form = form.text("language", language);
}
if let Some(prompt) = options.prompt {
form = form.text("prompt", prompt);
}
if let Some(response_format) = options.response_format {
form = form.text("response_format", response_format.as_str().to_string());
}
if let Some(temperature) = options.temperature {
form = form.text("temperature", temperature.to_string());
}
if let Some(granularities) = options.timestamp_granularities {
for g in granularities {
form = form.text("timestamp_granularities[]", g.as_str().to_string());
}
}
let url = format!("{}/transcriptions", self.auth.endpoint(AUDIO_PATH));
let response = client.post(&url).headers(headers).multipart(form).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
}
return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
}
serde_json::from_str::<TranscriptionResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn translate(&self, audio_path: &str, options: TranslateOptions) -> Result<TranscriptionResponse> {
let audio_content = tokio::fs::read(audio_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read audio file: {}", e)))?;
let filename = Path::new(audio_path).file_name().and_then(|n| n.to_str()).unwrap_or("audio.mp3").to_string();
self.translate_bytes(&audio_content, &filename, options).await
}
pub async fn translate_bytes(&self, audio_data: &[u8], filename: &str, options: TranslateOptions) -> Result<TranscriptionResponse> {
let (client, headers) = self.create_client()?;
let audio_part = Part::bytes(audio_data.to_vec())
.file_name(filename.to_string())
.mime_str("audio/mpeg")
.map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
let mut form = Form::new().part("file", audio_part);
let model = options.model.unwrap_or(SttModel::Whisper1);
form = form.text("model", model.as_str().to_string());
if let Some(prompt) = options.prompt {
form = form.text("prompt", prompt);
}
if let Some(response_format) = options.response_format {
form = form.text("response_format", response_format.as_str().to_string());
}
if let Some(temperature) = options.temperature {
form = form.text("temperature", temperature.to_string());
}
let url = format!("{}/translations", self.auth.endpoint(AUDIO_PATH));
let response = client.post(&url).headers(headers).multipart(form).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
}
return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
}
serde_json::from_str::<TranscriptionResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tts_model_as_str() {
assert_eq!(TtsModel::Tts1.as_str(), "tts-1");
assert_eq!(TtsModel::Tts1Hd.as_str(), "tts-1-hd");
assert_eq!(TtsModel::Gpt4oMiniTts.as_str(), "gpt-4o-mini-tts");
}
#[test]
fn test_tts_model_supports_instructions() {
assert!(TtsModel::Gpt4oMiniTts.supports_instructions());
assert!(!TtsModel::Tts1.supports_instructions());
assert!(!TtsModel::Tts1Hd.supports_instructions());
}
#[test]
fn test_tts_model_default() {
let model = TtsModel::default();
assert_eq!(model, TtsModel::Tts1);
}
#[test]
fn test_tts_model_display() {
assert_eq!(format!("{}", TtsModel::Gpt4oMiniTts), "gpt-4o-mini-tts");
}
#[test]
fn test_voice_as_str_all_voices() {
assert_eq!(Voice::Alloy.as_str(), "alloy");
assert_eq!(Voice::Ash.as_str(), "ash");
assert_eq!(Voice::Ballad.as_str(), "ballad");
assert_eq!(Voice::Cedar.as_str(), "cedar");
assert_eq!(Voice::Coral.as_str(), "coral");
assert_eq!(Voice::Echo.as_str(), "echo");
assert_eq!(Voice::Fable.as_str(), "fable");
assert_eq!(Voice::Marin.as_str(), "marin");
assert_eq!(Voice::Nova.as_str(), "nova");
assert_eq!(Voice::Onyx.as_str(), "onyx");
assert_eq!(Voice::Sage.as_str(), "sage");
assert_eq!(Voice::Shimmer.as_str(), "shimmer");
assert_eq!(Voice::Verse.as_str(), "verse");
}
#[test]
fn test_voice_new_voices() {
assert_eq!(Voice::Ballad.as_str(), "ballad");
assert_eq!(Voice::Cedar.as_str(), "cedar");
assert_eq!(Voice::Marin.as_str(), "marin");
assert_eq!(Voice::Verse.as_str(), "verse");
}
#[test]
fn test_voice_default() {
let voice = Voice::default();
assert_eq!(voice, Voice::Alloy);
}
#[test]
fn test_voice_serialization() {
let voice = Voice::Coral;
let json = serde_json::to_string(&voice).unwrap();
assert_eq!(json, "\"coral\"");
let ballad = Voice::Ballad;
let json = serde_json::to_string(&ballad).unwrap();
assert_eq!(json, "\"ballad\"");
}
#[test]
fn test_voice_deserialization() {
let voice: Voice = serde_json::from_str("\"coral\"").unwrap();
assert_eq!(voice, Voice::Coral);
let cedar: Voice = serde_json::from_str("\"cedar\"").unwrap();
assert_eq!(cedar, Voice::Cedar);
let marin: Voice = serde_json::from_str("\"marin\"").unwrap();
assert_eq!(marin, Voice::Marin);
}
#[test]
fn test_tts_options_default() {
let options = TtsOptions::default();
assert_eq!(options.model, TtsModel::Tts1);
assert_eq!(options.voice, Voice::Alloy);
assert_eq!(options.response_format, AudioFormat::Mp3);
assert!(options.speed.is_none());
assert!(options.instructions.is_none());
}
#[test]
fn test_tts_options_with_instructions() {
let options = TtsOptions {
model: TtsModel::Gpt4oMiniTts,
voice: Voice::Coral,
instructions: Some("Speak in a cheerful tone.".to_string()),
..Default::default()
};
assert_eq!(options.model, TtsModel::Gpt4oMiniTts);
assert_eq!(options.instructions, Some("Speak in a cheerful tone.".to_string()));
}
#[test]
fn test_tts_request_serialization_with_instructions() {
let request = TtsRequest {
model: "gpt-4o-mini-tts".to_string(),
input: "Hello, world!".to_string(),
voice: "coral".to_string(),
response_format: Some("mp3".to_string()),
speed: None,
instructions: Some("Speak cheerfully.".to_string()),
};
let json = serde_json::to_value(&request).unwrap();
assert_eq!(json["model"], "gpt-4o-mini-tts");
assert_eq!(json["input"], "Hello, world!");
assert_eq!(json["voice"], "coral");
assert_eq!(json["response_format"], "mp3");
assert_eq!(json["instructions"], "Speak cheerfully.");
assert!(json.get("speed").is_none());
}
#[test]
fn test_tts_request_serialization_without_instructions() {
let request = TtsRequest {
model: "tts-1".to_string(),
input: "Hello".to_string(),
voice: "alloy".to_string(),
response_format: Some("mp3".to_string()),
speed: Some(1.0),
instructions: None,
};
let json = serde_json::to_value(&request).unwrap();
assert_eq!(json["model"], "tts-1");
assert_eq!(json["speed"], 1.0);
assert!(json.get("instructions").is_none());
}
#[test]
fn test_tts_request_skip_serializing_none_fields() {
let request = TtsRequest {
model: "tts-1".to_string(),
input: "Test".to_string(),
voice: "echo".to_string(),
response_format: None,
speed: None,
instructions: None,
};
let json = serde_json::to_value(&request).unwrap();
assert!(json.get("model").is_some());
assert!(json.get("input").is_some());
assert!(json.get("voice").is_some());
assert!(json.get("response_format").is_none());
assert!(json.get("speed").is_none());
assert!(json.get("instructions").is_none());
}
#[test]
fn test_audio_format_as_str() {
assert_eq!(AudioFormat::Mp3.as_str(), "mp3");
assert_eq!(AudioFormat::Opus.as_str(), "opus");
assert_eq!(AudioFormat::Aac.as_str(), "aac");
assert_eq!(AudioFormat::Flac.as_str(), "flac");
assert_eq!(AudioFormat::Wav.as_str(), "wav");
assert_eq!(AudioFormat::Pcm.as_str(), "pcm");
}
#[test]
fn test_audio_format_file_extension() {
assert_eq!(AudioFormat::Mp3.file_extension(), "mp3");
assert_eq!(AudioFormat::Wav.file_extension(), "wav");
}
#[test]
fn test_stt_model_as_str() {
assert_eq!(SttModel::Whisper1.as_str(), "whisper-1");
assert_eq!(SttModel::Gpt4oTranscribe.as_str(), "gpt-4o-transcribe");
}
#[test]
fn test_transcription_format_as_str() {
assert_eq!(TranscriptionFormat::Json.as_str(), "json");
assert_eq!(TranscriptionFormat::Text.as_str(), "text");
assert_eq!(TranscriptionFormat::Srt.as_str(), "srt");
assert_eq!(TranscriptionFormat::VerboseJson.as_str(), "verbose_json");
assert_eq!(TranscriptionFormat::Vtt.as_str(), "vtt");
}
#[test]
fn test_timestamp_granularity_as_str() {
assert_eq!(TimestampGranularity::Word.as_str(), "word");
assert_eq!(TimestampGranularity::Segment.as_str(), "segment");
}
}