1use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use ailake_catalog::{CatalogProvider, SnapshotId, TableIdent};
6use ailake_core::{AilakeResult, VectorStoragePolicy};
7use ailake_store::Store;
8use arrow_array::RecordBatch;
9use arrow_select::concat::concat_batches;
10
11use crate::writer::TableWriter;
12
13#[derive(Debug, Clone)]
15pub struct MemTableConfig {
16 pub flush_size_bytes: usize,
19 pub flush_max_rows: usize,
22 pub flush_interval: Duration,
25}
26
27impl Default for MemTableConfig {
28 fn default() -> Self {
29 Self {
30 flush_size_bytes: 64 * 1024 * 1024,
31 flush_max_rows: 100_000,
32 flush_interval: Duration::from_secs(30),
33 }
34 }
35}
36
37pub struct MemTableWriter {
59 catalog: Arc<dyn CatalogProvider>,
60 store: Arc<dyn Store>,
61 policy: VectorStoragePolicy,
62 table: TableIdent,
63 config: MemTableConfig,
64
65 pending_batches: Vec<RecordBatch>,
67 pending_embeddings: Vec<Vec<f32>>,
68 buffered_bytes: usize,
69 last_flush: Instant,
70}
71
72impl MemTableWriter {
73 pub fn new(
74 catalog: Arc<dyn CatalogProvider>,
75 store: Arc<dyn Store>,
76 policy: VectorStoragePolicy,
77 table: TableIdent,
78 config: MemTableConfig,
79 ) -> Self {
80 Self {
81 catalog,
82 store,
83 policy,
84 table,
85 config,
86 pending_batches: Vec::new(),
87 pending_embeddings: Vec::new(),
88 buffered_bytes: 0,
89 last_flush: Instant::now(),
90 }
91 }
92
93 pub async fn insert(
96 &mut self,
97 batch: &RecordBatch,
98 embeddings: &[Vec<f32>],
99 ) -> AilakeResult<Option<SnapshotId>> {
100 let row_bytes =
101 embeddings.len() * self.policy.dim as usize * self.policy.precision.bytes_per_element();
102
103 self.pending_batches.push(batch.clone());
104 self.pending_embeddings.extend_from_slice(embeddings);
105 self.buffered_bytes += row_bytes;
106
107 if self.buffered_bytes >= self.config.flush_size_bytes
108 || self.pending_embeddings.len() >= self.config.flush_max_rows
109 {
110 Ok(Some(self.flush().await?))
111 } else {
112 Ok(None)
113 }
114 }
115
116 pub async fn flush_if_due(&mut self) -> AilakeResult<Option<SnapshotId>> {
119 if self.pending_embeddings.is_empty() {
120 self.last_flush = Instant::now();
121 return Ok(None);
122 }
123 if self.last_flush.elapsed() >= self.config.flush_interval {
124 Ok(Some(self.flush().await?))
125 } else {
126 Ok(None)
127 }
128 }
129
130 pub async fn flush(&mut self) -> AilakeResult<SnapshotId> {
134 if self.pending_embeddings.is_empty() {
135 self.last_flush = Instant::now();
136 return Ok(0);
137 }
138
139 let merged = if self.pending_batches.len() == 1 {
141 self.pending_batches.remove(0)
142 } else {
143 concat_batches(&self.pending_batches[0].schema(), &self.pending_batches)
144 .map_err(|e| ailake_core::AilakeError::Arrow(e.to_string()))?
145 };
146 let embeddings = std::mem::take(&mut self.pending_embeddings);
147 self.pending_batches.clear();
148 self.buffered_bytes = 0;
149 self.last_flush = Instant::now();
150
151 let mut writer = TableWriter::new(
153 self.catalog.clone(),
154 self.store.clone(),
155 self.policy.clone(),
156 self.table.clone(),
157 );
158 writer.write_batch_deferred(&merged, &embeddings).await?;
159 let snap = writer.commit().await?;
160 Ok(snap)
161 }
162
163 pub fn buffered_rows(&self) -> usize {
165 self.pending_embeddings.len()
166 }
167
168 pub fn buffered_bytes(&self) -> usize {
170 self.buffered_bytes
171 }
172
173 pub fn is_full(&self) -> bool {
175 self.buffered_bytes >= self.config.flush_size_bytes
176 || self.pending_embeddings.len() >= self.config.flush_max_rows
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use std::sync::Arc;
184
185 use ailake_catalog::{HadoopCatalog, TableIdent};
186 use ailake_core::{VectorMetric, VectorPrecision};
187 use ailake_store::{LocalStore, Store};
188 use arrow_array::{Int32Array, StringArray};
189 use arrow_schema::{DataType, Field, Schema};
190
191 fn make_policy() -> VectorStoragePolicy {
192 VectorStoragePolicy {
193 column_name: "embedding".to_string(),
194 dim: 4,
195 metric: VectorMetric::Euclidean,
196 precision: VectorPrecision::F16,
197 pq: None,
198 keep_raw_for_reranking: false,
199 }
200 }
201
202 fn make_batch(ids: &[i32]) -> RecordBatch {
203 let schema = Arc::new(Schema::new(vec![
204 Field::new("id", DataType::Int32, false),
205 Field::new("text", DataType::Utf8, false),
206 ]));
207 let texts: Vec<&str> = ids.iter().map(|_| "chunk").collect();
208 RecordBatch::try_new(
209 schema,
210 vec![
211 Arc::new(Int32Array::from(ids.to_vec())),
212 Arc::new(StringArray::from(texts)),
213 ],
214 )
215 .unwrap()
216 }
217
218 fn make_embeddings(n: usize, dim: usize) -> Vec<Vec<f32>> {
219 (0..n)
220 .map(|i| (0..dim).map(|d| i as f32 + d as f32 * 0.1).collect())
221 .collect()
222 }
223
224 async fn setup_table(catalog: &HadoopCatalog, table: &TableIdent) {
225 catalog
226 .create_table(
227 table,
228 &ailake_catalog::TableProperties {
229 policy: make_policy(),
230 extra: Default::default(),
231 },
232 )
233 .await
234 .unwrap();
235 }
236
237 #[tokio::test]
238 async fn mem_table_insert_and_flush() {
239 let dir = tempfile::tempdir().unwrap();
240 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
241 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
242 let table = TableIdent::new("default", "test_mem");
243 setup_table(&catalog, &table).await;
244
245 let config = MemTableConfig {
246 flush_size_bytes: 1024 * 1024,
247 flush_max_rows: 1000,
248 flush_interval: Duration::from_secs(60),
249 };
250 let mut mt = MemTableWriter::new(
251 catalog.clone(),
252 store.clone(),
253 make_policy(),
254 table.clone(),
255 config,
256 );
257
258 for i in 0..3 {
259 let ids: Vec<i32> = (i * 5..(i + 1) * 5).collect();
260 let batch = make_batch(&ids);
261 let embs = make_embeddings(5, 4);
262 let snap = mt.insert(&batch, &embs).await.unwrap();
263 assert!(snap.is_none(), "should not auto-flush yet");
264 }
265 assert_eq!(mt.buffered_rows(), 15);
266
267 let snap = mt.flush().await.unwrap();
268 assert!(snap > 0, "snapshot id should be non-zero");
269 assert_eq!(mt.buffered_rows(), 0, "buffer should be empty after flush");
270 assert_eq!(mt.buffered_bytes(), 0);
271 }
272
273 #[tokio::test]
274 async fn mem_table_auto_flush_on_row_limit() {
275 let dir = tempfile::tempdir().unwrap();
276 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
277 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
278 let table = TableIdent::new("default", "test_auto");
279 setup_table(&catalog, &table).await;
280
281 let config = MemTableConfig {
282 flush_size_bytes: 1024 * 1024 * 1024,
283 flush_max_rows: 8,
284 flush_interval: Duration::from_secs(60),
285 };
286 let mut mt = MemTableWriter::new(
287 catalog.clone(),
288 store.clone(),
289 make_policy(),
290 table.clone(),
291 config,
292 );
293
294 let batch = make_batch(&[1, 2, 3, 4, 5]);
295 let embs = make_embeddings(5, 4);
296 assert!(mt.insert(&batch, &embs).await.unwrap().is_none());
297
298 let batch2 = make_batch(&[6, 7, 8, 9, 10]);
299 let snap = mt.insert(&batch2, &embs).await.unwrap();
300 assert!(snap.is_some(), "should auto-flush when row limit exceeded");
301 assert_eq!(mt.buffered_rows(), 0);
302 }
303
304 #[tokio::test]
305 async fn mem_table_flush_if_due() {
306 let dir = tempfile::tempdir().unwrap();
307 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
308 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
309 let table = TableIdent::new("default", "test_due");
310 setup_table(&catalog, &table).await;
311
312 let config = MemTableConfig {
313 flush_size_bytes: 1024 * 1024 * 1024,
314 flush_max_rows: 1000,
315 flush_interval: Duration::from_millis(1),
316 };
317 let mut mt = MemTableWriter::new(
318 catalog.clone(),
319 store.clone(),
320 make_policy(),
321 table.clone(),
322 config,
323 );
324
325 let batch = make_batch(&[1, 2, 3]);
326 let embs = make_embeddings(3, 4);
327 mt.insert(&batch, &embs).await.unwrap();
328
329 tokio::time::sleep(Duration::from_millis(5)).await;
330
331 let snap = mt.flush_if_due().await.unwrap();
332 assert!(snap.is_some(), "should flush because interval elapsed");
333 assert_eq!(mt.buffered_rows(), 0);
334 }
335}