use std::io::Write;
use std::num::NonZeroUsize;
use std::sync::Arc;
use polars_core::frame::DataFrame;
use polars_core::runtime::RAYON;
use polars_core::schema::Schema;
use polars_error::PolarsResult;
use polars_utils::pl_str::PlSmallStr;
use super::write_impl::{UTF8_BOM, csv_header, write};
use super::{QuoteStyle, SerializeOptions};
use crate::shared::SerWriter;
#[must_use]
pub struct CsvWriter<W: Write> {
buffer: W,
options: Arc<SerializeOptions>,
header: bool,
bom: bool,
batch_size: NonZeroUsize,
n_threads: usize,
}
impl<W> SerWriter<W> for CsvWriter<W>
where
W: Write,
{
fn new(buffer: W) -> Self {
let options = SerializeOptions::default();
CsvWriter {
buffer,
options: options.into(),
header: true,
bom: false,
batch_size: NonZeroUsize::new(1024).unwrap(),
n_threads: RAYON.current_num_threads(),
}
}
fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> {
if self.bom {
self.buffer.write_all(&UTF8_BOM)?;
}
let names = df
.get_column_names()
.into_iter()
.map(|x| x.as_str())
.collect::<Vec<_>>();
if self.header {
self.buffer
.write_all(&csv_header(names.as_slice(), &self.options)?)?;
}
write(
&mut self.buffer,
df,
self.batch_size.into(),
self.options.clone(),
self.n_threads,
)
}
}
impl<W> CsvWriter<W>
where
W: Write,
{
fn options_mut(&mut self) -> &mut SerializeOptions {
Arc::make_mut(&mut self.options)
}
pub fn include_bom(mut self, include_bom: bool) -> Self {
self.bom = include_bom;
self
}
pub fn include_header(mut self, include_header: bool) -> Self {
self.header = include_header;
self
}
pub fn with_separator(mut self, separator: u8) -> Self {
self.options_mut().separator = separator;
self
}
pub fn with_batch_size(mut self, batch_size: NonZeroUsize) -> Self {
self.batch_size = batch_size;
self
}
pub fn with_date_format(mut self, format: Option<PlSmallStr>) -> Self {
if format.is_some() {
self.options_mut().date_format = format;
}
self
}
pub fn with_time_format(mut self, format: Option<PlSmallStr>) -> Self {
if format.is_some() {
self.options_mut().time_format = format;
}
self
}
pub fn with_datetime_format(mut self, format: Option<PlSmallStr>) -> Self {
if format.is_some() {
self.options_mut().datetime_format = format;
}
self
}
pub fn with_float_scientific(mut self, scientific: Option<bool>) -> Self {
if scientific.is_some() {
self.options_mut().float_scientific = scientific;
}
self
}
pub fn with_float_precision(mut self, precision: Option<usize>) -> Self {
if precision.is_some() {
self.options_mut().float_precision = precision;
}
self
}
pub fn with_decimal_comma(mut self, decimal_comma: bool) -> Self {
self.options_mut().decimal_comma = decimal_comma;
self
}
pub fn with_quote_char(mut self, char: u8) -> Self {
self.options_mut().quote_char = char;
self
}
pub fn with_null_value(mut self, null_value: PlSmallStr) -> Self {
self.options_mut().null = null_value;
self
}
pub fn with_line_terminator(mut self, line_terminator: PlSmallStr) -> Self {
self.options_mut().line_terminator = line_terminator;
self
}
pub fn with_quote_style(mut self, quote_style: QuoteStyle) -> Self {
self.options_mut().quote_style = quote_style;
self
}
pub fn n_threads(mut self, n_threads: usize) -> Self {
self.n_threads = n_threads;
self
}
pub fn batched(self, schema: &Schema) -> PolarsResult<BatchedWriter<W>> {
let expects_bom = self.bom;
let expects_header = self.header;
Ok(BatchedWriter {
writer: self,
has_written_bom: !expects_bom,
has_written_header: !expects_header,
schema: schema.clone(),
})
}
}
pub struct BatchedWriter<W: Write> {
writer: CsvWriter<W>,
has_written_bom: bool,
has_written_header: bool,
schema: Schema,
}
impl<W: Write> BatchedWriter<W> {
pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> {
if !self.has_written_bom {
self.has_written_bom = true;
self.writer.buffer.write_all(&UTF8_BOM)?;
}
if !self.has_written_header {
self.has_written_header = true;
let names = df
.get_column_names()
.into_iter()
.map(|x| x.as_str())
.collect::<Vec<_>>();
self.writer
.buffer
.write_all(&csv_header(names.as_slice(), &self.writer.options)?)?;
}
write(
&mut self.writer.buffer,
df,
self.writer.batch_size.into(),
self.writer.options.clone(),
self.writer.n_threads,
)?;
Ok(())
}
pub fn finish(&mut self) -> PolarsResult<()> {
if !self.has_written_bom {
self.has_written_bom = true;
self.writer.buffer.write_all(&UTF8_BOM)?;
}
if !self.has_written_header {
self.has_written_header = true;
let names = self
.schema
.iter_names()
.map(|x| x.as_str())
.collect::<Vec<_>>();
self.writer
.buffer
.write_all(&csv_header(&names, &self.writer.options)?)?;
};
Ok(())
}
}