Skip to main content

openai_tools/audio/
request.rs

1//! OpenAI Audio API Request Module
2//!
3//! This module provides the functionality to interact with the OpenAI Audio API.
4//! It supports text-to-speech (TTS), transcription, and translation.
5//!
6//! # Key Features
7//!
8//! - **Text-to-Speech**: Convert text to natural-sounding audio
9//! - **Transcription**: Convert audio to text (speech-to-text)
10//! - **Translation**: Translate audio to English text
11//!
12//! # Quick Start
13//!
14//! ```rust,no_run
15//! use openai_tools::audio::request::{Audio, TtsOptions, Voice};
16//!
17//! #[tokio::main]
18//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
19//!     let audio = Audio::new()?;
20//!
21//!     // Generate speech from text
22//!     let options = TtsOptions::default();
23//!     let audio_bytes = audio.text_to_speech("Hello, world!", options).await?;
24//!     std::fs::write("output.mp3", audio_bytes)?;
25//!
26//!     Ok(())
27//! }
28//! ```
29
30use crate::audio::response::TranscriptionResponse;
31use crate::common::auth::AuthProvider;
32use crate::common::client::create_http_client;
33use crate::common::errors::{ErrorResponse, OpenAIToolError, Result};
34use request::multipart::{Form, Part};
35use serde::{Deserialize, Serialize};
36use std::path::Path;
37use std::time::Duration;
38
39/// Default API path for Audio
40const AUDIO_PATH: &str = "audio";
41
42/// Text-to-speech models.
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
44pub enum TtsModel {
45    /// Standard quality TTS model
46    #[serde(rename = "tts-1")]
47    #[default]
48    Tts1,
49    /// High definition TTS model
50    #[serde(rename = "tts-1-hd")]
51    Tts1Hd,
52    /// GPT-4o Mini TTS model
53    #[serde(rename = "gpt-4o-mini-tts")]
54    Gpt4oMiniTts,
55}
56
57impl TtsModel {
58    /// Returns the model identifier string.
59    pub fn as_str(&self) -> &'static str {
60        match self {
61            Self::Tts1 => "tts-1",
62            Self::Tts1Hd => "tts-1-hd",
63            Self::Gpt4oMiniTts => "gpt-4o-mini-tts",
64        }
65    }
66
67    /// Checks if this model supports the `instructions` parameter.
68    ///
69    /// Only `gpt-4o-mini-tts` supports the instructions parameter for
70    /// controlling voice characteristics like tone, emotion, and pacing.
71    ///
72    /// # Example
73    ///
74    /// ```rust
75    /// use openai_tools::audio::request::TtsModel;
76    ///
77    /// assert!(TtsModel::Gpt4oMiniTts.supports_instructions());
78    /// assert!(!TtsModel::Tts1.supports_instructions());
79    /// assert!(!TtsModel::Tts1Hd.supports_instructions());
80    /// ```
81    pub fn supports_instructions(&self) -> bool {
82        matches!(self, Self::Gpt4oMiniTts)
83    }
84}
85
86impl std::fmt::Display for TtsModel {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        write!(f, "{}", self.as_str())
89    }
90}
91
92/// Voice options for text-to-speech.
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
94#[serde(rename_all = "lowercase")]
95pub enum Voice {
96    /// Alloy voice
97    #[default]
98    Alloy,
99    /// Ash voice
100    Ash,
101    /// Ballad voice
102    Ballad,
103    /// Cedar voice (recommended for quality)
104    Cedar,
105    /// Coral voice
106    Coral,
107    /// Echo voice
108    Echo,
109    /// Fable voice
110    Fable,
111    /// Marin voice (recommended for quality)
112    Marin,
113    /// Nova voice
114    Nova,
115    /// Onyx voice
116    Onyx,
117    /// Sage voice
118    Sage,
119    /// Shimmer voice
120    Shimmer,
121    /// Verse voice
122    Verse,
123}
124
125impl Voice {
126    /// Returns the voice identifier string.
127    pub fn as_str(&self) -> &'static str {
128        match self {
129            Self::Alloy => "alloy",
130            Self::Ash => "ash",
131            Self::Ballad => "ballad",
132            Self::Cedar => "cedar",
133            Self::Coral => "coral",
134            Self::Echo => "echo",
135            Self::Fable => "fable",
136            Self::Marin => "marin",
137            Self::Nova => "nova",
138            Self::Onyx => "onyx",
139            Self::Sage => "sage",
140            Self::Shimmer => "shimmer",
141            Self::Verse => "verse",
142        }
143    }
144}
145
146impl std::fmt::Display for Voice {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        write!(f, "{}", self.as_str())
149    }
150}
151
152/// Audio output formats for TTS.
153#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
154#[serde(rename_all = "lowercase")]
155pub enum AudioFormat {
156    /// MP3 format (default)
157    #[default]
158    Mp3,
159    /// Opus format
160    Opus,
161    /// AAC format
162    Aac,
163    /// FLAC format
164    Flac,
165    /// WAV format
166    Wav,
167    /// PCM format
168    Pcm,
169}
170
171impl AudioFormat {
172    /// Returns the format string.
173    pub fn as_str(&self) -> &'static str {
174        match self {
175            Self::Mp3 => "mp3",
176            Self::Opus => "opus",
177            Self::Aac => "aac",
178            Self::Flac => "flac",
179            Self::Wav => "wav",
180            Self::Pcm => "pcm",
181        }
182    }
183
184    /// Returns the file extension for this format.
185    pub fn file_extension(&self) -> &'static str {
186        self.as_str()
187    }
188}
189
190impl std::fmt::Display for AudioFormat {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        write!(f, "{}", self.as_str())
193    }
194}
195
196/// Speech-to-text models for transcription and translation.
197#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
198pub enum SttModel {
199    /// Whisper v1 model
200    #[serde(rename = "whisper-1")]
201    #[default]
202    Whisper1,
203    /// GPT-4o Transcribe model
204    #[serde(rename = "gpt-4o-transcribe")]
205    Gpt4oTranscribe,
206}
207
208impl SttModel {
209    /// Returns the model identifier string.
210    pub fn as_str(&self) -> &'static str {
211        match self {
212            Self::Whisper1 => "whisper-1",
213            Self::Gpt4oTranscribe => "gpt-4o-transcribe",
214        }
215    }
216}
217
218impl std::fmt::Display for SttModel {
219    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220        write!(f, "{}", self.as_str())
221    }
222}
223
224/// Transcription response formats.
225#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
226#[serde(rename_all = "snake_case")]
227pub enum TranscriptionFormat {
228    /// JSON format
229    #[default]
230    Json,
231    /// Plain text format
232    Text,
233    /// SRT subtitle format
234    Srt,
235    /// Verbose JSON with timestamps
236    VerboseJson,
237    /// VTT subtitle format
238    Vtt,
239}
240
241impl TranscriptionFormat {
242    /// Returns the format string.
243    pub fn as_str(&self) -> &'static str {
244        match self {
245            Self::Json => "json",
246            Self::Text => "text",
247            Self::Srt => "srt",
248            Self::VerboseJson => "verbose_json",
249            Self::Vtt => "vtt",
250        }
251    }
252}
253
254/// Timestamp granularity options.
255#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
256#[serde(rename_all = "lowercase")]
257pub enum TimestampGranularity {
258    /// Word-level timestamps
259    Word,
260    /// Segment-level timestamps
261    Segment,
262}
263
264impl TimestampGranularity {
265    /// Returns the granularity string.
266    pub fn as_str(&self) -> &'static str {
267        match self {
268            Self::Word => "word",
269            Self::Segment => "segment",
270        }
271    }
272}
273
274/// Options for text-to-speech generation.
275#[derive(Debug, Clone, Default)]
276pub struct TtsOptions {
277    /// The model to use (defaults to tts-1)
278    pub model: TtsModel,
279    /// The voice to use (defaults to alloy)
280    pub voice: Voice,
281    /// The output audio format (defaults to mp3)
282    pub response_format: AudioFormat,
283    /// Speech speed (0.25 to 4.0, defaults to 1.0)
284    pub speed: Option<f32>,
285    /// Instructions for controlling voice characteristics.
286    ///
287    /// Only supported by `gpt-4o-mini-tts` model.
288    /// Use natural language to control tone, emotion, and pacing.
289    ///
290    /// # Examples
291    ///
292    /// - `"Speak in a cheerful and positive tone."`
293    /// - `"Use a calm and soothing voice."`
294    /// - `"Speak with enthusiasm and energy."`
295    ///
296    /// If set with an unsupported model (`tts-1` or `tts-1-hd`),
297    /// this parameter will be ignored and a warning will be logged.
298    pub instructions: Option<String>,
299}
300
301/// Options for audio transcription.
302#[derive(Debug, Clone, Default)]
303pub struct TranscribeOptions {
304    /// The model to use (defaults to whisper-1)
305    pub model: Option<SttModel>,
306    /// The language of the input audio (ISO-639-1 code)
307    pub language: Option<String>,
308    /// Optional prompt to guide the model's style
309    pub prompt: Option<String>,
310    /// Response format (defaults to json)
311    pub response_format: Option<TranscriptionFormat>,
312    /// Temperature for sampling (0.0 to 1.0)
313    pub temperature: Option<f32>,
314    /// Timestamp granularities to include
315    pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
316}
317
318/// Options for audio translation.
319#[derive(Debug, Clone, Default)]
320pub struct TranslateOptions {
321    /// The model to use (only whisper-1 is supported)
322    pub model: Option<SttModel>,
323    /// Optional prompt to guide the model's style
324    pub prompt: Option<String>,
325    /// Response format (defaults to json)
326    pub response_format: Option<TranscriptionFormat>,
327    /// Temperature for sampling (0.0 to 1.0)
328    pub temperature: Option<f32>,
329}
330
331/// Request payload for TTS.
332#[derive(Debug, Clone, Serialize)]
333struct TtsRequest {
334    model: String,
335    input: String,
336    voice: String,
337    #[serde(skip_serializing_if = "Option::is_none")]
338    response_format: Option<String>,
339    #[serde(skip_serializing_if = "Option::is_none")]
340    speed: Option<f32>,
341    /// Instructions for voice control (only for gpt-4o-mini-tts).
342    #[serde(skip_serializing_if = "Option::is_none")]
343    instructions: Option<String>,
344}
345
346/// Client for interacting with the OpenAI Audio API.
347///
348/// This struct provides methods for text-to-speech, transcription, and translation.
349/// Use [`Audio::new()`] to create a new instance.
350///
351/// # Example
352///
353/// ```rust,no_run
354/// use openai_tools::audio::request::{Audio, TtsOptions, Voice, AudioFormat};
355///
356/// #[tokio::main]
357/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
358///     let audio = Audio::new()?;
359///
360///     let options = TtsOptions {
361///         voice: Voice::Nova,
362///         response_format: AudioFormat::Mp3,
363///         ..Default::default()
364///     };
365///
366///     let bytes = audio.text_to_speech("Welcome to our app!", options).await?;
367///     std::fs::write("welcome.mp3", bytes)?;
368///
369///     Ok(())
370/// }
371/// ```
372pub struct Audio {
373    /// Authentication provider (OpenAI or Azure)
374    auth: AuthProvider,
375    /// Optional request timeout duration
376    timeout: Option<Duration>,
377}
378
379impl Audio {
380    /// Creates a new Audio client for OpenAI API.
381    ///
382    /// Initializes the client by loading the OpenAI API key from
383    /// the environment variable `OPENAI_API_KEY`. Supports `.env` file loading
384    /// via dotenvy.
385    ///
386    /// # Returns
387    ///
388    /// * `Ok(Audio)` - A new Audio client ready for use
389    /// * `Err(OpenAIToolError)` - If the API key is not found in the environment
390    ///
391    /// # Example
392    ///
393    /// ```rust,no_run
394    /// use openai_tools::audio::request::Audio;
395    ///
396    /// let audio = Audio::new().expect("API key should be set");
397    /// ```
398    pub fn new() -> Result<Self> {
399        let auth = AuthProvider::openai_from_env()?;
400        Ok(Self { auth, timeout: None })
401    }
402
403    /// Creates a new Audio client with a custom authentication provider
404    pub fn with_auth(auth: AuthProvider) -> Self {
405        Self { auth, timeout: None }
406    }
407
408    /// Creates a new Audio client for Azure OpenAI API
409    pub fn azure() -> Result<Self> {
410        let auth = AuthProvider::azure_from_env()?;
411        Ok(Self { auth, timeout: None })
412    }
413
414    /// Creates a new Audio client by auto-detecting the provider
415    pub fn detect_provider() -> Result<Self> {
416        let auth = AuthProvider::from_env()?;
417        Ok(Self { auth, timeout: None })
418    }
419
420    /// Creates a new Audio client with URL-based provider detection
421    pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
422        let auth = AuthProvider::from_url_with_key(base_url, api_key);
423        Self { auth, timeout: None }
424    }
425
426    /// Creates a new Audio client from URL using environment variables
427    pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
428        let auth = AuthProvider::from_url(url)?;
429        Ok(Self { auth, timeout: None })
430    }
431
432    /// Returns the authentication provider
433    pub fn auth(&self) -> &AuthProvider {
434        &self.auth
435    }
436
437    /// Sets the request timeout duration.
438    ///
439    /// # Arguments
440    ///
441    /// * `timeout` - The maximum time to wait for a response
442    ///
443    /// # Returns
444    ///
445    /// A mutable reference to self for method chaining
446    pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
447        self.timeout = Some(timeout);
448        self
449    }
450
451    /// Creates the HTTP client with default headers.
452    fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
453        let client = create_http_client(self.timeout)?;
454        let mut headers = request::header::HeaderMap::new();
455        self.auth.apply_headers(&mut headers)?;
456        headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
457        Ok((client, headers))
458    }
459
460    /// Converts text to speech.
461    ///
462    /// Returns audio bytes in the specified format.
463    ///
464    /// # Arguments
465    ///
466    /// * `text` - The text to convert to speech (max 4096 characters)
467    /// * `options` - TTS options (model, voice, format, speed)
468    ///
469    /// # Returns
470    ///
471    /// * `Ok(Vec<u8>)` - The audio data as bytes
472    /// * `Err(OpenAIToolError)` - If the request fails
473    ///
474    /// # Example
475    ///
476    /// ```rust,no_run
477    /// use openai_tools::audio::request::{Audio, TtsOptions, TtsModel, Voice};
478    ///
479    /// #[tokio::main]
480    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
481    ///     let audio = Audio::new()?;
482    ///
483    ///     let options = TtsOptions {
484    ///         model: TtsModel::Tts1Hd,
485    ///         voice: Voice::Shimmer,
486    ///         speed: Some(1.2),
487    ///         ..Default::default()
488    ///     };
489    ///
490    ///     let bytes = audio.text_to_speech("Hello, this is a test.", options).await?;
491    ///     std::fs::write("speech.mp3", bytes)?;
492    ///
493    ///     Ok(())
494    /// }
495    /// ```
496    pub async fn text_to_speech(&self, text: &str, options: TtsOptions) -> Result<Vec<u8>> {
497        let (client, mut headers) = self.create_client()?;
498        headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
499
500        // Check if instructions parameter is supported by the model
501        let instructions = if options.instructions.is_some() {
502            if options.model.supports_instructions() {
503                options.instructions
504            } else {
505                tracing::warn!("Model '{}' does not support instructions parameter. Ignoring instructions.", options.model);
506                None
507            }
508        } else {
509            None
510        };
511
512        let request_body = TtsRequest {
513            model: options.model.as_str().to_string(),
514            input: text.to_string(),
515            voice: options.voice.as_str().to_string(),
516            response_format: Some(options.response_format.as_str().to_string()),
517            speed: options.speed,
518            instructions,
519        };
520
521        let body = serde_json::to_string(&request_body).map_err(OpenAIToolError::SerdeJsonError)?;
522
523        let url = format!("{}/speech", self.auth.endpoint(AUDIO_PATH));
524
525        let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
526
527        let bytes = response.bytes().await.map_err(OpenAIToolError::RequestError)?;
528
529        Ok(bytes.to_vec())
530    }
531
532    /// Transcribes audio from a file path.
533    ///
534    /// # Arguments
535    ///
536    /// * `audio_path` - Path to the audio file
537    /// * `options` - Transcription options
538    ///
539    /// # Returns
540    ///
541    /// * `Ok(TranscriptionResponse)` - The transcription result
542    /// * `Err(OpenAIToolError)` - If the request fails
543    ///
544    /// # Example
545    ///
546    /// ```rust,no_run
547    /// use openai_tools::audio::request::{Audio, TranscribeOptions};
548    ///
549    /// #[tokio::main]
550    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
551    ///     let audio = Audio::new()?;
552    ///
553    ///     let options = TranscribeOptions {
554    ///         language: Some("en".to_string()),
555    ///         ..Default::default()
556    ///     };
557    ///
558    ///     let response = audio.transcribe("audio.mp3", options).await?;
559    ///     println!("Transcription: {}", response.text);
560    ///
561    ///     Ok(())
562    /// }
563    /// ```
564    pub async fn transcribe(&self, audio_path: &str, options: TranscribeOptions) -> Result<TranscriptionResponse> {
565        let audio_content = tokio::fs::read(audio_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read audio file: {}", e)))?;
566
567        let filename = Path::new(audio_path).file_name().and_then(|n| n.to_str()).unwrap_or("audio.mp3").to_string();
568
569        self.transcribe_bytes(&audio_content, &filename, options).await
570    }
571
572    /// Transcribes audio from bytes.
573    ///
574    /// # Arguments
575    ///
576    /// * `audio_data` - The audio data as bytes
577    /// * `filename` - The filename with extension (e.g., "audio.mp3")
578    /// * `options` - Transcription options
579    ///
580    /// # Returns
581    ///
582    /// * `Ok(TranscriptionResponse)` - The transcription result
583    /// * `Err(OpenAIToolError)` - If the request fails
584    ///
585    /// # Example
586    ///
587    /// ```rust,no_run
588    /// use openai_tools::audio::request::{Audio, TranscribeOptions, SttModel};
589    ///
590    /// #[tokio::main]
591    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
592    ///     let audio = Audio::new()?;
593    ///
594    ///     let audio_data = std::fs::read("recording.mp3")?;
595    ///     let options = TranscribeOptions {
596    ///         model: Some(SttModel::Whisper1),
597    ///         ..Default::default()
598    ///     };
599    ///
600    ///     let response = audio.transcribe_bytes(&audio_data, "recording.mp3", options).await?;
601    ///     println!("Transcription: {}", response.text);
602    ///
603    ///     Ok(())
604    /// }
605    /// ```
606    pub async fn transcribe_bytes(&self, audio_data: &[u8], filename: &str, options: TranscribeOptions) -> Result<TranscriptionResponse> {
607        let (client, headers) = self.create_client()?;
608
609        let audio_part = Part::bytes(audio_data.to_vec())
610            .file_name(filename.to_string())
611            .mime_str("audio/mpeg")
612            .map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
613
614        let mut form = Form::new().part("file", audio_part);
615
616        // Add model
617        let model = options.model.unwrap_or_default();
618        form = form.text("model", model.as_str().to_string());
619
620        // Add optional parameters
621        if let Some(language) = options.language {
622            form = form.text("language", language);
623        }
624        if let Some(prompt) = options.prompt {
625            form = form.text("prompt", prompt);
626        }
627        if let Some(response_format) = options.response_format {
628            form = form.text("response_format", response_format.as_str().to_string());
629        }
630        if let Some(temperature) = options.temperature {
631            form = form.text("temperature", temperature.to_string());
632        }
633        if let Some(granularities) = options.timestamp_granularities {
634            for g in granularities {
635                form = form.text("timestamp_granularities[]", g.as_str().to_string());
636            }
637        }
638
639        let url = format!("{}/transcriptions", self.auth.endpoint(AUDIO_PATH));
640
641        let response = client.post(&url).headers(headers).multipart(form).send().await.map_err(OpenAIToolError::RequestError)?;
642
643        let status = response.status();
644        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
645
646        if cfg!(test) {
647            tracing::info!("Response content: {}", content);
648        }
649
650        if !status.is_success() {
651            if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
652                return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
653            }
654            return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
655        }
656
657        serde_json::from_str::<TranscriptionResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
658    }
659
660    /// Translates audio to English text.
661    ///
662    /// Only supports translation to English using the whisper-1 model.
663    ///
664    /// # Arguments
665    ///
666    /// * `audio_path` - Path to the audio file
667    /// * `options` - Translation options
668    ///
669    /// # Returns
670    ///
671    /// * `Ok(TranscriptionResponse)` - The translation result
672    /// * `Err(OpenAIToolError)` - If the request fails
673    ///
674    /// # Example
675    ///
676    /// ```rust,no_run
677    /// use openai_tools::audio::request::{Audio, TranslateOptions};
678    ///
679    /// #[tokio::main]
680    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
681    ///     let audio = Audio::new()?;
682    ///
683    ///     let options = TranslateOptions::default();
684    ///     let response = audio.translate("french_audio.mp3", options).await?;
685    ///     println!("English translation: {}", response.text);
686    ///
687    ///     Ok(())
688    /// }
689    /// ```
690    pub async fn translate(&self, audio_path: &str, options: TranslateOptions) -> Result<TranscriptionResponse> {
691        let audio_content = tokio::fs::read(audio_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read audio file: {}", e)))?;
692
693        let filename = Path::new(audio_path).file_name().and_then(|n| n.to_str()).unwrap_or("audio.mp3").to_string();
694
695        self.translate_bytes(&audio_content, &filename, options).await
696    }
697
698    /// Translates audio from bytes to English text.
699    ///
700    /// # Arguments
701    ///
702    /// * `audio_data` - The audio data as bytes
703    /// * `filename` - The filename with extension (e.g., "audio.mp3")
704    /// * `options` - Translation options
705    ///
706    /// # Returns
707    ///
708    /// * `Ok(TranscriptionResponse)` - The translation result
709    /// * `Err(OpenAIToolError)` - If the request fails
710    pub async fn translate_bytes(&self, audio_data: &[u8], filename: &str, options: TranslateOptions) -> Result<TranscriptionResponse> {
711        let (client, headers) = self.create_client()?;
712
713        let audio_part = Part::bytes(audio_data.to_vec())
714            .file_name(filename.to_string())
715            .mime_str("audio/mpeg")
716            .map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
717
718        let mut form = Form::new().part("file", audio_part);
719
720        // Add model (whisper-1 is the only supported model for translation)
721        let model = options.model.unwrap_or(SttModel::Whisper1);
722        form = form.text("model", model.as_str().to_string());
723
724        // Add optional parameters
725        if let Some(prompt) = options.prompt {
726            form = form.text("prompt", prompt);
727        }
728        if let Some(response_format) = options.response_format {
729            form = form.text("response_format", response_format.as_str().to_string());
730        }
731        if let Some(temperature) = options.temperature {
732            form = form.text("temperature", temperature.to_string());
733        }
734
735        let url = format!("{}/translations", self.auth.endpoint(AUDIO_PATH));
736
737        let response = client.post(&url).headers(headers).multipart(form).send().await.map_err(OpenAIToolError::RequestError)?;
738
739        let status = response.status();
740        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
741
742        if cfg!(test) {
743            tracing::info!("Response content: {}", content);
744        }
745
746        if !status.is_success() {
747            if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
748                return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
749            }
750            return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
751        }
752
753        serde_json::from_str::<TranscriptionResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
754    }
755}
756
757#[cfg(test)]
758mod tests {
759    use super::*;
760
761    // =========================================================================
762    // TtsModel Tests
763    // =========================================================================
764
765    #[test]
766    fn test_tts_model_as_str() {
767        assert_eq!(TtsModel::Tts1.as_str(), "tts-1");
768        assert_eq!(TtsModel::Tts1Hd.as_str(), "tts-1-hd");
769        assert_eq!(TtsModel::Gpt4oMiniTts.as_str(), "gpt-4o-mini-tts");
770    }
771
772    #[test]
773    fn test_tts_model_supports_instructions() {
774        // Only gpt-4o-mini-tts supports instructions
775        assert!(TtsModel::Gpt4oMiniTts.supports_instructions());
776        assert!(!TtsModel::Tts1.supports_instructions());
777        assert!(!TtsModel::Tts1Hd.supports_instructions());
778    }
779
780    #[test]
781    fn test_tts_model_default() {
782        let model = TtsModel::default();
783        assert_eq!(model, TtsModel::Tts1);
784    }
785
786    #[test]
787    fn test_tts_model_display() {
788        assert_eq!(format!("{}", TtsModel::Gpt4oMiniTts), "gpt-4o-mini-tts");
789    }
790
791    // =========================================================================
792    // Voice Tests
793    // =========================================================================
794
795    #[test]
796    fn test_voice_as_str_all_voices() {
797        assert_eq!(Voice::Alloy.as_str(), "alloy");
798        assert_eq!(Voice::Ash.as_str(), "ash");
799        assert_eq!(Voice::Ballad.as_str(), "ballad");
800        assert_eq!(Voice::Cedar.as_str(), "cedar");
801        assert_eq!(Voice::Coral.as_str(), "coral");
802        assert_eq!(Voice::Echo.as_str(), "echo");
803        assert_eq!(Voice::Fable.as_str(), "fable");
804        assert_eq!(Voice::Marin.as_str(), "marin");
805        assert_eq!(Voice::Nova.as_str(), "nova");
806        assert_eq!(Voice::Onyx.as_str(), "onyx");
807        assert_eq!(Voice::Sage.as_str(), "sage");
808        assert_eq!(Voice::Shimmer.as_str(), "shimmer");
809        assert_eq!(Voice::Verse.as_str(), "verse");
810    }
811
812    #[test]
813    fn test_voice_new_voices() {
814        // Test the newly added voices
815        assert_eq!(Voice::Ballad.as_str(), "ballad");
816        assert_eq!(Voice::Cedar.as_str(), "cedar");
817        assert_eq!(Voice::Marin.as_str(), "marin");
818        assert_eq!(Voice::Verse.as_str(), "verse");
819    }
820
821    #[test]
822    fn test_voice_default() {
823        let voice = Voice::default();
824        assert_eq!(voice, Voice::Alloy);
825    }
826
827    #[test]
828    fn test_voice_serialization() {
829        let voice = Voice::Coral;
830        let json = serde_json::to_string(&voice).unwrap();
831        assert_eq!(json, "\"coral\"");
832
833        // Test new voices
834        let ballad = Voice::Ballad;
835        let json = serde_json::to_string(&ballad).unwrap();
836        assert_eq!(json, "\"ballad\"");
837    }
838
839    #[test]
840    fn test_voice_deserialization() {
841        let voice: Voice = serde_json::from_str("\"coral\"").unwrap();
842        assert_eq!(voice, Voice::Coral);
843
844        // Test new voices
845        let cedar: Voice = serde_json::from_str("\"cedar\"").unwrap();
846        assert_eq!(cedar, Voice::Cedar);
847
848        let marin: Voice = serde_json::from_str("\"marin\"").unwrap();
849        assert_eq!(marin, Voice::Marin);
850    }
851
852    // =========================================================================
853    // TtsOptions Tests
854    // =========================================================================
855
856    #[test]
857    fn test_tts_options_default() {
858        let options = TtsOptions::default();
859        assert_eq!(options.model, TtsModel::Tts1);
860        assert_eq!(options.voice, Voice::Alloy);
861        assert_eq!(options.response_format, AudioFormat::Mp3);
862        assert!(options.speed.is_none());
863        assert!(options.instructions.is_none());
864    }
865
866    #[test]
867    fn test_tts_options_with_instructions() {
868        let options = TtsOptions {
869            model: TtsModel::Gpt4oMiniTts,
870            voice: Voice::Coral,
871            instructions: Some("Speak in a cheerful tone.".to_string()),
872            ..Default::default()
873        };
874        assert_eq!(options.model, TtsModel::Gpt4oMiniTts);
875        assert_eq!(options.instructions, Some("Speak in a cheerful tone.".to_string()));
876    }
877
878    // =========================================================================
879    // TtsRequest Tests
880    // =========================================================================
881
882    #[test]
883    fn test_tts_request_serialization_with_instructions() {
884        let request = TtsRequest {
885            model: "gpt-4o-mini-tts".to_string(),
886            input: "Hello, world!".to_string(),
887            voice: "coral".to_string(),
888            response_format: Some("mp3".to_string()),
889            speed: None,
890            instructions: Some("Speak cheerfully.".to_string()),
891        };
892        let json = serde_json::to_value(&request).unwrap();
893
894        assert_eq!(json["model"], "gpt-4o-mini-tts");
895        assert_eq!(json["input"], "Hello, world!");
896        assert_eq!(json["voice"], "coral");
897        assert_eq!(json["response_format"], "mp3");
898        assert_eq!(json["instructions"], "Speak cheerfully.");
899        assert!(json.get("speed").is_none());
900    }
901
902    #[test]
903    fn test_tts_request_serialization_without_instructions() {
904        let request = TtsRequest {
905            model: "tts-1".to_string(),
906            input: "Hello".to_string(),
907            voice: "alloy".to_string(),
908            response_format: Some("mp3".to_string()),
909            speed: Some(1.0),
910            instructions: None,
911        };
912        let json = serde_json::to_value(&request).unwrap();
913
914        assert_eq!(json["model"], "tts-1");
915        assert_eq!(json["speed"], 1.0);
916        // instructions should be omitted when None
917        assert!(json.get("instructions").is_none());
918    }
919
920    #[test]
921    fn test_tts_request_skip_serializing_none_fields() {
922        let request = TtsRequest {
923            model: "tts-1".to_string(),
924            input: "Test".to_string(),
925            voice: "echo".to_string(),
926            response_format: None,
927            speed: None,
928            instructions: None,
929        };
930        let json = serde_json::to_value(&request).unwrap();
931
932        // Required fields are present
933        assert!(json.get("model").is_some());
934        assert!(json.get("input").is_some());
935        assert!(json.get("voice").is_some());
936
937        // Optional fields with None are omitted
938        assert!(json.get("response_format").is_none());
939        assert!(json.get("speed").is_none());
940        assert!(json.get("instructions").is_none());
941    }
942
943    // =========================================================================
944    // AudioFormat Tests
945    // =========================================================================
946
947    #[test]
948    fn test_audio_format_as_str() {
949        assert_eq!(AudioFormat::Mp3.as_str(), "mp3");
950        assert_eq!(AudioFormat::Opus.as_str(), "opus");
951        assert_eq!(AudioFormat::Aac.as_str(), "aac");
952        assert_eq!(AudioFormat::Flac.as_str(), "flac");
953        assert_eq!(AudioFormat::Wav.as_str(), "wav");
954        assert_eq!(AudioFormat::Pcm.as_str(), "pcm");
955    }
956
957    #[test]
958    fn test_audio_format_file_extension() {
959        assert_eq!(AudioFormat::Mp3.file_extension(), "mp3");
960        assert_eq!(AudioFormat::Wav.file_extension(), "wav");
961    }
962
963    // =========================================================================
964    // SttModel Tests
965    // =========================================================================
966
967    #[test]
968    fn test_stt_model_as_str() {
969        assert_eq!(SttModel::Whisper1.as_str(), "whisper-1");
970        assert_eq!(SttModel::Gpt4oTranscribe.as_str(), "gpt-4o-transcribe");
971    }
972
973    // =========================================================================
974    // TranscriptionFormat Tests
975    // =========================================================================
976
977    #[test]
978    fn test_transcription_format_as_str() {
979        assert_eq!(TranscriptionFormat::Json.as_str(), "json");
980        assert_eq!(TranscriptionFormat::Text.as_str(), "text");
981        assert_eq!(TranscriptionFormat::Srt.as_str(), "srt");
982        assert_eq!(TranscriptionFormat::VerboseJson.as_str(), "verbose_json");
983        assert_eq!(TranscriptionFormat::Vtt.as_str(), "vtt");
984    }
985
986    // =========================================================================
987    // TimestampGranularity Tests
988    // =========================================================================
989
990    #[test]
991    fn test_timestamp_granularity_as_str() {
992        assert_eq!(TimestampGranularity::Word.as_str(), "word");
993        assert_eq!(TimestampGranularity::Segment.as_str(), "segment");
994    }
995}