lucisearch 0.8.0

Embeddable, in-process search engine — the SQLite/DuckDB of Elasticsearch
Documentation
//! Scalar quantization for vector search.
//!
//! Per-dimension linear quantization: maps float32 values to a reduced
//! integer representation using per-dimension min/max computed across
//! the dataset. Distances are computed asymmetrically: stored vectors
//! at reduced precision, query vector at float32.
//!
//! Currently implements `Int8` only. The user-facing [`QuantizationType`]
//! enum lives in `luci-mapping` and includes recognized-but-unimplemented
//! variants (`Int4`, `Bbq`); those are rejected at mapping parse time so
//! they cannot reach this layer. See [[code-must-not-lie]],
//! [[optimization-knn-int8-quantization]], and [[quantization]].

use super::DistanceMetric;

/// Int8 quantized vectors with per-dimension calibration data.
pub struct QuantizedVectors {
    pub dims: usize,
    pub num_vectors: usize,
    /// Flat quantized data: `data[i * dims .. (i+1) * dims]` = vector i.
    pub data: Vec<u8>,
    /// Per-dimension minimum values.
    pub mins: Vec<f32>,
    /// Per-dimension scale: `(max - min) / 255.0`.
    pub scales: Vec<f32>,
    /// Per-vector precomputed norms (approximate, for cosine distance).
    pub norms: Vec<f32>,
    /// Distance metric.
    pub metric: DistanceMetric,
}

impl QuantizedVectors {
    /// Quantize a set of float32 vectors to int8.
    pub fn quantize(vectors: &[Vec<f32>], metric: DistanceMetric) -> Self {
        if vectors.is_empty() {
            return Self {
                dims: 0,
                num_vectors: 0,
                data: Vec::new(),
                mins: Vec::new(),
                scales: Vec::new(),
                norms: Vec::new(),
                metric,
            };
        }

        let dims = vectors[0].len();
        let num_vectors = vectors.len();

        // Compute per-dimension min and max
        let mut mins = vec![f32::MAX; dims];
        let mut maxs = vec![f32::MIN; dims];
        for v in vectors {
            for d in 0..dims {
                if v[d] < mins[d] {
                    mins[d] = v[d];
                }
                if v[d] > maxs[d] {
                    maxs[d] = v[d];
                }
            }
        }

        // Compute scales (avoid division by zero for constant dimensions)
        let scales: Vec<f32> = (0..dims)
            .map(|d| {
                let range = maxs[d] - mins[d];
                if range == 0.0 { 0.0 } else { range / 255.0 }
            })
            .collect();

        // Quantize all vectors
        let mut data = vec![0u8; num_vectors * dims];
        let mut norms = vec![0.0f32; num_vectors];

        for (i, v) in vectors.iter().enumerate() {
            let offset = i * dims;
            let mut norm_sq = 0.0f32;
            for d in 0..dims {
                let q = if scales[d] == 0.0 {
                    128u8 // midpoint for constant dimensions
                } else {
                    ((v[d] - mins[d]) / scales[d]).round().clamp(0.0, 255.0) as u8
                };
                data[offset + d] = q;
                // Approximate dequantized value for norm computation
                let dequant = mins[d] + q as f32 * scales[d];
                norm_sq += dequant * dequant;
            }
            norms[i] = norm_sq.sqrt();
        }

        Self {
            dims,
            num_vectors,
            data,
            mins,
            scales,
            norms,
            metric,
        }
    }

    /// Get the quantized vector for index `idx`.
    #[inline]
    pub fn get(&self, idx: usize) -> &[u8] {
        let start = idx * self.dims;
        &self.data[start..start + self.dims]
    }

    /// Compute asymmetric distance between quantized stored vector and
    /// float32 query vector.
    ///
    /// For cosine, the query is expected to be unit-length (caller
    /// normalizes at entry per [[optimize-cosine-norm-precompute]]).
    /// `stored_norm` is the norm of the *dequantized* stored vector,
    /// which approximates 1.0 but drifts due to int8 quantization
    /// rounding — keeping it cancels that drift in the score.
    #[inline]
    pub fn asymmetric_distance(&self, idx: usize, query: &[f32]) -> f32 {
        match self.metric {
            DistanceMetric::Cosine => self.asymmetric_cosine(idx, query),
            DistanceMetric::DotProduct => self.asymmetric_dot(idx, query),
            DistanceMetric::L2 => self.asymmetric_l2(idx, query),
        }
    }

    /// Asymmetric cosine distance: quantized stored × float32 query.
    /// Caller guarantees the query is unit-length, so the cosine
    /// denominator collapses to `stored_norm`.
    fn asymmetric_cosine(&self, idx: usize, query: &[f32]) -> f32 {
        let quantized = self.get(idx);
        let mut dot = 0.0f32;
        for d in 0..self.dims {
            let dequant = self.mins[d] + quantized[d] as f32 * self.scales[d];
            dot += dequant * query[d];
        }
        let stored_norm = self.norms[idx];
        if stored_norm == 0.0 {
            1.0
        } else {
            1.0 - dot / stored_norm
        }
    }

    /// Asymmetric negative dot product: quantized stored × float32 query.
    fn asymmetric_dot(&self, idx: usize, query: &[f32]) -> f32 {
        let quantized = self.get(idx);
        let mut dot = 0.0f32;
        for d in 0..self.dims {
            let dequant = self.mins[d] + quantized[d] as f32 * self.scales[d];
            dot += dequant * query[d];
        }
        -dot
    }

    /// Asymmetric L2 distance: quantized stored × float32 query.
    fn asymmetric_l2(&self, idx: usize, query: &[f32]) -> f32 {
        let quantized = self.get(idx);
        let mut sum_sq = 0.0f32;
        for d in 0..self.dims {
            let dequant = self.mins[d] + quantized[d] as f32 * self.scales[d];
            let diff = dequant - query[d];
            sum_sq += diff * diff;
        }
        sum_sq.sqrt()
    }

    /// Serialize to bytes.
    pub fn to_bytes(&self) -> Vec<u8> {
        let mut buf = Vec::new();
        buf.extend_from_slice(&(self.dims as u32).to_le_bytes());
        buf.extend_from_slice(&(self.num_vectors as u32).to_le_bytes());
        buf.push(self.metric as u8);
        // Mins
        for &m in &self.mins {
            buf.extend_from_slice(&m.to_le_bytes());
        }
        // Scales
        for &s in &self.scales {
            buf.extend_from_slice(&s.to_le_bytes());
        }
        // Norms
        for &n in &self.norms {
            buf.extend_from_slice(&n.to_le_bytes());
        }
        // Quantized data
        buf.extend_from_slice(&self.data);
        buf
    }

    /// Deserialize from bytes.
    pub fn from_bytes(data: &[u8]) -> Self {
        let dims = u32::from_le_bytes(data[0..4].try_into().unwrap()) as usize;
        let num_vectors = u32::from_le_bytes(data[4..8].try_into().unwrap()) as usize;
        let metric = DistanceMetric::from_byte(data[8]);
        let mut pos = 9;

        let mut mins = vec![0.0f32; dims];
        for d in 0..dims {
            mins[d] = f32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
            pos += 4;
        }

        let mut scales = vec![0.0f32; dims];
        for d in 0..dims {
            scales[d] = f32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
            pos += 4;
        }

        let mut norms = vec![0.0f32; num_vectors];
        for i in 0..num_vectors {
            norms[i] = f32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
            pos += 4;
        }

        let qdata = data[pos..pos + num_vectors * dims].to_vec();

        Self {
            dims,
            num_vectors,
            data: qdata,
            mins,
            scales,
            norms,
            metric,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn quantize_round_trip() {
        let vectors = vec![
            vec![1.0, 2.0, 3.0],
            vec![4.0, 5.0, 6.0],
            vec![7.0, 8.0, 9.0],
        ];
        let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::Cosine);

        assert_eq!(qv.dims, 3);
        assert_eq!(qv.num_vectors, 3);
        // First vector should quantize to low values (near min)
        assert_eq!(qv.get(0), &[0, 0, 0]);
        // Last vector should quantize to 255 (at max)
        assert_eq!(qv.get(2), &[255, 255, 255]);
    }

    #[test]
    fn asymmetric_cosine_close_to_exact() {
        // Inputs are unit-length — production code normalizes the query
        // before reaching this kernel under the v0.7.2 invariant.
        let vectors = vec![
            vec![1.0, 0.0, 0.0],
            vec![0.0, 1.0, 0.0],
            vec![0.707, 0.707, 0.0],
        ];
        let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::Cosine);

        let query = vec![1.0, 0.0, 0.0];

        let d0 = qv.asymmetric_distance(0, &query);
        let d1 = qv.asymmetric_distance(1, &query);
        let d2 = qv.asymmetric_distance(2, &query);

        // Vector 0 is closest to query (same direction)
        assert!(d0 < d2, "d0={d0} should be < d2={d2}");
        assert!(d2 < d1, "d2={d2} should be < d1={d1}");
    }

    #[test]
    fn serialization_round_trip() {
        let vectors = vec![vec![1.5, -2.3, 0.7, 4.1], vec![-0.5, 3.2, 1.1, -1.0]];
        let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::L2);
        let bytes = qv.to_bytes();
        let qv2 = QuantizedVectors::from_bytes(&bytes);

        assert_eq!(qv.dims, qv2.dims);
        assert_eq!(qv.num_vectors, qv2.num_vectors);
        assert_eq!(qv.data, qv2.data);
        assert_eq!(qv.mins, qv2.mins);
        assert_eq!(qv.scales, qv2.scales);
    }

    #[test]
    fn empty_vectors() {
        let qv = QuantizedVectors::quantize(&[], DistanceMetric::Cosine);
        assert_eq!(qv.num_vectors, 0);
        assert_eq!(qv.dims, 0);
    }

    #[test]
    fn constant_dimension() {
        // All vectors have same value in one dimension
        let vectors = vec![vec![1.0, 5.0], vec![2.0, 5.0], vec![3.0, 5.0]];
        let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::L2);
        // Second dimension is constant — should get midpoint (128)
        assert_eq!(qv.get(0)[1], 128);
        assert_eq!(qv.get(1)[1], 128);
        assert_eq!(qv.get(2)[1], 128);
    }
}