1use std::sync::atomic::{AtomicUsize, Ordering};
27use std::sync::Arc;
28
29use arrow::array::{ArrayRef, BinaryBuilder, Float64Builder, Int64Builder, StringBuilder};
30use arrow::datatypes::{DataType, Field, Schema};
31use arrow::record_batch::RecordBatch;
32use tracing::debug;
33
34use crate::error::RusqliteOltpError;
35
36const ARROW_BATCH_ROWS: usize = 8192;
38
39const SCHEMA_WARMUP_ROWS: usize = 32_768;
66
67pub struct RusqliteEngine {
73 db_path: String,
75 write_conn: rhei_tokio_rusqlite::Connection,
76 read_pool: Vec<rhei_tokio_rusqlite::Connection>,
77 read_idx: AtomicUsize,
78}
79
80impl RusqliteEngine {
81 pub async fn new_local(path: &str, read_pool_size: usize) -> Result<Self, RusqliteOltpError> {
86 let write_conn = rhei_tokio_rusqlite::Connection::open(path).await?;
87
88 write_conn
90 .call(|conn| {
91 conn.execute_batch(
92 "PRAGMA journal_mode=WAL;
93 PRAGMA synchronous=NORMAL;
94 PRAGMA busy_timeout=5000;",
95 )?;
96 Ok(())
97 })
98 .await?;
99
100 let pool_size = read_pool_size.max(1);
101 let mut read_pool = Vec::with_capacity(pool_size);
102 for _ in 0..pool_size {
103 let read_conn = rhei_tokio_rusqlite::Connection::open(path).await?;
104 read_conn
105 .call(|conn| {
106 conn.execute_batch("PRAGMA busy_timeout=5000;")?;
107 Ok(())
108 })
109 .await?;
110 read_pool.push(read_conn);
111 }
112
113 Ok(Self {
114 db_path: path.to_string(),
115 write_conn,
116 read_pool,
117 read_idx: AtomicUsize::new(0),
118 })
119 }
120
121 pub async fn new_connection(
126 &self,
127 ) -> Result<rhei_tokio_rusqlite::Connection, RusqliteOltpError> {
128 let conn = rhei_tokio_rusqlite::Connection::open(&self.db_path).await?;
129 conn.call(|c| {
130 c.execute_batch("PRAGMA busy_timeout=5000;")?;
131 Ok(())
132 })
133 .await?;
134 Ok(conn)
135 }
136
137 pub fn connection(&self) -> rhei_tokio_rusqlite::Connection {
140 self.write_conn.clone()
141 }
142
143 fn next_read_conn(&self) -> &rhei_tokio_rusqlite::Connection {
144 let idx = self.read_idx.fetch_add(1, Ordering::Relaxed) % self.read_pool.len();
145 &self.read_pool[idx]
146 }
147}
148
149fn json_to_rusqlite(val: &serde_json::Value) -> rusqlite::types::Value {
151 match val {
152 serde_json::Value::Null => rusqlite::types::Value::Null,
153 serde_json::Value::Bool(b) => rusqlite::types::Value::Integer(if *b { 1 } else { 0 }),
154 serde_json::Value::Number(n) => {
155 if let Some(i) = n.as_i64() {
156 rusqlite::types::Value::Integer(i)
157 } else if let Some(f) = n.as_f64() {
158 rusqlite::types::Value::Real(f)
159 } else {
160 rusqlite::types::Value::Text(n.to_string())
161 }
162 }
163 serde_json::Value::String(s) => rusqlite::types::Value::Text(s.clone()),
164 other => rusqlite::types::Value::Text(other.to_string()),
165 }
166}
167
168fn build_batch(
170 chunk: &[Vec<rusqlite::types::Value>],
171 schema: &Arc<Schema>,
172) -> Result<RecordBatch, rhei_tokio_rusqlite::Error> {
173 let col_count = schema.fields().len();
174 let mut columns: Vec<ArrayRef> = Vec::with_capacity(col_count);
175
176 for col_idx in 0..col_count {
177 let dt = schema.field(col_idx).data_type();
178 let array: ArrayRef = match dt {
179 DataType::Int64 => {
180 let mut b = Int64Builder::with_capacity(chunk.len());
181 for row in chunk {
182 match &row[col_idx] {
183 rusqlite::types::Value::Integer(i) => b.append_value(*i),
184 _ => b.append_null(),
185 }
186 }
187 Arc::new(b.finish())
188 }
189 DataType::Float64 => {
190 let mut b = Float64Builder::with_capacity(chunk.len());
191 for row in chunk {
192 match &row[col_idx] {
193 rusqlite::types::Value::Real(f) => b.append_value(*f),
194 rusqlite::types::Value::Integer(i) => b.append_value(*i as f64),
195 _ => b.append_null(),
196 }
197 }
198 Arc::new(b.finish())
199 }
200 DataType::Binary => {
201 let mut b = BinaryBuilder::with_capacity(chunk.len(), chunk.len() * 16);
202 for row in chunk {
203 match &row[col_idx] {
204 rusqlite::types::Value::Blob(bytes) => b.append_value(bytes.as_slice()),
205 _ => b.append_null(),
206 }
207 }
208 Arc::new(b.finish())
209 }
210 _ => {
211 let mut b = StringBuilder::with_capacity(chunk.len(), chunk.len() * 8);
213 for row in chunk {
214 match &row[col_idx] {
215 rusqlite::types::Value::Text(s) => b.append_value(s.as_str()),
216 rusqlite::types::Value::Integer(i) => b.append_value(i.to_string()),
217 rusqlite::types::Value::Real(f) => b.append_value(f.to_string()),
218 rusqlite::types::Value::Null => b.append_null(),
219 rusqlite::types::Value::Blob(_) => b.append_null(),
220 }
221 }
222 Arc::new(b.finish())
223 }
224 };
225 columns.push(array);
226 }
227
228 RecordBatch::try_new(Arc::clone(schema), columns)
229 .map_err(|e| rhei_tokio_rusqlite::Error::Other(format!("Arrow error: {e}")))
230}
231
232fn query_to_arrow(
251 conn: &mut rusqlite::Connection,
252 sql: &str,
253 params: &[rusqlite::types::Value],
254) -> Result<Vec<RecordBatch>, rhei_tokio_rusqlite::Error> {
255 let mut stmt = conn.prepare(sql)?;
256
257 let col_count = stmt.column_count();
258 if col_count == 0 {
259 return Ok(vec![]);
261 }
262
263 let col_names: Vec<String> = stmt
264 .column_names()
265 .into_iter()
266 .map(|s| s.to_string())
267 .collect();
268
269 let params_refs: Vec<&dyn rusqlite::types::ToSql> = params
270 .iter()
271 .map(|v| v as &dyn rusqlite::types::ToSql)
272 .collect();
273
274 let mut rows = stmt.query(params_refs.as_slice())?;
275
276 let mut hints: Vec<Option<DataType>> = vec![None; col_count];
280 let mut warmup: Vec<Vec<rusqlite::types::Value>> = Vec::with_capacity(SCHEMA_WARMUP_ROWS);
281
282 while warmup.len() < SCHEMA_WARMUP_ROWS {
283 match rows.next()? {
284 None => break,
285 Some(row) => {
286 let mut vals = Vec::with_capacity(col_count);
287 for (col_idx, hint) in hints.iter_mut().enumerate() {
288 let val: rusqlite::types::Value = row.get(col_idx)?;
289 if hint.is_none() {
291 *hint = match &val {
292 rusqlite::types::Value::Integer(_) => Some(DataType::Int64),
293 rusqlite::types::Value::Real(_) => Some(DataType::Float64),
294 rusqlite::types::Value::Text(_) => Some(DataType::Utf8),
295 rusqlite::types::Value::Blob(_) => Some(DataType::Binary),
296 rusqlite::types::Value::Null => None,
297 };
298 }
299 vals.push(val);
300 }
301 warmup.push(vals);
302 }
303 }
304 }
305
306 let schema = Arc::new(Schema::new(
308 hints
309 .into_iter()
310 .zip(col_names.iter())
311 .map(|(hint, name)| {
312 let dt = hint.unwrap_or(DataType::Utf8);
313 Field::new(name, dt, true)
314 })
315 .collect::<Vec<_>>(),
316 ));
317
318 let mut batches: Vec<RecordBatch> = Vec::new();
320 let mut chunk: Vec<Vec<rusqlite::types::Value>> = Vec::with_capacity(ARROW_BATCH_ROWS);
321
322 let flush_chunk = |chunk: &mut Vec<Vec<rusqlite::types::Value>>,
324 schema: &Arc<Schema>,
325 batches: &mut Vec<RecordBatch>|
326 -> Result<(), rhei_tokio_rusqlite::Error> {
327 if chunk.is_empty() {
328 return Ok(());
329 }
330 batches.push(build_batch(chunk, schema)?);
331 chunk.clear();
332 Ok(())
333 };
334
335 for row_vals in warmup {
337 chunk.push(row_vals);
338 if chunk.len() >= ARROW_BATCH_ROWS {
339 flush_chunk(&mut chunk, &schema, &mut batches)?;
340 }
341 }
342
343 while let Some(row) = rows.next()? {
345 let mut vals = Vec::with_capacity(col_count);
346 for i in 0..col_count {
347 let val: rusqlite::types::Value = row.get(i)?;
348 vals.push(val);
349 }
350 chunk.push(vals);
351
352 if chunk.len() >= ARROW_BATCH_ROWS {
353 flush_chunk(&mut chunk, &schema, &mut batches)?;
354 }
355 }
356
357 flush_chunk(&mut chunk, &schema, &mut batches)?;
359
360 if batches.is_empty() {
361 batches.push(RecordBatch::new_empty(Arc::clone(&schema)));
363 }
364
365 Ok(batches)
366}
367
368impl rhei_core::OltpEngine for RusqliteEngine {
369 type Error = RusqliteOltpError;
370
371 async fn query(
383 &self,
384 sql: &str,
385 params: &[serde_json::Value],
386 ) -> Result<Vec<RecordBatch>, Self::Error> {
387 debug!(sql, params_count = params.len(), "OLTP rusqlite query");
388 let rusqlite_params: Vec<rusqlite::types::Value> =
389 params.iter().map(json_to_rusqlite).collect();
390 let sql_owned = sql.to_string();
391
392 let conn = self.next_read_conn();
393 let batches = conn
394 .call(move |c| query_to_arrow(c, &sql_owned, &rusqlite_params))
395 .await?;
396 Ok(batches)
397 }
398
399 async fn execute(&self, sql: &str, params: &[serde_json::Value]) -> Result<u64, Self::Error> {
410 debug!(sql, params_count = params.len(), "OLTP rusqlite execute");
411 let rusqlite_params: Vec<rusqlite::types::Value> =
412 params.iter().map(json_to_rusqlite).collect();
413 let sql_owned = sql.to_string();
414
415 let rows_affected = self
416 .write_conn
417 .call(move |c| {
418 let params_refs: Vec<&dyn rusqlite::types::ToSql> = rusqlite_params
419 .iter()
420 .map(|v| v as &dyn rusqlite::types::ToSql)
421 .collect();
422 let changed = c.execute(&sql_owned, params_refs.as_slice())?;
423 Ok(changed as u64)
424 })
425 .await?;
426 Ok(rows_affected)
427 }
428
429 async fn execute_batch(
440 &self,
441 statements: &[(String, Vec<serde_json::Value>)],
442 ) -> Result<(), Self::Error> {
443 debug!(count = statements.len(), "OLTP rusqlite execute_batch");
444 let stmts: Vec<(String, Vec<rusqlite::types::Value>)> = statements
446 .iter()
447 .map(|(sql, params)| {
448 let rusqlite_params: Vec<rusqlite::types::Value> =
449 params.iter().map(json_to_rusqlite).collect();
450 (sql.clone(), rusqlite_params)
451 })
452 .collect();
453
454 self.write_conn
455 .call(move |c| {
456 let tx = c.transaction()?;
457 for (sql, params) in &stmts {
458 let params_refs: Vec<&dyn rusqlite::types::ToSql> = params
459 .iter()
460 .map(|v| v as &dyn rusqlite::types::ToSql)
461 .collect();
462 tx.execute(sql, params_refs.as_slice())?;
463 }
464 tx.commit()?;
465 Ok(())
466 })
467 .await?;
468 Ok(())
469 }
470
471 async fn table_exists(&self, table_name: &str) -> Result<bool, Self::Error> {
479 let tbl = table_name.to_string();
480 let conn = self.next_read_conn();
481 let exists = conn
482 .call(move |c| {
483 let count: i64 = c.query_row(
484 "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?1",
485 rusqlite::params![tbl],
486 |row| row.get(0),
487 )?;
488 Ok(count > 0)
489 })
490 .await?;
491 Ok(exists)
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498 use arrow::array::{Array, Int64Array};
499 use rhei_core::OltpEngine;
500
501 #[tokio::test]
502 async fn test_basic_crud() {
503 let dir = tempfile::TempDir::new().unwrap();
504 let path = dir.path().join("test.db");
505 let engine = RusqliteEngine::new_local(path.to_str().unwrap(), 2)
506 .await
507 .unwrap();
508
509 engine
511 .execute(
512 "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)",
513 &[],
514 )
515 .await
516 .unwrap();
517
518 engine
520 .execute(
521 "INSERT INTO users (id, name, age) VALUES (?1, ?2, ?3)",
522 &[
523 serde_json::json!(1),
524 serde_json::json!("Alice"),
525 serde_json::json!(30),
526 ],
527 )
528 .await
529 .unwrap();
530
531 let batches = engine.query("SELECT * FROM users", &[]).await.unwrap();
533 assert_eq!(batches.len(), 1);
534 assert_eq!(batches[0].num_rows(), 1);
535 assert_eq!(batches[0].num_columns(), 3);
536
537 assert!(engine.table_exists("users").await.unwrap());
539 assert!(!engine.table_exists("nonexistent").await.unwrap());
540 }
541
542 #[tokio::test]
543 async fn test_execute_batch() {
544 let dir = tempfile::TempDir::new().unwrap();
545 let path = dir.path().join("test_batch.db");
546 let engine = RusqliteEngine::new_local(path.to_str().unwrap(), 1)
547 .await
548 .unwrap();
549
550 engine
551 .execute("CREATE TABLE items (id INTEGER PRIMARY KEY, val TEXT)", &[])
552 .await
553 .unwrap();
554
555 let stmts: Vec<(String, Vec<serde_json::Value>)> = vec![
556 (
557 "INSERT INTO items (id, val) VALUES (?1, ?2)".to_string(),
558 vec![serde_json::json!(1), serde_json::json!("a")],
559 ),
560 (
561 "INSERT INTO items (id, val) VALUES (?1, ?2)".to_string(),
562 vec![serde_json::json!(2), serde_json::json!("b")],
563 ),
564 (
565 "INSERT INTO items (id, val) VALUES (?1, ?2)".to_string(),
566 vec![serde_json::json!(3), serde_json::json!("c")],
567 ),
568 ];
569
570 engine.execute_batch(&stmts).await.unwrap();
571
572 let batches = engine
573 .query("SELECT COUNT(*) FROM items", &[])
574 .await
575 .unwrap();
576 let count = batches[0]
577 .column(0)
578 .as_any()
579 .downcast_ref::<Int64Array>()
580 .unwrap()
581 .value(0);
582 assert_eq!(count, 3);
583 }
584
585 #[tokio::test]
586 async fn test_chunked_query() {
587 let dir = tempfile::TempDir::new().unwrap();
588 let path = dir.path().join("test_chunks.db");
589 let engine = RusqliteEngine::new_local(path.to_str().unwrap(), 1)
590 .await
591 .unwrap();
592
593 engine
594 .execute("CREATE TABLE big (id INTEGER PRIMARY KEY, val TEXT)", &[])
595 .await
596 .unwrap();
597
598 let stmts: Vec<(String, Vec<serde_json::Value>)> = (0..20_000u64)
600 .map(|i| {
601 (
602 "INSERT INTO big (id, val) VALUES (?1, ?2)".to_string(),
603 vec![serde_json::json!(i), serde_json::json!(format!("v{i}"))],
604 )
605 })
606 .collect();
607 engine.execute_batch(&stmts).await.unwrap();
608
609 let batches = engine.query("SELECT * FROM big", &[]).await.unwrap();
610
611 assert!(
614 batches.len() > 1,
615 "expected multiple RecordBatches, got {}",
616 batches.len()
617 );
618
619 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
620 assert_eq!(total_rows, 20_000, "total row count mismatch");
621
622 let schema = batches[0].schema();
624 for (i, batch) in batches.iter().enumerate() {
625 assert_eq!(batch.schema(), schema, "schema mismatch on batch {i}");
626 }
627 }
628
629 #[tokio::test]
630 async fn test_new_connection() {
631 let dir = tempfile::TempDir::new().unwrap();
632 let path = dir.path().join("test_conn.db");
633 let engine = RusqliteEngine::new_local(path.to_str().unwrap(), 1)
634 .await
635 .unwrap();
636
637 engine
638 .execute("CREATE TABLE t (id INTEGER PRIMARY KEY)", &[])
639 .await
640 .unwrap();
641
642 let conn = engine.new_connection().await.unwrap();
644 let exists: bool = conn
645 .call(|c| {
646 let count: i64 = c.query_row(
647 "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='t'",
648 [],
649 |row| row.get(0),
650 )?;
651 Ok(count > 0)
652 })
653 .await
654 .unwrap();
655 assert!(exists);
656 }
657
658 #[tokio::test]
665 async fn test_sparse_column_type_inference() {
666 use arrow::array::Int64Array;
667 use arrow::datatypes::DataType;
668
669 let dir = tempfile::TempDir::new().unwrap();
670 let path = dir.path().join("test_sparse.db");
671 let engine = RusqliteEngine::new_local(path.to_str().unwrap(), 1)
672 .await
673 .unwrap();
674
675 engine
676 .execute(
677 "CREATE TABLE sparse (id INTEGER PRIMARY KEY, score INTEGER)",
678 &[],
679 )
680 .await
681 .unwrap();
682
683 const NULL_ROWS: usize = 10_000;
690 const INT_ROWS: usize = 100;
691
692 let null_stmts: Vec<(String, Vec<serde_json::Value>)> = (0..NULL_ROWS)
693 .map(|i| {
694 (
695 "INSERT INTO sparse (id, score) VALUES (?1, NULL)".to_string(),
696 vec![serde_json::json!(i)],
697 )
698 })
699 .collect();
700 engine.execute_batch(&null_stmts).await.unwrap();
701
702 let int_stmts: Vec<(String, Vec<serde_json::Value>)> = (NULL_ROWS..NULL_ROWS + INT_ROWS)
703 .map(|i| {
704 (
705 "INSERT INTO sparse (id, score) VALUES (?1, ?2)".to_string(),
706 vec![serde_json::json!(i), serde_json::json!(i as i64)],
707 )
708 })
709 .collect();
710 engine.execute_batch(&int_stmts).await.unwrap();
711
712 let batches = engine.query("SELECT * FROM sparse", &[]).await.unwrap();
713 assert!(!batches.is_empty(), "expected at least one batch");
714
715 let schema = batches[0].schema();
717 for (idx, batch) in batches.iter().enumerate() {
718 assert_eq!(batch.schema(), schema, "schema mismatch on batch {idx}");
719 }
720
721 let score_field = schema
723 .field_with_name("score")
724 .expect("score field missing");
725 assert_eq!(
726 score_field.data_type(),
727 &DataType::Int64,
728 "sparse column 'score' should be Int64; \
729 was it mis-typed as Utf8 due to early schema lock?"
730 );
731
732 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
734 assert_eq!(total_rows, NULL_ROWS + INT_ROWS);
735
736 let mut found_int = false;
738 for batch in &batches {
739 let score_col = batch.column_by_name("score").expect("score column missing");
740 if let Some(arr) = score_col.as_any().downcast_ref::<Int64Array>() {
741 for row in 0..arr.len() {
742 if arr.is_valid(row) {
743 assert!(
745 arr.value(row) >= NULL_ROWS as i64,
746 "unexpected integer value {}",
747 arr.value(row)
748 );
749 found_int = true;
750 }
751 }
752 }
753 }
754 assert!(
755 found_int,
756 "no non-null Int64 values found in 'score' column"
757 );
758 }
759}