openai_ergonomic/builders/
audio.rs1use std::path::{Path, PathBuf};
9
10use openai_client_base::models::transcription_chunking_strategy::TranscriptionChunkingStrategyTextVariantEnum;
11use openai_client_base::models::{
12 create_speech_request::{
13 ResponseFormat as SpeechResponseFormat, StreamFormat as SpeechStreamFormat,
14 },
15 AudioResponseFormat, CreateSpeechRequest, TranscriptionChunkingStrategy, TranscriptionInclude,
16 VadConfig,
17};
18
19use crate::{Builder, Error, Result};
20
21#[derive(Debug, Clone)]
23pub struct SpeechBuilder {
24 model: String,
25 input: String,
26 voice: String,
27 instructions: Option<String>,
28 response_format: Option<SpeechResponseFormat>,
29 speed: Option<f64>,
30 stream_format: Option<SpeechStreamFormat>,
31}
32
33impl SpeechBuilder {
34 #[must_use]
36 pub fn new(
37 model: impl Into<String>,
38 input: impl Into<String>,
39 voice: impl Into<String>,
40 ) -> Self {
41 Self {
42 model: model.into(),
43 input: input.into(),
44 voice: voice.into(),
45 instructions: None,
46 response_format: None,
47 speed: None,
48 stream_format: None,
49 }
50 }
51
52 #[must_use]
54 pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
55 self.instructions = Some(instructions.into());
56 self
57 }
58
59 #[must_use]
61 pub fn response_format(mut self, format: SpeechResponseFormat) -> Self {
62 self.response_format = Some(format);
63 self
64 }
65
66 #[must_use]
68 pub fn speed(mut self, speed: f64) -> Self {
69 self.speed = Some(speed);
70 self
71 }
72
73 #[must_use]
75 pub fn stream_format(mut self, format: SpeechStreamFormat) -> Self {
76 self.stream_format = Some(format);
77 self
78 }
79
80 #[must_use]
82 pub fn model(&self) -> &str {
83 &self.model
84 }
85
86 #[must_use]
88 pub fn input(&self) -> &str {
89 &self.input
90 }
91
92 #[must_use]
94 pub fn voice(&self) -> &str {
95 &self.voice
96 }
97
98 #[must_use]
100 pub fn response_format_ref(&self) -> Option<SpeechResponseFormat> {
101 self.response_format
102 }
103
104 #[must_use]
106 pub fn stream_format_ref(&self) -> Option<SpeechStreamFormat> {
107 self.stream_format
108 }
109}
110
111impl Builder<CreateSpeechRequest> for SpeechBuilder {
112 fn build(self) -> Result<CreateSpeechRequest> {
113 if let Some(speed) = self.speed {
114 if !(0.25..=4.0).contains(&speed) {
115 return Err(Error::InvalidRequest(format!(
116 "Speech speed {speed} is outside the supported range 0.25–4.0"
117 )));
118 }
119 }
120
121 Ok(CreateSpeechRequest {
122 model: self.model,
123 input: self.input,
124 instructions: self.instructions,
125 voice: self.voice,
126 response_format: self.response_format,
127 speed: self.speed,
128 stream_format: self.stream_format,
129 })
130 }
131}
132
133#[derive(Debug, Clone, Copy, Eq, PartialEq)]
135pub enum TimestampGranularity {
136 Segment,
138 Word,
140}
141
142impl TimestampGranularity {
143 pub(crate) fn as_str(self) -> &'static str {
144 match self {
145 Self::Segment => "segment",
146 Self::Word => "word",
147 }
148 }
149}
150
151#[derive(Debug, Clone)]
153pub struct TranscriptionBuilder {
154 file: PathBuf,
155 model: String,
156 language: Option<String>,
157 prompt: Option<String>,
158 response_format: Option<AudioResponseFormat>,
159 temperature: Option<f64>,
160 stream: Option<bool>,
161 chunking_strategy: Option<TranscriptionChunkingStrategy>,
162 timestamp_granularities: Vec<TimestampGranularity>,
163 include: Vec<TranscriptionInclude>,
164}
165
166impl TranscriptionBuilder {
167 #[must_use]
169 pub fn new(file: impl AsRef<Path>, model: impl Into<String>) -> Self {
170 Self {
171 file: file.as_ref().to_path_buf(),
172 model: model.into(),
173 language: None,
174 prompt: None,
175 response_format: None,
176 temperature: None,
177 stream: None,
178 chunking_strategy: None,
179 timestamp_granularities: Vec::new(),
180 include: Vec::new(),
181 }
182 }
183
184 #[must_use]
186 pub fn language(mut self, language: impl Into<String>) -> Self {
187 self.language = Some(language.into());
188 self
189 }
190
191 #[must_use]
193 pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
194 self.prompt = Some(prompt.into());
195 self
196 }
197
198 #[must_use]
200 pub fn response_format(mut self, format: AudioResponseFormat) -> Self {
201 self.response_format = Some(format);
202 self
203 }
204
205 #[must_use]
207 pub fn temperature(mut self, temperature: f64) -> Self {
208 self.temperature = Some(temperature);
209 self
210 }
211
212 #[must_use]
214 pub fn stream(mut self, stream: bool) -> Self {
215 self.stream = Some(stream);
216 self
217 }
218
219 #[must_use]
221 pub fn chunking_strategy_auto(mut self) -> Self {
222 self.chunking_strategy = Some(TranscriptionChunkingStrategy::TextVariant(
223 TranscriptionChunkingStrategyTextVariantEnum::Auto,
224 ));
225 self
226 }
227
228 #[must_use]
230 pub fn chunking_strategy_vad(mut self, config: VadConfig) -> Self {
231 self.chunking_strategy = Some(TranscriptionChunkingStrategy::Vadconfig(config));
232 self
233 }
234
235 #[must_use]
237 pub fn clear_chunking_strategy(mut self) -> Self {
238 self.chunking_strategy = None;
239 self
240 }
241
242 #[must_use]
244 pub fn timestamp_granularities(
245 mut self,
246 granularities: impl IntoIterator<Item = TimestampGranularity>,
247 ) -> Self {
248 self.timestamp_granularities = granularities.into_iter().collect();
249 self
250 }
251
252 #[must_use]
254 pub fn add_timestamp_granularity(mut self, granularity: TimestampGranularity) -> Self {
255 if !self.timestamp_granularities.contains(&granularity) {
256 self.timestamp_granularities.push(granularity);
257 }
258 self
259 }
260
261 #[must_use]
263 pub fn include(mut self, include: TranscriptionInclude) -> Self {
264 if !self.include.contains(&include) {
265 self.include.push(include);
266 }
267 self
268 }
269
270 #[must_use]
272 pub fn file(&self) -> &Path {
273 &self.file
274 }
275
276 #[must_use]
278 pub fn model(&self) -> &str {
279 &self.model
280 }
281
282 #[must_use]
284 pub fn language_ref(&self) -> Option<&str> {
285 self.language.as_deref()
286 }
287
288 #[must_use]
290 pub fn response_format_ref(&self) -> Option<AudioResponseFormat> {
291 self.response_format
292 }
293}
294
295#[derive(Debug, Clone)]
297pub struct TranscriptionRequest {
298 pub file: PathBuf,
300 pub model: String,
302 pub language: Option<String>,
304 pub prompt: Option<String>,
306 pub response_format: Option<AudioResponseFormat>,
308 pub temperature: Option<f64>,
310 pub stream: Option<bool>,
312 pub chunking_strategy: Option<TranscriptionChunkingStrategy>,
314 pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
316 pub include: Option<Vec<TranscriptionInclude>>,
318}
319
320impl Builder<TranscriptionRequest> for TranscriptionBuilder {
321 fn build(self) -> Result<TranscriptionRequest> {
322 if let Some(temperature) = self.temperature {
323 if !(0.0..=1.0).contains(&temperature) {
324 return Err(Error::InvalidRequest(format!(
325 "Transcription temperature {temperature} is outside the supported range 0.0–1.0"
326 )));
327 }
328 }
329
330 Ok(TranscriptionRequest {
331 file: self.file,
332 model: self.model,
333 language: self.language,
334 prompt: self.prompt,
335 response_format: self.response_format,
336 temperature: self.temperature,
337 stream: self.stream,
338 chunking_strategy: self.chunking_strategy,
339 timestamp_granularities: if self.timestamp_granularities.is_empty() {
340 None
341 } else {
342 Some(self.timestamp_granularities)
343 },
344 include: if self.include.is_empty() {
345 None
346 } else {
347 Some(self.include)
348 },
349 })
350 }
351}
352
353#[derive(Debug, Clone)]
355pub struct TranslationBuilder {
356 file: PathBuf,
357 model: String,
358 prompt: Option<String>,
359 response_format: Option<AudioResponseFormat>,
360 temperature: Option<f64>,
361}
362
363impl TranslationBuilder {
364 #[must_use]
366 pub fn new(file: impl AsRef<Path>, model: impl Into<String>) -> Self {
367 Self {
368 file: file.as_ref().to_path_buf(),
369 model: model.into(),
370 prompt: None,
371 response_format: None,
372 temperature: None,
373 }
374 }
375
376 #[must_use]
378 pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
379 self.prompt = Some(prompt.into());
380 self
381 }
382
383 #[must_use]
385 pub fn response_format(mut self, format: AudioResponseFormat) -> Self {
386 self.response_format = Some(format);
387 self
388 }
389
390 #[must_use]
392 pub fn temperature(mut self, temperature: f64) -> Self {
393 self.temperature = Some(temperature);
394 self
395 }
396
397 #[must_use]
399 pub fn model(&self) -> &str {
400 &self.model
401 }
402
403 #[must_use]
405 pub fn file(&self) -> &Path {
406 &self.file
407 }
408}
409
410#[derive(Debug, Clone)]
412pub struct TranslationRequest {
413 pub file: PathBuf,
415 pub model: String,
417 pub prompt: Option<String>,
419 pub response_format: Option<AudioResponseFormat>,
421 pub temperature: Option<f64>,
423}
424
425impl Builder<TranslationRequest> for TranslationBuilder {
426 fn build(self) -> Result<TranslationRequest> {
427 if let Some(temperature) = self.temperature {
428 if !(0.0..=1.0).contains(&temperature) {
429 return Err(Error::InvalidRequest(format!(
430 "Translation temperature {temperature} is outside the supported range 0.0–1.0"
431 )));
432 }
433 }
434
435 Ok(TranslationRequest {
436 file: self.file,
437 model: self.model,
438 prompt: self.prompt,
439 response_format: self.response_format,
440 temperature: self.temperature,
441 })
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
450 fn builds_speech_request() {
451 let request = SpeechBuilder::new("gpt-4o-mini-tts", "Hello world", "alloy")
452 .instructions("Speak calmly")
453 .response_format(SpeechResponseFormat::Wav)
454 .speed(1.25)
455 .stream_format(SpeechStreamFormat::Audio)
456 .build()
457 .expect("valid speech builder");
458
459 assert_eq!(request.model, "gpt-4o-mini-tts");
460 assert_eq!(request.input, "Hello world");
461 assert_eq!(request.voice, "alloy");
462 assert_eq!(request.response_format, Some(SpeechResponseFormat::Wav));
463 assert_eq!(request.speed, Some(1.25));
464 assert_eq!(request.stream_format, Some(SpeechStreamFormat::Audio));
465 }
466
467 #[test]
468 fn speech_speed_validation() {
469 let err = SpeechBuilder::new("gpt-4o-mini-tts", "Hi", "alloy")
470 .speed(5.0)
471 .build()
472 .expect_err("speed outside supported range");
473 assert!(matches!(err, Error::InvalidRequest(_)));
474 }
475
476 #[test]
477 fn builds_transcription_request() {
478 let request = TranscriptionBuilder::new("audio.wav", "gpt-4o-mini-transcribe")
479 .language("en")
480 .prompt("Friendly tone")
481 .response_format(AudioResponseFormat::VerboseJson)
482 .temperature(0.2)
483 .stream(true)
484 .chunking_strategy_auto()
485 .timestamp_granularities([TimestampGranularity::Segment, TimestampGranularity::Word])
486 .include(TranscriptionInclude::Logprobs)
487 .build()
488 .expect("valid transcription builder");
489
490 assert_eq!(request.model, "gpt-4o-mini-transcribe");
491 assert_eq!(request.language.as_deref(), Some("en"));
492 assert!(matches!(
493 request.timestamp_granularities,
494 Some(grans) if grans.contains(&TimestampGranularity::Word)
495 ));
496 assert!(matches!(
497 request.chunking_strategy,
498 Some(TranscriptionChunkingStrategy::TextVariant(_))
499 ));
500 assert!(matches!(
501 request.include,
502 Some(values) if values.contains(&TranscriptionInclude::Logprobs)
503 ));
504 }
505
506 #[test]
507 fn transcription_temperature_validation() {
508 let err = TranscriptionBuilder::new("audio.wav", "gpt-4o-mini-transcribe")
509 .temperature(1.2)
510 .build()
511 .expect_err("temperature outside range");
512 assert!(matches!(err, Error::InvalidRequest(_)));
513 }
514
515 #[test]
516 fn builds_translation_request() {
517 let request = TranslationBuilder::new("clip.mp3", "gpt-4o-mini-transcribe")
518 .prompt("Keep humour")
519 .response_format(AudioResponseFormat::Text)
520 .temperature(0.3)
521 .build()
522 .expect("valid translation builder");
523
524 assert_eq!(request.model, "gpt-4o-mini-transcribe");
525 assert_eq!(request.response_format, Some(AudioResponseFormat::Text));
526 }
527
528 #[test]
529 fn translation_temperature_validation() {
530 let err = TranslationBuilder::new("clip.mp3", "gpt-4o-mini-transcribe")
531 .temperature(1.5)
532 .build()
533 .expect_err("temperature outside range");
534 assert!(matches!(err, Error::InvalidRequest(_)));
535 }
536}