Skip to main content

ailake_file/
reader.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use ailake_core::{AilakeError, AilakeResult, Centroid, VectorMetric};
3use ailake_index::{AnyIndex, HnswIndex, IvfPqSerializer, MmapLoader, RaBitQSerializer};
4use ailake_parquet::ParquetVectorReader;
5use arrow_array::RecordBatch;
6use bytes::Bytes;
7
8use crate::footer::{
9    AilakeHeader, DistanceMetric, FLAG_INDEX_IVF_PQ, FLAG_INDEX_RABITQ, HEADER_SIZE,
10};
11
12pub struct AilakeFileReader {
13    bytes: Bytes,
14    vector_column: String,
15    #[allow(dead_code)]
16    dim: u32,
17}
18
19impl AilakeFileReader {
20    pub fn new(bytes: Bytes, vector_column: &str, dim: u32) -> Self {
21        Self {
22            bytes,
23            vector_column: vector_column.to_string(),
24            dim,
25        }
26    }
27
28    /// Returns the absolute byte offset of the primary AILK section.
29    /// Reads `ailake.footer_offset` from the Parquet footer key-value metadata.
30    pub fn ailk_offset(&self) -> AilakeResult<u64> {
31        let reader = ParquetVectorReader::new(self.bytes.clone(), &self.vector_column);
32        let val = reader
33            .kv_metadata("ailake.footer_offset")?
34            .ok_or(AilakeError::NotAnAilakeFile)?;
35        val.parse::<u64>().map_err(|_| AilakeError::NotAnAilakeFile)
36    }
37
38    /// Returns the absolute byte offset of the AILK section for a named vector column.
39    ///
40    /// For additional columns tries `ailake.{column}.footer_offset` first,
41    /// then falls back to `ailake.footer_offset` (primary / single-column files).
42    pub fn ailk_offset_for_column(&self, column: &str) -> AilakeResult<u64> {
43        let reader = ParquetVectorReader::new(self.bytes.clone(), column);
44        let col_key = format!("ailake.{column}.footer_offset");
45        if let Some(val) = reader.kv_metadata(&col_key)? {
46            return val.parse::<u64>().map_err(|_| AilakeError::NotAnAilakeFile);
47        }
48        let val = reader
49            .kv_metadata("ailake.footer_offset")?
50            .ok_or(AilakeError::NotAnAilakeFile)?;
51        val.parse::<u64>().map_err(|_| AilakeError::NotAnAilakeFile)
52    }
53
54    /// Returns true if the file contains an embedded AILK section.
55    pub fn is_ailake_file(&self) -> bool {
56        self.ailk_offset().is_ok()
57    }
58
59    /// Parse the 64-byte AI-Lake header from the embedded AILK section.
60    pub fn read_header(&self) -> AilakeResult<AilakeHeader> {
61        let offset = self.ailk_offset()? as usize;
62        if offset + HEADER_SIZE > self.bytes.len() {
63            return Err(AilakeError::NotAnAilakeFile);
64        }
65        let header_bytes: &[u8; HEADER_SIZE] = self.bytes[offset..offset + HEADER_SIZE]
66            .try_into()
67            .map_err(|_| AilakeError::NotAnAilakeFile)?;
68        AilakeHeader::from_bytes(header_bytes)
69    }
70
71    /// Read centroid + radius from the AILK section.
72    pub fn get_centroid(&self) -> AilakeResult<Centroid> {
73        let ailk_start = self.ailk_offset()? as usize;
74        let header = self.read_header()?;
75        let centroid_start = ailk_start + header.centroid_offset as usize;
76        let centroid_end = centroid_start + header.centroid_len as usize;
77
78        if centroid_end > self.bytes.len() {
79            return Err(AilakeError::NotAnAilakeFile);
80        }
81
82        let centroid_data = &self.bytes[centroid_start..centroid_end];
83        let dim = header.dim as usize;
84        let expected_len = dim * 4 + 4;
85        if centroid_data.len() != expected_len {
86            return Err(AilakeError::InvalidCentroidLength {
87                expected_dim: header.dim,
88                actual: centroid_data.len(),
89            });
90        }
91
92        let values: Vec<f32> = centroid_data[..dim * 4]
93            .chunks_exact(4)
94            .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
95            .collect();
96        let radius = f32::from_le_bytes(centroid_data[dim * 4..].try_into().unwrap());
97        let metric = distance_metric_to_vector_metric(header.distance_metric);
98
99        Ok(Centroid {
100            values,
101            radius,
102            metric,
103        })
104    }
105
106    /// Load the HNSW index from the primary AILK section.
107    pub fn load_index(&self) -> AilakeResult<HnswIndex> {
108        self.load_index_for_column(&self.vector_column.clone())
109    }
110
111    /// Load the HNSW index for a specific vector column.
112    ///
113    /// Works for both single-column files (falls back to primary AILK) and
114    /// multi-column files written with `AilakeFileWriter::write_multi`.
115    pub fn load_index_for_column(&self, column: &str) -> AilakeResult<HnswIndex> {
116        let ailk_start = self.ailk_offset_for_column(column)? as usize;
117
118        if ailk_start + HEADER_SIZE > self.bytes.len() {
119            return Err(AilakeError::NotAnAilakeFile);
120        }
121        let header_bytes: &[u8; HEADER_SIZE] = self.bytes[ailk_start..ailk_start + HEADER_SIZE]
122            .try_into()
123            .map_err(|_| AilakeError::NotAnAilakeFile)?;
124        let header = AilakeHeader::from_bytes(header_bytes)?;
125
126        let hnsw_start = ailk_start + header.hnsw_offset as usize;
127        let hnsw_end = hnsw_start + header.hnsw_len as usize;
128
129        if hnsw_end > self.bytes.len() {
130            return Err(AilakeError::NotAnAilakeFile);
131        }
132        MmapLoader::from_bytes(&self.bytes[hnsw_start..hnsw_end])
133    }
134
135    /// Load primary index as `AnyIndex`, dispatching on header flags.
136    pub fn load_any_index(&self) -> AilakeResult<AnyIndex> {
137        self.load_any_index_for_column(&self.vector_column.clone())
138    }
139
140    /// Load index for a specific vector column as `AnyIndex`.
141    pub fn load_any_index_for_column(&self, column: &str) -> AilakeResult<AnyIndex> {
142        let ailk_start = self.ailk_offset_for_column(column)? as usize;
143
144        if ailk_start + HEADER_SIZE > self.bytes.len() {
145            return Err(AilakeError::NotAnAilakeFile);
146        }
147        let header_bytes: &[u8; HEADER_SIZE] = self.bytes[ailk_start..ailk_start + HEADER_SIZE]
148            .try_into()
149            .map_err(|_| AilakeError::NotAnAilakeFile)?;
150        let header = AilakeHeader::from_bytes(header_bytes)?;
151
152        let index_start = ailk_start + header.hnsw_offset as usize;
153        let index_end = index_start + header.hnsw_len as usize;
154
155        if index_end > self.bytes.len() {
156            return Err(AilakeError::NotAnAilakeFile);
157        }
158        let index_bytes = &self.bytes[index_start..index_end];
159
160        if header.flags & FLAG_INDEX_RABITQ != 0 {
161            let idx = RaBitQSerializer::from_bytes(index_bytes)?;
162            Ok(AnyIndex::RaBitQ(idx))
163        } else if header.flags & FLAG_INDEX_IVF_PQ != 0 {
164            let idx = IvfPqSerializer::from_bytes(index_bytes)?;
165            Ok(AnyIndex::IvfPq(idx))
166        } else {
167            let idx = MmapLoader::from_bytes(index_bytes)?;
168            Ok(AnyIndex::Hnsw(idx))
169        }
170    }
171
172    /// Read the Parquet section (tabular data + decoded embeddings).
173    /// The full file is valid Parquet; the AILK section is invisible to standard readers.
174    pub fn read_parquet(&self) -> AilakeResult<(RecordBatch, Vec<Vec<f32>>)> {
175        let reader = ParquetVectorReader::new(self.bytes.clone(), &self.vector_column);
176        reader.read_all()
177    }
178
179    /// Verify the positional invariant: Parquet record_count == HNSW node_count.
180    pub fn verify_integrity(&self) -> AilakeResult<()> {
181        let header = self.read_header()?;
182        let index = self.load_index()?;
183        let reader = ParquetVectorReader::new(self.bytes.clone(), &self.vector_column);
184        let parquet_count = reader.record_count()?;
185
186        if parquet_count != index.node_count() {
187            return Err(AilakeError::RowCountMismatch {
188                parquet: parquet_count,
189                hnsw: index.node_count(),
190            });
191        }
192        if parquet_count != header.record_count {
193            return Err(AilakeError::RowCountMismatch {
194                parquet: parquet_count,
195                hnsw: header.record_count,
196            });
197        }
198        Ok(())
199    }
200}
201
202fn distance_metric_to_vector_metric(dm: DistanceMetric) -> VectorMetric {
203    match dm {
204        DistanceMetric::Cosine => VectorMetric::Cosine,
205        DistanceMetric::Euclidean => VectorMetric::Euclidean,
206        DistanceMetric::DotProduct => VectorMetric::DotProduct,
207        DistanceMetric::NormalizedCosine => VectorMetric::NormalizedCosine,
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::writer::AilakeFileWriter;
215    use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
216    use arrow_array::{Int32Array, RecordBatch};
217    use arrow_schema::{DataType, Field, Schema};
218    use std::sync::Arc;
219
220    fn make_policy(dim: u32) -> VectorStoragePolicy {
221        VectorStoragePolicy {
222            column_name: "embedding".to_string(),
223            dim,
224            metric: VectorMetric::Cosine,
225            precision: VectorPrecision::F16,
226            pq: None,
227            keep_raw_for_reranking: false,
228            pre_normalize: false,
229            hnsw_m: None,
230            hnsw_ef_construction: None,
231            rabitq: None,
232        }
233    }
234
235    fn write_file(rows: usize, dim: u32) -> Bytes {
236        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
237        let ids: Vec<i32> = (0..rows as i32).collect();
238        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
239        let embs: Vec<Vec<f32>> = (0..rows)
240            .map(|i| {
241                let mut v = vec![0.0f32; dim as usize];
242                v[i % dim as usize] = 1.0;
243                v
244            })
245            .collect();
246        AilakeFileWriter::new(make_policy(dim))
247            .write(&batch, &embs)
248            .unwrap()
249    }
250
251    #[test]
252    fn is_ailake_file() {
253        let file = write_file(3, 4);
254        let reader = AilakeFileReader::new(file, "embedding", 4);
255        assert!(reader.is_ailake_file());
256    }
257
258    #[test]
259    fn integrity_check_passes() {
260        let file = write_file(10, 8);
261        let reader = AilakeFileReader::new(file, "embedding", 8);
262        reader.verify_integrity().unwrap();
263    }
264
265    #[test]
266    fn centroid_has_correct_dim() {
267        let file = write_file(5, 4);
268        let reader = AilakeFileReader::new(file, "embedding", 4);
269        let centroid = reader.get_centroid().unwrap();
270        assert_eq!(centroid.values.len(), 4);
271    }
272
273    #[test]
274    fn search_finds_nearest() {
275        let dim = 4u32;
276        let file = write_file(4, dim);
277        let reader = AilakeFileReader::new(file, "embedding", dim);
278        let index = reader.load_index().unwrap();
279        let query = vec![1.0f32, 0.0, 0.0, 0.0];
280        let results = index.search(&query, 1, 50);
281        assert_eq!(results.len(), 1);
282        assert_eq!(results[0].0, ailake_core::RowId::new(0));
283    }
284
285    #[test]
286    fn parquet_read_returns_tabular_data() {
287        let file = write_file(3, 4);
288        let reader = AilakeFileReader::new(file, "embedding", 4);
289        let (batch, embs) = reader.read_parquet().unwrap();
290        assert_eq!(batch.num_rows(), 3);
291        assert_eq!(embs.len(), 3);
292    }
293}