use serde::{Deserialize, Serialize};
use crate::client::Client;
use crate::error::Result;
pub struct Audio<'a> {
client: &'a Client,
}
impl<'a> Audio<'a> {
pub(crate) fn new(client: &'a Client) -> Self {
Self { client }
}
pub fn transcriptions(&self) -> Transcriptions<'a> {
Transcriptions {
client: self.client,
}
}
pub fn translations(&self) -> Translations<'a> {
Translations {
client: self.client,
}
}
pub fn speech(&self) -> Speech<'a> {
Speech {
client: self.client,
}
}
}
pub struct Transcriptions<'a> {
client: &'a Client,
}
#[derive(Debug, Clone, Default)]
pub struct TranscriptionRequest {
pub file: Vec<u8>,
pub file_name: String,
pub mime_type: Option<String>,
pub model: String,
pub language: Option<String>,
pub prompt: Option<String>,
pub response_format: Option<TranscriptionFormat>,
pub temperature: Option<f32>,
pub timestamp_granularities: Option<Vec<String>>,
}
#[derive(Debug, Clone, Copy)]
pub enum TranscriptionFormat {
Json,
Text,
Srt,
VerboseJson,
Vtt,
}
impl TranscriptionFormat {
fn as_str(&self) -> &'static str {
match self {
Self::Json => "json",
Self::Text => "text",
Self::Srt => "srt",
Self::VerboseJson => "verbose_json",
Self::Vtt => "vtt",
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TranscriptionResponse {
pub text: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub duration: Option<f64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub segments: Vec<serde_json::Value>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub words: Vec<serde_json::Value>,
}
pub struct TranscriptionRequestBuilder {
inner: TranscriptionRequest,
}
impl TranscriptionRequestBuilder {
pub fn new(model: impl Into<String>) -> Self {
Self {
inner: TranscriptionRequest {
model: model.into(),
..Default::default()
},
}
}
pub fn file_bytes(mut self, bytes: Vec<u8>, file_name: impl Into<String>) -> Self {
self.inner.file = bytes;
self.inner.file_name = file_name.into();
self
}
pub fn mime_type(mut self, m: impl Into<String>) -> Self {
self.inner.mime_type = Some(m.into());
self
}
pub fn language(mut self, l: impl Into<String>) -> Self {
self.inner.language = Some(l.into());
self
}
pub fn prompt(mut self, p: impl Into<String>) -> Self {
self.inner.prompt = Some(p.into());
self
}
pub fn response_format(mut self, f: TranscriptionFormat) -> Self {
self.inner.response_format = Some(f);
self
}
pub fn temperature(mut self, t: f32) -> Self {
self.inner.temperature = Some(t);
self
}
pub fn timestamp_granularities(mut self, g: Vec<String>) -> Self {
self.inner.timestamp_granularities = Some(g);
self
}
pub fn build(self) -> TranscriptionRequest {
self.inner
}
}
impl<'a> Transcriptions<'a> {
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "debug", skip_all, fields(endpoint = "audio.transcriptions"))
)]
pub async fn create(&self, req: TranscriptionRequest) -> Result<TranscriptionResponse> {
let form = build_audio_form(&req)?;
super::post_multipart(self.client, "/audio/transcriptions", form).await
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(
level = "debug",
skip_all,
fields(endpoint = "audio.transcriptions.text")
)
)]
pub async fn create_text(&self, req: TranscriptionRequest) -> Result<String> {
let url = self.client.build_url("/audio/transcriptions")?;
let form = build_audio_form(&req)?;
let resp = self
.client
.http()
.post(url)
.headers(self.client.auth_headers())
.multipart(form)
.send()
.await?;
let status = resp.status();
let body = resp.text().await?;
if !status.is_success() {
return Err(crate::error::OpenAiError::from_response_body(
status.as_u16(),
&body,
));
}
Ok(body)
}
}
fn build_audio_form(req: &TranscriptionRequest) -> Result<reqwest::multipart::Form> {
let mut part =
reqwest::multipart::Part::bytes(req.file.clone()).file_name(req.file_name.clone());
if let Some(m) = &req.mime_type {
part = part
.mime_str(m)
.map_err(|e| crate::error::OpenAiError::config(format!("bad mime: {e}")))?;
} else {
part = part
.mime_str("audio/mpeg")
.map_err(|e| crate::error::OpenAiError::config(format!("bad mime: {e}")))?;
}
let mut form = reqwest::multipart::Form::new()
.text("model", req.model.clone())
.part("file", part);
if let Some(l) = &req.language {
form = form.text("language", l.clone());
}
if let Some(p) = &req.prompt {
form = form.text("prompt", p.clone());
}
if let Some(f) = req.response_format {
form = form.text("response_format", f.as_str());
}
if let Some(t) = req.temperature {
form = form.text("temperature", t.to_string());
}
if let Some(g) = &req.timestamp_granularities {
for v in g {
form = form.text("timestamp_granularities[]", v.clone());
}
}
Ok(form)
}
pub struct Translations<'a> {
client: &'a Client,
}
impl<'a> Translations<'a> {
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "debug", skip_all, fields(endpoint = "audio.translations"))
)]
pub async fn create(&self, req: TranscriptionRequest) -> Result<TranscriptionResponse> {
let form = build_audio_form(&req)?;
super::post_multipart(self.client, "/audio/translations", form).await
}
}
pub struct Speech<'a> {
client: &'a Client,
}
#[derive(Debug, Clone, Serialize)]
pub struct SpeechRequest {
pub model: String,
pub input: String,
pub voice: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub speed: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
}
pub struct SpeechRequestBuilder {
inner: SpeechRequest,
}
impl SpeechRequestBuilder {
pub fn new(
model: impl Into<String>,
voice: impl Into<String>,
input: impl Into<String>,
) -> Self {
Self {
inner: SpeechRequest {
model: model.into(),
voice: voice.into(),
input: input.into(),
response_format: None,
speed: None,
instructions: None,
},
}
}
pub fn response_format(mut self, f: impl Into<String>) -> Self {
self.inner.response_format = Some(f.into());
self
}
pub fn speed(mut self, s: f32) -> Self {
self.inner.speed = Some(s);
self
}
pub fn instructions(mut self, i: impl Into<String>) -> Self {
self.inner.instructions = Some(i.into());
self
}
pub fn build(self) -> SpeechRequest {
self.inner
}
}
impl<'a> Speech<'a> {
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "debug", skip_all, fields(endpoint = "audio.speech"))
)]
pub async fn create(&self, req: SpeechRequest) -> Result<bytes::Bytes> {
super::post_json_bytes(self.client, "/audio/speech", &req).await
}
}