use std::result;
use csv_core::{self, WriteResult, Writer as CoreWriter};
use futures::io::{self, AsyncWrite, AsyncWriteExt};
use crate::AsyncWriterBuilder;
use crate::byte_record::ByteRecord;
use crate::error::{Error, ErrorKind, IntoInnerError, Result};
impl AsyncWriterBuilder {
pub fn from_writer<W: AsyncWrite + Unpin>(&self, wtr: W) -> AsyncWriter<W> {
AsyncWriter::new(self, wtr)
}
}
#[derive(Debug)]
pub struct AsyncWriter<W: AsyncWrite + Unpin> {
core: CoreWriter,
wtr: Option<W>,
buf: Buffer,
state: WriterState,
}
#[derive(Debug)]
struct WriterState {
flexible: bool,
first_field_count: Option<u64>,
fields_written: u64,
panicked: bool,
}
#[derive(Debug)]
struct Buffer {
buf: Vec<u8>,
len: usize,
}
impl<W: AsyncWrite + Unpin> Drop for AsyncWriter<W> {
fn drop(&mut self) {
if self.wtr.is_some() && !self.state.panicked {
let _ = futures::executor::block_on(self.flush());
}
}
}
impl<W: AsyncWrite + Unpin> AsyncWriter<W> {
fn new(builder: &AsyncWriterBuilder, wtr: W) -> AsyncWriter<W> {
AsyncWriter {
core: builder.get_core_builder_ref().build(),
wtr: Some(wtr),
buf: Buffer { buf: vec![0; builder.get_buffer_capacity()], len: 0 },
state: WriterState {
flexible: builder.is_flexible(),
first_field_count: None,
fields_written: 0,
panicked: false,
},
}
}
pub fn from_writer(wtr: W) -> AsyncWriter<W> {
AsyncWriterBuilder::new().from_writer(wtr)
}
pub async fn write_record<I, T>(&mut self, record: I) -> Result<()>
where
I: IntoIterator<Item = T>,
T: AsRef<[u8]>,
{
for field in record.into_iter() {
self.write_field_impl(field).await?;
}
self.write_terminator().await
}
#[inline(never)]
pub async fn write_byte_record(&mut self, record: &ByteRecord) -> Result<()> {
if record.as_slice().is_empty() {
return self.write_record(record).await;
}
let upper_bound =
(2 * record.as_slice().len())
+ (record.len().saturating_sub(1))
+ (2 * record.len())
+ 2;
if self.buf.writable().len() < upper_bound {
return self.write_record(record).await;
}
let mut first = true;
for field in record.iter() {
if !first {
self.buf.writable()[0] = self.core.get_delimiter();
self.buf.written(1);
}
first = false;
if !self.core.should_quote(field) {
self.buf.writable()[..field.len()].copy_from_slice(field);
self.buf.written(field.len());
} else {
self.buf.writable()[0] = self.core.get_quote();
self.buf.written(1);
let (res, nin, nout) = csv_core::quote(
field,
self.buf.writable(),
self.core.get_quote(),
self.core.get_escape(),
self.core.get_double_quote(),
);
debug_assert!(res == WriteResult::InputEmpty);
debug_assert!(nin == field.len());
self.buf.written(nout);
self.buf.writable()[0] = self.core.get_quote();
self.buf.written(1);
}
}
self.state.fields_written = record.len() as u64;
self.write_terminator_into_buffer()
}
pub async fn write_field<T: AsRef<[u8]>>(&mut self, field: T) -> Result<()> {
self.write_field_impl(field).await
}
#[inline(always)]
async fn write_field_impl<T: AsRef<[u8]>>(&mut self, field: T) -> Result<()> {
if self.state.fields_written > 0 {
self.write_delimiter().await?;
}
let mut field = field.as_ref();
loop {
let (res, nin, nout) = self.core.field(field, self.buf.writable());
field = &field[nin..];
self.buf.written(nout);
match res {
WriteResult::InputEmpty => {
self.state.fields_written += 1;
return Ok(());
}
WriteResult::OutputFull => self.flush_buf().await?,
}
}
}
pub async fn flush(&mut self) -> io::Result<()> {
self.flush_buf().await?;
self.wtr.as_mut().unwrap().flush().await?;
Ok(())
}
async fn flush_buf(&mut self) -> io::Result<()> {
self.state.panicked = true;
let result = self.wtr.as_mut().unwrap().write_all(self.buf.readable()).await;
self.state.panicked = false;
result?;
self.buf.clear();
Ok(())
}
pub async fn into_inner(
mut self,
) -> result::Result<W, IntoInnerError<AsyncWriter<W>>> {
match self.flush().await {
Ok(()) => Ok(self.wtr.take().unwrap()),
Err(err) => Err(IntoInnerError::new(self, err)),
}
}
async fn write_delimiter(&mut self) -> Result<()> {
loop {
let (res, nout) = self.core.delimiter(self.buf.writable());
self.buf.written(nout);
match res {
WriteResult::InputEmpty => return Ok(()),
WriteResult::OutputFull => self.flush_buf().await?,
}
}
}
async fn write_terminator(&mut self) -> Result<()> {
self.check_field_count()?;
loop {
let (res, nout) = self.core.terminator(self.buf.writable());
self.buf.written(nout);
match res {
WriteResult::InputEmpty => {
self.state.fields_written = 0;
return Ok(());
}
WriteResult::OutputFull => self.flush_buf().await?,
}
}
}
#[inline(never)]
fn write_terminator_into_buffer(&mut self) -> Result<()> {
self.check_field_count()?;
match self.core.get_terminator() {
csv_core::Terminator::CRLF => {
self.buf.writable()[0] = b'\r';
self.buf.writable()[1] = b'\n';
self.buf.written(2);
}
csv_core::Terminator::Any(b) => {
self.buf.writable()[0] = b;
self.buf.written(1);
}
_ => unreachable!(),
}
self.state.fields_written = 0;
Ok(())
}
fn check_field_count(&mut self) -> Result<()> {
if !self.state.flexible {
match self.state.first_field_count {
None => {
self.state.first_field_count =
Some(self.state.fields_written);
}
Some(expected) if expected != self.state.fields_written => {
return Err(Error::new(ErrorKind::UnequalLengths {
pos: None,
expected_len: expected,
len: self.state.fields_written,
}))
}
Some(_) => {}
}
}
Ok(())
}
}
impl Buffer {
#[inline]
fn readable(&self) -> &[u8] {
&self.buf[..self.len]
}
#[inline]
fn writable(&mut self) -> &mut [u8] {
&mut self.buf[self.len..]
}
#[inline]
fn written(&mut self, n: usize) {
self.len += n;
}
#[inline]
fn clear(&mut self) {
self.len = 0;
}
}
#[cfg(test)]
mod tests {
use std::pin::Pin;
use std::task::{Context, Poll};
use futures::io;
use async_std::task;
use crate::byte_record::ByteRecord;
use crate::error::ErrorKind;
use crate::string_record::StringRecord;
use super::{AsyncWriter, AsyncWriterBuilder};
async fn wtr_as_string(wtr: AsyncWriter<Vec<u8>>) -> String {
String::from_utf8(wtr.into_inner().await.unwrap()).unwrap()
}
#[test]
fn one_record() {
task::block_on(async {
let mut wtr = AsyncWriterBuilder::new().from_writer(vec![]);
wtr.write_record(&["a", "b", "c"]).await.unwrap();
assert_eq!(wtr_as_string(wtr).await, "a,b,c\n");
});
}
#[test]
fn one_string_record() {
task::block_on(async {
let mut wtr = AsyncWriterBuilder::new().from_writer(vec![]);
wtr.write_record(&StringRecord::from(vec!["a", "b", "c"])).await.unwrap();
assert_eq!(wtr_as_string(wtr).await, "a,b,c\n");
});
}
#[test]
fn one_byte_record() {
task::block_on(async {
let mut wtr = AsyncWriterBuilder::new().from_writer(vec![]);
wtr.write_record(&ByteRecord::from(vec!["a", "b", "c"])).await.unwrap();
assert_eq!(wtr_as_string(wtr).await, "a,b,c\n");
});
}
#[test]
fn raw_one_byte_record() {
task::block_on(async {
let mut wtr = AsyncWriterBuilder::new().from_writer(vec![]);
wtr.write_byte_record(&ByteRecord::from(vec!["a", "b", "c"])).await.unwrap();
assert_eq!(wtr_as_string(wtr).await, "a,b,c\n");
});
}
#[test]
fn one_empty_record() {
task::block_on(async {
let mut wtr = AsyncWriterBuilder::new().from_writer(vec![]);
wtr.write_record(&[""]).await.unwrap();
assert_eq!(wtr_as_string(wtr).await, "\"\"\n");
});
}
#[test]
fn raw_one_empty_record() {
task::block_on(async {
let mut wtr = AsyncWriterBuilder::new().from_writer(vec![]);
wtr.write_byte_record(&ByteRecord::from(vec![""])).await.unwrap();
assert_eq!(wtr_as_string(wtr).await, "\"\"\n");
});
}
#[test]
fn two_empty_records() {
task::block_on(async {
let mut wtr = AsyncWriterBuilder::new().from_writer(vec![]);
wtr.write_record(&[""]).await.unwrap();
wtr.write_record(&[""]).await.unwrap();
assert_eq!(wtr_as_string(wtr).await, "\"\"\n\"\"\n");
});
}
#[test]
fn raw_two_empty_records() {
task::block_on(async {
let mut wtr = AsyncWriterBuilder::new().from_writer(vec![]);
wtr.write_byte_record(&ByteRecord::from(vec![""])).await.unwrap();
wtr.write_byte_record(&ByteRecord::from(vec![""])).await.unwrap();
assert_eq!(wtr_as_string(wtr).await, "\"\"\n\"\"\n");
});
}
#[test]
fn unequal_records_bad() {
task::block_on(async {
let mut wtr = AsyncWriterBuilder::new().from_writer(vec![]);
wtr.write_record(&ByteRecord::from(vec!["a", "b", "c"])).await.unwrap();
let err = wtr.write_record(&ByteRecord::from(vec!["a"])).await.unwrap_err();
match *err.kind() {
ErrorKind::UnequalLengths { ref pos, expected_len, len } => {
assert!(pos.is_none());
assert_eq!(expected_len, 3);
assert_eq!(len, 1);
}
ref x => {
panic!("expected UnequalLengths error, but got '{:?}'", x);
}
}
});
}
#[test]
fn raw_unequal_records_bad() {
task::block_on(async {
let mut wtr = AsyncWriterBuilder::new().from_writer(vec![]);
wtr.write_byte_record(&ByteRecord::from(vec!["a", "b", "c"])).await.unwrap();
let err =
wtr.write_byte_record(&ByteRecord::from(vec!["a"])).await.unwrap_err();
match *err.kind() {
ErrorKind::UnequalLengths { ref pos, expected_len, len } => {
assert!(pos.is_none());
assert_eq!(expected_len, 3);
assert_eq!(len, 1);
}
ref x => {
panic!("expected UnequalLengths error, but got '{:?}'", x);
}
}
});
}
#[test]
fn unequal_records_ok() {
task::block_on(async {
let mut wtr = AsyncWriterBuilder::new().flexible(true).from_writer(vec![]);
wtr.write_record(&ByteRecord::from(vec!["a", "b", "c"])).await.unwrap();
wtr.write_record(&ByteRecord::from(vec!["a"])).await.unwrap();
assert_eq!(wtr_as_string(wtr).await, "a,b,c\na\n");
});
}
#[test]
fn raw_unequal_records_ok() {
task::block_on(async {
let mut wtr = AsyncWriterBuilder::new().flexible(true).from_writer(vec![]);
wtr.write_byte_record(&ByteRecord::from(vec!["a", "b", "c"])).await.unwrap();
wtr.write_byte_record(&ByteRecord::from(vec!["a"])).await.unwrap();
assert_eq!(wtr_as_string(wtr).await, "a,b,c\na\n");
});
}
#[test]
fn full_buffer_should_not_flush_underlying() {
task::block_on(async {
#[derive(Debug)]
struct MarkWriteAndFlush(Vec<u8>);
impl MarkWriteAndFlush {
fn to_str(self) -> String {
String::from_utf8(self.0).unwrap()
}
}
impl io::AsyncWrite for MarkWriteAndFlush {
fn poll_write(
mut self: Pin<&mut Self>,
_: &mut Context,
buf: &[u8]
) -> Poll<Result<usize, io::Error>> {
use std::io::Write;
self.0.write(b">").unwrap();
let written = self.0.write(buf).unwrap();
assert_eq!(written, buf.len());
self.0.write(b"<").unwrap();
Poll::Ready(Ok(written))
}
fn poll_flush(mut self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), io::Error>> {
use std::io::Write;
self.0.write(b"!").unwrap();
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.poll_flush(cx)
}
}
let underlying = MarkWriteAndFlush(vec![]);
let mut wtr =
AsyncWriterBuilder::new().buffer_capacity(4).from_writer(underlying);
wtr.write_byte_record(&ByteRecord::from(vec!["a", "b"])).await.unwrap();
wtr.write_byte_record(&ByteRecord::from(vec!["c", "d"])).await.unwrap();
wtr.flush().await.unwrap();
wtr.write_byte_record(&ByteRecord::from(vec!["e", "f"])).await.unwrap();
let got = wtr.into_inner().await.unwrap().to_str();
assert_eq!(got, ">a,b\n<>c,d\n<!>e,f\n<!");
});
}
}