use std::cell::{Cell, RefCell};
use sqlx::{Execute, FromRow, MySql, Pool, QueryBuilder, mysql::MySqlRow};
use super::reader_common::{calculate_page_index, should_load_page};
use crate::BatchError;
use crate::core::item::{ItemReader, ItemReaderResult};
pub struct MySqlRdbcItemReader<'a, I>
where
for<'r> I: FromRow<'r, MySqlRow> + Send + Unpin + Clone,
{
pub(crate) pool: Pool<MySql>,
pub(crate) query: &'a str,
pub(crate) page_size: Option<i32>,
pub(crate) offset: Cell<i32>,
pub(crate) buffer: RefCell<Vec<I>>,
pub(crate) keyset_column: Option<String>,
#[allow(clippy::type_complexity)]
pub(crate) keyset_key: Option<Box<dyn Fn(&I) -> String>>,
pub(crate) last_cursor: RefCell<Option<String>>,
}
impl<'a, I> MySqlRdbcItemReader<'a, I>
where
for<'r> I: FromRow<'r, MySqlRow> + Send + Unpin + Clone,
{
#[allow(clippy::type_complexity)]
pub fn new(
pool: Pool<MySql>,
query: &'a str,
page_size: Option<i32>,
keyset_column: Option<String>,
keyset_key: Option<Box<dyn Fn(&I) -> String>>,
) -> Self {
Self {
pool,
query,
page_size,
offset: Cell::new(0),
buffer: RefCell::new(vec![]),
keyset_column,
keyset_key,
last_cursor: RefCell::new(None),
}
}
fn read_page(&self) -> Result<(), BatchError> {
let mut query_builder = QueryBuilder::<MySql>::new(self.query);
if let Some(page_size) = self.page_size {
if let Some(ref col) = self.keyset_column {
let last = self.last_cursor.borrow();
if let Some(ref cursor_val) = *last {
let escaped = cursor_val.replace('\'', "''");
query_builder.push(format!(" WHERE {} > '{}'", col, escaped));
}
query_builder.push(format!(" ORDER BY {} LIMIT {}", col, page_size));
} else {
query_builder.push(format!(" LIMIT {} OFFSET {}", page_size, self.offset.get()));
}
}
let query = query_builder.build();
let items = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
sqlx::query_as::<_, I>(query.sql())
.fetch_all(&self.pool)
.await
.map_err(|e| BatchError::ItemReader(e.to_string()))
})
})?;
self.buffer.borrow_mut().clear();
self.buffer.borrow_mut().extend(items);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use sqlx::MySqlPool;
#[derive(Clone)]
struct Dummy;
impl<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> for Dummy {
fn from_row(_row: &'r sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
Ok(Dummy)
}
}
fn reader_with_keyset(keyset: bool) -> MySqlRdbcItemReader<'static, Dummy> {
let pool = MySqlPool::connect_lazy("mysql://root:root@localhost/test")
.expect("lazy pool creation should not fail");
let (col, key): (Option<String>, Option<Box<dyn Fn(&Dummy) -> String>>) = if keyset {
(
Some("id".to_string()),
Some(Box::new(|_: &Dummy| "0".to_string())),
)
} else {
(None, None)
};
MySqlRdbcItemReader::new(pool, "SELECT 1", Some(10), col, key)
}
#[tokio::test(flavor = "multi_thread")]
async fn should_initialize_without_keyset() {
let reader = reader_with_keyset(false);
assert!(reader.keyset_column.is_none(), "no keyset column expected");
assert!(reader.keyset_key.is_none(), "no keyset key fn expected");
assert!(
reader.last_cursor.borrow().is_none(),
"cursor must start as None"
);
assert_eq!(reader.offset.get(), 0, "initial offset should be 0");
assert!(
reader.buffer.borrow().is_empty(),
"buffer should start empty"
);
assert_eq!(reader.page_size, Some(10));
}
#[tokio::test(flavor = "multi_thread")]
async fn should_initialize_with_keyset_column_and_none_cursor() {
let reader = reader_with_keyset(true);
assert_eq!(
reader.keyset_column.as_deref(),
Some("id"),
"keyset column should be stored"
);
assert!(
reader.keyset_key.is_some(),
"keyset key fn should be stored"
);
assert!(
reader.last_cursor.borrow().is_none(),
"cursor must start as None before first read"
);
}
}
impl<I> ItemReader<I> for MySqlRdbcItemReader<'_, I>
where
for<'r> I: FromRow<'r, MySqlRow> + Send + Unpin + Clone,
{
fn read(&self) -> ItemReaderResult<I> {
let index = calculate_page_index(self.offset.get(), self.page_size);
if should_load_page(index) {
self.read_page()?;
}
let result = self.buffer.borrow().get(index as usize).cloned();
if let (Some(item), Some(key_fn)) = (&result, &self.keyset_key) {
*self.last_cursor.borrow_mut() = Some(key_fn(item));
}
self.offset.set(self.offset.get() + 1);
Ok(result)
}
}