Skip to main content

ailake_query/
mem_table.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tracing::info;
5
6use ailake_catalog::{CatalogProvider, SnapshotId, TableIdent};
7use ailake_core::{AilakeResult, VectorStoragePolicy};
8use ailake_store::Store;
9use arrow_array::RecordBatch;
10use arrow_select::concat::concat_batches;
11
12use crate::writer::TableWriter;
13
14/// Tuning knobs for `MemTableWriter`.
15#[derive(Debug, Clone)]
16pub struct MemTableConfig {
17    /// Flush when accumulated embedding bytes exceed this threshold.
18    /// Default: 64 MiB.
19    pub flush_size_bytes: usize,
20    /// Flush when the row count exceeds this value regardless of byte size.
21    /// Default: 100,000 rows.
22    pub flush_max_rows: usize,
23    /// Maximum age of unflushed data before `flush_if_due` triggers a flush.
24    /// Default: 30 seconds.
25    pub flush_interval: Duration,
26}
27
28impl Default for MemTableConfig {
29    fn default() -> Self {
30        Self {
31            flush_size_bytes: 64 * 1024 * 1024,
32            flush_max_rows: 100_000,
33            flush_interval: Duration::from_secs(30),
34        }
35    }
36}
37
38/// In-memory write buffer that batches small inserts before persisting.
39///
40/// Problem: streaming pipelines (Flink, Spark Streaming) emit small
41/// RecordBatches every few seconds. Calling `write_batch_deferred` on each
42/// micro-batch creates many tiny Parquet files and triggers repeated HNSW
43/// builds — both are expensive.
44///
45/// Solution: buffer rows in RAM, flush to a single Parquet shard only when
46/// the buffer reaches the configured size/row/time threshold. The deferred
47/// HNSW build runs once per flush, not once per micro-batch.
48///
49/// # Usage
50///
51/// ```ignore
52/// let mut mt = MemTableWriter::new(catalog, store, policy, table, MemTableConfig::default());
53/// loop {
54///     mt.insert(&batch, &embeddings).await.unwrap();
55///     mt.flush_if_due().await.unwrap();
56/// }
57/// mt.flush().await.unwrap();
58/// ```
59pub struct MemTableWriter {
60    catalog: Arc<dyn CatalogProvider>,
61    store: Arc<dyn Store>,
62    policy: VectorStoragePolicy,
63    table: TableIdent,
64    config: MemTableConfig,
65
66    // Accumulated micro-batches waiting for flush
67    pending_batches: Vec<RecordBatch>,
68    pending_embeddings: Vec<Vec<f32>>,
69    buffered_bytes: usize,
70    last_flush: Instant,
71}
72
73impl MemTableWriter {
74    pub fn new(
75        catalog: Arc<dyn CatalogProvider>,
76        store: Arc<dyn Store>,
77        policy: VectorStoragePolicy,
78        table: TableIdent,
79        config: MemTableConfig,
80    ) -> Self {
81        Self {
82            catalog,
83            store,
84            policy,
85            table,
86            config,
87            pending_batches: Vec::new(),
88            pending_embeddings: Vec::new(),
89            buffered_bytes: 0,
90            last_flush: Instant::now(),
91        }
92    }
93
94    /// Buffer a micro-batch. Flushes automatically if size or row threshold exceeded.
95    /// Returns `Some(snapshot_id)` when an automatic flush occurred, `None` otherwise.
96    pub async fn insert(
97        &mut self,
98        batch: &RecordBatch,
99        embeddings: &[Vec<f32>],
100    ) -> AilakeResult<Option<SnapshotId>> {
101        let row_bytes =
102            embeddings.len() * self.policy.dim as usize * self.policy.precision.bytes_per_element();
103
104        self.pending_batches.push(batch.clone());
105        self.pending_embeddings.extend_from_slice(embeddings);
106        self.buffered_bytes += row_bytes;
107
108        if self.buffered_bytes >= self.config.flush_size_bytes
109            || self.pending_embeddings.len() >= self.config.flush_max_rows
110        {
111            info!(
112                "ailake: MemTable auto-flush triggered — {} rows / {} bytes buffered",
113                self.pending_embeddings.len(),
114                self.buffered_bytes
115            );
116            Ok(Some(self.flush().await?))
117        } else {
118            Ok(None)
119        }
120    }
121
122    /// Flush if `flush_interval` has elapsed since the last flush.
123    /// Returns `Some(snapshot_id)` when a flush occurred, `None` otherwise.
124    pub async fn flush_if_due(&mut self) -> AilakeResult<Option<SnapshotId>> {
125        if self.pending_embeddings.is_empty() {
126            self.last_flush = Instant::now();
127            return Ok(None);
128        }
129        if self.last_flush.elapsed() >= self.config.flush_interval {
130            info!(
131                "ailake: MemTable time-based flush — {} rows / {} bytes buffered (interval={}s)",
132                self.pending_embeddings.len(),
133                self.buffered_bytes,
134                self.config.flush_interval.as_secs()
135            );
136            Ok(Some(self.flush().await?))
137        } else {
138            Ok(None)
139        }
140    }
141
142    /// Flush all buffered data immediately, even if thresholds are not met.
143    /// Returns the committed `SnapshotId`. Calling on an empty buffer is a no-op
144    /// and returns `SnapshotId::default()`.
145    pub async fn flush(&mut self) -> AilakeResult<SnapshotId> {
146        if self.pending_embeddings.is_empty() {
147            self.last_flush = Instant::now();
148            return Ok(0);
149        }
150
151        // Concatenate accumulated micro-batches into one shard batch.
152        let merged = if self.pending_batches.len() == 1 {
153            self.pending_batches.remove(0)
154        } else {
155            concat_batches(&self.pending_batches[0].schema(), &self.pending_batches)
156                .map_err(|e| ailake_core::AilakeError::Arrow(e.to_string()))?
157        };
158        let embeddings = std::mem::take(&mut self.pending_embeddings);
159        self.pending_batches.clear();
160        self.buffered_bytes = 0;
161        self.last_flush = Instant::now();
162
163        // Delegate to TableWriter: Parquet-only write + deferred HNSW build.
164        let mut writer = TableWriter::new(
165            self.catalog.clone(),
166            self.store.clone(),
167            self.policy.clone(),
168            self.table.clone(),
169        );
170        writer.write_batch_deferred(&merged, &embeddings).await?;
171        let snap = writer.commit().await?;
172        Ok(snap)
173    }
174
175    /// Number of rows currently in the buffer.
176    pub fn buffered_rows(&self) -> usize {
177        self.pending_embeddings.len()
178    }
179
180    /// Estimated byte size of the embedding data currently in the buffer.
181    pub fn buffered_bytes(&self) -> usize {
182        self.buffered_bytes
183    }
184
185    /// True when a size or row-count threshold has been reached.
186    pub fn is_full(&self) -> bool {
187        self.buffered_bytes >= self.config.flush_size_bytes
188            || self.pending_embeddings.len() >= self.config.flush_max_rows
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use std::sync::Arc;
196
197    use ailake_catalog::{HadoopCatalog, TableIdent};
198    use ailake_core::{VectorMetric, VectorPrecision};
199    use ailake_store::{LocalStore, Store};
200    use arrow_array::{Int32Array, StringArray};
201    use arrow_schema::{DataType, Field, Schema};
202
203    fn make_policy() -> VectorStoragePolicy {
204        VectorStoragePolicy {
205            column_name: "embedding".to_string(),
206            dim: 4,
207            metric: VectorMetric::Euclidean,
208            precision: VectorPrecision::F16,
209            pq: None,
210            keep_raw_for_reranking: false,
211            pre_normalize: false,
212            hnsw_m: None,
213            hnsw_ef_construction: None,
214        }
215    }
216
217    fn make_batch(ids: &[i32]) -> RecordBatch {
218        let schema = Arc::new(Schema::new(vec![
219            Field::new("id", DataType::Int32, false),
220            Field::new("text", DataType::Utf8, false),
221        ]));
222        let texts: Vec<&str> = ids.iter().map(|_| "chunk").collect();
223        RecordBatch::try_new(
224            schema,
225            vec![
226                Arc::new(Int32Array::from(ids.to_vec())),
227                Arc::new(StringArray::from(texts)),
228            ],
229        )
230        .unwrap()
231    }
232
233    fn make_embeddings(n: usize, dim: usize) -> Vec<Vec<f32>> {
234        (0..n)
235            .map(|i| (0..dim).map(|d| i as f32 + d as f32 * 0.1).collect())
236            .collect()
237    }
238
239    async fn setup_table(catalog: &HadoopCatalog, table: &TableIdent) {
240        catalog
241            .create_table(
242                table,
243                &ailake_catalog::TableProperties {
244                    policy: make_policy(),
245                    extra: Default::default(),
246                },
247            )
248            .await
249            .unwrap();
250    }
251
252    #[tokio::test]
253    async fn mem_table_insert_and_flush() {
254        let dir = tempfile::tempdir().unwrap();
255        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
256        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
257        let table = TableIdent::new("default", "test_mem");
258        setup_table(&catalog, &table).await;
259
260        let config = MemTableConfig {
261            flush_size_bytes: 1024 * 1024,
262            flush_max_rows: 1000,
263            flush_interval: Duration::from_secs(60),
264        };
265        let mut mt = MemTableWriter::new(
266            catalog.clone(),
267            store.clone(),
268            make_policy(),
269            table.clone(),
270            config,
271        );
272
273        for i in 0..3 {
274            let ids: Vec<i32> = (i * 5..(i + 1) * 5).collect();
275            let batch = make_batch(&ids);
276            let embs = make_embeddings(5, 4);
277            let snap = mt.insert(&batch, &embs).await.unwrap();
278            assert!(snap.is_none(), "should not auto-flush yet");
279        }
280        assert_eq!(mt.buffered_rows(), 15);
281
282        let snap = mt.flush().await.unwrap();
283        assert!(snap > 0, "snapshot id should be non-zero");
284        assert_eq!(mt.buffered_rows(), 0, "buffer should be empty after flush");
285        assert_eq!(mt.buffered_bytes(), 0);
286    }
287
288    #[tokio::test]
289    async fn mem_table_auto_flush_on_row_limit() {
290        let dir = tempfile::tempdir().unwrap();
291        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
292        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
293        let table = TableIdent::new("default", "test_auto");
294        setup_table(&catalog, &table).await;
295
296        let config = MemTableConfig {
297            flush_size_bytes: 1024 * 1024 * 1024,
298            flush_max_rows: 8,
299            flush_interval: Duration::from_secs(60),
300        };
301        let mut mt = MemTableWriter::new(
302            catalog.clone(),
303            store.clone(),
304            make_policy(),
305            table.clone(),
306            config,
307        );
308
309        let batch = make_batch(&[1, 2, 3, 4, 5]);
310        let embs = make_embeddings(5, 4);
311        assert!(mt.insert(&batch, &embs).await.unwrap().is_none());
312
313        let batch2 = make_batch(&[6, 7, 8, 9, 10]);
314        let snap = mt.insert(&batch2, &embs).await.unwrap();
315        assert!(snap.is_some(), "should auto-flush when row limit exceeded");
316        assert_eq!(mt.buffered_rows(), 0);
317    }
318
319    #[tokio::test]
320    async fn mem_table_flush_if_due() {
321        let dir = tempfile::tempdir().unwrap();
322        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
323        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
324        let table = TableIdent::new("default", "test_due");
325        setup_table(&catalog, &table).await;
326
327        let config = MemTableConfig {
328            flush_size_bytes: 1024 * 1024 * 1024,
329            flush_max_rows: 1000,
330            flush_interval: Duration::from_millis(1),
331        };
332        let mut mt = MemTableWriter::new(
333            catalog.clone(),
334            store.clone(),
335            make_policy(),
336            table.clone(),
337            config,
338        );
339
340        let batch = make_batch(&[1, 2, 3]);
341        let embs = make_embeddings(3, 4);
342        mt.insert(&batch, &embs).await.unwrap();
343
344        tokio::time::sleep(Duration::from_millis(5)).await;
345
346        let snap = mt.flush_if_due().await.unwrap();
347        assert!(snap.is_some(), "should flush because interval elapsed");
348        assert_eq!(mt.buffered_rows(), 0);
349    }
350}