1use ailake_core::{AilakeResult, Centroid, RowId, VectorStoragePolicy};
3use ailake_index::{HnswBuilder, HnswConfig, HnswSerializer, IvfPqConfig, IvfPqSerializer};
4use ailake_parquet::ParquetVectorWriter;
5use ailake_vec::compute_centroid_and_radius;
6use arrow_array::RecordBatch;
7use bytes::{BufMut, Bytes, BytesMut};
8
9use crate::footer::{
10 AilakeHeader, AilakeTrailer, DistanceMetric, Precision, AILAKE_FORMAT_VERSION,
11 FLAG_INDEX_IVF_PQ, HEADER_SIZE, TRAILER_SIZE,
12};
13
14#[derive(Debug, Clone)]
16pub enum IndexType {
17 Hnsw(HnswConfig),
19 IvfPq(IvfPqConfig),
21 Auto,
26}
27
28impl Default for IndexType {
29 fn default() -> Self {
30 IndexType::Hnsw(HnswConfig::default())
31 }
32}
33
34pub struct VectorColumnBatch<'a> {
36 pub policy: &'a VectorStoragePolicy,
37 pub embeddings: &'a [Vec<f32>],
38}
39
40pub struct AilakeFileWriter {
41 policy: VectorStoragePolicy,
42 index_type: IndexType,
43}
44
45impl AilakeFileWriter {
46 pub fn new(policy: VectorStoragePolicy) -> Self {
47 Self {
48 policy,
49 index_type: IndexType::default(),
50 }
51 }
52
53 pub fn with_hnsw_config(mut self, config: HnswConfig) -> Self {
54 self.index_type = IndexType::Hnsw(config);
55 self
56 }
57
58 pub fn with_ivf_pq(mut self, config: IvfPqConfig) -> Self {
59 self.index_type = IndexType::IvfPq(config);
60 self
61 }
62
63 pub fn with_index_type(mut self, index_type: IndexType) -> Self {
64 self.index_type = index_type;
65 self
66 }
67
68 pub fn with_auto_index(mut self) -> Self {
71 self.index_type = IndexType::Auto;
72 self
73 }
74
75 pub fn write_parquet_only(
83 &self,
84 batch: &RecordBatch,
85 embeddings: &[Vec<f32>],
86 ) -> AilakeResult<Bytes> {
87 let parquet_writer = ParquetVectorWriter::new(self.policy.clone());
88 let (bytes, _) = parquet_writer.write_batch(batch, embeddings)?;
89 Ok(bytes)
90 }
91
92 pub fn write(&self, batch: &RecordBatch, embeddings: &[Vec<f32>]) -> AilakeResult<Bytes> {
101 let col = VectorColumnBatch {
102 policy: &self.policy,
103 embeddings,
104 };
105 self.write_multi(batch, &[col])
106 }
107
108 pub fn write_multi(
117 &self,
118 batch: &RecordBatch,
119 columns: &[VectorColumnBatch<'_>],
120 ) -> AilakeResult<Bytes> {
121 use ailake_core::AilakeError;
122
123 if columns.is_empty() {
124 return Err(AilakeError::InvalidArgument(
125 "write_multi requires at least one vector column".into(),
126 ));
127 }
128
129 let primary = &columns[0];
130 let parquet_writer = ParquetVectorWriter::new(primary.policy.clone());
131
132 let (parquet_v1, record_count) = parquet_writer.write_batch(batch, primary.embeddings)?;
134 let footer_start = parquet_footer_start(&parquet_v1)?;
135
136 let mut ailk_sections: Vec<Bytes> = Vec::with_capacity(columns.len());
138 let mut kv_owned: Vec<(String, String)> = Vec::with_capacity(columns.len());
139 let mut current_offset = footer_start as u64;
140
141 for (i, col) in columns.iter().enumerate() {
142 let section = build_ailk_section(
143 col.policy,
144 col.embeddings,
145 record_count,
146 current_offset,
147 &self.index_type,
148 )?;
149 let kv_key = if i == 0 {
150 "ailake.footer_offset".to_string()
151 } else {
152 format!("ailake.{}.footer_offset", col.policy.column_name)
153 };
154 kv_owned.push((kv_key, current_offset.to_string()));
155 current_offset += section.len() as u64;
156 ailk_sections.push(section);
157 }
158
159 let kv_refs: Vec<(&str, &str)> = kv_owned
161 .iter()
162 .map(|(k, v)| (k.as_str(), v.as_str()))
163 .collect();
164 let (parquet_v2, _) =
165 parquet_writer.write_batch_with_kv(batch, primary.embeddings, &kv_refs)?;
166 let footer_start_v2 = parquet_footer_start(&parquet_v2)?;
167
168 let total_ailk: usize = ailk_sections.iter().map(|s| s.len()).sum();
170 let total = footer_start_v2 + total_ailk + (parquet_v2.len() - footer_start_v2);
171 let mut out = BytesMut::with_capacity(total);
172 out.put_slice(&parquet_v2[..footer_start_v2]);
173 for section in ailk_sections {
174 out.put(section);
175 }
176 out.put_slice(&parquet_v2[footer_start_v2..]);
177
178 Ok(out.freeze())
179 }
180}
181
182fn build_ailk_section(
184 policy: &VectorStoragePolicy,
185 embeddings: &[Vec<f32>],
186 record_count: u64,
187 ailk_abs_offset: u64,
188 index_type: &IndexType,
189) -> AilakeResult<Bytes> {
190 let centroid: Centroid = compute_centroid_and_radius(embeddings, policy.metric);
191 let centroid_bytes = encode_centroid(¢roid);
192
193 let resolved: IndexType;
195 let index_type = if matches!(index_type, IndexType::Auto) {
196 let profile = ailake_index::HardwareProfile::detect();
197 resolved = if profile.recommend_ivf_pq(embeddings.len()) {
198 IndexType::IvfPq(ailake_index::IvfPqConfig::for_dataset(
199 policy.dim as usize,
200 embeddings.len(),
201 ))
202 } else {
203 IndexType::Hnsw(ailake_index::HnswConfig::default())
204 };
205 &resolved
206 } else {
207 index_type
208 };
209
210 let (index_bytes, flags) = match index_type {
211 IndexType::Hnsw(hnsw_config) => {
212 let mut builder = HnswBuilder::new(policy.dim, policy.metric, hnsw_config.clone());
213 for (i, v) in embeddings.iter().enumerate() {
214 builder.insert(RowId::new(i as u64), v.clone());
215 }
216 let index = builder.build();
217 (HnswSerializer::to_bytes(&index)?, 0u16)
218 }
219 IndexType::IvfPq(ivf_config) => {
220 let row_ids: Vec<RowId> = (0..embeddings.len() as u64).map(RowId::new).collect();
221 let index = ailake_index::IvfPqIndex::train(
222 &row_ids,
223 embeddings,
224 policy.metric,
225 ivf_config.clone(),
226 )?;
227 (IvfPqSerializer::to_bytes(&index)?, FLAG_INDEX_IVF_PQ)
228 }
229 IndexType::Auto => unreachable!("Auto resolved above"),
230 };
231
232 let centroid_offset = HEADER_SIZE as u64;
233 let centroid_len = centroid_bytes.len() as u64;
234 let index_offset_in_ailk = centroid_offset + centroid_len;
235 let index_len = index_bytes.len() as u64;
236 let ailk_total_len = HEADER_SIZE as u64 + centroid_len + index_len + TRAILER_SIZE as u64;
237
238 let header = AilakeHeader {
239 format_version: AILAKE_FORMAT_VERSION,
240 flags,
241 dim: policy.dim,
242 precision: Precision::from(policy.precision),
243 distance_metric: DistanceMetric::from(policy.metric),
244 record_count,
245 centroid_offset,
246 centroid_len,
247 hnsw_offset: index_offset_in_ailk,
248 hnsw_len: index_len,
249 };
250 let trailer = AilakeTrailer {
251 footer_offset: ailk_abs_offset,
252 footer_len: ailk_total_len,
253 format_version: AILAKE_FORMAT_VERSION,
254 flags,
255 };
256
257 let mut buf = BytesMut::with_capacity(ailk_total_len as usize);
258 buf.put_slice(&header.to_bytes());
259 buf.put_slice(¢roid_bytes);
260 buf.put_slice(&index_bytes);
261 buf.put_slice(&trailer.to_bytes());
262 Ok(buf.freeze())
263}
264
265fn parquet_footer_start(buf: &[u8]) -> AilakeResult<usize> {
268 use ailake_core::AilakeError;
269 let len = buf.len();
270 if len < 8 {
271 return Err(AilakeError::Parquet("file too small".into()));
272 }
273 if &buf[len - 4..] != b"PAR1" {
274 return Err(AilakeError::Parquet("missing PAR1 footer magic".into()));
275 }
276 let footer_thrift_len = u32::from_le_bytes(buf[len - 8..len - 4].try_into().unwrap()) as usize;
277 let start = len
278 .checked_sub(8 + footer_thrift_len)
279 .ok_or_else(|| AilakeError::Parquet("footer length overflow".into()))?;
280 Ok(start)
281}
282
283fn encode_centroid(c: &Centroid) -> Vec<u8> {
284 let mut bytes = Vec::with_capacity(c.values.len() * 4 + 4);
285 for &v in &c.values {
286 bytes.extend_from_slice(&v.to_le_bytes());
287 }
288 bytes.extend_from_slice(&c.radius.to_le_bytes());
289 bytes
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use ailake_core::{VectorMetric, VectorPrecision};
296 use arrow_array::{Int32Array, RecordBatch};
297 use arrow_schema::{DataType, Field, Schema};
298 use std::sync::Arc;
299
300 fn make_policy(dim: u32) -> VectorStoragePolicy {
301 VectorStoragePolicy {
302 column_name: "embedding".to_string(),
303 dim,
304 metric: VectorMetric::Cosine,
305 precision: VectorPrecision::F16,
306 pq: None,
307 keep_raw_for_reranking: false,
308 }
309 }
310
311 #[test]
312 fn write_ends_with_par1() {
313 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
314 let batch =
315 RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap();
316 let embs: Vec<Vec<f32>> = (0..3).map(|_| vec![0.1, 0.2, 0.3, 0.4]).collect();
317
318 let writer = AilakeFileWriter::new(make_policy(4));
319 let file = writer.write(&batch, &embs).unwrap();
320
321 assert_eq!(&file[file.len() - 4..], b"PAR1");
322 assert_eq!(&file[..4], b"PAR1");
323 assert!(file.windows(4).any(|w| w == b"AILK"));
324 }
325
326 #[test]
327 fn write_multi_two_columns() {
328 use ailake_core::{VectorMetric, VectorPrecision};
329
330 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
331 let batch =
332 RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap();
333
334 let embs: Vec<Vec<f32>> = (0..3).map(|i| vec![i as f32, 0.0, 0.0, 0.0]).collect();
335 let ctx_embs: Vec<Vec<f32>> = (0..3).map(|i| vec![0.0, i as f32, 0.0, 0.0]).collect();
336
337 let policy1 = make_policy(4);
338 let policy2 = VectorStoragePolicy {
339 column_name: "context_embedding".to_string(),
340 dim: 4,
341 metric: VectorMetric::Cosine,
342 precision: VectorPrecision::F16,
343 pq: None,
344 keep_raw_for_reranking: false,
345 };
346
347 let writer = AilakeFileWriter::new(policy1.clone());
348 let file = writer
349 .write_multi(
350 &batch,
351 &[
352 VectorColumnBatch {
353 policy: &policy1,
354 embeddings: &embs,
355 },
356 VectorColumnBatch {
357 policy: &policy2,
358 embeddings: &ctx_embs,
359 },
360 ],
361 )
362 .unwrap();
363
364 assert_eq!(&file[..4], b"PAR1");
366 assert_eq!(&file[file.len() - 4..], b"PAR1");
367 let ailk_count = file.windows(4).filter(|w| *w == b"AILK").count();
369 assert!(
370 ailk_count >= 2,
371 "expected >= 2 AILK markers, got {ailk_count}"
372 );
373 }
374}