Skip to main content

ailake_query/
writer.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use std::sync::atomic::{AtomicU32, Ordering};
3use std::sync::Arc;
4
5use ailake_catalog::{
6    encode_centroid_b64, make_data_file_entry, make_data_file_entry_indexing,
7    make_multi_column_data_file_entry, new_snapshot_id, CatalogProvider, DataFileEntry,
8    ExtraVectorIndex, IcebergSchemaUpdate, IndexStatus, NewSnapshot, SnapshotId, SnapshotOperation,
9    TableIdent, TableProperties, VectorIndexInfo,
10};
11use ailake_core::{AilakeError, AilakeResult, EmbeddingModelInfo, VectorStoragePolicy};
12use ailake_file::{AilakeFileReader, AilakeFileWriter, IndexType, VectorColumnBatch};
13use ailake_index::{IvfPqCodebook, IvfPqConfig};
14use ailake_store::Store;
15use ailake_vec::compute_centroid_and_radius;
16use arrow_array::RecordBatch;
17use arrow_schema::SchemaRef;
18use bytes::Bytes;
19use serde_json;
20use tracing::{error, info, warn};
21
22/// One vector column for a multi-column write batch.
23pub struct MultiVectorBatch<'a> {
24    pub policy: VectorStoragePolicy,
25    pub embeddings: &'a [Vec<f32>],
26}
27
28pub struct TableWriter {
29    catalog: Arc<dyn CatalogProvider>,
30    store: Arc<dyn Store>,
31    policy: VectorStoragePolicy,
32    table: TableIdent,
33    part_counter: Arc<AtomicU32>,
34    pending_files: Vec<DataFileEntry>,
35    parent_snapshot_id: Option<SnapshotId>,
36    /// Arrow schema captured from the first write_batch call; used to populate
37    /// Iceberg schema fields and schema.name-mapping.default on commit.
38    captured_schema: Option<SchemaRef>,
39    /// Extra vector column policies from write_batch_multi (columns beyond primary).
40    extra_vec_policies: Vec<VectorStoragePolicy>,
41    /// IVF-PQ codebook trained on the first shard and reused for all subsequent shards.
42    /// Ensures cross-shard ADC distances are comparable — no reranking needed.
43    cached_ivf_codebook: Option<Arc<IvfPqCodebook>>,
44    /// Shared codebook cell for deferred IVF-PQ builds. Cloneable Arc so each
45    /// background task can access it; OnceCell guarantees training runs exactly once.
46    deferred_ivf_codebook: Arc<tokio::sync::OnceCell<IvfPqCodebook>>,
47}
48
49impl TableWriter {
50    pub fn new(
51        catalog: Arc<dyn CatalogProvider>,
52        store: Arc<dyn Store>,
53        policy: VectorStoragePolicy,
54        table: TableIdent,
55    ) -> Self {
56        Self {
57            catalog,
58            store,
59            policy,
60            table,
61            part_counter: Arc::new(AtomicU32::new(0)),
62            pending_files: Vec::new(),
63            parent_snapshot_id: None,
64            captured_schema: None,
65            extra_vec_policies: Vec::new(),
66            cached_ivf_codebook: None,
67            deferred_ivf_codebook: Arc::new(tokio::sync::OnceCell::new()),
68        }
69    }
70
71    pub fn with_parent_snapshot(mut self, id: SnapshotId) -> Self {
72        self.parent_snapshot_id = Some(id);
73        self
74    }
75
76    /// Write batch as Parquet-only immediately, build HNSW in background.
77    ///
78    /// Returns after the Parquet file is persisted (~LanceDB write speed).
79    /// A tokio task runs concurrently to build the HNSW index, rewrite the
80    /// file with the AILK section, and update the catalog entry.
81    ///
82    /// During the build window, `SearchSession` serves this shard via flat scan
83    /// (brute-force, exact) instead of HNSW. The transition is automatic once
84    /// the background task commits the updated manifest entry.
85    pub async fn write_batch_deferred(
86        &mut self,
87        batch: &RecordBatch,
88        embeddings: &[Vec<f32>],
89    ) -> AilakeResult<()> {
90        self.validate_embedding_dim(embeddings)?;
91        if self.captured_schema.is_none() {
92            self.captured_schema = Some(batch.schema());
93        }
94        let part_num = self.part_counter.fetch_add(1, Ordering::SeqCst);
95        let file_path = format!("data/part-{:05}.parquet", part_num);
96
97        // Fast path: persist Parquet without HNSW.
98        let file_writer = AilakeFileWriter::new(self.policy.clone());
99        let parquet_bytes = file_writer.write_parquet_only(batch, embeddings)?;
100        let file_size = parquet_bytes.len() as u64;
101        self.store.put(&file_path, parquet_bytes).await?;
102
103        // Centroid needed immediately for geometric pruning during the build window.
104        let centroid = compute_centroid_and_radius(embeddings, self.policy.metric);
105        let mut entry = make_data_file_entry_indexing(
106            &file_path,
107            embeddings.len() as u64,
108            file_size,
109            &centroid,
110            &self.policy.column_name,
111            self.policy.dim,
112        );
113        entry.embedding_model = self
114            .policy
115            .embedding_model
116            .as_ref()
117            .map(|m| m.to_property_value());
118        self.pending_files.push(entry);
119
120        // Spawn background HNSW build (fire-and-forget; errors are logged).
121        let store = self.store.clone();
122        let catalog = self.catalog.clone();
123        let policy = self.policy.clone();
124        let table = self.table.clone();
125        let fp = file_path.clone();
126        tokio::spawn(async move {
127            if let Err(e) = build_and_patch_index(store, catalog, policy, table, fp).await {
128                error!(
129                    "ailake: deferred HNSW build failed — file is indexed as Parquet-only until \
130                     next compaction rebuilds the index: {}",
131                    e
132                );
133            }
134        });
135
136        Ok(())
137    }
138
139    /// Write batch as Parquet-only immediately; train IVF-PQ index in background.
140    ///
141    /// The first shard trains the shared codebook (k-means). All subsequent shards
142    /// reuse it via `OnceCell` — build is O(n) assign+encode, not O(n×k) k-means.
143    /// Returns after Parquet is persisted. Index transitions Indexing → Ready async.
144    pub async fn write_batch_ivf_pq_deferred(
145        &mut self,
146        batch: &RecordBatch,
147        embeddings: &[Vec<f32>],
148        ivf_config: IvfPqConfig,
149    ) -> AilakeResult<()> {
150        if self.captured_schema.is_none() {
151            self.captured_schema = Some(batch.schema());
152        }
153        let part_num = self.part_counter.fetch_add(1, Ordering::SeqCst);
154        let file_path = format!("data/part-{:05}.parquet", part_num);
155
156        let file_writer = AilakeFileWriter::new(self.policy.clone());
157        let parquet_bytes = file_writer.write_parquet_only(batch, embeddings)?;
158        let file_size = parquet_bytes.len() as u64;
159        self.store.put(&file_path, parquet_bytes).await?;
160
161        let centroid = compute_centroid_and_radius(embeddings, self.policy.metric);
162        let mut entry = make_data_file_entry_indexing(
163            &file_path,
164            embeddings.len() as u64,
165            file_size,
166            &centroid,
167            &self.policy.column_name,
168            self.policy.dim,
169        );
170        entry.embedding_model = self
171            .policy
172            .embedding_model
173            .as_ref()
174            .map(|m| m.to_property_value());
175        self.pending_files.push(entry);
176
177        let store = self.store.clone();
178        let catalog = self.catalog.clone();
179        let policy = self.policy.clone();
180        let table = self.table.clone();
181        let fp = file_path.clone();
182        let codebook_cell = self.deferred_ivf_codebook.clone();
183        tokio::spawn(async move {
184            if let Err(e) = build_ivf_pq_and_patch_index(
185                store,
186                catalog,
187                policy,
188                table,
189                fp,
190                ivf_config,
191                codebook_cell,
192            )
193            .await
194            {
195                error!(
196                    "ailake: deferred IVF-PQ build failed — file is indexed as Parquet-only until \
197                     next compaction rebuilds the index: {}",
198                    e
199                );
200            }
201        });
202
203        Ok(())
204    }
205
206    /// Idempotent variant of `write_batch`.
207    ///
208    /// Before any I/O, checks if `batch_id` already appears in the current
209    /// snapshot. If it does, this is a no-op — safe for Airflow/Kestra retries.
210    /// If not found, writes the batch and tags the `DataFileEntry` with `batch_id`
211    /// so future retries can detect it.
212    ///
213    /// `commit()` is likewise a no-op when `pending_files` is empty.
214    pub async fn write_batch_idempotent(
215        &mut self,
216        batch: &RecordBatch,
217        embeddings: &[Vec<f32>],
218        batch_id: &str,
219    ) -> AilakeResult<()> {
220        let existing = self.catalog.list_files(&self.table, None).await?;
221        if existing
222            .iter()
223            .any(|f| f.batch_id.as_deref() == Some(batch_id))
224        {
225            return Ok(());
226        }
227        self.write_batch_with_id(batch, embeddings, Some(batch_id.to_string()))
228            .await
229    }
230
231    /// Write a batch to a new AI-Lake file and stage it for commit.
232    /// Validates that provided embeddings match the table's configured dimension.
233    /// Returns `ModelMismatch` error when dim differs — prevents silently mixing
234    /// incompatible vectors (same error type used across write paths for consistency).
235    fn validate_embedding_dim(&self, embeddings: &[Vec<f32>]) -> AilakeResult<()> {
236        if let Some(first) = embeddings.first() {
237            let actual = first.len() as u32;
238            if actual != self.policy.dim {
239                let table_model = self
240                    .policy
241                    .embedding_model
242                    .as_ref()
243                    .map(|m| m.to_property_value())
244                    .unwrap_or_else(|| format!("dim={}", self.policy.dim));
245                return Err(AilakeError::ModelMismatch {
246                    table_model,
247                    table_dim: self.policy.dim,
248                    batch_model: format!("dim={}", actual),
249                    batch_dim: actual,
250                });
251            }
252        }
253        Ok(())
254    }
255
256    pub async fn write_batch(
257        &mut self,
258        batch: &RecordBatch,
259        embeddings: &[Vec<f32>],
260    ) -> AilakeResult<()> {
261        self.write_batch_with_id(batch, embeddings, None).await
262    }
263
264    async fn write_batch_with_id(
265        &mut self,
266        batch: &RecordBatch,
267        embeddings: &[Vec<f32>],
268        batch_id: Option<String>,
269    ) -> AilakeResult<()> {
270        self.validate_embedding_dim(embeddings)?;
271        if self.captured_schema.is_none() {
272            self.captured_schema = Some(batch.schema());
273        }
274        let part_num = self.part_counter.fetch_add(1, Ordering::SeqCst);
275        let file_path = format!("data/part-{:05}.parquet", part_num);
276
277        // Write AI-Lake file
278        let file_writer = AilakeFileWriter::new(self.policy.clone());
279        let file_bytes: Bytes = file_writer.write(batch, embeddings)?;
280        let file_size = file_bytes.len() as u64;
281
282        // Store the file
283        self.store.put(&file_path, file_bytes.clone()).await?;
284
285        // Compute centroid for catalog entry
286        let centroid = compute_centroid_and_radius(embeddings, self.policy.metric);
287
288        // Read back the HNSW offsets from the written file
289        let reader = ailake_file::AilakeFileReader::new(
290            file_bytes,
291            &self.policy.column_name,
292            self.policy.dim,
293        );
294        let header = reader.read_header()?;
295        let ailk_start = reader.ailk_offset()?;
296        let hnsw_abs_offset = ailk_start + header.hnsw_offset;
297        let hnsw_len = header.hnsw_len;
298
299        let mut entry = make_data_file_entry(
300            &file_path,
301            embeddings.len() as u64,
302            file_size,
303            &centroid,
304            VectorIndexInfo {
305                column: &self.policy.column_name,
306                dim: self.policy.dim,
307                hnsw_offset: hnsw_abs_offset,
308                hnsw_len,
309            },
310        );
311        entry.batch_id = batch_id;
312        entry.embedding_model = self
313            .policy
314            .embedding_model
315            .as_ref()
316            .map(|m| m.to_property_value());
317        self.pending_files.push(entry);
318        Ok(())
319    }
320
321    /// Write batch, auto-selecting the index based on detected hardware.
322    ///
323    /// Picks IVF-PQ when a CUDA GPU or ≥8 CPU cores are present AND the batch
324    /// has ≥5 000 vectors. Falls back to HNSW for weaker / local hardware.
325    /// Uses `IvfPqConfig::for_dataset` to scale nlist with dataset size.
326    pub async fn write_batch_auto(
327        &mut self,
328        batch: &RecordBatch,
329        embeddings: &[Vec<f32>],
330    ) -> AilakeResult<()> {
331        let profile = ailake_index::HardwareProfile::detect();
332        if profile.recommend_ivf_pq(embeddings.len()) {
333            let mut ivf_config =
334                ailake_index::IvfPqConfig::for_dataset(self.policy.dim as usize, embeddings.len());
335            if self.policy.ivf_residual {
336                ivf_config = ivf_config.with_residual();
337            }
338            self.write_batch_ivf_pq(batch, embeddings, ivf_config).await
339        } else {
340            self.write_batch(batch, embeddings).await
341        }
342    }
343
344    /// Write batch, auto-selecting the index based on detected hardware — deferred variant.
345    ///
346    /// Same hardware detection as `write_batch_auto`: picks IVF-PQ when a CUDA GPU or
347    /// ≥8 CPU cores are present AND the batch has ≥5 000 vectors; falls back to HNSW.
348    ///
349    /// Unlike `write_batch_auto`, the index is built in a background tokio task:
350    /// - Parquet is persisted immediately (~200k vec/s, same as write_parquet_only).
351    /// - HNSW or IVF-PQ index built asynchronously; shard served via flat scan meanwhile.
352    ///
353    /// Use this when ingest throughput matters more than immediate searchability.
354    pub async fn write_batch_auto_deferred(
355        &mut self,
356        batch: &RecordBatch,
357        embeddings: &[Vec<f32>],
358    ) -> AilakeResult<()> {
359        let profile = ailake_index::HardwareProfile::detect();
360        if profile.recommend_ivf_pq(embeddings.len()) {
361            let mut ivf_config =
362                ailake_index::IvfPqConfig::for_dataset(self.policy.dim as usize, embeddings.len());
363            if self.policy.ivf_residual {
364                ivf_config = ivf_config.with_residual();
365            }
366            self.write_batch_ivf_pq_deferred(batch, embeddings, ivf_config)
367                .await
368        } else {
369            self.write_batch_deferred(batch, embeddings).await
370        }
371    }
372
373    /// Write batch with IVF-PQ index built synchronously (no background task).
374    ///
375    /// Smaller index than HNSW; better for S3 sequential-scan workloads.
376    pub async fn write_batch_ivf_pq(
377        &mut self,
378        batch: &RecordBatch,
379        embeddings: &[Vec<f32>],
380        ivf_config: IvfPqConfig,
381    ) -> AilakeResult<()> {
382        if self.captured_schema.is_none() {
383            self.captured_schema = Some(batch.schema());
384        }
385        let part_num = self.part_counter.fetch_add(1, Ordering::SeqCst);
386        let file_path = format!("data/part-{:05}.parquet", part_num);
387
388        // Train codebook once on the first shard; all subsequent shards reuse it.
389        // This makes cross-shard ADC distances comparable, eliminating the need
390        // for exact reranking during multi-shard search.
391        if self.cached_ivf_codebook.is_none() {
392            let codebook = tokio::task::spawn_blocking({
393                let embeddings = embeddings.to_vec();
394                let metric = self.policy.metric;
395                let config = ivf_config.clone();
396                move || ailake_index::IvfPqIndex::train_codebook(&embeddings, metric, &config)
397            })
398            .await
399            .map_err(|e| ailake_core::AilakeError::Store(format!("spawn_blocking panic: {e}")))??;
400            self.cached_ivf_codebook = Some(Arc::new(codebook));
401        }
402        // SAFETY: set to Some in the block above (either pre-existing or just trained).
403        let codebook = self
404            .cached_ivf_codebook
405            .as_ref()
406            .expect("IVF-PQ codebook must be Some after training block")
407            .clone();
408
409        let file_writer = AilakeFileWriter::new(self.policy.clone())
410            .with_index_type(IndexType::IvfPq(ivf_config))
411            .with_shared_ivf_codebook(codebook);
412        let file_bytes: Bytes = file_writer.write(batch, embeddings)?;
413        let file_size = file_bytes.len() as u64;
414
415        self.store.put(&file_path, file_bytes.clone()).await?;
416
417        let centroid = compute_centroid_and_radius(embeddings, self.policy.metric);
418
419        let reader = ailake_file::AilakeFileReader::new(
420            file_bytes,
421            &self.policy.column_name,
422            self.policy.dim,
423        );
424        let header = reader.read_header()?;
425        let ailk_start = reader.ailk_offset()?;
426        let index_abs_offset = ailk_start + header.hnsw_offset;
427        let index_len = header.hnsw_len;
428
429        let mut entry = make_data_file_entry(
430            &file_path,
431            embeddings.len() as u64,
432            file_size,
433            &centroid,
434            VectorIndexInfo {
435                column: &self.policy.column_name,
436                dim: self.policy.dim,
437                hnsw_offset: index_abs_offset,
438                hnsw_len: index_len,
439            },
440        );
441        entry.embedding_model = self
442            .policy
443            .embedding_model
444            .as_ref()
445            .map(|m| m.to_property_value());
446        self.pending_files.push(entry);
447        Ok(())
448    }
449
450    /// Write a batch with multiple vector columns into a single AI-Lake file.
451    ///
452    /// The first entry in `columns` is treated as the primary column (used for
453    /// geometric pruning). Additional columns each get their own HNSW section.
454    pub async fn write_batch_multi(
455        &mut self,
456        batch: &RecordBatch,
457        columns: &[MultiVectorBatch<'_>],
458    ) -> AilakeResult<()> {
459        use ailake_core::AilakeError;
460        if self.captured_schema.is_none() {
461            self.captured_schema = Some(batch.schema());
462        }
463        if self.extra_vec_policies.is_empty() && columns.len() > 1 {
464            self.extra_vec_policies = columns[1..].iter().map(|c| c.policy.clone()).collect();
465        }
466
467        if columns.is_empty() {
468            return Err(AilakeError::InvalidArgument(
469                "write_batch_multi requires at least one column".into(),
470            ));
471        }
472
473        let part_num = self.part_counter.fetch_add(1, Ordering::SeqCst);
474        let file_path = format!("data/part-{:05}.parquet", part_num);
475
476        let col_batches: Vec<VectorColumnBatch<'_>> = columns
477            .iter()
478            .map(|c| VectorColumnBatch {
479                policy: &c.policy,
480                embeddings: c.embeddings,
481            })
482            .collect();
483
484        let primary_policy = &columns[0].policy;
485        let file_writer = AilakeFileWriter::new(primary_policy.clone());
486        let file_bytes: Bytes = file_writer.write_multi(batch, &col_batches)?;
487        let file_size = file_bytes.len() as u64;
488
489        self.store.put(&file_path, file_bytes.clone()).await?;
490
491        // Primary centroid for pruning
492        let primary_centroid =
493            compute_centroid_and_radius(columns[0].embeddings, primary_policy.metric);
494
495        // Read primary AILK header for offsets
496        let reader = ailake_file::AilakeFileReader::new(
497            file_bytes.clone(),
498            &primary_policy.column_name,
499            primary_policy.dim,
500        );
501        let primary_ailk_start = reader.ailk_offset()?;
502        let primary_header = {
503            use ailake_file::HEADER_SIZE;
504            let start = primary_ailk_start as usize;
505            let hdr_bytes: &[u8; HEADER_SIZE] = file_bytes[start..start + HEADER_SIZE]
506                .try_into()
507                .map_err(|_| AilakeError::NotAnAilakeFile)?;
508            ailake_file::AilakeHeader::from_bytes(hdr_bytes)?
509        };
510        let primary_hnsw_abs = primary_ailk_start + primary_header.hnsw_offset;
511
512        // Extra column index metadata
513        let mut extra: Vec<ExtraVectorIndex> = Vec::new();
514        for col in columns.iter().skip(1) {
515            let col_ailk_start = reader.ailk_offset_for_column(&col.policy.column_name)?;
516            let col_header = {
517                use ailake_file::HEADER_SIZE;
518                let start = col_ailk_start as usize;
519                let hdr_bytes: &[u8; HEADER_SIZE] = file_bytes[start..start + HEADER_SIZE]
520                    .try_into()
521                    .map_err(|_| AilakeError::NotAnAilakeFile)?;
522                ailake_file::AilakeHeader::from_bytes(hdr_bytes)?
523            };
524            let col_centroid = compute_centroid_and_radius(col.embeddings, col.policy.metric);
525            extra.push(ExtraVectorIndex {
526                column: col.policy.column_name.clone(),
527                dim: col.policy.dim,
528                hnsw_offset: col_ailk_start + col_header.hnsw_offset,
529                hnsw_len: col_header.hnsw_len,
530                centroid_b64: Some(encode_centroid_b64(&col_centroid)),
531                radius: Some(col_centroid.radius),
532            });
533        }
534
535        let mut entry = make_multi_column_data_file_entry(
536            &file_path,
537            columns[0].embeddings.len() as u64,
538            file_size,
539            &primary_centroid,
540            VectorIndexInfo {
541                column: &primary_policy.column_name,
542                dim: primary_policy.dim,
543                hnsw_offset: primary_hnsw_abs,
544                hnsw_len: primary_header.hnsw_len,
545            },
546            &extra,
547        );
548        entry.embedding_model = self
549            .policy
550            .embedding_model
551            .as_ref()
552            .map(|m| m.to_property_value());
553        self.pending_files.push(entry);
554        Ok(())
555    }
556
557    /// Commit all staged files as a new Iceberg snapshot.
558    ///
559    /// No-op when `pending_files` is empty (e.g., all `write_batch_idempotent`
560    /// calls were skipped because their `batch_id` was already committed).
561    /// Returns the current snapshot id in that case (or 0 if no snapshot exists yet).
562    pub async fn commit(mut self) -> AilakeResult<SnapshotId> {
563        if self.pending_files.is_empty() {
564            let current = self
565                .catalog
566                .load_table(&self.table)
567                .await
568                .ok()
569                .and_then(|m| m.current_snapshot_id)
570                .unwrap_or(0);
571            return Ok(current);
572        }
573        let iceberg_schema = self
574            .captured_schema
575            .as_deref()
576            .map(|s| arrow_schema_to_iceberg_update(s, &self.policy, &self.extra_vec_policies));
577        let snapshot = NewSnapshot {
578            snapshot_id: new_snapshot_id(),
579            parent_snapshot_id: self.parent_snapshot_id,
580            files: std::mem::take(&mut self.pending_files),
581            operation: SnapshotOperation::Append,
582            iceberg_schema,
583        };
584        self.catalog.commit_snapshot(&self.table, snapshot).await
585    }
586
587    /// Create a table if it doesn't exist, then return a writer for it.
588    pub async fn create_or_open(
589        catalog: Arc<dyn CatalogProvider>,
590        store: Arc<dyn Store>,
591        policy: VectorStoragePolicy,
592        table: TableIdent,
593    ) -> AilakeResult<Self> {
594        // Try to load; if not found, create
595        match catalog.load_table(&table).await {
596            Ok(existing_meta) => {
597                // Warn when writing with a different model name into an existing table.
598                // Dim mismatch is a hard error caught at write_batch time; name divergence
599                // is softer — same dim, different model (e.g. fine-tune vs base) — warn only.
600                if let Some(incoming) = &policy.embedding_model {
601                    if let Some(stored_val) = existing_meta
602                        .properties
603                        .get(EmbeddingModelInfo::property_key())
604                    {
605                        let stored = EmbeddingModelInfo::from_property_value(stored_val);
606                        if stored.name != incoming.name {
607                            warn!(
608                                "ailake: embedding model name changed: table has '{}', writing with '{}' \
609                                 (dim={}). Vectors may be incompatible for similarity search.",
610                                stored.name, incoming.name, policy.dim
611                            );
612                        }
613                    }
614                }
615            }
616            Err(_) => {
617                catalog
618                    .create_table(
619                        &table,
620                        &TableProperties {
621                            policy: policy.clone(),
622                            extra: std::collections::HashMap::new(),
623                        },
624                    )
625                    .await?;
626            }
627        }
628        Ok(Self::new(catalog, store, policy, table))
629    }
630}
631
632/// Convert an Arrow schema to an Iceberg schema update for catalog commits.
633///
634/// Top-level field IDs are assigned sequentially (1-based) and match the
635/// `PARQUET:field_id` stamps written by `ParquetVectorWriter`. Nested element
636/// IDs (inside List/Struct/Map) are assigned after all top-level IDs are
637/// pre-reserved, so they never collide with Parquet column field IDs.
638fn arrow_schema_to_iceberg_update(
639    schema: &arrow_schema::Schema,
640    policy: &VectorStoragePolicy,
641    extra_vec_policies: &[VectorStoragePolicy],
642) -> IcebergSchemaUpdate {
643    let bytes_per_dim = policy.precision.bytes_per_element() as u32;
644    let vec_fixed_len = policy.dim * bytes_per_dim;
645
646    // Collect all vector column names that will appear in the final schema.
647    let has_primary_in_batch = schema
648        .fields()
649        .iter()
650        .any(|f| f.name() == &policy.column_name);
651    let vec_cols: Vec<(String, u32)> = {
652        let mut v = Vec::new();
653        if !has_primary_in_batch {
654            v.push((policy.column_name.clone(), vec_fixed_len));
655        }
656        for ep in extra_vec_policies {
657            let ep_fixed_len = ep.dim * ep.precision.bytes_per_element() as u32;
658            if !schema.fields().iter().any(|f| f.name() == &ep.column_name) {
659                v.push((ep.column_name.clone(), ep_fixed_len));
660            }
661        }
662        v
663    };
664
665    // Total top-level columns = batch fields + appended vec columns.
666    let top_level_count = schema.fields().len() + vec_cols.len();
667    // Nested element IDs start after all top-level IDs are pre-reserved.
668    let mut nested_id = top_level_count as i32;
669
670    let mut fields: Vec<serde_json::Value> = Vec::new();
671    let mut name_mapping: Vec<serde_json::Value> = Vec::new();
672
673    for (idx, field) in schema.fields().iter().enumerate() {
674        let field_id = (idx + 1) as i32;
675        let iceberg_type = arrow_type_to_iceberg(field.data_type(), &mut nested_id);
676        fields.push(serde_json::json!({
677            "id": field_id,
678            "name": field.name(),
679            "required": false,
680            "type": iceberg_type,
681        }));
682        name_mapping.push(serde_json::json!({
683            "field-id": field_id,
684            "names": [field.name()],
685        }));
686    }
687
688    // Append vector columns that live outside the RecordBatch schema.
689    for (i, (col_name, fixed_len)) in vec_cols.iter().enumerate() {
690        let field_id = (schema.fields().len() + 1 + i) as i32;
691        fields.push(serde_json::json!({
692            "id": field_id,
693            "name": col_name,
694            "required": false,
695            "type": format!("fixed[{fixed_len}]"),
696        }));
697        name_mapping.push(serde_json::json!({
698            "field-id": field_id,
699            "names": [col_name],
700        }));
701    }
702
703    let last_column_id = nested_id;
704    let name_mapping_json = serde_json::to_string(&name_mapping).unwrap_or_else(|_| "[]".into());
705
706    IcebergSchemaUpdate {
707        fields,
708        last_column_id,
709        name_mapping_json,
710    }
711}
712
713/// Map an Arrow DataType to an Iceberg schema type value (string or JSON object).
714///
715/// `nested_id` is a shared counter for generating unique element/field IDs inside
716/// List, Struct, and Map types. It must start beyond all pre-reserved top-level IDs.
717fn arrow_type_to_iceberg(dt: &arrow_schema::DataType, nested_id: &mut i32) -> serde_json::Value {
718    use arrow_schema::DataType;
719    match dt {
720        DataType::Boolean => serde_json::json!("boolean"),
721        DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::UInt8 | DataType::UInt16 => {
722            serde_json::json!("int")
723        }
724        DataType::Int64 | DataType::UInt32 | DataType::UInt64 => serde_json::json!("long"),
725        DataType::Float16 | DataType::Float32 => serde_json::json!("float"),
726        DataType::Float64 => serde_json::json!("double"),
727        DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => serde_json::json!("string"),
728        DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
729            serde_json::json!("binary")
730        }
731        DataType::Date32 | DataType::Date64 => serde_json::json!("date"),
732        // Timestamp with timezone → timestamptz; without → timestamp.
733        DataType::Timestamp(_, Some(_)) => serde_json::json!("timestamptz"),
734        DataType::Timestamp(_, None) => serde_json::json!("timestamp"),
735        DataType::Time32(_) | DataType::Time64(_) => serde_json::json!("time"),
736        DataType::FixedSizeBinary(n) => serde_json::json!(format!("fixed[{n}]")),
737        DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => {
738            serde_json::json!(format!("decimal({p}, {s})"))
739        }
740        DataType::List(inner)
741        | DataType::LargeList(inner)
742        | DataType::ListView(inner)
743        | DataType::FixedSizeList(inner, _) => {
744            *nested_id += 1;
745            let element_id = *nested_id;
746            let element_type = arrow_type_to_iceberg(inner.data_type(), nested_id);
747            serde_json::json!({
748                "type": "list",
749                "element-id": element_id,
750                "element": element_type,
751                "element-required": !inner.is_nullable(),
752            })
753        }
754        DataType::Struct(arrow_fields) => {
755            let struct_fields: Vec<serde_json::Value> = arrow_fields
756                .iter()
757                .map(|f| {
758                    *nested_id += 1;
759                    let fid = *nested_id;
760                    let ftype = arrow_type_to_iceberg(f.data_type(), nested_id);
761                    serde_json::json!({
762                        "id": fid,
763                        "name": f.name(),
764                        "required": !f.is_nullable(),
765                        "type": ftype,
766                    })
767                })
768                .collect();
769            serde_json::json!({ "type": "struct", "fields": struct_fields })
770        }
771        DataType::Map(entries, _) => {
772            // Arrow Map is List<Struct<key: K, value: V>>.
773            *nested_id += 1;
774            let key_id = *nested_id;
775            *nested_id += 1;
776            let val_id = *nested_id;
777            if let DataType::Struct(kv_fields) = entries.data_type() {
778                let key_f = kv_fields
779                    .iter()
780                    .find(|f| f.name() == "key" || f.name() == "keys");
781                let val_f = kv_fields
782                    .iter()
783                    .find(|f| f.name() == "value" || f.name() == "values");
784                let key_type = key_f
785                    .map(|f| arrow_type_to_iceberg(f.data_type(), nested_id))
786                    .unwrap_or(serde_json::json!("binary"));
787                let val_type = val_f
788                    .map(|f| arrow_type_to_iceberg(f.data_type(), nested_id))
789                    .unwrap_or(serde_json::json!("binary"));
790                let val_required = val_f.map(|f| !f.is_nullable()).unwrap_or(false);
791                serde_json::json!({
792                    "type": "map",
793                    "key-id": key_id,
794                    "key": key_type,
795                    "value-id": val_id,
796                    "value": val_type,
797                    "value-required": val_required,
798                })
799            } else {
800                serde_json::json!("binary")
801            }
802        }
803        _ => serde_json::json!("binary"),
804    }
805}
806
807/// Background task: reads a Parquet-only shard, builds full AILK file, patches catalog.
808async fn build_and_patch_index(
809    store: Arc<dyn Store>,
810    catalog: Arc<dyn CatalogProvider>,
811    policy: VectorStoragePolicy,
812    table: TableIdent,
813    file_path: String,
814) -> AilakeResult<()> {
815    // Read the Parquet-only bytes already stored.
816    let parquet_bytes = store.get(&file_path).await?;
817    let reader = AilakeFileReader::new(parquet_bytes, &policy.column_name, policy.dim);
818    let (batch, embeddings) = reader.read_parquet()?;
819
820    // Build the full AILK file (Parquet + HNSW) — CPU-intensive; run on blocking pool
821    // so the tokio async threads aren't starved when many shards build concurrently.
822    let full_bytes = tokio::task::spawn_blocking({
823        let policy = policy.clone();
824        move || {
825            let file_writer = AilakeFileWriter::new(policy);
826            file_writer.write(&batch, &embeddings)
827        }
828    })
829    .await
830    .map_err(|e| ailake_core::AilakeError::Store(format!("spawn_blocking panic: {e}")))??;
831
832    // Extract HNSW offsets from the newly written file.
833    let full_reader = AilakeFileReader::new(full_bytes.clone(), &policy.column_name, policy.dim);
834    let header = full_reader.read_header()?;
835    let ailk_start = full_reader.ailk_offset()?;
836    let hnsw_abs_offset = ailk_start + header.hnsw_offset;
837    let hnsw_len = header.hnsw_len;
838
839    // Overwrite the Parquet-only file with the full AILK version.
840    store.put(&file_path, full_bytes).await?;
841
842    // Wait for the initial writer commit to appear (HNSW builds can finish before
843    // the main write loop calls commit_snapshot, so the catalog has no snapshot yet).
844    for _ in 0..120u32 {
845        match catalog.load_table(&table).await {
846            Ok(meta) if meta.current_snapshot_id.is_some() => break,
847            _ => tokio::time::sleep(std::time::Duration::from_millis(500)).await,
848        }
849    }
850
851    // Update the catalog with CAS-like retry to handle concurrent background tasks.
852    // Multiple tasks can race on commit_snapshot(Replace): the last writer wins and
853    // may overwrite a sibling task's Ready status. Retry until we confirm our file
854    // is marked Ready in the current snapshot.
855    for attempt in 0..50u32 {
856        let table_meta = catalog.load_table(&table).await?;
857        let parent_snapshot_id = table_meta.current_snapshot_id;
858        let mut files = catalog.list_files(&table, None).await?;
859
860        // Already marked Ready by a previous successful attempt.
861        if files
862            .iter()
863            .any(|f| f.path == file_path && f.index_status == IndexStatus::Ready)
864        {
865            break;
866        }
867
868        for f in &mut files {
869            if f.path == file_path {
870                f.hnsw_offset = Some(hnsw_abs_offset);
871                f.hnsw_len = Some(hnsw_len);
872                f.index_status = IndexStatus::Ready;
873                break;
874            }
875        }
876        catalog
877            .commit_snapshot(
878                &table,
879                NewSnapshot {
880                    snapshot_id: new_snapshot_id(),
881                    parent_snapshot_id,
882                    files,
883                    operation: SnapshotOperation::Replace,
884                    iceberg_schema: None,
885                },
886            )
887            .await?;
888
889        // Brief yield so sibling tasks can commit, then verify our change survived.
890        tokio::time::sleep(std::time::Duration::from_millis(10 + attempt as u64 * 5)).await;
891
892        let verify = catalog.list_files(&table, None).await?;
893        if verify
894            .iter()
895            .any(|f| f.path == file_path && f.index_status == IndexStatus::Ready)
896        {
897            break;
898        }
899        // Another task overwrote us — retry.
900    }
901
902    info!(
903        "ailake: deferred HNSW index built for {} (offset={}, len={})",
904        file_path, hnsw_abs_offset, hnsw_len
905    );
906    Ok(())
907}
908
909/// Background task: train IVF-PQ (using shared codebook) and patch catalog entry.
910///
911/// The OnceCell guarantees that k-means training runs exactly once across all
912/// concurrent background tasks — subsequent tasks skip directly to assign+encode.
913async fn build_ivf_pq_and_patch_index(
914    store: Arc<dyn Store>,
915    catalog: Arc<dyn CatalogProvider>,
916    policy: VectorStoragePolicy,
917    table: TableIdent,
918    file_path: String,
919    ivf_config: IvfPqConfig,
920    codebook_cell: Arc<tokio::sync::OnceCell<IvfPqCodebook>>,
921) -> AilakeResult<()> {
922    let parquet_bytes = store.get(&file_path).await?;
923    let reader = AilakeFileReader::new(parquet_bytes, &policy.column_name, policy.dim);
924    let (batch, embeddings) = reader.read_parquet()?;
925
926    // Get or train the shared codebook. First task trains; all others await the result.
927    let codebook = codebook_cell
928        .get_or_try_init(|| async {
929            let vecs = embeddings.clone();
930            let metric = policy.metric;
931            let cfg = ivf_config.clone();
932            tokio::task::spawn_blocking(move || {
933                ailake_index::IvfPqIndex::train_codebook(&vecs, metric, &cfg)
934            })
935            .await
936            .map_err(|e| ailake_core::AilakeError::Store(format!("spawn_blocking panic: {e}")))?
937        })
938        .await?;
939
940    let full_bytes = tokio::task::spawn_blocking({
941        let policy = policy.clone();
942        let codebook = codebook.clone();
943        move || {
944            let file_writer = AilakeFileWriter::new(policy)
945                .with_index_type(IndexType::IvfPq(ivf_config))
946                .with_shared_ivf_codebook(Arc::new(codebook));
947            file_writer.write(&batch, &embeddings)
948        }
949    })
950    .await
951    .map_err(|e| ailake_core::AilakeError::Store(format!("spawn_blocking panic: {e}")))??;
952
953    let full_reader = AilakeFileReader::new(full_bytes.clone(), &policy.column_name, policy.dim);
954    let header = full_reader.read_header()?;
955    let ailk_start = full_reader.ailk_offset()?;
956    let hnsw_abs_offset = ailk_start + header.hnsw_offset;
957    let hnsw_len = header.hnsw_len;
958
959    store.put(&file_path, full_bytes).await?;
960
961    // Wait for initial commit to appear then patch IndexStatus::Ready (same CAS loop as HNSW).
962    for _ in 0..120u32 {
963        match catalog.load_table(&table).await {
964            Ok(meta) if meta.current_snapshot_id.is_some() => break,
965            _ => tokio::time::sleep(std::time::Duration::from_millis(500)).await,
966        }
967    }
968
969    for attempt in 0..50u32 {
970        let table_meta = catalog.load_table(&table).await?;
971        let parent_snapshot_id = table_meta.current_snapshot_id;
972        let mut files = catalog.list_files(&table, None).await?;
973
974        if files
975            .iter()
976            .any(|f| f.path == file_path && f.index_status == IndexStatus::Ready)
977        {
978            break;
979        }
980
981        for f in &mut files {
982            if f.path == file_path {
983                f.hnsw_offset = Some(hnsw_abs_offset);
984                f.hnsw_len = Some(hnsw_len);
985                f.index_status = IndexStatus::Ready;
986                break;
987            }
988        }
989        catalog
990            .commit_snapshot(
991                &table,
992                NewSnapshot {
993                    snapshot_id: new_snapshot_id(),
994                    parent_snapshot_id,
995                    files,
996                    operation: SnapshotOperation::Replace,
997                    iceberg_schema: None,
998                },
999            )
1000            .await?;
1001
1002        tokio::time::sleep(std::time::Duration::from_millis(10 + attempt as u64 * 5)).await;
1003
1004        let verify = catalog.list_files(&table, None).await?;
1005        if verify
1006            .iter()
1007            .any(|f| f.path == file_path && f.index_status == IndexStatus::Ready)
1008        {
1009            break;
1010        }
1011    }
1012
1013    info!(
1014        "ailake: deferred IVF-PQ index built for {} (offset={}, len={})",
1015        file_path, hnsw_abs_offset, hnsw_len
1016    );
1017    Ok(())
1018}
1019
1020#[cfg(test)]
1021mod tests {
1022    use super::*;
1023    use ailake_core::{VectorMetric, VectorPrecision};
1024    use arrow_schema::{DataType, Field, Schema, TimeUnit};
1025
1026    fn policy(col: &str, dim: u32) -> VectorStoragePolicy {
1027        VectorStoragePolicy {
1028            column_name: col.to_string(),
1029            dim,
1030            metric: VectorMetric::Cosine,
1031            precision: VectorPrecision::F16,
1032            pq: None,
1033            keep_raw_for_reranking: true,
1034            pre_normalize: false,
1035            hnsw_m: None,
1036            hnsw_ef_construction: None,
1037            ivf_residual: false,
1038            embedding_model: None,
1039        }
1040    }
1041
1042    fn update_for(schema: &Schema, pol: &VectorStoragePolicy) -> IcebergSchemaUpdate {
1043        arrow_schema_to_iceberg_update(schema, pol, &[])
1044    }
1045
1046    #[test]
1047    fn simple_schema_produces_correct_fields() {
1048        let schema = Schema::new(vec![
1049            Field::new("id", DataType::Int32, false),
1050            Field::new("text", DataType::Utf8, false),
1051        ]);
1052        let pol = policy("embedding", 8);
1053        let upd = update_for(&schema, &pol);
1054
1055        assert_eq!(upd.fields.len(), 3);
1056        assert_eq!(upd.fields[0]["id"], 1);
1057        assert_eq!(upd.fields[0]["type"], "int");
1058        assert_eq!(upd.fields[1]["id"], 2);
1059        assert_eq!(upd.fields[1]["type"], "string");
1060        assert_eq!(upd.fields[2]["id"], 3);
1061        assert_eq!(upd.fields[2]["type"], "fixed[16]"); // dim=8, F16=2 bytes
1062
1063        let nm: Vec<serde_json::Value> = serde_json::from_str(&upd.name_mapping_json).unwrap();
1064        assert_eq!(nm.len(), 3);
1065        assert_eq!(nm[2]["field-id"], 3);
1066        assert_eq!(nm[2]["names"][0], "embedding");
1067        assert_eq!(upd.last_column_id, 3);
1068    }
1069
1070    #[test]
1071    fn timestamp_without_tz_maps_to_timestamp_not_timestamptz() {
1072        let schema = Schema::new(vec![
1073            Field::new(
1074                "created_at",
1075                DataType::Timestamp(TimeUnit::Microsecond, None),
1076                true,
1077            ),
1078            Field::new(
1079                "updated_at",
1080                DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
1081                true,
1082            ),
1083        ]);
1084        let pol = policy("vec", 4);
1085        let upd = update_for(&schema, &pol);
1086
1087        assert_eq!(upd.fields[0]["type"], "timestamp");
1088        assert_eq!(upd.fields[1]["type"], "timestamptz");
1089    }
1090
1091    #[test]
1092    fn list_type_produces_iceberg_list_object() {
1093        let schema = Schema::new(vec![Field::new(
1094            "tags",
1095            DataType::List(std::sync::Arc::new(Field::new(
1096                "item",
1097                DataType::Utf8,
1098                true,
1099            ))),
1100            true,
1101        )]);
1102        let pol = policy("vec", 4);
1103        let upd = update_for(&schema, &pol);
1104
1105        let t = &upd.fields[0]["type"];
1106        assert_eq!(t["type"], "list");
1107        assert_eq!(t["element"], "string");
1108        // element-id must be > top-level field count (2: tags + vec)
1109        assert!(t["element-id"].as_i64().unwrap() > 2);
1110    }
1111
1112    #[test]
1113    fn struct_type_produces_nested_fields() {
1114        let schema = Schema::new(vec![Field::new(
1115            "meta",
1116            DataType::Struct(
1117                vec![
1118                    Field::new("key", DataType::Utf8, false),
1119                    Field::new("val", DataType::Int64, false),
1120                ]
1121                .into(),
1122            ),
1123            true,
1124        )]);
1125        let pol = policy("vec", 4);
1126        let upd = update_for(&schema, &pol);
1127
1128        let t = &upd.fields[0]["type"];
1129        assert_eq!(t["type"], "struct");
1130        let nested = t["fields"].as_array().unwrap();
1131        assert_eq!(nested.len(), 2);
1132        assert_eq!(nested[0]["name"], "key");
1133        assert_eq!(nested[0]["type"], "string");
1134        assert_eq!(nested[1]["name"], "val");
1135        assert_eq!(nested[1]["type"], "long");
1136        // Nested IDs must be > top-level count (2: meta + vec)
1137        assert!(nested[0]["id"].as_i64().unwrap() > 2);
1138    }
1139
1140    #[test]
1141    fn no_duplicate_vec_column_when_already_in_batch() {
1142        // If for some reason the vec column is in the batch schema, don't add it twice.
1143        let schema = Schema::new(vec![
1144            Field::new("id", DataType::Int32, false),
1145            Field::new("embedding", DataType::FixedSizeBinary(16), false),
1146        ]);
1147        let pol = policy("embedding", 8);
1148        let upd = update_for(&schema, &pol);
1149
1150        assert_eq!(upd.fields.len(), 2, "should not add embedding twice");
1151        let names: Vec<&str> = upd
1152            .fields
1153            .iter()
1154            .map(|f| f["name"].as_str().unwrap())
1155            .collect();
1156        assert_eq!(names.iter().filter(|&&n| n == "embedding").count(), 1);
1157    }
1158
1159    #[test]
1160    fn multi_vec_policies_all_appended() {
1161        let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
1162        let primary = policy("embedding", 4);
1163        let extra = vec![policy("context_embedding", 4)];
1164        let upd = arrow_schema_to_iceberg_update(&schema, &primary, &extra);
1165
1166        assert_eq!(upd.fields.len(), 3); // id + embedding + context_embedding
1167        let names: Vec<&str> = upd
1168            .fields
1169            .iter()
1170            .map(|f| f["name"].as_str().unwrap())
1171            .collect();
1172        assert!(names.contains(&"embedding"));
1173        assert!(names.contains(&"context_embedding"));
1174    }
1175
1176    #[test]
1177    fn top_level_field_ids_match_parquet_stamp_sequence() {
1178        // Top-level IDs must be 1, 2, ..., N regardless of nested element IDs.
1179        let schema = Schema::new(vec![
1180            Field::new("id", DataType::Int64, false),
1181            Field::new(
1182                "tags",
1183                DataType::List(std::sync::Arc::new(Field::new(
1184                    "item",
1185                    DataType::Utf8,
1186                    true,
1187                ))),
1188                true,
1189            ),
1190        ]);
1191        let pol = policy("vec", 4);
1192        let upd = update_for(&schema, &pol);
1193
1194        // Top-level: id=1, tags=2, vec=3
1195        assert_eq!(upd.fields[0]["id"], 1);
1196        assert_eq!(upd.fields[1]["id"], 2);
1197        assert_eq!(upd.fields[2]["id"], 3);
1198
1199        // Nested element ID must be > 3
1200        assert!(upd.fields[1]["type"]["element-id"].as_i64().unwrap() > 3);
1201    }
1202
1203    /// Smoke-test write_batch_auto_deferred: verifies that it completes without error
1204    /// and stages a pending file entry (index built asynchronously in background).
1205    #[tokio::test]
1206    async fn write_batch_auto_deferred_stages_file() {
1207        use ailake_catalog::{HadoopCatalog, TableIdent};
1208        use ailake_store::LocalStore;
1209        use arrow_schema::{DataType, Field, Schema};
1210
1211        let dir = tempfile::tempdir().unwrap();
1212        let store: std::sync::Arc<dyn ailake_store::Store> =
1213            std::sync::Arc::new(LocalStore::new(dir.path().to_str().unwrap()));
1214        let catalog = std::sync::Arc::new(HadoopCatalog::new(std::sync::Arc::clone(&store), ""));
1215        let pol = policy("embedding", 4);
1216        let ident = TableIdent::new("default", "t");
1217
1218        let mut writer = TableWriter::create_or_open(catalog, store, pol, ident)
1219            .await
1220            .unwrap();
1221
1222        let schema =
1223            std::sync::Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
1224        let batch = arrow_array::RecordBatch::try_new(
1225            schema,
1226            vec![std::sync::Arc::new(arrow_array::StringArray::from(vec![
1227                "hello",
1228            ]))],
1229        )
1230        .unwrap();
1231        let embeddings = vec![vec![1.0f32, 0.0, 0.0, 0.0]];
1232
1233        writer
1234            .write_batch_auto_deferred(&batch, &embeddings)
1235            .await
1236            .unwrap();
1237
1238        // One pending file should be staged even before commit.
1239        assert_eq!(writer.pending_files.len(), 1);
1240    }
1241}