use std::{cell::RefCell, fs::File, io::Write, marker::PhantomData, path::Path};
use csv::{Writer, WriterBuilder};
use serde::Serialize;
use crate::{
BatchError,
core::item::{ItemWriter, ItemWriterResult},
};
pub struct CsvItemWriter<O, W: Write> {
writer: RefCell<Writer<W>>,
_phantom: PhantomData<O>,
}
impl<O: Serialize, W: Write> ItemWriter<O> for CsvItemWriter<O, W> {
fn write(&self, items: &[O]) -> ItemWriterResult {
for item in items.iter() {
let result = self.writer.borrow_mut().serialize(item);
if result.is_err() {
let error = result.err().unwrap();
return Err(BatchError::ItemWriter(error.to_string()));
}
}
Ok(())
}
fn flush(&self) -> ItemWriterResult {
let result = self.writer.borrow_mut().flush();
match result {
Ok(()) => Ok(()),
Err(error) => Err(BatchError::ItemWriter(error.to_string())),
}
}
}
#[derive(Default)]
pub struct CsvItemWriterBuilder<O> {
delimiter: u8,
has_headers: bool,
_pd: PhantomData<O>,
}
impl<O> CsvItemWriterBuilder<O> {
pub fn new() -> Self {
Self {
delimiter: b',',
has_headers: false,
_pd: PhantomData,
}
}
pub fn delimiter(mut self, delimiter: u8) -> Self {
self.delimiter = delimiter;
self
}
pub fn has_headers(mut self, yes: bool) -> Self {
self.has_headers = yes;
self
}
pub fn from_path<W: AsRef<Path>>(self, path: W) -> CsvItemWriter<O, File> {
let writer = WriterBuilder::new()
.flexible(false) .has_headers(self.has_headers)
.delimiter(self.delimiter)
.from_path(path);
CsvItemWriter {
writer: RefCell::new(writer.unwrap()),
_phantom: PhantomData,
}
}
pub fn from_writer<W: Write>(self, wtr: W) -> CsvItemWriter<O, W> {
let wtr = WriterBuilder::new()
.flexible(false) .has_headers(self.has_headers)
.delimiter(self.delimiter)
.from_writer(wtr);
CsvItemWriter {
writer: RefCell::new(wtr),
_phantom: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::item::ItemWriter;
use serde::Serialize;
#[derive(Serialize, Clone)]
struct Row {
name: String,
value: u32,
}
fn sample_rows() -> Vec<Row> {
vec![
Row {
name: "alpha".into(),
value: 1,
},
Row {
name: "beta".into(),
value: 2,
},
]
}
#[test]
fn should_write_records_with_headers() {
let mut buf = Vec::new();
{
let writer = CsvItemWriterBuilder::<Row>::new()
.has_headers(true)
.from_writer(&mut buf);
writer.write(&sample_rows()).unwrap();
ItemWriter::<Row>::flush(&writer).unwrap();
}
let out = String::from_utf8(buf).unwrap();
assert!(out.contains("name,value"), "header row missing: {out}");
assert!(out.contains("alpha,1"), "first data row missing: {out}");
assert!(out.contains("beta,2"), "second data row missing: {out}");
}
#[test]
fn should_write_records_without_headers() {
let mut buf = Vec::new();
{
let writer = CsvItemWriterBuilder::<Row>::new()
.has_headers(false)
.from_writer(&mut buf);
writer.write(&sample_rows()).unwrap();
ItemWriter::<Row>::flush(&writer).unwrap();
}
let out = String::from_utf8(buf).unwrap();
assert!(!out.contains("name"), "header row should be absent: {out}");
assert!(
out.contains("alpha,1"),
"data row missing from headerless output: {out}"
);
}
#[test]
fn should_write_with_custom_delimiter() {
let mut buf = Vec::new();
{
let writer = CsvItemWriterBuilder::<Row>::new()
.has_headers(true)
.delimiter(b';')
.from_writer(&mut buf);
writer.write(&sample_rows()).unwrap();
ItemWriter::<Row>::flush(&writer).unwrap();
}
let out = String::from_utf8(buf).unwrap();
assert!(
out.contains("name;value"),
"semicolon header missing: {out}"
);
assert!(out.contains("alpha;1"), "semicolon data missing: {out}");
}
#[test]
fn should_write_empty_chunk_without_error() {
let mut buf = Vec::new();
{
let writer = CsvItemWriterBuilder::<Row>::new()
.has_headers(true)
.from_writer(&mut buf);
writer.write(&[]).unwrap();
ItemWriter::<Row>::flush(&writer).unwrap();
}
let out = String::from_utf8(buf).unwrap();
assert!(
out.is_empty(),
"writing an empty chunk should produce no output, got: {out:?}"
);
}
#[test]
fn should_write_single_record() {
let mut buf = Vec::new();
{
let writer = CsvItemWriterBuilder::<Row>::new()
.has_headers(false)
.from_writer(&mut buf);
writer
.write(&[Row {
name: "only".into(),
value: 99,
}])
.unwrap();
ItemWriter::<Row>::flush(&writer).unwrap();
}
let out = String::from_utf8(buf).unwrap();
assert!(out.contains("only,99"), "single record missing: {out}");
}
#[test]
fn should_return_error_when_serialization_fails() {
use serde::ser;
#[derive(Clone)]
struct FailSerialize;
impl Serialize for FailSerialize {
fn serialize<S: serde::Serializer>(&self, _s: S) -> Result<S::Ok, S::Error> {
Err(ser::Error::custom("intentional failure"))
}
}
let mut buf = Vec::new();
let writer = CsvItemWriterBuilder::<FailSerialize>::new().from_writer(&mut buf);
let result = writer.write(&[FailSerialize]);
assert!(result.is_err(), "should fail when serialization fails");
match result {
Err(BatchError::ItemWriter(_)) => {}
other => panic!("expected ItemWriter error, got {other:?}"),
}
}
#[test]
fn should_return_error_when_flush_fails_on_io() {
use std::io;
struct FailFlushWriter;
impl Write for FailFlushWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::Other, "flush failed"))
}
}
let csv_writer = CsvItemWriter::<Row, FailFlushWriter> {
writer: RefCell::new(WriterBuilder::new().from_writer(FailFlushWriter)),
_phantom: PhantomData,
};
let result = ItemWriter::<Row>::flush(&csv_writer);
assert!(
result.is_err(),
"flush should fail when underlying writer fails"
);
match result {
Err(BatchError::ItemWriter(_)) => {}
other => panic!("expected ItemWriter error, got {other:?}"),
}
}
#[test]
fn should_write_to_file() {
use std::fs;
use tempfile::NamedTempFile;
let tmp = NamedTempFile::new().unwrap();
let path = tmp.path().to_path_buf();
let writer = CsvItemWriterBuilder::<Row>::new()
.has_headers(true)
.from_path(&path);
writer.write(&sample_rows()).unwrap();
ItemWriter::<Row>::flush(&writer).unwrap();
drop(writer);
let content = fs::read_to_string(&path).unwrap();
assert!(content.contains("name,value"), "file header missing");
assert!(content.contains("alpha,1"), "file data missing");
}
}