1use ailake_core::{AilakeResult, Centroid, RowId, VectorStoragePolicy};
3use ailake_index::{
4 HnswBuilder, HnswConfig, HnswSerializer, IvfPqCodebook, IvfPqConfig, IvfPqIndex,
5 IvfPqSerializer, RaBitQConfig, RaBitQIndex, RaBitQSerializer,
6};
7use ailake_parquet::ParquetVectorWriter;
8use ailake_vec::compute_centroid_and_radius;
9use arrow_array::RecordBatch;
10use bytes::{BufMut, Bytes, BytesMut};
11
12use crate::footer::{
13 AilakeHeader, AilakeTrailer, DistanceMetric, Precision, AILAKE_FORMAT_VERSION,
14 FLAG_INDEX_IVF_PQ, FLAG_INDEX_RABITQ, HEADER_SIZE, TRAILER_SIZE,
15};
16
17#[derive(Debug, Clone)]
19pub enum IndexType {
20 Hnsw(HnswConfig),
22 IvfPq(IvfPqConfig),
24 Auto,
29 RaBitQ(RaBitQConfig),
34}
35
36impl Default for IndexType {
37 fn default() -> Self {
38 IndexType::Hnsw(HnswConfig::default())
39 }
40}
41
42pub struct VectorColumnBatch<'a> {
44 pub policy: &'a VectorStoragePolicy,
45 pub embeddings: &'a [Vec<f32>],
46}
47
48pub struct AilakeFileWriter {
49 policy: VectorStoragePolicy,
50 index_type: IndexType,
51 shared_codebook: Option<std::sync::Arc<IvfPqCodebook>>,
53}
54
55impl AilakeFileWriter {
56 pub fn new(policy: VectorStoragePolicy) -> Self {
57 let index_type = if let Some(rb) = &policy.rabitq {
58 IndexType::RaBitQ(RaBitQConfig {
59 seed: rb.seed,
60 keep_raw: rb.keep_raw,
61 })
62 } else {
63 IndexType::default()
64 };
65 Self {
66 policy,
67 index_type,
68 shared_codebook: None,
69 }
70 }
71
72 pub fn with_shared_ivf_codebook(mut self, codebook: std::sync::Arc<IvfPqCodebook>) -> Self {
75 self.shared_codebook = Some(codebook);
76 self
77 }
78
79 pub fn with_hnsw_config(mut self, config: HnswConfig) -> Self {
80 self.index_type = IndexType::Hnsw(config);
81 self
82 }
83
84 pub fn with_ivf_pq(mut self, config: IvfPqConfig) -> Self {
85 self.index_type = IndexType::IvfPq(config);
86 self
87 }
88
89 pub fn with_index_type(mut self, index_type: IndexType) -> Self {
90 self.index_type = index_type;
91 self
92 }
93
94 pub fn with_auto_index(mut self) -> Self {
97 self.index_type = IndexType::Auto;
98 self
99 }
100
101 pub fn write_parquet_only(
109 &self,
110 batch: &RecordBatch,
111 embeddings: &[Vec<f32>],
112 ) -> AilakeResult<Bytes> {
113 let parquet_writer = ParquetVectorWriter::new(self.policy.clone());
114 let (bytes, _) = parquet_writer.write_batch(batch, embeddings)?;
115 Ok(bytes)
116 }
117
118 pub fn write(&self, batch: &RecordBatch, embeddings: &[Vec<f32>]) -> AilakeResult<Bytes> {
127 let col = VectorColumnBatch {
128 policy: &self.policy,
129 embeddings,
130 };
131 self.write_multi(batch, &[col])
132 }
133
134 pub fn write_multi(
143 &self,
144 batch: &RecordBatch,
145 columns: &[VectorColumnBatch<'_>],
146 ) -> AilakeResult<Bytes> {
147 use ailake_core::AilakeError;
148
149 if columns.is_empty() {
150 return Err(AilakeError::InvalidArgument(
151 "write_multi requires at least one vector column".into(),
152 ));
153 }
154
155 let primary = &columns[0];
156 let parquet_writer = ParquetVectorWriter::new(primary.policy.clone());
157
158 let (parquet_v1, record_count) = parquet_writer.write_batch(batch, primary.embeddings)?;
160 let footer_start = parquet_footer_start(&parquet_v1)?;
161
162 let mut ailk_sections: Vec<Bytes> = Vec::with_capacity(columns.len());
164 let mut kv_owned: Vec<(String, String)> = Vec::with_capacity(columns.len());
165 let mut current_offset = footer_start as u64;
166
167 for (i, col) in columns.iter().enumerate() {
168 let section = build_ailk_section(
169 col.policy,
170 col.embeddings,
171 record_count,
172 current_offset,
173 &self.index_type,
174 self.shared_codebook.as_deref(),
175 )?;
176 let kv_key = if i == 0 {
177 "ailake.footer_offset".to_string()
178 } else {
179 format!("ailake.{}.footer_offset", col.policy.column_name)
180 };
181 kv_owned.push((kv_key, current_offset.to_string()));
182 current_offset += section.len() as u64;
183 ailk_sections.push(section);
184 }
185
186 let kv_refs: Vec<(&str, &str)> = kv_owned
188 .iter()
189 .map(|(k, v)| (k.as_str(), v.as_str()))
190 .collect();
191 let (parquet_v2, _) =
192 parquet_writer.write_batch_with_kv(batch, primary.embeddings, &kv_refs)?;
193 let footer_start_v2 = parquet_footer_start(&parquet_v2)?;
194
195 let total_ailk: usize = ailk_sections.iter().map(|s| s.len()).sum();
197 let total = footer_start_v2 + total_ailk + (parquet_v2.len() - footer_start_v2);
198 let mut out = BytesMut::with_capacity(total);
199 out.put_slice(&parquet_v2[..footer_start_v2]);
200 for section in ailk_sections {
201 out.put(section);
202 }
203 out.put_slice(&parquet_v2[footer_start_v2..]);
204
205 Ok(out.freeze())
206 }
207}
208
209fn build_ailk_section(
211 policy: &VectorStoragePolicy,
212 embeddings: &[Vec<f32>],
213 record_count: u64,
214 ailk_abs_offset: u64,
215 index_type: &IndexType,
216 shared_codebook: Option<&IvfPqCodebook>,
217) -> AilakeResult<Bytes> {
218 let norm_storage: Vec<Vec<f32>>;
221 let (embeddings, hnsw_metric) =
222 if policy.pre_normalize && policy.metric == ailake_core::VectorMetric::Cosine {
223 norm_storage = embeddings
224 .iter()
225 .map(|v| ailake_vec::normalize_l2(v))
226 .collect();
227 (
228 norm_storage.as_slice(),
229 ailake_core::VectorMetric::NormalizedCosine,
230 )
231 } else {
232 (embeddings, policy.metric)
233 };
234
235 let centroid: Centroid = compute_centroid_and_radius(embeddings, hnsw_metric);
236 let centroid_bytes = encode_centroid(¢roid);
237
238 let resolved: IndexType;
240 let index_type = if matches!(index_type, IndexType::Auto) {
241 let profile = ailake_index::HardwareProfile::detect();
242 resolved = if profile.recommend_ivf_pq(embeddings.len()) {
243 IndexType::IvfPq(ailake_index::IvfPqConfig::for_dataset(
244 policy.dim as usize,
245 embeddings.len(),
246 ))
247 } else {
248 IndexType::Hnsw(ailake_index::HnswConfig::default())
249 };
250 &resolved
251 } else {
252 index_type
253 };
254
255 let (index_bytes, flags) = match index_type {
256 IndexType::Hnsw(hnsw_config) => {
257 let config = HnswConfig {
259 m: policy.hnsw_m.map(|v| v as usize).unwrap_or(hnsw_config.m),
260 ef_construction: policy
261 .hnsw_ef_construction
262 .map(|v| v as usize)
263 .unwrap_or(hnsw_config.ef_construction),
264 max_elements: hnsw_config.max_elements,
265 };
266 let mut builder = HnswBuilder::new(policy.dim, hnsw_metric, config);
267 for (i, v) in embeddings.iter().enumerate() {
268 builder.insert(RowId::new(i as u64), v.clone());
269 }
270 let index = builder.build();
271 (HnswSerializer::to_bytes(&index)?, 0u16)
272 }
273 IndexType::IvfPq(ivf_config) => {
274 let row_ids: Vec<RowId> = (0..embeddings.len() as u64).map(RowId::new).collect();
275 let index = if let Some(cb) = shared_codebook {
276 IvfPqIndex::build_with_codebook(&row_ids, embeddings, cb)?
277 } else {
278 ailake_index::IvfPqIndex::train(
279 &row_ids,
280 embeddings,
281 policy.metric,
282 ivf_config.clone(),
283 )?
284 };
285 (IvfPqSerializer::to_bytes(&index)?, FLAG_INDEX_IVF_PQ)
286 }
287 IndexType::RaBitQ(rb_config) => {
288 let row_ids: Vec<RowId> = (0..embeddings.len() as u64).map(RowId::new).collect();
289 let index = RaBitQIndex::build(
290 &row_ids,
291 embeddings,
292 hnsw_metric,
293 rb_config.clone(),
294 rb_config.keep_raw,
295 )?;
296 (RaBitQSerializer::to_bytes(&index)?, FLAG_INDEX_RABITQ)
297 }
298 IndexType::Auto => unreachable!("Auto resolved above"),
299 };
300
301 let centroid_offset = HEADER_SIZE as u64;
302 let centroid_len = centroid_bytes.len() as u64;
303 let index_offset_in_ailk = centroid_offset + centroid_len;
304 let index_len = index_bytes.len() as u64;
305 let ailk_total_len = HEADER_SIZE as u64 + centroid_len + index_len + TRAILER_SIZE as u64;
306
307 let header = AilakeHeader {
308 format_version: AILAKE_FORMAT_VERSION,
309 flags,
310 dim: policy.dim,
311 precision: Precision::from(policy.precision),
312 distance_metric: DistanceMetric::from(policy.metric),
313 record_count,
314 centroid_offset,
315 centroid_len,
316 hnsw_offset: index_offset_in_ailk,
317 hnsw_len: index_len,
318 };
319 let trailer = AilakeTrailer {
320 footer_offset: ailk_abs_offset,
321 footer_len: ailk_total_len,
322 format_version: AILAKE_FORMAT_VERSION,
323 flags,
324 };
325
326 let mut buf = BytesMut::with_capacity(ailk_total_len as usize);
327 buf.put_slice(&header.to_bytes());
328 buf.put_slice(¢roid_bytes);
329 buf.put_slice(&index_bytes);
330 buf.put_slice(&trailer.to_bytes());
331 Ok(buf.freeze())
332}
333
334fn parquet_footer_start(buf: &[u8]) -> AilakeResult<usize> {
337 use ailake_core::AilakeError;
338 let len = buf.len();
339 if len < 8 {
340 return Err(AilakeError::Parquet("file too small".into()));
341 }
342 if &buf[len - 4..] != b"PAR1" {
343 return Err(AilakeError::Parquet("missing PAR1 footer magic".into()));
344 }
345 let footer_thrift_len = u32::from_le_bytes(buf[len - 8..len - 4].try_into().unwrap()) as usize;
346 let start = len
347 .checked_sub(8 + footer_thrift_len)
348 .ok_or_else(|| AilakeError::Parquet("footer length overflow".into()))?;
349 Ok(start)
350}
351
352fn encode_centroid(c: &Centroid) -> Vec<u8> {
353 let mut bytes = Vec::with_capacity(c.values.len() * 4 + 4);
354 for &v in &c.values {
355 bytes.extend_from_slice(&v.to_le_bytes());
356 }
357 bytes.extend_from_slice(&c.radius.to_le_bytes());
358 bytes
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use ailake_core::{VectorMetric, VectorPrecision};
365 use arrow_array::{Int32Array, RecordBatch};
366 use arrow_schema::{DataType, Field, Schema};
367 use std::sync::Arc;
368
369 fn make_policy(dim: u32) -> VectorStoragePolicy {
370 VectorStoragePolicy {
371 column_name: "embedding".to_string(),
372 dim,
373 metric: VectorMetric::Cosine,
374 precision: VectorPrecision::F16,
375 pq: None,
376 keep_raw_for_reranking: false,
377 pre_normalize: false,
378 hnsw_m: None,
379 hnsw_ef_construction: None,
380 rabitq: None,
381 }
382 }
383
384 #[test]
385 fn write_ends_with_par1() {
386 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
387 let batch =
388 RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap();
389 let embs: Vec<Vec<f32>> = (0..3).map(|_| vec![0.1, 0.2, 0.3, 0.4]).collect();
390
391 let writer = AilakeFileWriter::new(make_policy(4));
392 let file = writer.write(&batch, &embs).unwrap();
393
394 assert_eq!(&file[file.len() - 4..], b"PAR1");
395 assert_eq!(&file[..4], b"PAR1");
396 assert!(file.windows(4).any(|w| w == b"AILK"));
397 }
398
399 #[test]
400 fn write_multi_two_columns() {
401 use ailake_core::{VectorMetric, VectorPrecision};
402
403 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
404 let batch =
405 RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap();
406
407 let embs: Vec<Vec<f32>> = (0..3).map(|i| vec![i as f32, 0.0, 0.0, 0.0]).collect();
408 let ctx_embs: Vec<Vec<f32>> = (0..3).map(|i| vec![0.0, i as f32, 0.0, 0.0]).collect();
409
410 let policy1 = make_policy(4);
411 let policy2 = VectorStoragePolicy {
412 column_name: "context_embedding".to_string(),
413 dim: 4,
414 metric: VectorMetric::Cosine,
415 precision: VectorPrecision::F16,
416 pq: None,
417 keep_raw_for_reranking: false,
418 pre_normalize: false,
419 hnsw_m: None,
420 hnsw_ef_construction: None,
421 rabitq: None,
422 };
423
424 let writer = AilakeFileWriter::new(policy1.clone());
425 let file = writer
426 .write_multi(
427 &batch,
428 &[
429 VectorColumnBatch {
430 policy: &policy1,
431 embeddings: &embs,
432 },
433 VectorColumnBatch {
434 policy: &policy2,
435 embeddings: &ctx_embs,
436 },
437 ],
438 )
439 .unwrap();
440
441 assert_eq!(&file[..4], b"PAR1");
443 assert_eq!(&file[file.len() - 4..], b"PAR1");
444 let ailk_count = file.windows(4).filter(|w| *w == b"AILK").count();
446 assert!(
447 ailk_count >= 2,
448 "expected >= 2 AILK markers, got {ailk_count}"
449 );
450 }
451}