1use std::process::Command;
6
7pub fn escape_powershell_string(s: &str) -> String {
9 s.replace('`', "``").replace('"', "`\"").replace('$', "`$")
11}
12
13use serde::{Deserialize, Serialize};
18
19#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21pub enum PipelineElementType {
22 CommandAst,
23 CommandExpressionAst,
24 ParenExpressionAst,
25}
26
27#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
29pub enum CommandElementType {
30 ScriptBlock,
31 SubExpression,
32 ExpandableString,
33 MemberInvocation,
34 Variable,
35 StringConstant,
36 Parameter,
37 Other,
38}
39
40#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
42pub enum StatementType {
43 PipelineAst,
44 PipelineChainAst,
45 AssignmentStatementAst,
46 IfStatementAst,
47 ForStatementAst,
48 ForEachStatementAst,
49 WhileStatementAst,
50 DoWhileStatementAst,
51 DoUntilStatementAst,
52 SwitchStatementAst,
53 TryStatementAst,
54 TrapStatementAst,
55 FunctionDefinitionAst,
56 DataStatementAst,
57 UnknownStatementAst,
58}
59
60#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub struct CommandElementChild {
63 pub element_type: CommandElementType,
64 pub text: String,
65}
66
67#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
69pub struct ParsedRedirection {
70 pub from: String,
71 pub to: String,
72 pub is_merging: bool,
73}
74
75#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
77pub struct ParsedCommandElement {
78 pub name: String,
79 pub name_type: String,
80 pub element_type: PipelineElementType,
81 pub args: Vec<String>,
82 pub text: String,
83 pub element_types: Option<Vec<CommandElementType>>,
84 pub children: Option<Vec<Option<Vec<CommandElementChild>>>>,
85 pub redirections: Option<Vec<ParsedRedirection>>,
86}
87
88#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
90pub struct PipelineSegment {
91 pub commands: Vec<ParsedCommandElement>,
92 pub redirections: Vec<ParsedRedirection>,
93 pub nested_commands: Option<Vec<ParsedCommandElement>>,
94}
95
96#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
98pub struct ParsedStatement {
99 pub statement_type: StatementType,
100 pub commands: Vec<ParsedCommandElement>,
101}
102
103#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105pub struct ParsedPowerShellCommand {
106 pub valid: bool,
107 pub statements: Vec<ParsedStatement>,
108 pub error: Option<String>,
109}
110
111pub fn is_powershell_parameter(arg: &str, element_type: Option<&CommandElementType>) -> bool {
113 if let Some(et) = element_type {
114 return *et == CommandElementType::Parameter;
115 }
116 arg.starts_with('-') || arg.starts_with('/') ||
118 arg.starts_with('–') || arg.starts_with('—') || arg.starts_with('―')
119}
120
121pub const PS_TOKENIZER_DASH_CHARS: &[char] = &['-', '–', '—', '―', '/'];
123
124pub fn parse_powershell_command(command: &str) -> ParsedPowerShellCommand {
126 let trimmed = command.trim();
127
128 if trimmed.is_empty() {
129 return ParsedPowerShellCommand {
130 valid: false,
131 statements: vec![],
132 error: Some("Empty command".to_string()),
133 };
134 }
135
136 let statement_strs: Vec<&str> = trimmed.split(|c| c == ';' || c == '\n')
138 .filter(|s| !s.trim().is_empty())
139 .collect();
140
141 let mut statements = Vec::new();
142
143 for stmt_str in statement_strs {
144 let statement_type = detect_statement_type(stmt_str);
145
146 let pipeline_strs: Vec<&str> = stmt_str.split('|').collect();
148 let mut commands = Vec::new();
149
150 for (idx, pipeline_str) in pipeline_strs.iter().enumerate() {
151 let pipeline_trimmed = pipeline_str.trim();
152 if pipeline_trimmed.is_empty() {
153 continue;
154 }
155
156 let parts: Vec<&str> = pipeline_trimmed
158 .split(|c| c == '&')
159 .filter(|s| !s.trim().is_empty())
160 .collect();
161
162 for part in parts {
163 let part_trimmed = part.trim();
164 if part_trimmed.is_empty() {
165 continue;
166 }
167
168 let cmd = parse_command_element(part_trimmed, idx == 0);
169 commands.push(cmd);
170 }
171 }
172
173 if !commands.is_empty() {
174 statements.push(ParsedStatement {
175 statement_type,
176 commands,
177 });
178 }
179 }
180
181 ParsedPowerShellCommand {
182 valid: !statements.is_empty(),
183 statements,
184 error: None,
185 }
186}
187
188fn detect_statement_type(cmd: &str) -> StatementType {
190 let lower = cmd.to_lowercase();
191
192 if lower.contains(" if ") || lower.starts_with("if ") {
193 StatementType::IfStatementAst
194 } else if lower.contains(" foreach ") || lower.starts_with("foreach ") || lower.contains("%{") {
195 StatementType::ForEachStatementAst
196 } else if lower.contains(" for ") || lower.starts_with("for ") {
197 StatementType::ForStatementAst
198 } else if lower.contains(" while ") || lower.starts_with("while ") {
199 StatementType::WhileStatementAst
200 } else if lower.contains(" do ") || lower.starts_with("do ") {
201 StatementType::DoWhileStatementAst
202 } else if lower.contains(" switch ") || lower.starts_with("switch ") {
203 StatementType::SwitchStatementAst
204 } else if lower.contains(" try ") || lower.starts_with("try ") {
205 StatementType::TryStatementAst
206 } else if lower.contains(" function ") || lower.starts_with("function ") {
207 StatementType::FunctionDefinitionAst
208 } else if lower.contains('=') && !lower.contains("==") {
209 StatementType::AssignmentStatementAst
210 } else {
211 StatementType::PipelineAst
212 }
213}
214
215fn parse_command_element(text: &str, is_first: bool) -> ParsedCommandElement {
217 let parts: Vec<&str> = text.split_whitespace().collect();
218
219 if parts.is_empty() {
220 return create_empty_command(text.to_string());
221 }
222
223 let name = parts[0].to_string();
224 let args: Vec<String> = parts[1..].iter().map(|s| s.to_string()).collect();
225 let name_type = classify_command_name(&name);
226 let element_type = if is_first {
227 PipelineElementType::CommandAst
228 } else {
229 PipelineElementType::CommandExpressionAst
230 };
231 let element_types = Some(determine_element_types(&args));
232
233 ParsedCommandElement {
234 name,
235 name_type,
236 element_type,
237 args,
238 text: text.to_string(),
239 element_types,
240 children: None,
241 redirections: None,
242 }
243}
244
245fn create_empty_command(text: String) -> ParsedCommandElement {
247 ParsedCommandElement {
248 name: String::new(),
249 name_type: "unknown".to_string(),
250 element_type: PipelineElementType::CommandAst,
251 args: vec![],
252 text,
253 element_types: None,
254 children: None,
255 redirections: None,
256 }
257}
258
259fn classify_command_name(name: &str) -> String {
261 let lower = name.to_lowercase();
262
263 if lower.contains('-') {
265 return "cmdlet".to_string();
266 }
267
268 if lower.contains('\\') || lower.contains('/') || lower.contains('.') {
270 return "application".to_string();
271 }
272
273 let external = ["git", "gh", "docker", "npm", "node", "python", "make", "tar", "curl", "wget"];
275 if external.contains(&lower.as_str()) {
276 return "application".to_string();
277 }
278
279 "unknown".to_string()
280}
281
282fn determine_element_types(args: &[String]) -> Vec<CommandElementType> {
284 let mut types = vec![CommandElementType::StringConstant];
285
286 for arg in args {
287 let et = classify_argument_element(arg);
288 types.push(et);
289 }
290
291 types
292}
293
294fn classify_argument_element(arg: &str) -> CommandElementType {
296 let trimmed = arg.trim();
297
298 if trimmed.starts_with('$') && trimmed.len() > 1 {
300 let second = trimmed.chars().nth(1);
301 if second == Some('(') || second == Some('@') {
303 return CommandElementType::SubExpression;
304 }
305 if second == Some('_') || second.is_some_and(|c| c.is_alphabetic()) {
307 return CommandElementType::Variable;
308 }
309 return CommandElementType::Variable;
310 }
311
312 if trimmed.starts_with('{') || trimmed.ends_with('}') ||
315 trimmed.contains("{ ") || trimmed.contains(" }") || trimmed.contains("{}") {
316 return CommandElementType::ScriptBlock;
317 }
318
319 if trimmed.starts_with("$(") || trimmed.starts_with("@(") || trimmed.contains("$(") || trimmed.contains("@(") {
321 return CommandElementType::SubExpression;
322 }
323
324 if trimmed.starts_with('"') && trimmed.ends_with('"') {
326 return CommandElementType::ExpandableString;
327 }
328
329 if trimmed.contains('.') && trimmed.contains('(') {
331 return CommandElementType::MemberInvocation;
332 }
333
334 if is_powershell_parameter(trimmed, None) {
336 return CommandElementType::Parameter;
337 }
338
339 CommandElementType::StringConstant
340}
341
342pub fn derive_security_flags(parsed: &ParsedPowerShellCommand) -> SecurityFlags {
344 let mut flags = SecurityFlags::default();
345
346 for statement in &parsed.statements {
347 for cmd in &statement.commands {
348 if let Some(ref types) = cmd.element_types {
350 for et in types {
351 match et {
352 CommandElementType::ScriptBlock => flags.has_script_blocks = true,
353 CommandElementType::SubExpression => flags.has_sub_expressions = true,
354 CommandElementType::ExpandableString => flags.has_expandable_strings = true,
355 CommandElementType::MemberInvocation => flags.has_member_invocations = true,
356 CommandElementType::Variable => flags.has_variables = true,
357 _ => {}
358 }
359 }
360 }
361
362 for arg in &cmd.args {
364 if arg.starts_with('$') && arg.len() > 1 {
366 let second = arg.chars().nth(1);
367 if second == Some('(') || second == Some('@') {
369 flags.has_sub_expressions = true;
370 } else {
371 flags.has_variables = true;
372 }
373 }
374 if arg.contains('{') || arg.contains('}') {
376 flags.has_script_blocks = true;
377 }
378 if arg.contains("$(") || arg.contains("@(") {
380 flags.has_sub_expressions = true;
381 }
382 if arg.starts_with('"') && arg.ends_with('"') {
384 flags.has_expandable_strings = true;
385 }
386 if arg.contains('=') && !arg.starts_with('-') {
388 flags.has_assignments = true;
389 }
390 }
391
392 let text = &cmd.text;
395 if text.starts_with('$') && text.len() > 1 && !text.contains(' ') {
396 flags.has_variables = true;
398 }
399 if text.contains('{') || text.contains('}') {
401 flags.has_script_blocks = true;
402 }
403 if text.contains("$(") || text.contains("@(") {
405 flags.has_sub_expressions = true;
406 }
407 }
408 }
409
410 flags
411}
412
413#[derive(Debug, Clone, Default)]
415pub struct SecurityFlags {
416 pub has_script_blocks: bool,
417 pub has_sub_expressions: bool,
418 pub has_expandable_strings: bool,
419 pub has_member_invocations: bool,
420 pub has_splatting: bool,
421 pub has_assignments: bool,
422 pub has_stop_parsing: bool,
423 pub has_variables: bool,
424}
425
426pub fn build_powershell_command(script: &str) -> Command {
432 let mut cmd = Command::new("pwsh");
433 cmd.args(["-NoProfile", "-NonInteractive", "-Command", script]);
434 cmd
435}
436
437pub fn build_powershell_command_utf8(script: &str) -> Command {
439 let full_script = format!(
440 "[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; {}",
441 script
442 );
443 build_powershell_command(&full_script)
444}
445
446pub fn is_powershell_available() -> bool {
448 Command::new("pwsh")
449 .arg("--version")
450 .output()
451 .map(|o| o.status.success())
452 .unwrap_or(false)
453}
454
455pub fn get_powershell_version() -> Option<String> {
457 Command::new("pwsh")
458 .arg("--version")
459 .output()
460 .ok()
461 .and_then(|o| {
462 if o.status.success() {
463 String::from_utf8(o.stdout).ok()
464 } else {
465 None
466 }
467 })
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn test_parse_simple_command() {
476 let result = parse_powershell_command("Get-Content file.txt");
477 assert!(result.valid);
478 assert_eq!(result.statements.len(), 1);
479 assert_eq!(result.statements[0].commands[0].name, "Get-Content");
480 }
481
482 #[test]
483 fn test_parse_command_with_args() {
484 let result = parse_powershell_command("Remove-Item -Path test.txt -Recurse -Force");
485 assert!(result.valid);
486 let cmd = &result.statements[0].commands[0];
487 assert_eq!(cmd.name, "Remove-Item");
488 assert!(cmd.args.contains(&"-Path".to_string()));
489 }
490
491 #[test]
492 fn test_parse_pipeline() {
493 let result = parse_powershell_command("Get-Content file.txt | Select-String pattern");
494 assert!(result.valid);
495 assert_eq!(result.statements[0].commands.len(), 2);
496 }
497
498 #[test]
499 fn test_parse_compound_statements() {
500 let result = parse_powershell_command("$var = 1; Get-Content file.txt");
501 assert!(result.valid);
502 assert_eq!(result.statements.len(), 2);
503 }
504
505 #[test]
506 fn test_detect_variables() {
507 let result = parse_powershell_command("Write-Host $env:SECRET");
508 assert!(result.valid);
509 let types = &result.statements[0].commands[0].element_types;
510 assert!(types.as_ref().map(|t| t.iter().any(|et| *et == CommandElementType::Variable)).unwrap_or(false));
511 }
512
513 #[test]
514 fn test_detect_script_blocks() {
515 let result = parse_powershell_command("Where-Object { $_.Name }");
516 assert!(result.valid);
517 let types = &result.statements[0].commands[0].element_types;
518 assert!(types.as_ref().map(|t| t.iter().any(|et| *et == CommandElementType::ScriptBlock)).unwrap_or(false));
519 }
520
521 #[test]
522 fn test_detect_subexpression() {
523 let result = parse_powershell_command("Invoke-Expression $(malicious)");
524 assert!(result.valid);
525 let types = &result.statements[0].commands[0].element_types;
526 assert!(types.as_ref().map(|t| t.iter().any(|et| *et == CommandElementType::SubExpression)).unwrap_or(false));
527 }
528
529 #[test]
530 fn test_classify_cmdlet() {
531 assert_eq!(classify_command_name("Get-Content"), "cmdlet");
532 assert_eq!(classify_command_name("Remove-Item"), "cmdlet");
533 }
534
535 #[test]
536 fn test_classify_application() {
537 assert_eq!(classify_command_name("git"), "application");
538 assert_eq!(classify_command_name("./script.ps1"), "application");
539 }
540
541 #[test]
542 fn test_is_powershell_parameter() {
543 assert!(is_powershell_parameter("-Path", None));
544 assert!(is_powershell_parameter("-Recurse", None));
545 assert!(is_powershell_parameter("/C", None));
546 assert!(!is_powershell_parameter("file.txt", None));
547 }
548
549 #[test]
550 fn test_derive_security_flags_variables() {
551 let parsed = parse_powershell_command("$env:SECRET | Write-Host");
552 let flags = derive_security_flags(&parsed);
553 assert!(flags.has_variables);
554 }
555
556 #[test]
557 fn test_derive_security_flags_script_blocks() {
558 let parsed = parse_powershell_command("Get-Process | Where-Object { $_.CPU }");
559 let flags = derive_security_flags(&parsed);
560 assert!(flags.has_script_blocks);
561 }
562
563 #[test]
564 fn test_derive_security_flags_subexpression() {
565 let parsed = parse_powershell_command("Invoke-Expression $(malicious)");
566 let flags = derive_security_flags(&parsed);
567 assert!(flags.has_sub_expressions);
568 }
569
570 #[test]
571 fn test_derive_security_flags_assignment() {
572 let parsed = parse_powershell_command("$result = Get-Content file.txt");
573 let flags = derive_security_flags(&parsed);
574 assert!(flags.has_assignments);
575 }
576
577 #[test]
578 fn test_empty_command() {
579 let result = parse_powershell_command("");
580 assert!(!result.valid);
581 }
582
583 #[test]
584 fn test_member_invocation() {
585 let result = parse_powershell_command("$obj.Method()");
586 assert!(result.valid);
587 }
588
589 #[test]
590 fn test_parse_alias() {
591 let result = parse_powershell_command("gc file.txt");
592 assert!(result.valid);
593 assert_eq!(result.statements[0].commands[0].name, "gc");
594 }
595}