anubis_age/cli_common/
file_io.rs

1//! File I/O helpers for CLI binaries.
2
3use std::fmt;
4use std::fs::{File, OpenOptions};
5use std::io::{self, Read, Write};
6use std::path::Path;
7
8#[cfg(unix)]
9use std::os::unix::fs::OpenOptionsExt;
10
11use is_terminal::IsTerminal;
12use zeroize::Zeroize;
13
14use crate::{fl, util::LINE_ENDING, wfl, wlnfl};
15
16const SHORT_OUTPUT_LENGTH: usize = 20 * 80;
17
18#[derive(Debug)]
19enum FileError {
20    DenyBinaryOutput,
21    DenyOverwriteFile(String),
22    DetectedBinaryOutput,
23    InvalidFilename(String),
24    MissingDirectory(String),
25}
26
27impl fmt::Display for FileError {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        match self {
30            Self::DenyBinaryOutput => {
31                wlnfl!(f, "err-deny-binary-output")?;
32                wfl!(f, "rec-deny-binary-output")
33            }
34            Self::DenyOverwriteFile(filename) => {
35                wfl!(f, "err-deny-overwrite-file", filename = filename.as_str())
36            }
37            Self::DetectedBinaryOutput => {
38                wlnfl!(f, "err-detected-binary")?;
39                wfl!(f, "rec-detected-binary")
40            }
41            Self::InvalidFilename(filename) => {
42                wfl!(f, "err-invalid-filename", filename = filename.as_str())
43            }
44            Self::MissingDirectory(path) => wfl!(f, "err-missing-directory", path = path.as_str()),
45        }
46    }
47}
48
49impl std::error::Error for FileError {}
50
51/// Wrapper around a [`File`].
52pub struct FileReader {
53    inner: File,
54    filename: String,
55}
56
57/// Wrapper around either a file or standard input.
58pub enum InputReader {
59    /// Wrapper around a file.
60    File(FileReader),
61    /// Wrapper around standard input.
62    Stdin(io::Stdin),
63}
64
65impl InputReader {
66    /// Reads input from the given filename, or standard input if `None` or `Some("-")`.
67    pub fn new(input: Option<String>) -> io::Result<Self> {
68        if let Some(filename) = input {
69            // Respect the Unix convention that "-" as an input filename
70            // parameter is an explicit request to use standard input.
71            if filename != "-" {
72                return Ok(InputReader::File(FileReader {
73                    inner: File::open(&filename)?,
74                    filename,
75                }));
76            }
77        }
78
79        Ok(InputReader::Stdin(io::stdin()))
80    }
81
82    /// Returns true if this input is from a terminal, and a user is likely typing it.
83    pub fn is_terminal(&self) -> bool {
84        matches!(self, Self::Stdin(_)) && io::stdin().is_terminal()
85    }
86
87    pub(crate) fn filename(&self) -> Option<&str> {
88        if let Self::File(f) = self {
89            Some(&f.filename)
90        } else {
91            None
92        }
93    }
94}
95
96impl Read for InputReader {
97    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
98        match self {
99            InputReader::File(f) => f.inner.read(buf),
100            InputReader::Stdin(handle) => handle.read(buf),
101        }
102    }
103}
104
105/// A stdout write that optionally buffers the entire output before writing.
106#[derive(Debug)]
107enum StdoutBuffer {
108    Direct(io::Stdout),
109    Buffered(Vec<u8>),
110}
111
112impl StdoutBuffer {
113    fn direct() -> Self {
114        Self::Direct(io::stdout())
115    }
116
117    fn buffered() -> Self {
118        Self::Buffered(Vec::with_capacity(8 * 1024 * 1024))
119    }
120}
121
122impl Write for StdoutBuffer {
123    fn write(&mut self, data: &[u8]) -> io::Result<usize> {
124        match self {
125            StdoutBuffer::Direct(w) => w.write(data),
126            StdoutBuffer::Buffered(buf) => {
127                // If we need to re-allocate the buffer, do so manually so we can zeroize.
128                if buf.len() + data.len() > buf.capacity() {
129                    let mut new_buf = Vec::with_capacity(std::cmp::max(
130                        buf.capacity() * 2,
131                        buf.capacity() + data.len(),
132                    ));
133                    new_buf.extend_from_slice(buf);
134                    buf.zeroize();
135                    *buf = new_buf;
136                }
137
138                buf.extend_from_slice(data);
139                Ok(data.len())
140            }
141        }
142    }
143
144    fn flush(&mut self) -> io::Result<()> {
145        match self {
146            StdoutBuffer::Direct(w) => w.flush(),
147            StdoutBuffer::Buffered(buf) => {
148                let mut w = io::stdout();
149                w.write_all(buf)?;
150                buf.zeroize();
151                buf.clear();
152                w.flush()
153            }
154        }
155    }
156}
157
158impl Drop for StdoutBuffer {
159    fn drop(&mut self) {
160        // Destructors should not panic, so we ignore a failed flush.
161        let _ = self.flush();
162    }
163}
164
165/// The data format being written out.
166#[derive(Debug)]
167pub enum OutputFormat {
168    /// Binary data that should not be sent to a TTY by default.
169    Binary,
170    /// Text data that is acceptable to send to a TTY.
171    Text,
172    /// Unknown data format; try to avoid sending binary data to a TTY.
173    Unknown,
174}
175
176/// Writer that wraps standard output to handle TTYs nicely.
177#[derive(Debug)]
178pub struct StdoutWriter {
179    inner: StdoutBuffer,
180    count: usize,
181    format: OutputFormat,
182    is_tty: bool,
183    truncated: bool,
184}
185
186impl StdoutWriter {
187    fn new(format: OutputFormat, is_tty: bool, input_is_tty: bool) -> Self {
188        StdoutWriter {
189            // If the input comes from a TTY and the output will go to a TTY, buffer the
190            // output so it doesn't get in the way of typing the input.
191            inner: if input_is_tty && is_tty {
192                StdoutBuffer::buffered()
193            } else {
194                StdoutBuffer::direct()
195            },
196            count: 0,
197            format,
198            is_tty,
199            truncated: false,
200        }
201    }
202}
203
204impl Write for StdoutWriter {
205    fn write(&mut self, data: &[u8]) -> io::Result<usize> {
206        if self.is_tty {
207            if let OutputFormat::Unknown = self.format {
208                // Don't send unprintable output to TTY
209                if std::str::from_utf8(data).is_err() {
210                    return Err(io::Error::new(
211                        io::ErrorKind::InvalidInput,
212                        FileError::DetectedBinaryOutput,
213                    ));
214                }
215            }
216
217            let to_write = if let OutputFormat::Binary = self.format {
218                // Only occurs if the user has explicitly forced stdout, so don't truncate.
219                data.len()
220            } else {
221                // Drop output if we've truncated already, or need to.
222                if self.truncated || self.count == SHORT_OUTPUT_LENGTH {
223                    if !self.truncated {
224                        self.inner.write_all(LINE_ENDING.as_bytes())?;
225                        self.inner.write_all(b"[")?;
226                        self.inner.write_all(fl!("cli-truncated-tty").as_bytes())?;
227                        self.inner.write_all(b"]")?;
228                        self.inner.write_all(LINE_ENDING.as_bytes())?;
229                        self.truncated = true;
230                    }
231
232                    return io::sink().write(data);
233                }
234
235                let mut to_write = SHORT_OUTPUT_LENGTH - self.count;
236                if to_write > data.len() {
237                    to_write = data.len();
238                }
239                to_write
240            };
241
242            let mut ret = self.inner.write(&data[..to_write])?;
243            self.count += to_write;
244
245            if let OutputFormat::Binary = self.format {
246                // Only occurs if the user has explicitly forced stdout, so don't truncate.
247            } else {
248                // If we have reached the output limit with data to spare,
249                // truncate and drop the remainder.
250                if self.count == SHORT_OUTPUT_LENGTH && data.len() > to_write {
251                    if !self.truncated {
252                        self.inner.write_all(LINE_ENDING.as_bytes())?;
253                        self.inner.write_all(b"[")?;
254                        self.inner.write_all(fl!("cli-truncated-tty").as_bytes())?;
255                        self.inner.write_all(b"]")?;
256                        self.inner.write_all(LINE_ENDING.as_bytes())?;
257                        self.truncated = true;
258                    }
259                    ret += io::sink().write(&data[to_write..])?;
260                }
261            }
262
263            Ok(ret)
264        } else {
265            self.inner.write(data)
266        }
267    }
268
269    fn flush(&mut self) -> io::Result<()> {
270        self.inner.flush()
271    }
272}
273
274/// A lazy [`File`] that is not opened until the first call to [`Write::write`] or
275/// [`Write::flush`].
276#[derive(Debug)]
277pub struct LazyFile {
278    filename: String,
279    allow_overwrite: bool,
280    #[cfg(unix)]
281    mode: u32,
282    file: Option<io::Result<File>>,
283}
284
285impl LazyFile {
286    fn get_file(&mut self) -> io::Result<&mut File> {
287        let filename = &self.filename;
288
289        if self.file.is_none() {
290            let mut options = OpenOptions::new();
291            options.write(true);
292            if self.allow_overwrite {
293                options.create(true).truncate(true);
294            } else {
295                // In addition to the check in `OutputWriter::new`, we enforce this at
296                // file opening time to avoid a race condition with the file being
297                // separately created between `OutputWriter` construction and usage.
298                options.create_new(true);
299            }
300
301            #[cfg(unix)]
302            options.mode(self.mode);
303
304            self.file = Some(options.open(filename));
305        }
306
307        self.file
308            .as_mut()
309            .unwrap()
310            .as_mut()
311            .map_err(|e| io::Error::new(e.kind(), format!("Failed to open file '{}'", filename)))
312    }
313}
314
315impl io::Write for LazyFile {
316    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
317        self.get_file()?.write(buf)
318    }
319
320    fn flush(&mut self) -> io::Result<()> {
321        self.get_file()?.flush()
322    }
323}
324
325/// Wrapper around either a file or standard output.
326#[derive(Debug)]
327pub enum OutputWriter {
328    /// Wrapper around a file.
329    File(LazyFile),
330    /// Wrapper around standard output.
331    Stdout(StdoutWriter),
332}
333
334impl OutputWriter {
335    /// Constructs a new `OutputWriter`.
336    ///
337    /// Writes to the file at path `output`, or standard output if `output` is `None` or
338    /// `Some("-")`.
339    ///
340    /// If `allow_overwrite` is `true`, the file at path `output` will be overwritten if
341    /// it exists. This option has no effect if `output` is `None` or `Some("-")`.
342    pub fn new(
343        output: Option<String>,
344        allow_overwrite: bool,
345        mut format: OutputFormat,
346        _mode: u32,
347        input_is_tty: bool,
348    ) -> io::Result<Self> {
349        let is_tty = console::user_attended();
350        if let Some(filename) = output {
351            // Respect the Unix convention that "-" as an output filename
352            // parameter is an explicit request to use standard output.
353            if filename != "-" {
354                let file_path = Path::new(&filename);
355
356                // Provide a better error if the filename is invalid, or the directory
357                // containing the file does not exist (we don't automatically create
358                // directories).
359                if let Some(dir_path) = file_path.parent() {
360                    if !(dir_path == Path::new("") || dir_path.exists()) {
361                        return Err(io::Error::new(
362                            io::ErrorKind::NotFound,
363                            FileError::MissingDirectory(dir_path.display().to_string()),
364                        ));
365                    }
366                } else {
367                    return Err(io::Error::new(
368                        io::ErrorKind::NotFound,
369                        FileError::InvalidFilename(filename),
370                    ));
371                }
372
373                // We open the file lazily, but as we don't want the caller to assume
374                // this, we eagerly confirm that the file does not exist if we can't
375                // overwrite it.
376                if !allow_overwrite && file_path.exists() {
377                    return Err(io::Error::new(
378                        io::ErrorKind::AlreadyExists,
379                        FileError::DenyOverwriteFile(filename),
380                    ));
381                }
382
383                return Ok(OutputWriter::File(LazyFile {
384                    filename,
385                    allow_overwrite,
386                    #[cfg(unix)]
387                    mode: _mode,
388                    file: None,
389                }));
390            } else {
391                // User explicitly requested stdout; force the format to binary so that we
392                // don't try to parse it as UTF-8 in StdoutWriter and perhaps reject it.
393                format = OutputFormat::Binary;
394            }
395        } else if is_tty {
396            if let OutputFormat::Binary = format {
397                // If output == Some("-") then this error is skipped.
398                return Err(io::Error::new(
399                    io::ErrorKind::Other,
400                    FileError::DenyBinaryOutput,
401                ));
402            }
403        }
404
405        Ok(OutputWriter::Stdout(StdoutWriter::new(
406            format,
407            is_tty,
408            input_is_tty,
409        )))
410    }
411
412    /// Returns true if this output is to a terminal, and a user will likely see it.
413    pub fn is_terminal(&self) -> bool {
414        match self {
415            OutputWriter::File(..) => false,
416            OutputWriter::Stdout(w) => w.is_tty,
417        }
418    }
419}
420
421impl Write for OutputWriter {
422    fn write(&mut self, data: &[u8]) -> io::Result<usize> {
423        match self {
424            OutputWriter::File(f) => f.write(data),
425            OutputWriter::Stdout(handle) => handle.write(data),
426        }
427    }
428
429    fn flush(&mut self) -> io::Result<()> {
430        match self {
431            OutputWriter::File(f) => f.flush(),
432            OutputWriter::Stdout(handle) => handle.flush(),
433        }
434    }
435}
436
437#[cfg(test)]
438pub(crate) mod tests {
439    #[cfg(unix)]
440    use super::{OutputFormat, OutputWriter};
441    #[cfg(unix)]
442    use std::io::Write;
443
444    #[cfg(unix)]
445    #[test]
446    fn lazy_existing_file_allow_overwrite() {
447        OutputWriter::new(
448            Some("/dev/null".to_string()),
449            true,
450            OutputFormat::Text,
451            0o600,
452            false,
453        )
454        .unwrap()
455        .flush()
456        .unwrap();
457    }
458
459    #[cfg(unix)]
460    #[test]
461    fn lazy_existing_file_forbid_overwrite() {
462        use std::io;
463
464        let e = OutputWriter::new(
465            Some("/dev/null".to_string()),
466            false,
467            OutputFormat::Text,
468            0o600,
469            false,
470        )
471        .unwrap_err();
472        assert_eq!(e.kind(), io::ErrorKind::AlreadyExists);
473    }
474}