1use rable::{Node, NodeKind};
2
3use crate::allowlists;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum RedirectOp {
8 Write,
10 Append,
12 Read,
14 FdDup,
16 Other,
18}
19
20#[must_use]
22pub fn command_name_from_words(words: &[Node]) -> Option<&str> {
23 words.first().and_then(word_value)
24}
25
26#[must_use]
28pub fn command_name(node: &Node) -> Option<&str> {
29 let NodeKind::Command { words, .. } = &node.kind else {
30 return None;
31 };
32 command_name_from_words(words)
33}
34
35#[must_use]
37pub fn command_args_from_words(words: &[Node]) -> Vec<String> {
38 words.iter().skip(1).map(node_text).collect()
39}
40
41#[must_use]
43pub fn command_args(node: &Node) -> Vec<String> {
44 let NodeKind::Command { words, .. } = &node.kind else {
45 return Vec::new();
46 };
47 command_args_from_words(words)
48}
49
50#[must_use]
52pub fn redirect_info(node: &Node) -> Option<(RedirectOp, String)> {
53 let NodeKind::Redirect { op, target, .. } = &node.kind else {
54 return None;
55 };
56 let redirect_op = match op.as_str() {
57 ">" => RedirectOp::Write,
58 ">>" => RedirectOp::Append,
59 "<" | "<<<" => RedirectOp::Read,
60 "&>" | ">&" => RedirectOp::FdDup,
61 _ => RedirectOp::Other,
62 };
63 Some((redirect_op, node_text(target)))
64}
65
66#[must_use]
71pub fn has_expansions(node: &Node) -> bool {
72 has_expansions_kind(&node.kind)
73}
74
75#[must_use]
77pub fn has_expansions_in_slices(words: &[Node], redirects: &[Node]) -> bool {
78 words.iter().any(has_expansions) || redirects.iter().any(has_expansions)
79}
80
81#[must_use]
87pub const fn is_expansion_node(kind: &NodeKind) -> bool {
88 matches!(
89 kind,
90 NodeKind::CommandSubstitution { .. }
91 | NodeKind::ProcessSubstitution { .. }
92 | NodeKind::ParamExpansion { .. }
93 | NodeKind::ParamIndirect { .. }
94 | NodeKind::ParamLength { .. }
95 | NodeKind::AnsiCQuote { .. }
96 | NodeKind::LocaleString { .. }
97 | NodeKind::ArithmeticExpansion { .. }
98 | NodeKind::BraceExpansion { .. }
99 )
100}
101
102fn has_expansions_kind(kind: &NodeKind) -> bool {
103 if is_expansion_node(kind) {
104 return true;
105 }
106 match kind {
107 NodeKind::Word { value, parts, .. } => {
108 if parts.is_empty() {
116 has_shell_expansion_pattern(value)
117 } else {
118 parts.iter().any(has_expansions)
119 }
120 }
121 NodeKind::Command {
122 words, redirects, ..
123 } => has_expansions_in_slices(words, redirects),
124 NodeKind::Pipeline { commands, .. } => commands.iter().any(has_expansions),
125 NodeKind::List { items } => items.iter().any(|item| has_expansions(&item.command)),
126 NodeKind::Redirect { target, .. } => has_expansions(target),
127 NodeKind::If {
128 condition,
129 then_body,
130 else_body,
131 ..
132 } => {
133 has_expansions(condition)
134 || has_expansions(then_body)
135 || else_body.as_deref().is_some_and(has_expansions)
136 }
137 NodeKind::Subshell { body, .. } | NodeKind::BraceGroup { body, .. } => has_expansions(body),
138 NodeKind::HereDoc {
139 content, quoted, ..
140 } => !quoted && has_shell_expansion_pattern(content),
141 _ => false,
142 }
143}
144
145#[must_use]
150pub fn has_shell_expansion_pattern(s: &str) -> bool {
151 let bytes = s.as_bytes();
152 for (i, &b) in bytes.iter().enumerate() {
153 if b == b'`' {
154 return true;
155 }
156 if b == b'$'
157 && let Some(&next) = bytes.get(i + 1)
158 && (next == b'('
159 || next == b'{'
160 || next == b'\''
161 || next == b'"'
162 || next.is_ascii_alphabetic()
163 || next == b'_')
164 {
165 return true;
166 }
167 }
168 false
169}
170
171#[must_use]
173pub fn is_safe_redirect_target(target: &str) -> bool {
174 matches!(target, "/dev/null" | "/dev/stdout" | "/dev/stderr")
175}
176
177#[must_use]
180pub fn has_unsafe_file_redirect(node: &Node) -> bool {
181 let NodeKind::Command { redirects, .. } = &node.kind else {
182 return false;
183 };
184 redirects.iter().any(|r| {
185 let Some((op, target)) = redirect_info(r) else {
186 return false;
187 };
188 matches!(op, RedirectOp::Write | RedirectOp::Append) && !is_safe_redirect_target(&target)
189 })
190}
191
192#[must_use]
194pub fn is_harmless_fallback(node: &Node) -> bool {
195 let Some(name) = command_name(node) else {
196 return false;
197 };
198 matches!(name, "true" | "false" | ":" | "echo" | "printf")
199}
200
201fn node_text(node: &Node) -> String {
203 if let NodeKind::Word { value, .. } = &node.kind {
204 strip_quotes(value)
205 } else {
206 String::new()
207 }
208}
209
210const fn word_value(node: &Node) -> Option<&str> {
212 if let NodeKind::Word { value, .. } = &node.kind {
213 Some(value.as_str())
214 } else {
215 None
216 }
217}
218
219fn strip_quotes(s: &str) -> String {
221 let s = s.trim();
222 if (s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')) {
223 s[1..s.len() - 1].to_owned()
224 } else if s.len() >= 3
225 && ((s.starts_with("$'") && s.ends_with('\''))
226 || (s.starts_with("$\"") && s.ends_with('"')))
227 {
228 s[2..s.len() - 1].to_owned()
229 } else {
230 s.to_owned()
231 }
232}
233
234#[must_use]
241pub fn is_safe_heredoc_substitution(command: &Node) -> bool {
242 let NodeKind::Command {
243 words, redirects, ..
244 } = &command.kind
245 else {
246 return false;
247 };
248 let Some(name) = command_name_from_words(words) else {
249 return false;
250 };
251 if !allowlists::is_simple_safe(name) {
252 return false;
253 }
254 if redirects.is_empty() {
255 return false;
256 }
257 let all_quoted_heredocs = redirects
258 .iter()
259 .all(|r| matches!(&r.kind, NodeKind::HereDoc { quoted, .. } if *quoted));
260 if !all_quoted_heredocs {
261 return false;
262 }
263 !words.iter().any(has_expansions)
264}
265
266#[cfg(test)]
267#[allow(clippy::unwrap_used)]
268mod tests {
269 use crate::parser::BashParser;
270
271 use super::*;
272
273 fn parse_first(source: &str) -> Vec<Node> {
274 let mut parser = BashParser::new().unwrap();
275 parser.parse(source).unwrap()
276 }
277
278 fn find_command(nodes: &[Node]) -> Option<&Node> {
279 for node in nodes {
280 match &node.kind {
281 NodeKind::Command { .. } => return Some(node),
282 NodeKind::Pipeline { commands, .. } => {
283 if let Some(cmd) = find_command(commands) {
284 return Some(cmd);
285 }
286 }
287 NodeKind::List { items } => {
288 let nodes: Vec<&Node> = items.iter().map(|i| &i.command).collect();
289 if let Some(cmd) = find_command_refs(&nodes) {
290 return Some(cmd);
291 }
292 }
293 _ => {}
294 }
295 }
296 None
297 }
298
299 fn find_command_refs<'a>(nodes: &[&'a Node]) -> Option<&'a Node> {
300 for node in nodes {
301 if matches!(node.kind, NodeKind::Command { .. }) {
302 return Some(node);
303 }
304 }
305 None
306 }
307
308 #[test]
309 fn extract_command_name() {
310 let nodes = parse_first("git status");
311 let cmd = find_command(&nodes).unwrap();
312 assert_eq!(command_name(cmd), Some("git"));
313 }
314
315 #[test]
316 fn extract_command_args() {
317 let nodes = parse_first("git commit -m 'hello world'");
318 let cmd = find_command(&nodes).unwrap();
319 let args = command_args(cmd);
320 assert!(args.contains(&"commit".to_owned()));
321 assert!(args.contains(&"-m".to_owned()));
322 }
323
324 #[test]
325 fn detect_command_substitution() {
326 let nodes = parse_first("echo $(whoami)");
327 assert!(has_expansions(&nodes[0]));
328 }
329
330 #[test]
331 fn no_expansions_in_literal() {
332 let nodes = parse_first("echo hello");
333 let cmd = find_command(&nodes).unwrap();
334 assert!(!has_expansions(cmd));
335 }
336
337 #[test]
338 fn redirect_write() {
339 let nodes = parse_first("echo foo > output.txt");
340 let NodeKind::Command { redirects, .. } = &nodes[0].kind else {
341 unreachable!("expected Command node");
342 };
343 let (op, target) = redirect_info(&redirects[0]).unwrap();
344 assert_eq!(op, RedirectOp::Write);
345 assert_eq!(target, "output.txt");
346 }
347
348 #[test]
349 fn redirect_append() {
350 let nodes = parse_first("echo foo >> log.txt");
351 let NodeKind::Command { redirects, .. } = &nodes[0].kind else {
352 unreachable!("expected Command node");
353 };
354 let (op, target) = redirect_info(&redirects[0]).unwrap();
355 assert_eq!(op, RedirectOp::Append);
356 assert_eq!(target, "log.txt");
357 }
358
359 #[test]
362 fn detect_param_expansion() {
363 let nodes = parse_first("echo ${HOME}");
364 assert!(has_expansions(&nodes[0]));
365 }
366
367 #[test]
368 fn detect_simple_var_expansion() {
369 let nodes = parse_first("echo $HOME");
370 assert!(has_expansions(&nodes[0]));
371 }
372
373 #[test]
374 fn detect_param_length() {
375 let nodes = parse_first("echo ${#var}");
376 assert!(has_expansions(&nodes[0]));
377 }
378
379 #[test]
380 fn detect_param_indirect() {
381 let nodes = parse_first("echo ${!ref}");
382 assert!(has_expansions(&nodes[0]));
383 }
384
385 #[test]
386 fn detect_ansi_c_quote() {
387 let nodes = parse_first("echo $'\\x41'");
388 assert!(has_expansions(&nodes[0]));
389 }
390
391 #[test]
392 fn detect_locale_string() {
393 let nodes = parse_first("echo $\"hello\"");
394 assert!(has_expansions(&nodes[0]));
395 }
396
397 #[test]
398 fn detect_arithmetic_expansion_inline() {
399 let nodes = parse_first("echo $((1+1))");
400 assert!(has_expansions(&nodes[0]));
401 }
402
403 #[test]
404 fn detect_brace_expansion() {
405 let nodes = parse_first("echo {a,b,c}");
406 assert!(has_expansions(&nodes[0]));
407 }
408
409 #[test]
410 fn detect_brace_expansion_range() {
411 let nodes = parse_first("echo {1..10}");
412 assert!(has_expansions(&nodes[0]));
413 }
414
415 #[test]
418 fn strip_ansi_c_quotes() {
419 assert_eq!(strip_quotes("$'hello'"), "hello");
420 }
421
422 #[test]
423 fn strip_locale_quotes() {
424 assert_eq!(strip_quotes("$\"hello\""), "hello");
425 }
426
427 #[test]
428 fn strip_regular_quotes_unchanged() {
429 assert_eq!(strip_quotes("'hello'"), "hello");
430 assert_eq!(strip_quotes("\"hello\""), "hello");
431 assert_eq!(strip_quotes("hello"), "hello");
432 }
433
434 #[test]
437 fn expansion_pattern_detects_dollar_var() {
438 assert!(has_shell_expansion_pattern("$HOME"));
439 assert!(has_shell_expansion_pattern("hello $USER world"));
440 assert!(has_shell_expansion_pattern("$_private"));
441 }
442
443 #[test]
444 fn expansion_pattern_detects_braced() {
445 assert!(has_shell_expansion_pattern("${HOME}"));
446 }
447
448 #[test]
449 fn expansion_pattern_detects_command_sub() {
450 assert!(has_shell_expansion_pattern("$(whoami)"));
451 assert!(has_shell_expansion_pattern("`whoami`"));
452 }
453
454 #[test]
455 fn expansion_pattern_detects_ansi_c() {
456 assert!(has_shell_expansion_pattern("$'hello'"));
457 }
458
459 #[test]
460 fn expansion_pattern_no_false_positive() {
461 assert!(!has_shell_expansion_pattern("hello world"));
462 assert!(!has_shell_expansion_pattern("price is $5"));
463 assert!(!has_shell_expansion_pattern(""));
464 }
465}