use crate::{FormRequest, OpenAiClient, process_response, process_text_response};
use std::io;
use std::path::{PathBuf};
use reqwest::multipart::{Form, Part};
use serde::{Serialize,Deserialize};
use crate::file_to_part;
use async_trait::async_trait;
use anyhow::Result;
use futures_util::TryFutureExt;
use reqwest::Response;
use strum_macros::Display;
use crate::conversions::AsyncTryFrom;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ResponseFormat{
Json,
VerboseJson,
Text,
Srt,
Vtt
}
impl ToString for ResponseFormat {
fn to_string(&self) -> String {
match self {
ResponseFormat::Json => "json".to_string(),
ResponseFormat::Text => "text".to_string(),
ResponseFormat::Srt => "srt".to_string(),
ResponseFormat::VerboseJson => "verbose_json".to_string(),
ResponseFormat::Vtt => "vtt".to_string()
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum AudioResponse{
Json(ShortAudioResponse),
VerboseJson(VerboseAudioResponse),
Text(String),
Srt(String),
Vtt(String)
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ShortAudioResponse{
pub text:String
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Segment {
pub id: i64,
pub seek: i64,
pub start: f64,
pub end: f64,
pub text: String,
pub tokens: Vec<i64>,
pub temperature: f64,
pub avg_logprob: f64,
pub compression_ratio: f64,
pub no_speech_prob: f64,
pub transient: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct VerboseAudioResponse {
pub task: String,
pub language: String,
pub duration: f64,
pub segments: Vec<Segment>,
pub text: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TranscriptionRequest{
file: PathBuf,
model: String,
prompt:Option<String>,
response_format: Option<ResponseFormat>,
temperature: Option<f64>,
language: Option<Iso639_1>
}
#[async_trait]
impl FormRequest<AudioResponse> for TranscriptionRequest{
const ENDPOINT: &'static str = "/audio/transcriptions";
async fn run(&self, client:&OpenAiClient)-> Result<AudioResponse>{
let final_url = client.url.to_owned()+Self::ENDPOINT;
let res = self.get_response(&client.client,final_url,&client.key).await?;
process_audio_response(&self.response_format,res).await
}
}
#[async_trait]
impl AsyncTryFrom<TranscriptionRequest> for Form {
type Error = io::Error;
async fn try_from(transcription_request: TranscriptionRequest) -> Result<Self, Self::Error> {
let mut form = Form::new();
form = form.part("model", Part::text(transcription_request.model));
form = form.part("file", file_to_part(&transcription_request.file).await?);
if let Some(prompt) = transcription_request.prompt {
form = form.part("prompt", Part::text(prompt));
}
if let Some(response_format) = transcription_request.response_format {
form = form.part("response_format", Part::text(response_format.to_string()));
}
if let Some(temperature) = transcription_request.temperature {
form = form.part("temperature", Part::text(temperature.to_string()));
}
if let Some(language) = transcription_request.language {
form = form.part("language", Part::text(language.to_string()));
}
Ok(form)
}
}
impl TranscriptionRequest {
pub fn new(file: PathBuf) -> Self {
TranscriptionRequest {
file,
model: "whisper-1".to_string(),
prompt: None,
response_format: None,
temperature: None,
language: None
}
}
pub fn with_model(file: PathBuf, model: String) -> Self {
TranscriptionRequest {
file,
model,
prompt: None,
response_format: None,
temperature: None,
language: None
}
}
pub fn file(mut self, file: PathBuf) -> Self {
self.file = file;
self
}
pub fn model(mut self, model: String) -> Self {
self.model = model;
self
}
pub fn prompt(mut self, prompt: String) -> Self {
self.prompt = Some(prompt);
self
}
pub fn response_format(mut self, response_format: ResponseFormat) -> Self {
self.response_format = Some(response_format);
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn language(mut self, language: Iso639_1) -> Self {
self.language = Some(language);
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TranslationRequest{
file: PathBuf,
model: String,
prompt:Option<String>,
response_format: Option<ResponseFormat>,
temperature: Option<f64>
}
#[async_trait]
impl FormRequest<AudioResponse> for TranslationRequest{
const ENDPOINT: &'static str = "/audio/translations";
async fn run(&self, client:&OpenAiClient)-> Result<AudioResponse>{
let final_url = client.url.to_owned()+Self::ENDPOINT;
let res = self.get_response(&client.client,final_url,&client.key).await?;
process_audio_response(&self.response_format,res).await
}
}
#[async_trait]
impl AsyncTryFrom<TranslationRequest> for Form {
type Error = io::Error;
async fn try_from(translation_request: TranslationRequest) -> Result<Self, Self::Error> {
let mut form = Form::new();
form = form.part("model", Part::text(translation_request.model));
form = form.part("file", file_to_part(&translation_request.file).await?);
if let Some(prompt) = translation_request.prompt {
form = form.part("prompt", Part::text(prompt));
}
if let Some(response_format) = translation_request.response_format {
form = form.part("response_format", Part::text(response_format.to_string()));
}
if let Some(temperature) = translation_request.temperature {
form = form.part("temperature", Part::text(temperature.to_string()));
}
Ok(form)
}
}
impl TranslationRequest {
pub fn new(file: PathBuf) -> Self {
TranslationRequest {
file,
model:"whisper-1".to_string(),
prompt: None,
response_format: None,
temperature: None
}
}
pub fn with_model(file: PathBuf, model: String) -> Self {
TranslationRequest {
file,
model,
prompt: None,
response_format: None,
temperature: None
}
}
pub fn file(mut self, file: PathBuf) -> Self {
self.file = file;
self
}
pub fn model(mut self, model: String) -> Self {
self.model = model;
self
}
pub fn prompt(mut self, prompt: String) -> Self {
self.prompt = Some(prompt);
self
}
pub fn response_format(mut self, response_format: ResponseFormat) -> Self {
self.response_format = Some(response_format);
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
}
#[derive(Clone, Debug, PartialEq, Display, Serialize, Deserialize)]
#[strum(serialize_all = "lowercase")]
pub enum Iso639_1 {
Aa,
Ab,
Ae,
Af,
Ak,
Am,
An,
Ar,
As,
Av,
Ay,
Az,
Ba,
Be,
Bg,
Bh,
Bi,
Bm,
Bn,
Bo,
Br,
Bs,
Ca,
Ce,
Ch,
Co,
Cr,
Cs,
Cu,
Cv,
Cy,
Da,
De,
Dv,
Dz,
Ee,
El,
En,
Eo,
Es,
Et,
Eu,
Fa,
Ff,
Fi,
Fj,
Fo,
Fr,
Fy,
Ga,
Gd,
Gl,
Gn,
Gu,
Gv,
Ha,
He,
Hi,
Ho,
Hr,
Ht,
Hu,
Hy,
Hz,
Ia,
Id,
Ie,
Ig,
Ii,
Ik,
Io,
Is,
It,
Iu,
Ja,
Jv,
Ka,
Kg,
Ki,
Kj,
Kk,
Kl,
Km,
Kn,
Ko,
Kr,
Ks,
Ku,
Kv,
Kw,
Ky,
La,
Lb,
Lg,
Li,
Ln,
Lo,
Lt,
Lu,
Lv,
Mg,
Mh,
Mi,
Mk,
Ml,
Mn,
Mr,
Ms,
Mt,
My,
Na,
Nb,
Nd,
Ne,
Ng,
Nl,
Nn,
No,
Nr,
Nv,
Ny,
Oc,
Oj,
Om,
Or,
Os,
Pa,
Pi,
Pl,
Ps,
Pt,
Qu,
Rm,
Rn,
Ro,
Ru,
Rw,
Sa,
Sc,
Sd,
Se,
Sg,
Si,
Sk,
Sl,
Sm,
Sn,
So,
Sq,
Sr,
Ss,
St,
Su,
Sv,
Sw,
Ta,
Te,
Tg,
Th,
Ti,
Tk,
Tl,
Tn,
To,
Tr,
Ts,
Tt,
Tw,
Ty,
Ug,
Uk,
Ur,
Uz,
Ve,
Vi,
Vo,
Wa,
Wo,
Xh,
Yi,
Yo,
Za,
Zh,
Zu,
}
async fn process_audio_response(resp_format:&Option<ResponseFormat>,res:Response)-> Result<AudioResponse>{
match resp_format {
None => process_response::<ShortAudioResponse>(res).map_ok(AudioResponse::Json).await,
Some(format) => match format {
ResponseFormat::Json => process_response::<ShortAudioResponse>(res).map_ok(AudioResponse::Json).await,
ResponseFormat::VerboseJson => process_response::<VerboseAudioResponse>(res).map_ok(AudioResponse::VerboseJson).await,
ResponseFormat::Text => process_text_response(res).map_ok(AudioResponse::Text).await,
ResponseFormat::Srt => process_text_response(res).map_ok(AudioResponse::Srt).await,
ResponseFormat::Vtt => process_text_response(res).map_ok(AudioResponse::Vtt).await
}
}
}