1use std::collections::HashMap;
3
4use ailake_catalog::{decode_centroid, DataFileEntry};
5use ailake_core::VectorMetric;
6use ailake_vec::{cosine_distance, dot_product, euclidean_distance};
7use tracing::debug;
8
9pub struct VectorPruner;
10
11impl VectorPruner {
12 pub fn prune(
18 files: Vec<DataFileEntry>,
19 query: &[f32],
20 metric: VectorMetric,
21 threshold: f32,
22 ) -> Vec<DataFileEntry> {
23 files
24 .into_iter()
25 .filter(|entry| {
26 match decode_centroid(entry, metric) {
27 Some(centroid) => {
28 let dist = compute_distance(query, ¢roid.values, metric);
29 let keep = dist - centroid.radius <= threshold;
30 debug!(
31 "ailake: pruner {} — dist={:.4} radius={:.4} edge={:.4} threshold={:.4} → {}",
32 entry.path,
33 dist,
34 centroid.radius,
35 dist - centroid.radius,
36 threshold,
37 if keep { "KEEP" } else { "PRUNE" }
38 );
39 keep
40 }
41 None => {
42 debug!(
43 "ailake: pruner {} — no centroid metadata, keeping (conservative fallback)",
44 entry.path
45 );
46 true }
48 }
49 })
50 .collect()
51 }
52}
53
54fn compute_distance(a: &[f32], b: &[f32], metric: VectorMetric) -> f32 {
55 match metric {
56 VectorMetric::Cosine | VectorMetric::NormalizedCosine => cosine_distance(a, b),
57 VectorMetric::Euclidean => euclidean_distance(a, b),
58 VectorMetric::DotProduct => -dot_product(a, b),
59 }
60}
61
62pub struct BloomPruner;
70
71impl BloomPruner {
72 pub fn prune(
77 files: Vec<DataFileEntry>,
78 query_text: &str,
79 bloom_map: &HashMap<String, crate::bloom::BloomFilter>,
80 ) -> Vec<DataFileEntry> {
81 let query_terms: Vec<String> = crate::bm25::tokenize(query_text);
82 if query_terms.is_empty() || bloom_map.is_empty() {
83 return files;
84 }
85 let before = files.len();
86 let surviving: Vec<DataFileEntry> = files
87 .into_iter()
88 .filter(|entry| match bloom_map.get(&entry.path) {
89 Some(bloom) => {
90 let keep = query_terms.iter().any(|t| bloom.may_contain(t));
91 debug!(
92 "ailake: bloom pruner {} — {} query terms, keep={}",
93 entry.path,
94 query_terms.len(),
95 keep
96 );
97 keep
98 }
99 None => true,
100 })
101 .collect();
102 debug!(
103 "ailake: bloom pruning — {}/{} files survive",
104 surviving.len(),
105 before
106 );
107 surviving
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use ailake_catalog::{make_data_file_entry, VectorIndexInfo};
115 use ailake_core::VectorMetric;
116 use ailake_vec::compute_centroid_and_radius;
117
118 fn make_entry(path: &str, vecs: &[Vec<f32>], metric: VectorMetric) -> DataFileEntry {
119 let centroid = compute_centroid_and_radius(vecs, metric);
120 make_data_file_entry(
121 path,
122 vecs.len() as u64,
123 1024,
124 ¢roid,
125 VectorIndexInfo {
126 column: "embedding",
127 dim: vecs[0].len() as u32,
128 hnsw_offset: 0,
129 hnsw_len: 0,
130 },
131 )
132 }
133
134 #[test]
135 fn prunes_far_file() {
136 let vecs = vec![vec![1.0f32, 0.0, 0.0], vec![0.9, 0.1, 0.0]];
138 let entry = make_entry("far.parquet", &vecs, VectorMetric::Cosine);
139 let query = vec![0.0f32, 0.0, 1.0];
140 let pruned = VectorPruner::prune(vec![entry], &query, VectorMetric::Cosine, 0.1);
141 assert!(pruned.is_empty(), "far file should be pruned");
142 }
143
144 #[test]
145 fn keeps_nearby_file() {
146 let vecs = vec![vec![1.0f32, 0.0, 0.0], vec![0.99, 0.1, 0.0]];
147 let entry = make_entry("near.parquet", &vecs, VectorMetric::Cosine);
148 let query = vec![1.0f32, 0.0, 0.0];
149 let kept = VectorPruner::prune(vec![entry], &query, VectorMetric::Cosine, 0.5);
150 assert_eq!(kept.len(), 1, "nearby file should be kept");
151 }
152
153 #[test]
154 fn no_centroid_always_kept() {
155 let entry = DataFileEntry {
156 path: "unknown.parquet".into(),
157 record_count: 10,
158 file_size_bytes: 512,
159 centroid_b64: None,
160 radius: None,
161 hnsw_offset: None,
162 hnsw_len: None,
163 vector_column: None,
164 vector_dim: None,
165 extra_vector_indexes: vec![],
166 index_status: ailake_catalog::IndexStatus::Ready,
167 batch_id: None,
168 embedding_model: None,
169 partition_value: None,
170 deletion_vector: None,
171 first_row_id: None,
172 };
173 let query = vec![0.0f32, 0.0, 1.0];
174 let kept = VectorPruner::prune(vec![entry], &query, VectorMetric::Cosine, 0.0);
175 assert_eq!(kept.len(), 1);
176 }
177}