#[allow(deprecated)]
use crate::client::transcription::TranscriptionModelHandle;
use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
use crate::{http_client, json_utils};
use std::sync::Arc;
use std::{fs, path::Path};
use thiserror::Error;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum TranscriptionError {
#[error("HttpError: {0}")]
HttpError(#[from] http_client::Error),
#[error("JsonError: {0}")]
JsonError(#[from] serde_json::Error),
#[cfg(not(target_family = "wasm"))]
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
#[cfg(target_family = "wasm")]
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + 'static>),
#[error("ResponseError: {0}")]
ResponseError(String),
#[error("ProviderError: {0}")]
ProviderError(String),
}
pub trait Transcription<M>
where
M: TranscriptionModel,
{
fn transcription(
&self,
filename: &str,
data: &[u8],
) -> impl std::future::Future<
Output = Result<TranscriptionRequestBuilder<M>, TranscriptionError>,
> + WasmCompatSend;
}
pub struct TranscriptionResponse<T> {
pub text: String,
pub response: T,
}
pub trait TranscriptionModel: Clone + WasmCompatSend + WasmCompatSync {
type Response: WasmCompatSend + WasmCompatSync;
type Client;
fn make(client: &Self::Client, model: impl Into<String>) -> Self;
fn transcription(
&self,
request: TranscriptionRequest,
) -> impl std::future::Future<
Output = Result<TranscriptionResponse<Self::Response>, TranscriptionError>,
> + WasmCompatSend;
fn transcription_request(&self) -> TranscriptionRequestBuilder<Self> {
TranscriptionRequestBuilder::new(self.clone())
}
}
#[allow(deprecated)]
#[deprecated(
since = "0.25.0",
note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `TranscriptionModel` instead."
)]
pub trait TranscriptionModelDyn: WasmCompatSend + WasmCompatSync {
fn transcription(
&self,
request: TranscriptionRequest,
) -> WasmBoxedFuture<'_, Result<TranscriptionResponse<()>, TranscriptionError>>;
fn transcription_request(&self) -> TranscriptionRequestBuilder<TranscriptionModelHandle<'_>>;
}
#[allow(deprecated)]
impl<T> TranscriptionModelDyn for T
where
T: TranscriptionModel,
{
fn transcription(
&self,
request: TranscriptionRequest,
) -> WasmBoxedFuture<'_, Result<TranscriptionResponse<()>, TranscriptionError>> {
Box::pin(async move {
let resp = self.transcription(request).await?;
Ok(TranscriptionResponse {
text: resp.text,
response: (),
})
})
}
fn transcription_request(&self) -> TranscriptionRequestBuilder<TranscriptionModelHandle<'_>> {
TranscriptionRequestBuilder::new(TranscriptionModelHandle {
inner: Arc::new(self.clone()),
})
}
}
pub struct TranscriptionRequest {
pub data: Vec<u8>,
pub filename: String,
pub language: Option<String>,
pub prompt: Option<String>,
pub temperature: Option<f64>,
pub additional_params: Option<serde_json::Value>,
}
pub struct TranscriptionRequestBuilder<M>
where
M: TranscriptionModel,
{
model: M,
data: Vec<u8>,
filename: Option<String>,
language: Option<String>,
prompt: Option<String>,
temperature: Option<f64>,
additional_params: Option<serde_json::Value>,
}
impl<M> TranscriptionRequestBuilder<M>
where
M: TranscriptionModel,
{
pub fn new(model: M) -> Self {
TranscriptionRequestBuilder {
model,
data: vec![],
filename: None,
language: None,
prompt: None,
temperature: None,
additional_params: None,
}
}
pub fn filename(mut self, filename: Option<String>) -> Self {
self.filename = filename;
self
}
pub fn data(mut self, data: Vec<u8>) -> Self {
self.data = data;
self
}
pub fn load_file<P>(self, path: P) -> Self
where
P: AsRef<Path>,
{
let path = path.as_ref();
let data = fs::read(path).expect("Failed to load audio file, file did not exist");
self.filename(Some(
path.file_name()
.expect("Path was not a file")
.to_str()
.expect("Failed to convert filename to ascii")
.to_string(),
))
.data(data)
}
pub fn language(mut self, language: String) -> Self {
self.language = Some(language);
self
}
pub fn prompt(mut self, prompt: String) -> Self {
self.prompt = Some(prompt);
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
match self.additional_params {
Some(params) => {
self.additional_params = Some(json_utils::merge(params, additional_params));
}
None => {
self.additional_params = Some(additional_params);
}
}
self
}
pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
self.additional_params = additional_params;
self
}
pub fn build(self) -> TranscriptionRequest {
if self.data.is_empty() {
panic!("Data cannot be empty!")
}
TranscriptionRequest {
data: self.data,
filename: self.filename.unwrap_or("file".to_string()),
language: self.language,
prompt: self.prompt,
temperature: self.temperature,
additional_params: self.additional_params,
}
}
pub async fn send(self) -> Result<TranscriptionResponse<M::Response>, TranscriptionError> {
let model = self.model.clone();
model.transcription(self.build()).await
}
}