1use ailake_core::{AilakeResult, Centroid, RowId, VectorStoragePolicy};
2use ailake_index::{HnswBuilder, HnswConfig, HnswSerializer, IvfPqConfig, IvfPqSerializer};
3use ailake_parquet::ParquetVectorWriter;
4use ailake_vec::compute_centroid_and_radius;
5use arrow_array::RecordBatch;
6use bytes::{BufMut, Bytes, BytesMut};
7
8use crate::footer::{
9 AilakeHeader, AilakeTrailer, DistanceMetric, Precision, AILAKE_FORMAT_VERSION,
10 FLAG_INDEX_IVF_PQ, HEADER_SIZE, TRAILER_SIZE,
11};
12
13#[derive(Debug, Clone)]
15pub enum IndexType {
16 Hnsw(HnswConfig),
18 IvfPq(IvfPqConfig),
20 Auto,
25}
26
27impl Default for IndexType {
28 fn default() -> Self {
29 IndexType::Hnsw(HnswConfig::default())
30 }
31}
32
33pub struct VectorColumnBatch<'a> {
35 pub policy: &'a VectorStoragePolicy,
36 pub embeddings: &'a [Vec<f32>],
37}
38
39pub struct AilakeFileWriter {
40 policy: VectorStoragePolicy,
41 index_type: IndexType,
42}
43
44impl AilakeFileWriter {
45 pub fn new(policy: VectorStoragePolicy) -> Self {
46 Self {
47 policy,
48 index_type: IndexType::default(),
49 }
50 }
51
52 pub fn with_hnsw_config(mut self, config: HnswConfig) -> Self {
53 self.index_type = IndexType::Hnsw(config);
54 self
55 }
56
57 pub fn with_ivf_pq(mut self, config: IvfPqConfig) -> Self {
58 self.index_type = IndexType::IvfPq(config);
59 self
60 }
61
62 pub fn with_index_type(mut self, index_type: IndexType) -> Self {
63 self.index_type = index_type;
64 self
65 }
66
67 pub fn with_auto_index(mut self) -> Self {
70 self.index_type = IndexType::Auto;
71 self
72 }
73
74 pub fn write_parquet_only(
82 &self,
83 batch: &RecordBatch,
84 embeddings: &[Vec<f32>],
85 ) -> AilakeResult<Bytes> {
86 let parquet_writer = ParquetVectorWriter::new(self.policy.clone());
87 let (bytes, _) = parquet_writer.write_batch(batch, embeddings)?;
88 Ok(bytes)
89 }
90
91 pub fn write(&self, batch: &RecordBatch, embeddings: &[Vec<f32>]) -> AilakeResult<Bytes> {
100 let col = VectorColumnBatch {
101 policy: &self.policy,
102 embeddings,
103 };
104 self.write_multi(batch, &[col])
105 }
106
107 pub fn write_multi(
116 &self,
117 batch: &RecordBatch,
118 columns: &[VectorColumnBatch<'_>],
119 ) -> AilakeResult<Bytes> {
120 use ailake_core::AilakeError;
121
122 if columns.is_empty() {
123 return Err(AilakeError::InvalidArgument(
124 "write_multi requires at least one vector column".into(),
125 ));
126 }
127
128 let primary = &columns[0];
129 let parquet_writer = ParquetVectorWriter::new(primary.policy.clone());
130
131 let (parquet_v1, record_count) = parquet_writer.write_batch(batch, primary.embeddings)?;
133 let footer_start = parquet_footer_start(&parquet_v1)?;
134
135 let mut ailk_sections: Vec<Bytes> = Vec::with_capacity(columns.len());
137 let mut kv_owned: Vec<(String, String)> = Vec::with_capacity(columns.len());
138 let mut current_offset = footer_start as u64;
139
140 for (i, col) in columns.iter().enumerate() {
141 let section = build_ailk_section(
142 col.policy,
143 col.embeddings,
144 record_count,
145 current_offset,
146 &self.index_type,
147 )?;
148 let kv_key = if i == 0 {
149 "ailake.footer_offset".to_string()
150 } else {
151 format!("ailake.{}.footer_offset", col.policy.column_name)
152 };
153 kv_owned.push((kv_key, current_offset.to_string()));
154 current_offset += section.len() as u64;
155 ailk_sections.push(section);
156 }
157
158 let kv_refs: Vec<(&str, &str)> = kv_owned
160 .iter()
161 .map(|(k, v)| (k.as_str(), v.as_str()))
162 .collect();
163 let (parquet_v2, _) =
164 parquet_writer.write_batch_with_kv(batch, primary.embeddings, &kv_refs)?;
165 let footer_start_v2 = parquet_footer_start(&parquet_v2)?;
166
167 let total_ailk: usize = ailk_sections.iter().map(|s| s.len()).sum();
169 let total = footer_start_v2 + total_ailk + (parquet_v2.len() - footer_start_v2);
170 let mut out = BytesMut::with_capacity(total);
171 out.put_slice(&parquet_v2[..footer_start_v2]);
172 for section in ailk_sections {
173 out.put(section);
174 }
175 out.put_slice(&parquet_v2[footer_start_v2..]);
176
177 Ok(out.freeze())
178 }
179}
180
181fn build_ailk_section(
183 policy: &VectorStoragePolicy,
184 embeddings: &[Vec<f32>],
185 record_count: u64,
186 ailk_abs_offset: u64,
187 index_type: &IndexType,
188) -> AilakeResult<Bytes> {
189 let centroid: Centroid = compute_centroid_and_radius(embeddings, policy.metric);
190 let centroid_bytes = encode_centroid(¢roid);
191
192 let resolved: IndexType;
194 let index_type = if matches!(index_type, IndexType::Auto) {
195 let profile = ailake_index::HardwareProfile::detect();
196 resolved = if profile.recommend_ivf_pq(embeddings.len()) {
197 IndexType::IvfPq(ailake_index::IvfPqConfig::for_dataset(
198 policy.dim as usize,
199 embeddings.len(),
200 ))
201 } else {
202 IndexType::Hnsw(ailake_index::HnswConfig::default())
203 };
204 &resolved
205 } else {
206 index_type
207 };
208
209 let (index_bytes, flags) = match index_type {
210 IndexType::Hnsw(hnsw_config) => {
211 let mut builder = HnswBuilder::new(policy.dim, policy.metric, hnsw_config.clone());
212 for (i, v) in embeddings.iter().enumerate() {
213 builder.insert(RowId::new(i as u64), v.clone());
214 }
215 let index = builder.build();
216 (HnswSerializer::to_bytes(&index)?, 0u16)
217 }
218 IndexType::IvfPq(ivf_config) => {
219 let row_ids: Vec<RowId> = (0..embeddings.len() as u64).map(RowId::new).collect();
220 let index = ailake_index::IvfPqIndex::train(
221 &row_ids,
222 embeddings,
223 policy.metric,
224 ivf_config.clone(),
225 )?;
226 (IvfPqSerializer::to_bytes(&index)?, FLAG_INDEX_IVF_PQ)
227 }
228 IndexType::Auto => unreachable!("Auto resolved above"),
229 };
230
231 let centroid_offset = HEADER_SIZE as u64;
232 let centroid_len = centroid_bytes.len() as u64;
233 let index_offset_in_ailk = centroid_offset + centroid_len;
234 let index_len = index_bytes.len() as u64;
235 let ailk_total_len = HEADER_SIZE as u64 + centroid_len + index_len + TRAILER_SIZE as u64;
236
237 let header = AilakeHeader {
238 format_version: AILAKE_FORMAT_VERSION,
239 flags,
240 dim: policy.dim,
241 precision: Precision::from(policy.precision),
242 distance_metric: DistanceMetric::from(policy.metric),
243 record_count,
244 centroid_offset,
245 centroid_len,
246 hnsw_offset: index_offset_in_ailk,
247 hnsw_len: index_len,
248 };
249 let trailer = AilakeTrailer {
250 footer_offset: ailk_abs_offset,
251 footer_len: ailk_total_len,
252 format_version: AILAKE_FORMAT_VERSION,
253 flags,
254 };
255
256 let mut buf = BytesMut::with_capacity(ailk_total_len as usize);
257 buf.put_slice(&header.to_bytes());
258 buf.put_slice(¢roid_bytes);
259 buf.put_slice(&index_bytes);
260 buf.put_slice(&trailer.to_bytes());
261 Ok(buf.freeze())
262}
263
264fn parquet_footer_start(buf: &[u8]) -> AilakeResult<usize> {
267 use ailake_core::AilakeError;
268 let len = buf.len();
269 if len < 8 {
270 return Err(AilakeError::Parquet("file too small".into()));
271 }
272 if &buf[len - 4..] != b"PAR1" {
273 return Err(AilakeError::Parquet("missing PAR1 footer magic".into()));
274 }
275 let footer_thrift_len = u32::from_le_bytes(buf[len - 8..len - 4].try_into().unwrap()) as usize;
276 let start = len
277 .checked_sub(8 + footer_thrift_len)
278 .ok_or_else(|| AilakeError::Parquet("footer length overflow".into()))?;
279 Ok(start)
280}
281
282fn encode_centroid(c: &Centroid) -> Vec<u8> {
283 let mut bytes = Vec::with_capacity(c.values.len() * 4 + 4);
284 for &v in &c.values {
285 bytes.extend_from_slice(&v.to_le_bytes());
286 }
287 bytes.extend_from_slice(&c.radius.to_le_bytes());
288 bytes
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use ailake_core::{VectorMetric, VectorPrecision};
295 use arrow_array::{Int32Array, RecordBatch};
296 use arrow_schema::{DataType, Field, Schema};
297 use std::sync::Arc;
298
299 fn make_policy(dim: u32) -> VectorStoragePolicy {
300 VectorStoragePolicy {
301 column_name: "embedding".to_string(),
302 dim,
303 metric: VectorMetric::Cosine,
304 precision: VectorPrecision::F16,
305 pq: None,
306 keep_raw_for_reranking: false,
307 }
308 }
309
310 #[test]
311 fn write_ends_with_par1() {
312 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
313 let batch =
314 RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap();
315 let embs: Vec<Vec<f32>> = (0..3).map(|_| vec![0.1, 0.2, 0.3, 0.4]).collect();
316
317 let writer = AilakeFileWriter::new(make_policy(4));
318 let file = writer.write(&batch, &embs).unwrap();
319
320 assert_eq!(&file[file.len() - 4..], b"PAR1");
321 assert_eq!(&file[..4], b"PAR1");
322 assert!(file.windows(4).any(|w| w == b"AILK"));
323 }
324
325 #[test]
326 fn write_multi_two_columns() {
327 use ailake_core::{VectorMetric, VectorPrecision};
328
329 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
330 let batch =
331 RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap();
332
333 let embs: Vec<Vec<f32>> = (0..3).map(|i| vec![i as f32, 0.0, 0.0, 0.0]).collect();
334 let ctx_embs: Vec<Vec<f32>> = (0..3).map(|i| vec![0.0, i as f32, 0.0, 0.0]).collect();
335
336 let policy1 = make_policy(4);
337 let policy2 = VectorStoragePolicy {
338 column_name: "context_embedding".to_string(),
339 dim: 4,
340 metric: VectorMetric::Cosine,
341 precision: VectorPrecision::F16,
342 pq: None,
343 keep_raw_for_reranking: false,
344 };
345
346 let writer = AilakeFileWriter::new(policy1.clone());
347 let file = writer
348 .write_multi(
349 &batch,
350 &[
351 VectorColumnBatch {
352 policy: &policy1,
353 embeddings: &embs,
354 },
355 VectorColumnBatch {
356 policy: &policy2,
357 embeddings: &ctx_embs,
358 },
359 ],
360 )
361 .unwrap();
362
363 assert_eq!(&file[..4], b"PAR1");
365 assert_eq!(&file[file.len() - 4..], b"PAR1");
366 let ailk_count = file.windows(4).filter(|w| *w == b"AILK").count();
368 assert!(
369 ailk_count >= 2,
370 "expected >= 2 AILK markers, got {ailk_count}"
371 );
372 }
373}