1use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tracing::info;
5
6use ailake_catalog::{CatalogProvider, SnapshotId, TableIdent};
7use ailake_core::{AilakeResult, VectorStoragePolicy};
8use ailake_store::Store;
9use arrow_array::RecordBatch;
10use arrow_select::concat::concat_batches;
11
12use crate::writer::TableWriter;
13
14#[derive(Debug, Clone)]
16pub struct MemTableConfig {
17 pub flush_size_bytes: usize,
20 pub flush_max_rows: usize,
23 pub flush_interval: Duration,
26}
27
28impl Default for MemTableConfig {
29 fn default() -> Self {
30 Self {
31 flush_size_bytes: 64 * 1024 * 1024,
32 flush_max_rows: 100_000,
33 flush_interval: Duration::from_secs(30),
34 }
35 }
36}
37
38pub struct MemTableWriter {
60 catalog: Arc<dyn CatalogProvider>,
61 store: Arc<dyn Store>,
62 policy: VectorStoragePolicy,
63 table: TableIdent,
64 config: MemTableConfig,
65
66 pending_batches: Vec<RecordBatch>,
68 pending_embeddings: Vec<Vec<f32>>,
69 buffered_bytes: usize,
70 last_flush: Instant,
71}
72
73impl MemTableWriter {
74 pub fn new(
75 catalog: Arc<dyn CatalogProvider>,
76 store: Arc<dyn Store>,
77 policy: VectorStoragePolicy,
78 table: TableIdent,
79 config: MemTableConfig,
80 ) -> Self {
81 Self {
82 catalog,
83 store,
84 policy,
85 table,
86 config,
87 pending_batches: Vec::new(),
88 pending_embeddings: Vec::new(),
89 buffered_bytes: 0,
90 last_flush: Instant::now(),
91 }
92 }
93
94 pub async fn insert(
97 &mut self,
98 batch: &RecordBatch,
99 embeddings: &[Vec<f32>],
100 ) -> AilakeResult<Option<SnapshotId>> {
101 let row_bytes =
102 embeddings.len() * self.policy.dim as usize * self.policy.precision.bytes_per_element();
103
104 self.pending_batches.push(batch.clone());
105 self.pending_embeddings.extend_from_slice(embeddings);
106 self.buffered_bytes += row_bytes;
107
108 if self.buffered_bytes >= self.config.flush_size_bytes
109 || self.pending_embeddings.len() >= self.config.flush_max_rows
110 {
111 info!(
112 "ailake: MemTable auto-flush triggered — {} rows / {} bytes buffered",
113 self.pending_embeddings.len(),
114 self.buffered_bytes
115 );
116 Ok(Some(self.flush().await?))
117 } else {
118 Ok(None)
119 }
120 }
121
122 pub async fn flush_if_due(&mut self) -> AilakeResult<Option<SnapshotId>> {
125 if self.pending_embeddings.is_empty() {
126 self.last_flush = Instant::now();
127 return Ok(None);
128 }
129 if self.last_flush.elapsed() >= self.config.flush_interval {
130 info!(
131 "ailake: MemTable time-based flush — {} rows / {} bytes buffered (interval={}s)",
132 self.pending_embeddings.len(),
133 self.buffered_bytes,
134 self.config.flush_interval.as_secs()
135 );
136 Ok(Some(self.flush().await?))
137 } else {
138 Ok(None)
139 }
140 }
141
142 pub async fn flush(&mut self) -> AilakeResult<SnapshotId> {
146 if self.pending_embeddings.is_empty() {
147 self.last_flush = Instant::now();
148 return Ok(0);
149 }
150
151 let merged = if self.pending_batches.len() == 1 {
153 self.pending_batches.remove(0)
154 } else {
155 concat_batches(&self.pending_batches[0].schema(), &self.pending_batches)
156 .map_err(|e| ailake_core::AilakeError::Arrow(e.to_string()))?
157 };
158 let embeddings = std::mem::take(&mut self.pending_embeddings);
159 self.pending_batches.clear();
160 self.buffered_bytes = 0;
161 self.last_flush = Instant::now();
162
163 let mut writer = TableWriter::new(
165 self.catalog.clone(),
166 self.store.clone(),
167 self.policy.clone(),
168 self.table.clone(),
169 );
170 writer.write_batch_deferred(&merged, &embeddings).await?;
171 let snap = writer.commit().await?;
172 Ok(snap)
173 }
174
175 pub fn buffered_rows(&self) -> usize {
177 self.pending_embeddings.len()
178 }
179
180 pub fn buffered_bytes(&self) -> usize {
182 self.buffered_bytes
183 }
184
185 pub fn is_full(&self) -> bool {
187 self.buffered_bytes >= self.config.flush_size_bytes
188 || self.pending_embeddings.len() >= self.config.flush_max_rows
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use std::sync::Arc;
196
197 use ailake_catalog::{HadoopCatalog, TableIdent};
198 use ailake_core::{VectorMetric, VectorPrecision};
199 use ailake_store::{LocalStore, Store};
200 use arrow_array::{Int32Array, StringArray};
201 use arrow_schema::{DataType, Field, Schema};
202
203 fn make_policy() -> VectorStoragePolicy {
204 VectorStoragePolicy {
205 column_name: "embedding".to_string(),
206 dim: 4,
207 metric: VectorMetric::Euclidean,
208 precision: VectorPrecision::F16,
209 pq: None,
210 keep_raw_for_reranking: false,
211 pre_normalize: false,
212 hnsw_m: None,
213 hnsw_ef_construction: None,
214 rabitq: None,
215 }
216 }
217
218 fn make_batch(ids: &[i32]) -> RecordBatch {
219 let schema = Arc::new(Schema::new(vec![
220 Field::new("id", DataType::Int32, false),
221 Field::new("text", DataType::Utf8, false),
222 ]));
223 let texts: Vec<&str> = ids.iter().map(|_| "chunk").collect();
224 RecordBatch::try_new(
225 schema,
226 vec![
227 Arc::new(Int32Array::from(ids.to_vec())),
228 Arc::new(StringArray::from(texts)),
229 ],
230 )
231 .unwrap()
232 }
233
234 fn make_embeddings(n: usize, dim: usize) -> Vec<Vec<f32>> {
235 (0..n)
236 .map(|i| (0..dim).map(|d| i as f32 + d as f32 * 0.1).collect())
237 .collect()
238 }
239
240 async fn setup_table(catalog: &HadoopCatalog, table: &TableIdent) {
241 catalog
242 .create_table(
243 table,
244 &ailake_catalog::TableProperties {
245 policy: make_policy(),
246 extra: Default::default(),
247 },
248 )
249 .await
250 .unwrap();
251 }
252
253 #[tokio::test]
254 async fn mem_table_insert_and_flush() {
255 let dir = tempfile::tempdir().unwrap();
256 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
257 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
258 let table = TableIdent::new("default", "test_mem");
259 setup_table(&catalog, &table).await;
260
261 let config = MemTableConfig {
262 flush_size_bytes: 1024 * 1024,
263 flush_max_rows: 1000,
264 flush_interval: Duration::from_secs(60),
265 };
266 let mut mt = MemTableWriter::new(
267 catalog.clone(),
268 store.clone(),
269 make_policy(),
270 table.clone(),
271 config,
272 );
273
274 for i in 0..3 {
275 let ids: Vec<i32> = (i * 5..(i + 1) * 5).collect();
276 let batch = make_batch(&ids);
277 let embs = make_embeddings(5, 4);
278 let snap = mt.insert(&batch, &embs).await.unwrap();
279 assert!(snap.is_none(), "should not auto-flush yet");
280 }
281 assert_eq!(mt.buffered_rows(), 15);
282
283 let snap = mt.flush().await.unwrap();
284 assert!(snap > 0, "snapshot id should be non-zero");
285 assert_eq!(mt.buffered_rows(), 0, "buffer should be empty after flush");
286 assert_eq!(mt.buffered_bytes(), 0);
287 }
288
289 #[tokio::test]
290 async fn mem_table_auto_flush_on_row_limit() {
291 let dir = tempfile::tempdir().unwrap();
292 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
293 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
294 let table = TableIdent::new("default", "test_auto");
295 setup_table(&catalog, &table).await;
296
297 let config = MemTableConfig {
298 flush_size_bytes: 1024 * 1024 * 1024,
299 flush_max_rows: 8,
300 flush_interval: Duration::from_secs(60),
301 };
302 let mut mt = MemTableWriter::new(
303 catalog.clone(),
304 store.clone(),
305 make_policy(),
306 table.clone(),
307 config,
308 );
309
310 let batch = make_batch(&[1, 2, 3, 4, 5]);
311 let embs = make_embeddings(5, 4);
312 assert!(mt.insert(&batch, &embs).await.unwrap().is_none());
313
314 let batch2 = make_batch(&[6, 7, 8, 9, 10]);
315 let snap = mt.insert(&batch2, &embs).await.unwrap();
316 assert!(snap.is_some(), "should auto-flush when row limit exceeded");
317 assert_eq!(mt.buffered_rows(), 0);
318 }
319
320 #[tokio::test]
321 async fn mem_table_flush_if_due() {
322 let dir = tempfile::tempdir().unwrap();
323 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
324 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
325 let table = TableIdent::new("default", "test_due");
326 setup_table(&catalog, &table).await;
327
328 let config = MemTableConfig {
329 flush_size_bytes: 1024 * 1024 * 1024,
330 flush_max_rows: 1000,
331 flush_interval: Duration::from_millis(1),
332 };
333 let mut mt = MemTableWriter::new(
334 catalog.clone(),
335 store.clone(),
336 make_policy(),
337 table.clone(),
338 config,
339 );
340
341 let batch = make_batch(&[1, 2, 3]);
342 let embs = make_embeddings(3, 4);
343 mt.insert(&batch, &embs).await.unwrap();
344
345 tokio::time::sleep(Duration::from_millis(5)).await;
346
347 let snap = mt.flush_if_due().await.unwrap();
348 assert!(snap.is_some(), "should flush because interval elapsed");
349 assert_eq!(mt.buffered_rows(), 0);
350 }
351}