Skip to main content

ailake_query/
mem_table.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tracing::info;
5
6use ailake_catalog::{CatalogProvider, SnapshotId, TableIdent};
7use ailake_core::{AilakeResult, VectorStoragePolicy};
8use ailake_store::Store;
9use arrow_array::RecordBatch;
10use arrow_select::concat::concat_batches;
11
12use crate::writer::TableWriter;
13
14/// Tuning knobs for `MemTableWriter`.
15#[derive(Debug, Clone)]
16pub struct MemTableConfig {
17    /// Flush when accumulated embedding bytes exceed this threshold.
18    /// Default: 64 MiB.
19    pub flush_size_bytes: usize,
20    /// Flush when the row count exceeds this value regardless of byte size.
21    /// Default: 100,000 rows.
22    pub flush_max_rows: usize,
23    /// Maximum age of unflushed data before `flush_if_due` triggers a flush.
24    /// Default: 30 seconds.
25    pub flush_interval: Duration,
26}
27
28impl Default for MemTableConfig {
29    fn default() -> Self {
30        Self {
31            flush_size_bytes: 64 * 1024 * 1024,
32            flush_max_rows: 100_000,
33            flush_interval: Duration::from_secs(30),
34        }
35    }
36}
37
38/// In-memory write buffer that batches small inserts before persisting.
39///
40/// Problem: streaming pipelines (Flink, Spark Streaming) emit small
41/// RecordBatches every few seconds. Calling `write_batch_deferred` on each
42/// micro-batch creates many tiny Parquet files and triggers repeated HNSW
43/// builds — both are expensive.
44///
45/// Solution: buffer rows in RAM, flush to a single Parquet shard only when
46/// the buffer reaches the configured size/row/time threshold. The deferred
47/// HNSW build runs once per flush, not once per micro-batch.
48///
49/// # Usage
50///
51/// ```ignore
52/// let mut mt = MemTableWriter::new(catalog, store, policy, table, MemTableConfig::default());
53/// loop {
54///     mt.insert(&batch, &embeddings).await.unwrap();
55///     mt.flush_if_due().await.unwrap();
56/// }
57/// mt.flush().await.unwrap();
58/// ```
59pub struct MemTableWriter {
60    catalog: Arc<dyn CatalogProvider>,
61    store: Arc<dyn Store>,
62    policy: VectorStoragePolicy,
63    table: TableIdent,
64    config: MemTableConfig,
65
66    // Accumulated micro-batches waiting for flush
67    pending_batches: Vec<RecordBatch>,
68    pending_embeddings: Vec<Vec<f32>>,
69    buffered_bytes: usize,
70    last_flush: Instant,
71}
72
73impl MemTableWriter {
74    pub fn new(
75        catalog: Arc<dyn CatalogProvider>,
76        store: Arc<dyn Store>,
77        policy: VectorStoragePolicy,
78        table: TableIdent,
79        config: MemTableConfig,
80    ) -> Self {
81        Self {
82            catalog,
83            store,
84            policy,
85            table,
86            config,
87            pending_batches: Vec::new(),
88            pending_embeddings: Vec::new(),
89            buffered_bytes: 0,
90            last_flush: Instant::now(),
91        }
92    }
93
94    /// Buffer a micro-batch. Flushes automatically if size or row threshold exceeded.
95    /// Returns `Some(snapshot_id)` when an automatic flush occurred, `None` otherwise.
96    pub async fn insert(
97        &mut self,
98        batch: &RecordBatch,
99        embeddings: &[Vec<f32>],
100    ) -> AilakeResult<Option<SnapshotId>> {
101        let row_bytes =
102            embeddings.len() * self.policy.dim as usize * self.policy.precision.bytes_per_element();
103
104        self.pending_batches.push(batch.clone());
105        self.pending_embeddings.extend_from_slice(embeddings);
106        self.buffered_bytes += row_bytes;
107
108        if self.buffered_bytes >= self.config.flush_size_bytes
109            || self.pending_embeddings.len() >= self.config.flush_max_rows
110        {
111            info!(
112                "ailake: MemTable auto-flush triggered — {} rows / {} bytes buffered",
113                self.pending_embeddings.len(),
114                self.buffered_bytes
115            );
116            Ok(Some(self.flush().await?))
117        } else {
118            Ok(None)
119        }
120    }
121
122    /// Flush if `flush_interval` has elapsed since the last flush.
123    /// Returns `Some(snapshot_id)` when a flush occurred, `None` otherwise.
124    pub async fn flush_if_due(&mut self) -> AilakeResult<Option<SnapshotId>> {
125        if self.pending_embeddings.is_empty() {
126            self.last_flush = Instant::now();
127            return Ok(None);
128        }
129        if self.last_flush.elapsed() >= self.config.flush_interval {
130            info!(
131                "ailake: MemTable time-based flush — {} rows / {} bytes buffered (interval={}s)",
132                self.pending_embeddings.len(),
133                self.buffered_bytes,
134                self.config.flush_interval.as_secs()
135            );
136            Ok(Some(self.flush().await?))
137        } else {
138            Ok(None)
139        }
140    }
141
142    /// Flush all buffered data immediately, even if thresholds are not met.
143    /// Returns the committed `SnapshotId`. Calling on an empty buffer is a no-op
144    /// and returns `SnapshotId::default()`.
145    pub async fn flush(&mut self) -> AilakeResult<SnapshotId> {
146        if self.pending_embeddings.is_empty() {
147            self.last_flush = Instant::now();
148            return Ok(0);
149        }
150
151        // Concatenate accumulated micro-batches into one shard batch.
152        let merged = if self.pending_batches.len() == 1 {
153            self.pending_batches.remove(0)
154        } else {
155            concat_batches(&self.pending_batches[0].schema(), &self.pending_batches)
156                .map_err(|e| ailake_core::AilakeError::Arrow(e.to_string()))?
157        };
158        let embeddings = std::mem::take(&mut self.pending_embeddings);
159        self.pending_batches.clear();
160        self.buffered_bytes = 0;
161        self.last_flush = Instant::now();
162
163        // Delegate to TableWriter: Parquet-only write + deferred HNSW build.
164        let mut writer = TableWriter::new(
165            self.catalog.clone(),
166            self.store.clone(),
167            self.policy.clone(),
168            self.table.clone(),
169        );
170        writer.write_batch_deferred(&merged, &embeddings).await?;
171        let snap = writer.commit().await?;
172        Ok(snap)
173    }
174
175    /// Number of rows currently in the buffer.
176    pub fn buffered_rows(&self) -> usize {
177        self.pending_embeddings.len()
178    }
179
180    /// Estimated byte size of the embedding data currently in the buffer.
181    pub fn buffered_bytes(&self) -> usize {
182        self.buffered_bytes
183    }
184
185    /// True when a size or row-count threshold has been reached.
186    pub fn is_full(&self) -> bool {
187        self.buffered_bytes >= self.config.flush_size_bytes
188            || self.pending_embeddings.len() >= self.config.flush_max_rows
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use std::sync::Arc;
196
197    use ailake_catalog::{HadoopCatalog, TableIdent};
198    use ailake_core::{VectorMetric, VectorPrecision};
199    use ailake_store::{LocalStore, Store};
200    use arrow_array::{Int32Array, StringArray};
201    use arrow_schema::{DataType, Field, Schema};
202
203    fn make_policy() -> VectorStoragePolicy {
204        VectorStoragePolicy {
205            column_name: "embedding".to_string(),
206            dim: 4,
207            metric: VectorMetric::Euclidean,
208            precision: VectorPrecision::F16,
209            pq: None,
210            keep_raw_for_reranking: false,
211            pre_normalize: false,
212            hnsw_m: None,
213            hnsw_ef_construction: None,
214            rabitq: None,
215        }
216    }
217
218    fn make_batch(ids: &[i32]) -> RecordBatch {
219        let schema = Arc::new(Schema::new(vec![
220            Field::new("id", DataType::Int32, false),
221            Field::new("text", DataType::Utf8, false),
222        ]));
223        let texts: Vec<&str> = ids.iter().map(|_| "chunk").collect();
224        RecordBatch::try_new(
225            schema,
226            vec![
227                Arc::new(Int32Array::from(ids.to_vec())),
228                Arc::new(StringArray::from(texts)),
229            ],
230        )
231        .unwrap()
232    }
233
234    fn make_embeddings(n: usize, dim: usize) -> Vec<Vec<f32>> {
235        (0..n)
236            .map(|i| (0..dim).map(|d| i as f32 + d as f32 * 0.1).collect())
237            .collect()
238    }
239
240    async fn setup_table(catalog: &HadoopCatalog, table: &TableIdent) {
241        catalog
242            .create_table(
243                table,
244                &ailake_catalog::TableProperties {
245                    policy: make_policy(),
246                    extra: Default::default(),
247                },
248            )
249            .await
250            .unwrap();
251    }
252
253    #[tokio::test]
254    async fn mem_table_insert_and_flush() {
255        let dir = tempfile::tempdir().unwrap();
256        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
257        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
258        let table = TableIdent::new("default", "test_mem");
259        setup_table(&catalog, &table).await;
260
261        let config = MemTableConfig {
262            flush_size_bytes: 1024 * 1024,
263            flush_max_rows: 1000,
264            flush_interval: Duration::from_secs(60),
265        };
266        let mut mt = MemTableWriter::new(
267            catalog.clone(),
268            store.clone(),
269            make_policy(),
270            table.clone(),
271            config,
272        );
273
274        for i in 0..3 {
275            let ids: Vec<i32> = (i * 5..(i + 1) * 5).collect();
276            let batch = make_batch(&ids);
277            let embs = make_embeddings(5, 4);
278            let snap = mt.insert(&batch, &embs).await.unwrap();
279            assert!(snap.is_none(), "should not auto-flush yet");
280        }
281        assert_eq!(mt.buffered_rows(), 15);
282
283        let snap = mt.flush().await.unwrap();
284        assert!(snap > 0, "snapshot id should be non-zero");
285        assert_eq!(mt.buffered_rows(), 0, "buffer should be empty after flush");
286        assert_eq!(mt.buffered_bytes(), 0);
287    }
288
289    #[tokio::test]
290    async fn mem_table_auto_flush_on_row_limit() {
291        let dir = tempfile::tempdir().unwrap();
292        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
293        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
294        let table = TableIdent::new("default", "test_auto");
295        setup_table(&catalog, &table).await;
296
297        let config = MemTableConfig {
298            flush_size_bytes: 1024 * 1024 * 1024,
299            flush_max_rows: 8,
300            flush_interval: Duration::from_secs(60),
301        };
302        let mut mt = MemTableWriter::new(
303            catalog.clone(),
304            store.clone(),
305            make_policy(),
306            table.clone(),
307            config,
308        );
309
310        let batch = make_batch(&[1, 2, 3, 4, 5]);
311        let embs = make_embeddings(5, 4);
312        assert!(mt.insert(&batch, &embs).await.unwrap().is_none());
313
314        let batch2 = make_batch(&[6, 7, 8, 9, 10]);
315        let snap = mt.insert(&batch2, &embs).await.unwrap();
316        assert!(snap.is_some(), "should auto-flush when row limit exceeded");
317        assert_eq!(mt.buffered_rows(), 0);
318    }
319
320    #[tokio::test]
321    async fn mem_table_flush_if_due() {
322        let dir = tempfile::tempdir().unwrap();
323        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
324        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
325        let table = TableIdent::new("default", "test_due");
326        setup_table(&catalog, &table).await;
327
328        let config = MemTableConfig {
329            flush_size_bytes: 1024 * 1024 * 1024,
330            flush_max_rows: 1000,
331            flush_interval: Duration::from_millis(1),
332        };
333        let mut mt = MemTableWriter::new(
334            catalog.clone(),
335            store.clone(),
336            make_policy(),
337            table.clone(),
338            config,
339        );
340
341        let batch = make_batch(&[1, 2, 3]);
342        let embs = make_embeddings(3, 4);
343        mt.insert(&batch, &embs).await.unwrap();
344
345        tokio::time::sleep(Duration::from_millis(5)).await;
346
347        let snap = mt.flush_if_due().await.unwrap();
348        assert!(snap.is_some(), "should flush because interval elapsed");
349        assert_eq!(mt.buffered_rows(), 0);
350    }
351}