Skip to main content

ailake_query/
migration.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Embedding model migration for AI-Lake tables.
3//!
4//! Reads all chunks from a table, re-embeds them with a new model, and writes
5//! new files with the updated embedding column. Two strategies are supported:
6//!
7//! - `AtomicReplace`: replaces each file one at a time. Lower peak storage, but
8//!   during the migration window different shards may have different columns.
9//! - `DualWriteThenCutover`: writes new files containing both old and new columns,
10//!   then atomically replaces all old files. Higher peak storage, zero downtime.
11
12use std::sync::Arc;
13
14use ailake_catalog::{
15    make_data_file_entry, new_snapshot_id, CatalogProvider, DataFileEntry, NewSnapshot,
16    SnapshotOperation, TableIdent, VectorIndexInfo,
17};
18use ailake_core::{AilakeError, AilakeResult, EmbeddingModelInfo, VectorStoragePolicy};
19use ailake_file::{AilakeFileReader, AilakeFileWriter};
20use ailake_store::Store;
21use ailake_vec::compute_centroid_and_radius;
22use arrow_array::{Array, RecordBatch, StringArray};
23use tracing::info;
24
25pub type EmbedFn = Arc<dyn Fn(&[String]) -> AilakeResult<Vec<Vec<f32>>> + Send + Sync>;
26pub type ProgressFn = Arc<dyn Fn(MigrationProgress) + Send + Sync>;
27
28/// How files are replaced during migration.
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub enum MigrationStrategy {
31    /// Write new files file-by-file, replacing each old file as it completes.
32    /// Lower peak storage. During migration, some shards have old column, others new.
33    AtomicReplace,
34    /// Write all new files first (old files untouched), then commit a single Replace
35    /// snapshot swapping all old files for new ones atomically.
36    /// Higher peak storage (2× during migration), but reads always see a consistent view.
37    DualWriteThenCutover,
38}
39
40/// Progress reported via `on_progress` callback.
41#[derive(Debug, Clone)]
42pub struct MigrationProgress {
43    pub files_done: usize,
44    pub files_total: usize,
45    pub rows_migrated: u64,
46}
47
48/// Migrates embedding columns in an AI-Lake table to a new model.
49///
50/// Usage:
51/// ```ignore
52/// let job = MigrationJob {
53///     table: TableIdent::new("default", "docs"),
54///     old_column: "embedding".to_string(),
55///     new_column: "embedding_v2".to_string(),
56///     text_column: "chunk_text".to_string(),
57///     embed_fn: Arc::new(|texts| Ok(my_model.encode(texts))),
58///     strategy: MigrationStrategy::DualWriteThenCutover,
59///     batch_size: 10_000,
60///     new_model: Some(EmbeddingModelInfo::new("my-model-v2")),
61///     on_progress: None,
62/// };
63/// job.run(catalog, store).await?;
64/// ```
65pub struct MigrationJob {
66    pub table: TableIdent,
67    /// Name of the embedding column to replace (e.g., "embedding").
68    pub old_column: String,
69    /// Name to give the new embedding column (e.g., "embedding_v2").
70    /// Can be the same as `old_column` to do an in-place model upgrade.
71    pub new_column: String,
72    /// Column in the Parquet files that holds the text to re-embed.
73    /// Defaults to `chunk_text` (the `LlmContextSchema` canonical name).
74    pub text_column: String,
75    /// Callable that converts a slice of texts to embeddings.
76    /// Must return exactly `texts.len()` vectors, all of the same dimension.
77    pub embed_fn: EmbedFn,
78    pub strategy: MigrationStrategy,
79    /// How many rows to embed per `embed_fn` call. Tune based on model batch size.
80    pub batch_size: usize,
81    /// Metadata for the new embedding model — stored in Iceberg properties after migration.
82    pub new_model: Option<EmbeddingModelInfo>,
83    /// Optional callback called after each file completes.
84    pub on_progress: Option<ProgressFn>,
85}
86
87impl MigrationJob {
88    pub async fn run(
89        self,
90        catalog: Arc<dyn CatalogProvider>,
91        store: Arc<dyn Store>,
92    ) -> AilakeResult<()> {
93        match self.strategy {
94            MigrationStrategy::AtomicReplace => self.run_atomic_replace(catalog, store).await,
95            MigrationStrategy::DualWriteThenCutover => self.run_dual_write(catalog, store).await,
96        }
97    }
98
99    /// AtomicReplace: process and commit each file one at a time.
100    async fn run_atomic_replace(
101        &self,
102        catalog: Arc<dyn CatalogProvider>,
103        store: Arc<dyn Store>,
104    ) -> AilakeResult<()> {
105        let table_meta = catalog.load_table(&self.table).await?;
106        let old_files = catalog
107            .list_files(&self.table, table_meta.current_snapshot_id)
108            .await?;
109        let total = old_files.len();
110        let mut rows_migrated: u64 = 0;
111
112        // Derive new policy from table properties + new model info
113        let new_policy = self.new_policy_from_metadata(&table_meta.properties)?;
114
115        let mut parent_snap = table_meta.current_snapshot_id;
116
117        for (idx, old_entry) in old_files.iter().enumerate() {
118            let (batch, texts) = self
119                .read_file_texts(&old_entry.path, &store, &new_policy)
120                .await?;
121            let new_embeddings = self.embed_in_batches(&texts)?;
122
123            let new_entry = self
124                .write_new_file(&batch, &new_embeddings, &new_policy, &store, idx)
125                .await?;
126
127            rows_migrated += new_entry.record_count;
128
129            let snap_id = new_snapshot_id();
130            catalog
131                .commit_snapshot(
132                    &self.table,
133                    NewSnapshot {
134                        snapshot_id: snap_id,
135                        parent_snapshot_id: parent_snap,
136                        files: vec![new_entry],
137                        operation: SnapshotOperation::Replace,
138                        iceberg_schema: None,
139                        extra_properties: std::collections::HashMap::new(),
140                    },
141                )
142                .await?;
143            parent_snap = Some(snap_id);
144
145            if let Some(cb) = &self.on_progress {
146                cb(MigrationProgress {
147                    files_done: idx + 1,
148                    files_total: total,
149                    rows_migrated,
150                });
151            }
152
153            info!(
154                "ailake migration: AtomicReplace {}/{} files done, {} rows migrated",
155                idx + 1,
156                total,
157                rows_migrated
158            );
159        }
160
161        Ok(())
162    }
163
164    /// DualWriteThenCutover: write all new files first, then commit one Replace snapshot.
165    async fn run_dual_write(
166        &self,
167        catalog: Arc<dyn CatalogProvider>,
168        store: Arc<dyn Store>,
169    ) -> AilakeResult<()> {
170        let table_meta = catalog.load_table(&self.table).await?;
171        let old_files = catalog
172            .list_files(&self.table, table_meta.current_snapshot_id)
173            .await?;
174        let total = old_files.len();
175        let mut rows_migrated: u64 = 0;
176
177        let new_policy = self.new_policy_from_metadata(&table_meta.properties)?;
178        let mut new_entries: Vec<DataFileEntry> = Vec::with_capacity(total);
179
180        for (idx, old_entry) in old_files.iter().enumerate() {
181            let (batch, texts) = self
182                .read_file_texts(&old_entry.path, &store, &new_policy)
183                .await?;
184            let new_embeddings = self.embed_in_batches(&texts)?;
185
186            let entry = self
187                .write_new_file(&batch, &new_embeddings, &new_policy, &store, idx)
188                .await?;
189
190            rows_migrated += entry.record_count;
191            new_entries.push(entry);
192
193            if let Some(cb) = &self.on_progress {
194                cb(MigrationProgress {
195                    files_done: idx + 1,
196                    files_total: total,
197                    rows_migrated,
198                });
199            }
200
201            info!(
202                "ailake migration: DualWrite phase {}/{} files ready",
203                idx + 1,
204                total
205            );
206        }
207
208        // Single atomic cutover: replace all old files with all new files.
209        let snap_id = new_snapshot_id();
210        catalog
211            .commit_snapshot(
212                &self.table,
213                NewSnapshot {
214                    snapshot_id: snap_id,
215                    parent_snapshot_id: table_meta.current_snapshot_id,
216                    files: new_entries,
217                    operation: SnapshotOperation::Replace,
218                    iceberg_schema: None,
219                    extra_properties: std::collections::HashMap::new(),
220                },
221            )
222            .await?;
223
224        info!(
225            "ailake migration: DualWriteThenCutover complete — {} files, {} rows",
226            total, rows_migrated
227        );
228        Ok(())
229    }
230
231    /// Read Parquet bytes from store, decode the text column.
232    async fn read_file_texts(
233        &self,
234        path: &str,
235        store: &Arc<dyn Store>,
236        policy: &VectorStoragePolicy,
237    ) -> AilakeResult<(RecordBatch, Vec<String>)> {
238        let bytes = store.get(path).await?;
239        let reader = AilakeFileReader::new(bytes, &self.old_column, policy.dim);
240        let (batch, _) = reader.read_parquet()?;
241
242        let texts = extract_string_column(&batch, &self.text_column)?;
243        Ok((batch, texts))
244    }
245
246    /// Call embed_fn in chunks of batch_size.
247    fn embed_in_batches(&self, texts: &[String]) -> AilakeResult<Vec<Vec<f32>>> {
248        let mut all: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
249        for chunk in texts.chunks(self.batch_size) {
250            let mut chunk_vecs = (self.embed_fn)(chunk)?;
251            all.append(&mut chunk_vecs);
252        }
253        Ok(all)
254    }
255
256    /// Write a new AI-Lake file with the re-embedded vectors, return its catalog entry.
257    async fn write_new_file(
258        &self,
259        batch: &RecordBatch,
260        embeddings: &[Vec<f32>],
261        policy: &VectorStoragePolicy,
262        store: &Arc<dyn Store>,
263        idx: usize,
264    ) -> AilakeResult<DataFileEntry> {
265        let file_path = format!("data/migrated-{:05}.parquet", idx);
266
267        let writer = AilakeFileWriter::new(policy.clone());
268        let file_bytes = writer.write(batch, embeddings)?;
269        let file_size = file_bytes.len() as u64;
270
271        store.put(&file_path, file_bytes.clone()).await?;
272
273        let centroid = compute_centroid_and_radius(embeddings, policy.metric);
274        let reader = AilakeFileReader::new(file_bytes, &policy.column_name, policy.dim);
275        let header = reader.read_header()?;
276        let ailk_start = reader.ailk_offset()?;
277        let hnsw_abs = ailk_start + header.hnsw_offset;
278
279        Ok(make_data_file_entry(
280            &file_path,
281            embeddings.len() as u64,
282            file_size,
283            &centroid,
284            VectorIndexInfo {
285                column: &policy.column_name,
286                dim: policy.dim,
287                hnsw_offset: hnsw_abs,
288                hnsw_len: header.hnsw_len,
289            },
290        ))
291    }
292
293    /// Build the new `VectorStoragePolicy` from existing table properties,
294    /// overriding the column name and embedding model.
295    fn new_policy_from_metadata(
296        &self,
297        props: &std::collections::HashMap<String, String>,
298    ) -> AilakeResult<VectorStoragePolicy> {
299        use ailake_core::{VectorMetric, VectorPrecision};
300
301        let dim: u32 = props
302            .get("ailake.vector-dim")
303            .and_then(|s| s.parse().ok())
304            .ok_or_else(|| {
305                AilakeError::InvalidArgument("table missing ailake.vector-dim property".into())
306            })?;
307
308        let metric = match props
309            .get("ailake.vector-metric")
310            .map(|s| s.as_str())
311            .unwrap_or("cosine")
312        {
313            "euclidean" => VectorMetric::Euclidean,
314            "dotproduct" | "dot_product" => VectorMetric::DotProduct,
315            "normalizedcosine" | "normalized_cosine" => VectorMetric::NormalizedCosine,
316            _ => VectorMetric::Cosine,
317        };
318
319        let precision = match props
320            .get("ailake.vector-precision")
321            .map(|s| s.as_str())
322            .unwrap_or("f16")
323        {
324            "f32" => VectorPrecision::F32,
325            "i8" => VectorPrecision::I8,
326            _ => VectorPrecision::F16,
327        };
328
329        Ok(VectorStoragePolicy {
330            column_name: self.new_column.clone(),
331            dim,
332            metric,
333            precision,
334            pq: None,
335            keep_raw_for_reranking: true,
336            pre_normalize: props
337                .get("ailake.pre-normalize")
338                .map(|s| s == "true")
339                .unwrap_or(false),
340            hnsw_m: props.get("ailake.hnsw-m").and_then(|s| s.parse().ok()),
341            hnsw_ef_construction: props
342                .get("ailake.hnsw-ef-construction")
343                .and_then(|s| s.parse().ok()),
344            ivf_residual: false,
345            embedding_model: self.new_model.clone(),
346            modality: None,
347        })
348    }
349}
350
351fn extract_string_column(batch: &RecordBatch, column_name: &str) -> AilakeResult<Vec<String>> {
352    let col = batch.column_by_name(column_name).ok_or_else(|| {
353        AilakeError::InvalidArgument(format!(
354            "text column '{}' not found in Parquet file; available: {}",
355            column_name,
356            batch
357                .schema()
358                .fields()
359                .iter()
360                .map(|f| f.name().as_str())
361                .collect::<Vec<_>>()
362                .join(", ")
363        ))
364    })?;
365
366    let arr = col.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
367        AilakeError::InvalidArgument(format!(
368            "column '{}' is not a Utf8/String column",
369            column_name
370        ))
371    })?;
372
373    Ok((0..arr.len())
374        .map(|i| {
375            if arr.is_null(i) {
376                String::new()
377            } else {
378                arr.value(i).to_string()
379            }
380        })
381        .collect())
382}