1use rable::{Node, NodeKind};
2
3use crate::allowlists;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum RedirectOp {
8 Write,
10 Append,
12 Read,
14 FdDup,
16 Other,
18}
19
20#[must_use]
22pub fn command_name_from_words(words: &[Node]) -> Option<&str> {
23 words.first().and_then(word_value)
24}
25
26#[must_use]
28pub fn command_name(node: &Node) -> Option<&str> {
29 let NodeKind::Command { words, .. } = &node.kind else {
30 return None;
31 };
32 command_name_from_words(words)
33}
34
35#[must_use]
37pub fn command_args_from_words(words: &[Node]) -> Vec<String> {
38 words.iter().skip(1).map(node_text).collect()
39}
40
41#[must_use]
43pub fn command_args(node: &Node) -> Vec<String> {
44 let NodeKind::Command { words, .. } = &node.kind else {
45 return Vec::new();
46 };
47 command_args_from_words(words)
48}
49
50#[must_use]
52pub fn redirect_info(node: &Node) -> Option<(RedirectOp, String)> {
53 let NodeKind::Redirect { op, target, .. } = &node.kind else {
54 return None;
55 };
56 let redirect_op = match op.as_str() {
57 ">" => RedirectOp::Write,
58 ">>" => RedirectOp::Append,
59 "<" | "<<<" => RedirectOp::Read,
60 "&>" | ">&" => RedirectOp::FdDup,
61 _ => RedirectOp::Other,
62 };
63 Some((redirect_op, node_text(target)))
64}
65
66#[must_use]
71pub fn has_expansions(node: &Node) -> bool {
72 has_expansions_kind(&node.kind)
73}
74
75#[must_use]
77pub fn has_expansions_in_slices(words: &[Node], redirects: &[Node]) -> bool {
78 words.iter().any(has_expansions) || redirects.iter().any(has_expansions)
79}
80
81#[must_use]
87pub const fn is_expansion_node(kind: &NodeKind) -> bool {
88 matches!(
89 kind,
90 NodeKind::CommandSubstitution { .. }
91 | NodeKind::ProcessSubstitution { .. }
92 | NodeKind::ParamExpansion { .. }
93 | NodeKind::ParamIndirect { .. }
94 | NodeKind::ParamLength { .. }
95 | NodeKind::AnsiCQuote { .. }
96 | NodeKind::LocaleString { .. }
97 | NodeKind::ArithmeticExpansion { .. }
98 | NodeKind::BraceExpansion { .. }
99 )
100}
101
102fn has_expansions_kind(kind: &NodeKind) -> bool {
103 if is_expansion_node(kind) {
104 return true;
105 }
106 match kind {
107 NodeKind::Word { value, parts, .. } => {
108 if parts.is_empty() {
116 has_shell_expansion_pattern(value)
117 } else {
118 parts.iter().any(has_expansions)
119 }
120 }
121 NodeKind::Command {
122 words, redirects, ..
123 } => has_expansions_in_slices(words, redirects),
124 NodeKind::Pipeline { commands, .. } => commands.iter().any(has_expansions),
125 NodeKind::List { items } => items.iter().any(|item| has_expansions(&item.command)),
126 NodeKind::Redirect { target, .. } => has_expansions(target),
127 NodeKind::If {
128 condition,
129 then_body,
130 else_body,
131 ..
132 } => {
133 has_expansions(condition)
134 || has_expansions(then_body)
135 || else_body.as_deref().is_some_and(has_expansions)
136 }
137 NodeKind::Subshell { body, .. } | NodeKind::BraceGroup { body, .. } => has_expansions(body),
138 NodeKind::HereDoc {
139 content, quoted, ..
140 } => !quoted && has_shell_expansion_pattern(content),
141 _ => false,
142 }
143}
144
145#[must_use]
150pub fn has_shell_expansion_pattern(s: &str) -> bool {
151 let bytes = s.as_bytes();
152 for (i, &b) in bytes.iter().enumerate() {
153 if b == b'`' {
154 return true;
155 }
156 if b == b'$'
157 && let Some(&next) = bytes.get(i + 1)
158 && (next == b'('
159 || next == b'{'
160 || next == b'\''
161 || next == b'"'
162 || next.is_ascii_alphabetic()
163 || next == b'_')
164 {
165 return true;
166 }
167 }
168 false
169}
170
171#[must_use]
173pub fn is_safe_redirect_target(target: &str) -> bool {
174 matches!(target, "/dev/null" | "/dev/stdout" | "/dev/stderr")
175}
176
177#[must_use]
180pub fn has_unsafe_file_redirect(node: &Node) -> bool {
181 let NodeKind::Command { redirects, .. } = &node.kind else {
182 return false;
183 };
184 redirects.iter().any(|r| {
185 let Some((op, target)) = redirect_info(r) else {
186 return false;
187 };
188 matches!(op, RedirectOp::Write | RedirectOp::Append) && !is_safe_redirect_target(&target)
189 })
190}
191
192#[must_use]
194pub fn is_harmless_fallback(node: &Node) -> bool {
195 let Some(name) = command_name(node) else {
196 return false;
197 };
198 matches!(name, "true" | "false" | ":" | "echo" | "printf")
199}
200
201fn node_text(node: &Node) -> String {
203 if let NodeKind::Word { value, .. } = &node.kind {
204 strip_quotes(value)
205 } else {
206 String::new()
207 }
208}
209
210const fn word_value(node: &Node) -> Option<&str> {
212 if let NodeKind::Word { value, .. } = &node.kind {
213 Some(value.as_str())
214 } else {
215 None
216 }
217}
218
219fn strip_quotes(s: &str) -> String {
221 let s = s.trim();
222 if (s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')) {
223 s[1..s.len() - 1].to_owned()
224 } else if s.len() >= 3
225 && ((s.starts_with("$'") && s.ends_with('\''))
226 || (s.starts_with("$\"") && s.ends_with('"')))
227 {
228 s[2..s.len() - 1].to_owned()
229 } else {
230 s.to_owned()
231 }
232}
233
234#[must_use]
255pub fn is_safe_heredoc_substitution(command: &Node) -> bool {
256 let NodeKind::Command {
257 words, redirects, ..
258 } = &command.kind
259 else {
260 return false;
261 };
262 let Some(name) = command_name_from_words(words) else {
263 return false;
264 };
265 if !allowlists::is_simple_safe(name) {
266 return false;
267 }
268 if redirects.is_empty() {
269 return false;
270 }
271 let all_quoted_heredocs = redirects
272 .iter()
273 .all(|r| matches!(&r.kind, NodeKind::HereDoc { quoted, .. } if *quoted));
274 if !all_quoted_heredocs {
275 return false;
276 }
277 !words.iter().any(has_expansions)
278}
279
280#[cfg(test)]
281#[allow(clippy::unwrap_used)]
282mod tests {
283 use crate::parser::BashParser;
284
285 use super::*;
286
287 fn parse_first(source: &str) -> Vec<Node> {
288 let mut parser = BashParser::new().unwrap();
289 parser.parse(source).unwrap()
290 }
291
292 fn find_command(nodes: &[Node]) -> Option<&Node> {
293 for node in nodes {
294 match &node.kind {
295 NodeKind::Command { .. } => return Some(node),
296 NodeKind::Pipeline { commands, .. } => {
297 if let Some(cmd) = find_command(commands) {
298 return Some(cmd);
299 }
300 }
301 NodeKind::List { items } => {
302 let nodes: Vec<&Node> = items.iter().map(|i| &i.command).collect();
303 if let Some(cmd) = find_command_refs(&nodes) {
304 return Some(cmd);
305 }
306 }
307 _ => {}
308 }
309 }
310 None
311 }
312
313 fn find_command_refs<'a>(nodes: &[&'a Node]) -> Option<&'a Node> {
314 for node in nodes {
315 if matches!(node.kind, NodeKind::Command { .. }) {
316 return Some(node);
317 }
318 }
319 None
320 }
321
322 #[test]
323 fn extract_command_name() {
324 let nodes = parse_first("git status");
325 let cmd = find_command(&nodes).unwrap();
326 assert_eq!(command_name(cmd), Some("git"));
327 }
328
329 #[test]
330 fn extract_command_args() {
331 let nodes = parse_first("git commit -m 'hello world'");
332 let cmd = find_command(&nodes).unwrap();
333 let args = command_args(cmd);
334 assert!(args.contains(&"commit".to_owned()));
335 assert!(args.contains(&"-m".to_owned()));
336 }
337
338 #[test]
339 fn detect_command_substitution() {
340 let nodes = parse_first("echo $(whoami)");
341 assert!(has_expansions(&nodes[0]));
342 }
343
344 #[test]
345 fn no_expansions_in_literal() {
346 let nodes = parse_first("echo hello");
347 let cmd = find_command(&nodes).unwrap();
348 assert!(!has_expansions(cmd));
349 }
350
351 #[test]
352 fn redirect_write() {
353 let nodes = parse_first("echo foo > output.txt");
354 let NodeKind::Command { redirects, .. } = &nodes[0].kind else {
355 unreachable!("expected Command node");
356 };
357 let (op, target) = redirect_info(&redirects[0]).unwrap();
358 assert_eq!(op, RedirectOp::Write);
359 assert_eq!(target, "output.txt");
360 }
361
362 #[test]
363 fn redirect_append() {
364 let nodes = parse_first("echo foo >> log.txt");
365 let NodeKind::Command { redirects, .. } = &nodes[0].kind else {
366 unreachable!("expected Command node");
367 };
368 let (op, target) = redirect_info(&redirects[0]).unwrap();
369 assert_eq!(op, RedirectOp::Append);
370 assert_eq!(target, "log.txt");
371 }
372
373 #[test]
376 fn detect_param_expansion() {
377 let nodes = parse_first("echo ${HOME}");
378 assert!(has_expansions(&nodes[0]));
379 }
380
381 #[test]
382 fn detect_simple_var_expansion() {
383 let nodes = parse_first("echo $HOME");
384 assert!(has_expansions(&nodes[0]));
385 }
386
387 #[test]
388 fn detect_param_length() {
389 let nodes = parse_first("echo ${#var}");
390 assert!(has_expansions(&nodes[0]));
391 }
392
393 #[test]
394 fn detect_param_indirect() {
395 let nodes = parse_first("echo ${!ref}");
396 assert!(has_expansions(&nodes[0]));
397 }
398
399 #[test]
400 fn detect_ansi_c_quote() {
401 let nodes = parse_first("echo $'\\x41'");
402 assert!(has_expansions(&nodes[0]));
403 }
404
405 #[test]
406 fn detect_locale_string() {
407 let nodes = parse_first("echo $\"hello\"");
408 assert!(has_expansions(&nodes[0]));
409 }
410
411 #[test]
412 fn detect_arithmetic_expansion_inline() {
413 let nodes = parse_first("echo $((1+1))");
414 assert!(has_expansions(&nodes[0]));
415 }
416
417 #[test]
418 fn detect_brace_expansion() {
419 let nodes = parse_first("echo {a,b,c}");
420 assert!(has_expansions(&nodes[0]));
421 }
422
423 #[test]
424 fn detect_brace_expansion_range() {
425 let nodes = parse_first("echo {1..10}");
426 assert!(has_expansions(&nodes[0]));
427 }
428
429 #[test]
432 fn strip_ansi_c_quotes() {
433 assert_eq!(strip_quotes("$'hello'"), "hello");
434 }
435
436 #[test]
437 fn strip_locale_quotes() {
438 assert_eq!(strip_quotes("$\"hello\""), "hello");
439 }
440
441 #[test]
442 fn strip_regular_quotes_unchanged() {
443 assert_eq!(strip_quotes("'hello'"), "hello");
444 assert_eq!(strip_quotes("\"hello\""), "hello");
445 assert_eq!(strip_quotes("hello"), "hello");
446 }
447
448 #[test]
451 fn expansion_pattern_detects_dollar_var() {
452 assert!(has_shell_expansion_pattern("$HOME"));
453 assert!(has_shell_expansion_pattern("hello $USER world"));
454 assert!(has_shell_expansion_pattern("$_private"));
455 }
456
457 #[test]
458 fn expansion_pattern_detects_braced() {
459 assert!(has_shell_expansion_pattern("${HOME}"));
460 }
461
462 #[test]
463 fn expansion_pattern_detects_command_sub() {
464 assert!(has_shell_expansion_pattern("$(whoami)"));
465 assert!(has_shell_expansion_pattern("`whoami`"));
466 }
467
468 #[test]
469 fn expansion_pattern_detects_ansi_c() {
470 assert!(has_shell_expansion_pattern("$'hello'"));
471 }
472
473 #[test]
474 fn expansion_pattern_no_false_positive() {
475 assert!(!has_shell_expansion_pattern("hello world"));
476 assert!(!has_shell_expansion_pattern("price is $5"));
477 assert!(!has_shell_expansion_pattern(""));
478 }
479}