Skip to main content

ailake_file/
writer.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use ailake_core::{AilakeResult, Centroid, RowId, VectorStoragePolicy};
3use ailake_index::{HnswBuilder, HnswConfig, HnswSerializer, IvfPqConfig, IvfPqSerializer};
4use ailake_parquet::ParquetVectorWriter;
5use ailake_vec::compute_centroid_and_radius;
6use arrow_array::RecordBatch;
7use bytes::{BufMut, Bytes, BytesMut};
8
9use crate::footer::{
10    AilakeHeader, AilakeTrailer, DistanceMetric, Precision, AILAKE_FORMAT_VERSION,
11    FLAG_INDEX_IVF_PQ, HEADER_SIZE, TRAILER_SIZE,
12};
13
14/// Which index algorithm to embed in the AILK section.
15#[derive(Debug, Clone)]
16pub enum IndexType {
17    /// HNSW (default). Best recall for in-memory workloads.
18    Hnsw(HnswConfig),
19    /// IVF-PQ. Best for S3: 10-100x smaller index, sequential inverted-list reads.
20    IvfPq(IvfPqConfig),
21    /// Detect hardware at write time and pick the best index automatically.
22    ///
23    /// Chooses IVF-PQ when a GPU or ≥8 CPU cores are available AND the dataset
24    /// has ≥5 000 vectors. Falls back to HNSW otherwise (local/low-power hardware).
25    Auto,
26}
27
28impl Default for IndexType {
29    fn default() -> Self {
30        IndexType::Hnsw(HnswConfig::default())
31    }
32}
33
34/// One vector column to embed in a multi-column write.
35pub struct VectorColumnBatch<'a> {
36    pub policy: &'a VectorStoragePolicy,
37    pub embeddings: &'a [Vec<f32>],
38}
39
40pub struct AilakeFileWriter {
41    policy: VectorStoragePolicy,
42    index_type: IndexType,
43}
44
45impl AilakeFileWriter {
46    pub fn new(policy: VectorStoragePolicy) -> Self {
47        Self {
48            policy,
49            index_type: IndexType::default(),
50        }
51    }
52
53    pub fn with_hnsw_config(mut self, config: HnswConfig) -> Self {
54        self.index_type = IndexType::Hnsw(config);
55        self
56    }
57
58    pub fn with_ivf_pq(mut self, config: IvfPqConfig) -> Self {
59        self.index_type = IndexType::IvfPq(config);
60        self
61    }
62
63    pub fn with_index_type(mut self, index_type: IndexType) -> Self {
64        self.index_type = index_type;
65        self
66    }
67
68    /// Use `IndexType::Auto`: detect GPU / CPU cores at write time and pick the
69    /// best index. Equivalent to `.with_index_type(IndexType::Auto)`.
70    pub fn with_auto_index(mut self) -> Self {
71        self.index_type = IndexType::Auto;
72        self
73    }
74
75    /// Write RecordBatch + embeddings as plain Parquet, with no AILK section.
76    ///
77    /// Used by `TableWriter::write_batch_deferred()` to persist data immediately
78    /// while the HNSW index is built asynchronously in the background.
79    /// The resulting file is a valid Parquet readable by any standard reader,
80    /// but `AilakeFileReader::is_ailake_file()` returns false until the HNSW
81    /// section is appended by the background indexing task.
82    pub fn write_parquet_only(
83        &self,
84        batch: &RecordBatch,
85        embeddings: &[Vec<f32>],
86    ) -> AilakeResult<Bytes> {
87        let parquet_writer = ParquetVectorWriter::new(self.policy.clone());
88        let (bytes, _) = parquet_writer.write_batch(batch, embeddings)?;
89        Ok(bytes)
90    }
91
92    /// Write RecordBatch + embeddings into a single AI-Lake file.
93    ///
94    /// Layout:
95    ///   [PAR1][row groups][AILK header+centroid+HNSW+trailer][Parquet footer][footer_len][PAR1]
96    ///
97    /// Standard Parquet readers find PAR1 at the end, read the footer, skip directly to row
98    /// group offsets. The AILK section sits between row groups and footer and is never touched.
99    /// AI-Lake readers find the AILK section via `ailake.footer_offset` in the Parquet footer KV.
100    pub fn write(&self, batch: &RecordBatch, embeddings: &[Vec<f32>]) -> AilakeResult<Bytes> {
101        let col = VectorColumnBatch {
102            policy: &self.policy,
103            embeddings,
104        };
105        self.write_multi(batch, &[col])
106    }
107
108    /// Write RecordBatch + multiple vector columns into a single AI-Lake file.
109    ///
110    /// Each column gets its own AILK section appended sequentially before the Parquet footer.
111    /// Offsets are recorded in Parquet KV metadata:
112    ///   - Primary (first) column: `ailake.footer_offset`
113    ///   - Additional columns: `ailake.{column_name}.footer_offset`
114    ///
115    /// Readers use the column-specific KV key to locate the right AILK section.
116    pub fn write_multi(
117        &self,
118        batch: &RecordBatch,
119        columns: &[VectorColumnBatch<'_>],
120    ) -> AilakeResult<Bytes> {
121        use ailake_core::AilakeError;
122
123        if columns.is_empty() {
124            return Err(AilakeError::InvalidArgument(
125                "write_multi requires at least one vector column".into(),
126            ));
127        }
128
129        let primary = &columns[0];
130        let parquet_writer = ParquetVectorWriter::new(primary.policy.clone());
131
132        // Pass 1 — write Parquet without KV to measure the data section size.
133        let (parquet_v1, record_count) = parquet_writer.write_batch(batch, primary.embeddings)?;
134        let footer_start = parquet_footer_start(&parquet_v1)?;
135
136        // Build all AILK sections sequentially; track running absolute offset.
137        let mut ailk_sections: Vec<Bytes> = Vec::with_capacity(columns.len());
138        let mut kv_owned: Vec<(String, String)> = Vec::with_capacity(columns.len());
139        let mut current_offset = footer_start as u64;
140
141        for (i, col) in columns.iter().enumerate() {
142            let section = build_ailk_section(
143                col.policy,
144                col.embeddings,
145                record_count,
146                current_offset,
147                &self.index_type,
148            )?;
149            let kv_key = if i == 0 {
150                "ailake.footer_offset".to_string()
151            } else {
152                format!("ailake.{}.footer_offset", col.policy.column_name)
153            };
154            kv_owned.push((kv_key, current_offset.to_string()));
155            current_offset += section.len() as u64;
156            ailk_sections.push(section);
157        }
158
159        // Pass 2 — write Parquet with all AILK offset KVs embedded.
160        let kv_refs: Vec<(&str, &str)> = kv_owned
161            .iter()
162            .map(|(k, v)| (k.as_str(), v.as_str()))
163            .collect();
164        let (parquet_v2, _) =
165            parquet_writer.write_batch_with_kv(batch, primary.embeddings, &kv_refs)?;
166        let footer_start_v2 = parquet_footer_start(&parquet_v2)?;
167
168        // Splice: [PAR1 + row groups] + [all AILK sections] + [Parquet footer + PAR1]
169        let total_ailk: usize = ailk_sections.iter().map(|s| s.len()).sum();
170        let total = footer_start_v2 + total_ailk + (parquet_v2.len() - footer_start_v2);
171        let mut out = BytesMut::with_capacity(total);
172        out.put_slice(&parquet_v2[..footer_start_v2]);
173        for section in ailk_sections {
174            out.put(section);
175        }
176        out.put_slice(&parquet_v2[footer_start_v2..]);
177
178        Ok(out.freeze())
179    }
180}
181
182/// Build a complete AILK section (header + centroid + index + trailer) for one vector column.
183fn build_ailk_section(
184    policy: &VectorStoragePolicy,
185    embeddings: &[Vec<f32>],
186    record_count: u64,
187    ailk_abs_offset: u64,
188    index_type: &IndexType,
189) -> AilakeResult<Bytes> {
190    let centroid: Centroid = compute_centroid_and_radius(embeddings, policy.metric);
191    let centroid_bytes = encode_centroid(&centroid);
192
193    // Resolve Auto to a concrete variant before matching.
194    let resolved: IndexType;
195    let index_type = if matches!(index_type, IndexType::Auto) {
196        let profile = ailake_index::HardwareProfile::detect();
197        resolved = if profile.recommend_ivf_pq(embeddings.len()) {
198            IndexType::IvfPq(ailake_index::IvfPqConfig::for_dataset(
199                policy.dim as usize,
200                embeddings.len(),
201            ))
202        } else {
203            IndexType::Hnsw(ailake_index::HnswConfig::default())
204        };
205        &resolved
206    } else {
207        index_type
208    };
209
210    let (index_bytes, flags) = match index_type {
211        IndexType::Hnsw(hnsw_config) => {
212            let mut builder = HnswBuilder::new(policy.dim, policy.metric, hnsw_config.clone());
213            for (i, v) in embeddings.iter().enumerate() {
214                builder.insert(RowId::new(i as u64), v.clone());
215            }
216            let index = builder.build();
217            (HnswSerializer::to_bytes(&index)?, 0u16)
218        }
219        IndexType::IvfPq(ivf_config) => {
220            let row_ids: Vec<RowId> = (0..embeddings.len() as u64).map(RowId::new).collect();
221            let index = ailake_index::IvfPqIndex::train(
222                &row_ids,
223                embeddings,
224                policy.metric,
225                ivf_config.clone(),
226            )?;
227            (IvfPqSerializer::to_bytes(&index)?, FLAG_INDEX_IVF_PQ)
228        }
229        IndexType::Auto => unreachable!("Auto resolved above"),
230    };
231
232    let centroid_offset = HEADER_SIZE as u64;
233    let centroid_len = centroid_bytes.len() as u64;
234    let index_offset_in_ailk = centroid_offset + centroid_len;
235    let index_len = index_bytes.len() as u64;
236    let ailk_total_len = HEADER_SIZE as u64 + centroid_len + index_len + TRAILER_SIZE as u64;
237
238    let header = AilakeHeader {
239        format_version: AILAKE_FORMAT_VERSION,
240        flags,
241        dim: policy.dim,
242        precision: Precision::from(policy.precision),
243        distance_metric: DistanceMetric::from(policy.metric),
244        record_count,
245        centroid_offset,
246        centroid_len,
247        hnsw_offset: index_offset_in_ailk,
248        hnsw_len: index_len,
249    };
250    let trailer = AilakeTrailer {
251        footer_offset: ailk_abs_offset,
252        footer_len: ailk_total_len,
253        format_version: AILAKE_FORMAT_VERSION,
254        flags,
255    };
256
257    let mut buf = BytesMut::with_capacity(ailk_total_len as usize);
258    buf.put_slice(&header.to_bytes());
259    buf.put_slice(&centroid_bytes);
260    buf.put_slice(&index_bytes);
261    buf.put_slice(&trailer.to_bytes());
262    Ok(buf.freeze())
263}
264
265/// Returns the byte offset in `buf` where the Parquet footer thrift starts.
266/// Layout of buf tail: [...footer_thrift...][footer_len u32 LE][PAR1 4 bytes]
267fn parquet_footer_start(buf: &[u8]) -> AilakeResult<usize> {
268    use ailake_core::AilakeError;
269    let len = buf.len();
270    if len < 8 {
271        return Err(AilakeError::Parquet("file too small".into()));
272    }
273    if &buf[len - 4..] != b"PAR1" {
274        return Err(AilakeError::Parquet("missing PAR1 footer magic".into()));
275    }
276    let footer_thrift_len = u32::from_le_bytes(buf[len - 8..len - 4].try_into().unwrap()) as usize;
277    let start = len
278        .checked_sub(8 + footer_thrift_len)
279        .ok_or_else(|| AilakeError::Parquet("footer length overflow".into()))?;
280    Ok(start)
281}
282
283fn encode_centroid(c: &Centroid) -> Vec<u8> {
284    let mut bytes = Vec::with_capacity(c.values.len() * 4 + 4);
285    for &v in &c.values {
286        bytes.extend_from_slice(&v.to_le_bytes());
287    }
288    bytes.extend_from_slice(&c.radius.to_le_bytes());
289    bytes
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use ailake_core::{VectorMetric, VectorPrecision};
296    use arrow_array::{Int32Array, RecordBatch};
297    use arrow_schema::{DataType, Field, Schema};
298    use std::sync::Arc;
299
300    fn make_policy(dim: u32) -> VectorStoragePolicy {
301        VectorStoragePolicy {
302            column_name: "embedding".to_string(),
303            dim,
304            metric: VectorMetric::Cosine,
305            precision: VectorPrecision::F16,
306            pq: None,
307            keep_raw_for_reranking: false,
308        }
309    }
310
311    #[test]
312    fn write_ends_with_par1() {
313        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
314        let batch =
315            RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap();
316        let embs: Vec<Vec<f32>> = (0..3).map(|_| vec![0.1, 0.2, 0.3, 0.4]).collect();
317
318        let writer = AilakeFileWriter::new(make_policy(4));
319        let file = writer.write(&batch, &embs).unwrap();
320
321        assert_eq!(&file[file.len() - 4..], b"PAR1");
322        assert_eq!(&file[..4], b"PAR1");
323        assert!(file.windows(4).any(|w| w == b"AILK"));
324    }
325
326    #[test]
327    fn write_multi_two_columns() {
328        use ailake_core::{VectorMetric, VectorPrecision};
329
330        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
331        let batch =
332            RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap();
333
334        let embs: Vec<Vec<f32>> = (0..3).map(|i| vec![i as f32, 0.0, 0.0, 0.0]).collect();
335        let ctx_embs: Vec<Vec<f32>> = (0..3).map(|i| vec![0.0, i as f32, 0.0, 0.0]).collect();
336
337        let policy1 = make_policy(4);
338        let policy2 = VectorStoragePolicy {
339            column_name: "context_embedding".to_string(),
340            dim: 4,
341            metric: VectorMetric::Cosine,
342            precision: VectorPrecision::F16,
343            pq: None,
344            keep_raw_for_reranking: false,
345        };
346
347        let writer = AilakeFileWriter::new(policy1.clone());
348        let file = writer
349            .write_multi(
350                &batch,
351                &[
352                    VectorColumnBatch {
353                        policy: &policy1,
354                        embeddings: &embs,
355                    },
356                    VectorColumnBatch {
357                        policy: &policy2,
358                        embeddings: &ctx_embs,
359                    },
360                ],
361            )
362            .unwrap();
363
364        // Valid Parquet envelope
365        assert_eq!(&file[..4], b"PAR1");
366        assert_eq!(&file[file.len() - 4..], b"PAR1");
367        // Two AILK sections — magic appears at least twice
368        let ailk_count = file.windows(4).filter(|w| *w == b"AILK").count();
369        assert!(
370            ailk_count >= 2,
371            "expected >= 2 AILK markers, got {ailk_count}"
372        );
373    }
374}