Skip to main content

ailake_file/
writer.rs

1use ailake_core::{AilakeResult, Centroid, RowId, VectorStoragePolicy};
2use ailake_index::{HnswBuilder, HnswConfig, HnswSerializer, IvfPqConfig, IvfPqSerializer};
3use ailake_parquet::ParquetVectorWriter;
4use ailake_vec::compute_centroid_and_radius;
5use arrow_array::RecordBatch;
6use bytes::{BufMut, Bytes, BytesMut};
7
8use crate::footer::{
9    AilakeHeader, AilakeTrailer, DistanceMetric, Precision, AILAKE_FORMAT_VERSION,
10    FLAG_INDEX_IVF_PQ, HEADER_SIZE, TRAILER_SIZE,
11};
12
13/// Which index algorithm to embed in the AILK section.
14#[derive(Debug, Clone)]
15pub enum IndexType {
16    /// HNSW (default). Best recall for in-memory workloads.
17    Hnsw(HnswConfig),
18    /// IVF-PQ. Best for S3: 10-100x smaller index, sequential inverted-list reads.
19    IvfPq(IvfPqConfig),
20    /// Detect hardware at write time and pick the best index automatically.
21    ///
22    /// Chooses IVF-PQ when a GPU or ≥8 CPU cores are available AND the dataset
23    /// has ≥5 000 vectors. Falls back to HNSW otherwise (local/low-power hardware).
24    Auto,
25}
26
27impl Default for IndexType {
28    fn default() -> Self {
29        IndexType::Hnsw(HnswConfig::default())
30    }
31}
32
33/// One vector column to embed in a multi-column write.
34pub struct VectorColumnBatch<'a> {
35    pub policy: &'a VectorStoragePolicy,
36    pub embeddings: &'a [Vec<f32>],
37}
38
39pub struct AilakeFileWriter {
40    policy: VectorStoragePolicy,
41    index_type: IndexType,
42}
43
44impl AilakeFileWriter {
45    pub fn new(policy: VectorStoragePolicy) -> Self {
46        Self {
47            policy,
48            index_type: IndexType::default(),
49        }
50    }
51
52    pub fn with_hnsw_config(mut self, config: HnswConfig) -> Self {
53        self.index_type = IndexType::Hnsw(config);
54        self
55    }
56
57    pub fn with_ivf_pq(mut self, config: IvfPqConfig) -> Self {
58        self.index_type = IndexType::IvfPq(config);
59        self
60    }
61
62    pub fn with_index_type(mut self, index_type: IndexType) -> Self {
63        self.index_type = index_type;
64        self
65    }
66
67    /// Use `IndexType::Auto`: detect GPU / CPU cores at write time and pick the
68    /// best index. Equivalent to `.with_index_type(IndexType::Auto)`.
69    pub fn with_auto_index(mut self) -> Self {
70        self.index_type = IndexType::Auto;
71        self
72    }
73
74    /// Write RecordBatch + embeddings as plain Parquet, with no AILK section.
75    ///
76    /// Used by `TableWriter::write_batch_deferred()` to persist data immediately
77    /// while the HNSW index is built asynchronously in the background.
78    /// The resulting file is a valid Parquet readable by any standard reader,
79    /// but `AilakeFileReader::is_ailake_file()` returns false until the HNSW
80    /// section is appended by the background indexing task.
81    pub fn write_parquet_only(
82        &self,
83        batch: &RecordBatch,
84        embeddings: &[Vec<f32>],
85    ) -> AilakeResult<Bytes> {
86        let parquet_writer = ParquetVectorWriter::new(self.policy.clone());
87        let (bytes, _) = parquet_writer.write_batch(batch, embeddings)?;
88        Ok(bytes)
89    }
90
91    /// Write RecordBatch + embeddings into a single AI-Lake file.
92    ///
93    /// Layout:
94    ///   [PAR1][row groups][AILK header+centroid+HNSW+trailer][Parquet footer][footer_len][PAR1]
95    ///
96    /// Standard Parquet readers find PAR1 at the end, read the footer, skip directly to row
97    /// group offsets. The AILK section sits between row groups and footer and is never touched.
98    /// AI-Lake readers find the AILK section via `ailake.footer_offset` in the Parquet footer KV.
99    pub fn write(&self, batch: &RecordBatch, embeddings: &[Vec<f32>]) -> AilakeResult<Bytes> {
100        let col = VectorColumnBatch {
101            policy: &self.policy,
102            embeddings,
103        };
104        self.write_multi(batch, &[col])
105    }
106
107    /// Write RecordBatch + multiple vector columns into a single AI-Lake file.
108    ///
109    /// Each column gets its own AILK section appended sequentially before the Parquet footer.
110    /// Offsets are recorded in Parquet KV metadata:
111    ///   - Primary (first) column: `ailake.footer_offset`
112    ///   - Additional columns: `ailake.{column_name}.footer_offset`
113    ///
114    /// Readers use the column-specific KV key to locate the right AILK section.
115    pub fn write_multi(
116        &self,
117        batch: &RecordBatch,
118        columns: &[VectorColumnBatch<'_>],
119    ) -> AilakeResult<Bytes> {
120        use ailake_core::AilakeError;
121
122        if columns.is_empty() {
123            return Err(AilakeError::InvalidArgument(
124                "write_multi requires at least one vector column".into(),
125            ));
126        }
127
128        let primary = &columns[0];
129        let parquet_writer = ParquetVectorWriter::new(primary.policy.clone());
130
131        // Pass 1 — write Parquet without KV to measure the data section size.
132        let (parquet_v1, record_count) = parquet_writer.write_batch(batch, primary.embeddings)?;
133        let footer_start = parquet_footer_start(&parquet_v1)?;
134
135        // Build all AILK sections sequentially; track running absolute offset.
136        let mut ailk_sections: Vec<Bytes> = Vec::with_capacity(columns.len());
137        let mut kv_owned: Vec<(String, String)> = Vec::with_capacity(columns.len());
138        let mut current_offset = footer_start as u64;
139
140        for (i, col) in columns.iter().enumerate() {
141            let section = build_ailk_section(
142                col.policy,
143                col.embeddings,
144                record_count,
145                current_offset,
146                &self.index_type,
147            )?;
148            let kv_key = if i == 0 {
149                "ailake.footer_offset".to_string()
150            } else {
151                format!("ailake.{}.footer_offset", col.policy.column_name)
152            };
153            kv_owned.push((kv_key, current_offset.to_string()));
154            current_offset += section.len() as u64;
155            ailk_sections.push(section);
156        }
157
158        // Pass 2 — write Parquet with all AILK offset KVs embedded.
159        let kv_refs: Vec<(&str, &str)> = kv_owned
160            .iter()
161            .map(|(k, v)| (k.as_str(), v.as_str()))
162            .collect();
163        let (parquet_v2, _) =
164            parquet_writer.write_batch_with_kv(batch, primary.embeddings, &kv_refs)?;
165        let footer_start_v2 = parquet_footer_start(&parquet_v2)?;
166
167        // Splice: [PAR1 + row groups] + [all AILK sections] + [Parquet footer + PAR1]
168        let total_ailk: usize = ailk_sections.iter().map(|s| s.len()).sum();
169        let total = footer_start_v2 + total_ailk + (parquet_v2.len() - footer_start_v2);
170        let mut out = BytesMut::with_capacity(total);
171        out.put_slice(&parquet_v2[..footer_start_v2]);
172        for section in ailk_sections {
173            out.put(section);
174        }
175        out.put_slice(&parquet_v2[footer_start_v2..]);
176
177        Ok(out.freeze())
178    }
179}
180
181/// Build a complete AILK section (header + centroid + index + trailer) for one vector column.
182fn build_ailk_section(
183    policy: &VectorStoragePolicy,
184    embeddings: &[Vec<f32>],
185    record_count: u64,
186    ailk_abs_offset: u64,
187    index_type: &IndexType,
188) -> AilakeResult<Bytes> {
189    let centroid: Centroid = compute_centroid_and_radius(embeddings, policy.metric);
190    let centroid_bytes = encode_centroid(&centroid);
191
192    // Resolve Auto to a concrete variant before matching.
193    let resolved: IndexType;
194    let index_type = if matches!(index_type, IndexType::Auto) {
195        let profile = ailake_index::HardwareProfile::detect();
196        resolved = if profile.recommend_ivf_pq(embeddings.len()) {
197            IndexType::IvfPq(ailake_index::IvfPqConfig::for_dataset(
198                policy.dim as usize,
199                embeddings.len(),
200            ))
201        } else {
202            IndexType::Hnsw(ailake_index::HnswConfig::default())
203        };
204        &resolved
205    } else {
206        index_type
207    };
208
209    let (index_bytes, flags) = match index_type {
210        IndexType::Hnsw(hnsw_config) => {
211            let mut builder = HnswBuilder::new(policy.dim, policy.metric, hnsw_config.clone());
212            for (i, v) in embeddings.iter().enumerate() {
213                builder.insert(RowId::new(i as u64), v.clone());
214            }
215            let index = builder.build();
216            (HnswSerializer::to_bytes(&index)?, 0u16)
217        }
218        IndexType::IvfPq(ivf_config) => {
219            let row_ids: Vec<RowId> = (0..embeddings.len() as u64).map(RowId::new).collect();
220            let index = ailake_index::IvfPqIndex::train(
221                &row_ids,
222                embeddings,
223                policy.metric,
224                ivf_config.clone(),
225            )?;
226            (IvfPqSerializer::to_bytes(&index)?, FLAG_INDEX_IVF_PQ)
227        }
228        IndexType::Auto => unreachable!("Auto resolved above"),
229    };
230
231    let centroid_offset = HEADER_SIZE as u64;
232    let centroid_len = centroid_bytes.len() as u64;
233    let index_offset_in_ailk = centroid_offset + centroid_len;
234    let index_len = index_bytes.len() as u64;
235    let ailk_total_len = HEADER_SIZE as u64 + centroid_len + index_len + TRAILER_SIZE as u64;
236
237    let header = AilakeHeader {
238        format_version: AILAKE_FORMAT_VERSION,
239        flags,
240        dim: policy.dim,
241        precision: Precision::from(policy.precision),
242        distance_metric: DistanceMetric::from(policy.metric),
243        record_count,
244        centroid_offset,
245        centroid_len,
246        hnsw_offset: index_offset_in_ailk,
247        hnsw_len: index_len,
248    };
249    let trailer = AilakeTrailer {
250        footer_offset: ailk_abs_offset,
251        footer_len: ailk_total_len,
252        format_version: AILAKE_FORMAT_VERSION,
253        flags,
254    };
255
256    let mut buf = BytesMut::with_capacity(ailk_total_len as usize);
257    buf.put_slice(&header.to_bytes());
258    buf.put_slice(&centroid_bytes);
259    buf.put_slice(&index_bytes);
260    buf.put_slice(&trailer.to_bytes());
261    Ok(buf.freeze())
262}
263
264/// Returns the byte offset in `buf` where the Parquet footer thrift starts.
265/// Layout of buf tail: [...footer_thrift...][footer_len u32 LE][PAR1 4 bytes]
266fn parquet_footer_start(buf: &[u8]) -> AilakeResult<usize> {
267    use ailake_core::AilakeError;
268    let len = buf.len();
269    if len < 8 {
270        return Err(AilakeError::Parquet("file too small".into()));
271    }
272    if &buf[len - 4..] != b"PAR1" {
273        return Err(AilakeError::Parquet("missing PAR1 footer magic".into()));
274    }
275    let footer_thrift_len = u32::from_le_bytes(buf[len - 8..len - 4].try_into().unwrap()) as usize;
276    let start = len
277        .checked_sub(8 + footer_thrift_len)
278        .ok_or_else(|| AilakeError::Parquet("footer length overflow".into()))?;
279    Ok(start)
280}
281
282fn encode_centroid(c: &Centroid) -> Vec<u8> {
283    let mut bytes = Vec::with_capacity(c.values.len() * 4 + 4);
284    for &v in &c.values {
285        bytes.extend_from_slice(&v.to_le_bytes());
286    }
287    bytes.extend_from_slice(&c.radius.to_le_bytes());
288    bytes
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use ailake_core::{VectorMetric, VectorPrecision};
295    use arrow_array::{Int32Array, RecordBatch};
296    use arrow_schema::{DataType, Field, Schema};
297    use std::sync::Arc;
298
299    fn make_policy(dim: u32) -> VectorStoragePolicy {
300        VectorStoragePolicy {
301            column_name: "embedding".to_string(),
302            dim,
303            metric: VectorMetric::Cosine,
304            precision: VectorPrecision::F16,
305            pq: None,
306            keep_raw_for_reranking: false,
307        }
308    }
309
310    #[test]
311    fn write_ends_with_par1() {
312        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
313        let batch =
314            RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap();
315        let embs: Vec<Vec<f32>> = (0..3).map(|_| vec![0.1, 0.2, 0.3, 0.4]).collect();
316
317        let writer = AilakeFileWriter::new(make_policy(4));
318        let file = writer.write(&batch, &embs).unwrap();
319
320        assert_eq!(&file[file.len() - 4..], b"PAR1");
321        assert_eq!(&file[..4], b"PAR1");
322        assert!(file.windows(4).any(|w| w == b"AILK"));
323    }
324
325    #[test]
326    fn write_multi_two_columns() {
327        use ailake_core::{VectorMetric, VectorPrecision};
328
329        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
330        let batch =
331            RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap();
332
333        let embs: Vec<Vec<f32>> = (0..3).map(|i| vec![i as f32, 0.0, 0.0, 0.0]).collect();
334        let ctx_embs: Vec<Vec<f32>> = (0..3).map(|i| vec![0.0, i as f32, 0.0, 0.0]).collect();
335
336        let policy1 = make_policy(4);
337        let policy2 = VectorStoragePolicy {
338            column_name: "context_embedding".to_string(),
339            dim: 4,
340            metric: VectorMetric::Cosine,
341            precision: VectorPrecision::F16,
342            pq: None,
343            keep_raw_for_reranking: false,
344        };
345
346        let writer = AilakeFileWriter::new(policy1.clone());
347        let file = writer
348            .write_multi(
349                &batch,
350                &[
351                    VectorColumnBatch {
352                        policy: &policy1,
353                        embeddings: &embs,
354                    },
355                    VectorColumnBatch {
356                        policy: &policy2,
357                        embeddings: &ctx_embs,
358                    },
359                ],
360            )
361            .unwrap();
362
363        // Valid Parquet envelope
364        assert_eq!(&file[..4], b"PAR1");
365        assert_eq!(&file[file.len() - 4..], b"PAR1");
366        // Two AILK sections — magic appears at least twice
367        let ailk_count = file.windows(4).filter(|w| *w == b"AILK").count();
368        assert!(
369            ailk_count >= 2,
370            "expected >= 2 AILK markers, got {ailk_count}"
371        );
372    }
373}