Skip to main content

ai_agent/utils/
powershell.rs

1//! PowerShell execution utilities.
2//!
3//! This module provides PowerShell command parsing, execution, and security analysis.
4
5use std::process::Command;
6
7/// Escape a string for use in PowerShell
8pub fn escape_powershell_string(s: &str) -> String {
9    // Replace backticks, double quotes, and $ with escaped versions
10    s.replace('`', "``").replace('"', "`\"").replace('$', "`$")
11}
12
13// ============================================================================
14// PowerShell Parser - AST types and parsing functions
15// ============================================================================
16
17use serde::{Deserialize, Serialize};
18
19/// Pipeline element type
20#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21pub enum PipelineElementType {
22    CommandAst,
23    CommandExpressionAst,
24    ParenExpressionAst,
25}
26
27/// Command element type (AST node classification)
28#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
29pub enum CommandElementType {
30    ScriptBlock,
31    SubExpression,
32    ExpandableString,
33    MemberInvocation,
34    Variable,
35    StringConstant,
36    Parameter,
37    Other,
38}
39
40/// Statement type
41#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
42pub enum StatementType {
43    PipelineAst,
44    PipelineChainAst,
45    AssignmentStatementAst,
46    IfStatementAst,
47    ForStatementAst,
48    ForEachStatementAst,
49    WhileStatementAst,
50    DoWhileStatementAst,
51    DoUntilStatementAst,
52    SwitchStatementAst,
53    TryStatementAst,
54    TrapStatementAst,
55    FunctionDefinitionAst,
56    DataStatementAst,
57    UnknownStatementAst,
58}
59
60/// A child node of a command element
61#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub struct CommandElementChild {
63    pub element_type: CommandElementType,
64    pub text: String,
65}
66
67/// Redirection
68#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
69pub struct ParsedRedirection {
70    pub from: String,
71    pub to: String,
72    pub is_merging: bool,
73}
74
75/// A command invocation within a pipeline
76#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
77pub struct ParsedCommandElement {
78    pub name: String,
79    pub name_type: String,
80    pub element_type: PipelineElementType,
81    pub args: Vec<String>,
82    pub text: String,
83    pub element_types: Option<Vec<CommandElementType>>,
84    pub children: Option<Vec<Option<Vec<CommandElementChild>>>>,
85    pub redirections: Option<Vec<ParsedRedirection>>,
86}
87
88/// Pipeline segment
89#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
90pub struct PipelineSegment {
91    pub commands: Vec<ParsedCommandElement>,
92    pub redirections: Vec<ParsedRedirection>,
93    pub nested_commands: Option<Vec<ParsedCommandElement>>,
94}
95
96/// A statement in the PowerShell command
97#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
98pub struct ParsedStatement {
99    pub statement_type: StatementType,
100    pub commands: Vec<ParsedCommandElement>,
101}
102
103/// Complete parsed PowerShell command
104#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105pub struct ParsedPowerShellCommand {
106    pub valid: bool,
107    pub statements: Vec<ParsedStatement>,
108    pub error: Option<String>,
109}
110
111/// Check if a string is a PowerShell parameter
112pub fn is_powershell_parameter(arg: &str, element_type: Option<&CommandElementType>) -> bool {
113    if let Some(et) = element_type {
114        return *et == CommandElementType::Parameter;
115    }
116    // Check for common parameter prefixes
117    arg.starts_with('-')
118        || arg.starts_with('/')
119        || arg.starts_with('–')
120        || arg.starts_with('—')
121        || arg.starts_with('―')
122}
123
124/// Alternative parameter prefix characters
125pub const PS_TOKENIZER_DASH_CHARS: &[char] = &['-', '–', '—', '―', '/'];
126
127/// Parse a PowerShell command string into components
128pub fn parse_powershell_command(command: &str) -> ParsedPowerShellCommand {
129    let trimmed = command.trim();
130
131    if trimmed.is_empty() {
132        return ParsedPowerShellCommand {
133            valid: false,
134            statements: vec![],
135            error: Some("Empty command".to_string()),
136        };
137    }
138
139    // Split by statement separators: ; and newlines
140    let statement_strs: Vec<&str> = trimmed
141        .split(|c| c == ';' || c == '\n')
142        .filter(|s| !s.trim().is_empty())
143        .collect();
144
145    let mut statements = Vec::new();
146
147    for stmt_str in statement_strs {
148        let statement_type = detect_statement_type(stmt_str);
149
150        // Parse pipeline elements (split by |)
151        let pipeline_strs: Vec<&str> = stmt_str.split('|').collect();
152        let mut commands = Vec::new();
153
154        for (idx, pipeline_str) in pipeline_strs.iter().enumerate() {
155            let pipeline_trimmed = pipeline_str.trim();
156            if pipeline_trimmed.is_empty() {
157                continue;
158            }
159
160            // Split by operators: &, && (but NOT | since we already split by | for pipeline)
161            let parts: Vec<&str> = pipeline_trimmed
162                .split(|c| c == '&')
163                .filter(|s| !s.trim().is_empty())
164                .collect();
165
166            for part in parts {
167                let part_trimmed = part.trim();
168                if part_trimmed.is_empty() {
169                    continue;
170                }
171
172                let cmd = parse_command_element(part_trimmed, idx == 0);
173                commands.push(cmd);
174            }
175        }
176
177        if !commands.is_empty() {
178            statements.push(ParsedStatement {
179                statement_type,
180                commands,
181            });
182        }
183    }
184
185    ParsedPowerShellCommand {
186        valid: !statements.is_empty(),
187        statements,
188        error: None,
189    }
190}
191
192/// Detect statement type from keywords
193fn detect_statement_type(cmd: &str) -> StatementType {
194    let lower = cmd.to_lowercase();
195
196    if lower.contains(" if ") || lower.starts_with("if ") {
197        StatementType::IfStatementAst
198    } else if lower.contains(" foreach ") || lower.starts_with("foreach ") || lower.contains("%{") {
199        StatementType::ForEachStatementAst
200    } else if lower.contains(" for ") || lower.starts_with("for ") {
201        StatementType::ForStatementAst
202    } else if lower.contains(" while ") || lower.starts_with("while ") {
203        StatementType::WhileStatementAst
204    } else if lower.contains(" do ") || lower.starts_with("do ") {
205        StatementType::DoWhileStatementAst
206    } else if lower.contains(" switch ") || lower.starts_with("switch ") {
207        StatementType::SwitchStatementAst
208    } else if lower.contains(" try ") || lower.starts_with("try ") {
209        StatementType::TryStatementAst
210    } else if lower.contains(" function ") || lower.starts_with("function ") {
211        StatementType::FunctionDefinitionAst
212    } else if lower.contains('=') && !lower.contains("==") {
213        StatementType::AssignmentStatementAst
214    } else {
215        StatementType::PipelineAst
216    }
217}
218
219/// Parse a single command element
220fn parse_command_element(text: &str, is_first: bool) -> ParsedCommandElement {
221    let parts: Vec<&str> = text.split_whitespace().collect();
222
223    if parts.is_empty() {
224        return create_empty_command(text.to_string());
225    }
226
227    let name = parts[0].to_string();
228    let args: Vec<String> = parts[1..].iter().map(|s| s.to_string()).collect();
229    let name_type = classify_command_name(&name);
230    let element_type = if is_first {
231        PipelineElementType::CommandAst
232    } else {
233        PipelineElementType::CommandExpressionAst
234    };
235    let element_types = Some(determine_element_types(&args));
236
237    ParsedCommandElement {
238        name,
239        name_type,
240        element_type,
241        args,
242        text: text.to_string(),
243        element_types,
244        children: None,
245        redirections: None,
246    }
247}
248
249/// Create an empty command
250fn create_empty_command(text: String) -> ParsedCommandElement {
251    ParsedCommandElement {
252        name: String::new(),
253        name_type: "unknown".to_string(),
254        element_type: PipelineElementType::CommandAst,
255        args: vec![],
256        text,
257        element_types: None,
258        children: None,
259        redirections: None,
260    }
261}
262
263/// Classify command name type
264fn classify_command_name(name: &str) -> String {
265    let lower = name.to_lowercase();
266
267    // Check if it's a cmdlet (Verb-Noun pattern)
268    if lower.contains('-') {
269        return "cmdlet".to_string();
270    }
271
272    // Check if it has path separators (application)
273    if lower.contains('\\') || lower.contains('/') || lower.contains('.') {
274        return "application".to_string();
275    }
276
277    // Common external commands
278    let external = [
279        "git", "gh", "docker", "npm", "node", "python", "make", "tar", "curl", "wget",
280    ];
281    if external.contains(&lower.as_str()) {
282        return "application".to_string();
283    }
284
285    "unknown".to_string()
286}
287
288/// Determine element types for arguments
289fn determine_element_types(args: &[String]) -> Vec<CommandElementType> {
290    let mut types = vec![CommandElementType::StringConstant];
291
292    for arg in args {
293        let et = classify_argument_element(arg);
294        types.push(et);
295    }
296
297    types
298}
299
300/// Classify an argument's element type
301fn classify_argument_element(arg: &str) -> CommandElementType {
302    let trimmed = arg.trim();
303
304    // Check for variable $var (includes $_.Name, $env:VAR, etc.)
305    if trimmed.starts_with('$') && trimmed.len() > 1 {
306        let second = trimmed.chars().nth(1);
307        // $() is subexpression, $var is variable
308        if second == Some('(') || second == Some('@') {
309            return CommandElementType::SubExpression;
310        }
311        // Check if it's $_.Property (member access with variable)
312        if second == Some('_') || second.is_some_and(|c| c.is_alphabetic()) {
313            return CommandElementType::Variable;
314        }
315        return CommandElementType::Variable;
316    }
317
318    // Check for script block {} - exact match OR contains script block content
319    // Handles: {}, { code }, { $_.Name }, and partial tokens like { and }
320    if trimmed.starts_with('{')
321        || trimmed.ends_with('}')
322        || trimmed.contains("{ ")
323        || trimmed.contains(" }")
324        || trimmed.contains("{}")
325    {
326        return CommandElementType::ScriptBlock;
327    }
328
329    // Check for subexpression $() - exact match OR contains it
330    if trimmed.starts_with("$(")
331        || trimmed.starts_with("@(")
332        || trimmed.contains("$(")
333        || trimmed.contains("@(")
334    {
335        return CommandElementType::SubExpression;
336    }
337
338    // Check for expandable string
339    if trimmed.starts_with('"') && trimmed.ends_with('"') {
340        return CommandElementType::ExpandableString;
341    }
342
343    // Check for member invocation .Method()
344    if trimmed.contains('.') && trimmed.contains('(') {
345        return CommandElementType::MemberInvocation;
346    }
347
348    // Check for parameter
349    if is_powershell_parameter(trimmed, None) {
350        return CommandElementType::Parameter;
351    }
352
353    CommandElementType::StringConstant
354}
355
356/// Derive security flags from parsed command
357pub fn derive_security_flags(parsed: &ParsedPowerShellCommand) -> SecurityFlags {
358    let mut flags = SecurityFlags::default();
359
360    for statement in &parsed.statements {
361        for cmd in &statement.commands {
362            // Check element_types first
363            if let Some(ref types) = cmd.element_types {
364                for et in types {
365                    match et {
366                        CommandElementType::ScriptBlock => flags.has_script_blocks = true,
367                        CommandElementType::SubExpression => flags.has_sub_expressions = true,
368                        CommandElementType::ExpandableString => flags.has_expandable_strings = true,
369                        CommandElementType::MemberInvocation => flags.has_member_invocations = true,
370                        CommandElementType::Variable => flags.has_variables = true,
371                        _ => {}
372                    }
373                }
374            }
375
376            // Also check raw args for variables and script blocks (handles edge cases)
377            for arg in &cmd.args {
378                // Check for variables: $var, $env:VAR, $_.prop
379                if arg.starts_with('$') && arg.len() > 1 {
380                    let second = arg.chars().nth(1);
381                    // $(), @() are subexpressions
382                    if second == Some('(') || second == Some('@') {
383                        flags.has_sub_expressions = true;
384                    } else {
385                        flags.has_variables = true;
386                    }
387                }
388                // Check for script blocks
389                if arg.contains('{') || arg.contains('}') {
390                    flags.has_script_blocks = true;
391                }
392                // Check for subexpressions
393                if arg.contains("$(") || arg.contains("@(") {
394                    flags.has_sub_expressions = true;
395                }
396                // Check for expandable strings
397                if arg.starts_with('"') && arg.ends_with('"') {
398                    flags.has_expandable_strings = true;
399                }
400                // Check for assignments
401                if arg.contains('=') && !arg.starts_with('-') {
402                    flags.has_assignments = true;
403                }
404            }
405
406            // Also check cmd.text (the full command text) for edge cases
407            // e.g., "$env:SECRET" becomes a command with name="$env:SECRET", args=[]
408            let text = &cmd.text;
409            if text.starts_with('$') && text.len() > 1 && !text.contains(' ') {
410                // Single token like $env:SECRET - it's a variable reference
411                flags.has_variables = true;
412            }
413            // Check for script blocks in text
414            if text.contains('{') || text.contains('}') {
415                flags.has_script_blocks = true;
416            }
417            // Check for subexpressions in text
418            if text.contains("$(") || text.contains("@(") {
419                flags.has_sub_expressions = true;
420            }
421        }
422    }
423
424    flags
425}
426
427/// Security flags derived from parsing
428#[derive(Debug, Clone, Default)]
429pub struct SecurityFlags {
430    pub has_script_blocks: bool,
431    pub has_sub_expressions: bool,
432    pub has_expandable_strings: bool,
433    pub has_member_invocations: bool,
434    pub has_splatting: bool,
435    pub has_assignments: bool,
436    pub has_stop_parsing: bool,
437    pub has_variables: bool,
438}
439
440// ============================================================================
441// Shell execution functions
442// ============================================================================
443
444/// Build a PowerShell command that outputs as UTF-8
445pub fn build_powershell_command(script: &str) -> Command {
446    let mut cmd = Command::new("pwsh");
447    cmd.args(["-NoProfile", "-NonInteractive", "-Command", script]);
448    cmd
449}
450
451/// Build a PowerShell command with UTF-8 output encoding
452pub fn build_powershell_command_utf8(script: &str) -> Command {
453    let full_script = format!(
454        "[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; {}",
455        script
456    );
457    build_powershell_command(&full_script)
458}
459
460/// Check if PowerShell is available
461pub fn is_powershell_available() -> bool {
462    Command::new("pwsh")
463        .arg("--version")
464        .output()
465        .map(|o| o.status.success())
466        .unwrap_or(false)
467}
468
469/// Get the PowerShell version
470pub fn get_powershell_version() -> Option<String> {
471    Command::new("pwsh")
472        .arg("--version")
473        .output()
474        .ok()
475        .and_then(|o| {
476            if o.status.success() {
477                String::from_utf8(o.stdout).ok()
478            } else {
479                None
480            }
481        })
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn test_parse_simple_command() {
490        let result = parse_powershell_command("Get-Content file.txt");
491        assert!(result.valid);
492        assert_eq!(result.statements.len(), 1);
493        assert_eq!(result.statements[0].commands[0].name, "Get-Content");
494    }
495
496    #[test]
497    fn test_parse_command_with_args() {
498        let result = parse_powershell_command("Remove-Item -Path test.txt -Recurse -Force");
499        assert!(result.valid);
500        let cmd = &result.statements[0].commands[0];
501        assert_eq!(cmd.name, "Remove-Item");
502        assert!(cmd.args.contains(&"-Path".to_string()));
503    }
504
505    #[test]
506    fn test_parse_pipeline() {
507        let result = parse_powershell_command("Get-Content file.txt | Select-String pattern");
508        assert!(result.valid);
509        assert_eq!(result.statements[0].commands.len(), 2);
510    }
511
512    #[test]
513    fn test_parse_compound_statements() {
514        let result = parse_powershell_command("$var = 1; Get-Content file.txt");
515        assert!(result.valid);
516        assert_eq!(result.statements.len(), 2);
517    }
518
519    #[test]
520    fn test_detect_variables() {
521        let result = parse_powershell_command("Write-Host $env:SECRET");
522        assert!(result.valid);
523        let types = &result.statements[0].commands[0].element_types;
524        assert!(
525            types
526                .as_ref()
527                .map(|t| t.iter().any(|et| *et == CommandElementType::Variable))
528                .unwrap_or(false)
529        );
530    }
531
532    #[test]
533    fn test_detect_script_blocks() {
534        let result = parse_powershell_command("Where-Object { $_.Name }");
535        assert!(result.valid);
536        let types = &result.statements[0].commands[0].element_types;
537        assert!(
538            types
539                .as_ref()
540                .map(|t| t.iter().any(|et| *et == CommandElementType::ScriptBlock))
541                .unwrap_or(false)
542        );
543    }
544
545    #[test]
546    fn test_detect_subexpression() {
547        let result = parse_powershell_command("Invoke-Expression $(malicious)");
548        assert!(result.valid);
549        let types = &result.statements[0].commands[0].element_types;
550        assert!(
551            types
552                .as_ref()
553                .map(|t| t.iter().any(|et| *et == CommandElementType::SubExpression))
554                .unwrap_or(false)
555        );
556    }
557
558    #[test]
559    fn test_classify_cmdlet() {
560        assert_eq!(classify_command_name("Get-Content"), "cmdlet");
561        assert_eq!(classify_command_name("Remove-Item"), "cmdlet");
562    }
563
564    #[test]
565    fn test_classify_application() {
566        assert_eq!(classify_command_name("git"), "application");
567        assert_eq!(classify_command_name("./script.ps1"), "application");
568    }
569
570    #[test]
571    fn test_is_powershell_parameter() {
572        assert!(is_powershell_parameter("-Path", None));
573        assert!(is_powershell_parameter("-Recurse", None));
574        assert!(is_powershell_parameter("/C", None));
575        assert!(!is_powershell_parameter("file.txt", None));
576    }
577
578    #[test]
579    fn test_derive_security_flags_variables() {
580        let parsed = parse_powershell_command("$env:SECRET | Write-Host");
581        let flags = derive_security_flags(&parsed);
582        assert!(flags.has_variables);
583    }
584
585    #[test]
586    fn test_derive_security_flags_script_blocks() {
587        let parsed = parse_powershell_command("Get-Process | Where-Object { $_.CPU }");
588        let flags = derive_security_flags(&parsed);
589        assert!(flags.has_script_blocks);
590    }
591
592    #[test]
593    fn test_derive_security_flags_subexpression() {
594        let parsed = parse_powershell_command("Invoke-Expression $(malicious)");
595        let flags = derive_security_flags(&parsed);
596        assert!(flags.has_sub_expressions);
597    }
598
599    #[test]
600    fn test_derive_security_flags_assignment() {
601        let parsed = parse_powershell_command("$result = Get-Content file.txt");
602        let flags = derive_security_flags(&parsed);
603        assert!(flags.has_assignments);
604    }
605
606    #[test]
607    fn test_empty_command() {
608        let result = parse_powershell_command("");
609        assert!(!result.valid);
610    }
611
612    #[test]
613    fn test_member_invocation() {
614        let result = parse_powershell_command("$obj.Method()");
615        assert!(result.valid);
616    }
617
618    #[test]
619    fn test_parse_alias() {
620        let result = parse_powershell_command("gc file.txt");
621        assert!(result.valid);
622        assert_eq!(result.statements[0].commands[0].name, "gc");
623    }
624}