1use std::{
2 fs::File,
3 io::{self, BufReader, BufWriter, Read, Write},
4 path::Path,
5};
6
7use anyhow::{Context, Result, anyhow};
8use csv::QuoteStyle;
9use encoding_rs::{Encoding, UTF_8};
10
11pub const DEFAULT_CSV_DELIMITER: u8 = b',';
12pub const DEFAULT_TSV_DELIMITER: u8 = b'\t';
13
14pub fn is_dash(path: &Path) -> bool {
15 path == Path::new("-")
16}
17
18pub fn resolve_encoding(label: Option<&str>) -> Result<&'static Encoding> {
19 if let Some(value) = label {
20 Encoding::for_label(value.trim().as_bytes())
21 .ok_or_else(|| anyhow!("Unknown encoding '{value}'"))
22 } else {
23 Ok(UTF_8)
24 }
25}
26
27pub fn resolve_input_delimiter(path: &Path, provided: Option<u8>) -> u8 {
28 provided.unwrap_or_else(|| match path.extension().and_then(|ext| ext.to_str()) {
29 Some(ext) if ext.eq_ignore_ascii_case("tsv") => DEFAULT_TSV_DELIMITER,
30 _ => DEFAULT_CSV_DELIMITER,
31 })
32}
33
34pub fn resolve_output_delimiter(path: Option<&Path>, provided: Option<u8>, fallback: u8) -> u8 {
35 if let Some(delim) = provided {
36 return delim;
37 }
38 if let Some(path) = path {
39 match path.extension().and_then(|ext| ext.to_str()) {
40 Some(ext) if ext.eq_ignore_ascii_case("tsv") => return DEFAULT_TSV_DELIMITER,
41 Some(ext) if ext.eq_ignore_ascii_case("csv") => return DEFAULT_CSV_DELIMITER,
42 _ => {}
43 }
44 }
45 fallback
46}
47
48pub fn open_csv_reader<R>(reader: R, delimiter: u8, has_headers: bool) -> csv::Reader<R>
49where
50 R: Read,
51{
52 let mut builder = csv::ReaderBuilder::new();
53 builder
54 .has_headers(has_headers)
55 .delimiter(delimiter)
56 .double_quote(true)
57 .flexible(false);
58 builder.from_reader(reader)
59}
60
61pub fn open_csv_reader_from_path(
62 path: &Path,
63 delimiter: u8,
64 has_headers: bool,
65) -> Result<csv::Reader<Box<dyn Read>>> {
66 let reader: Box<dyn Read> = if is_dash(path) {
67 Box::new(std::io::stdin().lock())
68 } else {
69 Box::new(BufReader::new(
70 File::open(path).with_context(|| format!("Opening input file {path:?}"))?,
71 ))
72 };
73 Ok(open_csv_reader(reader, delimiter, has_headers))
74}
75
76pub fn open_seekable_csv_reader(
77 path: &Path,
78 delimiter: u8,
79 has_headers: bool,
80) -> Result<csv::Reader<BufReader<File>>> {
81 let reader =
82 BufReader::new(File::open(path).with_context(|| format!("Opening input file {path:?}"))?);
83 Ok(open_csv_reader(reader, delimiter, has_headers))
84}
85
86pub fn open_csv_writer(
87 path: Option<&Path>,
88 delimiter: u8,
89 encoding: &'static Encoding,
90) -> Result<csv::Writer<Box<dyn Write>>> {
91 let base: Box<dyn Write> = match path {
92 Some(p) if !is_dash(p) => Box::new(BufWriter::new(
93 File::create(p).with_context(|| format!("Creating output file {p:?}"))?,
94 )),
95 _ => Box::new(std::io::stdout()),
96 };
97
98 let writer: Box<dyn Write> = if encoding == UTF_8 {
99 base
100 } else {
101 Box::new(TranscodingWriter::new(base, encoding))
102 };
103
104 let mut builder = csv::WriterBuilder::new();
105 builder
106 .delimiter(delimiter)
107 .quote_style(QuoteStyle::Necessary)
108 .double_quote(true);
109 Ok(builder.from_writer(writer))
110}
111
112pub fn decode_bytes(bytes: &[u8], encoding: &'static Encoding) -> Result<String> {
113 let (text, _, had_errors) = encoding.decode(bytes);
114 if had_errors {
115 Err(anyhow!(
116 "Failed to decode text with encoding {}",
117 encoding.name()
118 ))
119 } else {
120 Ok(text.into_owned())
121 }
122}
123
124pub fn decode_record(record: &csv::ByteRecord, encoding: &'static Encoding) -> Result<Vec<String>> {
125 record
126 .iter()
127 .map(|field| decode_bytes(field, encoding))
128 .collect()
129}
130
131pub fn decode_headers(
132 record: &csv::ByteRecord,
133 encoding: &'static Encoding,
134) -> Result<Vec<String>> {
135 decode_record(record, encoding)
136}
137
138pub fn reader_headers<R>(
139 reader: &mut csv::Reader<R>,
140 encoding: &'static Encoding,
141) -> Result<Vec<String>>
142where
143 R: Read,
144{
145 let headers = reader.byte_headers()?.clone();
146 decode_headers(&headers, encoding)
147}
148
149struct TranscodingWriter<W: Write> {
150 inner: W,
151 encoding: &'static Encoding,
152 buffer: Vec<u8>,
153}
154
155impl<W: Write> TranscodingWriter<W> {
156 fn new(inner: W, encoding: &'static Encoding) -> Self {
157 Self {
158 inner,
159 encoding,
160 buffer: Vec::new(),
161 }
162 }
163
164 fn flush_buffer(&mut self, force: bool) -> io::Result<()> {
165 let mut idx = 0;
166 while idx < self.buffer.len() {
167 match std::str::from_utf8(&self.buffer[idx..]) {
168 Ok(valid) => {
169 let text = valid.to_owned();
170 self.encode_and_write(&text)?;
171 self.buffer.clear();
172 return Ok(());
173 }
174 Err(err) => {
175 if let Some(error_len) = err.error_len() {
176 return Err(io::Error::new(
177 io::ErrorKind::InvalidData,
178 format!("Invalid UTF-8 sequence in output stream ({error_len} bytes)"),
179 ));
180 }
181 let valid_up_to = err.valid_up_to();
182 if valid_up_to > 0 {
183 let valid_slice = &self.buffer[idx..idx + valid_up_to];
184 let text = unsafe { std::str::from_utf8_unchecked(valid_slice).to_owned() };
185 self.encode_and_write(&text)?;
186 self.buffer.drain(..idx + valid_up_to);
187 idx = 0;
188 continue;
189 }
190 if force {
191 return Err(io::Error::new(
192 io::ErrorKind::InvalidData,
193 "Incomplete UTF-8 sequence at end of output stream",
194 ));
195 } else {
196 return Ok(());
197 }
198 }
199 }
200 }
201 if force && !self.buffer.is_empty() {
202 let text = String::from_utf8(self.buffer.clone()).map_err(|_| {
203 io::Error::new(
204 io::ErrorKind::InvalidData,
205 "Invalid UTF-8 sequence at end of output stream",
206 )
207 })?;
208 self.encode_and_write(&text)?;
209 self.buffer.clear();
210 }
211 Ok(())
212 }
213
214 fn encode_and_write(&mut self, text: &str) -> io::Result<()> {
215 let (encoded, _output_encoding, had_errors) = self.encoding.encode(text);
216 if had_errors {
217 return Err(io::Error::new(
218 io::ErrorKind::InvalidData,
219 format!("Failed to encode text using {}", self.encoding.name()),
220 ));
221 }
222 self.inner.write_all(encoded.as_ref())
223 }
224}
225
226impl<W: Write> Write for TranscodingWriter<W> {
227 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
228 self.buffer.extend_from_slice(buf);
229 self.flush_buffer(false)?;
230 Ok(buf.len())
231 }
232
233 fn flush(&mut self) -> io::Result<()> {
234 self.flush_buffer(true)?;
235 self.inner.flush()
236 }
237}