Skip to main content

ailake_query/
pruner.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use ailake_catalog::{decode_centroid, DataFileEntry};
3use ailake_core::VectorMetric;
4use ailake_vec::{cosine_distance, dot_product, euclidean_distance};
5use tracing::debug;
6
7pub struct VectorPruner;
8
9impl VectorPruner {
10    /// Remove files whose centroid is geometrically guaranteed to contain no vectors
11    /// within `threshold` distance of `query`.
12    ///
13    /// Pruning condition: `distance(query, centroid) - radius > threshold`
14    /// Files without centroid metadata are kept (conservative fallback).
15    pub fn prune(
16        files: Vec<DataFileEntry>,
17        query: &[f32],
18        metric: VectorMetric,
19        threshold: f32,
20    ) -> Vec<DataFileEntry> {
21        files
22            .into_iter()
23            .filter(|entry| {
24                match decode_centroid(entry, metric) {
25                    Some(centroid) => {
26                        let dist = compute_distance(query, &centroid.values, metric);
27                        let keep = dist - centroid.radius <= threshold;
28                        debug!(
29                            "ailake: pruner {} — dist={:.4} radius={:.4} edge={:.4} threshold={:.4} → {}",
30                            entry.path,
31                            dist,
32                            centroid.radius,
33                            dist - centroid.radius,
34                            threshold,
35                            if keep { "KEEP" } else { "PRUNE" }
36                        );
37                        keep
38                    }
39                    None => {
40                        debug!(
41                            "ailake: pruner {} — no centroid metadata, keeping (conservative fallback)",
42                            entry.path
43                        );
44                        true // no centroid → keep (safe fallback)
45                    }
46                }
47            })
48            .collect()
49    }
50}
51
52fn compute_distance(a: &[f32], b: &[f32], metric: VectorMetric) -> f32 {
53    match metric {
54        VectorMetric::Cosine | VectorMetric::NormalizedCosine => cosine_distance(a, b),
55        VectorMetric::Euclidean => euclidean_distance(a, b),
56        VectorMetric::DotProduct => -dot_product(a, b),
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63    use ailake_catalog::{make_data_file_entry, VectorIndexInfo};
64    use ailake_core::VectorMetric;
65    use ailake_vec::compute_centroid_and_radius;
66
67    fn make_entry(path: &str, vecs: &[Vec<f32>], metric: VectorMetric) -> DataFileEntry {
68        let centroid = compute_centroid_and_radius(vecs, metric);
69        make_data_file_entry(
70            path,
71            vecs.len() as u64,
72            1024,
73            &centroid,
74            VectorIndexInfo {
75                column: "embedding",
76                dim: vecs[0].len() as u32,
77                hnsw_offset: 0,
78                hnsw_len: 0,
79            },
80        )
81    }
82
83    #[test]
84    fn prunes_far_file() {
85        // File centroid near [1,0,0], query near [0,0,1] — orthogonal → prune
86        let vecs = vec![vec![1.0f32, 0.0, 0.0], vec![0.9, 0.1, 0.0]];
87        let entry = make_entry("far.parquet", &vecs, VectorMetric::Cosine);
88        let query = vec![0.0f32, 0.0, 1.0];
89        let pruned = VectorPruner::prune(vec![entry], &query, VectorMetric::Cosine, 0.1);
90        assert!(pruned.is_empty(), "far file should be pruned");
91    }
92
93    #[test]
94    fn keeps_nearby_file() {
95        let vecs = vec![vec![1.0f32, 0.0, 0.0], vec![0.99, 0.1, 0.0]];
96        let entry = make_entry("near.parquet", &vecs, VectorMetric::Cosine);
97        let query = vec![1.0f32, 0.0, 0.0];
98        let kept = VectorPruner::prune(vec![entry], &query, VectorMetric::Cosine, 0.5);
99        assert_eq!(kept.len(), 1, "nearby file should be kept");
100    }
101
102    #[test]
103    fn no_centroid_always_kept() {
104        let entry = DataFileEntry {
105            path: "unknown.parquet".into(),
106            record_count: 10,
107            file_size_bytes: 512,
108            centroid_b64: None,
109            radius: None,
110            hnsw_offset: None,
111            hnsw_len: None,
112            vector_column: None,
113            vector_dim: None,
114            extra_vector_indexes: vec![],
115            index_status: ailake_catalog::IndexStatus::Ready,
116            batch_id: None,
117        };
118        let query = vec![0.0f32, 0.0, 1.0];
119        let kept = VectorPruner::prune(vec![entry], &query, VectorMetric::Cosine, 0.0);
120        assert_eq!(kept.len(), 1);
121    }
122}