use reqwest::multipart;
use serde::Deserialize;
use crate::LLMApi;
use crate::Retry;
pub struct TranscribeInput {
pub audio: Vec<u8>,
pub audio_format: String,
pub language: String,
pub max_len: Option<u64>,
pub max_context: Option<i32>,
pub split_on_word: Option<bool>,
}
impl LLMApi for TranscribeInput {
type Output = TranscriptionOutput;
async fn api(&self, endpoint: &str, api_key: &str) -> Retry<Self::Output> {
transcribe_inner(endpoint, api_key, &self).await
}
}
#[derive(Debug, Deserialize)]
pub struct TranscriptionOutput {
pub text: String,
}
pub struct TranslateInput {
pub audio: Vec<u8>,
pub audio_format: String,
pub language: String,
pub max_len: Option<u64>,
pub max_context: Option<i32>,
pub split_on_word: Option<bool>,
}
impl LLMApi for TranslateInput {
type Output = TranslationOutput;
async fn api(&self, endpoint: &str, api_key: &str) -> Retry<Self::Output> {
translate_inner(endpoint, api_key, &self).await
}
}
#[derive(Debug, Deserialize)]
pub struct TranslationOutput {
pub text: String,
}
impl<'a> crate::LLMServiceFlows<'a> {
pub async fn transcribe(&self, input: TranscribeInput) -> Result<TranscriptionOutput, String> {
self.keep_trying(input).await
}
pub async fn translate(&self, input: TranslateInput) -> Result<TranslationOutput, String> {
self.keep_trying(input).await
}
}
async fn transcribe_inner(
endpoint: &str,
_api_key: &str,
input: &TranscribeInput,
) -> Retry<TranscriptionOutput> {
let uri = format!("{}/audio/transcriptions", endpoint);
let mut form = multipart::Form::new()
.part(
"file",
multipart::Part::bytes(input.audio.clone())
.file_name(format!("audio.{}", input.audio_format)),
)
.part("language", multipart::Part::text(input.language.clone()));
if input.max_len.is_some() {
form = form.part(
"max_len",
multipart::Part::text(input.max_len.unwrap().to_string()),
);
}
if input.max_context.is_some() {
form = form.part(
"max_context",
multipart::Part::text(input.max_context.unwrap().to_string()),
);
}
if input.split_on_word.is_some() {
form = form.part(
"split_on_word",
multipart::Part::text(input.split_on_word.unwrap().to_string()),
);
}
match reqwest::Client::new()
.post(uri)
.multipart(form)
.send()
.await
{
Ok(res) => {
let status = res.status();
let body = res.bytes().await.unwrap();
match status.is_success() {
true => Retry::No(
serde_json::from_slice::<TranscriptionOutput>(&body.as_ref())
.or(Err(String::from("Unexpected error"))),
),
false => {
match status.into() {
409 | 429 | 503 => {
Retry::Yes(String::from_utf8_lossy(&body.as_ref()).into_owned())
}
_ => Retry::No(Err(String::from_utf8_lossy(&body.as_ref()).into_owned())),
}
}
}
}
Err(e) => Retry::No(Err(e.to_string())),
}
}
async fn translate_inner(
endpoint: &str,
_api_key: &str,
input: &TranslateInput,
) -> Retry<TranslationOutput> {
let uri = format!("{}/audio/translations", endpoint);
let mut form = multipart::Form::new()
.part(
"file",
multipart::Part::bytes(input.audio.clone())
.file_name(format!("audio.{}", input.audio_format)),
)
.part("language", multipart::Part::text(input.language.clone()));
if input.max_len.is_some() {
form = form.part(
"max_len",
multipart::Part::text(input.max_len.unwrap().to_string()),
);
}
if input.max_context.is_some() {
form = form.part(
"max_context",
multipart::Part::text(input.max_context.unwrap().to_string()),
);
}
if input.split_on_word.is_some() {
form = form.part(
"split_on_word",
multipart::Part::text(input.split_on_word.unwrap().to_string()),
);
}
match reqwest::Client::new()
.post(uri)
.multipart(form)
.send()
.await
{
Ok(res) => {
let status = res.status();
let body = res.bytes().await.unwrap();
match status.is_success() {
true => Retry::No(
serde_json::from_slice::<TranslationOutput>(&body.as_ref())
.or(Err(String::from("Unexpected error"))),
),
false => {
match status.into() {
409 | 429 | 503 => {
Retry::Yes(String::from_utf8_lossy(&body.as_ref()).into_owned())
}
_ => Retry::No(Err(String::from_utf8_lossy(&body.as_ref()).into_owned())),
}
}
}
}
Err(e) => Retry::No(Err(e.to_string())),
}
}