use crate::markers::{Missing, Provided};
use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
use crate::{http_client, json_utils};
use std::io;
use std::{fs, path::Path};
use thiserror::Error;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum TranscriptionError {
#[error("HttpError: {0}")]
HttpError(#[from] http_client::Error),
#[error("JsonError: {0}")]
JsonError(#[from] serde_json::Error),
#[cfg(not(target_family = "wasm"))]
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
#[cfg(target_family = "wasm")]
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + 'static>),
#[error("ResponseError: {0}")]
ResponseError(String),
#[error("ProviderError: {0}")]
ProviderError(String),
}
pub trait Transcription<M>
where
M: TranscriptionModel,
{
fn transcription(
&self,
filename: &str,
data: &[u8],
) -> impl std::future::Future<
Output = Result<TranscriptionRequestBuilder<M, Provided<Vec<u8>>>, TranscriptionError>,
> + WasmCompatSend;
}
pub struct TranscriptionResponse<T> {
pub text: String,
pub response: T,
}
pub trait TranscriptionModel: Clone + WasmCompatSend + WasmCompatSync {
type Response: WasmCompatSend + WasmCompatSync;
type Client;
fn make(client: &Self::Client, model: impl Into<String>) -> Self;
fn transcription(
&self,
request: TranscriptionRequest,
) -> impl std::future::Future<
Output = Result<TranscriptionResponse<Self::Response>, TranscriptionError>,
> + WasmCompatSend;
fn transcription_request(&self) -> TranscriptionRequestBuilder<Self, Missing> {
TranscriptionRequestBuilder::new(self.clone())
}
}
pub struct TranscriptionRequest {
pub data: Vec<u8>,
pub filename: String,
pub language: Option<String>,
pub prompt: Option<String>,
pub temperature: Option<f64>,
pub additional_params: Option<serde_json::Value>,
}
pub struct TranscriptionRequestBuilder<M, D>
where
M: TranscriptionModel,
{
model: M,
data: D, filename: Option<String>,
language: Option<String>,
prompt: Option<String>,
temperature: Option<f64>,
additional_params: Option<serde_json::Value>,
}
impl<M> TranscriptionRequestBuilder<M, Missing>
where
M: TranscriptionModel,
{
pub fn new(model: M) -> Self {
TranscriptionRequestBuilder {
model,
data: Missing,
filename: None,
language: None,
prompt: None,
temperature: None,
additional_params: None,
}
}
}
impl<M, D> TranscriptionRequestBuilder<M, D>
where
M: TranscriptionModel,
{
pub fn filename(mut self, filename: Option<String>) -> Self {
self.filename = filename;
self
}
pub fn data(self, data: Vec<u8>) -> TranscriptionRequestBuilder<M, Provided<Vec<u8>>> {
TranscriptionRequestBuilder {
model: self.model,
data: Provided(data),
filename: self.filename,
language: self.language,
prompt: self.prompt,
temperature: self.temperature,
additional_params: self.additional_params,
}
}
pub fn load_file<P>(
self,
path: P,
) -> io::Result<TranscriptionRequestBuilder<M, Provided<Vec<u8>>>>
where
P: AsRef<Path>,
{
let path = path.as_ref();
let data = fs::read(path)?;
let filename = path.file_name().map(|n| n.to_string_lossy().into_owned());
Ok(TranscriptionRequestBuilder {
model: self.model,
data: Provided(data),
filename: filename.or(self.filename),
language: self.language,
prompt: self.prompt,
temperature: self.temperature,
additional_params: self.additional_params,
})
}
pub fn language(mut self, language: String) -> Self {
self.language = Some(language);
self
}
pub fn prompt(mut self, prompt: String) -> Self {
self.prompt = Some(prompt);
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
match self.additional_params {
Some(params) => {
self.additional_params = Some(json_utils::merge(params, additional_params));
}
None => {
self.additional_params = Some(additional_params);
}
}
self
}
pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
self.additional_params = additional_params;
self
}
}
impl<M> TranscriptionRequestBuilder<M, Provided<Vec<u8>>>
where
M: TranscriptionModel,
{
pub fn build(self) -> TranscriptionRequest {
TranscriptionRequest {
data: self.data.0,
filename: self.filename.unwrap_or("file".to_string()),
language: self.language,
prompt: self.prompt,
temperature: self.temperature,
additional_params: self.additional_params,
}
}
pub async fn send(self) -> Result<TranscriptionResponse<M::Response>, TranscriptionError> {
let model = self.model.clone();
model.transcription(self.build()).await
}
}