speakrs 0.3.2

Fast Rust speaker diarization with pyannote-level accuracy and native CoreML/CUDA acceleration
use ndarray::{Array2, Array3, s};

use super::gpu::EmbeddedChunk;
use super::{
    ChunkEmbeddings, ChunkLayout, DecodedSegmentations, InferenceArtifacts, PipelineError,
    invariant_error,
};

pub(super) fn batch_embeddings(
    num_masks: usize,
    data: Vec<f32>,
    context: &str,
) -> Result<Array2<f32>, PipelineError> {
    Array2::from_shape_vec((num_masks, 256), data).map_err(|error| {
        invariant_error(format!(
            "{context} produced invalid embedding shape: {error}"
        ))
    })
}

pub(super) fn build_chunk_artifacts(
    step_seconds: f64,
    step_samples: usize,
    window_samples: usize,
    summary: super::EmbeddingSummary,
) -> Option<InferenceArtifacts> {
    if summary.num_chunks == 0 {
        return None;
    }
    Some(InferenceArtifacts {
        layout: ChunkLayout::new(
            step_seconds,
            step_samples,
            window_samples,
            summary.num_chunks,
        ),
        segmentations: DecodedSegmentations(summary.segmentations),
        embeddings: ChunkEmbeddings(summary.embeddings),
    })
}

pub(super) struct FileCollector {
    seg_array: Array3<f32>,
    emb_array: Array3<f32>,
    max_slot_used: usize,
    chunks_received: usize,
    expected_chunks: usize,
}

impl FileCollector {
    pub(super) fn new(
        max_slots: usize,
        num_frames: usize,
        num_speakers: usize,
        expected_chunks: usize,
    ) -> Self {
        Self {
            seg_array: Array3::zeros((max_slots, num_frames, num_speakers)),
            emb_array: Array3::from_elem((max_slots, num_speakers, 256), f32::NAN),
            max_slot_used: 0,
            chunks_received: 0,
            expected_chunks,
        }
    }

    pub(super) fn add(
        &mut self,
        local_start: usize,
        chunk_win_capacity: usize,
        num_speakers: usize,
        embedded: EmbeddedChunk,
    ) -> Result<(), PipelineError> {
        let batch_emb =
            batch_embeddings(embedded.num_masks, embedded.data, "batch chunk embedding")?;

        for &(local, speaker_idx) in &embedded.active {
            let slot = local_start * chunk_win_capacity + local;
            if slot < self.emb_array.shape()[0] {
                let mask_idx = local * num_speakers + speaker_idx;
                self.emb_array
                    .slice_mut(s![slot, speaker_idx, ..])
                    .assign(&batch_emb.row(mask_idx));
            }
        }

        for (local, decoded) in embedded.decoded_chunk.into_iter().enumerate() {
            let slot = local_start * chunk_win_capacity + local;
            if slot < self.seg_array.shape()[0] {
                self.seg_array.slice_mut(s![slot, .., ..]).assign(&decoded);
                self.max_slot_used = self.max_slot_used.max(slot + 1);
            }
        }

        self.chunks_received += 1;
        Ok(())
    }

    pub(super) fn is_complete(&self) -> bool {
        self.chunks_received >= self.expected_chunks
    }

    pub(super) fn into_artifacts(
        self,
        step_seconds: f64,
        step_samples: usize,
        window_samples: usize,
    ) -> Option<InferenceArtifacts> {
        if self.max_slot_used == 0 {
            return None;
        }
        let n = self.max_slot_used;
        Some(InferenceArtifacts {
            layout: ChunkLayout::new(step_seconds, step_samples, window_samples, n),
            segmentations: DecodedSegmentations(self.seg_array.slice_move(s![..n, .., ..])),
            embeddings: ChunkEmbeddings(self.emb_array.slice_move(s![..n, .., ..])),
        })
    }
}