ssh_vault/vault/
dio.rs

1use std::fs::{File, OpenOptions};
2use std::io::{self, IsTerminal, Read, Write};
3
4pub enum InputSource {
5    Stdin,
6    File(File),
7}
8
9impl InputSource {
10    /// Create a new input source from an optional file path.
11    ///
12    /// # Errors
13    ///
14    /// Returns an error if the provided file cannot be opened.
15    pub fn new(input: Option<String>) -> io::Result<Self> {
16        if let Some(filename) = input {
17            // Use a file if the filename is not "-" (stdin)
18            if filename != "-" {
19                return Ok(Self::File(File::open(filename)?));
20            }
21        }
22
23        Ok(Self::Stdin)
24    }
25
26    #[must_use]
27    pub fn is_terminal(&self) -> bool {
28        matches!(self, Self::Stdin) && io::stdin().is_terminal()
29    }
30}
31
32impl Read for InputSource {
33    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
34        match self {
35            Self::Stdin => io::stdin().read(buf),
36            Self::File(file) => file.read(buf),
37        }
38    }
39}
40
41// OutputDestination is a wrapper around stdout or a temporary file
42pub enum OutputDestination {
43    Stdout,
44    File(File),
45}
46
47impl OutputDestination {
48    /// Create a new output destination from an optional file path.
49    ///
50    /// # Errors
51    ///
52    /// Returns an error if the provided file cannot be created or opened.
53    #[allow(clippy::suspicious_open_options)]
54    pub fn new(output: Option<String>) -> io::Result<Self> {
55        if let Some(filename) = output {
56            // Use a file if the filename is not "-" (stdout)
57            if filename != "-" {
58                return Ok(Self::File(
59                    OpenOptions::new().write(true).create(true).open(filename)?,
60                ));
61            }
62        }
63
64        Ok(Self::Stdout)
65    }
66
67    /// Truncate the underlying file (if any).
68    ///
69    /// # Errors
70    ///
71    /// Returns an error if truncation fails.
72    pub fn truncate(&self) -> io::Result<()> {
73        match self {
74            Self::File(file) => file.set_len(0),
75            Self::Stdout => Ok(()), // Do nothing for stdout
76        }
77    }
78
79    // Check if the output is empty, preventing overwriting a non-empty file
80    /// Check whether the output destination is empty.
81    ///
82    /// # Errors
83    ///
84    /// Returns an error if file metadata cannot be read.
85    pub fn is_empty(&self) -> io::Result<bool> {
86        match self {
87            Self::File(file) => Ok(file.metadata().map(|m| m.len() == 0).unwrap_or(false)),
88            Self::Stdout => Ok(true), // Do nothing for stdout
89        }
90    }
91}
92
93impl Write for OutputDestination {
94    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
95        match self {
96            Self::Stdout => io::stdout().write(buf),
97            Self::File(file) => file.write(buf),
98        }
99    }
100
101    fn flush(&mut self) -> io::Result<()> {
102        match self {
103            Self::Stdout => io::stdout().flush(),
104            Self::File(file) => file.flush(),
105        }
106    }
107}
108
109/// Configure input and output sources for CLI commands.
110///
111/// # Errors
112///
113/// Returns an error if either the input or output file cannot be opened.
114pub fn setup_io(
115    input: Option<String>,
116    output: Option<String>,
117) -> io::Result<(InputSource, OutputDestination)> {
118    let input = InputSource::new(input)?;
119    let output = OutputDestination::new(output)?;
120
121    Ok((input, output))
122}
123
124#[cfg(test)]
125#[allow(clippy::unwrap_used)]
126mod tests {
127    use super::*;
128    use tempfile::NamedTempFile;
129
130    #[test]
131    fn test_setup_io() {
132        if std::env::var("GITHUB_ACTIONS").is_ok() {
133            return;
134        }
135        let (input, output) = setup_io(None, None).unwrap();
136        assert!(input.is_terminal());
137        assert!(matches!(output, OutputDestination::Stdout));
138
139        let (input, output) = setup_io(Some("-".to_string()), None).unwrap();
140        assert!(input.is_terminal());
141        assert!(matches!(output, OutputDestination::Stdout));
142
143        let rs = setup_io(Some("noneexistent".to_string()), None);
144        assert!(rs.is_err());
145    }
146
147    #[test]
148    fn test_setup_io_file() {
149        let output_file = NamedTempFile::new().unwrap();
150
151        let (input, output) = setup_io(Some("Cargo.toml".to_string()), None).unwrap();
152        assert!(!input.is_terminal());
153        assert!(matches!(output, OutputDestination::Stdout));
154
155        let (input, output) =
156            setup_io(Some("Cargo.toml".to_string()), Some("-".to_string())).unwrap();
157        assert!(!input.is_terminal());
158        assert!(matches!(output, OutputDestination::Stdout));
159
160        let (input, output) = setup_io(
161            Some("Cargo.toml".to_string()),
162            Some(output_file.path().to_str().unwrap().to_string()),
163        )
164        .unwrap();
165        assert!(!input.is_terminal());
166        assert!(matches!(output, OutputDestination::File(_)));
167
168        // File is directory
169        let rs = setup_io(Some("Cargo.toml".to_string()), Some("/".to_string()));
170        assert!(rs.is_err());
171    }
172
173    #[test]
174    fn test_input_source() {
175        let mut input = InputSource::new(Some("Cargo.toml".to_string())).unwrap();
176        let mut buf = [0; 1024];
177        let n = input.read(&mut buf).unwrap();
178        assert!(n > 0);
179
180        let rs = InputSource::new(Some("noneexistent".to_string()));
181        assert!(rs.is_err());
182    }
183
184    #[test]
185    fn test_output_destination() {
186        let mut output = OutputDestination::new(Some("-".to_string())).unwrap();
187        let n = output.write(b"test").unwrap();
188        assert_eq!(n, 4);
189
190        let mut output = OutputDestination::new(None).unwrap();
191        let n = output.write(b"test").unwrap();
192        assert_eq!(n, 4);
193
194        let output_file = NamedTempFile::new().unwrap();
195        let mut output =
196            OutputDestination::new(Some(output_file.path().to_str().unwrap().to_string())).unwrap();
197        let n = output.write(b"test").unwrap();
198        assert_eq!(n, 4);
199    }
200
201    #[test]
202    fn test_output_destination_truncate() {
203        let mut output_file = NamedTempFile::new().unwrap();
204        let mut output =
205            OutputDestination::new(Some(output_file.path().to_str().unwrap().to_string())).unwrap();
206        let n = output.write(b"test").unwrap();
207        assert_eq!(n, 4);
208
209        output.truncate().unwrap();
210        let mut buf = [0; 1024];
211        let n = output_file.read(&mut buf).unwrap();
212        assert_eq!(n, 0);
213    }
214
215    #[test]
216    fn test_output_destination_is_empty() {
217        let output_file = NamedTempFile::new().unwrap();
218        let mut output =
219            OutputDestination::new(Some(output_file.path().to_str().unwrap().to_string())).unwrap();
220        let n = output.write(b"test").unwrap();
221        assert_eq!(n, 4);
222
223        let is_empty = output.is_empty().unwrap();
224        assert!(!is_empty);
225
226        output.truncate().unwrap();
227        let is_empty = output.is_empty().unwrap();
228        assert!(is_empty);
229    }
230}