use std::io::{BufWriter, IntoInnerError, Write};
use crate::decoders::decoders::Encoding;
pub struct CsvWriter<'a, W: Write> {
writer: BufWriter<W>,
delimiter: u8,
line_break: &'a [u8],
encoder: Encoding,
}
impl<'a, W: Write> CsvWriter<'a, W> {
pub fn new(inner: W, delimiter: u8, line_break: &'a [u8], encoding: Encoding) -> Self {
Self {
writer: BufWriter::with_capacity(64 * 1024, inner),
delimiter,
line_break,
encoder: encoding,
}
}
pub fn write_row(&mut self, fields: &[&[u8]]) -> std::io::Result<()> {
for (i, field) in fields.iter().enumerate() {
if i > 0 {
self.writer.write_all(&[self.delimiter])?;
}
self.writer.write_all(field)?;
}
self.writer.write_all(self.line_break)?;
Ok(())
}
pub fn write_row_encoded(&mut self, fields: &[&str]) -> std::io::Result<()> {
for (i, field) in fields.iter().enumerate() {
if i > 0 {
self.writer.write_all(&[self.delimiter])?;
}
let encoded = self.encoder.encode(field);
self.writer.write_all(&encoded)?;
}
self.writer.write_all(self.line_break)?;
Ok(())
}
pub fn write_row_fast(&mut self, fields: &[&str]) -> std::io::Result<()> {
for (i, field) in fields.iter().enumerate() {
if i > 0 {
self.writer.write_all(&[self.delimiter])?;
}
self.writer.write_all(field.as_bytes())?;
}
self.writer.write_all(self.line_break)?;
Ok(())
}
pub fn write_row_simd(&mut self, fields: &[&[u8]]) -> std::io::Result<()> {
let mut tmp = [0u8; 4096];
let mut cursor = 0usize;
for (i, field) in fields.iter().enumerate() {
if i > 0 {
tmp[cursor] = self.delimiter;
cursor += 1;
}
copy_bytes(&mut tmp[cursor..], field);
cursor += field.len();
}
copy_bytes(&mut tmp[cursor..], self.line_break);
cursor += self.line_break.len();
self.writer.write_all(&tmp[..cursor])?;
Ok(())
}
pub fn flush_and_get(mut self) -> Result<W, IntoInnerError<BufWriter<W>>> {
_ = self.flush();
self.writer.into_inner()
}
pub fn flush(&mut self) -> std::io::Result<()> {
self.writer.flush()
}
}
#[inline(always)]
fn copy_bytes(dest: &mut [u8], src: &[u8]) {
dest[..src.len()].copy_from_slice(src);
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
fn as_utf8_str(bytes: &Cursor<Vec<u8>>) -> &str {
std::str::from_utf8(bytes.get_ref()).expect("Output is not valid UTF-8")
}
#[test]
fn test_write_row() {
let buffer = Cursor::new(Vec::new());
let mut writer = CsvWriter::new(buffer, b',', b"\n", Encoding::Windows1252);
let fields: Vec<&[u8]> = vec![b"hello", b"world", b"csv"];
writer.write_row(&fields).expect("Failed to write row");
writer.flush().expect("Failed to flush writer");
let result = writer.writer.into_inner().expect("Failed to recover buffer");
assert_eq!(as_utf8_str(&result), "hello,world,csv\n");
}
#[test]
fn test_write_row_fast() {
let buffer = Cursor::new(Vec::new());
let mut writer = CsvWriter::new(buffer, b';', b"\r\n", Encoding::Windows1252);
let fields: Vec<&str> = vec!["fast", "simple", "write"];
writer.write_row_fast(&fields).expect("Failed to write fast row");
writer.flush().expect("Failed to flush writer");
let result = writer.writer.into_inner().expect("Failed to recover buffer");
assert_eq!(as_utf8_str(&result), "fast;simple;write\r\n");
}
#[test]
fn test_write_row_simd() {
let buffer = Cursor::new(Vec::new());
let mut writer = CsvWriter::new(buffer, b'\t', b"\n", Encoding::Windows1252);
let fields: Vec<&[u8]> = vec![b"one", b"two", b"three"];
writer.write_row_simd(&fields).expect("Failed to write simd row");
writer.flush().expect("Failed to flush writer");
let result = writer.writer.into_inner().expect("Failed to recover buffer");
assert_eq!(as_utf8_str(&result), "one\ttwo\tthree\n");
}
#[test]
fn test_write_empty_row() {
let buffer = Cursor::new(Vec::new());
let mut writer = CsvWriter::new(buffer, b',', b"\n", Encoding::Windows1252);
let fields: Vec<&[u8]> = vec![];
writer.write_row(&fields).expect("Failed to write empty row");
writer.flush().expect("Failed to flush writer");
let result = writer.writer.into_inner().expect("Failed to recover buffer");
assert_eq!(as_utf8_str(&result), "\n");
}
}