use std::cell::{Cell, RefCell};
use sqlx::{Execute, FromRow, Pool, QueryBuilder, Sqlite, sqlite::SqliteRow};
use super::reader_common::{calculate_page_index, should_load_page};
use crate::BatchError;
use crate::core::item::{ItemReader, ItemReaderResult};
pub struct SqliteRdbcItemReader<'a, I>
where
for<'r> I: FromRow<'r, SqliteRow> + Send + Unpin + Clone,
{
pub(crate) pool: Pool<Sqlite>,
pub(crate) query: &'a str,
pub(crate) page_size: Option<i32>,
pub(crate) offset: Cell<i32>,
pub(crate) buffer: RefCell<Vec<I>>,
}
impl<'a, I> SqliteRdbcItemReader<'a, I>
where
for<'r> I: FromRow<'r, SqliteRow> + Send + Unpin + Clone,
{
pub(crate) fn new(pool: Pool<Sqlite>, query: &'a str, page_size: Option<i32>) -> Self {
Self {
pool,
query,
page_size,
offset: Cell::new(0),
buffer: RefCell::new(vec![]),
}
}
fn read_page(&self) -> Result<(), BatchError> {
let mut query_builder = QueryBuilder::<Sqlite>::new(self.query);
if let Some(page_size) = self.page_size {
query_builder.push(format!(" LIMIT {} OFFSET {}", page_size, self.offset.get()));
}
let query = query_builder.build();
let items = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
sqlx::query_as::<_, I>(query.sql())
.fetch_all(&self.pool)
.await
.map_err(|e| BatchError::ItemReader(e.to_string()))
})
})?;
self.buffer.borrow_mut().clear();
self.buffer.borrow_mut().extend(items);
Ok(())
}
}
impl<I> ItemReader<I> for SqliteRdbcItemReader<'_, I>
where
for<'r> I: FromRow<'r, SqliteRow> + Send + Unpin + Clone,
{
fn read(&self) -> ItemReaderResult<I> {
let index = calculate_page_index(self.offset.get(), self.page_size);
if should_load_page(index) {
self.read_page()?;
}
let buffer = self.buffer.borrow();
let result = buffer.get(index as usize);
self.offset.set(self.offset.get() + 1);
Ok(result.cloned())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::item::ItemReader;
use sqlx::{FromRow, SqlitePool};
#[derive(Clone, FromRow)]
struct Row {
id: i32,
name: String,
}
async fn pool_with_rows(rows: &[(i32, &str)]) -> SqlitePool {
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
sqlx::query("CREATE TABLE items (id INTEGER, name TEXT)")
.execute(&pool)
.await
.unwrap();
for (id, name) in rows {
sqlx::query("INSERT INTO items (id, name) VALUES (?, ?)")
.bind(id)
.bind(name)
.execute(&pool)
.await
.unwrap();
}
pool
}
#[tokio::test(flavor = "multi_thread")]
async fn should_start_with_offset_zero_and_empty_buffer() {
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
let reader = SqliteRdbcItemReader::<Row>::new(pool, "SELECT id, name FROM items", None);
assert_eq!(reader.offset.get(), 0, "initial offset should be 0");
assert!(
reader.buffer.borrow().is_empty(),
"initial buffer should be empty"
);
assert_eq!(reader.page_size, None);
}
#[tokio::test(flavor = "multi_thread")]
async fn should_return_none_when_table_is_empty() {
let pool = pool_with_rows(&[]).await;
let reader = SqliteRdbcItemReader::<Row>::new(pool, "SELECT id, name FROM items", None);
let result = reader.read().unwrap();
assert!(result.is_none(), "empty table should yield None");
}
#[tokio::test(flavor = "multi_thread")]
async fn should_read_all_items_without_pagination() {
let pool = pool_with_rows(&[(1, "alice"), (2, "bob")]).await;
let reader =
SqliteRdbcItemReader::<Row>::new(pool, "SELECT id, name FROM items ORDER BY id", None);
let first = reader.read().unwrap().expect("first item should exist");
assert_eq!(first.name, "alice");
let second = reader.read().unwrap().expect("second item should exist");
assert_eq!(second.name, "bob");
assert!(
reader.read().unwrap().is_none(),
"should return None after all items"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn should_advance_offset_on_each_read() {
let pool = pool_with_rows(&[(1, "x"), (2, "y")]).await;
let reader =
SqliteRdbcItemReader::<Row>::new(pool, "SELECT id, name FROM items ORDER BY id", None);
assert_eq!(reader.offset.get(), 0);
reader.read().unwrap();
assert_eq!(
reader.offset.get(),
1,
"offset should increment after each read"
);
reader.read().unwrap();
assert_eq!(reader.offset.get(), 2);
}
#[tokio::test(flavor = "multi_thread")]
async fn should_read_all_items_with_pagination() {
let pool = pool_with_rows(&[(1, "a"), (2, "b"), (3, "c"), (4, "d")]).await;
let reader = SqliteRdbcItemReader::<Row>::new(
pool,
"SELECT id, name FROM items ORDER BY id",
Some(2), );
let mut count = 0;
while reader.read().unwrap().is_some() {
count += 1;
}
assert_eq!(count, 4, "should read all 4 items across 2 pages");
}
#[tokio::test(flavor = "multi_thread")]
async fn should_read_single_item() {
let pool = pool_with_rows(&[(42, "only")]).await;
let reader = SqliteRdbcItemReader::<Row>::new(pool, "SELECT id, name FROM items", None);
let item = reader
.read()
.unwrap()
.expect("should return the single item");
assert_eq!(item.id, 42);
assert_eq!(item.name, "only");
assert!(
reader.read().unwrap().is_none(),
"should return None after the only item"
);
}
}