use anyhow::{Context, Result};
use std::sync::Arc;
use tokio::sync::Semaphore;
use crate::audio::AudioChunk;
use crate::config::TranscriptionProvider;
use crate::provider::{DEFAULT_TIMEOUT_SECS, TranscriptionRequest, registry};
const MAX_CONCURRENT_REQUESTS: usize = 3;
const MAX_OVERLAP_WORDS: usize = 15;
pub struct ChunkTranscription {
pub index: usize,
pub text: String,
pub has_leading_overlap: bool,
}
pub fn transcribe_audio(
provider: &TranscriptionProvider,
api_key: &str,
language: Option<&str>,
audio_data: Vec<u8>,
) -> Result<String> {
transcribe_audio_with_format(provider, api_key, language, audio_data, None, None)
}
pub fn transcribe_audio_with_format(
provider: &TranscriptionProvider,
api_key: &str,
language: Option<&str>,
audio_data: Vec<u8>,
mime_type: Option<&str>,
filename: Option<&str>,
) -> Result<String> {
let provider_impl = registry().get_by_kind(provider)?;
let request = TranscriptionRequest {
audio_data,
language: language.map(String::from),
filename: filename.unwrap_or("audio.mp3").to_string(),
mime_type: mime_type.unwrap_or("audio/mpeg").to_string(),
};
let result = provider_impl.transcribe_sync(api_key, request)?;
Ok(result.text)
}
pub async fn parallel_transcribe(
provider: &TranscriptionProvider,
api_key: &str,
language: Option<&str>,
chunks: Vec<AudioChunk>,
progress_callback: Option<Box<dyn Fn(usize, usize) + Send + Sync>>,
) -> Result<String> {
let total_chunks = chunks.len();
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()
.context("Failed to create HTTP client")?;
let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_REQUESTS));
let client = Arc::new(client);
let api_key = Arc::new(api_key.to_string());
let language = language.map(|s| Arc::new(s.to_string()));
let provider_impl = registry().get_by_kind(provider)?;
let completed = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let progress_callback = progress_callback.map(Arc::new);
let mut handles = Vec::with_capacity(total_chunks);
for chunk in chunks {
let semaphore = semaphore.clone();
let client = client.clone();
let api_key = api_key.clone();
let language = language.clone();
let provider_impl = provider_impl.clone();
let completed = completed.clone();
let progress_callback = progress_callback.clone();
let handle = tokio::spawn(async move {
let _permit = semaphore.acquire_owned().await?;
let chunk_index = chunk.index;
let has_leading_overlap = chunk.has_leading_overlap;
let request = TranscriptionRequest {
audio_data: chunk.data,
language: language.as_ref().map(|s| s.to_string()),
filename: format!("audio_chunk_{chunk_index}.mp3"),
mime_type: "audio/mpeg".to_string(),
};
let result = provider_impl
.transcribe_async(&client, &api_key, request)
.await;
let transcription = match result {
Ok(r) => ChunkTranscription {
index: chunk_index,
text: r.text,
has_leading_overlap,
},
Err(e) => return Err(e),
};
let done = completed.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1;
if let Some(ref cb) = progress_callback {
cb(done, total_chunks);
}
Ok(transcription)
});
handles.push(handle);
}
let mut results = Vec::with_capacity(total_chunks);
let mut errors = Vec::new();
for handle in handles {
match handle.await {
Ok(Ok(transcription)) => results.push(transcription),
Ok(Err(e)) => errors.push(e),
Err(e) => errors.push(anyhow::anyhow!("Task panicked: {e}")),
}
}
if !errors.is_empty() {
let error_msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
anyhow::bail!(
"Failed to transcribe {} of {} chunks:\n{}",
errors.len(),
total_chunks,
error_msgs.join("\n")
);
}
results.sort_by_key(|r| r.index);
Ok(merge_transcriptions(results))
}
fn merge_transcriptions(transcriptions: Vec<ChunkTranscription>) -> String {
if transcriptions.is_empty() {
return String::new();
}
if transcriptions.len() == 1 {
return transcriptions.into_iter().next().unwrap().text;
}
let mut merged = String::new();
for (i, transcription) in transcriptions.into_iter().enumerate() {
let text = transcription.text.trim();
if i == 0 {
merged.push_str(text);
} else if transcription.has_leading_overlap {
let cleaned_text = remove_overlap(&merged, text);
if !merged.ends_with(' ') && !cleaned_text.is_empty() && !cleaned_text.starts_with(' ')
{
merged.push(' ');
}
merged.push_str(&cleaned_text);
} else {
if !merged.ends_with(' ') && !text.is_empty() && !text.starts_with(' ') {
merged.push(' ');
}
merged.push_str(text);
}
}
merged
}
fn remove_overlap(existing: &str, new_text: &str) -> String {
let existing_words: Vec<&str> = existing.split_whitespace().collect();
let new_words: Vec<&str> = new_text.split_whitespace().collect();
if existing_words.is_empty() || new_words.is_empty() {
return new_text.to_string();
}
let search_end = existing_words.len().min(MAX_OVERLAP_WORDS);
let search_new = new_words.len().min(MAX_OVERLAP_WORDS);
let mut best_overlap = 0;
for overlap_len in 1..=search_end.min(search_new) {
let end_slice = &existing_words[existing_words.len() - overlap_len..];
let start_slice = &new_words[..overlap_len];
let matches = end_slice
.iter()
.zip(start_slice.iter())
.all(|(a, b)| a.eq_ignore_ascii_case(b));
if matches {
best_overlap = overlap_len;
}
}
if best_overlap > 0 {
new_words[best_overlap..].join(" ")
} else {
new_text.to_string()
}
}