Skip to main content

ailake_query/
mem_table.rs

1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use ailake_catalog::{CatalogProvider, SnapshotId, TableIdent};
5use ailake_core::{AilakeResult, VectorStoragePolicy};
6use ailake_store::Store;
7use arrow_array::RecordBatch;
8use arrow_select::concat::concat_batches;
9
10use crate::writer::TableWriter;
11
12/// Tuning knobs for `MemTableWriter`.
13#[derive(Debug, Clone)]
14pub struct MemTableConfig {
15    /// Flush when accumulated embedding bytes exceed this threshold.
16    /// Default: 64 MiB.
17    pub flush_size_bytes: usize,
18    /// Flush when the row count exceeds this value regardless of byte size.
19    /// Default: 100,000 rows.
20    pub flush_max_rows: usize,
21    /// Maximum age of unflushed data before `flush_if_due` triggers a flush.
22    /// Default: 30 seconds.
23    pub flush_interval: Duration,
24}
25
26impl Default for MemTableConfig {
27    fn default() -> Self {
28        Self {
29            flush_size_bytes: 64 * 1024 * 1024,
30            flush_max_rows: 100_000,
31            flush_interval: Duration::from_secs(30),
32        }
33    }
34}
35
36/// In-memory write buffer that batches small inserts before persisting.
37///
38/// Problem: streaming pipelines (Flink, Spark Streaming) emit small
39/// RecordBatches every few seconds. Calling `write_batch_deferred` on each
40/// micro-batch creates many tiny Parquet files and triggers repeated HNSW
41/// builds — both are expensive.
42///
43/// Solution: buffer rows in RAM, flush to a single Parquet shard only when
44/// the buffer reaches the configured size/row/time threshold. The deferred
45/// HNSW build runs once per flush, not once per micro-batch.
46///
47/// # Usage
48///
49/// ```ignore
50/// let mut mt = MemTableWriter::new(catalog, store, policy, table, MemTableConfig::default());
51/// loop {
52///     mt.insert(&batch, &embeddings).await.unwrap();
53///     mt.flush_if_due().await.unwrap();
54/// }
55/// mt.flush().await.unwrap();
56/// ```
57pub struct MemTableWriter {
58    catalog: Arc<dyn CatalogProvider>,
59    store: Arc<dyn Store>,
60    policy: VectorStoragePolicy,
61    table: TableIdent,
62    config: MemTableConfig,
63
64    // Accumulated micro-batches waiting for flush
65    pending_batches: Vec<RecordBatch>,
66    pending_embeddings: Vec<Vec<f32>>,
67    buffered_bytes: usize,
68    last_flush: Instant,
69}
70
71impl MemTableWriter {
72    pub fn new(
73        catalog: Arc<dyn CatalogProvider>,
74        store: Arc<dyn Store>,
75        policy: VectorStoragePolicy,
76        table: TableIdent,
77        config: MemTableConfig,
78    ) -> Self {
79        Self {
80            catalog,
81            store,
82            policy,
83            table,
84            config,
85            pending_batches: Vec::new(),
86            pending_embeddings: Vec::new(),
87            buffered_bytes: 0,
88            last_flush: Instant::now(),
89        }
90    }
91
92    /// Buffer a micro-batch. Flushes automatically if size or row threshold exceeded.
93    /// Returns `Some(snapshot_id)` when an automatic flush occurred, `None` otherwise.
94    pub async fn insert(
95        &mut self,
96        batch: &RecordBatch,
97        embeddings: &[Vec<f32>],
98    ) -> AilakeResult<Option<SnapshotId>> {
99        let row_bytes =
100            embeddings.len() * self.policy.dim as usize * self.policy.precision.bytes_per_element();
101
102        self.pending_batches.push(batch.clone());
103        self.pending_embeddings.extend_from_slice(embeddings);
104        self.buffered_bytes += row_bytes;
105
106        if self.buffered_bytes >= self.config.flush_size_bytes
107            || self.pending_embeddings.len() >= self.config.flush_max_rows
108        {
109            Ok(Some(self.flush().await?))
110        } else {
111            Ok(None)
112        }
113    }
114
115    /// Flush if `flush_interval` has elapsed since the last flush.
116    /// Returns `Some(snapshot_id)` when a flush occurred, `None` otherwise.
117    pub async fn flush_if_due(&mut self) -> AilakeResult<Option<SnapshotId>> {
118        if self.pending_embeddings.is_empty() {
119            self.last_flush = Instant::now();
120            return Ok(None);
121        }
122        if self.last_flush.elapsed() >= self.config.flush_interval {
123            Ok(Some(self.flush().await?))
124        } else {
125            Ok(None)
126        }
127    }
128
129    /// Flush all buffered data immediately, even if thresholds are not met.
130    /// Returns the committed `SnapshotId`. Calling on an empty buffer is a no-op
131    /// and returns `SnapshotId::default()`.
132    pub async fn flush(&mut self) -> AilakeResult<SnapshotId> {
133        if self.pending_embeddings.is_empty() {
134            self.last_flush = Instant::now();
135            return Ok(0);
136        }
137
138        // Concatenate accumulated micro-batches into one shard batch.
139        let merged = if self.pending_batches.len() == 1 {
140            self.pending_batches.remove(0)
141        } else {
142            concat_batches(&self.pending_batches[0].schema(), &self.pending_batches)
143                .map_err(|e| ailake_core::AilakeError::Arrow(e.to_string()))?
144        };
145        let embeddings = std::mem::take(&mut self.pending_embeddings);
146        self.pending_batches.clear();
147        self.buffered_bytes = 0;
148        self.last_flush = Instant::now();
149
150        // Delegate to TableWriter: Parquet-only write + deferred HNSW build.
151        let mut writer = TableWriter::new(
152            self.catalog.clone(),
153            self.store.clone(),
154            self.policy.clone(),
155            self.table.clone(),
156        );
157        writer.write_batch_deferred(&merged, &embeddings).await?;
158        let snap = writer.commit().await?;
159        Ok(snap)
160    }
161
162    /// Number of rows currently in the buffer.
163    pub fn buffered_rows(&self) -> usize {
164        self.pending_embeddings.len()
165    }
166
167    /// Estimated byte size of the embedding data currently in the buffer.
168    pub fn buffered_bytes(&self) -> usize {
169        self.buffered_bytes
170    }
171
172    /// True when a size or row-count threshold has been reached.
173    pub fn is_full(&self) -> bool {
174        self.buffered_bytes >= self.config.flush_size_bytes
175            || self.pending_embeddings.len() >= self.config.flush_max_rows
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use std::sync::Arc;
183
184    use ailake_catalog::{HadoopCatalog, TableIdent};
185    use ailake_core::{VectorMetric, VectorPrecision};
186    use ailake_store::{LocalStore, Store};
187    use arrow_array::{Int32Array, StringArray};
188    use arrow_schema::{DataType, Field, Schema};
189
190    fn make_policy() -> VectorStoragePolicy {
191        VectorStoragePolicy {
192            column_name: "embedding".to_string(),
193            dim: 4,
194            metric: VectorMetric::Euclidean,
195            precision: VectorPrecision::F16,
196            pq: None,
197            keep_raw_for_reranking: false,
198        }
199    }
200
201    fn make_batch(ids: &[i32]) -> RecordBatch {
202        let schema = Arc::new(Schema::new(vec![
203            Field::new("id", DataType::Int32, false),
204            Field::new("text", DataType::Utf8, false),
205        ]));
206        let texts: Vec<&str> = ids.iter().map(|_| "chunk").collect();
207        RecordBatch::try_new(
208            schema,
209            vec![
210                Arc::new(Int32Array::from(ids.to_vec())),
211                Arc::new(StringArray::from(texts)),
212            ],
213        )
214        .unwrap()
215    }
216
217    fn make_embeddings(n: usize, dim: usize) -> Vec<Vec<f32>> {
218        (0..n)
219            .map(|i| (0..dim).map(|d| i as f32 + d as f32 * 0.1).collect())
220            .collect()
221    }
222
223    async fn setup_table(catalog: &HadoopCatalog, table: &TableIdent) {
224        catalog
225            .create_table(
226                table,
227                &ailake_catalog::TableProperties {
228                    policy: make_policy(),
229                    extra: Default::default(),
230                },
231            )
232            .await
233            .unwrap();
234    }
235
236    #[tokio::test]
237    async fn mem_table_insert_and_flush() {
238        let dir = tempfile::tempdir().unwrap();
239        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
240        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
241        let table = TableIdent::new("default", "test_mem");
242        setup_table(&catalog, &table).await;
243
244        let config = MemTableConfig {
245            flush_size_bytes: 1024 * 1024,
246            flush_max_rows: 1000,
247            flush_interval: Duration::from_secs(60),
248        };
249        let mut mt = MemTableWriter::new(
250            catalog.clone(),
251            store.clone(),
252            make_policy(),
253            table.clone(),
254            config,
255        );
256
257        for i in 0..3 {
258            let ids: Vec<i32> = (i * 5..(i + 1) * 5).collect();
259            let batch = make_batch(&ids);
260            let embs = make_embeddings(5, 4);
261            let snap = mt.insert(&batch, &embs).await.unwrap();
262            assert!(snap.is_none(), "should not auto-flush yet");
263        }
264        assert_eq!(mt.buffered_rows(), 15);
265
266        let snap = mt.flush().await.unwrap();
267        assert!(snap > 0, "snapshot id should be non-zero");
268        assert_eq!(mt.buffered_rows(), 0, "buffer should be empty after flush");
269        assert_eq!(mt.buffered_bytes(), 0);
270    }
271
272    #[tokio::test]
273    async fn mem_table_auto_flush_on_row_limit() {
274        let dir = tempfile::tempdir().unwrap();
275        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
276        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
277        let table = TableIdent::new("default", "test_auto");
278        setup_table(&catalog, &table).await;
279
280        let config = MemTableConfig {
281            flush_size_bytes: 1024 * 1024 * 1024,
282            flush_max_rows: 8,
283            flush_interval: Duration::from_secs(60),
284        };
285        let mut mt = MemTableWriter::new(
286            catalog.clone(),
287            store.clone(),
288            make_policy(),
289            table.clone(),
290            config,
291        );
292
293        let batch = make_batch(&[1, 2, 3, 4, 5]);
294        let embs = make_embeddings(5, 4);
295        assert!(mt.insert(&batch, &embs).await.unwrap().is_none());
296
297        let batch2 = make_batch(&[6, 7, 8, 9, 10]);
298        let snap = mt.insert(&batch2, &embs).await.unwrap();
299        assert!(snap.is_some(), "should auto-flush when row limit exceeded");
300        assert_eq!(mt.buffered_rows(), 0);
301    }
302
303    #[tokio::test]
304    async fn mem_table_flush_if_due() {
305        let dir = tempfile::tempdir().unwrap();
306        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
307        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
308        let table = TableIdent::new("default", "test_due");
309        setup_table(&catalog, &table).await;
310
311        let config = MemTableConfig {
312            flush_size_bytes: 1024 * 1024 * 1024,
313            flush_max_rows: 1000,
314            flush_interval: Duration::from_millis(1),
315        };
316        let mut mt = MemTableWriter::new(
317            catalog.clone(),
318            store.clone(),
319            make_policy(),
320            table.clone(),
321            config,
322        );
323
324        let batch = make_batch(&[1, 2, 3]);
325        let embs = make_embeddings(3, 4);
326        mt.insert(&batch, &embs).await.unwrap();
327
328        tokio::time::sleep(Duration::from_millis(5)).await;
329
330        let snap = mt.flush_if_due().await.unwrap();
331        assert!(snap.is_some(), "should flush because interval elapsed");
332        assert_eq!(mt.buffered_rows(), 0);
333    }
334}