use sea_orm::{ActiveModelTrait, DatabaseConnection, DbErr, EntityTrait, InsertResult};
use std::marker::PhantomData;
use crate::{
BatchError,
core::item::{ItemWriter, ItemWriterResult},
};
pub struct OrmItemWriter<'a, O>
where
O: ActiveModelTrait + Send,
{
connection: &'a DatabaseConnection,
_phantom: PhantomData<O>,
}
impl<'a, O> OrmItemWriter<'a, O>
where
O: ActiveModelTrait + Send,
{
pub fn new(connection: &'a DatabaseConnection) -> Self {
Self {
connection,
_phantom: PhantomData,
}
}
async fn insert_batch_async(&self, active_models: Vec<O>) -> Result<InsertResult<O>, DbErr> {
<O as ActiveModelTrait>::Entity::insert_many(active_models)
.exec(self.connection)
.await
}
fn insert_batch(&self, active_models: Vec<O>) -> Result<(), BatchError> {
let result = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(async { self.insert_batch_async(active_models).await })
});
match result {
Ok(_insert_result) => {
log::debug!("Successfully inserted batch to database");
Ok(())
}
Err(db_err) => {
let error_msg = format!("Failed to insert batch to database: {}", db_err);
log::error!("{}", error_msg);
Err(BatchError::ItemWriter(error_msg))
}
}
}
}
impl<O> ItemWriter<O> for OrmItemWriter<'_, O>
where
O: ActiveModelTrait + Send,
{
fn write(&self, items: &[O]) -> ItemWriterResult {
log::debug!("Writing {} active models to database", items.len());
if items.is_empty() {
log::debug!("No items to write, skipping database operation");
return Ok(());
}
let active_models: Vec<O> = items.to_vec();
self.insert_batch(active_models)?;
log::info!(
"Successfully wrote {} active models to database",
items.len()
);
Ok(())
}
fn flush(&self) -> ItemWriterResult {
log::debug!("Flush called on ORM writer (no-op)");
Ok(())
}
fn open(&self) -> ItemWriterResult {
log::debug!("Opened ORM writer");
Ok(())
}
fn close(&self) -> ItemWriterResult {
log::debug!("Closed ORM writer");
Ok(())
}
}
#[derive(Default)]
pub struct OrmItemWriterBuilder<'a, O>
where
O: ActiveModelTrait + Send,
{
connection: Option<&'a DatabaseConnection>,
_phantom: PhantomData<O>,
}
impl<'a, O> OrmItemWriterBuilder<'a, O>
where
O: ActiveModelTrait + Send,
{
pub fn new() -> Self {
Self {
connection: None,
_phantom: PhantomData,
}
}
pub fn connection(mut self, connection: &'a DatabaseConnection) -> Self {
self.connection = Some(connection);
self
}
pub fn build(self) -> OrmItemWriter<'a, O> {
let connection = self
.connection
.expect("Database connection is required. Call .connection() before .build()");
OrmItemWriter::new(connection)
}
}
#[cfg(test)]
mod tests {
use super::*;
use sea_orm::{
ActiveValue::{NotSet, Set},
entity::prelude::*,
};
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "test_entity")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub name: String,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}
#[test]
fn test_simplified_trait_bounds_compilation() {
fn _verify_bounds<A>()
where
A: ActiveModelTrait + Send,
{
}
_verify_bounds::<ActiveModel>();
let _builder = OrmItemWriterBuilder::<ActiveModel>::new();
assert!(_builder.connection.is_none());
}
#[test]
fn test_active_model_creation() {
let active_model = ActiveModel {
id: NotSet,
name: Set("Test".to_owned()),
};
fn check_traits<A>(_: A)
where
A: ActiveModelTrait + Send,
{
}
check_traits(active_model);
}
#[test]
fn test_write_empty_slice_skips_database_operation() {
use sea_orm::{DatabaseBackend, MockDatabase};
let db = MockDatabase::new(DatabaseBackend::Sqlite).into_connection();
let writer = OrmItemWriter::<ActiveModel>::new(&db);
assert!(writer.open().is_ok());
assert!(writer.flush().is_ok());
assert!(writer.write(&[]).is_ok());
assert!(writer.close().is_ok());
}
#[test]
fn should_build_writer_via_builder_with_connection() {
use sea_orm::{DatabaseBackend, MockDatabase};
let db = MockDatabase::new(DatabaseBackend::Sqlite).into_connection();
let writer = OrmItemWriterBuilder::<ActiveModel>::new()
.connection(&db)
.build();
assert!(writer.open().is_ok());
assert!(writer.flush().is_ok());
assert!(writer.close().is_ok());
}
#[tokio::test(flavor = "multi_thread")]
async fn should_write_active_models_to_mock_database() {
use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult};
let db = MockDatabase::new(DatabaseBackend::Sqlite)
.append_exec_results([MockExecResult {
last_insert_id: 1,
rows_affected: 1,
}])
.into_connection();
let writer = OrmItemWriter::<ActiveModel>::new(&db);
let items = vec![ActiveModel {
id: NotSet,
name: Set("Alice".to_owned()),
}];
let result = writer.write(&items);
assert!(
result.is_ok(),
"write should succeed with mock DB: {result:?}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn should_return_error_when_database_insert_fails() {
use crate::BatchError;
use sea_orm::{DatabaseBackend, DbErr, MockDatabase};
let db = MockDatabase::new(DatabaseBackend::Sqlite)
.append_exec_errors([DbErr::Custom("insert failed".to_owned())])
.into_connection();
let writer = OrmItemWriter::<ActiveModel>::new(&db);
let items = vec![ActiveModel {
id: NotSet,
name: Set("Fail".to_owned()),
}];
let result = writer.write(&items);
assert!(
result.is_err(),
"write should fail when database returns error"
);
match result {
Err(BatchError::ItemWriter(msg)) => {
assert!(msg.contains("Failed to insert"), "unexpected error: {msg}")
}
other => panic!("expected ItemWriter error, got {other:?}"),
}
}
}