1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use ailake_catalog::{CatalogProvider, SnapshotId, TableIdent};
5use ailake_core::{AilakeResult, VectorStoragePolicy};
6use ailake_store::Store;
7use arrow_array::RecordBatch;
8use arrow_select::concat::concat_batches;
9
10use crate::writer::TableWriter;
11
12#[derive(Debug, Clone)]
14pub struct MemTableConfig {
15 pub flush_size_bytes: usize,
18 pub flush_max_rows: usize,
21 pub flush_interval: Duration,
24}
25
26impl Default for MemTableConfig {
27 fn default() -> Self {
28 Self {
29 flush_size_bytes: 64 * 1024 * 1024,
30 flush_max_rows: 100_000,
31 flush_interval: Duration::from_secs(30),
32 }
33 }
34}
35
36pub struct MemTableWriter {
58 catalog: Arc<dyn CatalogProvider>,
59 store: Arc<dyn Store>,
60 policy: VectorStoragePolicy,
61 table: TableIdent,
62 config: MemTableConfig,
63
64 pending_batches: Vec<RecordBatch>,
66 pending_embeddings: Vec<Vec<f32>>,
67 buffered_bytes: usize,
68 last_flush: Instant,
69}
70
71impl MemTableWriter {
72 pub fn new(
73 catalog: Arc<dyn CatalogProvider>,
74 store: Arc<dyn Store>,
75 policy: VectorStoragePolicy,
76 table: TableIdent,
77 config: MemTableConfig,
78 ) -> Self {
79 Self {
80 catalog,
81 store,
82 policy,
83 table,
84 config,
85 pending_batches: Vec::new(),
86 pending_embeddings: Vec::new(),
87 buffered_bytes: 0,
88 last_flush: Instant::now(),
89 }
90 }
91
92 pub async fn insert(
95 &mut self,
96 batch: &RecordBatch,
97 embeddings: &[Vec<f32>],
98 ) -> AilakeResult<Option<SnapshotId>> {
99 let row_bytes =
100 embeddings.len() * self.policy.dim as usize * self.policy.precision.bytes_per_element();
101
102 self.pending_batches.push(batch.clone());
103 self.pending_embeddings.extend_from_slice(embeddings);
104 self.buffered_bytes += row_bytes;
105
106 if self.buffered_bytes >= self.config.flush_size_bytes
107 || self.pending_embeddings.len() >= self.config.flush_max_rows
108 {
109 Ok(Some(self.flush().await?))
110 } else {
111 Ok(None)
112 }
113 }
114
115 pub async fn flush_if_due(&mut self) -> AilakeResult<Option<SnapshotId>> {
118 if self.pending_embeddings.is_empty() {
119 self.last_flush = Instant::now();
120 return Ok(None);
121 }
122 if self.last_flush.elapsed() >= self.config.flush_interval {
123 Ok(Some(self.flush().await?))
124 } else {
125 Ok(None)
126 }
127 }
128
129 pub async fn flush(&mut self) -> AilakeResult<SnapshotId> {
133 if self.pending_embeddings.is_empty() {
134 self.last_flush = Instant::now();
135 return Ok(0);
136 }
137
138 let merged = if self.pending_batches.len() == 1 {
140 self.pending_batches.remove(0)
141 } else {
142 concat_batches(&self.pending_batches[0].schema(), &self.pending_batches)
143 .map_err(|e| ailake_core::AilakeError::Arrow(e.to_string()))?
144 };
145 let embeddings = std::mem::take(&mut self.pending_embeddings);
146 self.pending_batches.clear();
147 self.buffered_bytes = 0;
148 self.last_flush = Instant::now();
149
150 let mut writer = TableWriter::new(
152 self.catalog.clone(),
153 self.store.clone(),
154 self.policy.clone(),
155 self.table.clone(),
156 );
157 writer.write_batch_deferred(&merged, &embeddings).await?;
158 let snap = writer.commit().await?;
159 Ok(snap)
160 }
161
162 pub fn buffered_rows(&self) -> usize {
164 self.pending_embeddings.len()
165 }
166
167 pub fn buffered_bytes(&self) -> usize {
169 self.buffered_bytes
170 }
171
172 pub fn is_full(&self) -> bool {
174 self.buffered_bytes >= self.config.flush_size_bytes
175 || self.pending_embeddings.len() >= self.config.flush_max_rows
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use std::sync::Arc;
183
184 use ailake_catalog::{HadoopCatalog, TableIdent};
185 use ailake_core::{VectorMetric, VectorPrecision};
186 use ailake_store::{LocalStore, Store};
187 use arrow_array::{Int32Array, StringArray};
188 use arrow_schema::{DataType, Field, Schema};
189
190 fn make_policy() -> VectorStoragePolicy {
191 VectorStoragePolicy {
192 column_name: "embedding".to_string(),
193 dim: 4,
194 metric: VectorMetric::Euclidean,
195 precision: VectorPrecision::F16,
196 pq: None,
197 keep_raw_for_reranking: false,
198 }
199 }
200
201 fn make_batch(ids: &[i32]) -> RecordBatch {
202 let schema = Arc::new(Schema::new(vec![
203 Field::new("id", DataType::Int32, false),
204 Field::new("text", DataType::Utf8, false),
205 ]));
206 let texts: Vec<&str> = ids.iter().map(|_| "chunk").collect();
207 RecordBatch::try_new(
208 schema,
209 vec![
210 Arc::new(Int32Array::from(ids.to_vec())),
211 Arc::new(StringArray::from(texts)),
212 ],
213 )
214 .unwrap()
215 }
216
217 fn make_embeddings(n: usize, dim: usize) -> Vec<Vec<f32>> {
218 (0..n)
219 .map(|i| (0..dim).map(|d| i as f32 + d as f32 * 0.1).collect())
220 .collect()
221 }
222
223 async fn setup_table(catalog: &HadoopCatalog, table: &TableIdent) {
224 catalog
225 .create_table(
226 table,
227 &ailake_catalog::TableProperties {
228 policy: make_policy(),
229 extra: Default::default(),
230 },
231 )
232 .await
233 .unwrap();
234 }
235
236 #[tokio::test]
237 async fn mem_table_insert_and_flush() {
238 let dir = tempfile::tempdir().unwrap();
239 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
240 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
241 let table = TableIdent::new("default", "test_mem");
242 setup_table(&catalog, &table).await;
243
244 let config = MemTableConfig {
245 flush_size_bytes: 1024 * 1024,
246 flush_max_rows: 1000,
247 flush_interval: Duration::from_secs(60),
248 };
249 let mut mt = MemTableWriter::new(
250 catalog.clone(),
251 store.clone(),
252 make_policy(),
253 table.clone(),
254 config,
255 );
256
257 for i in 0..3 {
258 let ids: Vec<i32> = (i * 5..(i + 1) * 5).collect();
259 let batch = make_batch(&ids);
260 let embs = make_embeddings(5, 4);
261 let snap = mt.insert(&batch, &embs).await.unwrap();
262 assert!(snap.is_none(), "should not auto-flush yet");
263 }
264 assert_eq!(mt.buffered_rows(), 15);
265
266 let snap = mt.flush().await.unwrap();
267 assert!(snap > 0, "snapshot id should be non-zero");
268 assert_eq!(mt.buffered_rows(), 0, "buffer should be empty after flush");
269 assert_eq!(mt.buffered_bytes(), 0);
270 }
271
272 #[tokio::test]
273 async fn mem_table_auto_flush_on_row_limit() {
274 let dir = tempfile::tempdir().unwrap();
275 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
276 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
277 let table = TableIdent::new("default", "test_auto");
278 setup_table(&catalog, &table).await;
279
280 let config = MemTableConfig {
281 flush_size_bytes: 1024 * 1024 * 1024,
282 flush_max_rows: 8,
283 flush_interval: Duration::from_secs(60),
284 };
285 let mut mt = MemTableWriter::new(
286 catalog.clone(),
287 store.clone(),
288 make_policy(),
289 table.clone(),
290 config,
291 );
292
293 let batch = make_batch(&[1, 2, 3, 4, 5]);
294 let embs = make_embeddings(5, 4);
295 assert!(mt.insert(&batch, &embs).await.unwrap().is_none());
296
297 let batch2 = make_batch(&[6, 7, 8, 9, 10]);
298 let snap = mt.insert(&batch2, &embs).await.unwrap();
299 assert!(snap.is_some(), "should auto-flush when row limit exceeded");
300 assert_eq!(mt.buffered_rows(), 0);
301 }
302
303 #[tokio::test]
304 async fn mem_table_flush_if_due() {
305 let dir = tempfile::tempdir().unwrap();
306 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
307 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
308 let table = TableIdent::new("default", "test_due");
309 setup_table(&catalog, &table).await;
310
311 let config = MemTableConfig {
312 flush_size_bytes: 1024 * 1024 * 1024,
313 flush_max_rows: 1000,
314 flush_interval: Duration::from_millis(1),
315 };
316 let mut mt = MemTableWriter::new(
317 catalog.clone(),
318 store.clone(),
319 make_policy(),
320 table.clone(),
321 config,
322 );
323
324 let batch = make_batch(&[1, 2, 3]);
325 let embs = make_embeddings(3, 4);
326 mt.insert(&batch, &embs).await.unwrap();
327
328 tokio::time::sleep(Duration::from_millis(5)).await;
329
330 let snap = mt.flush_if_due().await.unwrap();
331 assert!(snap.is_some(), "should flush because interval elapsed");
332 assert_eq!(mt.buffered_rows(), 0);
333 }
334}