Skip to main content

ailake_query/
mem_table.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use std::collections::VecDeque;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tracing::info;
6
7use ailake_catalog::{CatalogProvider, SnapshotId, TableIdent};
8use ailake_core::{AilakeResult, VectorStoragePolicy};
9use ailake_store::Store;
10use arrow_array::RecordBatch;
11use arrow_select::concat::concat_batches;
12
13use crate::writer::TableWriter;
14
15/// Tuning knobs for `MemTableWriter`.
16#[derive(Debug, Clone)]
17pub struct MemTableConfig {
18    /// Flush when accumulated embedding bytes exceed this threshold.
19    /// Default: 64 MiB.
20    pub flush_size_bytes: usize,
21    /// Flush when the row count exceeds this value regardless of byte size.
22    /// Default: 100,000 rows.
23    pub flush_max_rows: usize,
24    /// Maximum age of unflushed data before `flush_if_due` triggers a flush.
25    /// Default: 30 seconds.
26    pub flush_interval: Duration,
27}
28
29impl Default for MemTableConfig {
30    fn default() -> Self {
31        Self {
32            flush_size_bytes: 64 * 1024 * 1024,
33            flush_max_rows: 100_000,
34            flush_interval: Duration::from_secs(30),
35        }
36    }
37}
38
39/// In-memory write buffer that batches small inserts before persisting.
40///
41/// Problem: streaming pipelines (Flink, Spark Streaming) emit small
42/// RecordBatches every few seconds. Calling `write_batch_deferred` on each
43/// micro-batch creates many tiny Parquet files and triggers repeated HNSW
44/// builds — both are expensive.
45///
46/// Solution: buffer rows in RAM, flush to a single Parquet shard only when
47/// the buffer reaches the configured size/row/time threshold. The deferred
48/// HNSW build runs once per flush, not once per micro-batch.
49///
50/// # Usage
51///
52/// ```ignore
53/// let mut mt = MemTableWriter::new(catalog, store, policy, table, MemTableConfig::default());
54/// loop {
55///     mt.insert(&batch, &embeddings).await.unwrap();
56///     mt.flush_if_due().await.unwrap();
57/// }
58/// mt.flush().await.unwrap();
59/// ```
60pub struct MemTableWriter {
61    catalog: Arc<dyn CatalogProvider>,
62    store: Arc<dyn Store>,
63    policy: VectorStoragePolicy,
64    table: TableIdent,
65    config: MemTableConfig,
66
67    // Accumulated micro-batches waiting for flush
68    pending_batches: Vec<RecordBatch>,
69    pending_embeddings: Vec<Vec<f32>>,
70    buffered_bytes: usize,
71    last_flush: Instant,
72}
73
74impl MemTableWriter {
75    pub fn new(
76        catalog: Arc<dyn CatalogProvider>,
77        store: Arc<dyn Store>,
78        policy: VectorStoragePolicy,
79        table: TableIdent,
80        config: MemTableConfig,
81    ) -> Self {
82        Self {
83            catalog,
84            store,
85            policy,
86            table,
87            config,
88            pending_batches: Vec::new(),
89            pending_embeddings: Vec::new(),
90            buffered_bytes: 0,
91            last_flush: Instant::now(),
92        }
93    }
94
95    /// Buffer a micro-batch. Flushes automatically if size or row threshold exceeded.
96    /// Returns `Some(snapshot_id)` when an automatic flush occurred, `None` otherwise.
97    pub async fn insert(
98        &mut self,
99        batch: &RecordBatch,
100        embeddings: &[Vec<f32>],
101    ) -> AilakeResult<Option<SnapshotId>> {
102        let row_bytes =
103            embeddings.len() * self.policy.dim as usize * self.policy.precision.bytes_per_element();
104
105        self.pending_batches.push(batch.clone());
106        self.pending_embeddings.extend_from_slice(embeddings);
107        self.buffered_bytes += row_bytes;
108
109        if self.buffered_bytes >= self.config.flush_size_bytes
110            || self.pending_embeddings.len() >= self.config.flush_max_rows
111        {
112            info!(
113                "ailake: MemTable auto-flush triggered — {} rows / {} bytes buffered",
114                self.pending_embeddings.len(),
115                self.buffered_bytes
116            );
117            Ok(Some(self.flush().await?))
118        } else {
119            Ok(None)
120        }
121    }
122
123    /// Flush if `flush_interval` has elapsed since the last flush.
124    /// Returns `Some(snapshot_id)` when a flush occurred, `None` otherwise.
125    pub async fn flush_if_due(&mut self) -> AilakeResult<Option<SnapshotId>> {
126        if self.pending_embeddings.is_empty() {
127            self.last_flush = Instant::now();
128            return Ok(None);
129        }
130        if self.last_flush.elapsed() >= self.config.flush_interval {
131            info!(
132                "ailake: MemTable time-based flush — {} rows / {} bytes buffered (interval={}s)",
133                self.pending_embeddings.len(),
134                self.buffered_bytes,
135                self.config.flush_interval.as_secs()
136            );
137            Ok(Some(self.flush().await?))
138        } else {
139            Ok(None)
140        }
141    }
142
143    /// Flush all buffered data immediately, even if thresholds are not met.
144    /// Returns the committed `SnapshotId`. Calling on an empty buffer is a no-op
145    /// and returns `SnapshotId::default()`.
146    pub async fn flush(&mut self) -> AilakeResult<SnapshotId> {
147        if self.pending_embeddings.is_empty() {
148            self.last_flush = Instant::now();
149            return Ok(0);
150        }
151
152        // Concatenate accumulated micro-batches into one shard batch.
153        let merged = if self.pending_batches.len() == 1 {
154            self.pending_batches.remove(0)
155        } else {
156            concat_batches(&self.pending_batches[0].schema(), &self.pending_batches)
157                .map_err(|e| ailake_core::AilakeError::Arrow(e.to_string()))?
158        };
159        let embeddings = std::mem::take(&mut self.pending_embeddings);
160        self.pending_batches.clear();
161        self.buffered_bytes = 0;
162        self.last_flush = Instant::now();
163
164        // Delegate to TableWriter: Parquet-only write + deferred HNSW build.
165        let mut writer = TableWriter::new(
166            self.catalog.clone(),
167            self.store.clone(),
168            self.policy.clone(),
169            self.table.clone(),
170        );
171        writer.write_batch_deferred(&merged, &embeddings).await?;
172        let snap = writer.commit().await?;
173        Ok(snap)
174    }
175
176    /// Number of rows currently in the buffer.
177    pub fn buffered_rows(&self) -> usize {
178        self.pending_embeddings.len()
179    }
180
181    /// Estimated byte size of the embedding data currently in the buffer.
182    pub fn buffered_bytes(&self) -> usize {
183        self.buffered_bytes
184    }
185
186    /// True when a size or row-count threshold has been reached.
187    pub fn is_full(&self) -> bool {
188        self.buffered_bytes >= self.config.flush_size_bytes
189            || self.pending_embeddings.len() >= self.config.flush_max_rows
190    }
191}
192
193// ─── WorkingMemoryBuffer ──────────────────────────────────────────────────────
194
195/// Single entry in a `WorkingMemoryBuffer`.
196#[derive(Debug, Clone)]
197pub struct WorkingMemoryEntry {
198    /// Text content (chunk_text or tool call summary).
199    pub text: String,
200    /// Embedding vector for similarity search.
201    pub embedding: Vec<f32>,
202    /// Agent-assigned importance score (0.0–1.0). Influences hybrid scoring.
203    pub importance: f32,
204}
205
206/// Bounded in-memory buffer for agent short-term memory.
207///
208/// Stores the N most recent entries (text + embedding). When full, the oldest
209/// entry is evicted on each `push`. Supports brute-force flat scan (`search`)
210/// and draining all entries to an AI-Lake table (`drain_to_table`).
211///
212/// # Cascade pattern
213///
214/// Short-term agents use only `WorkingMemoryBuffer`. Long-term agents cascade:
215/// 1. `search` queries the buffer first (fast, recent).
216/// 2. When full, `drain_to_table` persists to AI-Lake; continue searching via
217///    `ailake::search` for historical context.
218///
219/// # Example
220///
221/// ```ignore
222/// let mut wm = WorkingMemoryBuffer::new(100);
223/// wm.push("Meeting notes: …", embedding, 0.8);
224/// let top3 = wm.search(&query_vec, 3);
225/// if wm.is_full() {
226///     wm.drain_to_table(&mut writer).await?;
227///     writer.commit().await?;
228/// }
229/// ```
230pub struct WorkingMemoryBuffer {
231    max_rows: usize,
232    entries: VecDeque<WorkingMemoryEntry>,
233}
234
235impl WorkingMemoryBuffer {
236    /// Create buffer with at most `max_rows` entries. Evicts oldest on overflow.
237    pub fn new(max_rows: usize) -> Self {
238        Self {
239            max_rows: max_rows.max(1),
240            entries: VecDeque::with_capacity(max_rows.min(4096)),
241        }
242    }
243
244    /// Add entry. If at capacity, evicts the oldest entry (FIFO).
245    pub fn push(&mut self, text: impl Into<String>, embedding: Vec<f32>, importance: f32) {
246        if self.entries.len() >= self.max_rows {
247            self.entries.pop_front();
248        }
249        self.entries.push_back(WorkingMemoryEntry {
250            text: text.into(),
251            embedding,
252            importance: importance.clamp(0.0, 1.0),
253        });
254    }
255
256    /// Brute-force cosine similarity scan. Returns `(distance, entry)` pairs
257    /// sorted ascending (smallest distance = most similar). Distances in `[0, 2]`.
258    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(f32, &WorkingMemoryEntry)> {
259        if self.entries.is_empty() || top_k == 0 {
260            return vec![];
261        }
262        let q_norm = l2_norm(query);
263        let mut scored: Vec<(f32, usize)> = self
264            .entries
265            .iter()
266            .enumerate()
267            .map(|(i, e)| {
268                let v_norm = l2_norm(&e.embedding);
269                let dot: f32 = query.iter().zip(&e.embedding).map(|(a, b)| a * b).sum();
270                let cos_sim = if q_norm * v_norm < f32::EPSILON {
271                    0.0
272                } else {
273                    dot / (q_norm * v_norm)
274                };
275                // cosine distance: lower = more similar
276                (1.0 - cos_sim, i)
277            })
278            .collect();
279
280        scored.sort_unstable_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
281        scored.truncate(top_k);
282        scored
283            .into_iter()
284            .map(|(dist, idx)| (dist, &self.entries[idx]))
285            .collect()
286    }
287
288    /// Write all buffered entries to an AI-Lake table and clear the buffer.
289    ///
290    /// Creates a RecordBatch with a `chunk_text` column. Call `writer.commit()`
291    /// after to persist the snapshot.
292    pub async fn drain_to_table(&mut self, writer: &mut TableWriter) -> AilakeResult<()> {
293        use arrow_array::StringArray;
294        use arrow_schema::{DataType, Field, Schema};
295
296        if self.entries.is_empty() {
297            return Ok(());
298        }
299
300        let texts: Vec<&str> = self.entries.iter().map(|e| e.text.as_str()).collect();
301        let embeddings: Vec<Vec<f32>> = self.entries.iter().map(|e| e.embedding.clone()).collect();
302        let importance_vals: Vec<f32> = self.entries.iter().map(|e| e.importance).collect();
303
304        let schema = Arc::new(Schema::new(vec![
305            Field::new("chunk_text", DataType::Utf8, false),
306            Field::new("importance_score", DataType::Float32, false),
307        ]));
308        let batch = arrow_array::RecordBatch::try_new(
309            schema,
310            vec![
311                Arc::new(StringArray::from(texts)) as _,
312                Arc::new(arrow_array::Float32Array::from(importance_vals)) as _,
313            ],
314        )
315        .map_err(|e| ailake_core::AilakeError::Arrow(e.to_string()))?;
316
317        writer.write_batch(&batch, &embeddings).await?;
318        self.entries.clear();
319        Ok(())
320    }
321
322    pub fn len(&self) -> usize {
323        self.entries.len()
324    }
325
326    pub fn is_empty(&self) -> bool {
327        self.entries.is_empty()
328    }
329
330    /// True when the buffer has reached its configured capacity.
331    pub fn is_full(&self) -> bool {
332        self.entries.len() >= self.max_rows
333    }
334}
335
336fn l2_norm(v: &[f32]) -> f32 {
337    v.iter().map(|x| x * x).sum::<f32>().sqrt()
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use std::sync::Arc;
344
345    use ailake_catalog::{HadoopCatalog, TableIdent};
346    use ailake_core::{VectorMetric, VectorPrecision};
347    use ailake_store::{LocalStore, Store};
348    use arrow_array::{Int32Array, StringArray};
349    use arrow_schema::{DataType, Field, Schema};
350
351    fn make_policy() -> VectorStoragePolicy {
352        VectorStoragePolicy {
353            column_name: "embedding".to_string(),
354            dim: 4,
355            metric: VectorMetric::Euclidean,
356            precision: VectorPrecision::F16,
357            pq: None,
358            keep_raw_for_reranking: true,
359            pre_normalize: false,
360            hnsw_m: None,
361            hnsw_ef_construction: None,
362            ivf_residual: false,
363            embedding_model: None,
364            modality: None,
365            partition_by: None,
366            partition_value: None,
367            partition_column_type: None,
368            partition_fields: vec![],
369        }
370    }
371
372    fn make_batch(ids: &[i32]) -> RecordBatch {
373        let schema = Arc::new(Schema::new(vec![
374            Field::new("id", DataType::Int32, false),
375            Field::new("text", DataType::Utf8, false),
376        ]));
377        let texts: Vec<&str> = ids.iter().map(|_| "chunk").collect();
378        RecordBatch::try_new(
379            schema,
380            vec![
381                Arc::new(Int32Array::from(ids.to_vec())),
382                Arc::new(StringArray::from(texts)),
383            ],
384        )
385        .unwrap()
386    }
387
388    fn make_embeddings(n: usize, dim: usize) -> Vec<Vec<f32>> {
389        (0..n)
390            .map(|i| (0..dim).map(|d| i as f32 + d as f32 * 0.1).collect())
391            .collect()
392    }
393
394    async fn setup_table(catalog: &HadoopCatalog, table: &TableIdent) {
395        catalog
396            .create_table(
397                table,
398                &ailake_catalog::TableProperties {
399                    policy: make_policy(),
400                    extra: Default::default(),
401                    format_version: 2,
402                    partition_column_type: None,
403                },
404            )
405            .await
406            .unwrap();
407    }
408
409    #[tokio::test]
410    async fn mem_table_insert_and_flush() {
411        let dir = tempfile::tempdir().unwrap();
412        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
413        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
414        let table = TableIdent::new("default", "test_mem");
415        setup_table(&catalog, &table).await;
416
417        let config = MemTableConfig {
418            flush_size_bytes: 1024 * 1024,
419            flush_max_rows: 1000,
420            flush_interval: Duration::from_secs(60),
421        };
422        let mut mt = MemTableWriter::new(
423            catalog.clone(),
424            store.clone(),
425            make_policy(),
426            table.clone(),
427            config,
428        );
429
430        for i in 0..3 {
431            let ids: Vec<i32> = (i * 5..(i + 1) * 5).collect();
432            let batch = make_batch(&ids);
433            let embs = make_embeddings(5, 4);
434            let snap = mt.insert(&batch, &embs).await.unwrap();
435            assert!(snap.is_none(), "should not auto-flush yet");
436        }
437        assert_eq!(mt.buffered_rows(), 15);
438
439        let snap = mt.flush().await.unwrap();
440        assert!(snap > 0, "snapshot id should be non-zero");
441        assert_eq!(mt.buffered_rows(), 0, "buffer should be empty after flush");
442        assert_eq!(mt.buffered_bytes(), 0);
443    }
444
445    #[tokio::test]
446    async fn mem_table_auto_flush_on_row_limit() {
447        let dir = tempfile::tempdir().unwrap();
448        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
449        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
450        let table = TableIdent::new("default", "test_auto");
451        setup_table(&catalog, &table).await;
452
453        let config = MemTableConfig {
454            flush_size_bytes: 1024 * 1024 * 1024,
455            flush_max_rows: 8,
456            flush_interval: Duration::from_secs(60),
457        };
458        let mut mt = MemTableWriter::new(
459            catalog.clone(),
460            store.clone(),
461            make_policy(),
462            table.clone(),
463            config,
464        );
465
466        let batch = make_batch(&[1, 2, 3, 4, 5]);
467        let embs = make_embeddings(5, 4);
468        assert!(mt.insert(&batch, &embs).await.unwrap().is_none());
469
470        let batch2 = make_batch(&[6, 7, 8, 9, 10]);
471        let snap = mt.insert(&batch2, &embs).await.unwrap();
472        assert!(snap.is_some(), "should auto-flush when row limit exceeded");
473        assert_eq!(mt.buffered_rows(), 0);
474    }
475
476    #[test]
477    fn working_memory_evicts_oldest() {
478        let mut wm = WorkingMemoryBuffer::new(3);
479        wm.push("a", vec![1.0, 0.0], 1.0);
480        wm.push("b", vec![0.0, 1.0], 1.0);
481        wm.push("c", vec![1.0, 1.0], 1.0);
482        assert_eq!(wm.len(), 3);
483        assert!(wm.is_full());
484
485        wm.push("d", vec![0.5, 0.5], 1.0);
486        assert_eq!(wm.len(), 3);
487        // "a" should be evicted
488        assert!(!wm.entries.iter().any(|e| e.text == "a"));
489        assert!(wm.entries.iter().any(|e| e.text == "d"));
490    }
491
492    #[test]
493    fn working_memory_search_ranks_similar_first() {
494        let mut wm = WorkingMemoryBuffer::new(10);
495        wm.push("near", vec![1.0, 0.0, 0.0], 1.0);
496        wm.push("far", vec![0.0, 1.0, 0.0], 1.0);
497        wm.push("very far", vec![0.0, 0.0, 1.0], 1.0);
498
499        let query = vec![1.0, 0.0, 0.0];
500        let results = wm.search(&query, 3);
501        assert_eq!(results.len(), 3);
502        // "near" should have smallest cosine distance
503        assert_eq!(results[0].1.text, "near");
504        assert!(results[0].0 < results[1].0);
505    }
506
507    #[test]
508    fn working_memory_search_empty() {
509        let wm = WorkingMemoryBuffer::new(10);
510        assert!(wm.search(&[1.0, 0.0], 5).is_empty());
511    }
512
513    #[tokio::test]
514    async fn working_memory_drain_to_table() {
515        let dir = tempfile::tempdir().unwrap();
516        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
517        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
518        let table = TableIdent::new("default", "test_wm_drain");
519
520        let policy = VectorStoragePolicy {
521            column_name: "embedding".to_string(),
522            dim: 3,
523            metric: ailake_core::VectorMetric::Cosine,
524            precision: ailake_core::VectorPrecision::F16,
525            pq: None,
526            keep_raw_for_reranking: false,
527            pre_normalize: false,
528            hnsw_m: None,
529            hnsw_ef_construction: None,
530            ivf_residual: false,
531            embedding_model: None,
532            modality: None,
533            partition_by: None,
534            partition_value: None,
535            partition_column_type: None,
536            partition_fields: vec![],
537        };
538        catalog
539            .create_table(
540                &table,
541                &ailake_catalog::TableProperties {
542                    policy: policy.clone(),
543                    extra: Default::default(),
544                    format_version: 2,
545                    partition_column_type: None,
546                },
547            )
548            .await
549            .unwrap();
550
551        let mut wm = WorkingMemoryBuffer::new(5);
552        wm.push("memory one", vec![1.0, 0.0, 0.0], 0.9);
553        wm.push("memory two", vec![0.0, 1.0, 0.0], 0.5);
554
555        let mut writer = TableWriter::new(
556            Arc::clone(&catalog) as Arc<dyn ailake_catalog::CatalogProvider>,
557            Arc::clone(&store) as Arc<dyn ailake_store::Store>,
558            policy,
559            table,
560        );
561        wm.drain_to_table(&mut writer).await.unwrap();
562        assert!(wm.is_empty());
563        let snap = writer.commit().await.unwrap();
564        assert!(snap > 0);
565    }
566
567    #[tokio::test]
568    async fn mem_table_flush_if_due() {
569        let dir = tempfile::tempdir().unwrap();
570        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
571        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
572        let table = TableIdent::new("default", "test_due");
573        setup_table(&catalog, &table).await;
574
575        let config = MemTableConfig {
576            flush_size_bytes: 1024 * 1024 * 1024,
577            flush_max_rows: 1000,
578            flush_interval: Duration::from_millis(1),
579        };
580        let mut mt = MemTableWriter::new(
581            catalog.clone(),
582            store.clone(),
583            make_policy(),
584            table.clone(),
585            config,
586        );
587
588        let batch = make_batch(&[1, 2, 3]);
589        let embs = make_embeddings(3, 4);
590        mt.insert(&batch, &embs).await.unwrap();
591
592        tokio::time::sleep(Duration::from_millis(5)).await;
593
594        let snap = mt.flush_if_due().await.unwrap();
595        assert!(snap.is_some(), "should flush because interval elapsed");
596        assert_eq!(mt.buffered_rows(), 0);
597    }
598}