use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::Result;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AudioFormat {
#[default]
Wav,
Mp3,
Flac,
Ogg,
WebM,
M4a,
Opus,
Aac,
Pcm,
}
impl AudioFormat {
#[must_use]
pub const fn extension(&self) -> &'static str {
match self {
Self::Wav => "wav",
Self::Mp3 => "mp3",
Self::Flac => "flac",
Self::Ogg => "ogg",
Self::WebM => "webm",
Self::M4a => "m4a",
Self::Opus => "opus",
Self::Aac => "aac",
Self::Pcm => "pcm",
}
}
#[must_use]
pub const fn mime_type(&self) -> &'static str {
match self {
Self::Wav => "audio/wav",
Self::Mp3 => "audio/mpeg",
Self::Flac => "audio/flac",
Self::Ogg => "audio/ogg",
Self::WebM => "audio/webm",
Self::M4a => "audio/m4a",
Self::Opus => "audio/opus",
Self::Aac => "audio/aac",
Self::Pcm => "audio/pcm",
}
}
#[must_use]
pub const fn as_str(&self) -> &'static str {
self.extension()
}
#[must_use]
pub fn from_extension(ext: &str) -> Option<Self> {
match ext.to_ascii_lowercase().as_str() {
"wav" => Some(Self::Wav),
"mp3" => Some(Self::Mp3),
"flac" => Some(Self::Flac),
"ogg" => Some(Self::Ogg),
"webm" => Some(Self::WebM),
"m4a" => Some(Self::M4a),
"opus" => Some(Self::Opus),
"aac" => Some(Self::Aac),
"pcm" => Some(Self::Pcm),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Voice {
pub id: String,
#[serde(skip)]
pub description: Option<String>,
}
impl Voice {
#[must_use]
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
description: None,
}
}
#[must_use]
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
}
impl<S: Into<String>> From<S> for Voice {
fn from(s: S) -> Self {
Self::new(s)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpeechRequest {
pub model: String,
pub input: String,
pub voice: Voice,
pub response_format: AudioFormat,
#[serde(skip_serializing_if = "Option::is_none")]
pub speed: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
}
impl SpeechRequest {
#[must_use]
pub fn new(
model: impl Into<String>,
input: impl Into<String>,
voice: impl Into<Voice>,
) -> Self {
Self {
model: model.into(),
input: input.into(),
voice: voice.into(),
response_format: AudioFormat::Mp3,
speed: None,
instructions: None,
}
}
#[must_use]
pub const fn format(mut self, format: AudioFormat) -> Self {
self.response_format = format;
self
}
#[must_use]
pub const fn speed(mut self, speed: f32) -> Self {
self.speed = Some(speed);
self
}
#[must_use]
pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
self.instructions = Some(instructions.into());
self
}
}
#[derive(Debug, Clone)]
pub struct SpeechResponse {
pub audio: Vec<u8>,
pub format: AudioFormat,
}
impl SpeechResponse {
#[must_use]
pub const fn new(audio: Vec<u8>, format: AudioFormat) -> Self {
Self { audio, format }
}
pub fn save(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> {
std::fs::write(path, &self.audio)
}
#[must_use]
pub const fn extension(&self) -> &'static str {
self.format.extension()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TranscriptionResponseFormat {
#[default]
Json,
Text,
Srt,
Vtt,
VerboseJson,
}
impl TranscriptionResponseFormat {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Json => "json",
Self::Text => "text",
Self::Srt => "srt",
Self::Vtt => "vtt",
Self::VerboseJson => "verbose_json",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TimestampGranularity {
Word,
Segment,
}
impl TimestampGranularity {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Word => "word",
Self::Segment => "segment",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptionRequest {
pub model: String,
#[serde(skip)]
pub audio: Vec<u8>,
#[serde(skip)]
pub format: AudioFormat,
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<TranscriptionResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
}
impl TranscriptionRequest {
#[must_use]
pub fn new(model: impl Into<String>, audio: Vec<u8>) -> Self {
Self {
model: model.into(),
audio,
format: AudioFormat::default(),
language: None,
prompt: None,
response_format: None,
temperature: None,
timestamp_granularities: None,
}
}
#[must_use]
pub const fn format(mut self, format: AudioFormat) -> Self {
self.format = format;
self
}
#[must_use]
pub fn language(mut self, lang: impl Into<String>) -> Self {
self.language = Some(lang.into());
self
}
#[must_use]
pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
self.prompt = Some(prompt.into());
self
}
#[must_use]
pub const fn response_format(mut self, format: TranscriptionResponseFormat) -> Self {
self.response_format = Some(format);
self
}
#[must_use]
pub const fn temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
#[must_use]
pub fn with_word_timestamps(mut self) -> Self {
let mut granularities = self.timestamp_granularities.unwrap_or_default();
if !granularities.contains(&TimestampGranularity::Word) {
granularities.push(TimestampGranularity::Word);
}
self.timestamp_granularities = Some(granularities);
self.response_format = Some(TranscriptionResponseFormat::VerboseJson);
self
}
#[must_use]
pub fn with_segment_timestamps(mut self) -> Self {
let mut granularities = self.timestamp_granularities.unwrap_or_default();
if !granularities.contains(&TimestampGranularity::Segment) {
granularities.push(TimestampGranularity::Segment);
}
self.timestamp_granularities = Some(granularities);
self.response_format = Some(TranscriptionResponseFormat::VerboseJson);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptionWord {
pub word: String,
pub start: f32,
pub end: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptionSegment {
pub id: usize,
pub start: f32,
pub end: f32,
pub text: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TranscriptionResponse {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub duration: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub words: Option<Vec<TranscriptionWord>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub segments: Option<Vec<TranscriptionSegment>>,
}
impl TranscriptionResponse {
#[must_use]
pub fn new(text: impl Into<String>) -> Self {
Self {
text: text.into(),
..Default::default()
}
}
#[must_use]
pub fn with_language(mut self, language: impl Into<String>) -> Self {
self.language = Some(language.into());
self
}
#[must_use]
pub const fn with_duration(mut self, duration: f32) -> Self {
self.duration = Some(duration);
self
}
}
#[async_trait]
pub trait TextToSpeechProvider: Send + Sync {
async fn speech(&self, request: &SpeechRequest) -> Result<SpeechResponse>;
async fn speech_to_file(
&self,
request: &SpeechRequest,
path: impl AsRef<std::path::Path> + Send,
) -> Result<SpeechResponse> {
use crate::error::LlmError;
let response = self.speech(request).await?;
response
.save(&path)
.map_err(|e| LlmError::internal(format!("Failed to save audio file: {e}")))?;
Ok(response)
}
fn available_voices(&self) -> Vec<Voice> {
Vec::new()
}
}
#[async_trait]
pub trait SpeechToTextProvider: Send + Sync {
async fn transcribe(&self, request: &TranscriptionRequest) -> Result<TranscriptionResponse>;
async fn transcribe_file(&self, model: &str, file_path: &str) -> Result<TranscriptionResponse> {
use crate::error::LlmError;
let audio = std::fs::read(file_path)
.map_err(|e| LlmError::internal(format!("Failed to read audio file: {e}")))?;
let format = std::path::Path::new(file_path)
.extension()
.and_then(|ext| ext.to_str())
.and_then(AudioFormat::from_extension)
.unwrap_or_default();
let request = TranscriptionRequest::new(model, audio).format(format);
self.transcribe(&request).await
}
}