Skip to main content

ailake_query/
pruner.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use ailake_catalog::{decode_centroid, DataFileEntry};
3use ailake_core::VectorMetric;
4use ailake_vec::{cosine_distance, dot_product, euclidean_distance};
5
6pub struct VectorPruner;
7
8impl VectorPruner {
9    /// Remove files whose centroid is geometrically guaranteed to contain no vectors
10    /// within `threshold` distance of `query`.
11    ///
12    /// Pruning condition: `distance(query, centroid) - radius > threshold`
13    /// Files without centroid metadata are kept (conservative fallback).
14    pub fn prune(
15        files: Vec<DataFileEntry>,
16        query: &[f32],
17        metric: VectorMetric,
18        threshold: f32,
19    ) -> Vec<DataFileEntry> {
20        files
21            .into_iter()
22            .filter(|entry| {
23                match decode_centroid(entry, metric) {
24                    Some(centroid) => {
25                        let dist = compute_distance(query, &centroid.values, metric);
26                        // Keep file if any of its vectors could be within threshold
27                        dist - centroid.radius <= threshold
28                    }
29                    None => true, // no centroid → keep (safe fallback)
30                }
31            })
32            .collect()
33    }
34}
35
36fn compute_distance(a: &[f32], b: &[f32], metric: VectorMetric) -> f32 {
37    match metric {
38        VectorMetric::Cosine => cosine_distance(a, b),
39        VectorMetric::Euclidean => euclidean_distance(a, b),
40        VectorMetric::DotProduct => -dot_product(a, b),
41    }
42}
43
44#[cfg(test)]
45mod tests {
46    use super::*;
47    use ailake_catalog::{make_data_file_entry, VectorIndexInfo};
48    use ailake_core::VectorMetric;
49    use ailake_vec::compute_centroid_and_radius;
50
51    fn make_entry(path: &str, vecs: &[Vec<f32>], metric: VectorMetric) -> DataFileEntry {
52        let centroid = compute_centroid_and_radius(vecs, metric);
53        make_data_file_entry(
54            path,
55            vecs.len() as u64,
56            1024,
57            &centroid,
58            VectorIndexInfo {
59                column: "embedding",
60                dim: vecs[0].len() as u32,
61                hnsw_offset: 0,
62                hnsw_len: 0,
63            },
64        )
65    }
66
67    #[test]
68    fn prunes_far_file() {
69        // File centroid near [1,0,0], query near [0,0,1] — orthogonal → prune
70        let vecs = vec![vec![1.0f32, 0.0, 0.0], vec![0.9, 0.1, 0.0]];
71        let entry = make_entry("far.parquet", &vecs, VectorMetric::Cosine);
72        let query = vec![0.0f32, 0.0, 1.0];
73        let pruned = VectorPruner::prune(vec![entry], &query, VectorMetric::Cosine, 0.1);
74        assert!(pruned.is_empty(), "far file should be pruned");
75    }
76
77    #[test]
78    fn keeps_nearby_file() {
79        let vecs = vec![vec![1.0f32, 0.0, 0.0], vec![0.99, 0.1, 0.0]];
80        let entry = make_entry("near.parquet", &vecs, VectorMetric::Cosine);
81        let query = vec![1.0f32, 0.0, 0.0];
82        let kept = VectorPruner::prune(vec![entry], &query, VectorMetric::Cosine, 0.5);
83        assert_eq!(kept.len(), 1, "nearby file should be kept");
84    }
85
86    #[test]
87    fn no_centroid_always_kept() {
88        let entry = DataFileEntry {
89            path: "unknown.parquet".into(),
90            record_count: 10,
91            file_size_bytes: 512,
92            centroid_b64: None,
93            radius: None,
94            hnsw_offset: None,
95            hnsw_len: None,
96            vector_column: None,
97            vector_dim: None,
98            extra_vector_indexes: vec![],
99            index_status: ailake_catalog::IndexStatus::Ready,
100            batch_id: None,
101        };
102        let query = vec![0.0f32, 0.0, 1.0];
103        let kept = VectorPruner::prune(vec![entry], &query, VectorMetric::Cosine, 0.0);
104        assert_eq!(kept.len(), 1);
105    }
106}