csv_managed/
io_utils.rs

1use std::{
2    fs::File,
3    io::{self, BufReader, BufWriter, Read, Write},
4    path::Path,
5};
6
7use anyhow::{Context, Result, anyhow};
8use csv::QuoteStyle;
9use encoding_rs::{Encoding, UTF_8};
10
11pub const DEFAULT_CSV_DELIMITER: u8 = b',';
12pub const DEFAULT_TSV_DELIMITER: u8 = b'\t';
13
14pub fn is_dash(path: &Path) -> bool {
15    path == Path::new("-")
16}
17
18pub fn resolve_encoding(label: Option<&str>) -> Result<&'static Encoding> {
19    if let Some(value) = label {
20        Encoding::for_label(value.trim().as_bytes())
21            .ok_or_else(|| anyhow!("Unknown encoding '{value}'"))
22    } else {
23        Ok(UTF_8)
24    }
25}
26
27pub fn resolve_input_delimiter(path: &Path, provided: Option<u8>) -> u8 {
28    provided.unwrap_or_else(|| match path.extension().and_then(|ext| ext.to_str()) {
29        Some(ext) if ext.eq_ignore_ascii_case("tsv") => DEFAULT_TSV_DELIMITER,
30        _ => DEFAULT_CSV_DELIMITER,
31    })
32}
33
34pub fn resolve_output_delimiter(path: Option<&Path>, provided: Option<u8>, fallback: u8) -> u8 {
35    if let Some(delim) = provided {
36        return delim;
37    }
38    if let Some(path) = path {
39        match path.extension().and_then(|ext| ext.to_str()) {
40            Some(ext) if ext.eq_ignore_ascii_case("tsv") => return DEFAULT_TSV_DELIMITER,
41            Some(ext) if ext.eq_ignore_ascii_case("csv") => return DEFAULT_CSV_DELIMITER,
42            _ => {}
43        }
44    }
45    fallback
46}
47
48pub fn open_csv_reader<R>(reader: R, delimiter: u8, has_headers: bool) -> csv::Reader<R>
49where
50    R: Read,
51{
52    let mut builder = csv::ReaderBuilder::new();
53    builder
54        .has_headers(has_headers)
55        .delimiter(delimiter)
56        .double_quote(true)
57        .flexible(false);
58    builder.from_reader(reader)
59}
60
61pub fn open_csv_reader_from_path(
62    path: &Path,
63    delimiter: u8,
64    has_headers: bool,
65) -> Result<csv::Reader<Box<dyn Read>>> {
66    let reader: Box<dyn Read> = if is_dash(path) {
67        Box::new(std::io::stdin().lock())
68    } else {
69        Box::new(BufReader::new(
70            File::open(path).with_context(|| format!("Opening input file {path:?}"))?,
71        ))
72    };
73    Ok(open_csv_reader(reader, delimiter, has_headers))
74}
75
76pub fn open_seekable_csv_reader(
77    path: &Path,
78    delimiter: u8,
79    has_headers: bool,
80) -> Result<csv::Reader<BufReader<File>>> {
81    let reader =
82        BufReader::new(File::open(path).with_context(|| format!("Opening input file {path:?}"))?);
83    Ok(open_csv_reader(reader, delimiter, has_headers))
84}
85
86pub fn open_csv_writer(
87    path: Option<&Path>,
88    delimiter: u8,
89    encoding: &'static Encoding,
90) -> Result<csv::Writer<Box<dyn Write>>> {
91    let base: Box<dyn Write> = match path {
92        Some(p) if !is_dash(p) => Box::new(BufWriter::new(
93            File::create(p).with_context(|| format!("Creating output file {p:?}"))?,
94        )),
95        _ => Box::new(std::io::stdout()),
96    };
97
98    let writer: Box<dyn Write> = if encoding == UTF_8 {
99        base
100    } else {
101        Box::new(TranscodingWriter::new(base, encoding))
102    };
103
104    let mut builder = csv::WriterBuilder::new();
105    builder
106        .delimiter(delimiter)
107        .quote_style(QuoteStyle::Necessary)
108        .double_quote(true);
109    Ok(builder.from_writer(writer))
110}
111
112pub fn decode_bytes(bytes: &[u8], encoding: &'static Encoding) -> Result<String> {
113    let (text, _, had_errors) = encoding.decode(bytes);
114    if had_errors {
115        Err(anyhow!(
116            "Failed to decode text with encoding {}",
117            encoding.name()
118        ))
119    } else {
120        Ok(text.into_owned())
121    }
122}
123
124pub fn decode_record(record: &csv::ByteRecord, encoding: &'static Encoding) -> Result<Vec<String>> {
125    record
126        .iter()
127        .map(|field| decode_bytes(field, encoding))
128        .collect()
129}
130
131pub fn decode_headers(
132    record: &csv::ByteRecord,
133    encoding: &'static Encoding,
134) -> Result<Vec<String>> {
135    decode_record(record, encoding)
136}
137
138pub fn reader_headers<R>(
139    reader: &mut csv::Reader<R>,
140    encoding: &'static Encoding,
141) -> Result<Vec<String>>
142where
143    R: Read,
144{
145    let headers = reader.byte_headers()?.clone();
146    decode_headers(&headers, encoding)
147}
148
149struct TranscodingWriter<W: Write> {
150    inner: W,
151    encoding: &'static Encoding,
152    buffer: Vec<u8>,
153}
154
155impl<W: Write> TranscodingWriter<W> {
156    fn new(inner: W, encoding: &'static Encoding) -> Self {
157        Self {
158            inner,
159            encoding,
160            buffer: Vec::new(),
161        }
162    }
163
164    fn flush_buffer(&mut self, force: bool) -> io::Result<()> {
165        let mut idx = 0;
166        while idx < self.buffer.len() {
167            match std::str::from_utf8(&self.buffer[idx..]) {
168                Ok(valid) => {
169                    let text = valid.to_owned();
170                    self.encode_and_write(&text)?;
171                    self.buffer.clear();
172                    return Ok(());
173                }
174                Err(err) => {
175                    if let Some(error_len) = err.error_len() {
176                        return Err(io::Error::new(
177                            io::ErrorKind::InvalidData,
178                            format!("Invalid UTF-8 sequence in output stream ({error_len} bytes)"),
179                        ));
180                    }
181                    let valid_up_to = err.valid_up_to();
182                    if valid_up_to > 0 {
183                        let valid_slice = &self.buffer[idx..idx + valid_up_to];
184                        let text = unsafe { std::str::from_utf8_unchecked(valid_slice).to_owned() };
185                        self.encode_and_write(&text)?;
186                        self.buffer.drain(..idx + valid_up_to);
187                        idx = 0;
188                        continue;
189                    }
190                    if force {
191                        return Err(io::Error::new(
192                            io::ErrorKind::InvalidData,
193                            "Incomplete UTF-8 sequence at end of output stream",
194                        ));
195                    } else {
196                        return Ok(());
197                    }
198                }
199            }
200        }
201        if force && !self.buffer.is_empty() {
202            let text = String::from_utf8(self.buffer.clone()).map_err(|_| {
203                io::Error::new(
204                    io::ErrorKind::InvalidData,
205                    "Invalid UTF-8 sequence at end of output stream",
206                )
207            })?;
208            self.encode_and_write(&text)?;
209            self.buffer.clear();
210        }
211        Ok(())
212    }
213
214    fn encode_and_write(&mut self, text: &str) -> io::Result<()> {
215        let (encoded, _output_encoding, had_errors) = self.encoding.encode(text);
216        if had_errors {
217            return Err(io::Error::new(
218                io::ErrorKind::InvalidData,
219                format!("Failed to encode text using {}", self.encoding.name()),
220            ));
221        }
222        self.inner.write_all(encoded.as_ref())
223    }
224}
225
226impl<W: Write> Write for TranscodingWriter<W> {
227    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
228        self.buffer.extend_from_slice(buf);
229        self.flush_buffer(false)?;
230        Ok(buf.len())
231    }
232
233    fn flush(&mut self) -> io::Result<()> {
234        self.flush_buffer(true)?;
235        self.inner.flush()
236    }
237}