Skip to main content

ailake_query/
writer.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use std::sync::atomic::{AtomicU32, Ordering};
3use std::sync::Arc;
4
5use ailake_catalog::{
6    encode_centroid_b64, make_data_file_entry, make_data_file_entry_indexing,
7    make_multi_column_data_file_entry, new_snapshot_id, CatalogProvider, DataFileEntry,
8    ExtraVectorIndex, IcebergSchemaUpdate, IndexStatus, NewSnapshot, SnapshotId, SnapshotOperation,
9    TableIdent, TableProperties, VectorIndexInfo,
10};
11use ailake_core::{AilakeResult, VectorStoragePolicy};
12use ailake_file::{AilakeFileReader, AilakeFileWriter, IndexType, VectorColumnBatch};
13use ailake_index::IvfPqConfig;
14use ailake_store::Store;
15use ailake_vec::compute_centroid_and_radius;
16use arrow_array::RecordBatch;
17use arrow_schema::SchemaRef;
18use bytes::Bytes;
19use serde_json;
20
21/// One vector column for a multi-column write batch.
22pub struct MultiVectorBatch<'a> {
23    pub policy: VectorStoragePolicy,
24    pub embeddings: &'a [Vec<f32>],
25}
26
27pub struct TableWriter {
28    catalog: Arc<dyn CatalogProvider>,
29    store: Arc<dyn Store>,
30    policy: VectorStoragePolicy,
31    table: TableIdent,
32    part_counter: Arc<AtomicU32>,
33    pending_files: Vec<DataFileEntry>,
34    parent_snapshot_id: Option<SnapshotId>,
35    /// Arrow schema captured from the first write_batch call; used to populate
36    /// Iceberg schema fields and schema.name-mapping.default on commit.
37    captured_schema: Option<SchemaRef>,
38    /// Extra vector column policies from write_batch_multi (columns beyond primary).
39    extra_vec_policies: Vec<VectorStoragePolicy>,
40}
41
42impl TableWriter {
43    pub fn new(
44        catalog: Arc<dyn CatalogProvider>,
45        store: Arc<dyn Store>,
46        policy: VectorStoragePolicy,
47        table: TableIdent,
48    ) -> Self {
49        Self {
50            catalog,
51            store,
52            policy,
53            table,
54            part_counter: Arc::new(AtomicU32::new(0)),
55            pending_files: Vec::new(),
56            parent_snapshot_id: None,
57            captured_schema: None,
58            extra_vec_policies: Vec::new(),
59        }
60    }
61
62    pub fn with_parent_snapshot(mut self, id: SnapshotId) -> Self {
63        self.parent_snapshot_id = Some(id);
64        self
65    }
66
67    /// Write batch as Parquet-only immediately, build HNSW in background.
68    ///
69    /// Returns after the Parquet file is persisted (~LanceDB write speed).
70    /// A tokio task runs concurrently to build the HNSW index, rewrite the
71    /// file with the AILK section, and update the catalog entry.
72    ///
73    /// During the build window, `SearchSession` serves this shard via flat scan
74    /// (brute-force, exact) instead of HNSW. The transition is automatic once
75    /// the background task commits the updated manifest entry.
76    pub async fn write_batch_deferred(
77        &mut self,
78        batch: &RecordBatch,
79        embeddings: &[Vec<f32>],
80    ) -> AilakeResult<()> {
81        if self.captured_schema.is_none() {
82            self.captured_schema = Some(batch.schema());
83        }
84        let part_num = self.part_counter.fetch_add(1, Ordering::SeqCst);
85        let file_path = format!("data/part-{:05}.parquet", part_num);
86
87        // Fast path: persist Parquet without HNSW.
88        let file_writer = AilakeFileWriter::new(self.policy.clone());
89        let parquet_bytes = file_writer.write_parquet_only(batch, embeddings)?;
90        let file_size = parquet_bytes.len() as u64;
91        self.store.put(&file_path, parquet_bytes).await?;
92
93        // Centroid needed immediately for geometric pruning during the build window.
94        let centroid = compute_centroid_and_radius(embeddings, self.policy.metric);
95        let entry = make_data_file_entry_indexing(
96            &file_path,
97            embeddings.len() as u64,
98            file_size,
99            &centroid,
100            &self.policy.column_name,
101            self.policy.dim,
102        );
103        self.pending_files.push(entry);
104
105        // Spawn background HNSW build (fire-and-forget; errors are logged).
106        let store = self.store.clone();
107        let catalog = self.catalog.clone();
108        let policy = self.policy.clone();
109        let table = self.table.clone();
110        let fp = file_path.clone();
111        tokio::spawn(async move {
112            if let Err(e) = build_and_patch_index(store, catalog, policy, table, fp).await {
113                eprintln!("[ailake] deferred HNSW build failed: {e}");
114            }
115        });
116
117        Ok(())
118    }
119
120    /// Idempotent variant of `write_batch`.
121    ///
122    /// Before any I/O, checks if `batch_id` already appears in the current
123    /// snapshot. If it does, this is a no-op — safe for Airflow/Kestra retries.
124    /// If not found, writes the batch and tags the `DataFileEntry` with `batch_id`
125    /// so future retries can detect it.
126    ///
127    /// `commit()` is likewise a no-op when `pending_files` is empty.
128    pub async fn write_batch_idempotent(
129        &mut self,
130        batch: &RecordBatch,
131        embeddings: &[Vec<f32>],
132        batch_id: &str,
133    ) -> AilakeResult<()> {
134        let existing = self.catalog.list_files(&self.table, None).await?;
135        if existing
136            .iter()
137            .any(|f| f.batch_id.as_deref() == Some(batch_id))
138        {
139            return Ok(());
140        }
141        self.write_batch_with_id(batch, embeddings, Some(batch_id.to_string()))
142            .await
143    }
144
145    /// Write a batch to a new AI-Lake file and stage it for commit.
146    pub async fn write_batch(
147        &mut self,
148        batch: &RecordBatch,
149        embeddings: &[Vec<f32>],
150    ) -> AilakeResult<()> {
151        self.write_batch_with_id(batch, embeddings, None).await
152    }
153
154    async fn write_batch_with_id(
155        &mut self,
156        batch: &RecordBatch,
157        embeddings: &[Vec<f32>],
158        batch_id: Option<String>,
159    ) -> AilakeResult<()> {
160        if self.captured_schema.is_none() {
161            self.captured_schema = Some(batch.schema());
162        }
163        let part_num = self.part_counter.fetch_add(1, Ordering::SeqCst);
164        let file_path = format!("data/part-{:05}.parquet", part_num);
165
166        // Write AI-Lake file
167        let file_writer = AilakeFileWriter::new(self.policy.clone());
168        let file_bytes: Bytes = file_writer.write(batch, embeddings)?;
169        let file_size = file_bytes.len() as u64;
170
171        // Store the file
172        self.store.put(&file_path, file_bytes.clone()).await?;
173
174        // Compute centroid for catalog entry
175        let centroid = compute_centroid_and_radius(embeddings, self.policy.metric);
176
177        // Read back the HNSW offsets from the written file
178        let reader = ailake_file::AilakeFileReader::new(
179            file_bytes,
180            &self.policy.column_name,
181            self.policy.dim,
182        );
183        let header = reader.read_header()?;
184        let ailk_start = reader.ailk_offset()?;
185        let hnsw_abs_offset = ailk_start + header.hnsw_offset;
186        let hnsw_len = header.hnsw_len;
187
188        let mut entry = make_data_file_entry(
189            &file_path,
190            embeddings.len() as u64,
191            file_size,
192            &centroid,
193            VectorIndexInfo {
194                column: &self.policy.column_name,
195                dim: self.policy.dim,
196                hnsw_offset: hnsw_abs_offset,
197                hnsw_len,
198            },
199        );
200        entry.batch_id = batch_id;
201        self.pending_files.push(entry);
202        Ok(())
203    }
204
205    /// Write batch, auto-selecting the index based on detected hardware.
206    ///
207    /// Picks IVF-PQ when a CUDA GPU or ≥8 CPU cores are present AND the batch
208    /// has ≥5 000 vectors. Falls back to HNSW for weaker / local hardware.
209    /// Uses `IvfPqConfig::for_dataset` to scale nlist with dataset size.
210    pub async fn write_batch_auto(
211        &mut self,
212        batch: &RecordBatch,
213        embeddings: &[Vec<f32>],
214    ) -> AilakeResult<()> {
215        let profile = ailake_index::HardwareProfile::detect();
216        if profile.recommend_ivf_pq(embeddings.len()) {
217            let ivf_config =
218                ailake_index::IvfPqConfig::for_dataset(self.policy.dim as usize, embeddings.len());
219            self.write_batch_ivf_pq(batch, embeddings, ivf_config).await
220        } else {
221            self.write_batch(batch, embeddings).await
222        }
223    }
224
225    /// Write batch with IVF-PQ index built synchronously (no background task).
226    ///
227    /// Smaller index than HNSW; better for S3 sequential-scan workloads.
228    pub async fn write_batch_ivf_pq(
229        &mut self,
230        batch: &RecordBatch,
231        embeddings: &[Vec<f32>],
232        ivf_config: IvfPqConfig,
233    ) -> AilakeResult<()> {
234        if self.captured_schema.is_none() {
235            self.captured_schema = Some(batch.schema());
236        }
237        let part_num = self.part_counter.fetch_add(1, Ordering::SeqCst);
238        let file_path = format!("data/part-{:05}.parquet", part_num);
239
240        let file_writer = AilakeFileWriter::new(self.policy.clone())
241            .with_index_type(IndexType::IvfPq(ivf_config));
242        let file_bytes: Bytes = file_writer.write(batch, embeddings)?;
243        let file_size = file_bytes.len() as u64;
244
245        self.store.put(&file_path, file_bytes.clone()).await?;
246
247        let centroid = compute_centroid_and_radius(embeddings, self.policy.metric);
248
249        let reader = ailake_file::AilakeFileReader::new(
250            file_bytes,
251            &self.policy.column_name,
252            self.policy.dim,
253        );
254        let header = reader.read_header()?;
255        let ailk_start = reader.ailk_offset()?;
256        let index_abs_offset = ailk_start + header.hnsw_offset;
257        let index_len = header.hnsw_len;
258
259        let entry = make_data_file_entry(
260            &file_path,
261            embeddings.len() as u64,
262            file_size,
263            &centroid,
264            VectorIndexInfo {
265                column: &self.policy.column_name,
266                dim: self.policy.dim,
267                hnsw_offset: index_abs_offset,
268                hnsw_len: index_len,
269            },
270        );
271        self.pending_files.push(entry);
272        Ok(())
273    }
274
275    /// Write a batch with multiple vector columns into a single AI-Lake file.
276    ///
277    /// The first entry in `columns` is treated as the primary column (used for
278    /// geometric pruning). Additional columns each get their own HNSW section.
279    pub async fn write_batch_multi(
280        &mut self,
281        batch: &RecordBatch,
282        columns: &[MultiVectorBatch<'_>],
283    ) -> AilakeResult<()> {
284        use ailake_core::AilakeError;
285        if self.captured_schema.is_none() {
286            self.captured_schema = Some(batch.schema());
287        }
288        if self.extra_vec_policies.is_empty() && columns.len() > 1 {
289            self.extra_vec_policies = columns[1..].iter().map(|c| c.policy.clone()).collect();
290        }
291
292        if columns.is_empty() {
293            return Err(AilakeError::InvalidArgument(
294                "write_batch_multi requires at least one column".into(),
295            ));
296        }
297
298        let part_num = self.part_counter.fetch_add(1, Ordering::SeqCst);
299        let file_path = format!("data/part-{:05}.parquet", part_num);
300
301        let col_batches: Vec<VectorColumnBatch<'_>> = columns
302            .iter()
303            .map(|c| VectorColumnBatch {
304                policy: &c.policy,
305                embeddings: c.embeddings,
306            })
307            .collect();
308
309        let primary_policy = &columns[0].policy;
310        let file_writer = AilakeFileWriter::new(primary_policy.clone());
311        let file_bytes: Bytes = file_writer.write_multi(batch, &col_batches)?;
312        let file_size = file_bytes.len() as u64;
313
314        self.store.put(&file_path, file_bytes.clone()).await?;
315
316        // Primary centroid for pruning
317        let primary_centroid =
318            compute_centroid_and_radius(columns[0].embeddings, primary_policy.metric);
319
320        // Read primary AILK header for offsets
321        let reader = ailake_file::AilakeFileReader::new(
322            file_bytes.clone(),
323            &primary_policy.column_name,
324            primary_policy.dim,
325        );
326        let primary_ailk_start = reader.ailk_offset()?;
327        let primary_header = {
328            use ailake_file::HEADER_SIZE;
329            let start = primary_ailk_start as usize;
330            let hdr_bytes: &[u8; HEADER_SIZE] = file_bytes[start..start + HEADER_SIZE]
331                .try_into()
332                .map_err(|_| AilakeError::NotAnAilakeFile)?;
333            ailake_file::AilakeHeader::from_bytes(hdr_bytes)?
334        };
335        let primary_hnsw_abs = primary_ailk_start + primary_header.hnsw_offset;
336
337        // Extra column index metadata
338        let mut extra: Vec<ExtraVectorIndex> = Vec::new();
339        for col in columns.iter().skip(1) {
340            let col_ailk_start = reader.ailk_offset_for_column(&col.policy.column_name)?;
341            let col_header = {
342                use ailake_file::HEADER_SIZE;
343                let start = col_ailk_start as usize;
344                let hdr_bytes: &[u8; HEADER_SIZE] = file_bytes[start..start + HEADER_SIZE]
345                    .try_into()
346                    .map_err(|_| AilakeError::NotAnAilakeFile)?;
347                ailake_file::AilakeHeader::from_bytes(hdr_bytes)?
348            };
349            let col_centroid = compute_centroid_and_radius(col.embeddings, col.policy.metric);
350            extra.push(ExtraVectorIndex {
351                column: col.policy.column_name.clone(),
352                dim: col.policy.dim,
353                hnsw_offset: col_ailk_start + col_header.hnsw_offset,
354                hnsw_len: col_header.hnsw_len,
355                centroid_b64: Some(encode_centroid_b64(&col_centroid)),
356                radius: Some(col_centroid.radius),
357            });
358        }
359
360        let entry = make_multi_column_data_file_entry(
361            &file_path,
362            columns[0].embeddings.len() as u64,
363            file_size,
364            &primary_centroid,
365            VectorIndexInfo {
366                column: &primary_policy.column_name,
367                dim: primary_policy.dim,
368                hnsw_offset: primary_hnsw_abs,
369                hnsw_len: primary_header.hnsw_len,
370            },
371            &extra,
372        );
373        self.pending_files.push(entry);
374        Ok(())
375    }
376
377    /// Commit all staged files as a new Iceberg snapshot.
378    ///
379    /// No-op when `pending_files` is empty (e.g., all `write_batch_idempotent`
380    /// calls were skipped because their `batch_id` was already committed).
381    /// Returns the current snapshot id in that case (or 0 if no snapshot exists yet).
382    pub async fn commit(mut self) -> AilakeResult<SnapshotId> {
383        if self.pending_files.is_empty() {
384            let current = self
385                .catalog
386                .load_table(&self.table)
387                .await
388                .ok()
389                .and_then(|m| m.current_snapshot_id)
390                .unwrap_or(0);
391            return Ok(current);
392        }
393        let iceberg_schema = self
394            .captured_schema
395            .as_deref()
396            .map(|s| arrow_schema_to_iceberg_update(s, &self.policy, &self.extra_vec_policies));
397        let snapshot = NewSnapshot {
398            snapshot_id: new_snapshot_id(),
399            parent_snapshot_id: self.parent_snapshot_id,
400            files: std::mem::take(&mut self.pending_files),
401            operation: SnapshotOperation::Append,
402            iceberg_schema,
403        };
404        self.catalog.commit_snapshot(&self.table, snapshot).await
405    }
406
407    /// Create a table if it doesn't exist, then return a writer for it.
408    pub async fn create_or_open(
409        catalog: Arc<dyn CatalogProvider>,
410        store: Arc<dyn Store>,
411        policy: VectorStoragePolicy,
412        table: TableIdent,
413    ) -> AilakeResult<Self> {
414        // Try to load; if not found, create
415        if catalog.load_table(&table).await.is_err() {
416            catalog
417                .create_table(
418                    &table,
419                    &TableProperties {
420                        policy: policy.clone(),
421                        extra: std::collections::HashMap::new(),
422                    },
423                )
424                .await?;
425        }
426        Ok(Self::new(catalog, store, policy, table))
427    }
428}
429
430/// Convert an Arrow schema to an Iceberg schema update for catalog commits.
431///
432/// Top-level field IDs are assigned sequentially (1-based) and match the
433/// `PARQUET:field_id` stamps written by `ParquetVectorWriter`. Nested element
434/// IDs (inside List/Struct/Map) are assigned after all top-level IDs are
435/// pre-reserved, so they never collide with Parquet column field IDs.
436fn arrow_schema_to_iceberg_update(
437    schema: &arrow_schema::Schema,
438    policy: &VectorStoragePolicy,
439    extra_vec_policies: &[VectorStoragePolicy],
440) -> IcebergSchemaUpdate {
441    let bytes_per_dim = policy.precision.bytes_per_element() as u32;
442    let vec_fixed_len = policy.dim * bytes_per_dim;
443
444    // Collect all vector column names that will appear in the final schema.
445    let has_primary_in_batch = schema
446        .fields()
447        .iter()
448        .any(|f| f.name() == &policy.column_name);
449    let vec_cols: Vec<(String, u32)> = {
450        let mut v = Vec::new();
451        if !has_primary_in_batch {
452            v.push((policy.column_name.clone(), vec_fixed_len));
453        }
454        for ep in extra_vec_policies {
455            let ep_fixed_len = ep.dim * ep.precision.bytes_per_element() as u32;
456            if !schema.fields().iter().any(|f| f.name() == &ep.column_name) {
457                v.push((ep.column_name.clone(), ep_fixed_len));
458            }
459        }
460        v
461    };
462
463    // Total top-level columns = batch fields + appended vec columns.
464    let top_level_count = schema.fields().len() + vec_cols.len();
465    // Nested element IDs start after all top-level IDs are pre-reserved.
466    let mut nested_id = top_level_count as i32;
467
468    let mut fields: Vec<serde_json::Value> = Vec::new();
469    let mut name_mapping: Vec<serde_json::Value> = Vec::new();
470
471    for (idx, field) in schema.fields().iter().enumerate() {
472        let field_id = (idx + 1) as i32;
473        let iceberg_type = arrow_type_to_iceberg(field.data_type(), &mut nested_id);
474        fields.push(serde_json::json!({
475            "id": field_id,
476            "name": field.name(),
477            "required": false,
478            "type": iceberg_type,
479        }));
480        name_mapping.push(serde_json::json!({
481            "field-id": field_id,
482            "names": [field.name()],
483        }));
484    }
485
486    // Append vector columns that live outside the RecordBatch schema.
487    for (i, (col_name, fixed_len)) in vec_cols.iter().enumerate() {
488        let field_id = (schema.fields().len() + 1 + i) as i32;
489        fields.push(serde_json::json!({
490            "id": field_id,
491            "name": col_name,
492            "required": false,
493            "type": format!("fixed[{fixed_len}]"),
494        }));
495        name_mapping.push(serde_json::json!({
496            "field-id": field_id,
497            "names": [col_name],
498        }));
499    }
500
501    let last_column_id = nested_id;
502    let name_mapping_json = serde_json::to_string(&name_mapping).unwrap_or_else(|_| "[]".into());
503
504    IcebergSchemaUpdate {
505        fields,
506        last_column_id,
507        name_mapping_json,
508    }
509}
510
511/// Map an Arrow DataType to an Iceberg schema type value (string or JSON object).
512///
513/// `nested_id` is a shared counter for generating unique element/field IDs inside
514/// List, Struct, and Map types. It must start beyond all pre-reserved top-level IDs.
515fn arrow_type_to_iceberg(dt: &arrow_schema::DataType, nested_id: &mut i32) -> serde_json::Value {
516    use arrow_schema::DataType;
517    match dt {
518        DataType::Boolean => serde_json::json!("boolean"),
519        DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::UInt8 | DataType::UInt16 => {
520            serde_json::json!("int")
521        }
522        DataType::Int64 | DataType::UInt32 | DataType::UInt64 => serde_json::json!("long"),
523        DataType::Float16 | DataType::Float32 => serde_json::json!("float"),
524        DataType::Float64 => serde_json::json!("double"),
525        DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => serde_json::json!("string"),
526        DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
527            serde_json::json!("binary")
528        }
529        DataType::Date32 | DataType::Date64 => serde_json::json!("date"),
530        // Timestamp with timezone → timestamptz; without → timestamp.
531        DataType::Timestamp(_, Some(_)) => serde_json::json!("timestamptz"),
532        DataType::Timestamp(_, None) => serde_json::json!("timestamp"),
533        DataType::Time32(_) | DataType::Time64(_) => serde_json::json!("time"),
534        DataType::FixedSizeBinary(n) => serde_json::json!(format!("fixed[{n}]")),
535        DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => {
536            serde_json::json!(format!("decimal({p}, {s})"))
537        }
538        DataType::List(inner)
539        | DataType::LargeList(inner)
540        | DataType::ListView(inner)
541        | DataType::FixedSizeList(inner, _) => {
542            *nested_id += 1;
543            let element_id = *nested_id;
544            let element_type = arrow_type_to_iceberg(inner.data_type(), nested_id);
545            serde_json::json!({
546                "type": "list",
547                "element-id": element_id,
548                "element": element_type,
549                "element-required": !inner.is_nullable(),
550            })
551        }
552        DataType::Struct(arrow_fields) => {
553            let struct_fields: Vec<serde_json::Value> = arrow_fields
554                .iter()
555                .map(|f| {
556                    *nested_id += 1;
557                    let fid = *nested_id;
558                    let ftype = arrow_type_to_iceberg(f.data_type(), nested_id);
559                    serde_json::json!({
560                        "id": fid,
561                        "name": f.name(),
562                        "required": !f.is_nullable(),
563                        "type": ftype,
564                    })
565                })
566                .collect();
567            serde_json::json!({ "type": "struct", "fields": struct_fields })
568        }
569        DataType::Map(entries, _) => {
570            // Arrow Map is List<Struct<key: K, value: V>>.
571            *nested_id += 1;
572            let key_id = *nested_id;
573            *nested_id += 1;
574            let val_id = *nested_id;
575            if let DataType::Struct(kv_fields) = entries.data_type() {
576                let key_f = kv_fields
577                    .iter()
578                    .find(|f| f.name() == "key" || f.name() == "keys");
579                let val_f = kv_fields
580                    .iter()
581                    .find(|f| f.name() == "value" || f.name() == "values");
582                let key_type = key_f
583                    .map(|f| arrow_type_to_iceberg(f.data_type(), nested_id))
584                    .unwrap_or(serde_json::json!("binary"));
585                let val_type = val_f
586                    .map(|f| arrow_type_to_iceberg(f.data_type(), nested_id))
587                    .unwrap_or(serde_json::json!("binary"));
588                let val_required = val_f.map(|f| !f.is_nullable()).unwrap_or(false);
589                serde_json::json!({
590                    "type": "map",
591                    "key-id": key_id,
592                    "key": key_type,
593                    "value-id": val_id,
594                    "value": val_type,
595                    "value-required": val_required,
596                })
597            } else {
598                serde_json::json!("binary")
599            }
600        }
601        _ => serde_json::json!("binary"),
602    }
603}
604
605/// Background task: reads a Parquet-only shard, builds full AILK file, patches catalog.
606async fn build_and_patch_index(
607    store: Arc<dyn Store>,
608    catalog: Arc<dyn CatalogProvider>,
609    policy: VectorStoragePolicy,
610    table: TableIdent,
611    file_path: String,
612) -> AilakeResult<()> {
613    // Read the Parquet-only bytes already stored.
614    let parquet_bytes = store.get(&file_path).await?;
615    let reader = AilakeFileReader::new(parquet_bytes, &policy.column_name, policy.dim);
616    let (batch, embeddings) = reader.read_parquet()?;
617
618    // Build the full AILK file (Parquet + HNSW) — CPU-intensive; run on blocking pool
619    // so the tokio async threads aren't starved when many shards build concurrently.
620    let full_bytes = tokio::task::spawn_blocking({
621        let policy = policy.clone();
622        move || {
623            let file_writer = AilakeFileWriter::new(policy);
624            file_writer.write(&batch, &embeddings)
625        }
626    })
627    .await
628    .map_err(|e| ailake_core::AilakeError::Store(format!("spawn_blocking panic: {e}")))??;
629
630    // Extract HNSW offsets from the newly written file.
631    let full_reader = AilakeFileReader::new(full_bytes.clone(), &policy.column_name, policy.dim);
632    let header = full_reader.read_header()?;
633    let ailk_start = full_reader.ailk_offset()?;
634    let hnsw_abs_offset = ailk_start + header.hnsw_offset;
635    let hnsw_len = header.hnsw_len;
636
637    // Overwrite the Parquet-only file with the full AILK version.
638    store.put(&file_path, full_bytes).await?;
639
640    // Wait for the initial writer commit to appear (HNSW builds can finish before
641    // the main write loop calls commit_snapshot, so the catalog has no snapshot yet).
642    for _ in 0..120u32 {
643        match catalog.load_table(&table).await {
644            Ok(meta) if meta.current_snapshot_id.is_some() => break,
645            _ => tokio::time::sleep(std::time::Duration::from_millis(500)).await,
646        }
647    }
648
649    // Update the catalog with CAS-like retry to handle concurrent background tasks.
650    // Multiple tasks can race on commit_snapshot(Replace): the last writer wins and
651    // may overwrite a sibling task's Ready status. Retry until we confirm our file
652    // is marked Ready in the current snapshot.
653    for attempt in 0..50u32 {
654        let table_meta = catalog.load_table(&table).await?;
655        let parent_snapshot_id = table_meta.current_snapshot_id;
656        let mut files = catalog.list_files(&table, None).await?;
657
658        // Already marked Ready by a previous successful attempt.
659        if files
660            .iter()
661            .any(|f| f.path == file_path && f.index_status == IndexStatus::Ready)
662        {
663            break;
664        }
665
666        for f in &mut files {
667            if f.path == file_path {
668                f.hnsw_offset = Some(hnsw_abs_offset);
669                f.hnsw_len = Some(hnsw_len);
670                f.index_status = IndexStatus::Ready;
671                break;
672            }
673        }
674        catalog
675            .commit_snapshot(
676                &table,
677                NewSnapshot {
678                    snapshot_id: new_snapshot_id(),
679                    parent_snapshot_id,
680                    files,
681                    operation: SnapshotOperation::Replace,
682                    iceberg_schema: None,
683                },
684            )
685            .await?;
686
687        // Brief yield so sibling tasks can commit, then verify our change survived.
688        tokio::time::sleep(std::time::Duration::from_millis(10 + attempt as u64 * 5)).await;
689
690        let verify = catalog.list_files(&table, None).await?;
691        if verify
692            .iter()
693            .any(|f| f.path == file_path && f.index_status == IndexStatus::Ready)
694        {
695            break;
696        }
697        // Another task overwrote us — retry.
698    }
699
700    eprintln!(
701        "[ailake] deferred HNSW built for {file_path} (offset={hnsw_abs_offset}, len={hnsw_len})"
702    );
703    Ok(())
704}
705
706#[cfg(test)]
707mod tests {
708    use super::*;
709    use ailake_core::{VectorMetric, VectorPrecision};
710    use arrow_schema::{DataType, Field, Schema, TimeUnit};
711
712    fn policy(col: &str, dim: u32) -> VectorStoragePolicy {
713        VectorStoragePolicy {
714            column_name: col.to_string(),
715            dim,
716            metric: VectorMetric::Cosine,
717            precision: VectorPrecision::F16,
718            pq: None,
719            keep_raw_for_reranking: false,
720        }
721    }
722
723    fn update_for(schema: &Schema, pol: &VectorStoragePolicy) -> IcebergSchemaUpdate {
724        arrow_schema_to_iceberg_update(schema, pol, &[])
725    }
726
727    #[test]
728    fn simple_schema_produces_correct_fields() {
729        let schema = Schema::new(vec![
730            Field::new("id", DataType::Int32, false),
731            Field::new("text", DataType::Utf8, false),
732        ]);
733        let pol = policy("embedding", 8);
734        let upd = update_for(&schema, &pol);
735
736        assert_eq!(upd.fields.len(), 3);
737        assert_eq!(upd.fields[0]["id"], 1);
738        assert_eq!(upd.fields[0]["type"], "int");
739        assert_eq!(upd.fields[1]["id"], 2);
740        assert_eq!(upd.fields[1]["type"], "string");
741        assert_eq!(upd.fields[2]["id"], 3);
742        assert_eq!(upd.fields[2]["type"], "fixed[16]"); // dim=8, F16=2 bytes
743
744        let nm: Vec<serde_json::Value> = serde_json::from_str(&upd.name_mapping_json).unwrap();
745        assert_eq!(nm.len(), 3);
746        assert_eq!(nm[2]["field-id"], 3);
747        assert_eq!(nm[2]["names"][0], "embedding");
748        assert_eq!(upd.last_column_id, 3);
749    }
750
751    #[test]
752    fn timestamp_without_tz_maps_to_timestamp_not_timestamptz() {
753        let schema = Schema::new(vec![
754            Field::new(
755                "created_at",
756                DataType::Timestamp(TimeUnit::Microsecond, None),
757                true,
758            ),
759            Field::new(
760                "updated_at",
761                DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
762                true,
763            ),
764        ]);
765        let pol = policy("vec", 4);
766        let upd = update_for(&schema, &pol);
767
768        assert_eq!(upd.fields[0]["type"], "timestamp");
769        assert_eq!(upd.fields[1]["type"], "timestamptz");
770    }
771
772    #[test]
773    fn list_type_produces_iceberg_list_object() {
774        let schema = Schema::new(vec![Field::new(
775            "tags",
776            DataType::List(std::sync::Arc::new(Field::new(
777                "item",
778                DataType::Utf8,
779                true,
780            ))),
781            true,
782        )]);
783        let pol = policy("vec", 4);
784        let upd = update_for(&schema, &pol);
785
786        let t = &upd.fields[0]["type"];
787        assert_eq!(t["type"], "list");
788        assert_eq!(t["element"], "string");
789        // element-id must be > top-level field count (2: tags + vec)
790        assert!(t["element-id"].as_i64().unwrap() > 2);
791    }
792
793    #[test]
794    fn struct_type_produces_nested_fields() {
795        let schema = Schema::new(vec![Field::new(
796            "meta",
797            DataType::Struct(
798                vec![
799                    Field::new("key", DataType::Utf8, false),
800                    Field::new("val", DataType::Int64, false),
801                ]
802                .into(),
803            ),
804            true,
805        )]);
806        let pol = policy("vec", 4);
807        let upd = update_for(&schema, &pol);
808
809        let t = &upd.fields[0]["type"];
810        assert_eq!(t["type"], "struct");
811        let nested = t["fields"].as_array().unwrap();
812        assert_eq!(nested.len(), 2);
813        assert_eq!(nested[0]["name"], "key");
814        assert_eq!(nested[0]["type"], "string");
815        assert_eq!(nested[1]["name"], "val");
816        assert_eq!(nested[1]["type"], "long");
817        // Nested IDs must be > top-level count (2: meta + vec)
818        assert!(nested[0]["id"].as_i64().unwrap() > 2);
819    }
820
821    #[test]
822    fn no_duplicate_vec_column_when_already_in_batch() {
823        // If for some reason the vec column is in the batch schema, don't add it twice.
824        let schema = Schema::new(vec![
825            Field::new("id", DataType::Int32, false),
826            Field::new("embedding", DataType::FixedSizeBinary(16), false),
827        ]);
828        let pol = policy("embedding", 8);
829        let upd = update_for(&schema, &pol);
830
831        assert_eq!(upd.fields.len(), 2, "should not add embedding twice");
832        let names: Vec<&str> = upd
833            .fields
834            .iter()
835            .map(|f| f["name"].as_str().unwrap())
836            .collect();
837        assert_eq!(names.iter().filter(|&&n| n == "embedding").count(), 1);
838    }
839
840    #[test]
841    fn multi_vec_policies_all_appended() {
842        let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
843        let primary = policy("embedding", 4);
844        let extra = vec![policy("context_embedding", 4)];
845        let upd = arrow_schema_to_iceberg_update(&schema, &primary, &extra);
846
847        assert_eq!(upd.fields.len(), 3); // id + embedding + context_embedding
848        let names: Vec<&str> = upd
849            .fields
850            .iter()
851            .map(|f| f["name"].as_str().unwrap())
852            .collect();
853        assert!(names.contains(&"embedding"));
854        assert!(names.contains(&"context_embedding"));
855    }
856
857    #[test]
858    fn top_level_field_ids_match_parquet_stamp_sequence() {
859        // Top-level IDs must be 1, 2, ..., N regardless of nested element IDs.
860        let schema = Schema::new(vec![
861            Field::new("id", DataType::Int64, false),
862            Field::new(
863                "tags",
864                DataType::List(std::sync::Arc::new(Field::new(
865                    "item",
866                    DataType::Utf8,
867                    true,
868                ))),
869                true,
870            ),
871        ]);
872        let pol = policy("vec", 4);
873        let upd = update_for(&schema, &pol);
874
875        // Top-level: id=1, tags=2, vec=3
876        assert_eq!(upd.fields[0]["id"], 1);
877        assert_eq!(upd.fields[1]["id"], 2);
878        assert_eq!(upd.fields[2]["id"], 3);
879
880        // Nested element ID must be > 3
881        assert!(upd.fields[1]["type"]["element-id"].as_i64().unwrap() > 3);
882    }
883}