use std::{fs, path::Path};
use thiserror::Error;
use crate::json_utils;
#[derive(Debug, Error)]
pub enum TranscriptionError {
#[error("HttpError: {0}")]
HttpError(#[from] reqwest::Error),
#[error("JsonError: {0}")]
JsonError(#[from] serde_json::Error),
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("ResponseError: {0}")]
ResponseError(String),
#[error("ProviderError: {0}")]
ProviderError(String),
}
pub trait Transcription<M: TranscriptionModel> {
fn transcription(
&self,
filename: &str,
data: &[u8],
) -> impl std::future::Future<Output = Result<TranscriptionRequestBuilder<M>, TranscriptionError>>
+ Send;
}
pub struct TranscriptionResponse<T> {
pub text: String,
pub response: T,
}
pub trait TranscriptionModel: Clone + Send + Sync {
type Response: Sync + Send;
fn transcription(
&self,
request: TranscriptionRequest,
) -> impl std::future::Future<
Output = Result<TranscriptionResponse<Self::Response>, TranscriptionError>,
> + Send;
fn transcription_request(&self) -> TranscriptionRequestBuilder<Self> {
TranscriptionRequestBuilder::new(self.clone())
}
}
pub struct TranscriptionRequest {
pub data: Vec<u8>,
pub filename: String,
pub language: String,
pub prompt: Option<String>,
pub temperature: Option<f64>,
pub additional_params: Option<serde_json::Value>,
}
pub struct TranscriptionRequestBuilder<M: TranscriptionModel> {
model: M,
data: Vec<u8>,
filename: Option<String>,
language: String,
prompt: Option<String>,
temperature: Option<f64>,
additional_params: Option<serde_json::Value>,
}
impl<M: TranscriptionModel> TranscriptionRequestBuilder<M> {
pub fn new(model: M) -> Self {
TranscriptionRequestBuilder {
model,
data: vec![],
filename: None,
language: "en".to_string(),
prompt: None,
temperature: None,
additional_params: None,
}
}
pub fn filename(mut self, filename: Option<String>) -> Self {
self.filename = filename;
self
}
pub fn data(mut self, data: Vec<u8>) -> Self {
self.data = data;
self
}
pub fn load_file<P>(self, path: P) -> Self
where
P: AsRef<Path>,
{
let path = path.as_ref();
let data = fs::read(path).expect("Failed to load audio file, file did not exist");
self.filename(Some(
path.file_name()
.expect("Path was not a file")
.to_str()
.expect("Failed to convert filename to ascii")
.to_string(),
))
.data(data)
}
pub fn language(mut self, language: String) -> Self {
self.language = language;
self
}
pub fn prompt(mut self, prompt: String) -> Self {
self.prompt = Some(prompt);
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
match self.additional_params {
Some(params) => {
self.additional_params = Some(json_utils::merge(params, additional_params));
}
None => {
self.additional_params = Some(additional_params);
}
}
self
}
pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
self.additional_params = additional_params;
self
}
pub fn build(self) -> TranscriptionRequest {
if self.data.is_empty() {
panic!("Data cannot be empty!")
}
TranscriptionRequest {
data: self.data,
filename: self.filename.unwrap_or("file".to_string()),
language: self.language,
prompt: self.prompt,
temperature: self.temperature,
additional_params: self.additional_params,
}
}
pub async fn send(self) -> Result<TranscriptionResponse<M::Response>, TranscriptionError> {
let model = self.model.clone();
model.transcription(self.build()).await
}
}