Skip to main content

ailake_query/
migration.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Embedding model migration for AI-Lake tables.
3//!
4//! Reads all chunks from a table, re-embeds them with a new model, and writes
5//! new files with the updated embedding column. Two strategies are supported:
6//!
7//! - `AtomicReplace`: replaces each file one at a time. Lower peak storage, but
8//!   during the migration window different shards may have different columns.
9//! - `DualWriteThenCutover`: writes new files containing both old and new columns,
10//!   then atomically replaces all old files. Higher peak storage, zero downtime.
11
12use std::sync::Arc;
13
14use ailake_catalog::{
15    make_data_file_entry, new_snapshot_id, CatalogProvider, DataFileEntry, NewSnapshot,
16    SnapshotOperation, TableIdent, VectorIndexInfo,
17};
18use ailake_core::{AilakeError, AilakeResult, EmbeddingModelInfo, VectorStoragePolicy};
19use ailake_file::{AilakeFileReader, AilakeFileWriter};
20use ailake_store::Store;
21use ailake_vec::compute_centroid_and_radius;
22use arrow_array::{Array, RecordBatch, StringArray};
23use tracing::info;
24
25pub type EmbedFn = Arc<dyn Fn(&[String]) -> AilakeResult<Vec<Vec<f32>>> + Send + Sync>;
26pub type ProgressFn = Arc<dyn Fn(MigrationProgress) + Send + Sync>;
27
28/// How files are replaced during migration.
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub enum MigrationStrategy {
31    /// Write new files file-by-file, replacing each old file as it completes.
32    /// Lower peak storage. During migration, some shards have old column, others new.
33    AtomicReplace,
34    /// Write all new files first (old files untouched), then commit a single Replace
35    /// snapshot swapping all old files for new ones atomically.
36    /// Higher peak storage (2× during migration), but reads always see a consistent view.
37    DualWriteThenCutover,
38}
39
40/// Progress reported via `on_progress` callback.
41#[derive(Debug, Clone)]
42pub struct MigrationProgress {
43    pub files_done: usize,
44    pub files_total: usize,
45    pub rows_migrated: u64,
46}
47
48/// Migrates embedding columns in an AI-Lake table to a new model.
49///
50/// Usage:
51/// ```ignore
52/// let job = MigrationJob {
53///     table: TableIdent::new("default", "docs"),
54///     old_column: "embedding".to_string(),
55///     new_column: "embedding_v2".to_string(),
56///     text_column: "chunk_text".to_string(),
57///     embed_fn: Arc::new(|texts| Ok(my_model.encode(texts))),
58///     strategy: MigrationStrategy::DualWriteThenCutover,
59///     batch_size: 10_000,
60///     new_model: Some(EmbeddingModelInfo::new("my-model-v2")),
61///     on_progress: None,
62/// };
63/// job.run(catalog, store).await?;
64/// ```
65pub struct MigrationJob {
66    pub table: TableIdent,
67    /// Name of the embedding column to replace (e.g., "embedding").
68    pub old_column: String,
69    /// Name to give the new embedding column (e.g., "embedding_v2").
70    /// Can be the same as `old_column` to do an in-place model upgrade.
71    pub new_column: String,
72    /// Column in the Parquet files that holds the text to re-embed.
73    /// Defaults to `chunk_text` (the `LlmContextSchema` canonical name).
74    pub text_column: String,
75    /// Callable that converts a slice of texts to embeddings.
76    /// Must return exactly `texts.len()` vectors, all of the same dimension.
77    pub embed_fn: EmbedFn,
78    pub strategy: MigrationStrategy,
79    /// How many rows to embed per `embed_fn` call. Tune based on model batch size.
80    pub batch_size: usize,
81    /// Metadata for the new embedding model — stored in Iceberg properties after migration.
82    pub new_model: Option<EmbeddingModelInfo>,
83    /// Optional callback called after each file completes.
84    pub on_progress: Option<ProgressFn>,
85}
86
87impl MigrationJob {
88    pub async fn run(
89        self,
90        catalog: Arc<dyn CatalogProvider>,
91        store: Arc<dyn Store>,
92    ) -> AilakeResult<()> {
93        match self.strategy {
94            MigrationStrategy::AtomicReplace => self.run_atomic_replace(catalog, store).await,
95            MigrationStrategy::DualWriteThenCutover => self.run_dual_write(catalog, store).await,
96        }
97    }
98
99    /// AtomicReplace: process and commit each file one at a time.
100    async fn run_atomic_replace(
101        &self,
102        catalog: Arc<dyn CatalogProvider>,
103        store: Arc<dyn Store>,
104    ) -> AilakeResult<()> {
105        let table_meta = catalog.load_table(&self.table).await?;
106        let old_files = catalog
107            .list_files(&self.table, table_meta.current_snapshot_id)
108            .await?;
109        let total = old_files.len();
110        let mut rows_migrated: u64 = 0;
111
112        // Derive new policy from table properties + new model info
113        let new_policy = self.new_policy_from_metadata(&table_meta.properties)?;
114
115        let mut parent_snap = table_meta.current_snapshot_id;
116
117        for (idx, old_entry) in old_files.iter().enumerate() {
118            let (batch, texts) = self
119                .read_file_texts(&old_entry.path, &store, &new_policy)
120                .await?;
121            let new_embeddings = self.embed_in_batches(&texts)?;
122
123            let new_entry = self
124                .write_new_file(&batch, &new_embeddings, &new_policy, &store, idx)
125                .await?;
126
127            rows_migrated += new_entry.record_count;
128
129            let snap_id = new_snapshot_id();
130            catalog
131                .commit_snapshot(
132                    &self.table,
133                    NewSnapshot {
134                        snapshot_id: snap_id,
135                        parent_snapshot_id: parent_snap,
136                        files: vec![new_entry],
137                        operation: SnapshotOperation::Replace,
138                        iceberg_schema: None,
139                        extra_properties: std::collections::HashMap::new(),
140                        bloom_filters: vec![],
141                        equality_delete_files: vec![],
142                    },
143                )
144                .await?;
145            parent_snap = Some(snap_id);
146
147            if let Some(cb) = &self.on_progress {
148                cb(MigrationProgress {
149                    files_done: idx + 1,
150                    files_total: total,
151                    rows_migrated,
152                });
153            }
154
155            info!(
156                "ailake migration: AtomicReplace {}/{} files done, {} rows migrated",
157                idx + 1,
158                total,
159                rows_migrated
160            );
161        }
162
163        Ok(())
164    }
165
166    /// DualWriteThenCutover: write all new files first, then commit one Replace snapshot.
167    async fn run_dual_write(
168        &self,
169        catalog: Arc<dyn CatalogProvider>,
170        store: Arc<dyn Store>,
171    ) -> AilakeResult<()> {
172        let table_meta = catalog.load_table(&self.table).await?;
173        let old_files = catalog
174            .list_files(&self.table, table_meta.current_snapshot_id)
175            .await?;
176        let total = old_files.len();
177        let mut rows_migrated: u64 = 0;
178
179        let new_policy = self.new_policy_from_metadata(&table_meta.properties)?;
180        let mut new_entries: Vec<DataFileEntry> = Vec::with_capacity(total);
181
182        for (idx, old_entry) in old_files.iter().enumerate() {
183            let (batch, texts) = self
184                .read_file_texts(&old_entry.path, &store, &new_policy)
185                .await?;
186            let new_embeddings = self.embed_in_batches(&texts)?;
187
188            let entry = self
189                .write_new_file(&batch, &new_embeddings, &new_policy, &store, idx)
190                .await?;
191
192            rows_migrated += entry.record_count;
193            new_entries.push(entry);
194
195            if let Some(cb) = &self.on_progress {
196                cb(MigrationProgress {
197                    files_done: idx + 1,
198                    files_total: total,
199                    rows_migrated,
200                });
201            }
202
203            info!(
204                "ailake migration: DualWrite phase {}/{} files ready",
205                idx + 1,
206                total
207            );
208        }
209
210        // Single atomic cutover: replace all old files with all new files.
211        let snap_id = new_snapshot_id();
212        catalog
213            .commit_snapshot(
214                &self.table,
215                NewSnapshot {
216                    snapshot_id: snap_id,
217                    parent_snapshot_id: table_meta.current_snapshot_id,
218                    files: new_entries,
219                    operation: SnapshotOperation::Replace,
220                    iceberg_schema: None,
221                    extra_properties: std::collections::HashMap::new(),
222                    bloom_filters: vec![],
223                    equality_delete_files: vec![],
224                },
225            )
226            .await?;
227
228        info!(
229            "ailake migration: DualWriteThenCutover complete — {} files, {} rows",
230            total, rows_migrated
231        );
232        Ok(())
233    }
234
235    /// Read Parquet bytes from store, decode the text column.
236    async fn read_file_texts(
237        &self,
238        path: &str,
239        store: &Arc<dyn Store>,
240        policy: &VectorStoragePolicy,
241    ) -> AilakeResult<(RecordBatch, Vec<String>)> {
242        let bytes = store.get(path).await?;
243        let reader = AilakeFileReader::new(bytes, &self.old_column, policy.dim);
244        let (batch, _) = reader.read_parquet()?;
245
246        let texts = extract_string_column(&batch, &self.text_column)?;
247        Ok((batch, texts))
248    }
249
250    /// Call embed_fn in chunks of batch_size.
251    fn embed_in_batches(&self, texts: &[String]) -> AilakeResult<Vec<Vec<f32>>> {
252        let mut all: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
253        for chunk in texts.chunks(self.batch_size) {
254            let mut chunk_vecs = (self.embed_fn)(chunk)?;
255            all.append(&mut chunk_vecs);
256        }
257        Ok(all)
258    }
259
260    /// Write a new AI-Lake file with the re-embedded vectors, return its catalog entry.
261    async fn write_new_file(
262        &self,
263        batch: &RecordBatch,
264        embeddings: &[Vec<f32>],
265        policy: &VectorStoragePolicy,
266        store: &Arc<dyn Store>,
267        idx: usize,
268    ) -> AilakeResult<DataFileEntry> {
269        let file_path = format!("data/migrated-{:05}.parquet", idx);
270
271        let writer = AilakeFileWriter::new(policy.clone());
272        let file_bytes = writer.write(batch, embeddings)?;
273        let file_size = file_bytes.len() as u64;
274
275        store.put(&file_path, file_bytes.clone()).await?;
276
277        let centroid = compute_centroid_and_radius(embeddings, policy.metric);
278        let reader = AilakeFileReader::new(file_bytes, &policy.column_name, policy.dim);
279        let header = reader.read_header()?;
280        let ailk_start = reader.ailk_offset()?;
281        let hnsw_abs = ailk_start + header.hnsw_offset;
282
283        Ok(make_data_file_entry(
284            &file_path,
285            embeddings.len() as u64,
286            file_size,
287            &centroid,
288            VectorIndexInfo {
289                column: &policy.column_name,
290                dim: policy.dim,
291                hnsw_offset: hnsw_abs,
292                hnsw_len: header.hnsw_len,
293            },
294        ))
295    }
296
297    /// Build the new `VectorStoragePolicy` from existing table properties,
298    /// overriding the column name and embedding model.
299    fn new_policy_from_metadata(
300        &self,
301        props: &std::collections::HashMap<String, String>,
302    ) -> AilakeResult<VectorStoragePolicy> {
303        use ailake_core::{VectorMetric, VectorPrecision};
304
305        let dim: u32 = props
306            .get("ailake.vector-dim")
307            .and_then(|s| s.parse().ok())
308            .ok_or_else(|| {
309                AilakeError::InvalidArgument("table missing ailake.vector-dim property".into())
310            })?;
311
312        let metric = match props
313            .get("ailake.vector-metric")
314            .map(|s| s.as_str())
315            .unwrap_or("cosine")
316        {
317            "euclidean" => VectorMetric::Euclidean,
318            "dotproduct" | "dot_product" => VectorMetric::DotProduct,
319            "normalizedcosine" | "normalized_cosine" => VectorMetric::NormalizedCosine,
320            _ => VectorMetric::Cosine,
321        };
322
323        let precision = match props
324            .get("ailake.vector-precision")
325            .map(|s| s.as_str())
326            .unwrap_or("f16")
327        {
328            "f32" => VectorPrecision::F32,
329            "i8" => VectorPrecision::I8,
330            _ => VectorPrecision::F16,
331        };
332
333        Ok(VectorStoragePolicy {
334            column_name: self.new_column.clone(),
335            dim,
336            metric,
337            precision,
338            pq: None,
339            keep_raw_for_reranking: true,
340            pre_normalize: props
341                .get("ailake.pre-normalize")
342                .map(|s| s == "true")
343                .unwrap_or(false),
344            hnsw_m: props.get("ailake.hnsw-m").and_then(|s| s.parse().ok()),
345            hnsw_ef_construction: props
346                .get("ailake.hnsw-ef-construction")
347                .and_then(|s| s.parse().ok()),
348            ivf_residual: false,
349            embedding_model: self.new_model.clone(),
350            modality: None,
351            partition_by: None,
352            partition_value: None,
353            partition_column_type: None,
354            partition_fields: vec![],
355        })
356    }
357}
358
359fn extract_string_column(batch: &RecordBatch, column_name: &str) -> AilakeResult<Vec<String>> {
360    let col = batch.column_by_name(column_name).ok_or_else(|| {
361        AilakeError::InvalidArgument(format!(
362            "text column '{}' not found in Parquet file; available: {}",
363            column_name,
364            batch
365                .schema()
366                .fields()
367                .iter()
368                .map(|f| f.name().as_str())
369                .collect::<Vec<_>>()
370                .join(", ")
371        ))
372    })?;
373
374    let arr = col.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
375        AilakeError::InvalidArgument(format!(
376            "column '{}' is not a Utf8/String column",
377            column_name
378        ))
379    })?;
380
381    Ok((0..arr.len())
382        .map(|i| {
383            if arr.is_null(i) {
384                String::new()
385            } else {
386                arr.value(i).to_string()
387            }
388        })
389        .collect())
390}