use std::collections::VecDeque;
use std::path::Path;
use std::sync::Arc;
use std::time::Instant;
use ndarray::{s, Array2, Array3};
use ort::{inputs, value::Tensor};
use realfft::num_complex::Complex;
use realfft::{RealFftPlanner, RealToComplex};
use crate::onnx;
use crate::{AudioFrame, AudioTurnDetector, StageTiming, TurnError, TurnPrediction, TurnState};
const SAMPLE_RATE: u32 = 16_000;
const N_FFT: usize = 400;
const HOP_LENGTH: usize = 160;
const N_MELS: usize = 80;
const N_FRAMES: usize = 800;
const N_FREQS: usize = N_FFT / 2 + 1; const RING_CAPACITY: usize = 8 * SAMPLE_RATE as usize;
const MODEL_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/smart-turn-v3.2-cpu.onnx"));
struct MelExtractor {
mel_filters: Array2<f32>,
hann_window: Vec<f32>,
fft: Arc<dyn RealToComplex<f32>>,
fft_scratch: Vec<Complex<f32>>,
spectrum_buf: Vec<Complex<f32>>,
cached_power_spec: Option<Array2<f32>>,
cached_mel_spec: Option<Array2<f32>>,
}
impl MelExtractor {
fn new() -> Self {
let mel_filters = build_mel_filters(
SAMPLE_RATE as usize,
N_FFT,
N_MELS,
0.0,
SAMPLE_RATE as f32 / 2.0,
);
let hann_window = periodic_hann(N_FFT);
let mut planner = RealFftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(N_FFT);
let fft_scratch = fft.make_scratch_vec();
let spectrum_buf = fft.make_output_vec();
Self {
mel_filters,
hann_window,
fft,
fft_scratch,
spectrum_buf,
cached_power_spec: None,
cached_mel_spec: None,
}
}
fn extract(&mut self, audio: &[f32], shift_frames: usize) -> Array2<f32> {
debug_assert_eq!(audio.len(), RING_CAPACITY);
let pad = N_FFT / 2; let n = audio.len(); let mut padded = vec![0.0f32; pad + n + pad];
padded[pad..pad + n].copy_from_slice(audio);
for i in 0..pad {
padded[i] = audio[pad - i];
}
for i in 0..pad {
padded[pad + n + i] = audio[n - 2 - i];
}
let n_total_frames = (padded.len() - N_FFT) / HOP_LENGTH + 1;
let first_new_frame = match &self.cached_power_spec {
Some(cached) if shift_frames > 0 && shift_frames < n_total_frames => {
let kept = n_total_frames - shift_frames;
let mut power_spec = Array2::<f32>::zeros((N_FREQS, n_total_frames));
power_spec
.slice_mut(s![.., ..kept])
.assign(&cached.slice(s![.., shift_frames..]));
self.cached_power_spec = Some(power_spec);
kept }
_ => {
self.cached_power_spec = Some(Array2::<f32>::zeros((N_FREQS, n_total_frames)));
0 }
};
let power_spec = self.cached_power_spec.as_mut().unwrap();
let mut frame_buf = vec![0.0f32; N_FFT];
for frame_idx in first_new_frame..n_total_frames {
let start = frame_idx * HOP_LENGTH;
for (i, (&s, &w)) in padded[start..start + N_FFT]
.iter()
.zip(self.hann_window.iter())
.enumerate()
{
frame_buf[i] = s * w;
}
self.fft
.process_with_scratch(
&mut frame_buf,
&mut self.spectrum_buf,
&mut self.fft_scratch,
)
.expect("FFT failed: internal buffer size mismatch");
for (k, c) in self.spectrum_buf.iter().enumerate() {
power_spec[[k, frame_idx]] = c.re * c.re + c.im * c.im;
}
}
let power_spec_view = power_spec.slice(s![.., ..N_FRAMES]);
let mel_spec = match &self.cached_mel_spec {
Some(cached) if shift_frames > 0 && shift_frames <= N_FRAMES => {
let kept = N_FRAMES - shift_frames;
let mut ms = Array2::<f32>::zeros((N_MELS, N_FRAMES));
ms.slice_mut(s![.., ..kept])
.assign(&cached.slice(s![.., shift_frames..]));
let new_power = power_spec_view.slice(s![.., kept..]);
ms.slice_mut(s![.., kept..])
.assign(&self.mel_filters.dot(&new_power));
ms
}
_ => self.mel_filters.dot(&power_spec_view),
};
self.cached_mel_spec = Some(mel_spec.clone());
let mut log_mel = mel_spec.mapv(|x| x.max(1e-10_f32).log10());
let max_val = log_mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
log_mel.mapv_inplace(|x| (x.max(max_val - 8.0) + 4.0) / 4.0);
log_mel
}
fn invalidate_cache(&mut self) {
self.cached_power_spec = None;
self.cached_mel_spec = None;
}
}
fn hz_to_mel(hz: f32) -> f32 {
const F_SP: f32 = 200.0 / 3.0; const MIN_LOG_HZ: f32 = 1000.0;
const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP; let logstep = (6.4_f32).ln() / 27.0;
if hz >= MIN_LOG_HZ {
MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / logstep
} else {
hz / F_SP
}
}
fn mel_to_hz(mel: f32) -> f32 {
const F_SP: f32 = 200.0 / 3.0;
const MIN_LOG_HZ: f32 = 1000.0;
const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP;
let logstep = (6.4_f32).ln() / 27.0;
if mel >= MIN_LOG_MEL {
MIN_LOG_HZ * ((mel - MIN_LOG_MEL) * logstep).exp()
} else {
mel * F_SP
}
}
fn build_mel_filters(
sr: usize,
n_fft: usize,
n_mels: usize,
f_min: f32,
f_max: f32,
) -> Array2<f32> {
let n_freqs = n_fft / 2 + 1;
let fft_freqs: Vec<f32> = (0..n_freqs)
.map(|i| i as f32 * sr as f32 / n_fft as f32)
.collect();
let mel_min = hz_to_mel(f_min);
let mel_max = hz_to_mel(f_max);
let mel_pts: Vec<f32> = (0..=(n_mels + 1))
.map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32)
.collect();
let hz_pts: Vec<f32> = mel_pts.iter().map(|&m| mel_to_hz(m)).collect();
let mut filters = Array2::<f32>::zeros((n_mels, n_freqs));
for m in 0..n_mels {
let f_left = hz_pts[m];
let f_center = hz_pts[m + 1];
let f_right = hz_pts[m + 2];
let enorm = 2.0 / (f_right - f_left);
for (k, &f) in fft_freqs.iter().enumerate() {
let w = if f >= f_left && f <= f_center {
(f - f_left) / (f_center - f_left)
} else if f > f_center && f <= f_right {
(f_right - f) / (f_right - f_center)
} else {
0.0
};
filters[[m, k]] = w * enorm;
}
}
filters
}
fn periodic_hann(n: usize) -> Vec<f32> {
use std::f32::consts::PI;
(0..n)
.map(|k| 0.5 * (1.0 - (2.0 * PI * k as f32 / n as f32).cos()))
.collect()
}
fn prepare_audio(samples: &[f32]) -> Vec<f32> {
match samples.len().cmp(&RING_CAPACITY) {
std::cmp::Ordering::Equal => samples.to_vec(),
std::cmp::Ordering::Greater => samples[samples.len() - RING_CAPACITY..].to_vec(),
std::cmp::Ordering::Less => {
let mut out = vec![0.0f32; RING_CAPACITY - samples.len()];
out.extend_from_slice(samples);
out
}
}
}
pub struct PipecatSmartTurn {
session: ort::session::Session,
ring_buffer: VecDeque<f32>,
mel: MelExtractor,
samples_since_predict: usize,
}
unsafe impl Send for PipecatSmartTurn {}
unsafe impl Sync for PipecatSmartTurn {}
impl PipecatSmartTurn {
pub fn new() -> Result<Self, TurnError> {
let session = onnx::session_from_memory(MODEL_BYTES)?;
Ok(Self::build(session))
}
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, TurnError> {
let session = onnx::session_from_file(path)?;
Ok(Self::build(session))
}
fn build(session: ort::session::Session) -> Self {
Self {
session,
ring_buffer: VecDeque::with_capacity(RING_CAPACITY),
mel: MelExtractor::new(),
samples_since_predict: 0,
}
}
}
impl AudioTurnDetector for PipecatSmartTurn {
fn push_audio(&mut self, frame: &AudioFrame) {
if frame.sample_rate() != SAMPLE_RATE {
return;
}
let samples = frame.samples();
let overflow = (self.ring_buffer.len() + samples.len()).saturating_sub(RING_CAPACITY);
if overflow > 0 {
self.ring_buffer.drain(..overflow);
}
self.ring_buffer.extend(samples.iter().copied());
self.samples_since_predict += samples.len();
}
fn predict(&mut self) -> Result<TurnPrediction, TurnError> {
let t_start = Instant::now();
let shift_frames = self.samples_since_predict / HOP_LENGTH;
self.samples_since_predict = 0;
let buffered: Vec<f32> = self.ring_buffer.iter().copied().collect();
let audio = prepare_audio(&buffered);
let t_after_audio_prep = Instant::now();
let mel_spec = self.mel.extract(&audio, shift_frames);
let t_after_mel = Instant::now();
let (raw, _) = mel_spec.into_raw_vec_and_offset();
let input_array = Array3::from_shape_vec((1, N_MELS, N_FRAMES), raw)
.expect("internal: mel output has wrong element count");
let input_tensor = Tensor::from_array(input_array)
.map_err(|e| TurnError::BackendError(format!("failed to create input tensor: {e}")))?;
let outputs = self
.session
.run(inputs!["input_features" => input_tensor])
.map_err(|e| TurnError::BackendError(format!("inference failed: {e}")))?;
let t_after_onnx = Instant::now();
let output = outputs
.get("logits")
.ok_or_else(|| TurnError::BackendError("missing 'logits' output tensor".into()))?;
let (_, data): (_, &[f32]) = output
.try_extract_tensor()
.map_err(|e| TurnError::BackendError(format!("failed to extract logits: {e}")))?;
let probability = *data
.first()
.ok_or_else(|| TurnError::BackendError("logits tensor is empty".into()))?;
let latency_ms = t_start.elapsed().as_millis() as u64;
let us = |a: Instant, b: Instant| (b - a).as_secs_f64() * 1_000_000.0;
let stage_times = vec![
StageTiming {
name: "audio_prep",
us: us(t_start, t_after_audio_prep),
},
StageTiming {
name: "mel",
us: us(t_after_audio_prep, t_after_mel),
},
StageTiming {
name: "onnx",
us: us(t_after_mel, t_after_onnx),
},
];
let (state, confidence) = if probability > 0.5 {
(TurnState::Finished, probability)
} else {
(TurnState::Unfinished, 1.0 - probability)
};
let audio_duration_ms = (self.ring_buffer.len() as u64 * 1000) / SAMPLE_RATE as u64;
Ok(TurnPrediction {
state,
confidence,
latency_ms,
stage_times,
audio_duration_ms,
})
}
fn reset(&mut self) {
self.ring_buffer.clear();
self.samples_since_predict = 0;
self.mel.invalidate_cache();
}
}
#[cfg(test)]
mod mel_tests {
use std::path::{Path, PathBuf};
use ndarray::Array2;
use ndarray_npy::ReadNpyExt;
use super::{prepare_audio, MelExtractor, RING_CAPACITY, SAMPLE_RATE};
const MEL_TOLERANCE: f32 = 0.05;
fn fixtures_dir() -> PathBuf {
Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap() .parent()
.unwrap() .join("tests/fixtures")
}
fn load_wav_f32(path: &Path) -> Vec<f32> {
let mut reader = hound::WavReader::open(path)
.unwrap_or_else(|e| panic!("failed to open {}: {}", path.display(), e));
let spec = reader.spec();
assert_eq!(spec.sample_rate, SAMPLE_RATE, "expected 16 kHz");
assert_eq!(spec.channels, 1, "expected mono");
match spec.sample_format {
hound::SampleFormat::Int => reader
.samples::<i16>()
.map(|s| s.unwrap() as f32 / 32768.0)
.collect(),
hound::SampleFormat::Float => reader.samples::<f32>().map(|s| s.unwrap()).collect(),
}
}
fn load_python_mel(clip: &str) -> Array2<f32> {
let path = fixtures_dir().join(format!("{clip}.mel.npy"));
let file = std::fs::File::open(&path).unwrap_or_else(|_| {
panic!(
"missing {}: run `python scripts/gen_reference.py` first",
path.display()
)
});
Array2::<f32>::read_npy(file).expect("failed to parse .npy")
}
struct MelDiff {
max_diff: f32,
mean_diff: f32,
max_at: (usize, usize),
outlier_frac: f32,
}
fn compare_mel(clip: &str) -> MelDiff {
let samples = load_wav_f32(&fixtures_dir().join(clip));
let audio = prepare_audio(&samples);
assert_eq!(audio.len(), RING_CAPACITY);
let mut extractor = MelExtractor::new();
let rust_mel = extractor.extract(&audio, 0);
let python_mel = load_python_mel(clip);
assert_eq!(
rust_mel.shape(),
python_mel.shape(),
"{clip}: mel shape mismatch"
);
let shape = rust_mel.shape();
let (n_mels, n_frames) = (shape[0], shape[1]);
let mut max_diff = 0.0f32;
let mut max_at = (0, 0);
let mut sum_diff = 0.0f32;
let mut outliers = 0usize;
for m in 0..n_mels {
for t in 0..n_frames {
let d = (rust_mel[[m, t]] - python_mel[[m, t]]).abs();
sum_diff += d;
if d > max_diff {
max_diff = d;
max_at = (m, t);
}
if d > 0.01 {
outliers += 1;
}
}
}
let total = (n_mels * n_frames) as f32;
MelDiff {
max_diff,
mean_diff: sum_diff / total,
max_at,
outlier_frac: outliers as f32 / total,
}
}
#[test]
#[ignore]
fn mel_report() {
let clips = ["silence_2s.wav", "speech_finished.wav", "speech_mid.wav"];
println!();
println!("MEL_TOLERANCE={MEL_TOLERANCE}");
println!();
println!("| Clip | Max Diff | Mean Diff | Max at (mel,frame) | Outliers >0.01 | Status |");
println!("|------|----------|-----------|---------------------|----------------|--------|");
for clip in clips {
let d = compare_mel(clip);
let status = if d.max_diff <= MEL_TOLERANCE {
"PASS"
} else {
"FAIL"
};
println!(
"| `{clip}` | {:.6} | {:.6} | ({},{}) | {:.2}% | {status} |",
d.max_diff,
d.mean_diff,
d.max_at.0,
d.max_at.1,
d.outlier_frac * 100.0,
);
}
println!();
}
}