use std::cell::{Cell, RefCell};
use sea_orm::{DatabaseConnection, DbErr, EntityTrait, FromQueryResult, PaginatorTrait, Select};
use crate::{
BatchError,
core::item::{ItemReader, ItemReaderResult},
};
pub struct OrmItemReader<'a, I>
where
I: EntityTrait,
{
connection: Option<&'a DatabaseConnection>,
query: Option<Select<I>>,
page_size: Option<u64>,
offset: Cell<u64>,
buffer: RefCell<Vec<I::Model>>,
current_page: Cell<u64>,
}
impl<'a, I> Default for OrmItemReader<'a, I>
where
I: EntityTrait,
I::Model: FromQueryResult + Send + Sync + Clone,
{
fn default() -> Self {
Self::new()
}
}
impl<'a, I> OrmItemReader<'a, I>
where
I: EntityTrait,
I::Model: FromQueryResult + Send + Sync + Clone,
{
pub fn new() -> Self {
Self {
connection: None,
query: None,
page_size: None,
offset: Cell::new(0),
buffer: RefCell::new(Vec::new()),
current_page: Cell::new(0),
}
}
pub fn connection(mut self, connection: &'a DatabaseConnection) -> Self {
self.connection = Some(connection);
self
}
pub fn query(mut self, query: Select<I>) -> Self {
self.query = Some(query);
if let Some(page_size) = self.page_size {
self.buffer = RefCell::new(Vec::with_capacity(page_size as usize));
}
self
}
pub fn page_size(mut self, page_size: u64) -> Self {
self.page_size = Some(page_size);
self.buffer = RefCell::new(Vec::with_capacity(page_size as usize));
self
}
async fn read_page_async(&self) -> Result<(), DbErr> {
let results = if let Some(page_size) = self.page_size {
let paginator = self
.query
.as_ref()
.unwrap()
.clone()
.paginate(self.connection.unwrap(), page_size);
let current_page = self.current_page.get();
paginator.fetch_page(current_page).await?
} else {
self.query
.as_ref()
.unwrap()
.clone()
.all(self.connection.unwrap())
.await?
};
let mut buffer = self.buffer.borrow_mut();
buffer.clear();
buffer.extend(results);
Ok(())
}
fn read_page(&self) -> Result<(), BatchError> {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
self.read_page_async()
.await
.map_err(|e| BatchError::ItemReader(format!("SeaORM query failed: {}", e)))
})
})
}
}
impl<I> ItemReader<I::Model> for OrmItemReader<'_, I>
where
I: EntityTrait,
I::Model: FromQueryResult + Send + Sync + Clone,
{
fn read(&self) -> ItemReaderResult<I::Model> {
let index = if let Some(page_size) = self.page_size {
self.offset.get() % page_size
} else {
self.offset.get()
};
if index == 0 {
self.read_page()?
}
let buffer = self.buffer.borrow();
let result = buffer.get(index as usize);
match result {
Some(item) => {
self.offset.set(self.offset.get() + 1);
if let Some(page_size) = self.page_size
&& self.offset.get().is_multiple_of(page_size)
{
self.current_page.set(self.current_page.get() + 1);
}
Ok(Some(item.clone()))
}
None => {
Ok(None)
}
}
}
}
pub struct OrmItemReaderBuilder<'a, I>
where
I: EntityTrait,
{
connection: Option<&'a DatabaseConnection>,
query: Option<Select<I>>,
page_size: Option<u64>,
}
impl<I> Default for OrmItemReaderBuilder<'_, I>
where
I: EntityTrait,
{
fn default() -> Self {
Self {
connection: None,
query: None,
page_size: None,
}
}
}
impl<'a, I> OrmItemReaderBuilder<'a, I>
where
I: EntityTrait,
I::Model: FromQueryResult + Send + Sync + Clone,
{
pub fn page_size(mut self, page_size: u64) -> Self {
self.page_size = Some(page_size);
self
}
pub fn query(mut self, query: Select<I>) -> Self {
self.query = Some(query);
self
}
pub fn connection(mut self, connection: &'a DatabaseConnection) -> Self {
self.connection = Some(connection);
self
}
pub fn build(self) -> OrmItemReader<'a, I> {
let mut reader = OrmItemReader::new()
.connection(self.connection.expect("Database connection is required"))
.query(self.query.expect("Query is required"));
if let Some(page_size) = self.page_size {
reader = reader.page_size(page_size);
}
reader
}
pub fn new() -> Self {
Self::default()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::item::ItemReader;
use sea_orm::{DatabaseBackend, MockDatabase, entity::prelude::*};
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "record")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub name: String,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}
#[test]
fn should_create_reader_with_default_state() {
let reader = OrmItemReader::<Entity>::new();
assert!(
reader.connection.is_none(),
"connection should start as None"
);
assert_eq!(reader.page_size, None);
assert_eq!(reader.offset.get(), 0);
assert_eq!(reader.current_page.get(), 0);
assert!(reader.buffer.borrow().is_empty());
}
#[test]
fn should_set_page_size_via_method() {
let reader = OrmItemReader::<Entity>::new().page_size(50);
assert_eq!(reader.page_size, Some(50));
assert_eq!(reader.buffer.borrow().capacity(), 50);
}
#[tokio::test(flavor = "multi_thread")]
async fn should_return_none_when_database_is_empty() {
let db = MockDatabase::new(DatabaseBackend::Sqlite)
.append_query_results([Vec::<Model>::new()])
.into_connection();
let reader = OrmItemReader::<Entity>::new()
.connection(&db)
.query(Entity::find());
let result = reader.read().unwrap();
assert!(result.is_none(), "empty DB should yield None");
}
#[tokio::test(flavor = "multi_thread")]
async fn should_read_single_item_then_return_none() {
let db = MockDatabase::new(DatabaseBackend::Sqlite)
.append_query_results([vec![Model {
id: 1,
name: "Alice".to_string(),
}]])
.into_connection();
let reader = OrmItemReader::<Entity>::new()
.connection(&db)
.query(Entity::find());
let first = reader.read().unwrap().expect("first item should exist");
assert_eq!(first.name, "Alice");
assert_eq!(reader.offset.get(), 1);
let second = reader.read().unwrap();
assert!(second.is_none(), "should return None after the only item");
}
#[tokio::test(flavor = "multi_thread")]
async fn should_read_multiple_items_without_pagination() {
let db = MockDatabase::new(DatabaseBackend::Sqlite)
.append_query_results([vec![
Model {
id: 1,
name: "Alice".to_string(),
},
Model {
id: 2,
name: "Bob".to_string(),
},
]])
.into_connection();
let reader = OrmItemReader::<Entity>::new()
.connection(&db)
.query(Entity::find());
let a = reader.read().unwrap().unwrap();
assert_eq!(a.name, "Alice");
let b = reader.read().unwrap().unwrap();
assert_eq!(b.name, "Bob");
assert!(reader.read().unwrap().is_none());
}
#[tokio::test(flavor = "multi_thread")]
async fn should_paginate_across_multiple_pages() {
let db = MockDatabase::new(DatabaseBackend::Sqlite)
.append_query_results([
vec![Model {
id: 1,
name: "P1".to_string(),
}],
vec![Model {
id: 2,
name: "P2".to_string(),
}],
vec![], ])
.into_connection();
let reader = OrmItemReader::<Entity>::new()
.connection(&db)
.query(Entity::find())
.page_size(1);
let first = reader.read().unwrap().unwrap();
assert_eq!(first.name, "P1");
assert_eq!(
reader.current_page.get(),
1,
"page should advance after full page"
);
let second = reader.read().unwrap().unwrap();
assert_eq!(second.name, "P2");
assert!(
reader.read().unwrap().is_none(),
"should stop after all pages"
);
}
#[test]
fn should_create_builder_with_default_state() {
let builder = OrmItemReaderBuilder::<Entity>::new();
assert!(builder.connection.is_none());
assert!(builder.query.is_none());
assert_eq!(builder.page_size, None);
}
#[test]
fn should_set_page_size_on_builder() {
let builder = OrmItemReaderBuilder::<Entity>::new().page_size(100);
assert_eq!(builder.page_size, Some(100));
}
#[test]
#[should_panic(expected = "Database connection is required")]
fn should_panic_when_building_without_connection() {
OrmItemReaderBuilder::<Entity>::new()
.query(Entity::find())
.build();
}
#[test]
#[should_panic(expected = "Query is required")]
fn should_panic_when_building_without_query() {
let db = MockDatabase::new(DatabaseBackend::Sqlite).into_connection();
OrmItemReaderBuilder::<Entity>::new()
.connection(&db)
.build();
}
#[test]
fn should_build_reader_with_connection_and_query() {
let db = MockDatabase::new(DatabaseBackend::Sqlite).into_connection();
let reader = OrmItemReaderBuilder::<Entity>::new()
.connection(&db)
.query(Entity::find())
.build();
assert!(reader.connection.is_some());
assert_eq!(reader.page_size, None);
}
#[test]
fn should_build_reader_with_page_size() {
let db = MockDatabase::new(DatabaseBackend::Sqlite).into_connection();
let reader = OrmItemReaderBuilder::<Entity>::new()
.connection(&db)
.query(Entity::find())
.page_size(50)
.build();
assert_eq!(reader.page_size, Some(50));
}
}