1use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tracing::info;
5
6use ailake_catalog::{CatalogProvider, SnapshotId, TableIdent};
7use ailake_core::{AilakeResult, VectorStoragePolicy};
8use ailake_store::Store;
9use arrow_array::RecordBatch;
10use arrow_select::concat::concat_batches;
11
12use crate::writer::TableWriter;
13
14#[derive(Debug, Clone)]
16pub struct MemTableConfig {
17 pub flush_size_bytes: usize,
20 pub flush_max_rows: usize,
23 pub flush_interval: Duration,
26}
27
28impl Default for MemTableConfig {
29 fn default() -> Self {
30 Self {
31 flush_size_bytes: 64 * 1024 * 1024,
32 flush_max_rows: 100_000,
33 flush_interval: Duration::from_secs(30),
34 }
35 }
36}
37
38pub struct MemTableWriter {
60 catalog: Arc<dyn CatalogProvider>,
61 store: Arc<dyn Store>,
62 policy: VectorStoragePolicy,
63 table: TableIdent,
64 config: MemTableConfig,
65
66 pending_batches: Vec<RecordBatch>,
68 pending_embeddings: Vec<Vec<f32>>,
69 buffered_bytes: usize,
70 last_flush: Instant,
71}
72
73impl MemTableWriter {
74 pub fn new(
75 catalog: Arc<dyn CatalogProvider>,
76 store: Arc<dyn Store>,
77 policy: VectorStoragePolicy,
78 table: TableIdent,
79 config: MemTableConfig,
80 ) -> Self {
81 Self {
82 catalog,
83 store,
84 policy,
85 table,
86 config,
87 pending_batches: Vec::new(),
88 pending_embeddings: Vec::new(),
89 buffered_bytes: 0,
90 last_flush: Instant::now(),
91 }
92 }
93
94 pub async fn insert(
97 &mut self,
98 batch: &RecordBatch,
99 embeddings: &[Vec<f32>],
100 ) -> AilakeResult<Option<SnapshotId>> {
101 let row_bytes =
102 embeddings.len() * self.policy.dim as usize * self.policy.precision.bytes_per_element();
103
104 self.pending_batches.push(batch.clone());
105 self.pending_embeddings.extend_from_slice(embeddings);
106 self.buffered_bytes += row_bytes;
107
108 if self.buffered_bytes >= self.config.flush_size_bytes
109 || self.pending_embeddings.len() >= self.config.flush_max_rows
110 {
111 info!(
112 "ailake: MemTable auto-flush triggered — {} rows / {} bytes buffered",
113 self.pending_embeddings.len(),
114 self.buffered_bytes
115 );
116 Ok(Some(self.flush().await?))
117 } else {
118 Ok(None)
119 }
120 }
121
122 pub async fn flush_if_due(&mut self) -> AilakeResult<Option<SnapshotId>> {
125 if self.pending_embeddings.is_empty() {
126 self.last_flush = Instant::now();
127 return Ok(None);
128 }
129 if self.last_flush.elapsed() >= self.config.flush_interval {
130 info!(
131 "ailake: MemTable time-based flush — {} rows / {} bytes buffered (interval={}s)",
132 self.pending_embeddings.len(),
133 self.buffered_bytes,
134 self.config.flush_interval.as_secs()
135 );
136 Ok(Some(self.flush().await?))
137 } else {
138 Ok(None)
139 }
140 }
141
142 pub async fn flush(&mut self) -> AilakeResult<SnapshotId> {
146 if self.pending_embeddings.is_empty() {
147 self.last_flush = Instant::now();
148 return Ok(0);
149 }
150
151 let merged = if self.pending_batches.len() == 1 {
153 self.pending_batches.remove(0)
154 } else {
155 concat_batches(&self.pending_batches[0].schema(), &self.pending_batches)
156 .map_err(|e| ailake_core::AilakeError::Arrow(e.to_string()))?
157 };
158 let embeddings = std::mem::take(&mut self.pending_embeddings);
159 self.pending_batches.clear();
160 self.buffered_bytes = 0;
161 self.last_flush = Instant::now();
162
163 let mut writer = TableWriter::new(
165 self.catalog.clone(),
166 self.store.clone(),
167 self.policy.clone(),
168 self.table.clone(),
169 );
170 writer.write_batch_deferred(&merged, &embeddings).await?;
171 let snap = writer.commit().await?;
172 Ok(snap)
173 }
174
175 pub fn buffered_rows(&self) -> usize {
177 self.pending_embeddings.len()
178 }
179
180 pub fn buffered_bytes(&self) -> usize {
182 self.buffered_bytes
183 }
184
185 pub fn is_full(&self) -> bool {
187 self.buffered_bytes >= self.config.flush_size_bytes
188 || self.pending_embeddings.len() >= self.config.flush_max_rows
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use std::sync::Arc;
196
197 use ailake_catalog::{HadoopCatalog, TableIdent};
198 use ailake_core::{VectorMetric, VectorPrecision};
199 use ailake_store::{LocalStore, Store};
200 use arrow_array::{Int32Array, StringArray};
201 use arrow_schema::{DataType, Field, Schema};
202
203 fn make_policy() -> VectorStoragePolicy {
204 VectorStoragePolicy {
205 column_name: "embedding".to_string(),
206 dim: 4,
207 metric: VectorMetric::Euclidean,
208 precision: VectorPrecision::F16,
209 pq: None,
210 keep_raw_for_reranking: true,
211 pre_normalize: false,
212 hnsw_m: None,
213 hnsw_ef_construction: None,
214 ivf_residual: false,
215 embedding_model: None,
216 modality: None,
217 }
218 }
219
220 fn make_batch(ids: &[i32]) -> RecordBatch {
221 let schema = Arc::new(Schema::new(vec![
222 Field::new("id", DataType::Int32, false),
223 Field::new("text", DataType::Utf8, false),
224 ]));
225 let texts: Vec<&str> = ids.iter().map(|_| "chunk").collect();
226 RecordBatch::try_new(
227 schema,
228 vec![
229 Arc::new(Int32Array::from(ids.to_vec())),
230 Arc::new(StringArray::from(texts)),
231 ],
232 )
233 .unwrap()
234 }
235
236 fn make_embeddings(n: usize, dim: usize) -> Vec<Vec<f32>> {
237 (0..n)
238 .map(|i| (0..dim).map(|d| i as f32 + d as f32 * 0.1).collect())
239 .collect()
240 }
241
242 async fn setup_table(catalog: &HadoopCatalog, table: &TableIdent) {
243 catalog
244 .create_table(
245 table,
246 &ailake_catalog::TableProperties {
247 policy: make_policy(),
248 extra: Default::default(),
249 },
250 )
251 .await
252 .unwrap();
253 }
254
255 #[tokio::test]
256 async fn mem_table_insert_and_flush() {
257 let dir = tempfile::tempdir().unwrap();
258 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
259 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
260 let table = TableIdent::new("default", "test_mem");
261 setup_table(&catalog, &table).await;
262
263 let config = MemTableConfig {
264 flush_size_bytes: 1024 * 1024,
265 flush_max_rows: 1000,
266 flush_interval: Duration::from_secs(60),
267 };
268 let mut mt = MemTableWriter::new(
269 catalog.clone(),
270 store.clone(),
271 make_policy(),
272 table.clone(),
273 config,
274 );
275
276 for i in 0..3 {
277 let ids: Vec<i32> = (i * 5..(i + 1) * 5).collect();
278 let batch = make_batch(&ids);
279 let embs = make_embeddings(5, 4);
280 let snap = mt.insert(&batch, &embs).await.unwrap();
281 assert!(snap.is_none(), "should not auto-flush yet");
282 }
283 assert_eq!(mt.buffered_rows(), 15);
284
285 let snap = mt.flush().await.unwrap();
286 assert!(snap > 0, "snapshot id should be non-zero");
287 assert_eq!(mt.buffered_rows(), 0, "buffer should be empty after flush");
288 assert_eq!(mt.buffered_bytes(), 0);
289 }
290
291 #[tokio::test]
292 async fn mem_table_auto_flush_on_row_limit() {
293 let dir = tempfile::tempdir().unwrap();
294 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
295 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
296 let table = TableIdent::new("default", "test_auto");
297 setup_table(&catalog, &table).await;
298
299 let config = MemTableConfig {
300 flush_size_bytes: 1024 * 1024 * 1024,
301 flush_max_rows: 8,
302 flush_interval: Duration::from_secs(60),
303 };
304 let mut mt = MemTableWriter::new(
305 catalog.clone(),
306 store.clone(),
307 make_policy(),
308 table.clone(),
309 config,
310 );
311
312 let batch = make_batch(&[1, 2, 3, 4, 5]);
313 let embs = make_embeddings(5, 4);
314 assert!(mt.insert(&batch, &embs).await.unwrap().is_none());
315
316 let batch2 = make_batch(&[6, 7, 8, 9, 10]);
317 let snap = mt.insert(&batch2, &embs).await.unwrap();
318 assert!(snap.is_some(), "should auto-flush when row limit exceeded");
319 assert_eq!(mt.buffered_rows(), 0);
320 }
321
322 #[tokio::test]
323 async fn mem_table_flush_if_due() {
324 let dir = tempfile::tempdir().unwrap();
325 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
326 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
327 let table = TableIdent::new("default", "test_due");
328 setup_table(&catalog, &table).await;
329
330 let config = MemTableConfig {
331 flush_size_bytes: 1024 * 1024 * 1024,
332 flush_max_rows: 1000,
333 flush_interval: Duration::from_millis(1),
334 };
335 let mut mt = MemTableWriter::new(
336 catalog.clone(),
337 store.clone(),
338 make_policy(),
339 table.clone(),
340 config,
341 );
342
343 let batch = make_batch(&[1, 2, 3]);
344 let embs = make_embeddings(3, 4);
345 mt.insert(&batch, &embs).await.unwrap();
346
347 tokio::time::sleep(Duration::from_millis(5)).await;
348
349 let snap = mt.flush_if_due().await.unwrap();
350 assert!(snap.is_some(), "should flush because interval elapsed");
351 assert_eq!(mt.buffered_rows(), 0);
352 }
353}