Skip to main content

ailake_file/
writer.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use ailake_core::{AilakeResult, Centroid, RowId, VectorStoragePolicy};
3use ailake_index::{
4    HnswBuilder, HnswConfig, HnswSerializer, IvfPqCodebook, IvfPqConfig, IvfPqIndex,
5    IvfPqSerializer, RaBitQConfig, RaBitQIndex, RaBitQSerializer,
6};
7use ailake_parquet::ParquetVectorWriter;
8use ailake_vec::compute_centroid_and_radius;
9use arrow_array::RecordBatch;
10use bytes::{BufMut, Bytes, BytesMut};
11
12use crate::footer::{
13    AilakeHeader, AilakeTrailer, DistanceMetric, Precision, AILAKE_FORMAT_VERSION,
14    FLAG_INDEX_IVF_PQ, FLAG_INDEX_RABITQ, HEADER_SIZE, TRAILER_SIZE,
15};
16
17/// Which index algorithm to embed in the AILK section.
18#[derive(Debug, Clone)]
19pub enum IndexType {
20    /// HNSW (default). Best recall for in-memory workloads.
21    Hnsw(HnswConfig),
22    /// IVF-PQ. Best for S3: 10-100x smaller index, sequential inverted-list reads.
23    IvfPq(IvfPqConfig),
24    /// Detect hardware at write time and pick the best index automatically.
25    ///
26    /// Chooses IVF-PQ when a GPU or ≥8 CPU cores are available AND the dataset
27    /// has ≥5 000 vectors. Falls back to HNSW otherwise (local/low-power hardware).
28    Auto,
29    /// RaBitQ flat index. Best when storage is the primary constraint:
30    /// 1 bit/dim = 16× smaller than F16. Better recall than naive binary
31    /// quantization via random rotation + unbiased IP estimator.
32    /// Recommended: use with `keep_raw = true` + `rerank_factor ≥ 3` at search time.
33    RaBitQ(RaBitQConfig),
34}
35
36impl Default for IndexType {
37    fn default() -> Self {
38        IndexType::Hnsw(HnswConfig::default())
39    }
40}
41
42/// One vector column to embed in a multi-column write.
43pub struct VectorColumnBatch<'a> {
44    pub policy: &'a VectorStoragePolicy,
45    pub embeddings: &'a [Vec<f32>],
46}
47
48pub struct AilakeFileWriter {
49    policy: VectorStoragePolicy,
50    index_type: IndexType,
51    /// Pre-trained shared codebook. When set, skips k-means for IVF-PQ builds.
52    shared_codebook: Option<std::sync::Arc<IvfPqCodebook>>,
53}
54
55impl AilakeFileWriter {
56    pub fn new(policy: VectorStoragePolicy) -> Self {
57        let index_type = if let Some(rb) = &policy.rabitq {
58            IndexType::RaBitQ(RaBitQConfig {
59                seed: rb.seed,
60                keep_raw: rb.keep_raw,
61            })
62        } else {
63            IndexType::default()
64        };
65        Self {
66            policy,
67            index_type,
68            shared_codebook: None,
69        }
70    }
71
72    /// Use a pre-trained IVF-PQ codebook instead of running k-means.
73    /// Shards built from the same codebook produce comparable ADC distances.
74    pub fn with_shared_ivf_codebook(mut self, codebook: std::sync::Arc<IvfPqCodebook>) -> Self {
75        self.shared_codebook = Some(codebook);
76        self
77    }
78
79    pub fn with_hnsw_config(mut self, config: HnswConfig) -> Self {
80        self.index_type = IndexType::Hnsw(config);
81        self
82    }
83
84    pub fn with_ivf_pq(mut self, config: IvfPqConfig) -> Self {
85        self.index_type = IndexType::IvfPq(config);
86        self
87    }
88
89    pub fn with_index_type(mut self, index_type: IndexType) -> Self {
90        self.index_type = index_type;
91        self
92    }
93
94    /// Use `IndexType::Auto`: detect GPU / CPU cores at write time and pick the
95    /// best index. Equivalent to `.with_index_type(IndexType::Auto)`.
96    pub fn with_auto_index(mut self) -> Self {
97        self.index_type = IndexType::Auto;
98        self
99    }
100
101    /// Write RecordBatch + embeddings as plain Parquet, with no AILK section.
102    ///
103    /// Used by `TableWriter::write_batch_deferred()` to persist data immediately
104    /// while the HNSW index is built asynchronously in the background.
105    /// The resulting file is a valid Parquet readable by any standard reader,
106    /// but `AilakeFileReader::is_ailake_file()` returns false until the HNSW
107    /// section is appended by the background indexing task.
108    pub fn write_parquet_only(
109        &self,
110        batch: &RecordBatch,
111        embeddings: &[Vec<f32>],
112    ) -> AilakeResult<Bytes> {
113        let parquet_writer = ParquetVectorWriter::new(self.policy.clone());
114        let (bytes, _) = parquet_writer.write_batch(batch, embeddings)?;
115        Ok(bytes)
116    }
117
118    /// Write RecordBatch + embeddings into a single AI-Lake file.
119    ///
120    /// Layout:
121    ///   [PAR1][row groups][AILK header+centroid+HNSW+trailer][Parquet footer][footer_len][PAR1]
122    ///
123    /// Standard Parquet readers find PAR1 at the end, read the footer, skip directly to row
124    /// group offsets. The AILK section sits between row groups and footer and is never touched.
125    /// AI-Lake readers find the AILK section via `ailake.footer_offset` in the Parquet footer KV.
126    pub fn write(&self, batch: &RecordBatch, embeddings: &[Vec<f32>]) -> AilakeResult<Bytes> {
127        let col = VectorColumnBatch {
128            policy: &self.policy,
129            embeddings,
130        };
131        self.write_multi(batch, &[col])
132    }
133
134    /// Write RecordBatch + multiple vector columns into a single AI-Lake file.
135    ///
136    /// Each column gets its own AILK section appended sequentially before the Parquet footer.
137    /// Offsets are recorded in Parquet KV metadata:
138    ///   - Primary (first) column: `ailake.footer_offset`
139    ///   - Additional columns: `ailake.{column_name}.footer_offset`
140    ///
141    /// Readers use the column-specific KV key to locate the right AILK section.
142    pub fn write_multi(
143        &self,
144        batch: &RecordBatch,
145        columns: &[VectorColumnBatch<'_>],
146    ) -> AilakeResult<Bytes> {
147        use ailake_core::AilakeError;
148
149        if columns.is_empty() {
150            return Err(AilakeError::InvalidArgument(
151                "write_multi requires at least one vector column".into(),
152            ));
153        }
154
155        let primary = &columns[0];
156        let parquet_writer = ParquetVectorWriter::new(primary.policy.clone());
157
158        // Pass 1 — write Parquet without KV to measure the data section size.
159        let (parquet_v1, record_count) = parquet_writer.write_batch(batch, primary.embeddings)?;
160        let footer_start = parquet_footer_start(&parquet_v1)?;
161
162        // Build all AILK sections sequentially; track running absolute offset.
163        let mut ailk_sections: Vec<Bytes> = Vec::with_capacity(columns.len());
164        let mut kv_owned: Vec<(String, String)> = Vec::with_capacity(columns.len());
165        let mut current_offset = footer_start as u64;
166
167        for (i, col) in columns.iter().enumerate() {
168            let section = build_ailk_section(
169                col.policy,
170                col.embeddings,
171                record_count,
172                current_offset,
173                &self.index_type,
174                self.shared_codebook.as_deref(),
175            )?;
176            let kv_key = if i == 0 {
177                "ailake.footer_offset".to_string()
178            } else {
179                format!("ailake.{}.footer_offset", col.policy.column_name)
180            };
181            kv_owned.push((kv_key, current_offset.to_string()));
182            current_offset += section.len() as u64;
183            ailk_sections.push(section);
184        }
185
186        // Pass 2 — write Parquet with all AILK offset KVs embedded.
187        let kv_refs: Vec<(&str, &str)> = kv_owned
188            .iter()
189            .map(|(k, v)| (k.as_str(), v.as_str()))
190            .collect();
191        let (parquet_v2, _) =
192            parquet_writer.write_batch_with_kv(batch, primary.embeddings, &kv_refs)?;
193        let footer_start_v2 = parquet_footer_start(&parquet_v2)?;
194
195        // Splice: [PAR1 + row groups] + [all AILK sections] + [Parquet footer + PAR1]
196        let total_ailk: usize = ailk_sections.iter().map(|s| s.len()).sum();
197        let total = footer_start_v2 + total_ailk + (parquet_v2.len() - footer_start_v2);
198        let mut out = BytesMut::with_capacity(total);
199        out.put_slice(&parquet_v2[..footer_start_v2]);
200        for section in ailk_sections {
201            out.put(section);
202        }
203        out.put_slice(&parquet_v2[footer_start_v2..]);
204
205        Ok(out.freeze())
206    }
207}
208
209/// Build a complete AILK section (header + centroid + index + trailer) for one vector column.
210fn build_ailk_section(
211    policy: &VectorStoragePolicy,
212    embeddings: &[Vec<f32>],
213    record_count: u64,
214    ailk_abs_offset: u64,
215    index_type: &IndexType,
216    shared_codebook: Option<&IvfPqCodebook>,
217) -> AilakeResult<Bytes> {
218    // Normalize to unit L2 when pre_normalize is set.
219    // Enables the NormalizedCosine fast path: 1-dot(a,b) instead of full cosine.
220    let norm_storage: Vec<Vec<f32>>;
221    let (embeddings, hnsw_metric) =
222        if policy.pre_normalize && policy.metric == ailake_core::VectorMetric::Cosine {
223            norm_storage = embeddings
224                .iter()
225                .map(|v| ailake_vec::normalize_l2(v))
226                .collect();
227            (
228                norm_storage.as_slice(),
229                ailake_core::VectorMetric::NormalizedCosine,
230            )
231        } else {
232            (embeddings, policy.metric)
233        };
234
235    let centroid: Centroid = compute_centroid_and_radius(embeddings, hnsw_metric);
236    let centroid_bytes = encode_centroid(&centroid);
237
238    // Resolve Auto to a concrete variant before matching.
239    let resolved: IndexType;
240    let index_type = if matches!(index_type, IndexType::Auto) {
241        let profile = ailake_index::HardwareProfile::detect();
242        resolved = if profile.recommend_ivf_pq(embeddings.len()) {
243            IndexType::IvfPq(ailake_index::IvfPqConfig::for_dataset(
244                policy.dim as usize,
245                embeddings.len(),
246            ))
247        } else {
248            IndexType::Hnsw(ailake_index::HnswConfig::default())
249        };
250        &resolved
251    } else {
252        index_type
253    };
254
255    let (index_bytes, flags) = match index_type {
256        IndexType::Hnsw(hnsw_config) => {
257            // Policy-level M/ef_construction override the IndexType defaults when set.
258            let config = HnswConfig {
259                m: policy.hnsw_m.map(|v| v as usize).unwrap_or(hnsw_config.m),
260                ef_construction: policy
261                    .hnsw_ef_construction
262                    .map(|v| v as usize)
263                    .unwrap_or(hnsw_config.ef_construction),
264                max_elements: hnsw_config.max_elements,
265            };
266            let mut builder = HnswBuilder::new(policy.dim, hnsw_metric, config);
267            for (i, v) in embeddings.iter().enumerate() {
268                builder.insert(RowId::new(i as u64), v.clone());
269            }
270            let index = builder.build();
271            (HnswSerializer::to_bytes(&index)?, 0u16)
272        }
273        IndexType::IvfPq(ivf_config) => {
274            let row_ids: Vec<RowId> = (0..embeddings.len() as u64).map(RowId::new).collect();
275            let index = if let Some(cb) = shared_codebook {
276                IvfPqIndex::build_with_codebook(&row_ids, embeddings, cb)?
277            } else {
278                ailake_index::IvfPqIndex::train(
279                    &row_ids,
280                    embeddings,
281                    policy.metric,
282                    ivf_config.clone(),
283                )?
284            };
285            (IvfPqSerializer::to_bytes(&index)?, FLAG_INDEX_IVF_PQ)
286        }
287        IndexType::RaBitQ(rb_config) => {
288            let row_ids: Vec<RowId> = (0..embeddings.len() as u64).map(RowId::new).collect();
289            let index = RaBitQIndex::build(
290                &row_ids,
291                embeddings,
292                hnsw_metric,
293                rb_config.clone(),
294                rb_config.keep_raw,
295            )?;
296            (RaBitQSerializer::to_bytes(&index)?, FLAG_INDEX_RABITQ)
297        }
298        IndexType::Auto => unreachable!("Auto resolved above"),
299    };
300
301    let centroid_offset = HEADER_SIZE as u64;
302    let centroid_len = centroid_bytes.len() as u64;
303    let index_offset_in_ailk = centroid_offset + centroid_len;
304    let index_len = index_bytes.len() as u64;
305    let ailk_total_len = HEADER_SIZE as u64 + centroid_len + index_len + TRAILER_SIZE as u64;
306
307    let header = AilakeHeader {
308        format_version: AILAKE_FORMAT_VERSION,
309        flags,
310        dim: policy.dim,
311        precision: Precision::from(policy.precision),
312        distance_metric: DistanceMetric::from(policy.metric),
313        record_count,
314        centroid_offset,
315        centroid_len,
316        hnsw_offset: index_offset_in_ailk,
317        hnsw_len: index_len,
318    };
319    let trailer = AilakeTrailer {
320        footer_offset: ailk_abs_offset,
321        footer_len: ailk_total_len,
322        format_version: AILAKE_FORMAT_VERSION,
323        flags,
324    };
325
326    let mut buf = BytesMut::with_capacity(ailk_total_len as usize);
327    buf.put_slice(&header.to_bytes());
328    buf.put_slice(&centroid_bytes);
329    buf.put_slice(&index_bytes);
330    buf.put_slice(&trailer.to_bytes());
331    Ok(buf.freeze())
332}
333
334/// Returns the byte offset in `buf` where the Parquet footer thrift starts.
335/// Layout of buf tail: [...footer_thrift...][footer_len u32 LE][PAR1 4 bytes]
336fn parquet_footer_start(buf: &[u8]) -> AilakeResult<usize> {
337    use ailake_core::AilakeError;
338    let len = buf.len();
339    if len < 8 {
340        return Err(AilakeError::Parquet("file too small".into()));
341    }
342    if &buf[len - 4..] != b"PAR1" {
343        return Err(AilakeError::Parquet("missing PAR1 footer magic".into()));
344    }
345    let footer_thrift_len = u32::from_le_bytes(buf[len - 8..len - 4].try_into().unwrap()) as usize;
346    let start = len
347        .checked_sub(8 + footer_thrift_len)
348        .ok_or_else(|| AilakeError::Parquet("footer length overflow".into()))?;
349    Ok(start)
350}
351
352fn encode_centroid(c: &Centroid) -> Vec<u8> {
353    let mut bytes = Vec::with_capacity(c.values.len() * 4 + 4);
354    for &v in &c.values {
355        bytes.extend_from_slice(&v.to_le_bytes());
356    }
357    bytes.extend_from_slice(&c.radius.to_le_bytes());
358    bytes
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use ailake_core::{VectorMetric, VectorPrecision};
365    use arrow_array::{Int32Array, RecordBatch};
366    use arrow_schema::{DataType, Field, Schema};
367    use std::sync::Arc;
368
369    fn make_policy(dim: u32) -> VectorStoragePolicy {
370        VectorStoragePolicy {
371            column_name: "embedding".to_string(),
372            dim,
373            metric: VectorMetric::Cosine,
374            precision: VectorPrecision::F16,
375            pq: None,
376            keep_raw_for_reranking: false,
377            pre_normalize: false,
378            hnsw_m: None,
379            hnsw_ef_construction: None,
380            rabitq: None,
381        }
382    }
383
384    #[test]
385    fn write_ends_with_par1() {
386        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
387        let batch =
388            RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap();
389        let embs: Vec<Vec<f32>> = (0..3).map(|_| vec![0.1, 0.2, 0.3, 0.4]).collect();
390
391        let writer = AilakeFileWriter::new(make_policy(4));
392        let file = writer.write(&batch, &embs).unwrap();
393
394        assert_eq!(&file[file.len() - 4..], b"PAR1");
395        assert_eq!(&file[..4], b"PAR1");
396        assert!(file.windows(4).any(|w| w == b"AILK"));
397    }
398
399    #[test]
400    fn write_multi_two_columns() {
401        use ailake_core::{VectorMetric, VectorPrecision};
402
403        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
404        let batch =
405            RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap();
406
407        let embs: Vec<Vec<f32>> = (0..3).map(|i| vec![i as f32, 0.0, 0.0, 0.0]).collect();
408        let ctx_embs: Vec<Vec<f32>> = (0..3).map(|i| vec![0.0, i as f32, 0.0, 0.0]).collect();
409
410        let policy1 = make_policy(4);
411        let policy2 = VectorStoragePolicy {
412            column_name: "context_embedding".to_string(),
413            dim: 4,
414            metric: VectorMetric::Cosine,
415            precision: VectorPrecision::F16,
416            pq: None,
417            keep_raw_for_reranking: false,
418            pre_normalize: false,
419            hnsw_m: None,
420            hnsw_ef_construction: None,
421            rabitq: None,
422        };
423
424        let writer = AilakeFileWriter::new(policy1.clone());
425        let file = writer
426            .write_multi(
427                &batch,
428                &[
429                    VectorColumnBatch {
430                        policy: &policy1,
431                        embeddings: &embs,
432                    },
433                    VectorColumnBatch {
434                        policy: &policy2,
435                        embeddings: &ctx_embs,
436                    },
437                ],
438            )
439            .unwrap();
440
441        // Valid Parquet envelope
442        assert_eq!(&file[..4], b"PAR1");
443        assert_eq!(&file[file.len() - 4..], b"PAR1");
444        // Two AILK sections — magic appears at least twice
445        let ailk_count = file.windows(4).filter(|w| *w == b"AILK").count();
446        assert!(
447            ailk_count >= 2,
448            "expected >= 2 AILK markers, got {ailk_count}"
449        );
450    }
451}