Skip to main content

ai_agent/utils/
powershell.rs

1//! PowerShell execution utilities.
2//!
3//! This module provides PowerShell command parsing, execution, and security analysis.
4
5use std::process::Command;
6
7/// Escape a string for use in PowerShell
8pub fn escape_powershell_string(s: &str) -> String {
9    // Replace backticks, double quotes, and $ with escaped versions
10    s.replace('`', "``").replace('"', "`\"").replace('$', "`$")
11}
12
13// ============================================================================
14// PowerShell Parser - AST types and parsing functions
15// ============================================================================
16
17use serde::{Deserialize, Serialize};
18
19/// Pipeline element type
20#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21pub enum PipelineElementType {
22    CommandAst,
23    CommandExpressionAst,
24    ParenExpressionAst,
25}
26
27/// Command element type (AST node classification)
28#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
29pub enum CommandElementType {
30    ScriptBlock,
31    SubExpression,
32    ExpandableString,
33    MemberInvocation,
34    Variable,
35    StringConstant,
36    Parameter,
37    Other,
38}
39
40/// Statement type
41#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
42pub enum StatementType {
43    PipelineAst,
44    PipelineChainAst,
45    AssignmentStatementAst,
46    IfStatementAst,
47    ForStatementAst,
48    ForEachStatementAst,
49    WhileStatementAst,
50    DoWhileStatementAst,
51    DoUntilStatementAst,
52    SwitchStatementAst,
53    TryStatementAst,
54    TrapStatementAst,
55    FunctionDefinitionAst,
56    DataStatementAst,
57    UnknownStatementAst,
58}
59
60/// A child node of a command element
61#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub struct CommandElementChild {
63    pub element_type: CommandElementType,
64    pub text: String,
65}
66
67/// Redirection
68#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
69pub struct ParsedRedirection {
70    pub from: String,
71    pub to: String,
72    pub is_merging: bool,
73}
74
75/// A command invocation within a pipeline
76#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
77pub struct ParsedCommandElement {
78    pub name: String,
79    pub name_type: String,
80    pub element_type: PipelineElementType,
81    pub args: Vec<String>,
82    pub text: String,
83    pub element_types: Option<Vec<CommandElementType>>,
84    pub children: Option<Vec<Option<Vec<CommandElementChild>>>>,
85    pub redirections: Option<Vec<ParsedRedirection>>,
86}
87
88/// Pipeline segment
89#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
90pub struct PipelineSegment {
91    pub commands: Vec<ParsedCommandElement>,
92    pub redirections: Vec<ParsedRedirection>,
93    pub nested_commands: Option<Vec<ParsedCommandElement>>,
94}
95
96/// A statement in the PowerShell command
97#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
98pub struct ParsedStatement {
99    pub statement_type: StatementType,
100    pub commands: Vec<ParsedCommandElement>,
101}
102
103/// Complete parsed PowerShell command
104#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105pub struct ParsedPowerShellCommand {
106    pub valid: bool,
107    pub statements: Vec<ParsedStatement>,
108    pub error: Option<String>,
109}
110
111/// Check if a string is a PowerShell parameter
112pub fn is_powershell_parameter(arg: &str, element_type: Option<&CommandElementType>) -> bool {
113    if let Some(et) = element_type {
114        return *et == CommandElementType::Parameter;
115    }
116    // Check for common parameter prefixes
117    arg.starts_with('-') || arg.starts_with('/') ||
118    arg.starts_with('–') || arg.starts_with('—') || arg.starts_with('―')
119}
120
121/// Alternative parameter prefix characters
122pub const PS_TOKENIZER_DASH_CHARS: &[char] = &['-', '–', '—', '―', '/'];
123
124/// Parse a PowerShell command string into components
125pub fn parse_powershell_command(command: &str) -> ParsedPowerShellCommand {
126    let trimmed = command.trim();
127
128    if trimmed.is_empty() {
129        return ParsedPowerShellCommand {
130            valid: false,
131            statements: vec![],
132            error: Some("Empty command".to_string()),
133        };
134    }
135
136    // Split by statement separators: ; and newlines
137    let statement_strs: Vec<&str> = trimmed.split(|c| c == ';' || c == '\n')
138        .filter(|s| !s.trim().is_empty())
139        .collect();
140
141    let mut statements = Vec::new();
142
143    for stmt_str in statement_strs {
144        let statement_type = detect_statement_type(stmt_str);
145
146        // Parse pipeline elements (split by |)
147        let pipeline_strs: Vec<&str> = stmt_str.split('|').collect();
148        let mut commands = Vec::new();
149
150        for (idx, pipeline_str) in pipeline_strs.iter().enumerate() {
151            let pipeline_trimmed = pipeline_str.trim();
152            if pipeline_trimmed.is_empty() {
153                continue;
154            }
155
156            // Split by operators: &, && (but NOT | since we already split by | for pipeline)
157            let parts: Vec<&str> = pipeline_trimmed
158                .split(|c| c == '&')
159                .filter(|s| !s.trim().is_empty())
160                .collect();
161
162            for part in parts {
163                let part_trimmed = part.trim();
164                if part_trimmed.is_empty() {
165                    continue;
166                }
167
168                let cmd = parse_command_element(part_trimmed, idx == 0);
169                commands.push(cmd);
170            }
171        }
172
173        if !commands.is_empty() {
174            statements.push(ParsedStatement {
175                statement_type,
176                commands,
177            });
178        }
179    }
180
181    ParsedPowerShellCommand {
182        valid: !statements.is_empty(),
183        statements,
184        error: None,
185    }
186}
187
188/// Detect statement type from keywords
189fn detect_statement_type(cmd: &str) -> StatementType {
190    let lower = cmd.to_lowercase();
191
192    if lower.contains(" if ") || lower.starts_with("if ") {
193        StatementType::IfStatementAst
194    } else if lower.contains(" foreach ") || lower.starts_with("foreach ") || lower.contains("%{") {
195        StatementType::ForEachStatementAst
196    } else if lower.contains(" for ") || lower.starts_with("for ") {
197        StatementType::ForStatementAst
198    } else if lower.contains(" while ") || lower.starts_with("while ") {
199        StatementType::WhileStatementAst
200    } else if lower.contains(" do ") || lower.starts_with("do ") {
201        StatementType::DoWhileStatementAst
202    } else if lower.contains(" switch ") || lower.starts_with("switch ") {
203        StatementType::SwitchStatementAst
204    } else if lower.contains(" try ") || lower.starts_with("try ") {
205        StatementType::TryStatementAst
206    } else if lower.contains(" function ") || lower.starts_with("function ") {
207        StatementType::FunctionDefinitionAst
208    } else if lower.contains('=') && !lower.contains("==") {
209        StatementType::AssignmentStatementAst
210    } else {
211        StatementType::PipelineAst
212    }
213}
214
215/// Parse a single command element
216fn parse_command_element(text: &str, is_first: bool) -> ParsedCommandElement {
217    let parts: Vec<&str> = text.split_whitespace().collect();
218
219    if parts.is_empty() {
220        return create_empty_command(text.to_string());
221    }
222
223    let name = parts[0].to_string();
224    let args: Vec<String> = parts[1..].iter().map(|s| s.to_string()).collect();
225    let name_type = classify_command_name(&name);
226    let element_type = if is_first {
227        PipelineElementType::CommandAst
228    } else {
229        PipelineElementType::CommandExpressionAst
230    };
231    let element_types = Some(determine_element_types(&args));
232
233    ParsedCommandElement {
234        name,
235        name_type,
236        element_type,
237        args,
238        text: text.to_string(),
239        element_types,
240        children: None,
241        redirections: None,
242    }
243}
244
245/// Create an empty command
246fn create_empty_command(text: String) -> ParsedCommandElement {
247    ParsedCommandElement {
248        name: String::new(),
249        name_type: "unknown".to_string(),
250        element_type: PipelineElementType::CommandAst,
251        args: vec![],
252        text,
253        element_types: None,
254        children: None,
255        redirections: None,
256    }
257}
258
259/// Classify command name type
260fn classify_command_name(name: &str) -> String {
261    let lower = name.to_lowercase();
262
263    // Check if it's a cmdlet (Verb-Noun pattern)
264    if lower.contains('-') {
265        return "cmdlet".to_string();
266    }
267
268    // Check if it has path separators (application)
269    if lower.contains('\\') || lower.contains('/') || lower.contains('.') {
270        return "application".to_string();
271    }
272
273    // Common external commands
274    let external = ["git", "gh", "docker", "npm", "node", "python", "make", "tar", "curl", "wget"];
275    if external.contains(&lower.as_str()) {
276        return "application".to_string();
277    }
278
279    "unknown".to_string()
280}
281
282/// Determine element types for arguments
283fn determine_element_types(args: &[String]) -> Vec<CommandElementType> {
284    let mut types = vec![CommandElementType::StringConstant];
285
286    for arg in args {
287        let et = classify_argument_element(arg);
288        types.push(et);
289    }
290
291    types
292}
293
294/// Classify an argument's element type
295fn classify_argument_element(arg: &str) -> CommandElementType {
296    let trimmed = arg.trim();
297
298    // Check for variable $var (includes $_.Name, $env:VAR, etc.)
299    if trimmed.starts_with('$') && trimmed.len() > 1 {
300        let second = trimmed.chars().nth(1);
301        // $() is subexpression, $var is variable
302        if second == Some('(') || second == Some('@') {
303            return CommandElementType::SubExpression;
304        }
305        // Check if it's $_.Property (member access with variable)
306        if second == Some('_') || second.is_some_and(|c| c.is_alphabetic()) {
307            return CommandElementType::Variable;
308        }
309        return CommandElementType::Variable;
310    }
311
312    // Check for script block {} - exact match OR contains script block content
313    // Handles: {}, { code }, { $_.Name }, and partial tokens like { and }
314    if trimmed.starts_with('{') || trimmed.ends_with('}') ||
315       trimmed.contains("{ ") || trimmed.contains(" }") || trimmed.contains("{}") {
316        return CommandElementType::ScriptBlock;
317    }
318
319    // Check for subexpression $() - exact match OR contains it
320    if trimmed.starts_with("$(") || trimmed.starts_with("@(") || trimmed.contains("$(") || trimmed.contains("@(") {
321        return CommandElementType::SubExpression;
322    }
323
324    // Check for expandable string
325    if trimmed.starts_with('"') && trimmed.ends_with('"') {
326        return CommandElementType::ExpandableString;
327    }
328
329    // Check for member invocation .Method()
330    if trimmed.contains('.') && trimmed.contains('(') {
331        return CommandElementType::MemberInvocation;
332    }
333
334    // Check for parameter
335    if is_powershell_parameter(trimmed, None) {
336        return CommandElementType::Parameter;
337    }
338
339    CommandElementType::StringConstant
340}
341
342/// Derive security flags from parsed command
343pub fn derive_security_flags(parsed: &ParsedPowerShellCommand) -> SecurityFlags {
344    let mut flags = SecurityFlags::default();
345
346    for statement in &parsed.statements {
347        for cmd in &statement.commands {
348            // Check element_types first
349            if let Some(ref types) = cmd.element_types {
350                for et in types {
351                    match et {
352                        CommandElementType::ScriptBlock => flags.has_script_blocks = true,
353                        CommandElementType::SubExpression => flags.has_sub_expressions = true,
354                        CommandElementType::ExpandableString => flags.has_expandable_strings = true,
355                        CommandElementType::MemberInvocation => flags.has_member_invocations = true,
356                        CommandElementType::Variable => flags.has_variables = true,
357                        _ => {}
358                    }
359                }
360            }
361
362            // Also check raw args for variables and script blocks (handles edge cases)
363            for arg in &cmd.args {
364                // Check for variables: $var, $env:VAR, $_.prop
365                if arg.starts_with('$') && arg.len() > 1 {
366                    let second = arg.chars().nth(1);
367                    // $(), @() are subexpressions
368                    if second == Some('(') || second == Some('@') {
369                        flags.has_sub_expressions = true;
370                    } else {
371                        flags.has_variables = true;
372                    }
373                }
374                // Check for script blocks
375                if arg.contains('{') || arg.contains('}') {
376                    flags.has_script_blocks = true;
377                }
378                // Check for subexpressions
379                if arg.contains("$(") || arg.contains("@(") {
380                    flags.has_sub_expressions = true;
381                }
382                // Check for expandable strings
383                if arg.starts_with('"') && arg.ends_with('"') {
384                    flags.has_expandable_strings = true;
385                }
386                // Check for assignments
387                if arg.contains('=') && !arg.starts_with('-') {
388                    flags.has_assignments = true;
389                }
390            }
391
392            // Also check cmd.text (the full command text) for edge cases
393            // e.g., "$env:SECRET" becomes a command with name="$env:SECRET", args=[]
394            let text = &cmd.text;
395            if text.starts_with('$') && text.len() > 1 && !text.contains(' ') {
396                // Single token like $env:SECRET - it's a variable reference
397                flags.has_variables = true;
398            }
399            // Check for script blocks in text
400            if text.contains('{') || text.contains('}') {
401                flags.has_script_blocks = true;
402            }
403            // Check for subexpressions in text
404            if text.contains("$(") || text.contains("@(") {
405                flags.has_sub_expressions = true;
406            }
407        }
408    }
409
410    flags
411}
412
413/// Security flags derived from parsing
414#[derive(Debug, Clone, Default)]
415pub struct SecurityFlags {
416    pub has_script_blocks: bool,
417    pub has_sub_expressions: bool,
418    pub has_expandable_strings: bool,
419    pub has_member_invocations: bool,
420    pub has_splatting: bool,
421    pub has_assignments: bool,
422    pub has_stop_parsing: bool,
423    pub has_variables: bool,
424}
425
426// ============================================================================
427// Shell execution functions
428// ============================================================================
429
430/// Build a PowerShell command that outputs as UTF-8
431pub fn build_powershell_command(script: &str) -> Command {
432    let mut cmd = Command::new("pwsh");
433    cmd.args(["-NoProfile", "-NonInteractive", "-Command", script]);
434    cmd
435}
436
437/// Build a PowerShell command with UTF-8 output encoding
438pub fn build_powershell_command_utf8(script: &str) -> Command {
439    let full_script = format!(
440        "[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; {}",
441        script
442    );
443    build_powershell_command(&full_script)
444}
445
446/// Check if PowerShell is available
447pub fn is_powershell_available() -> bool {
448    Command::new("pwsh")
449        .arg("--version")
450        .output()
451        .map(|o| o.status.success())
452        .unwrap_or(false)
453}
454
455/// Get the PowerShell version
456pub fn get_powershell_version() -> Option<String> {
457    Command::new("pwsh")
458        .arg("--version")
459        .output()
460        .ok()
461        .and_then(|o| {
462            if o.status.success() {
463                String::from_utf8(o.stdout).ok()
464            } else {
465                None
466            }
467        })
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn test_parse_simple_command() {
476        let result = parse_powershell_command("Get-Content file.txt");
477        assert!(result.valid);
478        assert_eq!(result.statements.len(), 1);
479        assert_eq!(result.statements[0].commands[0].name, "Get-Content");
480    }
481
482    #[test]
483    fn test_parse_command_with_args() {
484        let result = parse_powershell_command("Remove-Item -Path test.txt -Recurse -Force");
485        assert!(result.valid);
486        let cmd = &result.statements[0].commands[0];
487        assert_eq!(cmd.name, "Remove-Item");
488        assert!(cmd.args.contains(&"-Path".to_string()));
489    }
490
491    #[test]
492    fn test_parse_pipeline() {
493        let result = parse_powershell_command("Get-Content file.txt | Select-String pattern");
494        assert!(result.valid);
495        assert_eq!(result.statements[0].commands.len(), 2);
496    }
497
498    #[test]
499    fn test_parse_compound_statements() {
500        let result = parse_powershell_command("$var = 1; Get-Content file.txt");
501        assert!(result.valid);
502        assert_eq!(result.statements.len(), 2);
503    }
504
505    #[test]
506    fn test_detect_variables() {
507        let result = parse_powershell_command("Write-Host $env:SECRET");
508        assert!(result.valid);
509        let types = &result.statements[0].commands[0].element_types;
510        assert!(types.as_ref().map(|t| t.iter().any(|et| *et == CommandElementType::Variable)).unwrap_or(false));
511    }
512
513    #[test]
514    fn test_detect_script_blocks() {
515        let result = parse_powershell_command("Where-Object { $_.Name }");
516        assert!(result.valid);
517        let types = &result.statements[0].commands[0].element_types;
518        assert!(types.as_ref().map(|t| t.iter().any(|et| *et == CommandElementType::ScriptBlock)).unwrap_or(false));
519    }
520
521    #[test]
522    fn test_detect_subexpression() {
523        let result = parse_powershell_command("Invoke-Expression $(malicious)");
524        assert!(result.valid);
525        let types = &result.statements[0].commands[0].element_types;
526        assert!(types.as_ref().map(|t| t.iter().any(|et| *et == CommandElementType::SubExpression)).unwrap_or(false));
527    }
528
529    #[test]
530    fn test_classify_cmdlet() {
531        assert_eq!(classify_command_name("Get-Content"), "cmdlet");
532        assert_eq!(classify_command_name("Remove-Item"), "cmdlet");
533    }
534
535    #[test]
536    fn test_classify_application() {
537        assert_eq!(classify_command_name("git"), "application");
538        assert_eq!(classify_command_name("./script.ps1"), "application");
539    }
540
541    #[test]
542    fn test_is_powershell_parameter() {
543        assert!(is_powershell_parameter("-Path", None));
544        assert!(is_powershell_parameter("-Recurse", None));
545        assert!(is_powershell_parameter("/C", None));
546        assert!(!is_powershell_parameter("file.txt", None));
547    }
548
549    #[test]
550    fn test_derive_security_flags_variables() {
551        let parsed = parse_powershell_command("$env:SECRET | Write-Host");
552        let flags = derive_security_flags(&parsed);
553        assert!(flags.has_variables);
554    }
555
556    #[test]
557    fn test_derive_security_flags_script_blocks() {
558        let parsed = parse_powershell_command("Get-Process | Where-Object { $_.CPU }");
559        let flags = derive_security_flags(&parsed);
560        assert!(flags.has_script_blocks);
561    }
562
563    #[test]
564    fn test_derive_security_flags_subexpression() {
565        let parsed = parse_powershell_command("Invoke-Expression $(malicious)");
566        let flags = derive_security_flags(&parsed);
567        assert!(flags.has_sub_expressions);
568    }
569
570    #[test]
571    fn test_derive_security_flags_assignment() {
572        let parsed = parse_powershell_command("$result = Get-Content file.txt");
573        let flags = derive_security_flags(&parsed);
574        assert!(flags.has_assignments);
575    }
576
577    #[test]
578    fn test_empty_command() {
579        let result = parse_powershell_command("");
580        assert!(!result.valid);
581    }
582
583    #[test]
584    fn test_member_invocation() {
585        let result = parse_powershell_command("$obj.Method()");
586        assert!(result.valid);
587    }
588
589    #[test]
590    fn test_parse_alias() {
591        let result = parse_powershell_command("gc file.txt");
592        assert!(result.valid);
593        assert_eq!(result.statements[0].commands[0].name, "gc");
594    }
595}