Skip to main content

ailake_query/
mem_table.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use ailake_catalog::{CatalogProvider, SnapshotId, TableIdent};
6use ailake_core::{AilakeResult, VectorStoragePolicy};
7use ailake_store::Store;
8use arrow_array::RecordBatch;
9use arrow_select::concat::concat_batches;
10
11use crate::writer::TableWriter;
12
13/// Tuning knobs for `MemTableWriter`.
14#[derive(Debug, Clone)]
15pub struct MemTableConfig {
16    /// Flush when accumulated embedding bytes exceed this threshold.
17    /// Default: 64 MiB.
18    pub flush_size_bytes: usize,
19    /// Flush when the row count exceeds this value regardless of byte size.
20    /// Default: 100,000 rows.
21    pub flush_max_rows: usize,
22    /// Maximum age of unflushed data before `flush_if_due` triggers a flush.
23    /// Default: 30 seconds.
24    pub flush_interval: Duration,
25}
26
27impl Default for MemTableConfig {
28    fn default() -> Self {
29        Self {
30            flush_size_bytes: 64 * 1024 * 1024,
31            flush_max_rows: 100_000,
32            flush_interval: Duration::from_secs(30),
33        }
34    }
35}
36
37/// In-memory write buffer that batches small inserts before persisting.
38///
39/// Problem: streaming pipelines (Flink, Spark Streaming) emit small
40/// RecordBatches every few seconds. Calling `write_batch_deferred` on each
41/// micro-batch creates many tiny Parquet files and triggers repeated HNSW
42/// builds — both are expensive.
43///
44/// Solution: buffer rows in RAM, flush to a single Parquet shard only when
45/// the buffer reaches the configured size/row/time threshold. The deferred
46/// HNSW build runs once per flush, not once per micro-batch.
47///
48/// # Usage
49///
50/// ```ignore
51/// let mut mt = MemTableWriter::new(catalog, store, policy, table, MemTableConfig::default());
52/// loop {
53///     mt.insert(&batch, &embeddings).await.unwrap();
54///     mt.flush_if_due().await.unwrap();
55/// }
56/// mt.flush().await.unwrap();
57/// ```
58pub struct MemTableWriter {
59    catalog: Arc<dyn CatalogProvider>,
60    store: Arc<dyn Store>,
61    policy: VectorStoragePolicy,
62    table: TableIdent,
63    config: MemTableConfig,
64
65    // Accumulated micro-batches waiting for flush
66    pending_batches: Vec<RecordBatch>,
67    pending_embeddings: Vec<Vec<f32>>,
68    buffered_bytes: usize,
69    last_flush: Instant,
70}
71
72impl MemTableWriter {
73    pub fn new(
74        catalog: Arc<dyn CatalogProvider>,
75        store: Arc<dyn Store>,
76        policy: VectorStoragePolicy,
77        table: TableIdent,
78        config: MemTableConfig,
79    ) -> Self {
80        Self {
81            catalog,
82            store,
83            policy,
84            table,
85            config,
86            pending_batches: Vec::new(),
87            pending_embeddings: Vec::new(),
88            buffered_bytes: 0,
89            last_flush: Instant::now(),
90        }
91    }
92
93    /// Buffer a micro-batch. Flushes automatically if size or row threshold exceeded.
94    /// Returns `Some(snapshot_id)` when an automatic flush occurred, `None` otherwise.
95    pub async fn insert(
96        &mut self,
97        batch: &RecordBatch,
98        embeddings: &[Vec<f32>],
99    ) -> AilakeResult<Option<SnapshotId>> {
100        let row_bytes =
101            embeddings.len() * self.policy.dim as usize * self.policy.precision.bytes_per_element();
102
103        self.pending_batches.push(batch.clone());
104        self.pending_embeddings.extend_from_slice(embeddings);
105        self.buffered_bytes += row_bytes;
106
107        if self.buffered_bytes >= self.config.flush_size_bytes
108            || self.pending_embeddings.len() >= self.config.flush_max_rows
109        {
110            Ok(Some(self.flush().await?))
111        } else {
112            Ok(None)
113        }
114    }
115
116    /// Flush if `flush_interval` has elapsed since the last flush.
117    /// Returns `Some(snapshot_id)` when a flush occurred, `None` otherwise.
118    pub async fn flush_if_due(&mut self) -> AilakeResult<Option<SnapshotId>> {
119        if self.pending_embeddings.is_empty() {
120            self.last_flush = Instant::now();
121            return Ok(None);
122        }
123        if self.last_flush.elapsed() >= self.config.flush_interval {
124            Ok(Some(self.flush().await?))
125        } else {
126            Ok(None)
127        }
128    }
129
130    /// Flush all buffered data immediately, even if thresholds are not met.
131    /// Returns the committed `SnapshotId`. Calling on an empty buffer is a no-op
132    /// and returns `SnapshotId::default()`.
133    pub async fn flush(&mut self) -> AilakeResult<SnapshotId> {
134        if self.pending_embeddings.is_empty() {
135            self.last_flush = Instant::now();
136            return Ok(0);
137        }
138
139        // Concatenate accumulated micro-batches into one shard batch.
140        let merged = if self.pending_batches.len() == 1 {
141            self.pending_batches.remove(0)
142        } else {
143            concat_batches(&self.pending_batches[0].schema(), &self.pending_batches)
144                .map_err(|e| ailake_core::AilakeError::Arrow(e.to_string()))?
145        };
146        let embeddings = std::mem::take(&mut self.pending_embeddings);
147        self.pending_batches.clear();
148        self.buffered_bytes = 0;
149        self.last_flush = Instant::now();
150
151        // Delegate to TableWriter: Parquet-only write + deferred HNSW build.
152        let mut writer = TableWriter::new(
153            self.catalog.clone(),
154            self.store.clone(),
155            self.policy.clone(),
156            self.table.clone(),
157        );
158        writer.write_batch_deferred(&merged, &embeddings).await?;
159        let snap = writer.commit().await?;
160        Ok(snap)
161    }
162
163    /// Number of rows currently in the buffer.
164    pub fn buffered_rows(&self) -> usize {
165        self.pending_embeddings.len()
166    }
167
168    /// Estimated byte size of the embedding data currently in the buffer.
169    pub fn buffered_bytes(&self) -> usize {
170        self.buffered_bytes
171    }
172
173    /// True when a size or row-count threshold has been reached.
174    pub fn is_full(&self) -> bool {
175        self.buffered_bytes >= self.config.flush_size_bytes
176            || self.pending_embeddings.len() >= self.config.flush_max_rows
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use std::sync::Arc;
184
185    use ailake_catalog::{HadoopCatalog, TableIdent};
186    use ailake_core::{VectorMetric, VectorPrecision};
187    use ailake_store::{LocalStore, Store};
188    use arrow_array::{Int32Array, StringArray};
189    use arrow_schema::{DataType, Field, Schema};
190
191    fn make_policy() -> VectorStoragePolicy {
192        VectorStoragePolicy {
193            column_name: "embedding".to_string(),
194            dim: 4,
195            metric: VectorMetric::Euclidean,
196            precision: VectorPrecision::F16,
197            pq: None,
198            keep_raw_for_reranking: false,
199        }
200    }
201
202    fn make_batch(ids: &[i32]) -> RecordBatch {
203        let schema = Arc::new(Schema::new(vec![
204            Field::new("id", DataType::Int32, false),
205            Field::new("text", DataType::Utf8, false),
206        ]));
207        let texts: Vec<&str> = ids.iter().map(|_| "chunk").collect();
208        RecordBatch::try_new(
209            schema,
210            vec![
211                Arc::new(Int32Array::from(ids.to_vec())),
212                Arc::new(StringArray::from(texts)),
213            ],
214        )
215        .unwrap()
216    }
217
218    fn make_embeddings(n: usize, dim: usize) -> Vec<Vec<f32>> {
219        (0..n)
220            .map(|i| (0..dim).map(|d| i as f32 + d as f32 * 0.1).collect())
221            .collect()
222    }
223
224    async fn setup_table(catalog: &HadoopCatalog, table: &TableIdent) {
225        catalog
226            .create_table(
227                table,
228                &ailake_catalog::TableProperties {
229                    policy: make_policy(),
230                    extra: Default::default(),
231                },
232            )
233            .await
234            .unwrap();
235    }
236
237    #[tokio::test]
238    async fn mem_table_insert_and_flush() {
239        let dir = tempfile::tempdir().unwrap();
240        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
241        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
242        let table = TableIdent::new("default", "test_mem");
243        setup_table(&catalog, &table).await;
244
245        let config = MemTableConfig {
246            flush_size_bytes: 1024 * 1024,
247            flush_max_rows: 1000,
248            flush_interval: Duration::from_secs(60),
249        };
250        let mut mt = MemTableWriter::new(
251            catalog.clone(),
252            store.clone(),
253            make_policy(),
254            table.clone(),
255            config,
256        );
257
258        for i in 0..3 {
259            let ids: Vec<i32> = (i * 5..(i + 1) * 5).collect();
260            let batch = make_batch(&ids);
261            let embs = make_embeddings(5, 4);
262            let snap = mt.insert(&batch, &embs).await.unwrap();
263            assert!(snap.is_none(), "should not auto-flush yet");
264        }
265        assert_eq!(mt.buffered_rows(), 15);
266
267        let snap = mt.flush().await.unwrap();
268        assert!(snap > 0, "snapshot id should be non-zero");
269        assert_eq!(mt.buffered_rows(), 0, "buffer should be empty after flush");
270        assert_eq!(mt.buffered_bytes(), 0);
271    }
272
273    #[tokio::test]
274    async fn mem_table_auto_flush_on_row_limit() {
275        let dir = tempfile::tempdir().unwrap();
276        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
277        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
278        let table = TableIdent::new("default", "test_auto");
279        setup_table(&catalog, &table).await;
280
281        let config = MemTableConfig {
282            flush_size_bytes: 1024 * 1024 * 1024,
283            flush_max_rows: 8,
284            flush_interval: Duration::from_secs(60),
285        };
286        let mut mt = MemTableWriter::new(
287            catalog.clone(),
288            store.clone(),
289            make_policy(),
290            table.clone(),
291            config,
292        );
293
294        let batch = make_batch(&[1, 2, 3, 4, 5]);
295        let embs = make_embeddings(5, 4);
296        assert!(mt.insert(&batch, &embs).await.unwrap().is_none());
297
298        let batch2 = make_batch(&[6, 7, 8, 9, 10]);
299        let snap = mt.insert(&batch2, &embs).await.unwrap();
300        assert!(snap.is_some(), "should auto-flush when row limit exceeded");
301        assert_eq!(mt.buffered_rows(), 0);
302    }
303
304    #[tokio::test]
305    async fn mem_table_flush_if_due() {
306        let dir = tempfile::tempdir().unwrap();
307        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
308        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
309        let table = TableIdent::new("default", "test_due");
310        setup_table(&catalog, &table).await;
311
312        let config = MemTableConfig {
313            flush_size_bytes: 1024 * 1024 * 1024,
314            flush_max_rows: 1000,
315            flush_interval: Duration::from_millis(1),
316        };
317        let mut mt = MemTableWriter::new(
318            catalog.clone(),
319            store.clone(),
320            make_policy(),
321            table.clone(),
322            config,
323        );
324
325        let batch = make_batch(&[1, 2, 3]);
326        let embs = make_embeddings(3, 4);
327        mt.insert(&batch, &embs).await.unwrap();
328
329        tokio::time::sleep(Duration::from_millis(5)).await;
330
331        let snap = mt.flush_if_due().await.unwrap();
332        assert!(snap.is_some(), "should flush because interval elapsed");
333        assert_eq!(mt.buffered_rows(), 0);
334    }
335}