use crate::IntoRequest;
use derive_builder::Builder;
use reqwest::{
multipart::{Form, Part},
Client, RequestBuilder,
};
use serde::Deserialize;
use strum::{Display, EnumString};
#[derive(Debug, Clone, Builder)]
#[builder(pattern = "mutable")]
pub struct WhisperRequest {
file: Vec<u8>,
#[builder(default)]
model: WhisperModel,
#[builder(default, setter(strip_option, into))]
language: Option<String>,
#[builder(default, setter(strip_option, into))]
prompt: Option<String>,
#[builder(default)]
pub(crate) response_format: WhisperResponseFormat,
#[builder(default, setter(strip_option))]
temperature: Option<f32>,
request_type: WhisperRequestType,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
pub enum WhisperModel {
#[default]
#[strum(serialize = "whisper-1")]
Whisper1,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
#[strum(serialize_all = "snake_case")]
pub enum WhisperResponseFormat {
#[default]
Json,
Text,
Srt,
VerboseJson,
Vtt,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
pub enum WhisperRequestType {
#[default]
Transcription,
Translation,
}
#[derive(Debug, Clone, Deserialize)]
pub struct WhisperResponse {
pub text: String,
}
impl WhisperRequest {
pub fn transcription(data: Vec<u8>) -> Self {
WhisperRequestBuilder::default()
.file(data)
.request_type(WhisperRequestType::Transcription)
.build()
.unwrap()
}
pub fn translation(data: Vec<u8>) -> Self {
WhisperRequestBuilder::default()
.file(data)
.request_type(WhisperRequestType::Translation)
.build()
.unwrap()
}
fn into_form(self) -> Form {
let part = Part::bytes(self.file)
.file_name("file")
.mime_str("audio/mp3")
.unwrap();
let mut form = Form::new()
.part("file", part)
.text("model", self.model.to_string())
.text("response_format", self.response_format.to_string());
form = match (self.request_type, self.language) {
(WhisperRequestType::Transcription, Some(language)) => form.text("language", language),
_ => form,
};
form = if let Some(prompt) = self.prompt {
form.text("prompt", prompt)
} else {
form
};
if let Some(temperature) = self.temperature {
form.text("temperature", temperature.to_string())
} else {
form
}
}
}
impl IntoRequest for WhisperRequest {
fn into_request(self, base_url: &str, client: Client) -> RequestBuilder {
let url = match self.request_type {
WhisperRequestType::Transcription => format!("{}/audio/transcriptions", base_url),
WhisperRequestType::Translation => format!("{}/audio/translations", base_url),
};
client.post(url).multipart(self.into_form())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::SDK;
use anyhow::Result;
use std::fs;
#[tokio::test]
async fn transcription_should_work() -> Result<()> {
let data = fs::read("fixtures/speech.mp3")?;
let req = WhisperRequest::transcription(data);
let res = SDK.whisper(req).await?;
assert_eq!(res.text, "The quick brown fox jumped over the lazy dog.");
Ok(())
}
#[tokio::test]
async fn transcription_with_response_format_should_work() -> Result<()> {
let data = fs::read("fixtures/speech.mp3")?;
let req = WhisperRequestBuilder::default()
.file(data)
.response_format(WhisperResponseFormat::Text)
.request_type(WhisperRequestType::Transcription)
.build()?;
let res = SDK.whisper(req).await?;
assert_eq!(res.text, "The quick brown fox jumped over the lazy dog.\n");
Ok(())
}
#[tokio::test]
async fn transcription_with_vtt_response_format_should_work() -> Result<()> {
let data = fs::read("fixtures/speech.mp3")?;
let req = WhisperRequestBuilder::default()
.file(data)
.response_format(WhisperResponseFormat::Vtt)
.request_type(WhisperRequestType::Transcription)
.build()?;
let res = SDK.whisper(req).await?;
assert_eq!(res.text, "WEBVTT\n\n00:00:00.000 --> 00:00:02.800\nThe quick brown fox jumped over the lazy dog.\n\n");
Ok(())
}
#[tokio::test]
async fn translate_should_work() -> Result<()> {
let data = fs::read("fixtures/chinese.mp3")?;
let req = WhisperRequestBuilder::default()
.file(data)
.response_format(WhisperResponseFormat::Srt)
.request_type(WhisperRequestType::Translation)
.build()?;
let res = SDK.whisper(req).await?;
assert_eq!(res.text, "1\n00:00:00,000 --> 00:00:03,000\nThe red scarf hangs on the chest, the motherland is always in my heart.\n\n\n");
Ok(())
}
}