Skip to main content

ailake_query/
migration.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Embedding model migration for AI-Lake tables.
3//!
4//! Reads all chunks from a table, re-embeds them with a new model, and writes
5//! new files with the updated embedding column. Two strategies are supported:
6//!
7//! - `AtomicReplace`: replaces each file one at a time. Lower peak storage, but
8//!   during the migration window different shards may have different columns.
9//! - `DualWriteThenCutover`: writes new files containing both old and new columns,
10//!   then atomically replaces all old files. Higher peak storage, zero downtime.
11
12use std::sync::Arc;
13
14use ailake_catalog::{
15    make_data_file_entry, new_snapshot_id, CatalogProvider, DataFileEntry, NewSnapshot,
16    SnapshotOperation, TableIdent, VectorIndexInfo,
17};
18use ailake_core::{AilakeError, AilakeResult, EmbeddingModelInfo, VectorStoragePolicy};
19use ailake_file::{AilakeFileReader, AilakeFileWriter};
20use ailake_store::Store;
21use ailake_vec::compute_centroid_and_radius;
22use arrow_array::{Array, RecordBatch, StringArray};
23use tracing::info;
24
25pub type EmbedFn = Arc<dyn Fn(&[String]) -> AilakeResult<Vec<Vec<f32>>> + Send + Sync>;
26pub type ProgressFn = Arc<dyn Fn(MigrationProgress) + Send + Sync>;
27
28/// How files are replaced during migration.
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub enum MigrationStrategy {
31    /// Write new files file-by-file, replacing each old file as it completes.
32    /// Lower peak storage. During migration, some shards have old column, others new.
33    AtomicReplace,
34    /// Write all new files first (old files untouched), then commit a single Replace
35    /// snapshot swapping all old files for new ones atomically.
36    /// Higher peak storage (2× during migration), but reads always see a consistent view.
37    DualWriteThenCutover,
38}
39
40/// Progress reported via `on_progress` callback.
41#[derive(Debug, Clone)]
42pub struct MigrationProgress {
43    pub files_done: usize,
44    pub files_total: usize,
45    pub rows_migrated: u64,
46}
47
48/// Migrates embedding columns in an AI-Lake table to a new model.
49///
50/// Usage:
51/// ```ignore
52/// let job = MigrationJob {
53///     table: TableIdent::new("default", "docs"),
54///     old_column: "embedding".to_string(),
55///     new_column: "embedding_v2".to_string(),
56///     text_column: "chunk_text".to_string(),
57///     embed_fn: Arc::new(|texts| Ok(my_model.encode(texts))),
58///     strategy: MigrationStrategy::DualWriteThenCutover,
59///     batch_size: 10_000,
60///     new_model: Some(EmbeddingModelInfo::new("my-model-v2")),
61///     on_progress: None,
62/// };
63/// job.run(catalog, store).await?;
64/// ```
65pub struct MigrationJob {
66    pub table: TableIdent,
67    /// Name of the embedding column to replace (e.g., "embedding").
68    pub old_column: String,
69    /// Name to give the new embedding column (e.g., "embedding_v2").
70    /// Can be the same as `old_column` to do an in-place model upgrade.
71    pub new_column: String,
72    /// Column in the Parquet files that holds the text to re-embed.
73    /// Defaults to `chunk_text` (the `LlmContextSchema` canonical name).
74    pub text_column: String,
75    /// Callable that converts a slice of texts to embeddings.
76    /// Must return exactly `texts.len()` vectors, all of the same dimension.
77    pub embed_fn: EmbedFn,
78    pub strategy: MigrationStrategy,
79    /// How many rows to embed per `embed_fn` call. Tune based on model batch size.
80    pub batch_size: usize,
81    /// Metadata for the new embedding model — stored in Iceberg properties after migration.
82    pub new_model: Option<EmbeddingModelInfo>,
83    /// Optional callback called after each file completes.
84    pub on_progress: Option<ProgressFn>,
85}
86
87impl MigrationJob {
88    pub async fn run(
89        self,
90        catalog: Arc<dyn CatalogProvider>,
91        store: Arc<dyn Store>,
92    ) -> AilakeResult<()> {
93        match self.strategy {
94            MigrationStrategy::AtomicReplace => self.run_atomic_replace(catalog, store).await,
95            MigrationStrategy::DualWriteThenCutover => self.run_dual_write(catalog, store).await,
96        }
97    }
98
99    /// AtomicReplace: process and commit each file one at a time.
100    async fn run_atomic_replace(
101        &self,
102        catalog: Arc<dyn CatalogProvider>,
103        store: Arc<dyn Store>,
104    ) -> AilakeResult<()> {
105        let table_meta = catalog.load_table(&self.table).await?;
106        let old_files = catalog
107            .list_files(&self.table, table_meta.current_snapshot_id)
108            .await?;
109        let total = old_files.len();
110        let mut rows_migrated: u64 = 0;
111
112        // Derive new policy from table properties + new model info
113        let new_policy = self.new_policy_from_metadata(&table_meta.properties)?;
114
115        let mut parent_snap = table_meta.current_snapshot_id;
116
117        for (idx, old_entry) in old_files.iter().enumerate() {
118            let (batch, texts) = self
119                .read_file_texts(&old_entry.path, &store, &new_policy)
120                .await?;
121            let new_embeddings = self.embed_in_batches(&texts)?;
122
123            let new_entry = self
124                .write_new_file(&batch, &new_embeddings, &new_policy, &store, idx)
125                .await?;
126
127            rows_migrated += new_entry.record_count;
128
129            let snap_id = new_snapshot_id();
130            catalog
131                .commit_snapshot(
132                    &self.table,
133                    NewSnapshot {
134                        snapshot_id: snap_id,
135                        parent_snapshot_id: parent_snap,
136                        files: vec![new_entry],
137                        operation: SnapshotOperation::Replace,
138                        iceberg_schema: None,
139                    },
140                )
141                .await?;
142            parent_snap = Some(snap_id);
143
144            if let Some(cb) = &self.on_progress {
145                cb(MigrationProgress {
146                    files_done: idx + 1,
147                    files_total: total,
148                    rows_migrated,
149                });
150            }
151
152            info!(
153                "ailake migration: AtomicReplace {}/{} files done, {} rows migrated",
154                idx + 1,
155                total,
156                rows_migrated
157            );
158        }
159
160        Ok(())
161    }
162
163    /// DualWriteThenCutover: write all new files first, then commit one Replace snapshot.
164    async fn run_dual_write(
165        &self,
166        catalog: Arc<dyn CatalogProvider>,
167        store: Arc<dyn Store>,
168    ) -> AilakeResult<()> {
169        let table_meta = catalog.load_table(&self.table).await?;
170        let old_files = catalog
171            .list_files(&self.table, table_meta.current_snapshot_id)
172            .await?;
173        let total = old_files.len();
174        let mut rows_migrated: u64 = 0;
175
176        let new_policy = self.new_policy_from_metadata(&table_meta.properties)?;
177        let mut new_entries: Vec<DataFileEntry> = Vec::with_capacity(total);
178
179        for (idx, old_entry) in old_files.iter().enumerate() {
180            let (batch, texts) = self
181                .read_file_texts(&old_entry.path, &store, &new_policy)
182                .await?;
183            let new_embeddings = self.embed_in_batches(&texts)?;
184
185            let entry = self
186                .write_new_file(&batch, &new_embeddings, &new_policy, &store, idx)
187                .await?;
188
189            rows_migrated += entry.record_count;
190            new_entries.push(entry);
191
192            if let Some(cb) = &self.on_progress {
193                cb(MigrationProgress {
194                    files_done: idx + 1,
195                    files_total: total,
196                    rows_migrated,
197                });
198            }
199
200            info!(
201                "ailake migration: DualWrite phase {}/{} files ready",
202                idx + 1,
203                total
204            );
205        }
206
207        // Single atomic cutover: replace all old files with all new files.
208        let snap_id = new_snapshot_id();
209        catalog
210            .commit_snapshot(
211                &self.table,
212                NewSnapshot {
213                    snapshot_id: snap_id,
214                    parent_snapshot_id: table_meta.current_snapshot_id,
215                    files: new_entries,
216                    operation: SnapshotOperation::Replace,
217                    iceberg_schema: None,
218                },
219            )
220            .await?;
221
222        info!(
223            "ailake migration: DualWriteThenCutover complete — {} files, {} rows",
224            total, rows_migrated
225        );
226        Ok(())
227    }
228
229    /// Read Parquet bytes from store, decode the text column.
230    async fn read_file_texts(
231        &self,
232        path: &str,
233        store: &Arc<dyn Store>,
234        policy: &VectorStoragePolicy,
235    ) -> AilakeResult<(RecordBatch, Vec<String>)> {
236        let bytes = store.get(path).await?;
237        let reader = AilakeFileReader::new(bytes, &self.old_column, policy.dim);
238        let (batch, _) = reader.read_parquet()?;
239
240        let texts = extract_string_column(&batch, &self.text_column)?;
241        Ok((batch, texts))
242    }
243
244    /// Call embed_fn in chunks of batch_size.
245    fn embed_in_batches(&self, texts: &[String]) -> AilakeResult<Vec<Vec<f32>>> {
246        let mut all: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
247        for chunk in texts.chunks(self.batch_size) {
248            let mut chunk_vecs = (self.embed_fn)(chunk)?;
249            all.append(&mut chunk_vecs);
250        }
251        Ok(all)
252    }
253
254    /// Write a new AI-Lake file with the re-embedded vectors, return its catalog entry.
255    async fn write_new_file(
256        &self,
257        batch: &RecordBatch,
258        embeddings: &[Vec<f32>],
259        policy: &VectorStoragePolicy,
260        store: &Arc<dyn Store>,
261        idx: usize,
262    ) -> AilakeResult<DataFileEntry> {
263        let file_path = format!("data/migrated-{:05}.parquet", idx);
264
265        let writer = AilakeFileWriter::new(policy.clone());
266        let file_bytes = writer.write(batch, embeddings)?;
267        let file_size = file_bytes.len() as u64;
268
269        store.put(&file_path, file_bytes.clone()).await?;
270
271        let centroid = compute_centroid_and_radius(embeddings, policy.metric);
272        let reader = AilakeFileReader::new(file_bytes, &policy.column_name, policy.dim);
273        let header = reader.read_header()?;
274        let ailk_start = reader.ailk_offset()?;
275        let hnsw_abs = ailk_start + header.hnsw_offset;
276
277        Ok(make_data_file_entry(
278            &file_path,
279            embeddings.len() as u64,
280            file_size,
281            &centroid,
282            VectorIndexInfo {
283                column: &policy.column_name,
284                dim: policy.dim,
285                hnsw_offset: hnsw_abs,
286                hnsw_len: header.hnsw_len,
287            },
288        ))
289    }
290
291    /// Build the new `VectorStoragePolicy` from existing table properties,
292    /// overriding the column name and embedding model.
293    fn new_policy_from_metadata(
294        &self,
295        props: &std::collections::HashMap<String, String>,
296    ) -> AilakeResult<VectorStoragePolicy> {
297        use ailake_core::{VectorMetric, VectorPrecision};
298
299        let dim: u32 = props
300            .get("ailake.vector-dim")
301            .and_then(|s| s.parse().ok())
302            .ok_or_else(|| {
303                AilakeError::InvalidArgument("table missing ailake.vector-dim property".into())
304            })?;
305
306        let metric = match props
307            .get("ailake.vector-metric")
308            .map(|s| s.as_str())
309            .unwrap_or("cosine")
310        {
311            "euclidean" => VectorMetric::Euclidean,
312            "dotproduct" | "dot_product" => VectorMetric::DotProduct,
313            "normalizedcosine" | "normalized_cosine" => VectorMetric::NormalizedCosine,
314            _ => VectorMetric::Cosine,
315        };
316
317        let precision = match props
318            .get("ailake.vector-precision")
319            .map(|s| s.as_str())
320            .unwrap_or("f16")
321        {
322            "f32" => VectorPrecision::F32,
323            "i8" => VectorPrecision::I8,
324            _ => VectorPrecision::F16,
325        };
326
327        Ok(VectorStoragePolicy {
328            column_name: self.new_column.clone(),
329            dim,
330            metric,
331            precision,
332            pq: None,
333            keep_raw_for_reranking: true,
334            pre_normalize: props
335                .get("ailake.pre-normalize")
336                .map(|s| s == "true")
337                .unwrap_or(false),
338            hnsw_m: props.get("ailake.hnsw-m").and_then(|s| s.parse().ok()),
339            hnsw_ef_construction: props
340                .get("ailake.hnsw-ef-construction")
341                .and_then(|s| s.parse().ok()),
342            ivf_residual: false,
343            embedding_model: self.new_model.clone(),
344        })
345    }
346}
347
348fn extract_string_column(batch: &RecordBatch, column_name: &str) -> AilakeResult<Vec<String>> {
349    let col = batch.column_by_name(column_name).ok_or_else(|| {
350        AilakeError::InvalidArgument(format!(
351            "text column '{}' not found in Parquet file; available: {}",
352            column_name,
353            batch
354                .schema()
355                .fields()
356                .iter()
357                .map(|f| f.name().as_str())
358                .collect::<Vec<_>>()
359                .join(", ")
360        ))
361    })?;
362
363    let arr = col.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
364        AilakeError::InvalidArgument(format!(
365            "column '{}' is not a Utf8/String column",
366            column_name
367        ))
368    })?;
369
370    Ok((0..arr.len())
371        .map(|i| {
372            if arr.is_null(i) {
373                String::new()
374            } else {
375                arr.value(i).to_string()
376            }
377        })
378        .collect())
379}