use serde::{Deserialize, Serialize};
use crate::provider::Provider;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum TranscriptionModel {
WhisperLargeV3Turbo,
WhisperLargeV3,
}
#[derive(Debug, Clone)]
pub struct TranscriptionModelInfo {
pub model: TranscriptionModel,
pub description: &'static str,
pub supports_translation: bool,
pub providers: Vec<TranscriptionProviderMapping>,
}
#[derive(Debug, Clone)]
pub struct TranscriptionProviderMapping {
pub provider: Provider,
pub model_id: &'static str,
}
impl TranscriptionProviderMapping {
pub const fn new(provider: Provider, model_id: &'static str) -> Self {
Self { provider, model_id }
}
}
impl TranscriptionModel {
pub fn info(&self) -> TranscriptionModelInfo {
match self {
TranscriptionModel::WhisperLargeV3Turbo => TranscriptionModelInfo {
model: *self,
description: "Whisper Large V3 Turbo - Fast multilingual transcription",
supports_translation: false,
providers: vec![TranscriptionProviderMapping::new(
Provider::Groq,
"whisper-large-v3-turbo",
)],
},
TranscriptionModel::WhisperLargeV3 => TranscriptionModelInfo {
model: *self,
description: "Whisper Large V3 - High accuracy transcription and translation",
supports_translation: true,
providers: vec![TranscriptionProviderMapping::new(
Provider::Groq,
"whisper-large-v3",
)],
},
}
}
pub fn name(&self) -> &'static str {
match self {
TranscriptionModel::WhisperLargeV3Turbo => "Whisper Large V3 Turbo",
TranscriptionModel::WhisperLargeV3 => "Whisper Large V3",
}
}
pub fn default() -> Self {
TranscriptionModel::WhisperLargeV3Turbo
}
pub fn all() -> &'static [TranscriptionModel] {
&[
TranscriptionModel::WhisperLargeV3Turbo,
TranscriptionModel::WhisperLargeV3,
]
}
}
impl std::fmt::Display for TranscriptionModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TranscriptionResponseFormat {
Json,
Text,
VerboseJson,
}
impl Default for TranscriptionResponseFormat {
fn default() -> Self {
TranscriptionResponseFormat::Json
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TimestampGranularity {
Word,
Segment,
}
#[derive(Debug, Clone, Default)]
pub struct TranscriptionOptions {
pub language: Option<String>,
pub prompt: Option<String>,
pub response_format: Option<TranscriptionResponseFormat>,
pub temperature: Option<f32>,
pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
}
impl TranscriptionOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_language(mut self, language: impl Into<String>) -> Self {
self.language = Some(language.into());
self
}
pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
self.prompt = Some(prompt.into());
self
}
pub fn with_response_format(mut self, format: TranscriptionResponseFormat) -> Self {
self.response_format = Some(format);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_timestamp_granularities(
mut self,
granularities: Vec<TimestampGranularity>,
) -> Self {
self.timestamp_granularities = Some(granularities);
self
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct TranscriptionResponse {
pub text: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct VerboseTranscriptionResponse {
pub text: String,
#[serde(default)]
pub language: Option<String>,
#[serde(default)]
pub duration: Option<f64>,
#[serde(default)]
pub segments: Option<Vec<TranscriptionSegment>>,
#[serde(default)]
pub words: Option<Vec<TranscriptionWord>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TranscriptionSegment {
pub id: i32,
#[serde(default)]
pub seek: Option<i32>,
pub start: f64,
pub end: f64,
pub text: String,
#[serde(default)]
pub tokens: Option<Vec<i32>>,
#[serde(default)]
pub avg_logprob: Option<f64>,
#[serde(default)]
pub compression_ratio: Option<f64>,
#[serde(default)]
pub no_speech_prob: Option<f64>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TranscriptionWord {
pub word: String,
pub start: f64,
pub end: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transcription_model_info() {
let info = TranscriptionModel::WhisperLargeV3Turbo.info();
assert!(!info.providers.is_empty());
assert!(!info.supports_translation);
let info = TranscriptionModel::WhisperLargeV3.info();
assert!(info.supports_translation);
}
#[test]
fn test_all_transcription_models_have_providers() {
for model in TranscriptionModel::all() {
let info = model.info();
assert!(
!info.providers.is_empty(),
"Transcription model {} has no providers",
model.name()
);
}
}
#[test]
fn test_transcription_options() {
let options = TranscriptionOptions::new()
.with_language("en")
.with_temperature(0.2)
.with_response_format(TranscriptionResponseFormat::VerboseJson);
assert_eq!(options.language, Some("en".to_string()));
assert_eq!(options.temperature, Some(0.2));
assert_eq!(
options.response_format,
Some(TranscriptionResponseFormat::VerboseJson)
);
}
#[test]
fn test_transcription_response_parsing() {
let json = r#"{"text": "Hello, world!"}"#;
let response: TranscriptionResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.text, "Hello, world!");
}
#[test]
fn test_verbose_transcription_response_parsing() {
let json = r#"{
"text": "Hello, world!",
"language": "en",
"duration": 1.5,
"segments": [{
"id": 0,
"start": 0.0,
"end": 1.5,
"text": "Hello, world!"
}]
}"#;
let response: VerboseTranscriptionResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.text, "Hello, world!");
assert_eq!(response.language, Some("en".to_string()));
assert_eq!(response.duration, Some(1.5));
assert!(response.segments.is_some());
}
}