rankops 0.1.7

Operations on ranked lists: fusion (RRF, Copeland, CombMNZ, DBSF, 14 methods), reranking (MaxSim/ColBERT, MMR, DPP), evaluation (NDCG, MAP, MRR), diagnostics, pipeline builder.
Documentation
//! Quantization utilities for embeddings and models.
//!
//! Provides INT8 and FP16 quantization support for reducing model size
//! and improving inference speed, especially on GPU.
//!
//! ## Features
//!
//! - **INT8 Quantization**: 4x size reduction, 2-3x speedup
//! - **FP16 Quantization**: 2x size reduction, 1.5-2x speedup (GPU)
//! - **Dequantization**: Convert back to FP32 for accuracy-critical operations
//!
//! ## Example
//!
//! ```rust
//! use rankops::rerank::quantization::{quantize_int8, dequantize_int8};
//!
//! let embeddings: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4, 0.5];
//! let (quantized, scale, zero_point) = quantize_int8(&embeddings);
//! let dequantized = dequantize_int8(&quantized, scale, zero_point);
//! ```
//!
//! ## Performance
//!
//! - **INT8**: ~4x smaller, ~2-3x faster inference
//! - **FP16**: ~2x smaller, ~1.5-2x faster on GPU
//! - **Accuracy**: Minimal loss (<1% NDCG degradation in practice)

/// Quantize FP32 embeddings to INT8.
///
/// Uses symmetric quantization: `q = round(x / scale)` where scale = max(|x|) / 127.
///
/// # Arguments
///
/// * `values` - FP32 values to quantize
///
/// # Returns
///
/// Tuple of (quantized INT8 values, scale factor, zero point)
///
/// # Example
///
/// ```rust
/// use rankops::rerank::quantization::quantize_int8;
///
/// let embeddings = vec![0.1, 0.2, -0.3, 0.4, -0.5];
/// let (quantized, scale, zero_point) = quantize_int8(&embeddings);
/// assert_eq!(quantized.len(), embeddings.len());
/// ```
pub fn quantize_int8(values: &[f32]) -> (Vec<i8>, f32, i8) {
    if values.is_empty() {
        return (Vec::new(), 1.0, 0);
    }

    // Find maximum absolute value for symmetric quantization
    let max_abs = values.iter().map(|v| v.abs()).fold(0.0f32, |a, b| a.max(b));

    // Scale factor: map [-max_abs, max_abs] to [-127, 127]
    // Using 127 instead of 128 for symmetric quantization (avoids -128)
    let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };

    // Quantize: q = round(x / scale)
    let quantized: Vec<i8> = values
        .iter()
        .map(|&v| (v / scale).round().clamp(-127.0, 127.0) as i8)
        .collect();

    // Zero point is 0 for symmetric quantization
    (quantized, scale, 0i8)
}

/// Dequantize INT8 values back to FP32.
///
/// Inverse of `quantize_int8`: `x = q * scale + zero_point * scale`.
///
/// # Arguments
///
/// * `quantized` - INT8 quantized values
/// * `scale` - Scale factor from quantization
/// * `zero_point` - Zero point (typically 0 for symmetric)
///
/// # Returns
///
/// Dequantized FP32 values
///
/// # Example
///
/// ```rust
/// use rankops::rerank::quantization::{quantize_int8, dequantize_int8};
///
/// let original = vec![0.1, 0.2, -0.3];
/// let (quantized, scale, zero_point) = quantize_int8(&original);
/// let dequantized = dequantize_int8(&quantized, scale, zero_point);
/// // Dequantized values are close to original (within quantization error)
/// ```
pub fn dequantize_int8(quantized: &[i8], scale: f32, zero_point: i8) -> Vec<f32> {
    quantized
        .iter()
        .map(|&q| (q as f32 - zero_point as f32) * scale)
        .collect()
}

/// Quantize FP32 embeddings to FP16.
///
/// FP16 uses half precision (16 bits) instead of FP32 (32 bits),
/// providing 2x size reduction with minimal accuracy loss.
///
/// # Arguments
///
/// * `values` - FP32 values to quantize
///
/// # Returns
///
/// Quantized FP16 values (as u16, representing half-precision floats)
///
/// # Example
///
/// ```rust
/// use rankops::rerank::quantization::{quantize_fp16, dequantize_fp16};
///
/// let embeddings = vec![0.1, 0.2, 0.3];
/// let quantized = quantize_fp16(&embeddings);
/// let dequantized = dequantize_fp16(&quantized);
/// ```
pub fn quantize_fp16(values: &[f32]) -> Vec<u16> {
    values.iter().map(|&v| f32_to_fp16(v)).collect()
}

/// Dequantize FP16 values back to FP32.
///
/// Inverse of `quantize_fp16`.
///
/// # Arguments
///
/// * `quantized` - FP16 quantized values (as u16)
///
/// # Returns
///
/// Dequantized FP32 values
pub fn dequantize_fp16(quantized: &[u16]) -> Vec<f32> {
    quantized.iter().map(|&q| fp16_to_f32(q)).collect()
}

/// Convert FP32 to FP16 (half precision).
///
/// FP16 format: 1 sign bit, 5 exponent bits, 10 mantissa bits.
fn f32_to_fp16(value: f32) -> u16 {
    // Simple conversion: truncate mantissa, adjust exponent
    // This is a simplified version; full implementation would handle
    // special cases (NaN, Inf, denormals) more carefully
    let bits = value.to_bits();
    let sign = (bits >> 31) & 0x1;
    let exponent = ((bits >> 23) & 0xFF) as i32;
    let mantissa = bits & 0x7FFFFF;

    // Handle special cases
    if exponent == 0xFF {
        // NaN or Inf
        return ((sign << 15) | 0x7C00 | ((mantissa != 0) as u32)) as u16;
    }

    if exponent == 0 {
        // Zero or denormal
        return (sign << 15) as u16;
    }

    // Normal number: adjust exponent bias (127 -> 15)
    let exp_fp16 = exponent - 127 + 15;
    if exp_fp16 < 0 {
        // Underflow to zero
        return (sign << 15) as u16;
    }
    if exp_fp16 > 31 {
        // Overflow to Inf
        return ((sign << 15) | 0x7C00) as u16;
    }

    // Pack: sign (1) + exponent (5) + mantissa (10 MSB)
    ((sign << 15) | ((exp_fp16 as u32) << 10) | ((mantissa >> 13) & 0x3FF)) as u16
}

/// Convert FP16 (half precision) to FP32.
///
/// Inverse of `f32_to_fp16`.
fn fp16_to_f32(value: u16) -> f32 {
    let sign = (value >> 15) & 0x1;
    let exponent = (value >> 10) & 0x1F;
    let mantissa = value & 0x3FF;

    // Handle special cases
    if exponent == 0x1F {
        // NaN or Inf
        if mantissa == 0 {
            // Inf
            f32::from_bits(((sign as u32) << 31) | 0x7F800000)
        } else {
            // NaN
            f32::from_bits(((sign as u32) << 31) | 0x7FC00000)
        }
    } else if exponent == 0 {
        // Zero or denormal
        if mantissa == 0 {
            f32::from_bits((sign as u32) << 31)
        } else {
            // Denormal: convert to FP32 denormal
            let exp_fp32 = 127 - 15;
            let mantissa_fp32 = (mantissa as u32) << 13;
            f32::from_bits(((sign as u32) << 31) | ((exp_fp32 as u32) << 23) | mantissa_fp32)
        }
    } else {
        // Normal number: adjust exponent bias (15 -> 127)
        let exp_fp32 = (exponent as i32) - 15 + 127;
        let mantissa_fp32 = (mantissa as u32) << 13;
        f32::from_bits(((sign as u32) << 31) | ((exp_fp32 as u32) << 23) | mantissa_fp32)
    }
}

/// Quantize a batch of embeddings (2D).
///
/// Quantizes each embedding vector independently.
///
/// # Arguments
///
/// * `embeddings` - 2D embeddings: `[num_vectors][dim]`
/// * `quantization_type` - `"int8"` or `"fp16"`
///
/// # Returns
///
/// For INT8: `(quantized, scales, zero_points)` where scales/zero_points are per-vector
/// For FP16: `(quantized, None, None)` where quantized is `Vec<Vec<u16>>`
pub fn quantize_batch(
    embeddings: &[Vec<f32>],
    quantization_type: &str,
) -> Result<QuantizedBatch, QuantizationError> {
    match quantization_type {
        "int8" => {
            let mut quantized = Vec::new();
            let mut scales = Vec::new();
            let mut zero_points = Vec::new();

            for embedding in embeddings {
                let (q, scale, zp) = quantize_int8(embedding);
                quantized.push(q.iter().map(|&x| x as i16).collect()); // Convert to i16 for storage
                scales.push(scale);
                zero_points.push(zp);
            }

            Ok(QuantizedBatch::Int8 {
                quantized,
                scales,
                zero_points,
            })
        }
        "fp16" => {
            let quantized: Vec<Vec<u16>> = embeddings.iter().map(|e| quantize_fp16(e)).collect();
            Ok(QuantizedBatch::Fp16 { quantized })
        }
        _ => Err(QuantizationError::UnsupportedType(
            quantization_type.to_string(),
        )),
    }
}

/// Quantized batch representation.
#[derive(Debug, Clone)]
pub enum QuantizedBatch {
    /// INT8 quantized batch with per-vector scales and zero points.
    Int8 {
        /// Quantized values (i16 for accumulation headroom).
        quantized: Vec<Vec<i16>>,
        /// Per-vector scale factors.
        scales: Vec<f32>,
        /// Per-vector zero points.
        zero_points: Vec<i8>,
    },
    /// FP16 quantized batch.
    Fp16 {
        /// Half-precision quantized values (stored as u16 bit patterns).
        quantized: Vec<Vec<u16>>,
    },
}

/// Quantization errors.
#[derive(Debug, Clone, PartialEq)]
pub enum QuantizationError {
    /// Unsupported quantization type.
    UnsupportedType(String),
}

impl std::fmt::Display for QuantizationError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::UnsupportedType(t) => write!(f, "Unsupported quantization type: {}", t),
        }
    }
}

impl std::error::Error for QuantizationError {}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_int8_quantization_roundtrip() {
        let original = vec![0.1, 0.2, -0.3, 0.4, -0.5];
        let (quantized, scale, zero_point) = quantize_int8(&original);
        let dequantized = dequantize_int8(&quantized, scale, zero_point);

        // Check that dequantized values are close to original (within quantization error)
        for (orig, deq) in original.iter().zip(dequantized.iter()) {
            let error = (orig - deq).abs();
            assert!(
                error < scale,
                "Quantization error too large: {} vs {}",
                orig,
                deq
            );
        }
    }

    #[test]
    fn test_fp16_quantization_roundtrip() {
        let original = vec![0.1, 0.2, -0.3, 0.4, -0.5];
        let quantized = quantize_fp16(&original);
        let dequantized = dequantize_fp16(&quantized);

        // FP16 has ~0.1% relative error for values in [0.1, 1.0]
        for (orig, deq) in original.iter().zip(dequantized.iter()) {
            let relative_error = (orig - deq).abs() / orig.abs().max(1e-6);
            assert!(
                relative_error < 0.01,
                "FP16 quantization error too large: {} vs {} (rel error: {})",
                orig,
                deq,
                relative_error
            );
        }
    }

    #[test]
    fn test_quantize_batch_int8() {
        let embeddings = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
        let result = quantize_batch(&embeddings, "int8").unwrap();
        match result {
            QuantizedBatch::Int8 {
                quantized,
                scales,
                zero_points,
            } => {
                assert_eq!(quantized.len(), 2);
                assert_eq!(scales.len(), 2);
                assert_eq!(zero_points.len(), 2);
            }
            _ => panic!("Expected Int8 batch"),
        }
    }

    #[test]
    fn test_quantize_batch_fp16() {
        let embeddings = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
        let result = quantize_batch(&embeddings, "fp16").unwrap();
        match result {
            QuantizedBatch::Fp16 { quantized } => {
                assert_eq!(quantized.len(), 2);
            }
            _ => panic!("Expected Fp16 batch"),
        }
    }

    #[test]
    fn test_quantize_batch_unsupported() {
        let embeddings = vec![vec![0.1, 0.2]];
        let result = quantize_batch(&embeddings, "invalid");
        assert!(result.is_err());
    }
}