1use std::process::Command;
6
7pub fn escape_powershell_string(s: &str) -> String {
9 s.replace('`', "``").replace('"', "`\"").replace('$', "`$")
11}
12
13use serde::{Deserialize, Serialize};
18
19#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21pub enum PipelineElementType {
22 CommandAst,
23 CommandExpressionAst,
24 ParenExpressionAst,
25}
26
27#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
29pub enum CommandElementType {
30 ScriptBlock,
31 SubExpression,
32 ExpandableString,
33 MemberInvocation,
34 Variable,
35 StringConstant,
36 Parameter,
37 Other,
38}
39
40#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
42pub enum StatementType {
43 PipelineAst,
44 PipelineChainAst,
45 AssignmentStatementAst,
46 IfStatementAst,
47 ForStatementAst,
48 ForEachStatementAst,
49 WhileStatementAst,
50 DoWhileStatementAst,
51 DoUntilStatementAst,
52 SwitchStatementAst,
53 TryStatementAst,
54 TrapStatementAst,
55 FunctionDefinitionAst,
56 DataStatementAst,
57 UnknownStatementAst,
58}
59
60#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub struct CommandElementChild {
63 pub element_type: CommandElementType,
64 pub text: String,
65}
66
67#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
69pub struct ParsedRedirection {
70 pub from: String,
71 pub to: String,
72 pub is_merging: bool,
73}
74
75#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
77pub struct ParsedCommandElement {
78 pub name: String,
79 pub name_type: String,
80 pub element_type: PipelineElementType,
81 pub args: Vec<String>,
82 pub text: String,
83 pub element_types: Option<Vec<CommandElementType>>,
84 pub children: Option<Vec<Option<Vec<CommandElementChild>>>>,
85 pub redirections: Option<Vec<ParsedRedirection>>,
86}
87
88#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
90pub struct PipelineSegment {
91 pub commands: Vec<ParsedCommandElement>,
92 pub redirections: Vec<ParsedRedirection>,
93 pub nested_commands: Option<Vec<ParsedCommandElement>>,
94}
95
96#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
98pub struct ParsedStatement {
99 pub statement_type: StatementType,
100 pub commands: Vec<ParsedCommandElement>,
101}
102
103#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105pub struct ParsedPowerShellCommand {
106 pub valid: bool,
107 pub statements: Vec<ParsedStatement>,
108 pub error: Option<String>,
109}
110
111pub fn is_powershell_parameter(arg: &str, element_type: Option<&CommandElementType>) -> bool {
113 if let Some(et) = element_type {
114 return *et == CommandElementType::Parameter;
115 }
116 arg.starts_with('-')
118 || arg.starts_with('/')
119 || arg.starts_with('–')
120 || arg.starts_with('—')
121 || arg.starts_with('―')
122}
123
124pub const PS_TOKENIZER_DASH_CHARS: &[char] = &['-', '–', '—', '―', '/'];
126
127pub fn parse_powershell_command(command: &str) -> ParsedPowerShellCommand {
129 let trimmed = command.trim();
130
131 if trimmed.is_empty() {
132 return ParsedPowerShellCommand {
133 valid: false,
134 statements: vec![],
135 error: Some("Empty command".to_string()),
136 };
137 }
138
139 let statement_strs: Vec<&str> = trimmed
141 .split(|c| c == ';' || c == '\n')
142 .filter(|s| !s.trim().is_empty())
143 .collect();
144
145 let mut statements = Vec::new();
146
147 for stmt_str in statement_strs {
148 let statement_type = detect_statement_type(stmt_str);
149
150 let pipeline_strs: Vec<&str> = stmt_str.split('|').collect();
152 let mut commands = Vec::new();
153
154 for (idx, pipeline_str) in pipeline_strs.iter().enumerate() {
155 let pipeline_trimmed = pipeline_str.trim();
156 if pipeline_trimmed.is_empty() {
157 continue;
158 }
159
160 let parts: Vec<&str> = pipeline_trimmed
162 .split(|c| c == '&')
163 .filter(|s| !s.trim().is_empty())
164 .collect();
165
166 for part in parts {
167 let part_trimmed = part.trim();
168 if part_trimmed.is_empty() {
169 continue;
170 }
171
172 let cmd = parse_command_element(part_trimmed, idx == 0);
173 commands.push(cmd);
174 }
175 }
176
177 if !commands.is_empty() {
178 statements.push(ParsedStatement {
179 statement_type,
180 commands,
181 });
182 }
183 }
184
185 ParsedPowerShellCommand {
186 valid: !statements.is_empty(),
187 statements,
188 error: None,
189 }
190}
191
192fn detect_statement_type(cmd: &str) -> StatementType {
194 let lower = cmd.to_lowercase();
195
196 if lower.contains(" if ") || lower.starts_with("if ") {
197 StatementType::IfStatementAst
198 } else if lower.contains(" foreach ") || lower.starts_with("foreach ") || lower.contains("%{") {
199 StatementType::ForEachStatementAst
200 } else if lower.contains(" for ") || lower.starts_with("for ") {
201 StatementType::ForStatementAst
202 } else if lower.contains(" while ") || lower.starts_with("while ") {
203 StatementType::WhileStatementAst
204 } else if lower.contains(" do ") || lower.starts_with("do ") {
205 StatementType::DoWhileStatementAst
206 } else if lower.contains(" switch ") || lower.starts_with("switch ") {
207 StatementType::SwitchStatementAst
208 } else if lower.contains(" try ") || lower.starts_with("try ") {
209 StatementType::TryStatementAst
210 } else if lower.contains(" function ") || lower.starts_with("function ") {
211 StatementType::FunctionDefinitionAst
212 } else if lower.contains('=') && !lower.contains("==") {
213 StatementType::AssignmentStatementAst
214 } else {
215 StatementType::PipelineAst
216 }
217}
218
219fn parse_command_element(text: &str, is_first: bool) -> ParsedCommandElement {
221 let parts: Vec<&str> = text.split_whitespace().collect();
222
223 if parts.is_empty() {
224 return create_empty_command(text.to_string());
225 }
226
227 let name = parts[0].to_string();
228 let args: Vec<String> = parts[1..].iter().map(|s| s.to_string()).collect();
229 let name_type = classify_command_name(&name);
230 let element_type = if is_first {
231 PipelineElementType::CommandAst
232 } else {
233 PipelineElementType::CommandExpressionAst
234 };
235 let element_types = Some(determine_element_types(&args));
236
237 ParsedCommandElement {
238 name,
239 name_type,
240 element_type,
241 args,
242 text: text.to_string(),
243 element_types,
244 children: None,
245 redirections: None,
246 }
247}
248
249fn create_empty_command(text: String) -> ParsedCommandElement {
251 ParsedCommandElement {
252 name: String::new(),
253 name_type: "unknown".to_string(),
254 element_type: PipelineElementType::CommandAst,
255 args: vec![],
256 text,
257 element_types: None,
258 children: None,
259 redirections: None,
260 }
261}
262
263fn classify_command_name(name: &str) -> String {
265 let lower = name.to_lowercase();
266
267 if lower.contains('-') {
269 return "cmdlet".to_string();
270 }
271
272 if lower.contains('\\') || lower.contains('/') || lower.contains('.') {
274 return "application".to_string();
275 }
276
277 let external = [
279 "git", "gh", "docker", "npm", "node", "python", "make", "tar", "curl", "wget",
280 ];
281 if external.contains(&lower.as_str()) {
282 return "application".to_string();
283 }
284
285 "unknown".to_string()
286}
287
288fn determine_element_types(args: &[String]) -> Vec<CommandElementType> {
290 let mut types = vec![CommandElementType::StringConstant];
291
292 for arg in args {
293 let et = classify_argument_element(arg);
294 types.push(et);
295 }
296
297 types
298}
299
300fn classify_argument_element(arg: &str) -> CommandElementType {
302 let trimmed = arg.trim();
303
304 if trimmed.starts_with('$') && trimmed.len() > 1 {
306 let second = trimmed.chars().nth(1);
307 if second == Some('(') || second == Some('@') {
309 return CommandElementType::SubExpression;
310 }
311 if second == Some('_') || second.is_some_and(|c| c.is_alphabetic()) {
313 return CommandElementType::Variable;
314 }
315 return CommandElementType::Variable;
316 }
317
318 if trimmed.starts_with('{')
321 || trimmed.ends_with('}')
322 || trimmed.contains("{ ")
323 || trimmed.contains(" }")
324 || trimmed.contains("{}")
325 {
326 return CommandElementType::ScriptBlock;
327 }
328
329 if trimmed.starts_with("$(")
331 || trimmed.starts_with("@(")
332 || trimmed.contains("$(")
333 || trimmed.contains("@(")
334 {
335 return CommandElementType::SubExpression;
336 }
337
338 if trimmed.starts_with('"') && trimmed.ends_with('"') {
340 return CommandElementType::ExpandableString;
341 }
342
343 if trimmed.contains('.') && trimmed.contains('(') {
345 return CommandElementType::MemberInvocation;
346 }
347
348 if is_powershell_parameter(trimmed, None) {
350 return CommandElementType::Parameter;
351 }
352
353 CommandElementType::StringConstant
354}
355
356pub fn derive_security_flags(parsed: &ParsedPowerShellCommand) -> SecurityFlags {
358 let mut flags = SecurityFlags::default();
359
360 for statement in &parsed.statements {
361 for cmd in &statement.commands {
362 if let Some(ref types) = cmd.element_types {
364 for et in types {
365 match et {
366 CommandElementType::ScriptBlock => flags.has_script_blocks = true,
367 CommandElementType::SubExpression => flags.has_sub_expressions = true,
368 CommandElementType::ExpandableString => flags.has_expandable_strings = true,
369 CommandElementType::MemberInvocation => flags.has_member_invocations = true,
370 CommandElementType::Variable => flags.has_variables = true,
371 _ => {}
372 }
373 }
374 }
375
376 for arg in &cmd.args {
378 if arg.starts_with('$') && arg.len() > 1 {
380 let second = arg.chars().nth(1);
381 if second == Some('(') || second == Some('@') {
383 flags.has_sub_expressions = true;
384 } else {
385 flags.has_variables = true;
386 }
387 }
388 if arg.contains('{') || arg.contains('}') {
390 flags.has_script_blocks = true;
391 }
392 if arg.contains("$(") || arg.contains("@(") {
394 flags.has_sub_expressions = true;
395 }
396 if arg.starts_with('"') && arg.ends_with('"') {
398 flags.has_expandable_strings = true;
399 }
400 if arg.contains('=') && !arg.starts_with('-') {
402 flags.has_assignments = true;
403 }
404 }
405
406 let text = &cmd.text;
409 if text.starts_with('$') && text.len() > 1 && !text.contains(' ') {
410 flags.has_variables = true;
412 }
413 if text.contains('{') || text.contains('}') {
415 flags.has_script_blocks = true;
416 }
417 if text.contains("$(") || text.contains("@(") {
419 flags.has_sub_expressions = true;
420 }
421 }
422 }
423
424 flags
425}
426
427#[derive(Debug, Clone, Default)]
429pub struct SecurityFlags {
430 pub has_script_blocks: bool,
431 pub has_sub_expressions: bool,
432 pub has_expandable_strings: bool,
433 pub has_member_invocations: bool,
434 pub has_splatting: bool,
435 pub has_assignments: bool,
436 pub has_stop_parsing: bool,
437 pub has_variables: bool,
438}
439
440pub fn build_powershell_command(script: &str) -> Command {
446 let mut cmd = Command::new("pwsh");
447 cmd.args(["-NoProfile", "-NonInteractive", "-Command", script]);
448 cmd
449}
450
451pub fn build_powershell_command_utf8(script: &str) -> Command {
453 let full_script = format!(
454 "[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; {}",
455 script
456 );
457 build_powershell_command(&full_script)
458}
459
460pub fn is_powershell_available() -> bool {
462 Command::new("pwsh")
463 .arg("--version")
464 .output()
465 .map(|o| o.status.success())
466 .unwrap_or(false)
467}
468
469pub fn get_powershell_version() -> Option<String> {
471 Command::new("pwsh")
472 .arg("--version")
473 .output()
474 .ok()
475 .and_then(|o| {
476 if o.status.success() {
477 String::from_utf8(o.stdout).ok()
478 } else {
479 None
480 }
481 })
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn test_parse_simple_command() {
490 let result = parse_powershell_command("Get-Content file.txt");
491 assert!(result.valid);
492 assert_eq!(result.statements.len(), 1);
493 assert_eq!(result.statements[0].commands[0].name, "Get-Content");
494 }
495
496 #[test]
497 fn test_parse_command_with_args() {
498 let result = parse_powershell_command("Remove-Item -Path test.txt -Recurse -Force");
499 assert!(result.valid);
500 let cmd = &result.statements[0].commands[0];
501 assert_eq!(cmd.name, "Remove-Item");
502 assert!(cmd.args.contains(&"-Path".to_string()));
503 }
504
505 #[test]
506 fn test_parse_pipeline() {
507 let result = parse_powershell_command("Get-Content file.txt | Select-String pattern");
508 assert!(result.valid);
509 assert_eq!(result.statements[0].commands.len(), 2);
510 }
511
512 #[test]
513 fn test_parse_compound_statements() {
514 let result = parse_powershell_command("$var = 1; Get-Content file.txt");
515 assert!(result.valid);
516 assert_eq!(result.statements.len(), 2);
517 }
518
519 #[test]
520 fn test_detect_variables() {
521 let result = parse_powershell_command("Write-Host $env:SECRET");
522 assert!(result.valid);
523 let types = &result.statements[0].commands[0].element_types;
524 assert!(
525 types
526 .as_ref()
527 .map(|t| t.iter().any(|et| *et == CommandElementType::Variable))
528 .unwrap_or(false)
529 );
530 }
531
532 #[test]
533 fn test_detect_script_blocks() {
534 let result = parse_powershell_command("Where-Object { $_.Name }");
535 assert!(result.valid);
536 let types = &result.statements[0].commands[0].element_types;
537 assert!(
538 types
539 .as_ref()
540 .map(|t| t.iter().any(|et| *et == CommandElementType::ScriptBlock))
541 .unwrap_or(false)
542 );
543 }
544
545 #[test]
546 fn test_detect_subexpression() {
547 let result = parse_powershell_command("Invoke-Expression $(malicious)");
548 assert!(result.valid);
549 let types = &result.statements[0].commands[0].element_types;
550 assert!(
551 types
552 .as_ref()
553 .map(|t| t.iter().any(|et| *et == CommandElementType::SubExpression))
554 .unwrap_or(false)
555 );
556 }
557
558 #[test]
559 fn test_classify_cmdlet() {
560 assert_eq!(classify_command_name("Get-Content"), "cmdlet");
561 assert_eq!(classify_command_name("Remove-Item"), "cmdlet");
562 }
563
564 #[test]
565 fn test_classify_application() {
566 assert_eq!(classify_command_name("git"), "application");
567 assert_eq!(classify_command_name("./script.ps1"), "application");
568 }
569
570 #[test]
571 fn test_is_powershell_parameter() {
572 assert!(is_powershell_parameter("-Path", None));
573 assert!(is_powershell_parameter("-Recurse", None));
574 assert!(is_powershell_parameter("/C", None));
575 assert!(!is_powershell_parameter("file.txt", None));
576 }
577
578 #[test]
579 fn test_derive_security_flags_variables() {
580 let parsed = parse_powershell_command("$env:SECRET | Write-Host");
581 let flags = derive_security_flags(&parsed);
582 assert!(flags.has_variables);
583 }
584
585 #[test]
586 fn test_derive_security_flags_script_blocks() {
587 let parsed = parse_powershell_command("Get-Process | Where-Object { $_.CPU }");
588 let flags = derive_security_flags(&parsed);
589 assert!(flags.has_script_blocks);
590 }
591
592 #[test]
593 fn test_derive_security_flags_subexpression() {
594 let parsed = parse_powershell_command("Invoke-Expression $(malicious)");
595 let flags = derive_security_flags(&parsed);
596 assert!(flags.has_sub_expressions);
597 }
598
599 #[test]
600 fn test_derive_security_flags_assignment() {
601 let parsed = parse_powershell_command("$result = Get-Content file.txt");
602 let flags = derive_security_flags(&parsed);
603 assert!(flags.has_assignments);
604 }
605
606 #[test]
607 fn test_empty_command() {
608 let result = parse_powershell_command("");
609 assert!(!result.valid);
610 }
611
612 #[test]
613 fn test_member_invocation() {
614 let result = parse_powershell_command("$obj.Method()");
615 assert!(result.valid);
616 }
617
618 #[test]
619 fn test_parse_alias() {
620 let result = parse_powershell_command("gc file.txt");
621 assert!(result.valid);
622 assert_eq!(result.statements[0].commands[0].name, "gc");
623 }
624}