Skip to main content

treemd/
input.rs

1//! Input handling for stdin and file sources
2//!
3//! Provides robust stdin reading with UTF-8 validation and format detection.
4//! Includes security limits to prevent denial-of-service via large inputs.
5
6use std::io::{self, BufRead, IsTerminal};
7use std::path::Path;
8
9/// Maximum input size (100 MB) - prevents memory exhaustion attacks
10const MAX_INPUT_SIZE: usize = 100 * 1024 * 1024;
11
12/// Maximum line size (10 MB) - prevents single-line attacks
13const MAX_LINE_SIZE: usize = 10 * 1024 * 1024;
14
15/// Input source for treemd
16#[derive(Debug)]
17pub enum InputSource {
18    File(String),
19    Stdin(String),
20}
21
22/// Errors that can occur during input reading
23#[derive(Debug)]
24pub enum InputError {
25    Io(io::Error),
26    Utf8Error,
27    EmptyInput,
28    NoTty,
29    InputTooLarge(usize),
30    LineTooLong(usize),
31}
32
33impl std::fmt::Display for InputError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            InputError::Io(e) => write!(f, "I/O error: {}", e),
37            InputError::Utf8Error => write!(f, "Invalid UTF-8 in input"),
38            InputError::EmptyInput => write!(f, "Empty input provided"),
39            InputError::NoTty => {
40                write!(f, "No file specified and stdin is not being piped")
41            }
42            InputError::InputTooLarge(size) => {
43                write!(
44                    f,
45                    "Input too large: {} bytes (max {} MB)",
46                    size,
47                    MAX_INPUT_SIZE / (1024 * 1024)
48                )
49            }
50            InputError::LineTooLong(size) => {
51                write!(
52                    f,
53                    "Line too long: {} bytes (max {} MB)",
54                    size,
55                    MAX_LINE_SIZE / (1024 * 1024)
56                )
57            }
58        }
59    }
60}
61
62impl std::error::Error for InputError {}
63
64impl From<io::Error> for InputError {
65    fn from(e: io::Error) -> Self {
66        InputError::Io(e)
67    }
68}
69
70/// Check if stdin is being piped (not a TTY)
71pub fn is_stdin_piped() -> bool {
72    !io::stdin().is_terminal()
73}
74
75/// Read input from stdin with proper error handling
76///
77/// Implements best practices from Rust stdin handling guides:
78/// - Line-by-line buffered reading for performance
79/// - UTF-8 validation
80/// - Proper error propagation
81/// - Size limits to prevent DoS attacks
82pub fn read_stdin() -> Result<String, InputError> {
83    let stdin = io::stdin();
84    let mut handle = stdin.lock();
85    let mut buffer = String::new();
86    let mut total_size = 0usize;
87    let mut line_buffer = String::new();
88
89    loop {
90        line_buffer.clear();
91        let bytes_read = handle.read_line(&mut line_buffer)?;
92
93        // EOF reached
94        if bytes_read == 0 {
95            break;
96        }
97
98        // Check line size limit
99        if line_buffer.len() > MAX_LINE_SIZE {
100            return Err(InputError::LineTooLong(line_buffer.len()));
101        }
102
103        // Check total size limit
104        total_size = total_size.saturating_add(bytes_read);
105        if total_size > MAX_INPUT_SIZE {
106            return Err(InputError::InputTooLarge(total_size));
107        }
108
109        buffer.push_str(&line_buffer);
110    }
111
112    // Validate UTF-8 (String already enforces this, but explicit check)
113    if buffer.is_empty() {
114        return Err(InputError::EmptyInput);
115    }
116
117    Ok(buffer)
118}
119
120/// Determine input source based on arguments and stdin state
121///
122/// Priority:
123/// 1. If file path is exactly "-", read from stdin
124/// 2. If file path is provided, use file
125/// 3. If no file and stdin is piped, read from stdin
126/// 4. Otherwise, error (no input available)
127pub fn determine_input_source(file_path: Option<&Path>) -> Result<InputSource, InputError> {
128    match file_path {
129        Some(path) if path == Path::new("-") => {
130            // Explicit stdin via "-"
131            let content = read_stdin()?;
132            Ok(InputSource::Stdin(content))
133        }
134        Some(path) => {
135            // File path provided
136            let content = std::fs::read_to_string(path).map_err(InputError::Io)?;
137            Ok(InputSource::File(content))
138        }
139        None if is_stdin_piped() => {
140            // No file, but stdin is piped
141            let content = read_stdin()?;
142            Ok(InputSource::Stdin(content))
143        }
144        None => {
145            // No file and stdin is TTY (user error)
146            Err(InputError::NoTty)
147        }
148    }
149}
150
151/// Process input and return content ready for markdown parsing
152///
153/// Supports:
154/// - Raw markdown (passed through)
155/// - Plain text (wrapped in markdown heading)
156pub fn process_input(source: InputSource) -> Result<String, Box<dyn std::error::Error>> {
157    let content = match source {
158        InputSource::File(c) | InputSource::Stdin(c) => c,
159    };
160
161    // Check if content looks like markdown (has headings)
162    if content.trim_start().starts_with('#') || content.contains("\n#") {
163        // Markdown content, pass through
164        Ok(content)
165    } else {
166        // Plain text - wrap in a document heading for basic viewing
167        let mut markdown = String::from("# Input\n\n");
168        markdown.push_str(&content);
169        Ok(markdown)
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn test_process_markdown_input() {
179        let markdown = "# Title\n\nContent here\n\n## Section\n";
180        let source = InputSource::Stdin(markdown.to_string());
181
182        let result = process_input(source).unwrap();
183        assert_eq!(result, markdown);
184    }
185
186    #[test]
187    fn test_process_plain_text() {
188        let text = "Just some plain text\nwith multiple lines";
189        let source = InputSource::Stdin(text.to_string());
190
191        let result = process_input(source).unwrap();
192        assert!(result.starts_with("# Input\n\n"));
193        assert!(result.contains("Just some plain text"));
194    }
195}