use csv::{ReaderBuilder, StringRecordsIntoIter, Terminator, Trim};
use serde::de::DeserializeOwned;
use std::{cell::RefCell, fs::File, io::Read, marker::PhantomData, path::Path};
use crate::{
core::item::{ItemReader, ItemReaderResult},
error::BatchError,
};
pub struct CsvItemReader<R: Read> {
records: RefCell<StringRecordsIntoIter<R>>,
}
impl<I: DeserializeOwned, R: Read> ItemReader<I> for CsvItemReader<R> {
fn read(&self) -> ItemReaderResult<I> {
if let Some(result) = self.records.borrow_mut().next() {
match result {
Ok(string_record) => {
let result: Result<I, _> = string_record.deserialize(None);
match result {
Ok(record) => Ok(Some(record)),
Err(error) => Err(BatchError::ItemReader(error.to_string())),
}
}
Err(error) => Err(BatchError::ItemReader(error.to_string())),
}
} else {
Ok(None)
}
}
}
#[derive(Default)]
pub struct CsvItemReaderBuilder<I> {
delimiter: u8,
terminator: Terminator,
has_headers: bool,
_pd: PhantomData<I>,
}
impl<I> CsvItemReaderBuilder<I> {
pub fn new() -> Self {
Self {
delimiter: b',',
terminator: Terminator::CRLF,
has_headers: false,
_pd: PhantomData,
}
}
pub fn delimiter(mut self, delimiter: u8) -> Self {
self.delimiter = delimiter;
self
}
pub fn terminator(mut self, terminator: Terminator) -> Self {
self.terminator = terminator;
self
}
pub fn has_headers(mut self, yes: bool) -> Self {
self.has_headers = yes;
self
}
pub fn from_reader<R: Read>(self, rdr: R) -> CsvItemReader<R> {
let rdr = ReaderBuilder::new()
.trim(Trim::All) .delimiter(self.delimiter)
.terminator(self.terminator)
.has_headers(self.has_headers)
.flexible(false) .from_reader(rdr);
let records = rdr.into_records();
CsvItemReader {
records: RefCell::new(records),
}
}
pub fn from_path<R: AsRef<Path>>(self, path: R) -> CsvItemReader<File> {
let rdr = ReaderBuilder::new()
.trim(Trim::All) .delimiter(self.delimiter)
.terminator(self.terminator)
.has_headers(self.has_headers)
.flexible(false) .from_path(path);
let records = rdr.unwrap().into_records();
CsvItemReader {
records: RefCell::new(records),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::item::ItemReader;
use csv::StringRecord;
use serde::Deserialize;
use std::error::Error;
use std::io::Write;
use tempfile::NamedTempFile;
#[derive(Debug, Deserialize, PartialEq)]
struct City {
city: String,
country: String,
pop: u32,
}
#[test]
fn should_parse_string_records_with_headers() -> Result<(), Box<dyn Error>> {
let data = "city,country,pop
Boston,United States,4628910
Concord,United States,42695";
let reader = CsvItemReaderBuilder::<City>::new()
.has_headers(true)
.delimiter(b',')
.from_reader(data.as_bytes());
let records = reader
.records
.into_inner()
.collect::<Result<Vec<StringRecord>, csv::Error>>()?;
assert_eq!(
records,
vec![
vec!["Boston", "United States", "4628910"],
vec!["Concord", "United States", "42695"],
]
);
Ok(())
}
#[test]
fn test_deserialize_typed_records() -> Result<(), Box<dyn Error>> {
let data = "city,country,pop
Boston,United States,4628910
Concord,United States,42695";
let reader = CsvItemReaderBuilder::<City>::new()
.has_headers(true)
.from_reader(data.as_bytes());
let record1: City = reader.read()?.unwrap();
assert_eq!(
record1,
City {
city: "Boston".to_string(),
country: "United States".to_string(),
pop: 4628910,
}
);
let record2: City = reader.read()?.unwrap();
assert_eq!(
record2,
City {
city: "Concord".to_string(),
country: "United States".to_string(),
pop: 42695,
}
);
assert!(ItemReader::<City>::read(&reader)?.is_none());
Ok(())
}
#[test]
fn test_read_from_file() -> Result<(), Box<dyn Error>> {
let mut temp_file = NamedTempFile::new()?;
let csv_content = "city,country,pop\nParis,France,2161000\nLyon,France,513275";
temp_file.write_all(csv_content.as_bytes())?;
let reader = CsvItemReaderBuilder::<City>::new()
.has_headers(true)
.from_path(temp_file.path());
let city1: City = reader.read()?.unwrap();
let city2: City = reader.read()?.unwrap();
assert_eq!(city1.city, "Paris");
assert_eq!(city2.city, "Lyon");
assert_eq!(city1.pop, 2161000);
assert_eq!(city2.pop, 513275);
Ok(())
}
#[test]
fn test_different_csv_formats() -> Result<(), Box<dyn Error>> {
let data = "city;country;pop\nBerlin;Germany;3645000\nMunich;Germany;1472000";
let reader = CsvItemReaderBuilder::<City>::new()
.has_headers(true)
.delimiter(b';')
.terminator(Terminator::Any(b'\n'))
.from_reader(data.as_bytes());
let city1: City = reader.read()?.unwrap();
let city2: City = reader.read()?.unwrap();
assert_eq!(city1.city, "Berlin");
assert_eq!(city2.city, "Munich");
assert_eq!(city1.country, "Germany");
Ok(())
}
#[test]
fn test_no_headers() -> Result<(), Box<dyn Error>> {
#[derive(Debug, Deserialize, PartialEq)]
struct Record {
field1: String,
field2: String,
field3: u32,
}
let data = "Tokyo,Japan,13960000\nOsaka,Japan,2691000";
let reader = CsvItemReaderBuilder::<Record>::new()
.has_headers(false)
.from_reader(data.as_bytes());
let record1: Record = ItemReader::<Record>::read(&reader)?.unwrap();
let record2: Record = ItemReader::<Record>::read(&reader)?.unwrap();
assert_eq!(
record1,
Record {
field1: "Tokyo".to_string(),
field2: "Japan".to_string(),
field3: 13960000,
}
);
assert_eq!(
record2,
Record {
field1: "Osaka".to_string(),
field2: "Japan".to_string(),
field3: 2691000,
}
);
Ok(())
}
#[test]
fn test_deserialization_error() {
let data = "city,country,pop\nMilan,Italy,not_a_number";
let reader = CsvItemReaderBuilder::<City>::new()
.has_headers(true)
.from_reader(data.as_bytes());
let result = ItemReader::<City>::read(&reader);
assert!(result.is_err());
}
#[test]
fn test_empty_file() -> Result<(), Box<dyn Error>> {
let data = "";
let reader = CsvItemReaderBuilder::<City>::new()
.has_headers(false)
.from_reader(data.as_bytes());
assert!(ItemReader::<City>::read(&reader)?.is_none());
Ok(())
}
#[test]
fn test_headers_only() -> Result<(), Box<dyn Error>> {
let data = "city,country,pop";
let reader = CsvItemReaderBuilder::<City>::new()
.has_headers(true)
.from_reader(data.as_bytes());
assert!(ItemReader::<City>::read(&reader)?.is_none());
Ok(())
}
}