use crate::shared::FileUpload;
use bytes::Bytes;
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::fmt::Display;
#[derive(Serialize, Deserialize, Debug, Default, Builder, Clone, PartialEq)]
#[builder(name = "AudioSpeechParametersBuilder")]
#[builder(setter(into, strip_option), default)]
pub struct AudioSpeechParameters {
pub model: String,
pub input: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub voice: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub voice_text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<AudioSpeechResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub speed: Option<f32>,
}
#[derive(Serialize, Deserialize, Debug, Default, Builder, Clone, PartialEq)]
#[builder(name = "AudioTranscriptionParametersBuilder")]
#[builder(setter(into, strip_option), default)]
pub struct AudioTranscriptionParameters {
pub file: FileUpload,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub chunking_strategy: Option<TranscriptionChunkingStrategy>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<AudioOutputFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
#[serde(flatten)]
#[serde(skip_serializing_if = "Option::is_none")]
pub extra_body: Option<Value>,
}
#[derive(Serialize, Deserialize, Debug, Default, Builder, Clone, PartialEq)]
#[builder(name = "AudioTranslationParametersBuilder")]
#[builder(setter(into, strip_option), default)]
pub struct AudioTranslationParameters {
pub file: FileUpload,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<AudioOutputFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
}
#[derive(Debug, Clone)]
pub struct AudioSpeechResponse {
pub bytes: Bytes,
}
#[cfg(feature = "stream")]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct StreamAudioSpeechParameters {
pub model: String,
pub input: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub voice: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub voice_text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<AudioSpeechResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub speed: Option<f32>,
pub stream: bool,
}
#[cfg(feature = "stream")]
#[derive(Debug, Clone, PartialEq)]
pub struct AudioSpeechResponseChunkResponse {
pub bytes: Bytes,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum AudioOutputFormat {
Json,
Text,
Srt,
VerboseJson,
Vtt,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum AudioSpeechResponseFormat {
Mp3,
Opus,
Aac,
Flac,
Wav,
Pcm,
}
#[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum AudioVoice {
#[default]
Alloy,
Ash,
Coral,
Echo,
Fable,
Onyx,
Nova,
Sage,
Shimmer,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum TimestampGranularity {
Word,
Segment,
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum TranscriptionChunkingStrategy {
Auto,
#[serde(untagged)]
VadConfig(VadConfig),
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)]
pub struct VadConfig {
pub r#type: VadConfigType,
pub prefix_padding_ms: Option<usize>,
pub silence_duration_ms: Option<usize>,
pub threshold: Option<f32>,
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum VadConfigType {
ServerVad,
}
impl Display for AudioOutputFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
AudioOutputFormat::Json => "json",
AudioOutputFormat::Text => "text",
AudioOutputFormat::Srt => "srt",
AudioOutputFormat::VerboseJson => "verbose_json",
AudioOutputFormat::Vtt => "vtt",
}
)
}
}
impl Display for TimestampGranularity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
TimestampGranularity::Word => "word",
TimestampGranularity::Segment => "segment",
}
)
}
}
impl Display for TranscriptionChunkingStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TranscriptionChunkingStrategy::Auto => "auto".fmt(f),
TranscriptionChunkingStrategy::VadConfig(vad_config) => vad_config.fmt(f),
}
}
}
impl Display for VadConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = serde_json::to_string(self).map_err(|_| std::fmt::Error)?;
write!(f, "{s}")
}
}
#[cfg(test)]
mod tests {
use crate::audio::{
AudioTranscriptionParameters, AudioTranscriptionParametersBuilder,
TranscriptionChunkingStrategy, VadConfig, VadConfigType,
};
use crate::shared::FileUpload;
#[test]
fn test_audio_transcription_chunking_strategy_auto_serialization_deserialization() {
let chunking_strategy = TranscriptionChunkingStrategy::Auto;
let serialized = serde_json::to_string(&chunking_strategy).unwrap();
assert_eq!(serialized, "\"auto\"");
let deserialized: TranscriptionChunkingStrategy =
serde_json::from_str(serialized.as_str()).unwrap();
assert_eq!(deserialized, chunking_strategy)
}
#[test]
fn test_audio_transcription_chunking_strategy_vad_config_serialization_deserialization() {
let chunking_strategy = TranscriptionChunkingStrategy::VadConfig(VadConfig {
r#type: VadConfigType::ServerVad,
prefix_padding_ms: Some(10),
silence_duration_ms: Some(20),
threshold: Some(0.5),
});
let serialized = serde_json::to_string(&chunking_strategy).unwrap();
assert_eq!(serialized, "{\"type\":\"server_vad\",\"prefix_padding_ms\":10,\"silence_duration_ms\":20,\"threshold\":0.5}");
let deserialized: TranscriptionChunkingStrategy =
serde_json::from_str(serialized.as_str()).unwrap();
assert_eq!(deserialized, chunking_strategy)
}
#[test]
fn test_audio_transcription_extra_body_serialization_deserialization() {
let mut builder = &mut AudioTranscriptionParametersBuilder::default();
builder = builder.file(FileUpload::File("test.wav".to_string()));
builder = builder.model("test");
let extra = serde_json::json!({
"enable_my_feature": true,
"my_param": 10
});
builder = builder.extra_body(extra);
let params: AudioTranscriptionParameters = builder.build().unwrap();
let serialized = serde_json::to_string(¶ms).unwrap();
assert_eq!(serialized, "{\"file\":{\"File\":\"test.wav\"},\"model\":\"test\",\"enable_my_feature\":true,\"my_param\":10}");
let deserialized: AudioTranscriptionParameters =
serde_json::from_str(serialized.as_str()).unwrap();
assert_eq!(deserialized, params)
}
}