use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::core::providers::unified_provider::ProviderError;
pub const STT_ENDPOINT_PATH: &str = "/listen";
#[derive(Debug, Clone)]
pub struct TranscriptionRequest {
pub file: Vec<u8>,
pub model: String,
pub language: Option<String>,
pub smart_format: Option<bool>,
pub punctuate: Option<bool>,
pub diarize: Option<bool>,
pub paragraphs: Option<bool>,
pub utterances: Option<bool>,
pub words: Option<bool>,
pub search: Option<Vec<String>>,
pub keywords: Option<Vec<String>>,
pub filler_words: Option<bool>,
pub detect_language: Option<bool>,
pub filename: Option<String>,
}
impl Default for TranscriptionRequest {
fn default() -> Self {
Self {
file: Vec::new(),
model: "nova-2".to_string(),
language: None,
smart_format: None,
punctuate: None,
diarize: None,
paragraphs: None,
utterances: None,
words: None,
search: None,
keywords: None,
filler_words: None,
detect_language: None,
filename: None,
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct DeepgramResponse {
pub metadata: ResponseMetadata,
pub results: TranscriptionResults,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ResponseMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
pub transaction_key: Option<String>,
pub request_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub sha256: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub created: Option<String>,
pub duration: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub channels: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub models: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_info: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TranscriptionResults {
pub channels: Vec<ChannelResult>,
#[serde(skip_serializing_if = "Option::is_none")]
pub utterances: Option<Vec<Utterance>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ChannelResult {
pub alternatives: Vec<TranscriptionAlternative>,
#[serde(skip_serializing_if = "Option::is_none")]
pub detected_language: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub language_confidence: Option<f32>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TranscriptionAlternative {
pub transcript: String,
pub confidence: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub words: Option<Vec<WordInfo>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub paragraphs: Option<Paragraphs>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct WordInfo {
pub word: String,
pub start: f32,
pub end: f32,
pub confidence: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub speaker: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub punctuated_word: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Paragraphs {
pub transcript: String,
pub paragraphs: Vec<Paragraph>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Paragraph {
pub sentences: Vec<Sentence>,
#[serde(skip_serializing_if = "Option::is_none")]
pub speaker: Option<u32>,
pub num_words: u32,
pub start: f32,
pub end: f32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Sentence {
pub text: String,
pub start: f32,
pub end: f32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Utterance {
pub start: f32,
pub end: f32,
pub confidence: f32,
pub transcript: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub speaker: Option<u32>,
pub channel: u32,
pub id: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct OpenAITranscriptionResponse {
pub text: String,
pub task: String,
pub language: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub duration: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub words: Option<Vec<OpenAIWordInfo>>,
}
#[derive(Debug, Clone, Serialize)]
pub struct OpenAIWordInfo {
pub word: String,
pub start: f32,
pub end: f32,
}
impl TryFrom<DeepgramResponse> for OpenAITranscriptionResponse {
type Error = ProviderError;
fn try_from(response: DeepgramResponse) -> Result<Self, Self::Error> {
let first_channel = response.results.channels.first().ok_or_else(|| {
ProviderError::response_parsing("deepgram", "Response contains no channels")
})?;
let first_alternative = first_channel.alternatives.first().ok_or_else(|| {
ProviderError::response_parsing("deepgram", "Channel contains no alternatives")
})?;
let has_diarization = first_alternative
.words
.as_ref()
.map(|words| words.first().map(|w| w.speaker.is_some()).unwrap_or(false))
.unwrap_or(false);
let text = if !has_diarization {
first_alternative.transcript.clone()
} else if let Some(ref paragraphs) = first_alternative.paragraphs {
paragraphs.transcript.clone()
} else if let Some(ref words) = first_alternative.words {
reconstruct_diarized_transcript(words)
} else {
first_alternative.transcript.clone()
};
let words = first_alternative.words.as_ref().map(|words| {
words
.iter()
.map(|w| OpenAIWordInfo {
word: w.word.clone(),
start: w.start,
end: w.end,
})
.collect()
});
let language = first_channel
.detected_language
.clone()
.unwrap_or_else(|| "en".to_string());
Ok(OpenAITranscriptionResponse {
text,
task: "transcribe".to_string(),
language,
duration: Some(response.metadata.duration),
words,
})
}
}
fn reconstruct_diarized_transcript(words: &[WordInfo]) -> String {
if words.is_empty() {
return String::new();
}
let mut segments = Vec::new();
let mut current_speaker: Option<u32> = None;
let mut current_words: Vec<String> = Vec::new();
for word_obj in words {
let speaker = word_obj.speaker;
let word_text = word_obj
.punctuated_word
.clone()
.unwrap_or_else(|| word_obj.word.clone());
if speaker != current_speaker {
if !current_words.is_empty()
&& let Some(sp) = current_speaker
{
segments.push(format!("Speaker {}: {}", sp, current_words.join(" ")));
}
current_speaker = speaker;
current_words = vec![word_text];
} else {
current_words.push(word_text);
}
}
if !current_words.is_empty()
&& let Some(sp) = current_speaker
{
segments.push(format!("\nSpeaker {}: {}\n", sp, current_words.join(" ")));
}
segments.join("\n")
}
pub fn build_query_params(request: &TranscriptionRequest) -> String {
let mut params = vec![format!("model={}", request.model)];
if let Some(ref lang) = request.language {
params.push(format!("language={}", lang));
}
if let Some(smart_format) = request.smart_format {
params.push(format!("smart_format={}", smart_format));
}
if let Some(punctuate) = request.punctuate {
params.push(format!("punctuate={}", punctuate));
}
if let Some(diarize) = request.diarize {
params.push(format!("diarize={}", diarize));
}
if let Some(paragraphs) = request.paragraphs {
params.push(format!("paragraphs={}", paragraphs));
}
if let Some(utterances) = request.utterances {
params.push(format!("utterances={}", utterances));
}
if let Some(words) = request.words {
params.push(format!("words={}", words));
}
if let Some(filler_words) = request.filler_words {
params.push(format!("filler_words={}", filler_words));
}
if let Some(detect_language) = request.detect_language {
params.push(format!("detect_language={}", detect_language));
}
if let Some(ref keywords) = request.keywords {
for keyword in keywords {
params.push(format!("keywords={}", keyword));
}
}
if let Some(ref search) = request.search {
for term in search {
params.push(format!("search={}", term));
}
}
params.join("&")
}
pub fn build_stt_url(base_url: &str, request: &TranscriptionRequest) -> String {
let base = base_url.trim_end_matches('/');
let query = build_query_params(request);
format!(
"{}{}{}",
base,
STT_ENDPOINT_PATH,
if query.is_empty() {
String::new()
} else {
format!("?{}", query)
}
)
}
pub fn detect_audio_mime_type(filename: &str) -> &'static str {
let extension = filename
.rsplit('.')
.next()
.map(|s| s.to_lowercase())
.unwrap_or_default();
match extension.as_str() {
"mp3" => "audio/mpeg",
"mp4" | "m4a" => "audio/mp4",
"wav" => "audio/wav",
"webm" => "audio/webm",
"ogg" | "oga" => "audio/ogg",
"flac" => "audio/flac",
_ => "audio/mpeg",
}
}
pub fn supported_audio_formats() -> &'static [&'static str] {
&[
"mp3", "mp4", "mp2", "aac", "wav", "flac", "pcm", "m4a", "ogg", "opus", "webm",
]
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum STTModel {
#[default]
Nova2,
Nova2General,
Nova2Meeting,
Nova2PhoneCall,
Nova2Finance,
Nova2ConversationalAI,
Nova2Voicemail,
Nova2Video,
Nova2Medical,
Nova2DriveThru,
Nova2Automotive,
Enhanced,
Base,
}
impl STTModel {
pub fn as_str(&self) -> &'static str {
match self {
STTModel::Nova2 => "nova-2",
STTModel::Nova2General => "nova-2-general",
STTModel::Nova2Meeting => "nova-2-meeting",
STTModel::Nova2PhoneCall => "nova-2-phonecall",
STTModel::Nova2Finance => "nova-2-finance",
STTModel::Nova2ConversationalAI => "nova-2-conversationalai",
STTModel::Nova2Voicemail => "nova-2-voicemail",
STTModel::Nova2Video => "nova-2-video",
STTModel::Nova2Medical => "nova-2-medical",
STTModel::Nova2DriveThru => "nova-2-drivethru",
STTModel::Nova2Automotive => "nova-2-automotive",
STTModel::Enhanced => "enhanced",
STTModel::Base => "base",
}
}
pub fn parse(s: &str) -> Option<Self> {
match s {
"nova-2" => Some(STTModel::Nova2),
"nova-2-general" => Some(STTModel::Nova2General),
"nova-2-meeting" => Some(STTModel::Nova2Meeting),
"nova-2-phonecall" => Some(STTModel::Nova2PhoneCall),
"nova-2-finance" => Some(STTModel::Nova2Finance),
"nova-2-conversationalai" => Some(STTModel::Nova2ConversationalAI),
"nova-2-voicemail" => Some(STTModel::Nova2Voicemail),
"nova-2-video" => Some(STTModel::Nova2Video),
"nova-2-medical" => Some(STTModel::Nova2Medical),
"nova-2-drivethru" => Some(STTModel::Nova2DriveThru),
"nova-2-automotive" => Some(STTModel::Nova2Automotive),
"enhanced" => Some(STTModel::Enhanced),
"base" => Some(STTModel::Base),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stt_model_as_str() {
assert_eq!(STTModel::Nova2.as_str(), "nova-2");
assert_eq!(STTModel::Nova2Meeting.as_str(), "nova-2-meeting");
assert_eq!(STTModel::Enhanced.as_str(), "enhanced");
}
#[test]
fn test_stt_model_from_str() {
assert_eq!(STTModel::parse("nova-2"), Some(STTModel::Nova2));
assert_eq!(
STTModel::parse("nova-2-meeting"),
Some(STTModel::Nova2Meeting)
);
assert_eq!(STTModel::parse("invalid"), None);
}
#[test]
fn test_build_query_params() {
let request = TranscriptionRequest {
model: "nova-2".to_string(),
language: Some("en".to_string()),
punctuate: Some(true),
diarize: Some(true),
..Default::default()
};
let params = build_query_params(&request);
assert!(params.contains("model=nova-2"));
assert!(params.contains("language=en"));
assert!(params.contains("punctuate=true"));
assert!(params.contains("diarize=true"));
}
#[test]
fn test_build_stt_url() {
let request = TranscriptionRequest {
model: "nova-2".to_string(),
..Default::default()
};
let url = build_stt_url("https://api.deepgram.com/v1", &request);
assert!(url.starts_with("https://api.deepgram.com/v1/listen?"));
assert!(url.contains("model=nova-2"));
}
#[test]
fn test_detect_audio_mime_type() {
assert_eq!(detect_audio_mime_type("audio.mp3"), "audio/mpeg");
assert_eq!(detect_audio_mime_type("audio.wav"), "audio/wav");
assert_eq!(detect_audio_mime_type("audio.m4a"), "audio/mp4");
assert_eq!(detect_audio_mime_type("audio.webm"), "audio/webm");
}
#[test]
fn test_reconstruct_diarized_transcript() {
let words = vec![
WordInfo {
word: "Hello".to_string(),
start: 0.0,
end: 0.5,
confidence: 0.95,
speaker: Some(0),
punctuated_word: Some("Hello,".to_string()),
},
WordInfo {
word: "world".to_string(),
start: 0.5,
end: 1.0,
confidence: 0.90,
speaker: Some(0),
punctuated_word: Some("world.".to_string()),
},
WordInfo {
word: "Hi".to_string(),
start: 1.5,
end: 2.0,
confidence: 0.92,
speaker: Some(1),
punctuated_word: Some("Hi!".to_string()),
},
];
let transcript = reconstruct_diarized_transcript(&words);
assert!(transcript.contains("Speaker 0"));
assert!(transcript.contains("Speaker 1"));
assert!(transcript.contains("Hello,"));
assert!(transcript.contains("Hi!"));
}
#[test]
fn test_supported_audio_formats() {
let formats = supported_audio_formats();
assert!(formats.contains(&"mp3"));
assert!(formats.contains(&"wav"));
assert!(formats.contains(&"flac"));
assert!(formats.contains(&"webm"));
}
#[test]
fn test_transcription_request_default() {
let request = TranscriptionRequest::default();
assert_eq!(request.model, "nova-2");
assert!(request.language.is_none());
assert!(request.diarize.is_none());
}
}