use reqwest::multipart::Form;
use std::fmt::{Debug, Display};
use subtp::srt::SubRip;
use subtp::vtt::WebVtt;
use crate::audio::AudioApiError;
use crate::audio::AudioApiResult;
use crate::audio::AudioModel;
use crate::audio::File;
use crate::audio::Iso639_1;
use crate::audio::JsonResponse;
use crate::audio::JsonResponseFormatter;
use crate::audio::PlainTextResponseFormatter;
use crate::audio::SrtResponseFormatter;
use crate::audio::TextResponseFormat;
use crate::audio::TextResponseFormatter;
use crate::audio::TimestampGranularity;
use crate::audio::VerboseJsonResponse;
use crate::audio::VerboseJsonResponseFormatter;
use crate::audio::VttResponseFormatter;
use crate::ApiError;
use crate::Client;
use crate::ClientError;
use crate::Prompt;
use crate::Temperature;
#[derive(Debug, Default)]
pub struct TranscriptionsRequestBody {
pub file: File,
pub model: AudioModel,
pub language: Option<Iso639_1>,
pub prompt: Option<Prompt>,
pub temperature: Option<Temperature>,
pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
}
impl Display for TranscriptionsRequestBody {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
write!(f, "file: {}", self.file)?;
write!(f, ", model: {}", self.model)?;
if let Some(language) = self.language {
write!(f, ", language: {}", language)?;
}
if let Some(prompt) = self.prompt.clone() {
write!(f, ", prompt: {}", prompt)?;
}
if let Some(temperature) = self.temperature {
write!(f, ", temperature: {}", temperature)?;
}
if let Some(timestamp_granularities) = self
.timestamp_granularities
.clone()
{
write!(
f,
", timestamp_granularities: [{}]",
timestamp_granularities
.iter()
.map(|granularity| granularity.to_string())
.collect::<Vec<String>>()
.join(", ")
)?;
}
Ok(())
}
}
impl TranscriptionsRequestBody {
pub fn new(
file: File,
model: AudioModel,
language: Option<Iso639_1>,
prompt: Option<Prompt>,
temperature: Option<Temperature>,
timestamp_granularities: Option<Vec<TimestampGranularity>>,
) -> Self {
Self {
file,
model,
language,
prompt,
temperature,
timestamp_granularities,
}
}
async fn build_form<F, T>(self) -> Form
where
F: TextResponseFormat,
T: TextResponseFormatter<F>,
{
let mut form = Form::new()
.part("file", self.file.part)
.text("model", self.model.to_string())
.text("response_format", F::format());
if let Some(language) = self.language {
form = form.text("language", language.to_string());
}
if let Some(prompt) = self.prompt {
form = form.text("prompt", prompt.to_string());
}
if let Some(temperature) = self.temperature {
form = form.text("temperature", temperature.to_string());
}
if let Some(timestamp_granularities) = self.timestamp_granularities {
for granularity in timestamp_granularities {
form = form.text(
"timestamp_granularities[]",
granularity.to_string(),
)
}
}
form
}
}
async fn transcribe<F, T>(
client: &Client,
request_body: TranscriptionsRequestBody,
) -> AudioApiResult<F>
where
F: TextResponseFormat,
T: TextResponseFormatter<F>,
{
if request_body
.timestamp_granularities
.is_some()
&& F::format() != VerboseJsonResponse::format()
{
return Err(AudioApiError::TimestampOptionMismatch);
}
let form = request_body
.build_form::<F, T>()
.await;
let response = client
.post("https://api.openai.com/v1/audio/transcriptions")
.multipart(form)
.send()
.await
.map_err(ClientError::HttpRequestError)?;
let status_code = response.status();
let response_text = response
.text()
.await
.map_err(ClientError::ReadResponseTextFailed)?;
if status_code.is_success() {
println!("response_text: {}", response_text);
T::format(response_text).map_err(AudioApiError::FormatResponseFailed)
}
else {
let error_response =
serde_json::from_str(&response_text).map_err(|error| {
ClientError::ErrorResponseDeserializationFailed {
error,
text: response_text,
}
})?;
Err(ApiError {
status_code,
error_response,
}
.into())
}
}
pub(crate) async fn transcribe_into_json(
client: &Client,
request_body: TranscriptionsRequestBody,
) -> AudioApiResult<JsonResponse> {
transcribe::<JsonResponse, JsonResponseFormatter>(client, request_body)
.await
}
pub(crate) async fn transcribe_into_plain_text(
client: &Client,
request_body: TranscriptionsRequestBody,
) -> AudioApiResult<String> {
transcribe::<String, PlainTextResponseFormatter>(client, request_body).await
}
pub(crate) async fn transcribe_into_verbose_json(
client: &Client,
request_body: TranscriptionsRequestBody,
) -> AudioApiResult<VerboseJsonResponse> {
transcribe::<VerboseJsonResponse, VerboseJsonResponseFormatter>(
client,
request_body,
)
.await
}
pub(crate) async fn transcribe_into_srt(
client: &Client,
request_body: TranscriptionsRequestBody,
) -> AudioApiResult<SubRip> {
transcribe::<SubRip, SrtResponseFormatter>(client, request_body).await
}
pub(crate) async fn transcribe_into_vtt(
client: &Client,
request_body: TranscriptionsRequestBody,
) -> AudioApiResult<WebVtt> {
transcribe::<WebVtt, VttResponseFormatter>(client, request_body).await
}