use std::{
fs, io,
path::{Path, PathBuf},
process,
};
use anyhow::{anyhow, bail, Context, Result};
#[tracing::instrument(skip_all)]
pub fn convert(
input_file: &Path,
output_file: Option<&Path>,
model_file: &Path,
normalize: bool,
) -> anyhow::Result<()> {
let input_file = if normalize {
tracing::info!(?input_file, "Normalizing.");
file_normalize(input_file)
.context(format!("Failed to normalize file: {input_file:?}"))?
} else {
input_file.to_owned()
};
let audio_data = read_wav(&input_file).context(format!(
"Failed to read audio from WAV file: {input_file:?}"
))?;
let text_segments = segments(&audio_data, model_file)
.context("Failed to determine text segments.")?;
let mut text_buf: Box<dyn io::Write> = match output_file {
None => {
tracing::info!("Output to stdout.");
Box::new(io::stdout().lock())
}
Some(ref path) => {
tracing::info!(?path, "Output to file.");
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).context(format!(
"Failed to create directory path: {path:?}"
))?;
}
let buf = fs::File::create(path).context(format!(
"Failed to open file for writing: {path:?}"
))?;
Box::new(buf)
}
};
for text in text_segments {
writeln!(text_buf, "{}", &text)
.context("Failed to output a text segment.")?;
}
Ok(())
}
fn exec(cmd: &str, args: &[&str]) -> Result<Vec<u8>> {
tracing::debug!(?cmd, ?args, "Executing command.");
let out = process::Command::new(cmd).args(args).output()?;
if out.status.success() {
Ok(out.stdout)
} else {
let stderr = String::from_utf8(out.stderr.clone())?;
tracing::error!(?cmd, ?args, ?stderr, "Command failed to execute");
Err(anyhow!("Failure in '{} {:?}'. out: {:?}", cmd, args, out))
}
}
#[tracing::instrument]
fn file_normalize(in_path: &Path) -> Result<PathBuf> {
let out_path = tempfile::tempdir()?.into_path().join("normalized.wav");
#[rustfmt::skip] exec(
"ffmpeg",
&[
"-y", "-i", in_path.as_os_str().to_str().ok_or_else(|| {
anyhow!(
"Failed to convert input path to string: {:?}",
&in_path
)
})?,
"-ar:a", "16000", "-ac:a", "1", "-codec:a", "pcm_s16le", "-f", "wav", out_path.as_os_str().to_str().ok_or_else(|| {
anyhow!(
"Failed to convert output path to string: {:?}",
&out_path
)
})?,
],
)?;
Ok(out_path)
}
#[tracing::instrument]
fn read_wav(path: &Path) -> Result<Vec<f32>> {
let mut wav_reader = hound::WavReader::open(path)?;
let spec = wav_reader.spec();
tracing::debug!(?spec, "Constructed a hound wav reader.");
let hound::WavSpec {
channels: ch,
sample_rate: rate,
bits_per_sample: bits,
sample_format: fmt,
..
} = spec;
match rate {
16000 => (),
n => bail!(
"Unsupported sample rate: {} Hz. Only 16000 Hz is supported.",
n
),
}
let convert_stereo2mono = match ch {
1 => false,
2 => true,
n => bail!("Unsupported number of channels: {}", n),
};
let convert_int2float = match (bits, fmt) {
(16, hound::SampleFormat::Int) => true,
(32, hound::SampleFormat::Float) => false,
(bits, fmt) => bail!(
"Unsupported combination of \
bits ({}) and \
format ({:?}) \
in file: {:?}",
bits,
fmt,
path
),
};
let mut samples: Vec<i32> = Vec::new();
for sample_result in wav_reader.samples() {
let sample = sample_result?;
samples.push(sample);
}
let mut samples: Vec<f32> = if convert_int2float {
let dat: Vec<i16> = samples.iter().map(|i| *i as i16).collect();
whisper_rs::convert_integer_to_float_audio(&dat)
} else {
samples.iter().map(|i| *i as f32).collect()
};
if convert_stereo2mono {
samples = whisper_rs::convert_stereo_to_mono_audio(&samples)
.map_err(|e| {
anyhow!("failed to convert stereo to mono: {:?}", e)
})?;
}
Ok(samples)
}
#[tracing::instrument(skip_all)]
fn segments(data: &[f32], model: &Path) -> Result<Vec<String>> {
let ctx = whisper_rs::WhisperContext::new(
model.as_os_str().to_str().ok_or_else(|| {
anyhow!("Failed to convert model path to &str: {:?}", model)
})?,
)
.context("failed to load model")?;
let params =
whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy {
best_of: 1,
});
let mut state = ctx.create_state().context("failed to create state")?;
state.full(params, data).context("failed to run model")?;
let num_segments = state
.full_n_segments()
.context("failed to get number of segments")?;
let mut text_segments = Vec::new();
for i in 0..num_segments {
text_segments.push(
state
.full_get_segment_text(i)
.context("failed to get segment")?,
);
}
Ok(text_segments)
}