speakrs 0.3.2

Fast Rust speaker diarization with pyannote-level accuracy and native CoreML/CUDA acceleration
use ndarray::{Array2, s};
use ort::value::TensorRef;

#[cfg(feature = "coreml")]
use super::tensor::{array2_slice, array3_slice};
use super::{
    EmbeddingModel, FBANK_FEATURES, FBANK_FRAMES, MULTI_MASK_BATCH_SIZE, MaskedEmbeddingInput,
    NUM_SPEAKERS, PRIMARY_BATCH_SIZE, SplitTailInput, array2_from_shape_vec, array2_slice_mut,
    array3_slice_mut, select_mask,
};

impl EmbeddingModel {
    /// Extract speaker embeddings for a batch of masked audio windows
    pub fn embed_batch(
        &mut self,
        inputs: &[MaskedEmbeddingInput<'_>],
    ) -> Result<Array2<f32>, ort::Error> {
        if let Some(sess) = self
            .ort
            .primary_batched_session
            .as_mut()
            .filter(|_| inputs.len() == PRIMARY_BATCH_SIZE)
        {
            for (batch_idx, input) in inputs.iter().enumerate() {
                let used_mask = select_mask(
                    input.mask,
                    input.clean_mask,
                    input.audio.len(),
                    self.meta.min_num_samples,
                );
                Self::prepare_waveform(
                    batch_idx,
                    input.audio,
                    self.meta.window_samples,
                    &mut self.buffers.primary_batch_waveform_buffer.view_mut(),
                );
                Self::prepare_weights(
                    batch_idx,
                    used_mask,
                    self.meta.mask_frames,
                    &mut self.buffers.primary_batch_weights_buffer.view_mut(),
                );
            }

            let waveform_tensor =
                TensorRef::from_array_view(self.buffers.primary_batch_waveform_buffer.view())?;
            let weights_tensor =
                TensorRef::from_array_view(self.buffers.primary_batch_weights_buffer.view())?;
            let ort_inputs =
                ort::inputs!["waveform" => waveform_tensor, "weights" => weights_tensor];
            let outputs = if let Some(opts) = &self.ort.primary_batch_run_options {
                sess.run_with_options(ort_inputs, opts)?
            } else {
                sess.run(ort_inputs)?
            };
            let (_shape, data) = outputs[0].try_extract_tensor::<f32>()?;
            let n = inputs.len();
            let mut result = Array2::<f32>::zeros((n, 256));
            array2_slice_mut(&mut result, "batched embedding output")?
                .copy_from_slice(&data[..n * 256]);
            return Ok(result);
        }

        let mut stacked = Array2::<f32>::zeros((inputs.len(), 256));
        for (idx, input) in inputs.iter().enumerate() {
            let embedding = self.embed_masked(input.audio, input.mask, input.clean_mask)?;
            stacked.row_mut(idx).assign(&embedding);
        }
        Ok(stacked)
    }

    pub(crate) fn embed_multi_mask_batch(
        &mut self,
        fbanks: &[&Array2<f32>],
        masks: &[&[f32]],
    ) -> Result<Array2<f32>, ort::Error> {
        let num_fbanks = fbanks.len();
        let num_masks = masks.len();
        debug_assert_eq!(num_masks, num_fbanks * NUM_SPEAKERS);
        debug_assert!(num_fbanks <= MULTI_MASK_BATCH_SIZE);

        let fbank_row_stride = FBANK_FRAMES * FBANK_FEATURES;
        for (idx, fbank) in fbanks.iter().enumerate() {
            self.buffers
                .multi_mask_fbank_buffer
                .slice_mut(s![idx, ..fbank.nrows(), ..fbank.ncols()])
                .assign(fbank);
        }

        for (idx, mask) in masks.iter().enumerate() {
            Self::prepare_weights(
                idx,
                mask,
                self.meta.mask_frames,
                &mut self.buffers.multi_mask_masks_buffer.view_mut(),
            );
        }
        if num_fbanks < MULTI_MASK_BATCH_SIZE {
            let start = num_fbanks * fbank_row_stride;
            let buf = array3_slice_mut(
                &mut self.buffers.multi_mask_fbank_buffer,
                "multi-mask fbank scratch buffer",
            )?;
            buf[start..].fill(0.0);
        }
        if num_masks < MULTI_MASK_BATCH_SIZE * NUM_SPEAKERS {
            self.buffers
                .multi_mask_masks_buffer
                .slice_mut(s![num_masks.., ..])
                .fill(0.0);
        }

        let full_mask_batch = MULTI_MASK_BATCH_SIZE * NUM_SPEAKERS;

        #[cfg(feature = "coreml")]
        {
            self.ensure_native_multi_mask_loaded()?;
        }
        #[cfg(feature = "coreml")]
        if let Some(native) = self.coreml.native_multi_mask_session.as_ref() {
            let fbank_data = array3_slice(
                &self.buffers.multi_mask_fbank_buffer,
                "native multi-mask fbank input",
            )?;
            let masks_data = array2_slice(
                &self.buffers.multi_mask_masks_buffer,
                "native multi-mask masks input",
            )?;
            let (data, _) = native
                .predict_cached(&[
                    (&self.coreml.cached_multi_mask_fbank_shape, fbank_data),
                    (&self.coreml.cached_multi_mask_masks_shape, masks_data),
                ])
                .map_err(|e| ort::Error::new(e.to_string()))?;
            let batch =
                array2_from_shape_vec(full_mask_batch, 256, data, "native multi-mask output")?;
            return Ok(batch.slice(s![0..num_masks, ..]).to_owned());
        }

        let use_batched =
            num_fbanks == MULTI_MASK_BATCH_SIZE && self.ort.multi_mask_batched_session.is_some();

        if use_batched {
            let fbank_tensor =
                TensorRef::from_array_view(self.buffers.multi_mask_fbank_buffer.view())?;
            let masks_tensor =
                TensorRef::from_array_view(self.buffers.multi_mask_masks_buffer.view())?;
            let outputs = self
                .ort
                .multi_mask_batched_session
                .as_mut()
                .ok_or_else(|| ort::Error::new("missing multi-mask batched session"))?
                .run(ort::inputs!["fbank" => fbank_tensor, "masks" => masks_tensor])?;
            let (_shape, data) = outputs[0].try_extract_tensor::<f32>()?;
            let batch = array2_from_shape_vec(
                full_mask_batch,
                256,
                data.to_vec(),
                "multi-mask batched output",
            )?;
            Ok(batch.slice(s![0..num_masks, ..]).to_owned())
        } else {
            let mut all_embeddings = Array2::<f32>::zeros((num_masks, 256));
            for fbank_idx in 0..num_fbanks {
                let fbank_slice = self.buffers.multi_mask_fbank_buffer.slice(s![
                    fbank_idx..fbank_idx + 1,
                    ..,
                    ..
                ]);
                let mask_start = fbank_idx * NUM_SPEAKERS;
                let mask_end = mask_start + NUM_SPEAKERS;
                let masks_slice = self
                    .buffers
                    .multi_mask_masks_buffer
                    .slice(s![mask_start..mask_end, ..]);
                let fbank_tensor = TensorRef::from_array_view(fbank_slice.view())?;
                let masks_tensor = TensorRef::from_array_view(masks_slice.view())?;
                let outputs = self
                    .ort
                    .multi_mask_session
                    .as_mut()
                    .ok_or_else(|| ort::Error::new("missing multi-mask session"))?
                    .run(ort::inputs!["fbank" => fbank_tensor, "masks" => masks_tensor])?;
                let (_shape, data) = outputs[0].try_extract_tensor::<f32>()?;
                for (local_idx, row_idx) in (mask_start..mask_end).enumerate() {
                    let start = local_idx * 256;
                    all_embeddings
                        .row_mut(row_idx)
                        .assign(&ndarray::ArrayView1::from(&data[start..start + 256]));
                }
            }
            Ok(all_embeddings)
        }
    }

    pub(crate) fn embed_tail_batch_inputs(
        &mut self,
        inputs: &[SplitTailInput<'_>],
    ) -> Result<Array2<f32>, ort::Error> {
        debug_assert!(inputs.len() <= PRIMARY_BATCH_SIZE);

        let row_stride = FBANK_FRAMES * FBANK_FEATURES;
        for (batch_idx, input) in inputs.iter().enumerate() {
            debug_assert_eq!(input.fbank.ncols(), FBANK_FEATURES);

            if batch_idx > 0 && std::ptr::eq(input.fbank, inputs[batch_idx - 1].fbank) {
                let buf = self
                    .buffers
                    .split_primary_feature_batch_buffer
                    .as_slice_mut()
                    .ok_or_else(|| {
                        ort::Error::new("split primary feature batch buffer was not contiguous")
                    })?;
                let prev_start = (batch_idx - 1) * row_stride;
                buf.copy_within(prev_start..prev_start + row_stride, batch_idx * row_stride);
            } else {
                self.buffers
                    .split_primary_feature_batch_buffer
                    .slice_mut(s![batch_idx, ..input.fbank.nrows(), ..input.fbank.ncols()])
                    .assign(input.fbank);
            }

            Self::prepare_weights(
                batch_idx,
                input.weights,
                self.meta.mask_frames,
                &mut self.buffers.split_primary_weights_batch_buffer.view_mut(),
            );
        }
        if inputs.len() < PRIMARY_BATCH_SIZE {
            self.buffers
                .split_primary_weights_batch_buffer
                .slice_mut(s![inputs.len().., ..])
                .fill(0.0);
        }

        #[cfg(feature = "coreml")]
        {
            self.ensure_native_tail_primary_batched_loaded()?;
        }
        #[cfg(feature = "coreml")]
        if let Some(native) = self.coreml.native_tail_primary_batched_session.as_mut() {
            let fbank_data = array3_slice(
                &self.buffers.split_primary_feature_batch_buffer,
                "native primary tail fbank input",
            )?;
            let weights_data = array2_slice(
                &self.buffers.split_primary_weights_batch_buffer,
                "native primary tail weights input",
            )?;
            let (data, _) = native
                .predict_cached(&[
                    (&self.coreml.cached_tail_fbank_shape, fbank_data),
                    (&self.coreml.cached_tail_weights_shape, weights_data),
                ])
                .map_err(|e| ort::Error::new(e.to_string()))?;
            let batch =
                array2_from_shape_vec(PRIMARY_BATCH_SIZE, 256, data, "native primary tail output")?;
            return Ok(batch.slice(s![0..inputs.len(), ..]).to_owned());
        }

        let fbank_tensor =
            TensorRef::from_array_view(self.buffers.split_primary_feature_batch_buffer.view())?;
        let weights_tensor =
            TensorRef::from_array_view(self.buffers.split_primary_weights_batch_buffer.view())?;
        let outputs = self
            .ort
            .split_primary_tail_batched_session
            .as_mut()
            .ok_or_else(|| ort::Error::new("missing primary tail batched session"))?
            .run(ort::inputs!["fbank" => fbank_tensor, "weights" => weights_tensor])?;
        let (_shape, data) = outputs[0].try_extract_tensor::<f32>()?;
        let batch = array2_from_shape_vec(
            PRIMARY_BATCH_SIZE,
            256,
            data.to_vec(),
            "primary tail batched output",
        )?;
        Ok(batch.slice(s![0..inputs.len(), ..]).to_owned())
    }
}