speakrs 0.3.2

Fast Rust speaker diarization with pyannote-level accuracy and native CoreML/CUDA acceleration
use crossbeam_channel::Receiver;
use ndarray::{Array3, s};
use tracing::{debug, trace};

use super::collect::batch_embeddings;
use super::gpu::{ChunkEmbeddingResources, GpuWorker, chunk_embedding_resources};
use super::prep::{ChunkPrep, DecodedChunk, PrepScratch, PrepWorker};
use super::{
    ChunkParams, EmbeddingModel, EmbeddingSummary, PipelineError, SegmentationModel,
    chunk_audio_raw, chunk_session_for_windows, join_scoped_result, write_speaker_mask_to_slice,
};

type ChunkEmbeddingSetup = (usize, usize, bool, Option<ChunkEmbeddingResources>);

pub(super) fn seg_worker_count() -> usize {
    std::thread::available_parallelism()
        .map(usize::from)
        .unwrap_or(4)
        .min(8)
}

pub(super) fn setup_chunk_embedding(
    seg_model: &SegmentationModel,
    emb_model: &mut EmbeddingModel,
    audio: &[f32],
) -> Result<Option<ChunkEmbeddingSetup>, PipelineError> {
    let window_samples = seg_model.window_samples();
    if audio.len() < window_samples {
        return Ok(None);
    }

    let Some(chunk_win_capacity) = emb_model.chunk_window_capacity() else {
        return Ok(None);
    };
    let step_samples = seg_model.step_samples();
    let total_windows = audio.len().saturating_sub(window_samples) / step_samples + 1;
    let est_chunks = total_windows.div_ceil(chunk_win_capacity);
    let use_pipelined = est_chunks >= 2;
    let resources = use_pipelined
        .then(|| chunk_embedding_resources(emb_model))
        .transpose()?
        .flatten();

    Ok(Some((
        chunk_win_capacity,
        total_windows,
        use_pipelined,
        resources,
    )))
}

pub(super) fn run_pipelined<'scope>(
    scope: &'scope std::thread::Scope<'scope, '_>,
    emb_start: std::time::Instant,
    resources: ChunkEmbeddingResources,
    chunk_rx: &'scope Receiver<DecodedChunk>,
    audio: &'scope [f32],
    params: &'scope ChunkParams,
) -> Result<EmbeddingSummary, PipelineError> {
    let largest = resources.largest_session()?;
    let max_active = params.chunk_win_capacity * params.num_speakers;
    let prep_config = ChunkPrep {
        step_samples: params.step_samples,
        window_samples: params.window_samples,
        num_speakers: params.num_speakers,
        min_num_samples: params.min_num_samples,
        largest_fbank_frames: largest.fbank_frames,
        largest_num_masks: largest.num_masks,
        max_active,
        fbank_30s: resources.fbank_30s.clone(),
        fbank_10s: resources.fbank_10s.clone(),
    };

    let (prep_tx, prep_rx) = crossbeam_channel::bounded(48);
    let (emb_tx, emb_rx) = crossbeam_channel::bounded(8);

    let mut prep_handles = Vec::with_capacity(2);
    for _ in 0..2usize {
        let worker = PrepWorker {
            prep: prep_config.clone(),
            scratch: PrepScratch::new(params.window_samples),
        };
        let prep_tx = prep_tx.clone();
        prep_handles.push(scope.spawn(move || {
            worker.run(
                audio,
                params.step_samples,
                params.window_samples,
                chunk_rx,
                prep_tx,
            )
        }));
    }
    drop(prep_tx);

    let gpu_emb_tx = emb_tx.clone();
    let gpu_handle = scope.spawn(move || {
        GpuWorker {
            model: largest.session.model,
            fbank_shape: largest.session.cached_fbank_shape,
            masks_shape: largest.session.cached_masks_shape,
            prep: prep_config,
            scratch: PrepScratch::new(params.window_samples),
        }
        .run(audio, prep_rx, chunk_rx.clone(), gpu_emb_tx)
    });
    drop(emb_tx);

    let max_slots = params.total_windows + params.chunk_win_capacity;
    let first = match emb_rx.recv() {
        Ok(embedded) => embedded,
        Err(_) => {
            let _gpu_stats = join_scoped_result("chunk embedding gpu", gpu_handle)?;
            let mut total_prep_fbank_us = 0u64;
            for handle in prep_handles {
                total_prep_fbank_us += join_scoped_result("chunk embedding prep", handle)?.fbank_us;
            }
            return Ok(EmbeddingSummary {
                segmentations: Array3::zeros((0, 0, params.num_speakers)),
                embeddings: Array3::from_elem((0, params.num_speakers, 256), f32::NAN),
                num_chunks: 0,
                gpu_predict_us: 0,
                prep_fbank_us: total_prep_fbank_us,
                prep_mask_us: 0,
            });
        }
    };

    let num_frames = first.decoded_chunk[0].nrows();
    let mut seg_array = Array3::<f32>::zeros((max_slots, num_frames, params.num_speakers));
    let mut emb_array = Array3::<f32>::from_elem((max_slots, params.num_speakers, 256), f32::NAN);
    let mut max_slot_used = 0usize;
    let mut total_predict_us = 0u64;
    let mut total_chunks = 0u32;

    for embedded in std::iter::once(first).chain(std::iter::from_fn(|| emb_rx.recv().ok())) {
        total_predict_us += embedded.predict_us;
        let batch_emb = batch_embeddings(
            embedded.num_masks,
            embedded.data,
            "pipelined chunk embedding",
        )?;

        for &(local, speaker_idx) in &embedded.active {
            let slot = embedded.global_start + local;
            if slot < max_slots {
                let mask_idx = local * params.num_speakers + speaker_idx;
                emb_array
                    .slice_mut(s![slot, speaker_idx, ..])
                    .assign(&batch_emb.row(mask_idx));
            }
        }

        for (local, decoded) in embedded.decoded_chunk.into_iter().enumerate() {
            let slot = embedded.global_start + local;
            if slot < max_slots {
                seg_array.slice_mut(s![slot, .., ..]).assign(&decoded);
                max_slot_used = max_slot_used.max(slot + 1);
            }
        }

        total_chunks += 1;
    }

    let gpu_stats = join_scoped_result("chunk embedding gpu", gpu_handle)?;
    let mut total_prep_fbank_us = 0u64;
    for handle in prep_handles {
        total_prep_fbank_us += join_scoped_result("chunk embedding prep", handle)?.fbank_us;
    }

    let num_chunks = max_slot_used;
    let seg_array = seg_array.slice_move(s![..num_chunks, .., ..]);
    let emb_array = emb_array.slice_move(s![..num_chunks, .., ..]);

    debug!(
        total_chunks,
        num_chunks,
        gpu_chunks = gpu_stats.chunks,
        gpu_predict_ms = gpu_stats.predict_us / 1000,
        gpu_self_prep_ms = gpu_stats.self_prep_us / 1000,
        cpu_prep_ms = total_prep_fbank_us / 1000,
        predict_ms = total_predict_us / 1000,
        emb_wall_ms = emb_start.elapsed().as_millis(),
        "Priority-pull breakdown"
    );

    Ok(EmbeddingSummary {
        segmentations: seg_array,
        embeddings: emb_array,
        num_chunks,
        gpu_predict_us: total_predict_us,
        prep_fbank_us: total_prep_fbank_us + gpu_stats.self_prep_us,
        prep_mask_us: 0,
    })
}

pub(super) fn run_sequential_chunks(
    emb_model: &mut EmbeddingModel,
    chunk_rx: &Receiver<DecodedChunk>,
    audio: &[f32],
    params: &ChunkParams,
    emb_start: std::time::Instant,
) -> Result<EmbeddingSummary, PipelineError> {
    let max_slots = params.total_windows + params.chunk_win_capacity;
    let mut seg_array: Option<Array3<f32>> = None;
    let mut emb_array: Option<Array3<f32>> = None;
    let mut seq_idx = 0usize;
    let mut seq_fbank_us = 0u64;
    let mut seq_mask_us = 0u64;
    let mut seq_predict_us = 0u64;
    let mut seq_chunks = 0u32;

    for DecodedChunk {
        global_start,
        decoded_chunk,
    } in chunk_rx
    {
        let wins = decoded_chunk.len();
        let session = chunk_session_for_windows(emb_model, wins)?;
        let sess_fbank_frames = session.fbank_frames;
        let sess_num_masks = session.num_masks;

        let chunk_audio_start = global_start * params.step_samples;
        if chunk_audio_start + params.window_samples > audio.len() {
            continue;
        }
        let chunk_audio_len = params.window_samples + (wins - 1) * params.step_samples;
        let chunk_audio_end = (chunk_audio_start + chunk_audio_len).min(audio.len());
        let chunk_audio = &audio[chunk_audio_start..chunk_audio_end];

        let num_frames = decoded_chunk[0].nrows();
        let seg = seg_array
            .get_or_insert_with(|| Array3::zeros((max_slots, num_frames, params.num_speakers)));
        let emb = emb_array.get_or_insert_with(|| {
            Array3::from_elem((max_slots, params.num_speakers, 256), f32::NAN)
        });

        let mut fbank = vec![0.0f32; sess_fbank_frames * 80];
        let fbank_start = std::time::Instant::now();

        if let Some(full_fbank) = emb_model.compute_chunk_fbank_30s(chunk_audio)? {
            let copy_frames = full_fbank.nrows().min(sess_fbank_frames);
            for row_idx in 0..copy_frames {
                let dst = row_idx * 80;
                let row_view = full_fbank.row(row_idx);
                let row = row_view.as_slice().ok_or_else(|| {
                    super::invariant_error("30s chunk fbank row was not contiguous")
                })?;
                fbank[dst..dst + 80].copy_from_slice(row);
            }
        } else {
            let mut fbank_offset = 0usize;
            let mut audio_offset = 0usize;
            while fbank_offset < sess_fbank_frames && audio_offset < chunk_audio.len() {
                let segment_end = (audio_offset + params.window_samples).min(chunk_audio.len());
                let seg_fbank =
                    emb_model.compute_chunk_fbank(&chunk_audio[audio_offset..segment_end])?;
                let copy = seg_fbank.nrows().min(sess_fbank_frames - fbank_offset);
                for row_idx in 0..copy {
                    let dst = (fbank_offset + row_idx) * 80;
                    let row_view = seg_fbank.row(row_idx);
                    let row = row_view.as_slice().ok_or_else(|| {
                        super::invariant_error("10s chunk fbank row was not contiguous")
                    })?;
                    fbank[dst..dst + 80].copy_from_slice(row);
                }
                fbank_offset += 998;
                audio_offset += params.window_samples;
            }
        }

        seq_fbank_us += fbank_start.elapsed().as_micros() as u64;
        let mask_start = std::time::Instant::now();
        let mut masks = vec![0.0f32; sess_num_masks * 589];
        let mut active = Vec::new();
        for (local_idx, decoded) in decoded_chunk.iter().enumerate() {
            let global_idx = global_start + local_idx;
            let win_audio = chunk_audio_raw(
                audio,
                params.step_samples,
                params.window_samples,
                global_idx,
            );
            for speaker_idx in 0..params.num_speakers {
                let mask_idx = local_idx * params.num_speakers + speaker_idx;
                if mask_idx >= sess_num_masks {
                    break;
                }
                let dest = &mut masks[mask_idx * 589..mask_idx * 589 + 589];
                if write_speaker_mask_to_slice(
                    &decoded.view(),
                    speaker_idx,
                    win_audio.len(),
                    params.min_num_samples,
                    dest,
                ) {
                    active.push((local_idx, speaker_idx));
                }
            }
        }
        seq_mask_us += mask_start.elapsed().as_micros() as u64;

        let predict_start = std::time::Instant::now();
        let session = chunk_session_for_windows(emb_model, wins)?;
        let batch_emb = EmbeddingModel::embed_chunk_session(session, &fbank, &masks)?;
        seq_predict_us += predict_start.elapsed().as_micros() as u64;
        seq_chunks += 1;

        for &(local_idx, speaker_idx) in &active {
            let mask_idx = local_idx * params.num_speakers + speaker_idx;
            emb.slice_mut(s![seq_idx + local_idx, speaker_idx, ..])
                .assign(&batch_emb.row(mask_idx));
        }
        for (local_idx, decoded) in decoded_chunk.into_iter().enumerate() {
            seg.slice_mut(s![seq_idx + local_idx, .., ..])
                .assign(&decoded);
        }
        seq_idx += wins;
    }

    let num_chunks = seq_idx;
    let seg_array = match seg_array {
        Some(array) => array.slice_move(s![..num_chunks, .., ..]),
        None => Array3::zeros((0, 0, params.num_speakers)),
    };
    let emb_array = match emb_array {
        Some(array) => array.slice_move(s![..num_chunks, .., ..]),
        None => Array3::from_elem((0, params.num_speakers, 256), f32::NAN),
    };

    trace!(
        seq_chunks,
        num_chunks,
        fbank_ms = seq_fbank_us / 1000,
        mask_ms = seq_mask_us / 1000,
        predict_ms = seq_predict_us / 1000,
        wall_ms = emb_start.elapsed().as_millis(),
        "EMB sequential",
    );

    Ok(EmbeddingSummary {
        segmentations: seg_array,
        embeddings: emb_array,
        num_chunks,
        gpu_predict_us: seq_predict_us,
        prep_fbank_us: seq_fbank_us,
        prep_mask_us: seq_mask_us,
    })
}