use std::{fs::File, io, path::Path, result};
use {
csv_core::{
self, WriteResult, Writer as CoreWriter,
WriterBuilder as CoreWriterBuilder,
},
serde::Serialize,
};
use crate::{
byte_record::ByteRecord,
error::{Error, ErrorKind, IntoInnerError, Result},
serializer::{serialize, serialize_header},
{QuoteStyle, Terminator},
};
#[derive(Debug)]
pub struct WriterBuilder {
builder: CoreWriterBuilder,
capacity: usize,
flexible: bool,
has_headers: bool,
}
impl Default for WriterBuilder {
fn default() -> WriterBuilder {
WriterBuilder {
builder: CoreWriterBuilder::default(),
capacity: 8 * (1 << 10),
flexible: false,
has_headers: true,
}
}
}
impl WriterBuilder {
pub fn new() -> WriterBuilder {
WriterBuilder::default()
}
pub fn from_path<P: AsRef<Path>>(&self, path: P) -> Result<Writer<File>> {
Ok(Writer::new(self, File::create(path)?))
}
pub fn from_writer<W: io::Write>(&self, wtr: W) -> Writer<W> {
Writer::new(self, wtr)
}
pub fn delimiter(&mut self, delimiter: u8) -> &mut WriterBuilder {
self.builder.delimiter(delimiter);
self
}
pub fn has_headers(&mut self, yes: bool) -> &mut WriterBuilder {
self.has_headers = yes;
self
}
pub fn flexible(&mut self, yes: bool) -> &mut WriterBuilder {
self.flexible = yes;
self
}
pub fn terminator(&mut self, term: Terminator) -> &mut WriterBuilder {
self.builder.terminator(term.to_core());
self
}
pub fn quote_style(&mut self, style: QuoteStyle) -> &mut WriterBuilder {
self.builder.quote_style(style.to_core());
self
}
pub fn quote(&mut self, quote: u8) -> &mut WriterBuilder {
self.builder.quote(quote);
self
}
pub fn double_quote(&mut self, yes: bool) -> &mut WriterBuilder {
self.builder.double_quote(yes);
self
}
pub fn escape(&mut self, escape: u8) -> &mut WriterBuilder {
self.builder.escape(escape);
self
}
pub fn comment(&mut self, comment: Option<u8>) -> &mut WriterBuilder {
self.builder.comment(comment);
self
}
pub fn buffer_capacity(&mut self, capacity: usize) -> &mut WriterBuilder {
self.capacity = capacity;
self
}
}
#[derive(Debug)]
pub struct Writer<W: io::Write> {
core: CoreWriter,
wtr: Option<W>,
buf: Buffer,
state: WriterState,
}
#[derive(Debug)]
struct WriterState {
header: HeaderState,
flexible: bool,
first_field_count: Option<u64>,
fields_written: u64,
panicked: bool,
}
#[derive(Debug)]
enum HeaderState {
Write,
DidWrite,
DidNotWrite,
None,
}
#[derive(Debug)]
struct Buffer {
buf: Vec<u8>,
len: usize,
}
impl<W: io::Write> Drop for Writer<W> {
fn drop(&mut self) {
if self.wtr.is_some() && !self.state.panicked {
let _ = self.flush();
}
}
}
impl Writer<File> {
pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Writer<File>> {
WriterBuilder::new().from_path(path)
}
}
impl<W: io::Write> Writer<W> {
fn new(builder: &WriterBuilder, wtr: W) -> Writer<W> {
let header_state = if builder.has_headers {
HeaderState::Write
} else {
HeaderState::None
};
Writer {
core: builder.builder.build(),
wtr: Some(wtr),
buf: Buffer { buf: vec![0; builder.capacity], len: 0 },
state: WriterState {
header: header_state,
flexible: builder.flexible,
first_field_count: None,
fields_written: 0,
panicked: false,
},
}
}
pub fn from_writer(wtr: W) -> Writer<W> {
WriterBuilder::new().from_writer(wtr)
}
pub fn serialize<S: Serialize>(&mut self, record: S) -> Result<()> {
if let HeaderState::Write = self.state.header {
let wrote_header = serialize_header(self, &record)?;
if wrote_header {
self.write_terminator()?;
self.state.header = HeaderState::DidWrite;
} else {
self.state.header = HeaderState::DidNotWrite;
};
}
serialize(self, &record)?;
self.write_terminator()?;
Ok(())
}
pub fn write_record<I, T>(&mut self, record: I) -> Result<()>
where
I: IntoIterator<Item = T>,
T: AsRef<[u8]>,
{
for field in record.into_iter() {
self.write_field_impl(field)?;
}
self.write_terminator()
}
#[inline(never)]
pub fn write_byte_record(&mut self, record: &ByteRecord) -> Result<()> {
if record.as_slice().is_empty() {
return self.write_record(record);
}
let upper_bound =
(2 * record.as_slice().len())
+ (record.len().saturating_sub(1))
+ (2 * record.len())
+ 2;
if self.buf.writable().len() < upper_bound {
return self.write_record(record);
}
let mut first = true;
for field in record.iter() {
if !first {
self.buf.writable()[0] = self.core.get_delimiter();
self.buf.written(1);
}
first = false;
if !self.core.should_quote(field) {
self.buf.writable()[..field.len()].copy_from_slice(field);
self.buf.written(field.len());
} else {
self.buf.writable()[0] = self.core.get_quote();
self.buf.written(1);
let (res, nin, nout) = csv_core::quote(
field,
self.buf.writable(),
self.core.get_quote(),
self.core.get_escape(),
self.core.get_double_quote(),
);
debug_assert!(res == WriteResult::InputEmpty);
debug_assert!(nin == field.len());
self.buf.written(nout);
self.buf.writable()[0] = self.core.get_quote();
self.buf.written(1);
}
}
self.state.fields_written = record.len() as u64;
self.write_terminator_into_buffer()
}
pub fn write_field<T: AsRef<[u8]>>(&mut self, field: T) -> Result<()> {
self.write_field_impl(field)
}
#[inline(always)]
fn write_field_impl<T: AsRef<[u8]>>(&mut self, field: T) -> Result<()> {
if self.state.fields_written > 0 {
self.write_delimiter()?;
}
let mut field = field.as_ref();
loop {
let (res, nin, nout) = self.core.field(field, self.buf.writable());
field = &field[nin..];
self.buf.written(nout);
match res {
WriteResult::InputEmpty => {
self.state.fields_written += 1;
return Ok(());
}
WriteResult::OutputFull => self.flush_buf()?,
}
}
}
pub fn flush(&mut self) -> io::Result<()> {
self.flush_buf()?;
self.wtr.as_mut().unwrap().flush()?;
Ok(())
}
fn flush_buf(&mut self) -> io::Result<()> {
self.state.panicked = true;
let result = self.wtr.as_mut().unwrap().write_all(self.buf.readable());
self.state.panicked = false;
result?;
self.buf.clear();
Ok(())
}
pub fn get_ref(&self) -> &W {
self.wtr.as_ref().unwrap()
}
pub fn into_inner(
mut self,
) -> result::Result<W, IntoInnerError<Writer<W>>> {
match self.flush() {
Ok(()) => Ok(self.wtr.take().unwrap()),
Err(err) => Err(IntoInnerError::new(self, err)),
}
}
fn write_delimiter(&mut self) -> Result<()> {
loop {
let (res, nout) = self.core.delimiter(self.buf.writable());
self.buf.written(nout);
match res {
WriteResult::InputEmpty => return Ok(()),
WriteResult::OutputFull => self.flush_buf()?,
}
}
}
fn write_terminator(&mut self) -> Result<()> {
self.check_field_count()?;
loop {
let (res, nout) = self.core.terminator(self.buf.writable());
self.buf.written(nout);
match res {
WriteResult::InputEmpty => {
self.state.fields_written = 0;
return Ok(());
}
WriteResult::OutputFull => self.flush_buf()?,
}
}
}
#[inline(never)]
fn write_terminator_into_buffer(&mut self) -> Result<()> {
self.check_field_count()?;
match self.core.get_terminator() {
csv_core::Terminator::CRLF => {
self.buf.writable()[0] = b'\r';
self.buf.writable()[1] = b'\n';
self.buf.written(2);
}
csv_core::Terminator::Any(b) => {
self.buf.writable()[0] = b;
self.buf.written(1);
}
_ => unreachable!(),
}
self.state.fields_written = 0;
Ok(())
}
fn check_field_count(&mut self) -> Result<()> {
if !self.state.flexible {
match self.state.first_field_count {
None => {
self.state.first_field_count =
Some(self.state.fields_written);
}
Some(expected) if expected != self.state.fields_written => {
return Err(Error::new(ErrorKind::UnequalLengths {
pos: None,
expected_len: expected,
len: self.state.fields_written,
}))
}
Some(_) => {}
}
}
Ok(())
}
}
impl Buffer {
#[inline]
fn readable(&self) -> &[u8] {
&self.buf[..self.len]
}
#[inline]
fn writable(&mut self) -> &mut [u8] {
&mut self.buf[self.len..]
}
#[inline]
fn written(&mut self, n: usize) {
self.len += n;
}
#[inline]
fn clear(&mut self) {
self.len = 0;
}
}
#[cfg(test)]
mod tests {
use std::io::{self, Write};
use serde::{serde_if_integer128, Serialize};
use crate::{
byte_record::ByteRecord, error::ErrorKind, string_record::StringRecord,
};
use super::{Writer, WriterBuilder};
fn wtr_as_string(wtr: Writer<Vec<u8>>) -> String {
String::from_utf8(wtr.into_inner().unwrap()).unwrap()
}
#[test]
fn one_record() {
let mut wtr = WriterBuilder::new().from_writer(vec![]);
wtr.write_record(&["a", "b", "c"]).unwrap();
assert_eq!(wtr_as_string(wtr), "a,b,c\n");
}
#[test]
fn one_string_record() {
let mut wtr = WriterBuilder::new().from_writer(vec![]);
wtr.write_record(&StringRecord::from(vec!["a", "b", "c"])).unwrap();
assert_eq!(wtr_as_string(wtr), "a,b,c\n");
}
#[test]
fn one_byte_record() {
let mut wtr = WriterBuilder::new().from_writer(vec![]);
wtr.write_record(&ByteRecord::from(vec!["a", "b", "c"])).unwrap();
assert_eq!(wtr_as_string(wtr), "a,b,c\n");
}
#[test]
fn raw_one_byte_record() {
let mut wtr = WriterBuilder::new().from_writer(vec![]);
wtr.write_byte_record(&ByteRecord::from(vec!["a", "b", "c"])).unwrap();
assert_eq!(wtr_as_string(wtr), "a,b,c\n");
}
#[test]
fn one_empty_record() {
let mut wtr = WriterBuilder::new().from_writer(vec![]);
wtr.write_record(&[""]).unwrap();
assert_eq!(wtr_as_string(wtr), "\"\"\n");
}
#[test]
fn raw_one_empty_record() {
let mut wtr = WriterBuilder::new().from_writer(vec![]);
wtr.write_byte_record(&ByteRecord::from(vec![""])).unwrap();
assert_eq!(wtr_as_string(wtr), "\"\"\n");
}
#[test]
fn two_empty_records() {
let mut wtr = WriterBuilder::new().from_writer(vec![]);
wtr.write_record(&[""]).unwrap();
wtr.write_record(&[""]).unwrap();
assert_eq!(wtr_as_string(wtr), "\"\"\n\"\"\n");
}
#[test]
fn raw_two_empty_records() {
let mut wtr = WriterBuilder::new().from_writer(vec![]);
wtr.write_byte_record(&ByteRecord::from(vec![""])).unwrap();
wtr.write_byte_record(&ByteRecord::from(vec![""])).unwrap();
assert_eq!(wtr_as_string(wtr), "\"\"\n\"\"\n");
}
#[test]
fn unequal_records_bad() {
let mut wtr = WriterBuilder::new().from_writer(vec![]);
wtr.write_record(&ByteRecord::from(vec!["a", "b", "c"])).unwrap();
let err = wtr.write_record(&ByteRecord::from(vec!["a"])).unwrap_err();
match *err.kind() {
ErrorKind::UnequalLengths { ref pos, expected_len, len } => {
assert!(pos.is_none());
assert_eq!(expected_len, 3);
assert_eq!(len, 1);
}
ref x => {
panic!("expected UnequalLengths error, but got '{:?}'", x);
}
}
}
#[test]
fn raw_unequal_records_bad() {
let mut wtr = WriterBuilder::new().from_writer(vec![]);
wtr.write_byte_record(&ByteRecord::from(vec!["a", "b", "c"])).unwrap();
let err =
wtr.write_byte_record(&ByteRecord::from(vec!["a"])).unwrap_err();
match *err.kind() {
ErrorKind::UnequalLengths { ref pos, expected_len, len } => {
assert!(pos.is_none());
assert_eq!(expected_len, 3);
assert_eq!(len, 1);
}
ref x => {
panic!("expected UnequalLengths error, but got '{:?}'", x);
}
}
}
#[test]
fn unequal_records_ok() {
let mut wtr = WriterBuilder::new().flexible(true).from_writer(vec![]);
wtr.write_record(&ByteRecord::from(vec!["a", "b", "c"])).unwrap();
wtr.write_record(&ByteRecord::from(vec!["a"])).unwrap();
assert_eq!(wtr_as_string(wtr), "a,b,c\na\n");
}
#[test]
fn raw_unequal_records_ok() {
let mut wtr = WriterBuilder::new().flexible(true).from_writer(vec![]);
wtr.write_byte_record(&ByteRecord::from(vec!["a", "b", "c"])).unwrap();
wtr.write_byte_record(&ByteRecord::from(vec!["a"])).unwrap();
assert_eq!(wtr_as_string(wtr), "a,b,c\na\n");
}
#[test]
fn full_buffer_should_not_flush_underlying() {
struct MarkWriteAndFlush(Vec<u8>);
impl MarkWriteAndFlush {
fn to_str(self) -> String {
String::from_utf8(self.0).unwrap()
}
}
impl Write for MarkWriteAndFlush {
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
self.0.write(b">")?;
let written = self.0.write(data)?;
self.0.write(b"<")?;
Ok(written)
}
fn flush(&mut self) -> io::Result<()> {
self.0.write(b"!")?;
Ok(())
}
}
let underlying = MarkWriteAndFlush(vec![]);
let mut wtr =
WriterBuilder::new().buffer_capacity(4).from_writer(underlying);
wtr.write_byte_record(&ByteRecord::from(vec!["a", "b"])).unwrap();
wtr.write_byte_record(&ByteRecord::from(vec!["c", "d"])).unwrap();
wtr.flush().unwrap();
wtr.write_byte_record(&ByteRecord::from(vec!["e", "f"])).unwrap();
let got = wtr.into_inner().unwrap().to_str();
assert_eq!(got, ">a,b\n<>c,d\n<!>e,f\n<!");
}
#[test]
fn serialize_with_headers() {
#[derive(Serialize)]
struct Row {
foo: i32,
bar: f64,
baz: bool,
}
let mut wtr = WriterBuilder::new().from_writer(vec![]);
wtr.serialize(Row { foo: 42, bar: 42.5, baz: true }).unwrap();
assert_eq!(wtr_as_string(wtr), "foo,bar,baz\n42,42.5,true\n");
}
#[test]
fn serialize_no_headers() {
#[derive(Serialize)]
struct Row {
foo: i32,
bar: f64,
baz: bool,
}
let mut wtr =
WriterBuilder::new().has_headers(false).from_writer(vec![]);
wtr.serialize(Row { foo: 42, bar: 42.5, baz: true }).unwrap();
assert_eq!(wtr_as_string(wtr), "42,42.5,true\n");
}
serde_if_integer128! {
#[test]
fn serialize_no_headers_128() {
#[derive(Serialize)]
struct Row {
foo: i128,
bar: f64,
baz: bool,
}
let mut wtr =
WriterBuilder::new().has_headers(false).from_writer(vec![]);
wtr.serialize(Row {
foo: 9_223_372_036_854_775_808,
bar: 42.5,
baz: true,
}).unwrap();
assert_eq!(wtr_as_string(wtr), "9223372036854775808,42.5,true\n");
}
}
#[test]
fn serialize_tuple() {
let mut wtr = WriterBuilder::new().from_writer(vec![]);
wtr.serialize((true, 1.3, "hi")).unwrap();
assert_eq!(wtr_as_string(wtr), "true,1.3,hi\n");
}
#[test]
fn comment_char_is_automatically_quoted() {
let mut wtr =
WriterBuilder::new().comment(Some(b'#')).from_writer(Vec::new());
wtr.write_record(&["# comment", "another"]).unwrap();
let buf = wtr.into_inner().unwrap();
assert_eq!(String::from_utf8(buf).unwrap(), "\"# comment\",another\n");
}
}