1use std::collections::VecDeque;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tracing::info;
6
7use ailake_catalog::{CatalogProvider, SnapshotId, TableIdent};
8use ailake_core::{AilakeResult, VectorStoragePolicy};
9use ailake_store::Store;
10use arrow_array::RecordBatch;
11use arrow_select::concat::concat_batches;
12
13use crate::writer::TableWriter;
14
15#[derive(Debug, Clone)]
17pub struct MemTableConfig {
18 pub flush_size_bytes: usize,
21 pub flush_max_rows: usize,
24 pub flush_interval: Duration,
27}
28
29impl Default for MemTableConfig {
30 fn default() -> Self {
31 Self {
32 flush_size_bytes: 64 * 1024 * 1024,
33 flush_max_rows: 100_000,
34 flush_interval: Duration::from_secs(30),
35 }
36 }
37}
38
39pub struct MemTableWriter {
61 catalog: Arc<dyn CatalogProvider>,
62 store: Arc<dyn Store>,
63 policy: VectorStoragePolicy,
64 table: TableIdent,
65 config: MemTableConfig,
66
67 pending_batches: Vec<RecordBatch>,
69 pending_embeddings: Vec<Vec<f32>>,
70 buffered_bytes: usize,
71 last_flush: Instant,
72}
73
74impl MemTableWriter {
75 pub fn new(
76 catalog: Arc<dyn CatalogProvider>,
77 store: Arc<dyn Store>,
78 policy: VectorStoragePolicy,
79 table: TableIdent,
80 config: MemTableConfig,
81 ) -> Self {
82 Self {
83 catalog,
84 store,
85 policy,
86 table,
87 config,
88 pending_batches: Vec::new(),
89 pending_embeddings: Vec::new(),
90 buffered_bytes: 0,
91 last_flush: Instant::now(),
92 }
93 }
94
95 pub async fn insert(
98 &mut self,
99 batch: &RecordBatch,
100 embeddings: &[Vec<f32>],
101 ) -> AilakeResult<Option<SnapshotId>> {
102 let row_bytes =
103 embeddings.len() * self.policy.dim as usize * self.policy.precision.bytes_per_element();
104
105 self.pending_batches.push(batch.clone());
106 self.pending_embeddings.extend_from_slice(embeddings);
107 self.buffered_bytes += row_bytes;
108
109 if self.buffered_bytes >= self.config.flush_size_bytes
110 || self.pending_embeddings.len() >= self.config.flush_max_rows
111 {
112 info!(
113 "ailake: MemTable auto-flush triggered — {} rows / {} bytes buffered",
114 self.pending_embeddings.len(),
115 self.buffered_bytes
116 );
117 Ok(Some(self.flush().await?))
118 } else {
119 Ok(None)
120 }
121 }
122
123 pub async fn flush_if_due(&mut self) -> AilakeResult<Option<SnapshotId>> {
126 if self.pending_embeddings.is_empty() {
127 self.last_flush = Instant::now();
128 return Ok(None);
129 }
130 if self.last_flush.elapsed() >= self.config.flush_interval {
131 info!(
132 "ailake: MemTable time-based flush — {} rows / {} bytes buffered (interval={}s)",
133 self.pending_embeddings.len(),
134 self.buffered_bytes,
135 self.config.flush_interval.as_secs()
136 );
137 Ok(Some(self.flush().await?))
138 } else {
139 Ok(None)
140 }
141 }
142
143 pub async fn flush(&mut self) -> AilakeResult<SnapshotId> {
147 if self.pending_embeddings.is_empty() {
148 self.last_flush = Instant::now();
149 return Ok(0);
150 }
151
152 let merged = if self.pending_batches.len() == 1 {
154 self.pending_batches.remove(0)
155 } else {
156 concat_batches(&self.pending_batches[0].schema(), &self.pending_batches)
157 .map_err(|e| ailake_core::AilakeError::Arrow(e.to_string()))?
158 };
159 let embeddings = std::mem::take(&mut self.pending_embeddings);
160 self.pending_batches.clear();
161 self.buffered_bytes = 0;
162 self.last_flush = Instant::now();
163
164 let mut writer = TableWriter::new(
166 self.catalog.clone(),
167 self.store.clone(),
168 self.policy.clone(),
169 self.table.clone(),
170 );
171 writer.write_batch_deferred(&merged, &embeddings).await?;
172 let snap = writer.commit().await?;
173 Ok(snap)
174 }
175
176 pub fn buffered_rows(&self) -> usize {
178 self.pending_embeddings.len()
179 }
180
181 pub fn buffered_bytes(&self) -> usize {
183 self.buffered_bytes
184 }
185
186 pub fn is_full(&self) -> bool {
188 self.buffered_bytes >= self.config.flush_size_bytes
189 || self.pending_embeddings.len() >= self.config.flush_max_rows
190 }
191}
192
193#[derive(Debug, Clone)]
197pub struct WorkingMemoryEntry {
198 pub text: String,
200 pub embedding: Vec<f32>,
202 pub importance: f32,
204}
205
206pub struct WorkingMemoryBuffer {
231 max_rows: usize,
232 entries: VecDeque<WorkingMemoryEntry>,
233}
234
235impl WorkingMemoryBuffer {
236 pub fn new(max_rows: usize) -> Self {
238 Self {
239 max_rows: max_rows.max(1),
240 entries: VecDeque::with_capacity(max_rows.min(4096)),
241 }
242 }
243
244 pub fn push(&mut self, text: impl Into<String>, embedding: Vec<f32>, importance: f32) {
246 if self.entries.len() >= self.max_rows {
247 self.entries.pop_front();
248 }
249 self.entries.push_back(WorkingMemoryEntry {
250 text: text.into(),
251 embedding,
252 importance: importance.clamp(0.0, 1.0),
253 });
254 }
255
256 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(f32, &WorkingMemoryEntry)> {
259 if self.entries.is_empty() || top_k == 0 {
260 return vec![];
261 }
262 let q_norm = l2_norm(query);
263 let mut scored: Vec<(f32, usize)> = self
264 .entries
265 .iter()
266 .enumerate()
267 .map(|(i, e)| {
268 let v_norm = l2_norm(&e.embedding);
269 let dot: f32 = query.iter().zip(&e.embedding).map(|(a, b)| a * b).sum();
270 let cos_sim = if q_norm * v_norm < f32::EPSILON {
271 0.0
272 } else {
273 dot / (q_norm * v_norm)
274 };
275 (1.0 - cos_sim, i)
277 })
278 .collect();
279
280 scored.sort_unstable_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
281 scored.truncate(top_k);
282 scored
283 .into_iter()
284 .map(|(dist, idx)| (dist, &self.entries[idx]))
285 .collect()
286 }
287
288 pub async fn drain_to_table(&mut self, writer: &mut TableWriter) -> AilakeResult<()> {
293 use arrow_array::StringArray;
294 use arrow_schema::{DataType, Field, Schema};
295
296 if self.entries.is_empty() {
297 return Ok(());
298 }
299
300 let texts: Vec<&str> = self.entries.iter().map(|e| e.text.as_str()).collect();
301 let embeddings: Vec<Vec<f32>> = self.entries.iter().map(|e| e.embedding.clone()).collect();
302 let importance_vals: Vec<f32> = self.entries.iter().map(|e| e.importance).collect();
303
304 let schema = Arc::new(Schema::new(vec![
305 Field::new("chunk_text", DataType::Utf8, false),
306 Field::new("importance_score", DataType::Float32, false),
307 ]));
308 let batch = arrow_array::RecordBatch::try_new(
309 schema,
310 vec![
311 Arc::new(StringArray::from(texts)) as _,
312 Arc::new(arrow_array::Float32Array::from(importance_vals)) as _,
313 ],
314 )
315 .map_err(|e| ailake_core::AilakeError::Arrow(e.to_string()))?;
316
317 writer.write_batch(&batch, &embeddings).await?;
318 self.entries.clear();
319 Ok(())
320 }
321
322 pub fn len(&self) -> usize {
323 self.entries.len()
324 }
325
326 pub fn is_empty(&self) -> bool {
327 self.entries.is_empty()
328 }
329
330 pub fn is_full(&self) -> bool {
332 self.entries.len() >= self.max_rows
333 }
334}
335
336fn l2_norm(v: &[f32]) -> f32 {
337 v.iter().map(|x| x * x).sum::<f32>().sqrt()
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use std::sync::Arc;
344
345 use ailake_catalog::{HadoopCatalog, TableIdent};
346 use ailake_core::{VectorMetric, VectorPrecision};
347 use ailake_store::{LocalStore, Store};
348 use arrow_array::{Int32Array, StringArray};
349 use arrow_schema::{DataType, Field, Schema};
350
351 fn make_policy() -> VectorStoragePolicy {
352 VectorStoragePolicy {
353 column_name: "embedding".to_string(),
354 dim: 4,
355 metric: VectorMetric::Euclidean,
356 precision: VectorPrecision::F16,
357 pq: None,
358 keep_raw_for_reranking: true,
359 pre_normalize: false,
360 hnsw_m: None,
361 hnsw_ef_construction: None,
362 ivf_residual: false,
363 embedding_model: None,
364 modality: None,
365 partition_by: None,
366 partition_value: None,
367 partition_column_type: None,
368 partition_fields: vec![],
369 }
370 }
371
372 fn make_batch(ids: &[i32]) -> RecordBatch {
373 let schema = Arc::new(Schema::new(vec![
374 Field::new("id", DataType::Int32, false),
375 Field::new("text", DataType::Utf8, false),
376 ]));
377 let texts: Vec<&str> = ids.iter().map(|_| "chunk").collect();
378 RecordBatch::try_new(
379 schema,
380 vec![
381 Arc::new(Int32Array::from(ids.to_vec())),
382 Arc::new(StringArray::from(texts)),
383 ],
384 )
385 .unwrap()
386 }
387
388 fn make_embeddings(n: usize, dim: usize) -> Vec<Vec<f32>> {
389 (0..n)
390 .map(|i| (0..dim).map(|d| i as f32 + d as f32 * 0.1).collect())
391 .collect()
392 }
393
394 async fn setup_table(catalog: &HadoopCatalog, table: &TableIdent) {
395 catalog
396 .create_table(
397 table,
398 &ailake_catalog::TableProperties {
399 policy: make_policy(),
400 extra: Default::default(),
401 format_version: 2,
402 partition_column_type: None,
403 },
404 )
405 .await
406 .unwrap();
407 }
408
409 #[tokio::test]
410 async fn mem_table_insert_and_flush() {
411 let dir = tempfile::tempdir().unwrap();
412 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
413 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
414 let table = TableIdent::new("default", "test_mem");
415 setup_table(&catalog, &table).await;
416
417 let config = MemTableConfig {
418 flush_size_bytes: 1024 * 1024,
419 flush_max_rows: 1000,
420 flush_interval: Duration::from_secs(60),
421 };
422 let mut mt = MemTableWriter::new(
423 catalog.clone(),
424 store.clone(),
425 make_policy(),
426 table.clone(),
427 config,
428 );
429
430 for i in 0..3 {
431 let ids: Vec<i32> = (i * 5..(i + 1) * 5).collect();
432 let batch = make_batch(&ids);
433 let embs = make_embeddings(5, 4);
434 let snap = mt.insert(&batch, &embs).await.unwrap();
435 assert!(snap.is_none(), "should not auto-flush yet");
436 }
437 assert_eq!(mt.buffered_rows(), 15);
438
439 let snap = mt.flush().await.unwrap();
440 assert!(snap > 0, "snapshot id should be non-zero");
441 assert_eq!(mt.buffered_rows(), 0, "buffer should be empty after flush");
442 assert_eq!(mt.buffered_bytes(), 0);
443 }
444
445 #[tokio::test]
446 async fn mem_table_auto_flush_on_row_limit() {
447 let dir = tempfile::tempdir().unwrap();
448 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
449 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
450 let table = TableIdent::new("default", "test_auto");
451 setup_table(&catalog, &table).await;
452
453 let config = MemTableConfig {
454 flush_size_bytes: 1024 * 1024 * 1024,
455 flush_max_rows: 8,
456 flush_interval: Duration::from_secs(60),
457 };
458 let mut mt = MemTableWriter::new(
459 catalog.clone(),
460 store.clone(),
461 make_policy(),
462 table.clone(),
463 config,
464 );
465
466 let batch = make_batch(&[1, 2, 3, 4, 5]);
467 let embs = make_embeddings(5, 4);
468 assert!(mt.insert(&batch, &embs).await.unwrap().is_none());
469
470 let batch2 = make_batch(&[6, 7, 8, 9, 10]);
471 let snap = mt.insert(&batch2, &embs).await.unwrap();
472 assert!(snap.is_some(), "should auto-flush when row limit exceeded");
473 assert_eq!(mt.buffered_rows(), 0);
474 }
475
476 #[test]
477 fn working_memory_evicts_oldest() {
478 let mut wm = WorkingMemoryBuffer::new(3);
479 wm.push("a", vec![1.0, 0.0], 1.0);
480 wm.push("b", vec![0.0, 1.0], 1.0);
481 wm.push("c", vec![1.0, 1.0], 1.0);
482 assert_eq!(wm.len(), 3);
483 assert!(wm.is_full());
484
485 wm.push("d", vec![0.5, 0.5], 1.0);
486 assert_eq!(wm.len(), 3);
487 assert!(!wm.entries.iter().any(|e| e.text == "a"));
489 assert!(wm.entries.iter().any(|e| e.text == "d"));
490 }
491
492 #[test]
493 fn working_memory_search_ranks_similar_first() {
494 let mut wm = WorkingMemoryBuffer::new(10);
495 wm.push("near", vec![1.0, 0.0, 0.0], 1.0);
496 wm.push("far", vec![0.0, 1.0, 0.0], 1.0);
497 wm.push("very far", vec![0.0, 0.0, 1.0], 1.0);
498
499 let query = vec![1.0, 0.0, 0.0];
500 let results = wm.search(&query, 3);
501 assert_eq!(results.len(), 3);
502 assert_eq!(results[0].1.text, "near");
504 assert!(results[0].0 < results[1].0);
505 }
506
507 #[test]
508 fn working_memory_search_empty() {
509 let wm = WorkingMemoryBuffer::new(10);
510 assert!(wm.search(&[1.0, 0.0], 5).is_empty());
511 }
512
513 #[tokio::test]
514 async fn working_memory_drain_to_table() {
515 let dir = tempfile::tempdir().unwrap();
516 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
517 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
518 let table = TableIdent::new("default", "test_wm_drain");
519
520 let policy = VectorStoragePolicy {
521 column_name: "embedding".to_string(),
522 dim: 3,
523 metric: ailake_core::VectorMetric::Cosine,
524 precision: ailake_core::VectorPrecision::F16,
525 pq: None,
526 keep_raw_for_reranking: false,
527 pre_normalize: false,
528 hnsw_m: None,
529 hnsw_ef_construction: None,
530 ivf_residual: false,
531 embedding_model: None,
532 modality: None,
533 partition_by: None,
534 partition_value: None,
535 partition_column_type: None,
536 partition_fields: vec![],
537 };
538 catalog
539 .create_table(
540 &table,
541 &ailake_catalog::TableProperties {
542 policy: policy.clone(),
543 extra: Default::default(),
544 format_version: 2,
545 partition_column_type: None,
546 },
547 )
548 .await
549 .unwrap();
550
551 let mut wm = WorkingMemoryBuffer::new(5);
552 wm.push("memory one", vec![1.0, 0.0, 0.0], 0.9);
553 wm.push("memory two", vec![0.0, 1.0, 0.0], 0.5);
554
555 let mut writer = TableWriter::new(
556 Arc::clone(&catalog) as Arc<dyn ailake_catalog::CatalogProvider>,
557 Arc::clone(&store) as Arc<dyn ailake_store::Store>,
558 policy,
559 table,
560 );
561 wm.drain_to_table(&mut writer).await.unwrap();
562 assert!(wm.is_empty());
563 let snap = writer.commit().await.unwrap();
564 assert!(snap > 0);
565 }
566
567 #[tokio::test]
568 async fn mem_table_flush_if_due() {
569 let dir = tempfile::tempdir().unwrap();
570 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
571 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
572 let table = TableIdent::new("default", "test_due");
573 setup_table(&catalog, &table).await;
574
575 let config = MemTableConfig {
576 flush_size_bytes: 1024 * 1024 * 1024,
577 flush_max_rows: 1000,
578 flush_interval: Duration::from_millis(1),
579 };
580 let mut mt = MemTableWriter::new(
581 catalog.clone(),
582 store.clone(),
583 make_policy(),
584 table.clone(),
585 config,
586 );
587
588 let batch = make_batch(&[1, 2, 3]);
589 let embs = make_embeddings(3, 4);
590 mt.insert(&batch, &embs).await.unwrap();
591
592 tokio::time::sleep(Duration::from_millis(5)).await;
593
594 let snap = mt.flush_if_due().await.unwrap();
595 assert!(snap.is_some(), "should flush because interval elapsed");
596 assert_eq!(mt.buffered_rows(), 0);
597 }
598}