use crate::error::{Error, Result};
use crate::execution::ModelConfig;
use ndarray::{s, Array1, Array2, Array3, Axis};
use ort::session::Session;
use realfft::RealFftPlanner;
use std::f32::consts::PI;
use std::path::Path;
const N_FFT: usize = 512;
const WIN_LENGTH: usize = 400;
const HOP_LENGTH: usize = 160;
const N_MELS: usize = 128;
const PREEMPH: f32 = 0.97;
const LOG_ZERO_GUARD: f32 = 5.960_464_5e-8;
const SAMPLE_RATE: usize = 16000;
const CHUNK_LEN: usize = 124; const FIFO_LEN: usize = 124; const SPKCACHE_LEN: usize = 188; const RIGHT_CONTEXT: usize = 1; const SUBSAMPLING: usize = 8; const EMB_DIM: usize = 512; pub const NUM_SPEAKERS: usize = 4; const FRAME_DURATION: f32 = 0.08;
const SPKCACHE_SIL_FRAMES_PER_SPK: usize = 3;
const PRED_SCORE_THRESHOLD: f32 = 0.25;
const STRONG_BOOST_RATE: f32 = 0.75;
const WEAK_BOOST_RATE: f32 = 1.5;
const MIN_POS_SCORES_RATE: f32 = 0.5;
const SIL_THRESHOLD: f32 = 0.2;
const MAX_INDEX: usize = 99999;
#[derive(Debug, Clone)]
pub struct DiarizationConfig {
pub onset: f32,
pub offset: f32,
pub pad_onset: f32,
pub pad_offset: f32,
pub min_duration_on: f32,
pub min_duration_off: f32,
pub median_window: usize,
}
impl Default for DiarizationConfig {
fn default() -> Self {
Self::callhome()
}
}
impl DiarizationConfig {
pub fn callhome() -> Self {
Self {
onset: 0.641,
offset: 0.561,
pad_onset: 0.229,
pad_offset: 0.079,
min_duration_on: 0.511,
min_duration_off: 0.296,
median_window: 11,
}
}
pub fn dihard3() -> Self {
Self {
onset: 0.56,
offset: 1.0,
pad_onset: 0.063,
pad_offset: 0.002,
min_duration_on: 0.007,
min_duration_off: 0.151,
median_window: 11,
}
}
pub fn custom(onset: f32, offset: f32) -> Self {
Self {
onset,
offset,
pad_onset: 0.0,
pad_offset: 0.0,
min_duration_on: 0.1,
min_duration_off: 0.1,
median_window: 11,
}
}
}
#[derive(Debug, Clone)]
pub struct SpeakerSegment {
pub start: u64,
pub end: u64,
pub speaker_id: usize,
}
#[derive(Debug, Clone)]
pub struct RawDiarizationPredictions {
pub predictions: Array2<f32>,
pub num_valid_frames: usize,
}
pub struct Sortformer {
session: Session,
config: DiarizationConfig,
pub chunk_len: usize,
pub fifo_len: usize,
pub spkcache_len: usize,
pub right_context: usize,
spkcache: Array3<f32>, spkcache_preds: Option<Array3<f32>>, fifo: Array3<f32>, fifo_preds: Array3<f32>, mean_sil_emb: Array2<f32>, n_sil_frames: usize,
audio_buffer: Vec<f32>,
elapsed_samples: usize,
mel_basis: Array2<f32>,
}
impl Sortformer {
pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
Self::with_config(model_path, None, DiarizationConfig::default())
}
pub fn with_config<P: AsRef<Path>>(
model_path: P,
execution_config: Option<ModelConfig>,
config: DiarizationConfig,
) -> Result<Self> {
let config_to_use = execution_config.unwrap_or_default();
let mut builder = config_to_use
.apply_to_session_builder(Session::builder()?)?;
let session = builder.commit_from_file(model_path.as_ref())?;
let (chunk_len, fifo_len, spkcache_len, right_context) =
if let Ok(metadata) = session.metadata() {
let c = metadata
.custom("chunk_len")
.and_then(|v| v.parse().ok())
.unwrap_or(CHUNK_LEN);
let f = metadata
.custom("fifo_len")
.and_then(|v| v.parse().ok())
.unwrap_or(FIFO_LEN);
let s = metadata
.custom("spkcache_len")
.and_then(|v| v.parse().ok())
.unwrap_or(SPKCACHE_LEN);
let r = metadata
.custom("right_context")
.and_then(|v| v.parse().ok())
.unwrap_or(RIGHT_CONTEXT);
(c, f, s, r)
} else {
(CHUNK_LEN, FIFO_LEN, SPKCACHE_LEN, RIGHT_CONTEXT)
};
let mel_basis = crate::audio::create_mel_filterbank(N_FFT, N_MELS, SAMPLE_RATE);
let mut instance = Self {
session,
config,
chunk_len,
fifo_len,
spkcache_len,
right_context,
spkcache: Array3::zeros((1, 0, EMB_DIM)),
spkcache_preds: None,
fifo: Array3::zeros((1, 0, EMB_DIM)),
fifo_preds: Array3::zeros((1, 0, NUM_SPEAKERS)),
mean_sil_emb: Array2::zeros((1, EMB_DIM)),
n_sil_frames: 0,
audio_buffer: Vec::new(),
elapsed_samples: 0,
mel_basis,
};
instance.reset_state();
Ok(instance)
}
pub fn latency(&self) -> f32 {
(self.chunk_len + self.right_context) as f32 * FRAME_DURATION
}
pub fn reset_state(&mut self) {
self.spkcache = Array3::zeros((1, 0, EMB_DIM));
self.spkcache_preds = None;
self.fifo = Array3::zeros((1, 0, EMB_DIM));
self.fifo_preds = Array3::zeros((1, 0, NUM_SPEAKERS));
self.mean_sil_emb = Array2::zeros((1, EMB_DIM));
self.n_sil_frames = 0;
self.audio_buffer.clear();
self.elapsed_samples = 0;
}
pub fn diarize(
&mut self,
mut audio: Vec<f32>,
sample_rate: u32,
channels: u16,
) -> Result<Vec<SpeakerSegment>> {
if sample_rate != SAMPLE_RATE as u32 {
return Err(Error::Audio(format!(
"Expected {} Hz, got {} Hz",
SAMPLE_RATE, sample_rate
)));
}
if channels > 1 {
audio = audio
.chunks(channels as usize)
.map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
.collect();
}
self.reset_state();
let features = self.extract_mel_features(&audio)?;
let full_preds = self.process_features(&features)?;
let filtered_preds = if self.config.median_window > 1 {
self.median_filter(&full_preds)
} else {
full_preds
};
let n_audio_samples = audio.len() as u64;
let mut segments = self.binarize(&filtered_preds);
for seg in &mut segments {
seg.end = seg.end.min(n_audio_samples);
}
segments.retain(|s| s.end > s.start);
Ok(segments)
}
pub fn diarize_chunk(&mut self, audio_16k_mono: &[f32]) -> Result<Vec<SpeakerSegment>> {
if audio_16k_mono.is_empty() {
return Ok(vec![]);
}
let features = self.extract_mel_features(audio_16k_mono)?;
let full_preds = self.process_features(&features)?;
let filtered_preds = if self.config.median_window > 1 {
self.median_filter(&full_preds)
} else {
full_preds
};
let n_audio_samples = audio_16k_mono.len() as u64;
let mut segments = self.binarize(&filtered_preds);
for seg in &mut segments {
seg.end = seg.end.min(n_audio_samples);
}
segments.retain(|s| s.end > s.start);
Ok(segments)
}
pub fn diarize_chunk_raw(
&mut self,
audio_16k_mono: &[f32],
) -> Result<RawDiarizationPredictions> {
if audio_16k_mono.is_empty() {
return Ok(RawDiarizationPredictions {
predictions: Array2::zeros((0, NUM_SPEAKERS)),
num_valid_frames: 0,
});
}
let features = self.extract_mel_features(audio_16k_mono)?;
let full_preds = self.process_features(&features)?;
let num_valid_frames = full_preds.nrows();
Ok(RawDiarizationPredictions {
predictions: full_preds,
num_valid_frames,
})
}
pub fn feed(&mut self, audio_16k_mono: &[f32]) -> Result<Vec<SpeakerSegment>> {
self.audio_buffer.extend_from_slice(audio_16k_mono);
let feed_size = (self.chunk_len + self.right_context) * SUBSAMPLING;
let stride_samples = self.chunk_len * SUBSAMPLING * HOP_LENGTH;
let feed_samples = (self.chunk_len + self.right_context) * SUBSAMPLING * HOP_LENGTH;
let mut all_segments = Vec::new();
while self.audio_buffer.len() >= feed_samples {
let window = &self.audio_buffer[..feed_samples];
let features = self.extract_mel_features(window)?;
let chunk_feat = features.slice(s![.., ..feed_size, ..]).to_owned();
let current_len = feed_size;
let chunk_preds = self.streaming_update(&chunk_feat, current_len)?;
let filtered_preds = if self.config.median_window > 1 {
self.median_filter(&chunk_preds)
} else {
chunk_preds
};
let sample_offset = self.elapsed_samples as u64;
let chunk_samples = (self.chunk_len * SUBSAMPLING * HOP_LENGTH) as u64;
let mut segments = self.binarize(&filtered_preds);
for seg in &mut segments {
seg.start += sample_offset;
seg.end = (seg.end + sample_offset).min(sample_offset + chunk_samples);
}
segments.retain(|s| s.end > s.start);
all_segments.extend(segments);
self.audio_buffer.drain(..stride_samples);
self.elapsed_samples += stride_samples;
}
Ok(all_segments)
}
pub fn flush(&mut self) -> Result<Vec<SpeakerSegment>> {
if self.audio_buffer.is_empty() {
return Ok(vec![]);
}
let feed_size = (self.chunk_len + self.right_context) * SUBSAMPLING;
let remaining = std::mem::take(&mut self.audio_buffer);
let features = self.extract_mel_features(&remaining)?;
let total_mel = features.shape()[1];
let current_len = total_mel.min(feed_size);
let chunk_feat = if current_len < feed_size {
let mut padded = Array3::zeros((1, feed_size, N_MELS));
padded
.slice_mut(s![.., ..current_len, ..])
.assign(&features.slice(s![.., ..current_len, ..]));
padded
} else {
features.slice(s![.., ..feed_size, ..]).to_owned()
};
let chunk_preds = self.streaming_update(&chunk_feat, current_len)?;
let filtered_preds = if self.config.median_window > 1 {
self.median_filter(&chunk_preds)
} else {
chunk_preds
};
let sample_offset = self.elapsed_samples as u64;
let remaining_samples = remaining.len() as u64;
let mut segments = self.binarize(&filtered_preds);
for seg in &mut segments {
seg.start += sample_offset;
seg.end = (seg.end + sample_offset).min(sample_offset + remaining_samples);
}
segments.retain(|s| s.end > s.start);
self.elapsed_samples += remaining.len();
Ok(segments)
}
fn process_features(&mut self, features: &Array3<f32>) -> Result<Array2<f32>> {
let total_frames = features.shape()[1];
let chunk_stride = self.chunk_len * SUBSAMPLING;
let feed_size = (self.chunk_len + self.right_context) * SUBSAMPLING;
let num_chunks = total_frames.div_ceil(chunk_stride);
let mut all_chunk_preds = Vec::new();
for chunk_idx in 0..num_chunks {
let start = chunk_idx * chunk_stride;
let end = (start + feed_size).min(total_frames);
let current_len = end - start;
let mut chunk_feat = features.slice(s![.., start..end, ..]).to_owned();
if current_len < feed_size {
let mut padded = Array3::zeros((1, feed_size, N_MELS));
padded
.slice_mut(s![.., ..current_len, ..])
.assign(&chunk_feat);
chunk_feat = padded;
}
let chunk_preds = self.streaming_update(&chunk_feat, current_len)?;
all_chunk_preds.push(chunk_preds);
}
Ok(Self::concat_predictions(&all_chunk_preds))
}
fn streaming_update(
&mut self,
chunk_feat: &Array3<f32>,
current_len: usize,
) -> Result<Array2<f32>> {
let spkcache_len = self.spkcache.shape()[1];
let fifo_len = self.fifo.shape()[1];
let chunk_lengths = Array1::from_vec(vec![current_len as i64]);
let spkcache_lengths = Array1::from_vec(vec![spkcache_len as i64]);
let fifo_lengths = Array1::from_vec(vec![fifo_len as i64]);
let empty_3d = Array3::<f32>::zeros((1, 0, EMB_DIM));
let fifo_ref = if fifo_len > 0 { &self.fifo } else { &empty_3d };
let spkcache_ref = if spkcache_len > 0 {
&self.spkcache
} else {
&empty_3d
};
let chunk_value = ort::value::TensorRef::<f32>::from_array_view(chunk_feat.view())?;
let chunk_lengths_value = ort::value::Value::from_array(chunk_lengths)?;
let spkcache_value = ort::value::TensorRef::<f32>::from_array_view(spkcache_ref.view())?;
let spkcache_lengths_value = ort::value::Value::from_array(spkcache_lengths)?;
let fifo_value = ort::value::TensorRef::<f32>::from_array_view(fifo_ref.view())?;
let fifo_lengths_value = ort::value::Value::from_array(fifo_lengths)?;
let (preds, new_embs, chunk_len) = {
let outputs = self.session.run(ort::inputs!(
"chunk" => chunk_value,
"chunk_lengths" => chunk_lengths_value,
"spkcache" => spkcache_value,
"spkcache_lengths" => spkcache_lengths_value,
"fifo" => fifo_value,
"fifo_lengths" => fifo_lengths_value
))?;
let (preds_shape, preds_data) = outputs["spkcache_fifo_chunk_preds"]
.try_extract_tensor::<f32>()
.map_err(|e| Error::Model(format!("Failed to extract preds: {e}")))?;
let (embs_shape, embs_data) = outputs["chunk_pre_encode_embs"]
.try_extract_tensor::<f32>()
.map_err(|e| Error::Model(format!("Failed to extract embs: {e}")))?;
let preds_dims = preds_shape.as_ref();
let embs_dims = embs_shape.as_ref();
let preds = Array3::from_shape_vec(
(
preds_dims[0] as usize,
preds_dims[1] as usize,
preds_dims[2] as usize,
),
preds_data.to_vec(),
)
.map_err(|e| Error::Model(format!("Failed to reshape preds: {e}")))?;
let new_embs = Array3::from_shape_vec(
(
embs_dims[0] as usize,
embs_dims[1] as usize,
embs_dims[2] as usize,
),
embs_data.to_vec(),
)
.map_err(|e| Error::Model(format!("Failed to reshape embs: {e}")))?;
let valid_frames = current_len.div_ceil(SUBSAMPLING);
(preds, new_embs, valid_frames)
};
let fifo_preds = if fifo_len > 0 {
preds
.slice(s![0, spkcache_len..spkcache_len + fifo_len, ..])
.to_owned()
} else {
Array2::zeros((0, NUM_SPEAKERS))
};
let keep = self.chunk_len.min(chunk_len);
let chunk_preds = preds
.slice(s![
0,
spkcache_len + fifo_len..spkcache_len + fifo_len + keep,
..
])
.to_owned();
let chunk_embs = new_embs.slice(s![0, ..keep, ..]).to_owned();
self.fifo = Self::concat_axis1(&self.fifo, &chunk_embs.insert_axis(Axis(0)));
if fifo_len > 0 {
let combined = Self::concat_axis1_2d(&fifo_preds, &chunk_preds);
self.fifo_preds = combined.insert_axis(Axis(0));
} else {
self.fifo_preds = chunk_preds.clone().insert_axis(Axis(0));
}
let fifo_len_after = self.fifo.shape()[1];
if fifo_len_after > self.fifo_len {
let mut pop_out_len = self.chunk_len;
pop_out_len = pop_out_len.max(chunk_len.saturating_sub(self.fifo_len) + fifo_len);
pop_out_len = pop_out_len.min(fifo_len_after);
let pop_out_embs = self.fifo.slice(s![.., ..pop_out_len, ..]).to_owned();
let pop_out_preds = self.fifo_preds.slice(s![.., ..pop_out_len, ..]).to_owned();
self.update_silence_profile(&pop_out_embs, &pop_out_preds);
self.fifo = self.fifo.slice(s![.., pop_out_len.., ..]).to_owned();
self.fifo_preds = self.fifo_preds.slice(s![.., pop_out_len.., ..]).to_owned();
self.spkcache = Self::concat_axis1(&self.spkcache, &pop_out_embs);
if let Some(ref cache_preds) = self.spkcache_preds {
self.spkcache_preds = Some(Self::concat_axis1(cache_preds, &pop_out_preds));
}
if self.spkcache.shape()[1] > self.spkcache_len {
if self.spkcache_preds.is_none() {
let initial_cache_preds = preds.slice(s![.., ..spkcache_len, ..]).to_owned();
let combined = Self::concat_axis1(&initial_cache_preds, &pop_out_preds);
self.spkcache_preds = Some(combined);
}
self.compress_spkcache();
}
}
Ok(chunk_preds)
}
fn update_silence_profile(&mut self, embs: &Array3<f32>, preds: &Array3<f32>) {
let preds_2d = preds.slice(s![0, .., ..]);
for t in 0..preds_2d.shape()[0] {
let sum: f32 = (0..NUM_SPEAKERS).map(|s| preds_2d[[t, s]]).sum();
if sum < SIL_THRESHOLD {
let emb = embs.slice(s![0, t, ..]);
let old_sum: Vec<f32> = self
.mean_sil_emb
.slice(s![0, ..])
.iter()
.map(|&x| x * self.n_sil_frames as f32)
.collect();
self.n_sil_frames += 1;
for i in 0..EMB_DIM {
self.mean_sil_emb[[0, i]] = (old_sum[i] + emb[i]) / self.n_sil_frames as f32;
}
}
}
}
fn compress_spkcache(&mut self) {
let cache_preds = match &self.spkcache_preds {
Some(p) => p.clone(),
None => return,
};
let n_frames = self.spkcache.shape()[1];
let per_spk = self.spkcache_len / NUM_SPEAKERS;
if per_spk <= SPKCACHE_SIL_FRAMES_PER_SPK {
self.spkcache = self.spkcache.slice(s![.., ..self.spkcache_len, ..]).to_owned();
if let Some(ref p) = self.spkcache_preds {
self.spkcache_preds = Some(p.slice(s![.., ..self.spkcache_len, ..]).to_owned());
}
return;
}
let spkcache_len_per_spk = per_spk - SPKCACHE_SIL_FRAMES_PER_SPK;
let strong_boost_per_spk = (spkcache_len_per_spk as f32 * STRONG_BOOST_RATE) as usize;
let weak_boost_per_spk = (spkcache_len_per_spk as f32 * WEAK_BOOST_RATE) as usize;
let min_pos_scores_per_spk = (spkcache_len_per_spk as f32 * MIN_POS_SCORES_RATE) as usize;
let preds_2d = cache_preds.slice(s![0, .., ..]).to_owned();
let mut scores = self.get_log_pred_scores(&preds_2d);
scores = self.disable_low_scores(&preds_2d, scores, min_pos_scores_per_spk);
scores = self.boost_topk_scores(scores, strong_boost_per_spk, 2.0);
scores = self.boost_topk_scores(scores, weak_boost_per_spk, 1.0);
if SPKCACHE_SIL_FRAMES_PER_SPK > 0 {
let mut padded = Array2::from_elem(
(n_frames + SPKCACHE_SIL_FRAMES_PER_SPK, NUM_SPEAKERS),
f32::NEG_INFINITY,
);
padded.slice_mut(s![..n_frames, ..]).assign(&scores);
for i in n_frames..n_frames + SPKCACHE_SIL_FRAMES_PER_SPK {
for j in 0..NUM_SPEAKERS {
padded[[i, j]] = f32::INFINITY;
}
}
scores = padded;
}
let (topk_indices, is_disabled) = self.get_topk_indices(&scores, n_frames);
let (new_embs, new_preds) = self.gather_spkcache(&topk_indices, &is_disabled);
self.spkcache = new_embs;
self.spkcache_preds = Some(new_preds);
}
fn get_log_pred_scores(&self, preds: &Array2<f32>) -> Array2<f32> {
let mut scores = Array2::zeros(preds.dim());
for t in 0..preds.shape()[0] {
let mut log_1_probs_sum = 0.0f32;
for s in 0..NUM_SPEAKERS {
let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD);
let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln();
log_1_probs_sum += log_1_p;
}
for s in 0..NUM_SPEAKERS {
let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD);
let log_p = p.ln();
let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln();
scores[[t, s]] = log_p - log_1_p + log_1_probs_sum - 0.5f32.ln();
}
}
scores
}
fn disable_low_scores(
&self,
preds: &Array2<f32>,
mut scores: Array2<f32>,
min_pos_scores_per_spk: usize,
) -> Array2<f32> {
let mut pos_count = [0usize; NUM_SPEAKERS];
for t in 0..scores.shape()[0] {
for s in 0..NUM_SPEAKERS {
if scores[[t, s]] > 0.0 {
pos_count[s] += 1;
}
}
}
for t in 0..preds.shape()[0] {
for s in 0..NUM_SPEAKERS {
let is_speech = preds[[t, s]] > 0.5;
if !is_speech {
scores[[t, s]] = f32::NEG_INFINITY;
} else {
let is_pos = scores[[t, s]] > 0.0;
if !is_pos && pos_count[s] >= min_pos_scores_per_spk {
scores[[t, s]] = f32::NEG_INFINITY;
}
}
}
}
scores
}
fn boost_topk_scores(
&self,
mut scores: Array2<f32>,
n_boost_per_spk: usize,
scale_factor: f32,
) -> Array2<f32> {
for s in 0..NUM_SPEAKERS {
let col: Vec<(usize, f32)> = (0..scores.shape()[0])
.map(|t| (t, scores[[t, s]]))
.collect();
let mut sorted = col.clone();
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for item in sorted.iter().take(n_boost_per_spk.min(sorted.len())) {
let t = item.0;
if scores[[t, s]] != f32::NEG_INFINITY {
scores[[t, s]] -= scale_factor * 0.5f32.ln();
}
}
}
scores
}
fn get_topk_indices(
&self,
scores: &Array2<f32>,
n_frames_no_sil: usize,
) -> (Vec<usize>, Vec<bool>) {
let n_frames = scores.shape()[0];
let mut flat_scores: Vec<(usize, f32)> = Vec::with_capacity(n_frames * NUM_SPEAKERS);
for s in 0..NUM_SPEAKERS {
for t in 0..n_frames {
let flat_idx = s * n_frames + t;
flat_scores.push((flat_idx, scores[[t, s]]));
}
}
flat_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut topk_flat: Vec<usize> = flat_scores
.iter()
.take(self.spkcache_len)
.map(|(idx, score)| {
if *score == f32::NEG_INFINITY {
MAX_INDEX
} else {
*idx
}
})
.collect();
topk_flat.sort();
let mut is_disabled = vec![false; self.spkcache_len];
let mut frame_indices = vec![0usize; self.spkcache_len];
for (i, &flat_idx) in topk_flat.iter().enumerate() {
if flat_idx == MAX_INDEX {
is_disabled[i] = true;
frame_indices[i] = 0; } else {
let frame_idx = flat_idx % n_frames;
if frame_idx >= n_frames_no_sil {
is_disabled[i] = true;
frame_indices[i] = 0; } else {
frame_indices[i] = frame_idx;
}
}
}
(frame_indices, is_disabled)
}
fn gather_spkcache(
&self,
indices: &[usize],
is_disabled: &[bool],
) -> (Array3<f32>, Array3<f32>) {
let mut new_embs = Array3::zeros((1, self.spkcache_len, EMB_DIM));
let mut new_preds = Array3::zeros((1, self.spkcache_len, NUM_SPEAKERS));
let cache_preds = self.spkcache_preds.as_ref().unwrap();
for (i, (&idx, &disabled)) in indices.iter().zip(is_disabled.iter()).enumerate() {
if i >= self.spkcache_len {
break;
}
if disabled {
new_embs
.slice_mut(s![0, i, ..])
.assign(&self.mean_sil_emb.slice(s![0, ..]));
} else if idx < self.spkcache.shape()[1] {
new_embs
.slice_mut(s![0, i, ..])
.assign(&self.spkcache.slice(s![0, idx, ..]));
new_preds
.slice_mut(s![0, i, ..])
.assign(&cache_preds.slice(s![0, idx, ..]));
}
}
(new_embs, new_preds)
}
fn concat_axis1(a: &Array3<f32>, b: &Array3<f32>) -> Array3<f32> {
if a.shape()[1] == 0 {
return b.clone();
}
if b.shape()[1] == 0 {
return a.clone();
}
ndarray::concatenate(Axis(1), &[a.view(), b.view()]).unwrap()
}
fn concat_axis1_2d(a: &Array2<f32>, b: &Array2<f32>) -> Array2<f32> {
if a.shape()[0] == 0 {
return b.clone();
}
if b.shape()[0] == 0 {
return a.clone();
}
ndarray::concatenate(Axis(0), &[a.view(), b.view()]).unwrap()
}
fn concat_predictions(preds: &[Array2<f32>]) -> Array2<f32> {
if preds.is_empty() {
return Array2::zeros((0, NUM_SPEAKERS));
}
if preds.len() == 1 {
return preds[0].clone();
}
let views: Vec<_> = preds.iter().map(|p| p.view()).collect();
ndarray::concatenate(Axis(0), &views).unwrap()
}
fn median_filter(&self, preds: &Array2<f32>) -> Array2<f32> {
let window = self.config.median_window;
let half = window / 2;
let mut filtered = preds.clone();
for spk in 0..NUM_SPEAKERS {
for t in 0..preds.shape()[0] {
let start = t.saturating_sub(half);
let end = (t + half + 1).min(preds.shape()[0]);
let mut values: Vec<f32> = (start..end).map(|i| preds[[i, spk]]).collect();
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
filtered[[t, spk]] = values[values.len() / 2];
}
}
filtered
}
fn binarize(&self, preds: &Array2<f32>) -> Vec<SpeakerSegment> {
let mut segments = Vec::new();
let num_frames = preds.shape()[0];
let pad_onset_samples = (self.config.pad_onset * SAMPLE_RATE as f32) as u64;
let pad_offset_samples = (self.config.pad_offset * SAMPLE_RATE as f32) as u64;
let min_dur_on_samples = (self.config.min_duration_on * SAMPLE_RATE as f32) as u64;
let min_dur_off_samples = (self.config.min_duration_off * SAMPLE_RATE as f32) as u64;
let samples_per_frame = (FRAME_DURATION * SAMPLE_RATE as f32) as u64;
for spk in 0..NUM_SPEAKERS {
let mut in_seg = false;
let mut seg_start = 0;
let mut temp_segments = Vec::new();
for t in 0..num_frames {
let p = preds[[t, spk]];
if p >= self.config.onset && !in_seg {
in_seg = true;
seg_start = t;
} else if p < self.config.offset && in_seg {
in_seg = false;
let start_s = (seg_start as u64 * samples_per_frame)
.saturating_sub(pad_onset_samples);
let end_s = t as u64 * samples_per_frame + pad_offset_samples;
if end_s - start_s >= min_dur_on_samples {
temp_segments.push(SpeakerSegment {
start: start_s,
end: end_s,
speaker_id: spk,
});
}
}
}
if in_seg {
let start_s = (seg_start as u64 * samples_per_frame)
.saturating_sub(pad_onset_samples);
let end_s = num_frames as u64 * samples_per_frame + pad_offset_samples;
if end_s - start_s >= min_dur_on_samples {
temp_segments.push(SpeakerSegment {
start: start_s,
end: end_s,
speaker_id: spk,
});
}
}
if temp_segments.len() > 1 {
let mut filtered = vec![temp_segments[0].clone()];
for seg in temp_segments.into_iter().skip(1) {
let last = filtered.last_mut().unwrap();
let gap = seg.start.saturating_sub(last.end);
if gap < min_dur_off_samples {
last.end = seg.end; } else {
filtered.push(seg);
}
}
segments.extend(filtered);
} else {
segments.extend(temp_segments);
}
}
segments.sort_by_key(|s| s.start);
segments
}
fn apply_preemphasis(audio: &[f32]) -> Vec<f32> {
let mut result = Vec::with_capacity(audio.len());
result.push(audio[0]);
for i in 1..audio.len() {
result.push(audio[i] - PREEMPH * audio[i - 1]);
}
result
}
fn hann_window(window_length: usize) -> Vec<f32> {
(0..window_length)
.map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / window_length as f32).cos())
.collect()
}
fn stft(audio: &[f32]) -> Result<Array2<f32>> {
let mut planner = RealFftPlanner::<f32>::new();
let r2c = planner.plan_fft_forward(N_FFT);
let hann = Self::hann_window(WIN_LENGTH);
let win_offset = (N_FFT - WIN_LENGTH) / 2;
let mut fft_window = vec![0.0f32; N_FFT];
fft_window[win_offset..(WIN_LENGTH + win_offset)].copy_from_slice(&hann[..WIN_LENGTH]);
let pad_amount = N_FFT / 2;
let mut padded_audio = vec![0.0; pad_amount];
padded_audio.extend_from_slice(audio);
padded_audio.extend(vec![0.0; pad_amount]);
let num_frames = (padded_audio.len() - N_FFT) / HOP_LENGTH + 1;
let freq_bins = N_FFT / 2 + 1;
let mut spectrogram = Array2::<f32>::zeros((freq_bins, num_frames));
let mut input = vec![0.0f32; N_FFT];
let mut output = r2c.make_output_vec();
let mut scratch = r2c.make_scratch_vec();
for frame_idx in 0..num_frames {
let start = frame_idx * HOP_LENGTH;
for i in 0..N_FFT {
input[i] = if start + i < padded_audio.len() {
padded_audio[start + i] * fft_window[i]
} else {
0.0
};
}
r2c.process_with_scratch(&mut input, &mut output, &mut scratch)
.map_err(|e| Error::Audio(format!("FFT failed: {e}")))?;
for k in 0..freq_bins {
spectrogram[[k, frame_idx]] = output[k].norm_sqr();
}
}
Ok(spectrogram)
}
fn extract_mel_features(&self, audio: &[f32]) -> Result<Array3<f32>> {
let preemphasized = Self::apply_preemphasis(audio);
let spectrogram = Self::stft(&preemphasized)?;
let mel_spec = self.mel_basis.dot(&spectrogram);
let log_mel_spec = mel_spec.mapv(|x| (x + LOG_ZERO_GUARD).ln());
Ok(log_mel_spec.t().to_owned().insert_axis(Axis(0)))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sine_wave(freq_hz: f32, sample_rate: usize, num_samples: usize) -> Vec<f32> {
(0..num_samples)
.map(|i| (2.0 * PI * freq_hz * i as f32 / sample_rate as f32).sin())
.collect()
}
#[test]
fn stft_concentrates_power_at_expected_bin() {
let audio = sine_wave(1000.0, SAMPLE_RATE, SAMPLE_RATE);
let spec = Sortformer::stft(&audio).unwrap();
let expected_bin = 32;
let freq_bins = N_FFT / 2 + 1;
let num_frames = spec.shape()[1];
let mut correct_frames = 0;
for frame in 2..num_frames.saturating_sub(2) {
let mut max_bin = 0;
let mut max_power = 0.0f32;
for bin in 0..freq_bins {
if spec[[bin, frame]] > max_power {
max_power = spec[[bin, frame]];
max_bin = bin;
}
}
if max_bin == expected_bin {
correct_frames += 1;
}
}
let interior_frames = num_frames.saturating_sub(4);
assert!(
correct_frames > interior_frames / 2,
"Expected bin {expected_bin} to dominate, but only {correct_frames}/{interior_frames}"
);
}
#[test]
fn stft_output_shape_is_correct() {
let audio = vec![0.0f32; SAMPLE_RATE]; let spec = Sortformer::stft(&audio).unwrap();
let freq_bins = N_FFT / 2 + 1;
assert_eq!(spec.shape()[0], freq_bins);
assert!(spec.shape()[1] > 0);
}
}