speakrs 0.3.2

Fast Rust speaker diarization with pyannote-level accuracy and native CoreML/CUDA acceleration
use std::sync::Arc;

use crossbeam_channel::{Receiver, Sender};

use crate::inference::coreml::{CachedInputShape, SharedCoreMlModel};

use super::prep::{DecodedChunk, PrepScratch, TaggedDecoded};
use super::{EmbeddingModel, PipelineError, backend_error, invariant_error};

impl ChunkEmbeddingResources {
    pub(super) fn largest_session(&self) -> Result<LargestChunkSession, PipelineError> {
        let session = self
            .chunk_sessions
            .last()
            .cloned()
            .ok_or_else(|| invariant_error("missing chunk embedding session"))?;
        let (_, fbank_frames, num_masks) = self
            .chunk_lookup
            .last()
            .copied()
            .ok_or_else(|| invariant_error("missing chunk embedding session metadata"))?;

        Ok(LargestChunkSession {
            session,
            fbank_frames,
            num_masks,
        })
    }
}

pub(super) fn chunk_embedding_resources(
    emb_model: &mut EmbeddingModel,
) -> Result<Option<ChunkEmbeddingResources>, PipelineError> {
    let Some(bundle) = emb_model.prepare_chunk_resources()? else {
        return Ok(None);
    };

    let chunk_sessions = bundle
        .sessions
        .iter()
        .map(|session| ChunkSessionHandle {
            cached_fbank_shape: Arc::clone(&session.cached_fbank_shape),
            cached_masks_shape: Arc::clone(&session.cached_masks_shape),
            model: Arc::clone(&session.model),
        })
        .collect();

    let chunk_lookup = bundle
        .sessions
        .iter()
        .map(|session| (session.num_windows, session.fbank_frames, session.num_masks))
        .collect();

    Ok(Some(ChunkEmbeddingResources {
        chunk_sessions,
        chunk_lookup,
        fbank_30s: bundle.fbank_30s,
        fbank_10s: bundle.fbank_10s,
    }))
}

pub(super) struct GpuWorker {
    pub(super) model: Arc<SharedCoreMlModel>,
    pub(super) fbank_shape: Arc<CachedInputShape>,
    pub(super) masks_shape: Arc<CachedInputShape>,
    pub(super) prep: super::ChunkPrep,
    pub(super) scratch: PrepScratch,
}

impl GpuWorker {
    fn next_prepared(
        &mut self,
        audio: &[f32],
        prep_rx: &Receiver<PreparedChunk>,
        chunk_rx: &Receiver<DecodedChunk>,
        decoded_done: &mut bool,
        total_prep_us: &mut u64,
    ) -> Result<Option<PreparedChunk>, PipelineError> {
        match prep_rx.try_recv() {
            Ok(prepared) => return Ok(Some(prepared)),
            Err(crossbeam_channel::TryRecvError::Disconnected) => return Ok(None),
            Err(crossbeam_channel::TryRecvError::Empty) => {}
        }

        if *decoded_done {
            return match prep_rx.recv() {
                Ok(prepared) => Ok(Some(prepared)),
                Err(_) => Ok(None),
            };
        }

        match chunk_rx.try_recv() {
            Ok(decoded) => {
                let prep_start = std::time::Instant::now();
                let prepared = self.prep.prep(decoded, audio, &mut self.scratch)?;
                *total_prep_us += prep_start.elapsed().as_micros() as u64;
                Ok(Some(prepared))
            }
            Err(crossbeam_channel::TryRecvError::Empty) => crossbeam_channel::select! {
                recv(prep_rx) -> message => match message {
                    Ok(prepared) => Ok(Some(prepared)),
                    Err(_) => Ok(None),
                },
                recv(chunk_rx) -> message => match message {
                    Ok(decoded) => {
                        let prep_start = std::time::Instant::now();
                        let prepared = self.prep.prep(decoded, audio, &mut self.scratch)?;
                        *total_prep_us += prep_start.elapsed().as_micros() as u64;
                        Ok(Some(prepared))
                    }
                    Err(_) => {
                        *decoded_done = true;
                        match prep_rx.recv() {
                            Ok(prepared) => Ok(Some(prepared)),
                            Err(_) => Ok(None),
                        }
                    }
                },
            },
            Err(crossbeam_channel::TryRecvError::Disconnected) => {
                *decoded_done = true;
                match prep_rx.recv() {
                    Ok(prepared) => Ok(Some(prepared)),
                    Err(_) => Ok(None),
                }
            }
        }
    }

    fn predict(&self, prepared: &PreparedChunk) -> Result<(Vec<f32>, u64), PipelineError> {
        let predict_start = std::time::Instant::now();
        let (data, _) = self
            .model
            .predict_cached(&[
                (&*self.fbank_shape, &prepared.fbank),
                (&*self.masks_shape, &prepared.masks),
            ])
            .map_err(|error| backend_error("chunk embedding prediction failed", error))?;
        Ok((data, predict_start.elapsed().as_micros() as u64))
    }

    pub(super) fn run(
        mut self,
        audio: &[f32],
        prep_rx: Receiver<PreparedChunk>,
        chunk_rx: Receiver<DecodedChunk>,
        emb_tx: Sender<EmbeddedChunk>,
    ) -> Result<GpuStats, PipelineError> {
        let mut total_predict_us = 0u64;
        let mut total_prep_us = 0u64;
        let mut chunk_num = 0u32;
        let mut decoded_done = false;

        loop {
            let Some(prepared) = self.next_prepared(
                audio,
                &prep_rx,
                &chunk_rx,
                &mut decoded_done,
                &mut total_prep_us,
            )?
            else {
                break;
            };

            let (data, predict_us) = self.predict(&prepared)?;
            total_predict_us += predict_us;
            chunk_num += 1;

            if emb_tx
                .send(EmbeddedChunk {
                    global_start: prepared.global_start,
                    decoded_chunk: prepared.decoded_chunk,
                    data,
                    active: prepared.active,
                    num_masks: prepared.num_masks,
                    predict_us,
                })
                .is_err()
            {
                break;
            }
        }

        Ok(GpuStats {
            predict_us: total_predict_us,
            chunks: chunk_num,
            self_prep_us: total_prep_us,
        })
    }
}

pub(super) struct BatchGpuWorker {
    pub(super) model: Arc<SharedCoreMlModel>,
    pub(super) fbank_shape: Arc<CachedInputShape>,
    pub(super) masks_shape: Arc<CachedInputShape>,
    pub(super) prep: super::ChunkPrep,
    pub(super) scratch: PrepScratch,
}

impl BatchGpuWorker {
    fn predict(&self, prepared: &PreparedChunk) -> Result<(Vec<f32>, u64), PipelineError> {
        let predict_start = std::time::Instant::now();
        let (data, _) = self
            .model
            .predict_cached(&[
                (&*self.fbank_shape, &prepared.fbank),
                (&*self.masks_shape, &prepared.masks),
            ])
            .map_err(|error| backend_error("batch chunk embedding prediction failed", error))?;
        Ok((data, predict_start.elapsed().as_micros() as u64))
    }

    pub(super) fn run(
        mut self,
        audios: &[&[f32]],
        prepared_rx: Receiver<TaggedPrepared>,
        decoded_rx: Receiver<TaggedDecoded>,
        embedded_tx: Sender<TaggedEmbedded>,
    ) -> Result<GpuStats, PipelineError> {
        let mut total_predict_us = 0u64;
        let mut total_prep_us = 0u64;
        let mut chunk_num = 0u32;
        let mut decoded_done = false;

        loop {
            let (file_idx, local_start, prepared) = match prepared_rx.try_recv() {
                Ok(tagged) => (tagged.file_idx, tagged.local_start, tagged.prepared),
                Err(crossbeam_channel::TryRecvError::Disconnected) => break,
                Err(crossbeam_channel::TryRecvError::Empty) => {
                    if decoded_done {
                        match prepared_rx.recv() {
                            Ok(tagged) => (tagged.file_idx, tagged.local_start, tagged.prepared),
                            Err(_) => break,
                        }
                    } else {
                        match decoded_rx.try_recv() {
                            Ok(tagged) => {
                                let audio = audios[tagged.file_idx];
                                let decoded = DecodedChunk {
                                    global_start: tagged.local_start
                                        * self.prep.chunk_win_capacity(),
                                    decoded_chunk: tagged.decoded_chunk,
                                };
                                let prep_start = std::time::Instant::now();
                                let prepared = self.prep.prep(decoded, audio, &mut self.scratch)?;
                                total_prep_us += prep_start.elapsed().as_micros() as u64;
                                (tagged.file_idx, tagged.local_start, prepared)
                            }
                            Err(crossbeam_channel::TryRecvError::Empty) => {
                                crossbeam_channel::select! {
                                    recv(prepared_rx) -> message => match message {
                                        Ok(tagged) => {
                                            (tagged.file_idx, tagged.local_start, tagged.prepared)
                                        }
                                        Err(_) => break,
                                    },
                                    recv(decoded_rx) -> message => match message {
                                        Ok(tagged) => {
                                            let audio = audios[tagged.file_idx];
                                            let decoded = DecodedChunk {
                                                global_start: tagged.local_start * self.prep.chunk_win_capacity(),
                                                decoded_chunk: tagged.decoded_chunk,
                                            };
                                            let prep_start = std::time::Instant::now();
                                            let prepared = self.prep.prep(decoded, audio, &mut self.scratch)?;
                                            total_prep_us += prep_start.elapsed().as_micros() as u64;
                                            (tagged.file_idx, tagged.local_start, prepared)
                                        }
                                        Err(_) => {
                                            decoded_done = true;
                                            match prepared_rx.recv() {
                                                Ok(tagged) => (tagged.file_idx, tagged.local_start, tagged.prepared),
                                                Err(_) => break,
                                            }
                                        }
                                    },
                                }
                            }
                            Err(crossbeam_channel::TryRecvError::Disconnected) => {
                                decoded_done = true;
                                match prepared_rx.recv() {
                                    Ok(tagged) => {
                                        (tagged.file_idx, tagged.local_start, tagged.prepared)
                                    }
                                    Err(_) => break,
                                }
                            }
                        }
                    }
                }
            };

            let (data, predict_us) = self.predict(&prepared)?;
            total_predict_us += predict_us;
            chunk_num += 1;

            if embedded_tx
                .send(TaggedEmbedded {
                    file_idx,
                    local_start,
                    embedded: EmbeddedChunk {
                        global_start: local_start * self.prep.chunk_win_capacity(),
                        decoded_chunk: prepared.decoded_chunk,
                        data,
                        active: prepared.active,
                        num_masks: prepared.num_masks,
                        predict_us,
                    },
                })
                .is_err()
            {
                break;
            }
        }

        Ok(GpuStats {
            predict_us: total_predict_us,
            chunks: chunk_num,
            self_prep_us: total_prep_us,
        })
    }
}

#[derive(Clone)]
pub(super) struct ChunkSessionHandle {
    pub(super) cached_fbank_shape: Arc<CachedInputShape>,
    pub(super) cached_masks_shape: Arc<CachedInputShape>,
    pub(super) model: Arc<SharedCoreMlModel>,
}

#[derive(Clone)]
pub(super) struct ChunkEmbeddingResources {
    pub(super) chunk_sessions: Vec<ChunkSessionHandle>,
    pub(super) chunk_lookup: Vec<(usize, usize, usize)>,
    pub(super) fbank_30s: Option<Arc<SharedCoreMlModel>>,
    pub(super) fbank_10s: Option<Arc<SharedCoreMlModel>>,
}

pub(super) struct LargestChunkSession {
    pub(super) session: ChunkSessionHandle,
    pub(super) fbank_frames: usize,
    pub(super) num_masks: usize,
}

pub(super) struct PreparedChunk {
    pub(super) global_start: usize,
    pub(super) decoded_chunk: Vec<ndarray::Array2<f32>>,
    pub(super) fbank: Vec<f32>,
    pub(super) masks: Vec<f32>,
    pub(super) active: Vec<(usize, usize)>,
    pub(super) num_masks: usize,
}

pub(super) struct EmbeddedChunk {
    pub(super) global_start: usize,
    pub(super) decoded_chunk: Vec<ndarray::Array2<f32>>,
    pub(super) data: Vec<f32>,
    pub(super) active: Vec<(usize, usize)>,
    pub(super) num_masks: usize,
    pub(super) predict_us: u64,
}

pub(super) struct TaggedPrepared {
    pub(super) file_idx: usize,
    pub(super) local_start: usize,
    pub(super) prepared: PreparedChunk,
}

pub(super) struct TaggedEmbedded {
    pub(super) file_idx: usize,
    pub(super) local_start: usize,
    pub(super) embedded: EmbeddedChunk,
}

pub(super) struct GpuStats {
    pub(super) predict_us: u64,
    pub(super) chunks: u32,
    pub(super) self_prep_us: u64,
}