whisper-cpp-plus 0.1.4

Safe Rust bindings for whisper.cpp with real-time PCM streaming and VAD support
Documentation
//! Integration tests for model quantization functionality

#![cfg(feature = "quantization")]

mod common;

use common::TestModels;
use std::fs;
use whisper_cpp_plus::{WhisperQuantize, QuantizationType};

#[test]
fn test_quantization_types() {
    let types = [
        QuantizationType::Q4_0,
        QuantizationType::Q4_1,
        QuantizationType::Q5_0,
        QuantizationType::Q5_1,
        QuantizationType::Q8_0,
        QuantizationType::Q2_K,
        QuantizationType::Q3_K,
        QuantizationType::Q4_K,
        QuantizationType::Q5_K,
        QuantizationType::Q6_K,
    ];

    for qtype in &types {
        assert!(!qtype.name().is_empty());

        let factor = qtype.size_factor();
        assert!(factor > 0.0 && factor < 1.0,
            "{} has invalid size factor: {}", qtype, factor);
    }
}

#[test]
fn test_quantization_type_parsing() {
    assert_eq!("Q4_0".parse::<QuantizationType>().unwrap(), QuantizationType::Q4_0);
    assert_eq!("q4_0".parse::<QuantizationType>().unwrap(), QuantizationType::Q4_0);
    assert_eq!("Q40".parse::<QuantizationType>().unwrap(), QuantizationType::Q4_0);

    assert_eq!("Q5_K".parse::<QuantizationType>().unwrap(), QuantizationType::Q5_K);
    assert_eq!("q5k".parse::<QuantizationType>().unwrap(), QuantizationType::Q5_K);

    assert!("invalid".parse::<QuantizationType>().is_err());
    assert!("".parse::<QuantizationType>().is_err());
}

#[test]
fn test_quantization_display() {
    assert_eq!(format!("{}", QuantizationType::Q4_0), "Q4_0");
    assert_eq!(format!("{}", QuantizationType::Q5_K), "Q5_K");
}

#[test]
fn test_quantize_model() {
    let Some(model_path) = TestModels::tiny_en() else {
        eprintln!("Skipping: model not found. Run `cargo xtask test-setup`");
        return;
    };

    let output_path = model_path.with_file_name("ggml-tiny.en-q5_0.bin");
    let _ = fs::remove_file(&output_path);

    let result = WhisperQuantize::quantize_model_file(
        model_path.to_str().unwrap(),
        output_path.to_str().unwrap(),
        QuantizationType::Q5_0,
    );

    assert!(result.is_ok(), "Quantization failed: {:?}", result);
    assert!(output_path.exists(), "Output file was not created");

    let input_size = fs::metadata(&model_path).unwrap().len();
    let output_size = fs::metadata(&output_path).unwrap().len();
    assert!(output_size < input_size,
        "Quantized model should be smaller: {} >= {}", output_size, input_size);

    let _ = fs::remove_file(&output_path);
}

#[test]
fn test_quantize_with_progress() {
    use std::sync::atomic::{AtomicU32, Ordering};
    use std::sync::Arc;

    let Some(model_path) = TestModels::tiny_en() else {
        eprintln!("Skipping: model not found. Run `cargo xtask test-setup`");
        return;
    };

    let output_path = model_path.with_file_name("ggml-tiny.en-q4_0.bin");
    let _ = fs::remove_file(&output_path);

    let call_count = Arc::new(AtomicU32::new(0));
    let call_count_clone = Arc::clone(&call_count);

    let result = WhisperQuantize::quantize_model_file_with_progress(
        model_path.to_str().unwrap(),
        output_path.to_str().unwrap(),
        QuantizationType::Q4_0,
        move |progress| {
            assert!(progress >= 0.0 && progress <= 1.0, "Invalid progress value: {}", progress);
            call_count_clone.fetch_add(1, Ordering::Relaxed);
        },
    );

    assert!(result.is_ok(), "Quantization failed: {:?}", result);
    assert!(output_path.exists(), "Output file was not created");

    let total_calls = call_count.load(Ordering::Relaxed);
    // Per-tensor progress: should fire many more than 2 times
    // (tiny model has ~50+ tensors, plus the initial 0.0 callback)
    assert!(total_calls > 2,
        "Expected per-tensor progress (>2 calls), got {} calls", total_calls);
    eprintln!("Progress callback fired {} times", total_calls);

    let _ = fs::remove_file(&output_path);
}

#[test]
fn test_get_model_quantization_type() {
    let Some(model_path) = TestModels::tiny_en() else {
        eprintln!("Skipping: model not found. Run `cargo xtask test-setup`");
        return;
    };

    let result = WhisperQuantize::get_model_quantization_type(model_path.to_str().unwrap());
    assert!(result.is_ok(), "Failed to check model type: {:?}", result);

    match result.unwrap() {
        Some(qtype) => println!("Model is quantized as: {}", qtype),
        None => println!("Model is in full precision"),
    }
}

#[test]
fn test_estimate_quantized_size() {
    let Some(model_path) = TestModels::tiny_en() else {
        eprintln!("Skipping: model not found. Run `cargo xtask test-setup`");
        return;
    };

    let original_size = fs::metadata(&model_path).unwrap().len();

    for qtype in QuantizationType::all() {
        let estimated = WhisperQuantize::estimate_quantized_size(model_path.to_str().unwrap(), *qtype).unwrap();

        assert!(estimated < original_size,
            "{} estimation {} >= original {}", qtype, estimated, original_size);

        let expected = (original_size as f64 * qtype.size_factor() as f64) as u64;
        let diff = if estimated > expected {
            estimated - expected
        } else {
            expected - estimated
        };

        let margin = (expected as f64 * 0.1) as u64;
        assert!(diff < margin,
            "{}: estimated {} differs too much from expected {} (diff: {})",
            qtype, estimated, expected, diff);
    }
}

#[test]
fn test_error_handling() {
    let result = WhisperQuantize::quantize_model_file(
        "non_existent_model.bin",
        "output.bin",
        QuantizationType::Q4_0,
    );
    assert!(result.is_err(), "Should fail with non-existent input");

    let result = WhisperQuantize::get_model_quantization_type("non_existent.bin");
    assert!(result.is_err(), "Should fail with non-existent file");

    let result = WhisperQuantize::estimate_quantized_size(
        "non_existent.bin",
        QuantizationType::Q5_0
    );
    assert!(result.is_err(), "Should fail with non-existent file");
}