1use std::collections::BTreeMap;
2use std::io::Write;
3
4use crate::{CSVError, CSVErrorType, Dialect};
5
6pub struct CSVWriter<T>
10where
11 T: Write + Sized,
12{
13 pub(crate) output: T,
14 pub(crate) columns: Option<Vec<String>>,
15 pub(crate) dialect: Dialect,
16 pub(crate) wrote_header: bool,
17}
18
19impl<T: Write + Sized> CSVWriter<T> {
20 #[must_use]
23 pub fn new(output: T) -> Self {
24 CSVWriter {
25 output,
26 columns: None,
27 dialect: Dialect::default(),
28 wrote_header: false,
29 }
30 }
31
32 #[must_use]
35 pub fn with_dialect(self, dialect: Dialect) -> Self {
36 CSVWriter { dialect, ..self }
37 }
38
39 #[must_use]
42 pub fn with_column_names(self, columns: &[&str]) -> Self {
43 let columns: Vec<String> = columns.iter().map(ToString::to_string).collect();
44 CSVWriter {
45 columns: Some(columns),
46 ..self
47 }
48 }
49
50 pub fn write_header(&mut self) -> Result<(), CSVError> {
53 if self.wrote_header {
54 return Ok(());
55 }
56 let Some(cols) = &self.columns else {
57 self.wrote_header = true;
58 return Ok(());
59 };
60
61 let line = self.make_line(cols);
62 self.output.write_all(line.as_bytes())?;
63 self.wrote_header = true;
64 Ok(())
65 }
66
67 #[must_use]
71 pub(crate) fn make_line(&self, fields: &[String]) -> String {
72 let line = fields.join(self.dialect.get_field_separators());
73 format!("{line}{}", self.dialect.get_line_separators())
74 }
75
76 pub fn write_line<R: AsRef<str>>(&mut self, fields: &[R]) -> Result<(), CSVError> {
80 self.write_header()?;
81 let fields: Vec<String> = fields.iter().map(|f| f.as_ref().to_string()).collect();
82 let line = self.make_line(fields.as_slice());
83 self.output.write_all(line.as_bytes())?;
84 Ok(())
85 }
86
87 pub fn write_fields<K: AsRef<str>, V: AsRef<str>>(
96 &mut self,
97 fields: &BTreeMap<K, V>,
98 ) -> Result<(), CSVError> {
99 self.write_header()?;
100 let Some(cols) = &self.columns else {
101 return CSVError::err(
102 CSVErrorType::MissingHeaderError,
103 "No header columns specified".to_string(),
104 );
105 };
106 let mut out = Vec::new();
107 for col in cols {
108 out.push(
109 fields
110 .iter()
111 .find_map(|(k, v)| {
112 if col == k.as_ref() {
113 return Some(String::from(v.as_ref()));
114 }
115 None
116 })
117 .unwrap_or_default(),
118 );
119 }
120 let line = self.make_line(&out);
121 self.output.write_all(line.as_bytes())?;
122 Ok(())
123 }
124}