transcribe-cli 0.0.4

Whisper CLI transcription pipeline on CTranslate2 with CPU and optional CUDA support
use std::fs::File;
use std::io::BufReader;
use std::path::Path;

use anyhow::{Context, Result, anyhow};
use ct2rs::Config;
use ct2rs::sys::{Device, StorageView, Whisper as SysWhisper};
use mel_spec::mel::{log_mel_spectrogram, mel, norm_mel};
use mel_spec::stft::Spectrogram;
use ndarray::{Array2, Axis, s, stack};
use serde::Deserialize;
use tokenizers::{Decoder, Tokenizer};

const PREPROCESSOR_CONFIG_FILE: &str = "preprocessor_config.json";
const TOKENIZER_FILE: &str = "tokenizer.json";

pub use ct2rs::sys::WhisperOptions;

pub struct Whisper {
    inner: SysWhisper,
    tokenizer: Tokenizer,
    config: PreprocessorConfig,
}

impl Whisper {
    pub fn new<T: AsRef<Path>>(model_path: T, config: Config) -> Result<Self> {
        let model_path = model_path.as_ref();
        let tokenizer_path = model_path.join(TOKENIZER_FILE);
        let preprocessor_path = model_path.join(PREPROCESSOR_CONFIG_FILE);

        Ok(Self {
            inner: SysWhisper::new(model_path, config)?,
            tokenizer: Tokenizer::from_file(&tokenizer_path).map_err(|error| {
                anyhow!(
                    "failed to load tokenizer from `{}`: {error}",
                    tokenizer_path.display()
                )
            })?,
            config: PreprocessorConfig::read(&preprocessor_path).with_context(|| {
                format!(
                    "failed to load Whisper preprocessor config from `{}`",
                    preprocessor_path.display()
                )
            })?,
        })
    }

    pub fn generate(
        &self,
        samples: &[f32],
        language: Option<&str>,
        timestamp: bool,
        options: &WhisperOptions,
    ) -> Result<Vec<String>> {
        if samples.is_empty() {
            return Ok(Vec::new());
        }

        let mut stft = Spectrogram::new(self.config.n_fft, self.config.hop_length);
        let mut mel_batches = Vec::new();

        for chunk in samples.chunks(self.config.n_samples.max(1)) {
            let mut mel_per_chunk =
                Array2::zeros((self.config.feature_size, self.config.nb_max_frames));

            for (frame_index, frame) in chunk.chunks(self.config.hop_length).enumerate() {
                if frame_index >= self.config.nb_max_frames {
                    break;
                }

                if let Some(fft_frame) = stft.add(frame) {
                    let mel = norm_mel(&log_mel_spectrogram(&fft_frame, &self.config.mel_filters))
                        .mapv(|value| value as f32);
                    mel_per_chunk
                        .slice_mut(s![.., frame_index])
                        .assign(&mel.slice(s![.., 0]));
                }
            }

            mel_batches.push(mel_per_chunk);
        }

        let mut mel_spectrogram = stack(
            Axis(0),
            &mel_batches
                .iter()
                .map(|batch| batch.view())
                .collect::<Vec<_>>(),
        )?;

        if !mel_spectrogram.is_standard_layout() {
            mel_spectrogram = mel_spectrogram.as_standard_layout().into_owned();
        }

        let shape = mel_spectrogram.shape().to_vec();
        let storage = StorageView::new(
            &shape,
            mel_spectrogram
                .as_slice_mut()
                .context("failed to access mel spectrogram buffer")?,
            Device::CPU,
        )?;

        let language_token = match language {
            Some(language) => format!("<|{language}|>"),
            None => self.detect_language(&storage)?,
        };

        let mut prompt = vec![
            String::from("<|startoftranscript|>"),
            language_token,
            String::from("<|transcribe|>"),
        ];
        if !timestamp {
            prompt.push(String::from("<|notimestamps|>"));
        }

        let prompts = vec![prompt; mel_batches.len()];
        self.inner
            .generate(&storage, &prompts, options)?
            .into_iter()
            .map(|result| self.decode_result(result))
            .collect()
    }

    pub fn sampling_rate(&self) -> usize {
        self.config.sampling_rate
    }

    pub fn n_samples(&self) -> usize {
        self.config.n_samples
    }

    fn detect_language(&self, storage: &StorageView<'_>) -> Result<String> {
        self.inner
            .detect_language(storage)?
            .into_iter()
            .next()
            .and_then(|result| result.into_iter().next())
            .map(|result| result.language)
            .ok_or_else(|| anyhow!("failed to detect language"))
    }

    fn decode_result(&self, result: ct2rs::sys::WhisperGenerationResult) -> Result<String> {
        let tokens = result
            .sequences
            .into_iter()
            .next()
            .ok_or_else(|| anyhow!("failed to decode empty Whisper sequence"))?;
        let decoder = self
            .tokenizer
            .get_decoder()
            .ok_or_else(|| anyhow!("tokenizer does not provide a decoder"))?;

        decoder
            .decode(tokens)
            .map_err(|error| anyhow!("failed to decode Whisper tokens: {error}"))
    }
}

#[derive(Debug)]
#[allow(dead_code)]
struct PreprocessorConfig {
    chunk_length: usize,
    feature_extractor_type: String,
    feature_size: usize,
    hop_length: usize,
    n_fft: usize,
    n_samples: usize,
    nb_max_frames: usize,
    padding_side: String,
    padding_value: f32,
    processor_class: String,
    return_attention_mask: bool,
    sampling_rate: usize,
    mel_filters: Array2<f64>,
}

impl PreprocessorConfig {
    fn read<T: AsRef<Path>>(path: T) -> Result<Self> {
        let file = File::open(path)?;
        let reader = BufReader::new(file);

        #[derive(Deserialize)]
        struct PreprocessorConfigAux {
            chunk_length: usize,
            feature_extractor_type: String,
            feature_size: usize,
            hop_length: usize,
            n_fft: usize,
            n_samples: usize,
            nb_max_frames: usize,
            padding_side: String,
            padding_value: f32,
            processor_class: String,
            return_attention_mask: bool,
            sampling_rate: usize,
            mel_filters: Option<Vec<Vec<f64>>>,
        }

        let aux: PreprocessorConfigAux = serde_json::from_reader(reader)?;
        let mel_filters = if let Some(mel_filters) = aux.mel_filters {
            let rows = mel_filters.len();
            let cols = mel_filters.first().map(|row| row.len()).unwrap_or_default();
            Array2::from_shape_vec((rows, cols), mel_filters.into_iter().flatten().collect())?
        } else {
            mel(
                aux.sampling_rate as f64,
                aux.n_fft,
                aux.feature_size,
                None,
                None,
                false,
                true,
            )
        };

        Ok(Self {
            chunk_length: aux.chunk_length,
            feature_extractor_type: aux.feature_extractor_type,
            feature_size: aux.feature_size,
            hop_length: aux.hop_length,
            n_fft: aux.n_fft,
            n_samples: aux.n_samples,
            nb_max_frames: aux.nb_max_frames,
            padding_side: aux.padding_side,
            padding_value: aux.padding_value,
            processor_class: aux.processor_class,
            return_attention_mask: aux.return_attention_mask,
            sampling_rate: aux.sampling_rate,
            mel_filters,
        })
    }
}