1use ailake_core::{AilakeError, AilakeResult, Centroid, VectorMetric};
3use ailake_index::{AnyIndex, HnswIndex, IvfPqSerializer, MmapLoader};
4use ailake_parquet::ParquetVectorReader;
5use arrow_array::RecordBatch;
6use bytes::Bytes;
7
8use crate::footer::{AilakeHeader, DistanceMetric, FLAG_INDEX_IVF_PQ, HEADER_SIZE};
9
10pub struct AilakeFileReader {
11 bytes: Bytes,
12 vector_column: String,
13 #[allow(dead_code)]
14 dim: u32,
15}
16
17impl AilakeFileReader {
18 pub fn new(bytes: Bytes, vector_column: &str, dim: u32) -> Self {
19 Self {
20 bytes,
21 vector_column: vector_column.to_string(),
22 dim,
23 }
24 }
25
26 pub fn ailk_offset(&self) -> AilakeResult<u64> {
29 let reader = ParquetVectorReader::new(self.bytes.clone(), &self.vector_column);
30 let val = reader
31 .kv_metadata("ailake.footer_offset")?
32 .ok_or(AilakeError::NotAnAilakeFile)?;
33 val.parse::<u64>().map_err(|_| AilakeError::NotAnAilakeFile)
34 }
35
36 pub fn ailk_offset_for_column(&self, column: &str) -> AilakeResult<u64> {
41 let reader = ParquetVectorReader::new(self.bytes.clone(), column);
42 let col_key = format!("ailake.{column}.footer_offset");
43 if let Some(val) = reader.kv_metadata(&col_key)? {
44 return val.parse::<u64>().map_err(|_| AilakeError::NotAnAilakeFile);
45 }
46 let val = reader
47 .kv_metadata("ailake.footer_offset")?
48 .ok_or(AilakeError::NotAnAilakeFile)?;
49 val.parse::<u64>().map_err(|_| AilakeError::NotAnAilakeFile)
50 }
51
52 pub fn is_ailake_file(&self) -> bool {
54 self.ailk_offset().is_ok()
55 }
56
57 pub fn read_header(&self) -> AilakeResult<AilakeHeader> {
59 let offset = self.ailk_offset()? as usize;
60 if offset + HEADER_SIZE > self.bytes.len() {
61 return Err(AilakeError::NotAnAilakeFile);
62 }
63 let header_bytes: &[u8; HEADER_SIZE] = self.bytes[offset..offset + HEADER_SIZE]
64 .try_into()
65 .map_err(|_| AilakeError::NotAnAilakeFile)?;
66 AilakeHeader::from_bytes(header_bytes)
67 }
68
69 pub fn get_centroid(&self) -> AilakeResult<Centroid> {
71 let ailk_start = self.ailk_offset()? as usize;
72 let header = self.read_header()?;
73 let centroid_start = ailk_start + header.centroid_offset as usize;
74 let centroid_end = centroid_start + header.centroid_len as usize;
75
76 if centroid_end > self.bytes.len() {
77 return Err(AilakeError::NotAnAilakeFile);
78 }
79
80 let centroid_data = &self.bytes[centroid_start..centroid_end];
81 let dim = header.dim as usize;
82 let expected_len = dim * 4 + 4;
83 if centroid_data.len() != expected_len {
84 return Err(AilakeError::InvalidCentroidLength {
85 expected_dim: header.dim,
86 actual: centroid_data.len(),
87 });
88 }
89
90 let values: Vec<f32> = centroid_data[..dim * 4]
91 .chunks_exact(4)
92 .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
93 .collect();
94 let radius = f32::from_le_bytes(centroid_data[dim * 4..].try_into().unwrap());
95 let metric = distance_metric_to_vector_metric(header.distance_metric);
96
97 Ok(Centroid {
98 values,
99 radius,
100 metric,
101 })
102 }
103
104 pub fn load_index(&self) -> AilakeResult<HnswIndex> {
106 self.load_index_for_column(&self.vector_column.clone())
107 }
108
109 pub fn load_index_for_column(&self, column: &str) -> AilakeResult<HnswIndex> {
114 let ailk_start = self.ailk_offset_for_column(column)? as usize;
115
116 if ailk_start + HEADER_SIZE > self.bytes.len() {
117 return Err(AilakeError::NotAnAilakeFile);
118 }
119 let header_bytes: &[u8; HEADER_SIZE] = self.bytes[ailk_start..ailk_start + HEADER_SIZE]
120 .try_into()
121 .map_err(|_| AilakeError::NotAnAilakeFile)?;
122 let header = AilakeHeader::from_bytes(header_bytes)?;
123
124 let hnsw_start = ailk_start + header.hnsw_offset as usize;
125 let hnsw_end = hnsw_start + header.hnsw_len as usize;
126
127 if hnsw_end > self.bytes.len() {
128 return Err(AilakeError::NotAnAilakeFile);
129 }
130 MmapLoader::from_bytes(&self.bytes[hnsw_start..hnsw_end])
131 }
132
133 pub fn load_any_index(&self) -> AilakeResult<AnyIndex> {
135 self.load_any_index_for_column(&self.vector_column.clone())
136 }
137
138 pub fn load_any_index_for_column(&self, column: &str) -> AilakeResult<AnyIndex> {
140 let ailk_start = self.ailk_offset_for_column(column)? as usize;
141
142 if ailk_start + HEADER_SIZE > self.bytes.len() {
143 return Err(AilakeError::NotAnAilakeFile);
144 }
145 let header_bytes: &[u8; HEADER_SIZE] = self.bytes[ailk_start..ailk_start + HEADER_SIZE]
146 .try_into()
147 .map_err(|_| AilakeError::NotAnAilakeFile)?;
148 let header = AilakeHeader::from_bytes(header_bytes)?;
149
150 let index_start = ailk_start + header.hnsw_offset as usize;
151 let index_end = index_start + header.hnsw_len as usize;
152
153 if index_end > self.bytes.len() {
154 return Err(AilakeError::NotAnAilakeFile);
155 }
156 let index_bytes = &self.bytes[index_start..index_end];
157
158 if header.flags & FLAG_INDEX_IVF_PQ != 0 {
159 let idx = IvfPqSerializer::from_bytes(index_bytes)?;
160 Ok(AnyIndex::IvfPq(idx))
161 } else {
162 let idx = MmapLoader::from_bytes(index_bytes)?;
163 Ok(AnyIndex::Hnsw(idx))
164 }
165 }
166
167 pub fn read_parquet(&self) -> AilakeResult<(RecordBatch, Vec<Vec<f32>>)> {
170 let reader = ParquetVectorReader::new(self.bytes.clone(), &self.vector_column);
171 reader.read_all()
172 }
173
174 pub fn verify_integrity(&self) -> AilakeResult<()> {
176 let header = self.read_header()?;
177 let index = self.load_index()?;
178 let reader = ParquetVectorReader::new(self.bytes.clone(), &self.vector_column);
179 let parquet_count = reader.record_count()?;
180
181 if parquet_count != index.node_count() {
182 return Err(AilakeError::RowCountMismatch {
183 parquet: parquet_count,
184 hnsw: index.node_count(),
185 });
186 }
187 if parquet_count != header.record_count {
188 return Err(AilakeError::RowCountMismatch {
189 parquet: parquet_count,
190 hnsw: header.record_count,
191 });
192 }
193 Ok(())
194 }
195}
196
197fn distance_metric_to_vector_metric(dm: DistanceMetric) -> VectorMetric {
198 match dm {
199 DistanceMetric::Cosine => VectorMetric::Cosine,
200 DistanceMetric::Euclidean => VectorMetric::Euclidean,
201 DistanceMetric::DotProduct => VectorMetric::DotProduct,
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::writer::AilakeFileWriter;
209 use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
210 use arrow_array::{Int32Array, RecordBatch};
211 use arrow_schema::{DataType, Field, Schema};
212 use std::sync::Arc;
213
214 fn make_policy(dim: u32) -> VectorStoragePolicy {
215 VectorStoragePolicy {
216 column_name: "embedding".to_string(),
217 dim,
218 metric: VectorMetric::Cosine,
219 precision: VectorPrecision::F16,
220 pq: None,
221 keep_raw_for_reranking: false,
222 }
223 }
224
225 fn write_file(rows: usize, dim: u32) -> Bytes {
226 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
227 let ids: Vec<i32> = (0..rows as i32).collect();
228 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
229 let embs: Vec<Vec<f32>> = (0..rows)
230 .map(|i| {
231 let mut v = vec![0.0f32; dim as usize];
232 v[i % dim as usize] = 1.0;
233 v
234 })
235 .collect();
236 AilakeFileWriter::new(make_policy(dim))
237 .write(&batch, &embs)
238 .unwrap()
239 }
240
241 #[test]
242 fn is_ailake_file() {
243 let file = write_file(3, 4);
244 let reader = AilakeFileReader::new(file, "embedding", 4);
245 assert!(reader.is_ailake_file());
246 }
247
248 #[test]
249 fn integrity_check_passes() {
250 let file = write_file(10, 8);
251 let reader = AilakeFileReader::new(file, "embedding", 8);
252 reader.verify_integrity().unwrap();
253 }
254
255 #[test]
256 fn centroid_has_correct_dim() {
257 let file = write_file(5, 4);
258 let reader = AilakeFileReader::new(file, "embedding", 4);
259 let centroid = reader.get_centroid().unwrap();
260 assert_eq!(centroid.values.len(), 4);
261 }
262
263 #[test]
264 fn search_finds_nearest() {
265 let dim = 4u32;
266 let file = write_file(4, dim);
267 let reader = AilakeFileReader::new(file, "embedding", dim);
268 let index = reader.load_index().unwrap();
269 let query = vec![1.0f32, 0.0, 0.0, 0.0];
270 let results = index.search(&query, 1, 50);
271 assert_eq!(results.len(), 1);
272 assert_eq!(results[0].0, ailake_core::RowId::new(0));
273 }
274
275 #[test]
276 fn parquet_read_returns_tabular_data() {
277 let file = write_file(3, 4);
278 let reader = AilakeFileReader::new(file, "embedding", 4);
279 let (batch, embs) = reader.read_parquet().unwrap();
280 assert_eq!(batch.num_rows(), 3);
281 assert_eq!(embs.len(), 3);
282 }
283}