Skip to main content

rhei_oltp_rusqlite/
engine.rs

1//! [`RusqliteEngine`]: the async SQLite OLTP engine.
2//!
3//! This module provides the central OLTP component of `rhei-oltp-rusqlite`.
4//! It wraps `rusqlite` via [`rhei_tokio_rusqlite`] to expose a fully async
5//! interface without any `unsafe` code.
6//!
7//! ## Connection topology
8//!
9//! | Connection | Count | Purpose |
10//! |------------|-------|---------|
11//! | Write | 1 | All `INSERT` / `UPDATE` / `DELETE` / DDL |
12//! | Read pool | N (default 4) | Concurrent `SELECT` via round-robin |
13//! | CDC | 1 (separate) | `_rhei_cdc_log` polling by [`RusqliteCdcProducer`][crate::RusqliteCdcProducer] |
14//!
15//! The write connection is opened with WAL mode, `synchronous=NORMAL`, and
16//! `busy_timeout=5000`.  Read pool connections only set `busy_timeout`.
17//!
18//! ## Arrow output
19//!
20//! [`rhei_core::OltpEngine::query`] returns `Vec<RecordBatch>`.  Because
21//! SQLite is dynamically typed, the Arrow schema is inferred from actual
22//! values rather than declared column types.  A warm-up window of
23//! `SCHEMA_WARMUP_ROWS` rows is buffered before the schema is committed, so
24//! sparse columns (NULL for many rows, then non-null) are typed correctly.
25
26use std::sync::atomic::{AtomicUsize, Ordering};
27use std::sync::Arc;
28
29use arrow::array::{ArrayRef, BinaryBuilder, Float64Builder, Int64Builder, StringBuilder};
30use arrow::datatypes::{DataType, Field, Schema};
31use arrow::record_batch::RecordBatch;
32use tracing::debug;
33
34use crate::error::RusqliteOltpError;
35
36/// Maximum number of rows per Arrow RecordBatch produced by [`query_to_arrow`].
37const ARROW_BATCH_ROWS: usize = 8192;
38
39/// Number of rows buffered for type-inference warm-up before the Arrow schema
40/// is committed.
41///
42/// ### Why this exists
43///
44/// SQLite columns have no enforced declared type — every cell is dynamically
45/// typed. `query_to_arrow` must infer the Arrow `DataType` for each column by
46/// inspecting actual values. If we commit to a schema after only the first
47/// chunk (8 192 rows), *sparse* columns that happen to be NULL in those rows
48/// are permanently mis-typed as `Utf8`, even if they contain `Int64` values
49/// further into the result set.
50///
51/// ### The fix: warm-up window
52///
53/// We accumulate up to `SCHEMA_WARMUP_ROWS` rows before choosing a schema.
54/// Each column's type is resolved as soon as its first non-null value appears.
55/// Any column still fully NULL after the warm-up window defaults to `Utf8`
56/// (safe because there is no later non-null evidence to contradict it).
57/// Once the schema is locked we flush the buffered rows in `ARROW_BATCH_ROWS`
58/// chunks and continue streaming the remainder of the cursor.
59///
60/// ### Trade-off
61///
62/// For a typical row of ~100 bytes the warm-up buffer peaks at ~3 MB — tiny
63/// relative to a 10 M-row dataset. Queries that return fewer rows than
64/// `SCHEMA_WARMUP_ROWS` incur no extra overhead (schema is locked at EOF).
65const SCHEMA_WARMUP_ROWS: usize = 32_768;
66
67/// Rusqlite-backed OLTP engine.
68///
69/// Uses a single dedicated write connection for DML/DDL and a round-robin pool
70/// of read connections for queries. WAL mode is enabled on startup so readers
71/// don't block the writer.
72pub struct RusqliteEngine {
73    /// The path used to open the database (stored for creating new connections).
74    db_path: String,
75    write_conn: rhei_tokio_rusqlite::Connection,
76    read_pool: Vec<rhei_tokio_rusqlite::Connection>,
77    read_idx: AtomicUsize,
78}
79
80impl RusqliteEngine {
81    /// Create a new engine backed by a local SQLite database file.
82    ///
83    /// `read_pool_size` controls how many concurrent read connections are
84    /// maintained. Values below 1 are clamped to 1.
85    pub async fn new_local(path: &str, read_pool_size: usize) -> Result<Self, RusqliteOltpError> {
86        let write_conn = rhei_tokio_rusqlite::Connection::open(path).await?;
87
88        // Enable WAL mode so read connections don't block the writer.
89        write_conn
90            .call(|conn| {
91                conn.execute_batch(
92                    "PRAGMA journal_mode=WAL;
93                     PRAGMA synchronous=NORMAL;
94                     PRAGMA busy_timeout=5000;",
95                )?;
96                Ok(())
97            })
98            .await?;
99
100        let pool_size = read_pool_size.max(1);
101        let mut read_pool = Vec::with_capacity(pool_size);
102        for _ in 0..pool_size {
103            let read_conn = rhei_tokio_rusqlite::Connection::open(path).await?;
104            read_conn
105                .call(|conn| {
106                    conn.execute_batch("PRAGMA busy_timeout=5000;")?;
107                    Ok(())
108                })
109                .await?;
110            read_pool.push(read_conn);
111        }
112
113        Ok(Self {
114            db_path: path.to_string(),
115            write_conn,
116            read_pool,
117            read_idx: AtomicUsize::new(0),
118        })
119    }
120
121    /// Create a fresh, independent connection to the same database file.
122    ///
123    /// Used to give `RusqliteCdcProducer` its own connection that doesn't compete
124    /// with application reads/writes.
125    pub async fn new_connection(
126        &self,
127    ) -> Result<rhei_tokio_rusqlite::Connection, RusqliteOltpError> {
128        let conn = rhei_tokio_rusqlite::Connection::open(&self.db_path).await?;
129        conn.call(|c| {
130            c.execute_batch("PRAGMA busy_timeout=5000;")?;
131            Ok(())
132        })
133        .await?;
134        Ok(conn)
135    }
136
137    /// Returns a clone of the write connection handle. Used by CDC trigger
138    /// setup/teardown (DDL).
139    pub fn connection(&self) -> rhei_tokio_rusqlite::Connection {
140        self.write_conn.clone()
141    }
142
143    fn next_read_conn(&self) -> &rhei_tokio_rusqlite::Connection {
144        let idx = self.read_idx.fetch_add(1, Ordering::Relaxed) % self.read_pool.len();
145        &self.read_pool[idx]
146    }
147}
148
149/// Convert a serde_json::Value to a rusqlite Value suitable for binding.
150fn json_to_rusqlite(val: &serde_json::Value) -> rusqlite::types::Value {
151    match val {
152        serde_json::Value::Null => rusqlite::types::Value::Null,
153        serde_json::Value::Bool(b) => rusqlite::types::Value::Integer(if *b { 1 } else { 0 }),
154        serde_json::Value::Number(n) => {
155            if let Some(i) = n.as_i64() {
156                rusqlite::types::Value::Integer(i)
157            } else if let Some(f) = n.as_f64() {
158                rusqlite::types::Value::Real(f)
159            } else {
160                rusqlite::types::Value::Text(n.to_string())
161            }
162        }
163        serde_json::Value::String(s) => rusqlite::types::Value::Text(s.clone()),
164        other => rusqlite::types::Value::Text(other.to_string()),
165    }
166}
167
168/// Build one Arrow [`RecordBatch`] from a row-major chunk of rusqlite values.
169fn build_batch(
170    chunk: &[Vec<rusqlite::types::Value>],
171    schema: &Arc<Schema>,
172) -> Result<RecordBatch, rhei_tokio_rusqlite::Error> {
173    let col_count = schema.fields().len();
174    let mut columns: Vec<ArrayRef> = Vec::with_capacity(col_count);
175
176    for col_idx in 0..col_count {
177        let dt = schema.field(col_idx).data_type();
178        let array: ArrayRef = match dt {
179            DataType::Int64 => {
180                let mut b = Int64Builder::with_capacity(chunk.len());
181                for row in chunk {
182                    match &row[col_idx] {
183                        rusqlite::types::Value::Integer(i) => b.append_value(*i),
184                        _ => b.append_null(),
185                    }
186                }
187                Arc::new(b.finish())
188            }
189            DataType::Float64 => {
190                let mut b = Float64Builder::with_capacity(chunk.len());
191                for row in chunk {
192                    match &row[col_idx] {
193                        rusqlite::types::Value::Real(f) => b.append_value(*f),
194                        rusqlite::types::Value::Integer(i) => b.append_value(*i as f64),
195                        _ => b.append_null(),
196                    }
197                }
198                Arc::new(b.finish())
199            }
200            DataType::Binary => {
201                let mut b = BinaryBuilder::with_capacity(chunk.len(), chunk.len() * 16);
202                for row in chunk {
203                    match &row[col_idx] {
204                        rusqlite::types::Value::Blob(bytes) => b.append_value(bytes.as_slice()),
205                        _ => b.append_null(),
206                    }
207                }
208                Arc::new(b.finish())
209            }
210            _ => {
211                // Utf8 and all-null fallback
212                let mut b = StringBuilder::with_capacity(chunk.len(), chunk.len() * 8);
213                for row in chunk {
214                    match &row[col_idx] {
215                        rusqlite::types::Value::Text(s) => b.append_value(s.as_str()),
216                        rusqlite::types::Value::Integer(i) => b.append_value(i.to_string()),
217                        rusqlite::types::Value::Real(f) => b.append_value(f.to_string()),
218                        rusqlite::types::Value::Null => b.append_null(),
219                        rusqlite::types::Value::Blob(_) => b.append_null(),
220                    }
221                }
222                Arc::new(b.finish())
223            }
224        };
225        columns.push(array);
226    }
227
228    RecordBatch::try_new(Arc::clone(schema), columns)
229        .map_err(|e| rhei_tokio_rusqlite::Error::Other(format!("Arrow error: {e}")))
230}
231
232/// Execute a query inside a `conn.call()` closure and convert results to Arrow
233/// [`RecordBatch`]es.
234///
235/// Rows are accumulated into a row buffer and flushed into Arrow
236/// [`RecordBatch`]es every [`ARROW_BATCH_ROWS`] rows so that large result sets
237/// never materialise as a single in-memory allocation.
238///
239/// ### Schema inference
240///
241/// SQLite is dynamically typed, so the Arrow schema cannot be derived from the
242/// column declarations alone — it must be inferred from the actual values.
243/// To handle *sparse* columns (NULL for many consecutive rows, non-null later),
244/// we use a **warm-up window**: the first [`SCHEMA_WARMUP_ROWS`] rows are
245/// buffered. Each column's Arrow type is resolved as soon as its first non-null
246/// value appears. Any column still fully NULL when the window is exhausted or
247/// the cursor is drained defaults to `Utf8`. The schema is then locked and the
248/// buffered rows are flushed in `ARROW_BATCH_ROWS`-sized chunks; subsequent
249/// rows stream through normally. See [`SCHEMA_WARMUP_ROWS`] for the rationale.
250fn query_to_arrow(
251    conn: &mut rusqlite::Connection,
252    sql: &str,
253    params: &[rusqlite::types::Value],
254) -> Result<Vec<RecordBatch>, rhei_tokio_rusqlite::Error> {
255    let mut stmt = conn.prepare(sql)?;
256
257    let col_count = stmt.column_count();
258    if col_count == 0 {
259        // Statement produces no columns (e.g., a write statement called via query).
260        return Ok(vec![]);
261    }
262
263    let col_names: Vec<String> = stmt
264        .column_names()
265        .into_iter()
266        .map(|s| s.to_string())
267        .collect();
268
269    let params_refs: Vec<&dyn rusqlite::types::ToSql> = params
270        .iter()
271        .map(|v| v as &dyn rusqlite::types::ToSql)
272        .collect();
273
274    let mut rows = stmt.query(params_refs.as_slice())?;
275
276    // --- Phase 1: warm-up buffer -------------------------------------------------
277    // Accumulate up to SCHEMA_WARMUP_ROWS rows while tracking per-column type hints.
278    // `None` means "still fully NULL for this column".
279    let mut hints: Vec<Option<DataType>> = vec![None; col_count];
280    let mut warmup: Vec<Vec<rusqlite::types::Value>> = Vec::with_capacity(SCHEMA_WARMUP_ROWS);
281
282    while warmup.len() < SCHEMA_WARMUP_ROWS {
283        match rows.next()? {
284            None => break,
285            Some(row) => {
286                let mut vals = Vec::with_capacity(col_count);
287                for (col_idx, hint) in hints.iter_mut().enumerate() {
288                    let val: rusqlite::types::Value = row.get(col_idx)?;
289                    // Lock in the type as soon as we see a non-null value.
290                    if hint.is_none() {
291                        *hint = match &val {
292                            rusqlite::types::Value::Integer(_) => Some(DataType::Int64),
293                            rusqlite::types::Value::Real(_) => Some(DataType::Float64),
294                            rusqlite::types::Value::Text(_) => Some(DataType::Utf8),
295                            rusqlite::types::Value::Blob(_) => Some(DataType::Binary),
296                            rusqlite::types::Value::Null => None,
297                        };
298                    }
299                    vals.push(val);
300                }
301                warmup.push(vals);
302            }
303        }
304    }
305
306    // Commit the schema: any column still None (all-NULL in warmup) defaults to Utf8.
307    let schema = Arc::new(Schema::new(
308        hints
309            .into_iter()
310            .zip(col_names.iter())
311            .map(|(hint, name)| {
312                let dt = hint.unwrap_or(DataType::Utf8);
313                Field::new(name, dt, true)
314            })
315            .collect::<Vec<_>>(),
316    ));
317
318    // --- Phase 2: flush warmup buffer + stream remainder ------------------------
319    let mut batches: Vec<RecordBatch> = Vec::new();
320    let mut chunk: Vec<Vec<rusqlite::types::Value>> = Vec::with_capacity(ARROW_BATCH_ROWS);
321
322    // Helper: flush `chunk` into a batch and clear it.
323    let flush_chunk = |chunk: &mut Vec<Vec<rusqlite::types::Value>>,
324                       schema: &Arc<Schema>,
325                       batches: &mut Vec<RecordBatch>|
326     -> Result<(), rhei_tokio_rusqlite::Error> {
327        if chunk.is_empty() {
328            return Ok(());
329        }
330        batches.push(build_batch(chunk, schema)?);
331        chunk.clear();
332        Ok(())
333    };
334
335    // Flush the warm-up rows first.
336    for row_vals in warmup {
337        chunk.push(row_vals);
338        if chunk.len() >= ARROW_BATCH_ROWS {
339            flush_chunk(&mut chunk, &schema, &mut batches)?;
340        }
341    }
342
343    // Stream the remainder of the cursor.
344    while let Some(row) = rows.next()? {
345        let mut vals = Vec::with_capacity(col_count);
346        for i in 0..col_count {
347            let val: rusqlite::types::Value = row.get(i)?;
348            vals.push(val);
349        }
350        chunk.push(vals);
351
352        if chunk.len() >= ARROW_BATCH_ROWS {
353            flush_chunk(&mut chunk, &schema, &mut batches)?;
354        }
355    }
356
357    // Flush any remaining rows.
358    flush_chunk(&mut chunk, &schema, &mut batches)?;
359
360    if batches.is_empty() {
361        // Query returned zero rows: emit one empty batch so callers always get a schema.
362        batches.push(RecordBatch::new_empty(Arc::clone(&schema)));
363    }
364
365    Ok(batches)
366}
367
368impl rhei_core::OltpEngine for RusqliteEngine {
369    type Error = RusqliteOltpError;
370
371    /// Execute a read-only SQL query and return the results as Arrow
372    /// [`RecordBatch`]es.
373    ///
374    /// Dispatches to the next connection in the round-robin read pool.
375    /// The Arrow schema is inferred from actual cell values (see the
376    /// `SCHEMA_WARMUP_ROWS` constant for the sparse-column strategy).
377    ///
378    /// # Errors
379    ///
380    /// Returns [`RusqliteOltpError`] if statement preparation, execution, or
381    /// Arrow conversion fails.
382    async fn query(
383        &self,
384        sql: &str,
385        params: &[serde_json::Value],
386    ) -> Result<Vec<RecordBatch>, Self::Error> {
387        debug!(sql, params_count = params.len(), "OLTP rusqlite query");
388        let rusqlite_params: Vec<rusqlite::types::Value> =
389            params.iter().map(json_to_rusqlite).collect();
390        let sql_owned = sql.to_string();
391
392        let conn = self.next_read_conn();
393        let batches = conn
394            .call(move |c| query_to_arrow(c, &sql_owned, &rusqlite_params))
395            .await?;
396        Ok(batches)
397    }
398
399    /// Execute a single write statement (`INSERT`, `UPDATE`, `DELETE`, or DDL)
400    /// and return the number of rows affected.
401    ///
402    /// Always dispatches to the single dedicated write connection.
403    ///
404    /// # Errors
405    ///
406    /// Returns [`RusqliteOltpError`] if the statement fails (e.g., constraint
407    /// violation, syntax error, or the database is locked beyond
408    /// `busy_timeout`).
409    async fn execute(&self, sql: &str, params: &[serde_json::Value]) -> Result<u64, Self::Error> {
410        debug!(sql, params_count = params.len(), "OLTP rusqlite execute");
411        let rusqlite_params: Vec<rusqlite::types::Value> =
412            params.iter().map(json_to_rusqlite).collect();
413        let sql_owned = sql.to_string();
414
415        let rows_affected = self
416            .write_conn
417            .call(move |c| {
418                let params_refs: Vec<&dyn rusqlite::types::ToSql> = rusqlite_params
419                    .iter()
420                    .map(|v| v as &dyn rusqlite::types::ToSql)
421                    .collect();
422                let changed = c.execute(&sql_owned, params_refs.as_slice())?;
423                Ok(changed as u64)
424            })
425            .await?;
426        Ok(rows_affected)
427    }
428
429    /// Execute a batch of write statements inside a single `BEGIN`…`COMMIT`
430    /// transaction.
431    ///
432    /// All statements succeed or the entire transaction is rolled back.
433    /// Dispatches to the dedicated write connection.
434    ///
435    /// # Errors
436    ///
437    /// Returns [`RusqliteOltpError`] if any statement fails or the transaction
438    /// cannot be committed.
439    async fn execute_batch(
440        &self,
441        statements: &[(String, Vec<serde_json::Value>)],
442    ) -> Result<(), Self::Error> {
443        debug!(count = statements.len(), "OLTP rusqlite execute_batch");
444        // Clone the statements so we can move them into the closure
445        let stmts: Vec<(String, Vec<rusqlite::types::Value>)> = statements
446            .iter()
447            .map(|(sql, params)| {
448                let rusqlite_params: Vec<rusqlite::types::Value> =
449                    params.iter().map(json_to_rusqlite).collect();
450                (sql.clone(), rusqlite_params)
451            })
452            .collect();
453
454        self.write_conn
455            .call(move |c| {
456                let tx = c.transaction()?;
457                for (sql, params) in &stmts {
458                    let params_refs: Vec<&dyn rusqlite::types::ToSql> = params
459                        .iter()
460                        .map(|v| v as &dyn rusqlite::types::ToSql)
461                        .collect();
462                    tx.execute(sql, params_refs.as_slice())?;
463                }
464                tx.commit()?;
465                Ok(())
466            })
467            .await?;
468        Ok(())
469    }
470
471    /// Return `true` if a table named `table_name` exists in `sqlite_master`.
472    ///
473    /// Uses a read-pool connection.
474    ///
475    /// # Errors
476    ///
477    /// Returns [`RusqliteOltpError`] if the metadata query fails.
478    async fn table_exists(&self, table_name: &str) -> Result<bool, Self::Error> {
479        let tbl = table_name.to_string();
480        let conn = self.next_read_conn();
481        let exists = conn
482            .call(move |c| {
483                let count: i64 = c.query_row(
484                    "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?1",
485                    rusqlite::params![tbl],
486                    |row| row.get(0),
487                )?;
488                Ok(count > 0)
489            })
490            .await?;
491        Ok(exists)
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498    use arrow::array::{Array, Int64Array};
499    use rhei_core::OltpEngine;
500
501    #[tokio::test]
502    async fn test_basic_crud() {
503        let dir = tempfile::TempDir::new().unwrap();
504        let path = dir.path().join("test.db");
505        let engine = RusqliteEngine::new_local(path.to_str().unwrap(), 2)
506            .await
507            .unwrap();
508
509        // Create table
510        engine
511            .execute(
512                "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)",
513                &[],
514            )
515            .await
516            .unwrap();
517
518        // Insert
519        engine
520            .execute(
521                "INSERT INTO users (id, name, age) VALUES (?1, ?2, ?3)",
522                &[
523                    serde_json::json!(1),
524                    serde_json::json!("Alice"),
525                    serde_json::json!(30),
526                ],
527            )
528            .await
529            .unwrap();
530
531        // Query
532        let batches = engine.query("SELECT * FROM users", &[]).await.unwrap();
533        assert_eq!(batches.len(), 1);
534        assert_eq!(batches[0].num_rows(), 1);
535        assert_eq!(batches[0].num_columns(), 3);
536
537        // table_exists
538        assert!(engine.table_exists("users").await.unwrap());
539        assert!(!engine.table_exists("nonexistent").await.unwrap());
540    }
541
542    #[tokio::test]
543    async fn test_execute_batch() {
544        let dir = tempfile::TempDir::new().unwrap();
545        let path = dir.path().join("test_batch.db");
546        let engine = RusqliteEngine::new_local(path.to_str().unwrap(), 1)
547            .await
548            .unwrap();
549
550        engine
551            .execute("CREATE TABLE items (id INTEGER PRIMARY KEY, val TEXT)", &[])
552            .await
553            .unwrap();
554
555        let stmts: Vec<(String, Vec<serde_json::Value>)> = vec![
556            (
557                "INSERT INTO items (id, val) VALUES (?1, ?2)".to_string(),
558                vec![serde_json::json!(1), serde_json::json!("a")],
559            ),
560            (
561                "INSERT INTO items (id, val) VALUES (?1, ?2)".to_string(),
562                vec![serde_json::json!(2), serde_json::json!("b")],
563            ),
564            (
565                "INSERT INTO items (id, val) VALUES (?1, ?2)".to_string(),
566                vec![serde_json::json!(3), serde_json::json!("c")],
567            ),
568        ];
569
570        engine.execute_batch(&stmts).await.unwrap();
571
572        let batches = engine
573            .query("SELECT COUNT(*) FROM items", &[])
574            .await
575            .unwrap();
576        let count = batches[0]
577            .column(0)
578            .as_any()
579            .downcast_ref::<Int64Array>()
580            .unwrap()
581            .value(0);
582        assert_eq!(count, 3);
583    }
584
585    #[tokio::test]
586    async fn test_chunked_query() {
587        let dir = tempfile::TempDir::new().unwrap();
588        let path = dir.path().join("test_chunks.db");
589        let engine = RusqliteEngine::new_local(path.to_str().unwrap(), 1)
590            .await
591            .unwrap();
592
593        engine
594            .execute("CREATE TABLE big (id INTEGER PRIMARY KEY, val TEXT)", &[])
595            .await
596            .unwrap();
597
598        // Insert 20 000 rows via a single execute_batch to keep the test fast.
599        let stmts: Vec<(String, Vec<serde_json::Value>)> = (0..20_000u64)
600            .map(|i| {
601                (
602                    "INSERT INTO big (id, val) VALUES (?1, ?2)".to_string(),
603                    vec![serde_json::json!(i), serde_json::json!(format!("v{i}"))],
604                )
605            })
606            .collect();
607        engine.execute_batch(&stmts).await.unwrap();
608
609        let batches = engine.query("SELECT * FROM big", &[]).await.unwrap();
610
611        // With ARROW_BATCH_ROWS = 8192 and 20 000 rows we expect 3 batches
612        // (8192 + 8192 + 3616).
613        assert!(
614            batches.len() > 1,
615            "expected multiple RecordBatches, got {}",
616            batches.len()
617        );
618
619        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
620        assert_eq!(total_rows, 20_000, "total row count mismatch");
621
622        // Verify all batches use the same schema.
623        let schema = batches[0].schema();
624        for (i, batch) in batches.iter().enumerate() {
625            assert_eq!(batch.schema(), schema, "schema mismatch on batch {i}");
626        }
627    }
628
629    #[tokio::test]
630    async fn test_new_connection() {
631        let dir = tempfile::TempDir::new().unwrap();
632        let path = dir.path().join("test_conn.db");
633        let engine = RusqliteEngine::new_local(path.to_str().unwrap(), 1)
634            .await
635            .unwrap();
636
637        engine
638            .execute("CREATE TABLE t (id INTEGER PRIMARY KEY)", &[])
639            .await
640            .unwrap();
641
642        // new_connection should be able to see the table
643        let conn = engine.new_connection().await.unwrap();
644        let exists: bool = conn
645            .call(|c| {
646                let count: i64 = c.query_row(
647                    "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='t'",
648                    [],
649                    |row| row.get(0),
650                )?;
651                Ok(count > 0)
652            })
653            .await
654            .unwrap();
655        assert!(exists);
656    }
657
658    /// Regression test for sparse-column type inference.
659    ///
660    /// Inserts 10 000 rows with `score IS NULL` followed by rows with integer
661    /// values for `score`. Before the warm-up fix, the schema was locked after
662    /// the first 8 192 rows, all of which had a NULL `score` — permanently
663    /// mis-typing the column as `Utf8`. The correct type is `Int64`.
664    #[tokio::test]
665    async fn test_sparse_column_type_inference() {
666        use arrow::array::Int64Array;
667        use arrow::datatypes::DataType;
668
669        let dir = tempfile::TempDir::new().unwrap();
670        let path = dir.path().join("test_sparse.db");
671        let engine = RusqliteEngine::new_local(path.to_str().unwrap(), 1)
672            .await
673            .unwrap();
674
675        engine
676            .execute(
677                "CREATE TABLE sparse (id INTEGER PRIMARY KEY, score INTEGER)",
678                &[],
679            )
680            .await
681            .unwrap();
682
683        // Insert 10 000 NULL-score rows, then 100 rows with integer scores.
684        // With the old code the first ARROW_BATCH_ROWS (8 192) rows were all-NULL
685        // for `score`, locking it as Utf8. We need at least SCHEMA_WARMUP_ROWS
686        // (32 768) non-null rows to exercise the boundary, but 10 100 rows is
687        // sufficient to demonstrate the old bug (NULL for the first two 8 192-row
688        // chunks, then non-null) and confirm the fix works.
689        const NULL_ROWS: usize = 10_000;
690        const INT_ROWS: usize = 100;
691
692        let null_stmts: Vec<(String, Vec<serde_json::Value>)> = (0..NULL_ROWS)
693            .map(|i| {
694                (
695                    "INSERT INTO sparse (id, score) VALUES (?1, NULL)".to_string(),
696                    vec![serde_json::json!(i)],
697                )
698            })
699            .collect();
700        engine.execute_batch(&null_stmts).await.unwrap();
701
702        let int_stmts: Vec<(String, Vec<serde_json::Value>)> = (NULL_ROWS..NULL_ROWS + INT_ROWS)
703            .map(|i| {
704                (
705                    "INSERT INTO sparse (id, score) VALUES (?1, ?2)".to_string(),
706                    vec![serde_json::json!(i), serde_json::json!(i as i64)],
707                )
708            })
709            .collect();
710        engine.execute_batch(&int_stmts).await.unwrap();
711
712        let batches = engine.query("SELECT * FROM sparse", &[]).await.unwrap();
713        assert!(!batches.is_empty(), "expected at least one batch");
714
715        // All batches must share the same schema.
716        let schema = batches[0].schema();
717        for (idx, batch) in batches.iter().enumerate() {
718            assert_eq!(batch.schema(), schema, "schema mismatch on batch {idx}");
719        }
720
721        // `score` column must be Int64, not Utf8.
722        let score_field = schema
723            .field_with_name("score")
724            .expect("score field missing");
725        assert_eq!(
726            score_field.data_type(),
727            &DataType::Int64,
728            "sparse column 'score' should be Int64; \
729             was it mis-typed as Utf8 due to early schema lock?"
730        );
731
732        // Spot-check: the non-null integer values should be readable as Int64.
733        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
734        assert_eq!(total_rows, NULL_ROWS + INT_ROWS);
735
736        // Find a batch that contains the integer rows and verify a value.
737        let mut found_int = false;
738        for batch in &batches {
739            let score_col = batch.column_by_name("score").expect("score column missing");
740            if let Some(arr) = score_col.as_any().downcast_ref::<Int64Array>() {
741                for row in 0..arr.len() {
742                    if arr.is_valid(row) {
743                        // The integer value must be >= NULL_ROWS (our insert offset).
744                        assert!(
745                            arr.value(row) >= NULL_ROWS as i64,
746                            "unexpected integer value {}",
747                            arr.value(row)
748                        );
749                        found_int = true;
750                    }
751                }
752            }
753        }
754        assert!(
755            found_int,
756            "no non-null Int64 values found in 'score' column"
757        );
758    }
759}