openai_ergonomic/builders/
audio.rs

1//! Audio API builders.
2//!
3//! This module provides ergonomic builders for working with the `OpenAI` audio
4//! endpoints, covering text-to-speech, transcription, and translation
5//! workflows. Builders perform lightweight validation and produce values that
6//! can be passed directly to `openai-client-base` request functions.
7
8use std::path::{Path, PathBuf};
9
10use openai_client_base::models::transcription_chunking_strategy::TranscriptionChunkingStrategyTextVariantEnum;
11use openai_client_base::models::{
12    create_speech_request::{
13        ResponseFormat as SpeechResponseFormat, StreamFormat as SpeechStreamFormat,
14    },
15    AudioResponseFormat, CreateSpeechRequest, TranscriptionChunkingStrategy, TranscriptionInclude,
16    VadConfig,
17};
18
19use crate::{Builder, Error, Result};
20
21/// Builder for text-to-speech requests.
22#[derive(Debug, Clone)]
23pub struct SpeechBuilder {
24    model: String,
25    input: String,
26    voice: String,
27    instructions: Option<String>,
28    response_format: Option<SpeechResponseFormat>,
29    speed: Option<f64>,
30    stream_format: Option<SpeechStreamFormat>,
31}
32
33impl SpeechBuilder {
34    /// Create a new speech builder with the required model, input text, and voice.
35    #[must_use]
36    pub fn new(
37        model: impl Into<String>,
38        input: impl Into<String>,
39        voice: impl Into<String>,
40    ) -> Self {
41        Self {
42            model: model.into(),
43            input: input.into(),
44            voice: voice.into(),
45            instructions: None,
46            response_format: None,
47            speed: None,
48            stream_format: None,
49        }
50    }
51
52    /// Add additional voice instructions (ignored for legacy TTS models).
53    #[must_use]
54    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
55        self.instructions = Some(instructions.into());
56        self
57    }
58
59    /// Choose the audio response format (default is `mp3`).
60    #[must_use]
61    pub fn response_format(mut self, format: SpeechResponseFormat) -> Self {
62        self.response_format = Some(format);
63        self
64    }
65
66    /// Set the playback speed. Must be between `0.25` and `4.0` inclusive.
67    #[must_use]
68    pub fn speed(mut self, speed: f64) -> Self {
69        self.speed = Some(speed);
70        self
71    }
72
73    /// Configure streaming output. Unsupported for some legacy models.
74    #[must_use]
75    pub fn stream_format(mut self, format: SpeechStreamFormat) -> Self {
76        self.stream_format = Some(format);
77        self
78    }
79
80    /// Access the configured model.
81    #[must_use]
82    pub fn model(&self) -> &str {
83        &self.model
84    }
85
86    /// Access the configured input text.
87    #[must_use]
88    pub fn input(&self) -> &str {
89        &self.input
90    }
91
92    /// Access the configured voice.
93    #[must_use]
94    pub fn voice(&self) -> &str {
95        &self.voice
96    }
97
98    /// Access the configured response format, if any.
99    #[must_use]
100    pub fn response_format_ref(&self) -> Option<SpeechResponseFormat> {
101        self.response_format
102    }
103
104    /// Access the configured stream format, if any.
105    #[must_use]
106    pub fn stream_format_ref(&self) -> Option<SpeechStreamFormat> {
107        self.stream_format
108    }
109}
110
111impl Builder<CreateSpeechRequest> for SpeechBuilder {
112    fn build(self) -> Result<CreateSpeechRequest> {
113        if let Some(speed) = self.speed {
114            if !(0.25..=4.0).contains(&speed) {
115                return Err(Error::InvalidRequest(format!(
116                    "Speech speed {speed} is outside the supported range 0.25–4.0"
117                )));
118            }
119        }
120
121        Ok(CreateSpeechRequest {
122            model: self.model,
123            input: self.input,
124            instructions: self.instructions,
125            voice: self.voice,
126            response_format: self.response_format,
127            speed: self.speed,
128            stream_format: self.stream_format,
129        })
130    }
131}
132
133/// Granularity options for transcription timestamps.
134#[derive(Debug, Clone, Copy, Eq, PartialEq)]
135pub enum TimestampGranularity {
136    /// Include timestamps at the segment level.
137    Segment,
138    /// Include timestamps at the word level (where supported).
139    Word,
140}
141
142impl TimestampGranularity {
143    pub(crate) fn as_str(self) -> &'static str {
144        match self {
145            Self::Segment => "segment",
146            Self::Word => "word",
147        }
148    }
149}
150
151/// Builder for audio transcription requests.
152#[derive(Debug, Clone)]
153pub struct TranscriptionBuilder {
154    file: PathBuf,
155    model: String,
156    language: Option<String>,
157    prompt: Option<String>,
158    response_format: Option<AudioResponseFormat>,
159    temperature: Option<f64>,
160    stream: Option<bool>,
161    chunking_strategy: Option<TranscriptionChunkingStrategy>,
162    timestamp_granularities: Vec<TimestampGranularity>,
163    include: Vec<TranscriptionInclude>,
164}
165
166impl TranscriptionBuilder {
167    /// Create a new transcription builder for the given audio file and model.
168    #[must_use]
169    pub fn new(file: impl AsRef<Path>, model: impl Into<String>) -> Self {
170        Self {
171            file: file.as_ref().to_path_buf(),
172            model: model.into(),
173            language: None,
174            prompt: None,
175            response_format: None,
176            temperature: None,
177            stream: None,
178            chunking_strategy: None,
179            timestamp_granularities: Vec::new(),
180            include: Vec::new(),
181        }
182    }
183
184    /// Provide the input language to improve accuracy.
185    #[must_use]
186    pub fn language(mut self, language: impl Into<String>) -> Self {
187        self.language = Some(language.into());
188        self
189    }
190
191    /// Supply a prompt to guide the transcription style.
192    #[must_use]
193    pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
194        self.prompt = Some(prompt.into());
195        self
196    }
197
198    /// Set the desired response format (`json`, `text`, `srt`, `verbose_json`, `vtt`).
199    #[must_use]
200    pub fn response_format(mut self, format: AudioResponseFormat) -> Self {
201        self.response_format = Some(format);
202        self
203    }
204
205    /// Control randomness (0.0–1.0). `0.0` yields deterministic output.
206    #[must_use]
207    pub fn temperature(mut self, temperature: f64) -> Self {
208        self.temperature = Some(temperature);
209        self
210    }
211
212    /// Enable or disable server-side streaming for partial results.
213    #[must_use]
214    pub fn stream(mut self, stream: bool) -> Self {
215        self.stream = Some(stream);
216        self
217    }
218
219    /// Use the default automatic chunking strategy.
220    #[must_use]
221    pub fn chunking_strategy_auto(mut self) -> Self {
222        self.chunking_strategy = Some(TranscriptionChunkingStrategy::TextVariant(
223            TranscriptionChunkingStrategyTextVariantEnum::Auto,
224        ));
225        self
226    }
227
228    /// Provide a custom VAD configuration for chunking.
229    #[must_use]
230    pub fn chunking_strategy_vad(mut self, config: VadConfig) -> Self {
231        self.chunking_strategy = Some(TranscriptionChunkingStrategy::Vadconfig(config));
232        self
233    }
234
235    /// Disable chunking hints and fall back to API defaults.
236    #[must_use]
237    pub fn clear_chunking_strategy(mut self) -> Self {
238        self.chunking_strategy = None;
239        self
240    }
241
242    /// Request specific timestamp granularities.
243    #[must_use]
244    pub fn timestamp_granularities(
245        mut self,
246        granularities: impl IntoIterator<Item = TimestampGranularity>,
247    ) -> Self {
248        self.timestamp_granularities = granularities.into_iter().collect();
249        self
250    }
251
252    /// Append a timestamp granularity option.
253    #[must_use]
254    pub fn add_timestamp_granularity(mut self, granularity: TimestampGranularity) -> Self {
255        if !self.timestamp_granularities.contains(&granularity) {
256            self.timestamp_granularities.push(granularity);
257        }
258        self
259    }
260
261    /// Include additional metadata (e.g., log probabilities).
262    #[must_use]
263    pub fn include(mut self, include: TranscriptionInclude) -> Self {
264        if !self.include.contains(&include) {
265            self.include.push(include);
266        }
267        self
268    }
269
270    /// Access the source file path.
271    #[must_use]
272    pub fn file(&self) -> &Path {
273        &self.file
274    }
275
276    /// Access the target model.
277    #[must_use]
278    pub fn model(&self) -> &str {
279        &self.model
280    }
281
282    /// Access the selected language.
283    #[must_use]
284    pub fn language_ref(&self) -> Option<&str> {
285        self.language.as_deref()
286    }
287
288    /// Access the selected response format.
289    #[must_use]
290    pub fn response_format_ref(&self) -> Option<AudioResponseFormat> {
291        self.response_format
292    }
293}
294
295/// Fully prepared transcription request data.
296#[derive(Debug, Clone)]
297pub struct TranscriptionRequest {
298    /// Audio file to upload for transcription.
299    pub file: PathBuf,
300    /// Model identifier to use (e.g., `gpt-4o-mini-transcribe`).
301    pub model: String,
302    /// Optional language hint.
303    pub language: Option<String>,
304    /// Optional style/context prompt.
305    pub prompt: Option<String>,
306    /// Desired response format.
307    pub response_format: Option<AudioResponseFormat>,
308    /// Randomness control (0.0–1.0).
309    pub temperature: Option<f64>,
310    /// Enable partial streaming responses.
311    pub stream: Option<bool>,
312    /// Chunking strategy configuration.
313    pub chunking_strategy: Option<TranscriptionChunkingStrategy>,
314    /// Requested timestamp granularities.
315    pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
316    /// Additional metadata to include in the response.
317    pub include: Option<Vec<TranscriptionInclude>>,
318}
319
320impl Builder<TranscriptionRequest> for TranscriptionBuilder {
321    fn build(self) -> Result<TranscriptionRequest> {
322        if let Some(temperature) = self.temperature {
323            if !(0.0..=1.0).contains(&temperature) {
324                return Err(Error::InvalidRequest(format!(
325                    "Transcription temperature {temperature} is outside the supported range 0.0–1.0"
326                )));
327            }
328        }
329
330        Ok(TranscriptionRequest {
331            file: self.file,
332            model: self.model,
333            language: self.language,
334            prompt: self.prompt,
335            response_format: self.response_format,
336            temperature: self.temperature,
337            stream: self.stream,
338            chunking_strategy: self.chunking_strategy,
339            timestamp_granularities: if self.timestamp_granularities.is_empty() {
340                None
341            } else {
342                Some(self.timestamp_granularities)
343            },
344            include: if self.include.is_empty() {
345                None
346            } else {
347                Some(self.include)
348            },
349        })
350    }
351}
352
353/// Builder for audio translation (audio → English text).
354#[derive(Debug, Clone)]
355pub struct TranslationBuilder {
356    file: PathBuf,
357    model: String,
358    prompt: Option<String>,
359    response_format: Option<AudioResponseFormat>,
360    temperature: Option<f64>,
361}
362
363impl TranslationBuilder {
364    /// Create a new translation builder for the given audio file and model.
365    #[must_use]
366    pub fn new(file: impl AsRef<Path>, model: impl Into<String>) -> Self {
367        Self {
368            file: file.as_ref().to_path_buf(),
369            model: model.into(),
370            prompt: None,
371            response_format: None,
372            temperature: None,
373        }
374    }
375
376    /// Provide an optional prompt to guide translation tone.
377    #[must_use]
378    pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
379        self.prompt = Some(prompt.into());
380        self
381    }
382
383    /// Select the output format (defaults to JSON).
384    #[must_use]
385    pub fn response_format(mut self, format: AudioResponseFormat) -> Self {
386        self.response_format = Some(format);
387        self
388    }
389
390    /// Control randomness (0.0–1.0).
391    #[must_use]
392    pub fn temperature(mut self, temperature: f64) -> Self {
393        self.temperature = Some(temperature);
394        self
395    }
396
397    /// Access the configured model.
398    #[must_use]
399    pub fn model(&self) -> &str {
400        &self.model
401    }
402
403    /// Access the configured file path.
404    #[must_use]
405    pub fn file(&self) -> &Path {
406        &self.file
407    }
408}
409
410/// Fully prepared translation request data.
411#[derive(Debug, Clone)]
412pub struct TranslationRequest {
413    /// Audio file to translate.
414    pub file: PathBuf,
415    /// Model to use for translation.
416    pub model: String,
417    /// Optional prompt for style control.
418    pub prompt: Option<String>,
419    /// Desired output format.
420    pub response_format: Option<AudioResponseFormat>,
421    /// Randomness control.
422    pub temperature: Option<f64>,
423}
424
425impl Builder<TranslationRequest> for TranslationBuilder {
426    fn build(self) -> Result<TranslationRequest> {
427        if let Some(temperature) = self.temperature {
428            if !(0.0..=1.0).contains(&temperature) {
429                return Err(Error::InvalidRequest(format!(
430                    "Translation temperature {temperature} is outside the supported range 0.0–1.0"
431                )));
432            }
433        }
434
435        Ok(TranslationRequest {
436            file: self.file,
437            model: self.model,
438            prompt: self.prompt,
439            response_format: self.response_format,
440            temperature: self.temperature,
441        })
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[test]
450    fn builds_speech_request() {
451        let request = SpeechBuilder::new("gpt-4o-mini-tts", "Hello world", "alloy")
452            .instructions("Speak calmly")
453            .response_format(SpeechResponseFormat::Wav)
454            .speed(1.25)
455            .stream_format(SpeechStreamFormat::Audio)
456            .build()
457            .expect("valid speech builder");
458
459        assert_eq!(request.model, "gpt-4o-mini-tts");
460        assert_eq!(request.input, "Hello world");
461        assert_eq!(request.voice, "alloy");
462        assert_eq!(request.response_format, Some(SpeechResponseFormat::Wav));
463        assert_eq!(request.speed, Some(1.25));
464        assert_eq!(request.stream_format, Some(SpeechStreamFormat::Audio));
465    }
466
467    #[test]
468    fn speech_speed_validation() {
469        let err = SpeechBuilder::new("gpt-4o-mini-tts", "Hi", "alloy")
470            .speed(5.0)
471            .build()
472            .expect_err("speed outside supported range");
473        assert!(matches!(err, Error::InvalidRequest(_)));
474    }
475
476    #[test]
477    fn builds_transcription_request() {
478        let request = TranscriptionBuilder::new("audio.wav", "gpt-4o-mini-transcribe")
479            .language("en")
480            .prompt("Friendly tone")
481            .response_format(AudioResponseFormat::VerboseJson)
482            .temperature(0.2)
483            .stream(true)
484            .chunking_strategy_auto()
485            .timestamp_granularities([TimestampGranularity::Segment, TimestampGranularity::Word])
486            .include(TranscriptionInclude::Logprobs)
487            .build()
488            .expect("valid transcription builder");
489
490        assert_eq!(request.model, "gpt-4o-mini-transcribe");
491        assert_eq!(request.language.as_deref(), Some("en"));
492        assert!(matches!(
493            request.timestamp_granularities,
494            Some(grans) if grans.contains(&TimestampGranularity::Word)
495        ));
496        assert!(matches!(
497            request.chunking_strategy,
498            Some(TranscriptionChunkingStrategy::TextVariant(_))
499        ));
500        assert!(matches!(
501            request.include,
502            Some(values) if values.contains(&TranscriptionInclude::Logprobs)
503        ));
504    }
505
506    #[test]
507    fn transcription_temperature_validation() {
508        let err = TranscriptionBuilder::new("audio.wav", "gpt-4o-mini-transcribe")
509            .temperature(1.2)
510            .build()
511            .expect_err("temperature outside range");
512        assert!(matches!(err, Error::InvalidRequest(_)));
513    }
514
515    #[test]
516    fn builds_translation_request() {
517        let request = TranslationBuilder::new("clip.mp3", "gpt-4o-mini-transcribe")
518            .prompt("Keep humour")
519            .response_format(AudioResponseFormat::Text)
520            .temperature(0.3)
521            .build()
522            .expect("valid translation builder");
523
524        assert_eq!(request.model, "gpt-4o-mini-transcribe");
525        assert_eq!(request.response_format, Some(AudioResponseFormat::Text));
526    }
527
528    #[test]
529    fn translation_temperature_validation() {
530        let err = TranslationBuilder::new("clip.mp3", "gpt-4o-mini-transcribe")
531            .temperature(1.5)
532            .build()
533            .expect_err("temperature outside range");
534        assert!(matches!(err, Error::InvalidRequest(_)));
535    }
536}