use super::*;
use std::collections::{HashMap, HashSet};
use std::io::IsTerminal;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::thread::{self, JoinHandle};
pub(crate) fn transcribe_command(mut args: TranscribeArgs) -> anyhow::Result<()> {
args.input = expand_transcribe_inputs(&args.input)?;
validate_transcribe_args(&args)?;
validate_explicit_output_dir_collisions(&args)?;
let feedback = TerminalFeedback::new();
feedback.inputs(&args);
let configs = args
.input
.iter()
.cloned()
.map(|input| transcribe_config(&args, input))
.collect::<Vec<_>>();
let progress = feedback.start_progress();
let reports = match run_many(configs) {
Ok(reports) => {
if let Some(progress) = progress {
progress.finish_success();
}
reports
}
Err(error) => {
if let Some(progress) = progress {
progress.finish_error();
}
return Err(error.into());
}
};
feedback.outputs(&reports);
if should_print_json_report() && reports.len() == 1 {
println!("{}", serde_json::to_string_pretty(&reports[0])?);
} else if should_print_json_report() {
println!("{}", serde_json::to_string_pretty(&reports)?);
}
Ok(())
}
fn validate_transcribe_args(args: &TranscribeArgs) -> anyhow::Result<()> {
validate_speaker_directory_args(&args.speaker_directory)?;
let subtitle_layout_requested =
args.highlight_words || args.max_line_width.is_some() || args.max_line_count.is_some();
if args.no_align && subtitle_layout_requested {
anyhow::bail!(
"--highlight_words, --max_line_width, and --max_line_count require alignment; remove --no_align"
);
}
if args.task == CliTask::Translate
&& args.provider == CliProvider::Native
&& args.translation_model.is_none()
&& args.translation_bundle.is_none()
{
anyhow::bail!(
"native --task translate requires --translation-model or --translation-bundle; use --provider external-whisperx for WhisperX built-in translation"
);
}
let native_pyannote_model = args.provider == CliProvider::Native
&& args
.diarize_model
.as_deref()
.is_some_and(is_pyannote_diarization_model);
if args.speaker_embeddings
&& args.provider == CliProvider::Native
&& !(native_pyannote_model && args.diarization_model_bundle.is_some())
{
anyhow::bail!(
"native speaker embeddings require --diarize-model pyannote/... and --diarization-model-bundle"
);
}
if native_pyannote_model && args.diarization_model_bundle.is_none() {
anyhow::bail!("native pyannote diarization requires --diarization-model-bundle");
}
if args.provider == CliProvider::Native
&& args.diarization_model_bundle.is_some()
&& !native_pyannote_model
{
anyhow::bail!("native --diarization-model-bundle requires --diarize-model pyannote/...");
}
if args.basename.is_some() && args.input.len() > 1 {
anyhow::bail!("--basename cannot be used with multiple input files");
}
Ok(())
}
fn expand_transcribe_inputs(inputs: &[PathBuf]) -> anyhow::Result<Vec<PathBuf>> {
let mut expanded = Vec::new();
let mut seen = HashSet::new();
for input in inputs {
if input.exists() {
push_unique_input(&mut expanded, &mut seen, input.clone())?;
} else if is_glob_pattern(input) {
let pattern = input.to_string_lossy();
let mut matches = glob::glob(&pattern)
.with_context(|| format!("invalid input pattern `{pattern}`"))?
.map(|entry| {
entry.with_context(|| format!("failed to read input pattern `{pattern}` match"))
})
.collect::<anyhow::Result<Vec<_>>>()?;
matches.sort();
if matches.is_empty() {
anyhow::bail!("input pattern `{pattern}` matched no input files");
}
for matched in matches {
if !matched.is_file() {
anyhow::bail!(
"input pattern `{pattern}` matched non-file input `{}`",
matched.display()
);
}
push_unique_input(&mut expanded, &mut seen, matched)?;
}
} else {
push_unique_input(&mut expanded, &mut seen, input.clone())?;
}
}
Ok(expanded)
}
fn is_glob_pattern(path: &Path) -> bool {
path.to_string_lossy()
.chars()
.any(|character| matches!(character, '*' | '?' | '['))
}
fn push_unique_input(
expanded: &mut Vec<PathBuf>,
seen: &mut HashSet<PathBuf>,
input: PathBuf,
) -> anyhow::Result<()> {
let dedupe_key = input
.canonicalize()
.unwrap_or_else(|_| absolute_from_cwd(input.clone()).unwrap_or_else(|_| input.clone()));
if seen.insert(dedupe_key) {
expanded.push(input);
}
Ok(())
}
fn validate_explicit_output_dir_collisions(args: &TranscribeArgs) -> anyhow::Result<()> {
if args.output_dir.is_none() {
return Ok(());
}
let mut by_basename: HashMap<String, PathBuf> = HashMap::new();
for input in &args.input {
let basename = input
.file_stem()
.and_then(|stem| stem.to_str())
.filter(|stem| !stem.is_empty())
.unwrap_or("transcript")
.to_string();
if let Some(previous) = by_basename.insert(basename.clone(), input.clone()) {
anyhow::bail!(
"output basename collision `{basename}` for inputs `{}` and `{}`; choose distinct input filenames or omit --output-dir to write beside each input",
previous.display(),
input.display()
);
}
}
Ok(())
}
fn transcribe_config(args: &TranscribeArgs, input: PathBuf) -> NativeWhisperxConfig {
let output_dir = transcribe_output_dir(args, &input);
let provider = match args.provider {
CliProvider::Native => AsrProvider::Native,
CliProvider::ExternalWhisperx => AsrProvider::ExternalWhisperX,
};
let external_output_dir = match args.provider {
CliProvider::ExternalWhisperx if args.output_dir.is_none() => {
Some(unique_external_whisperx_output_dir())
}
CliProvider::ExternalWhisperx => output_dir.clone(),
CliProvider::Native => None,
};
let diarize = args.diarize
|| args.speaker_embeddings
|| args.diarization_model_bundle.is_some()
|| args.speaker_embedding_bundle.is_some()
|| args.min_speakers.is_some()
|| args.max_speakers.is_some();
let diarize_model = args
.diarize_model
.clone()
.unwrap_or_else(|| match args.provider {
CliProvider::Native => DiarizationConfig::default().model_id,
CliProvider::ExternalWhisperx => "pyannote/speaker-diarization-community-1".to_string(),
});
NativeWhisperxConfig {
input: InputSource::Path { path: input },
asr: AsrConfig {
provider,
task: args.task.into(),
model_id: args.model.clone(),
language: args.language.clone(),
whisper_bundle: args.whisper_bundle.clone(),
model_dir: args.model_dir.clone(),
model_cache_only: args.model_cache_only,
device: args.device.into(),
device_index: args.device_index.clone(),
compute_type: args.compute_type.clone(),
batch_chunks: true,
max_batch_size: args.batch_size.or_else(default_native_batch_size),
decode: decode_config(args),
external_whisperx: ExternalWhisperxConfig {
model: args.model.clone(),
output_dir: external_output_dir,
extra_args: logging_extra_args(args),
..ExternalWhisperxConfig::default()
},
},
translation: translation_config(
args.translation_model.clone(),
args.translation_bundle.clone(),
args.model_dir.clone(),
args.model_cache_only,
args.translation_source_language.clone(),
args.translation_target_language.clone(),
args.translation_max_new_tokens,
),
vad: VadConfig {
method: args.vad_method.into(),
onset: args.vad_onset,
offset: args.vad_offset,
chunk_size: args.chunk_size,
model_bundle: args.vad_model_bundle.clone(),
model_file: args.vad_model_file.clone(),
input_name: args.vad_input_name.clone(),
output_name: args.vad_output_name.clone(),
..VadConfig::default()
},
alignment: alignment_config(
args.no_align
|| args.task == CliTask::Translate
&& args.provider == CliProvider::Native
&& args.translation_model.is_none()
&& args.translation_bundle.is_none(),
args.alignment_model.clone(),
args.alignment_bundle.clone(),
args.model_dir.clone(),
args.model_cache_only,
args.interpolate_method,
args.return_char_alignments,
),
diarization: DiarizationConfig {
enabled: diarize,
model_id: diarize_model,
hf_token: args.hf_token.clone(),
return_speaker_embeddings: args.speaker_embeddings,
model_bundle: args.diarization_model_bundle.clone(),
manifest_file: args.diarization_manifest_file.clone(),
segmentation_model_file: args.diarization_segmentation_model_file.clone(),
embedding_model_file: args.diarization_embedding_model_file.clone(),
plda_transform_file: args.diarization_plda_transform_file.clone(),
plda_model_file: args.diarization_plda_model_file.clone(),
clustering_config_file: args.diarization_clustering_config_file.clone(),
speaker_embedding_model_bundle: args.speaker_embedding_bundle.clone(),
speaker_embedding_model_file: args.speaker_embedding_model_file.clone(),
speaker_embedding_dimension: args.speaker_embedding_dim,
speaker_embedding_sample_rate: args.speaker_embedding_sample_rate,
min_speakers: args.min_speakers,
max_speakers: args.max_speakers,
assignment_policy: args.speaker_assignment_policy.into(),
speaker_directory: args
.speaker_directory
.clone()
.try_into()
.expect("transcribe args were validated"),
disable_speaker_library: args.no_speaker_library || args.no_speaker_store,
save_draft_speakers: !args.no_save_draft_speakers,
use_draft_speakers: !args.no_use_draft_speakers,
..DiarizationConfig::default()
},
output: OutputConfig {
output_dir,
formats: args.formats.iter().copied().map(Into::into).collect(),
basename: args.basename.clone(),
pretty_json: true,
subtitles: SubtitleConfig {
max_line_width: args.max_line_width,
max_line_count: args.max_line_count,
highlight_words: args.highlight_words,
segment_resolution: args.segment_resolution.into(),
},
},
}
}
fn transcribe_output_dir(args: &TranscribeArgs, input: &Path) -> Option<PathBuf> {
args.output_dir.clone().or_else(|| {
input
.parent()
.filter(|parent| !parent.as_os_str().is_empty())
.map(Path::to_path_buf)
.or_else(|| Some(PathBuf::from(".")))
})
}
fn unique_external_whisperx_output_dir() -> PathBuf {
let millis = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
std::env::temp_dir().join(format!(
"native-whisperx-external-{}-{millis}",
std::process::id()
))
}
fn is_pyannote_diarization_model(model_id: &str) -> bool {
model_id
.trim()
.to_ascii_lowercase()
.starts_with("pyannote/")
}
fn decode_config(args: &TranscribeArgs) -> WhisperxDecodeConfig {
WhisperxDecodeConfig {
temperature: args.temperature.clone(),
best_of: args.best_of,
beam_size: args.beam_size,
patience: args.patience,
length_penalty: args.length_penalty,
suppress_tokens: args.suppress_tokens.clone(),
suppress_numerals: args.suppress_numerals,
initial_prompt: args.initial_prompt.clone(),
hotwords: args.hotwords.clone(),
condition_on_previous_text: args.condition_on_previous_text,
fp16: args.fp16,
compression_ratio_threshold: args.compression_ratio_threshold,
logprob_threshold: args.logprob_threshold,
no_speech_threshold: args.no_speech_threshold,
threads: args.threads,
}
}
fn logging_extra_args(args: &TranscribeArgs) -> Vec<String> {
let mut extra_args = Vec::new();
if let Some(verbose) = &args.verbose {
extra_args.extend(["--verbose".to_string(), verbose.clone()]);
}
if let Some(log_level) = &args.log_level {
extra_args.extend(["--log-level".to_string(), log_level.clone()]);
}
if args.print_progress {
extra_args.push("--print_progress".to_string());
}
extra_args
}
fn default_native_batch_size() -> Option<usize> {
AsrConfig::default().max_batch_size
}
struct TerminalFeedback {
enabled: bool,
}
impl TerminalFeedback {
fn new() -> Self {
Self {
enabled: std::io::stderr().is_terminal(),
}
}
fn inputs(&self, args: &TranscribeArgs) {
if !self.enabled {
return;
}
let provider = match args.provider {
CliProvider::Native => "native",
CliProvider::ExternalWhisperx => "external-whisperx",
};
eprintln!(
"native-whisperx: transcribing {} input(s) with {provider} provider, model {}, max batch size {}",
args.input.len(),
args.model,
format_batch_size(args.batch_size.or_else(default_native_batch_size))
);
for input in &args.input {
eprintln!("native-whisperx: queued {}", input.display());
}
}
fn start_progress(&self) -> Option<TerminalProgress> {
if !self.enabled {
return None;
}
Some(TerminalProgress::start())
}
fn outputs(&self, reports: &[NativeWhisperxReport]) {
if !self.enabled {
return;
}
for report in reports {
for diagnostic in &report.response.diagnostics {
if diagnostic.starts_with("cudaOomRetryCount=")
|| diagnostic.starts_with("cudaOomRetriedBatchSizes=")
|| diagnostic.starts_with("cudaOomFinalBatchSize=")
|| diagnostic.starts_with("alignmentFallbackCount=")
|| diagnostic.starts_with("alignmentFallbackReason=")
{
eprintln!("native-whisperx: {diagnostic}");
}
}
for output in &report.output_files {
eprintln!("native-whisperx: wrote {}", output.path.display());
}
}
}
}
struct TerminalProgress {
done: Arc<AtomicBool>,
handle: Option<JoinHandle<()>>,
}
impl TerminalProgress {
fn start() -> Self {
let done = Arc::new(AtomicBool::new(false));
let thread_done = Arc::clone(&done);
let handle = thread::spawn(move || {
let started = Instant::now();
while !thread_done.load(Ordering::Relaxed) {
let progress = estimated_progress_percent(started.elapsed());
eprint!(
"\rnative-whisperx: [{}] {:>3}% processing media",
render_progress_bar(progress),
progress
);
let _ = std::io::stderr().flush();
thread::sleep(Duration::from_millis(250));
}
});
Self {
done,
handle: Some(handle),
}
}
fn finish_success(mut self) {
self.done.store(true, Ordering::Relaxed);
if let Some(handle) = self.handle.take() {
let _ = handle.join();
}
eprintln!(
"\rnative-whisperx: [{}] 100% done ",
render_progress_bar(100)
);
}
fn finish_error(mut self) {
self.done.store(true, Ordering::Relaxed);
if let Some(handle) = self.handle.take() {
let _ = handle.join();
}
eprintln!();
}
}
fn estimated_progress_percent(elapsed: Duration) -> u8 {
let elapsed = elapsed.as_secs_f64();
((elapsed / (elapsed + 30.0)) * 95.0).round() as u8
}
fn render_progress_bar(progress: u8) -> String {
let width = 28usize;
let filled = (progress as usize * width / 100).min(width);
format!("{}{}", "#".repeat(filled), "-".repeat(width - filled))
}
fn should_print_json_report() -> bool {
!std::io::stdout().is_terminal()
}
fn format_batch_size(batch_size: Option<usize>) -> String {
batch_size
.map(|batch_size| batch_size.to_string())
.unwrap_or_else(|| "unbounded".to_string())
}
pub(crate) fn import_whisperx_command(args: ImportWhisperxArgs) -> anyhow::Result<()> {
let bytes = fs::read(&args.whisperx_json)
.with_context(|| format!("failed to read {}", args.whisperx_json.display()))?;
let transcript = import_whisperx_json(&bytes)?;
let json = serde_json::to_string_pretty(&transcript)?;
if let Some(output) = args.output {
fs::write(&output, json)
.with_context(|| format!("failed to write {}", output.display()))?;
} else {
println!("{json}");
}
Ok(())
}