use serde::{Deserialize, Serialize};
use spring_batch_rs::{
BatchError,
core::{
item::{ItemProcessor, PassThroughProcessor},
job::{Job, JobBuilder},
step::StepBuilder,
},
item::{
csv::csv_reader::CsvItemReaderBuilder,
rdbc::{DatabaseItemBinder, RdbcItemReaderBuilder, RdbcItemWriterBuilder},
xml::xml_writer::XmlItemWriterBuilder,
},
};
use sqlx::{FromRow, PgPool, Postgres, query_builder::Separated};
use std::{
env,
fs::File,
io::{BufReader, BufWriter, Write},
time::Instant,
};
#[derive(Debug, Clone, Deserialize, Serialize, FromRow)]
struct Transaction {
transaction_id: String,
amount: f64,
currency: String,
#[serde(rename = "timestamp")]
timestamp: String,
account_from: String,
account_to: String,
status: String,
#[serde(default)]
amount_eur: f64,
}
#[derive(Default)]
struct TransactionProcessor;
impl ItemProcessor<Transaction, Transaction> for TransactionProcessor {
fn process(&self, item: &Transaction) -> Result<Option<Transaction>, BatchError> {
let rate = match item.currency.as_str() {
"USD" => 0.92,
"GBP" => 1.17,
_ => 1.0,
};
let status = if item.status == "CANCELLED" {
"FAILED".to_string()
} else {
item.status.clone()
};
Ok(Some(Transaction {
transaction_id: item.transaction_id.clone(),
amount: item.amount,
currency: item.currency.clone(),
timestamp: item.timestamp.clone(),
account_from: item.account_from.clone(),
account_to: item.account_to.clone(),
status,
amount_eur: (item.amount * rate * 100.0).round() / 100.0,
}))
}
}
struct TransactionBinder;
impl DatabaseItemBinder<Transaction, Postgres> for TransactionBinder {
fn bind(&self, item: &Transaction, mut q: Separated<Postgres, &str>) {
q.push_bind(item.transaction_id.clone());
q.push_bind(item.amount);
q.push_bind(item.currency.clone());
q.push_bind(item.timestamp.clone());
q.push_bind(item.account_from.clone());
q.push_bind(item.account_to.clone());
q.push_bind(item.status.clone());
q.push_bind(item.amount_eur);
}
}
const CURRENCIES: [&str; 3] = ["USD", "EUR", "GBP"];
const STATUSES: [&str; 4] = ["PENDING", "COMPLETED", "FAILED", "CANCELLED"];
const TOTAL_RECORDS: u64 = 10_000_000;
fn generate_csv(path: &str, count: u64) -> Result<(), BatchError> {
let file = File::create(path)
.map_err(|e| BatchError::ItemWriter(format!("Cannot create CSV: {}", e)))?;
let mut writer = BufWriter::with_capacity(256 * 1024, file);
writeln!(
writer,
"transaction_id,amount,currency,timestamp,account_from,account_to,status"
)
.map_err(|e| BatchError::ItemWriter(e.to_string()))?;
let mut seed: u64 = 0xDEAD_BEEF_CAFE_BABE;
for i in 0..count {
seed = seed
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let r1 = (seed >> 33) as u32;
seed = seed
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let r2 = (seed >> 33) as u32;
let currency = CURRENCIES[(r1 % 3) as usize];
let status = STATUSES[(r2 % 4) as usize];
let amount = ((r1 % 9_999_999) + 100) as f64 / 100.0;
let month = r1 % 12 + 1;
let day = r2 % 28 + 1;
let hour = r1 % 24;
let min = r2 % 60;
let sec = r1 % 60;
let acc_from = r1 % 1_000_000;
let acc_to = r2 % 1_000_000;
writeln!(
writer,
"TXN-{:010},{:.2},{},2024-{:02}-{:02}T{:02}:{:02}:{:02}Z,\
ACC-{:08},ACC-{:08},{}",
i + 1,
amount,
currency,
month,
day,
hour,
min,
sec,
acc_from,
acc_to,
status
)
.map_err(|e| BatchError::ItemWriter(e.to_string()))?;
}
writer
.flush()
.map_err(|e| BatchError::ItemWriter(e.to_string()))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use spring_batch_rs::core::item::ItemProcessor;
fn make_transaction(currency: &str, amount: f64, status: &str) -> Transaction {
Transaction {
transaction_id: "TXN-0000000001".to_string(),
amount,
currency: currency.to_string(),
timestamp: "2024-06-15T12:00:00Z".to_string(),
account_from: "ACC-00000001".to_string(),
account_to: "ACC-00000002".to_string(),
status: status.to_string(),
amount_eur: 0.0,
}
}
#[test]
fn should_convert_usd_to_eur() {
let processor = TransactionProcessor;
let input = make_transaction("USD", 1000.0, "COMPLETED");
let result = processor.process(&input).unwrap(); assert_eq!(result.amount_eur, 920.0, "USD 1000 * 0.92 = EUR 920");
assert_eq!(result.currency, "USD", "currency field must not change");
}
#[test]
fn should_convert_gbp_to_eur() {
let processor = TransactionProcessor;
let input = make_transaction("GBP", 100.0, "COMPLETED");
let result = processor.process(&input).unwrap(); assert_eq!(result.amount_eur, 117.0, "GBP 100 * 1.17 = EUR 117");
}
#[test]
fn should_keep_eur_unchanged() {
let processor = TransactionProcessor;
let input = make_transaction("EUR", 500.0, "PENDING");
let result = processor.process(&input).unwrap(); assert_eq!(result.amount_eur, 500.0, "EUR passthrough: rate = 1.0");
}
#[test]
fn should_normalise_cancelled_to_failed() {
let processor = TransactionProcessor;
let input = make_transaction("EUR", 100.0, "CANCELLED");
let result = processor.process(&input).unwrap(); assert_eq!(
result.status, "FAILED",
"CANCELLED must be mapped to FAILED"
);
}
#[test]
fn should_preserve_other_statuses() {
let processor = TransactionProcessor;
for status in &["PENDING", "COMPLETED", "FAILED"] {
let input = make_transaction("EUR", 100.0, status);
let result = processor.process(&input).unwrap(); assert_eq!(
&result.status, status,
"status '{}' must not be changed",
status
);
}
}
#[test]
fn should_round_amount_eur_to_two_decimals() {
let processor = TransactionProcessor;
let input = make_transaction("USD", 333.33, "COMPLETED");
let result = processor.process(&input).unwrap(); assert!(
(result.amount_eur - 306.66_f64).abs() < 1e-9,
"amount_eur must be rounded to 2 decimals, got {}",
result.amount_eur
);
}
#[test]
fn should_generate_csv_with_correct_header_and_row_count() {
use std::io::Read;
let path = std::env::temp_dir().join("bench_smoke_test.csv");
generate_csv(path.to_str().unwrap(), 5).unwrap();
let mut content = String::new();
File::open(&path)
.unwrap()
.read_to_string(&mut content)
.unwrap();
let lines: Vec<&str> = content.lines().collect();
assert_eq!(
lines[0], "transaction_id,amount,currency,timestamp,account_from,account_to,status",
"CSV header mismatch"
);
assert_eq!(
lines.len(),
6,
"header + 5 data rows expected, got {}",
lines.len()
);
}
}
fn run_step1(pool: &PgPool, csv_path: &str) -> Result<u64, BatchError> {
log::info!("[Step 1] CSV → PostgreSQL …");
let t0 = Instant::now();
let file = File::open(csv_path)
.map_err(|e| BatchError::ItemReader(format!("Cannot open CSV: {}", e)))?;
let buffered = BufReader::with_capacity(64 * 1024, file);
let reader = CsvItemReaderBuilder::<Transaction>::new()
.has_headers(true)
.from_reader(buffered);
let binder = TransactionBinder;
let writer = RdbcItemWriterBuilder::<Transaction>::new()
.postgres(pool)
.table("transactions")
.add_column("transaction_id")
.add_column("amount")
.add_column("currency")
.add_column("timestamp")
.add_column("account_from")
.add_column("account_to")
.add_column("status")
.add_column("amount_eur")
.postgres_binder(&binder)
.build_postgres();
let processor = TransactionProcessor;
let step = StepBuilder::new("csv-to-postgres")
.chunk::<Transaction, Transaction>(1_000)
.reader(&reader)
.processor(&processor)
.writer(&writer)
.build();
let job = JobBuilder::new().start(&step).build();
job.run()?;
let exec = job
.get_step_execution("csv-to-postgres")
.expect("step 'csv-to-postgres' must exist after job.run()");
let duration = t0.elapsed();
let throughput = exec.write_count as f64 / duration.as_secs_f64();
eprintln!(
"[Step 1] Done — {} records written in {:.1}s ({:.0} rec/s)",
exec.write_count,
duration.as_secs_f64(),
throughput
);
if exec.read_error_count > 0 || exec.write_error_count > 0 {
eprintln!(
"[Step 1] WARNING: {} read errors, {} write errors skipped — throughput may be understated",
exec.read_error_count, exec.write_error_count
);
}
Ok(exec.write_count as u64)
}
fn run_step2(pool: &PgPool, xml_path: &str) -> Result<u64, BatchError> {
log::info!("[Step 2] PostgreSQL → XML …");
let t0 = Instant::now();
let reader = RdbcItemReaderBuilder::<Transaction>::new()
.postgres(pool.clone())
.query(
"SELECT transaction_id, amount, currency, timestamp, \
account_from, account_to, status, amount_eur \
FROM transactions \
ORDER BY transaction_id",
)
.with_page_size(1_000)
.build_postgres();
let writer = XmlItemWriterBuilder::<Transaction>::new()
.root_tag("transactions")
.item_tag("transaction")
.from_path(xml_path)
.map_err(|e| BatchError::ItemWriter(e.to_string()))?;
let processor = PassThroughProcessor::<Transaction>::new();
let step = StepBuilder::new("postgres-to-xml")
.chunk::<Transaction, Transaction>(1_000)
.reader(&reader)
.processor(&processor)
.writer(&writer)
.build();
let job = JobBuilder::new().start(&step).build();
job.run()?;
let exec = job
.get_step_execution("postgres-to-xml")
.expect("step 'postgres-to-xml' must exist after job.run()");
let duration = t0.elapsed();
let throughput = exec.write_count as f64 / duration.as_secs_f64();
eprintln!(
"[Step 2] Done — {} records written in {:.1}s ({:.0} rec/s)",
exec.write_count,
duration.as_secs_f64(),
throughput
);
if exec.read_error_count > 0 || exec.write_error_count > 0 {
eprintln!(
"[Step 2] WARNING: {} read errors, {} write errors skipped — throughput may be understated",
exec.read_error_count, exec.write_error_count
);
}
Ok(exec.write_count as u64)
}
#[tokio::main]
async fn main() -> Result<(), BatchError> {
env_logger::init();
let db_url = env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgresql://postgres:postgres@localhost:5432/benchmark".to_string());
let csv_path = env::var("CSV_PATH").unwrap_or_else(|_| {
std::env::temp_dir()
.join("transactions.csv")
.to_string_lossy()
.into_owned()
});
let xml_path = env::var("XML_PATH").unwrap_or_else(|_| {
std::env::temp_dir()
.join("transactions_export.xml")
.to_string_lossy()
.into_owned()
});
eprintln!("╔══════════════════════════════════════════════════════════╗");
eprintln!("║ Spring Batch RS — 10M Transaction Benchmark ║");
eprintln!("╚══════════════════════════════════════════════════════════╝");
eprintln!();
eprintln!("DB : {}", db_url);
eprintln!("CSV : {}", csv_path);
eprintln!("XML : {}", xml_path);
eprintln!();
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(10)
.connect(&db_url)
.await
.map_err(|e| BatchError::Step(format!("DB connect failed: {}", e)))?;
sqlx::query(
"CREATE TABLE IF NOT EXISTS transactions (
transaction_id VARCHAR(36) PRIMARY KEY,
amount DOUBLE PRECISION NOT NULL,
currency VARCHAR(3) NOT NULL,
timestamp VARCHAR(25) NOT NULL,
account_from VARCHAR(15) NOT NULL,
account_to VARCHAR(15) NOT NULL,
status VARCHAR(15) NOT NULL,
amount_eur DOUBLE PRECISION NOT NULL DEFAULT 0.0
)",
)
.execute(&pool)
.await
.map_err(|e| BatchError::Step(format!("Schema creation failed: {}", e)))?;
sqlx::query("TRUNCATE TABLE transactions")
.execute(&pool)
.await
.map_err(|e| BatchError::Step(format!("Truncate failed: {}", e)))?;
eprintln!(
"[Generate] Writing {} rows to {} …",
TOTAL_RECORDS, csv_path
);
let t_gen = Instant::now();
generate_csv(&csv_path, TOTAL_RECORDS)?;
eprintln!("[Generate] Done in {:.1}s", t_gen.elapsed().as_secs_f64());
eprintln!();
let t_total = Instant::now();
run_step1(&pool, &csv_path)?;
eprintln!();
run_step2(&pool, &xml_path)?;
eprintln!();
let total_secs = t_total.elapsed().as_secs_f64();
eprintln!("╔══════════════════════════════════════════════════════════╗");
eprintln!("║ BENCHMARK SUMMARY ║");
eprintln!("╠══════════════════════════════════════════════════════════╣");
eprintln!("║ Total pipeline duration : {:.1}s", total_secs);
eprintln!("║ Records processed : {}", TOTAL_RECORDS);
eprintln!(
"║ Average throughput : {:.0} rec/s",
TOTAL_RECORDS as f64 / total_secs
);
eprintln!("╚══════════════════════════════════════════════════════════╝");
eprintln!();
eprintln!("Hint: measure peak RSS with:");
eprintln!(" /usr/bin/time -v cargo run --release --example benchmark_csv_postgres_xml \\");
eprintln!(" --features csv,xml,rdbc-postgres 2>&1 | grep 'Maximum resident'");
Ok(())
}