net_shell/vars/
mod.rs

1use std::collections::HashMap;
2use regex::Regex;
3use anyhow::{Result, Context};
4use crate::models::{ExtractRule, ExecutionResult};
5
6/// 变量管理器
7#[derive(Debug, Clone)]
8pub struct VariableManager {
9    variables: HashMap<String, String>,
10}
11
12impl VariableManager {
13    /// 创建新的变量管理器
14    pub fn new(initial_variables: Option<HashMap<String, String>>) -> Self {
15        Self {
16            variables: initial_variables.unwrap_or_default(),
17        }
18    }
19
20    /// 替换字符串中的变量占位符
21    pub fn replace_variables(&self, content: &str) -> String {
22        let mut result = content.to_string();
23        
24        // 替换 {{ variable_name }} 格式的变量
25        for (key, value) in &self.variables {
26            let placeholder = format!("{{{{ {} }}}}", key);
27            result = result.replace(&placeholder, value);
28        }
29        
30        result
31    }
32
33    /// 从执行结果中提取变量
34    pub fn extract_variables(&mut self, extract_rules: &[ExtractRule], execution_result: &ExecutionResult) -> Result<()> {
35        for rule in extract_rules {
36            let source_content = match rule.source.as_str() {
37                "stdout" => &execution_result.stdout,
38                "stderr" => &execution_result.stderr,
39                "exit_code" => &execution_result.exit_code.to_string(),
40                _ => {
41                    return Err(anyhow::anyhow!("Unknown extract source: {}", rule.source));
42                }
43            };
44
45            // 检查是否启用级联模式
46            if rule.cascade {
47                // 级联模式:前一个正则的匹配结果作为下一个正则的输入
48                self.extract_with_cascade(rule, source_content)?;
49            } else {
50                // 普通模式:尝试多个正则表达式,按顺序匹配直到成功
51                self.extract_with_fallback(rule, source_content)?;
52            }
53        }
54        
55        Ok(())
56    }
57
58    /// 级联模式提取:前一个正则的匹配结果作为下一个正则的输入
59    /// 约定:始终获取第一个捕获组(第一个括号)的内容
60    fn extract_with_cascade(&mut self, rule: &ExtractRule, source_content: &str) -> Result<()> {
61        let mut current_content = source_content.to_string();
62        let mut extracted_value = None;
63
64        for (pattern_index, pattern) in rule.patterns.iter().enumerate() {
65            let regex = Regex::new(pattern)
66                .context(format!("Invalid regex pattern {} for rule '{}': {}", pattern_index + 1, rule.name, pattern))?;
67            
68            if let Some(captures) = regex.captures(&current_content) {
69                // 约定:始终获取第一个捕获组(第一个括号)的内容
70                let matched_value = if let Some(value) = captures.get(1) {
71                    value.as_str().to_string()
72                } else {
73                    // 如果没有捕获组,记录警告并使用完整匹配
74                    tracing::warn!("Pattern {} for rule '{}' has no capture groups, using full match: {}", 
75                                  pattern_index + 1, rule.name, pattern);
76                    if let Some(full_match) = captures.get(0) {
77                        full_match.as_str().to_string()
78                    } else {
79                        continue;
80                    }
81                };
82                
83                if pattern_index == rule.patterns.len() - 1 {
84                    // 最后一个正则,保存最终结果
85                    extracted_value = Some(matched_value);
86                    break;
87                } else {
88                    // 不是最后一个正则,将匹配结果作为下一个正则的输入
89                    current_content = matched_value;
90                }
91            } else {
92                // 当前正则没有匹配,级联失败
93                tracing::debug!("Cascade failed at pattern {} for rule '{}': no match", pattern_index + 1, rule.name);
94                break;
95            }
96        }
97
98        if let Some(value) = extracted_value {
99            self.variables.insert(rule.name.clone(), value.clone());
100            tracing::debug!("Cascade extraction successful for rule '{}': {}", rule.name, value);
101        } else {
102            tracing::debug!("Cascade extraction failed for rule '{}'", rule.name);
103        }
104
105        Ok(())
106    }
107
108    /// 普通模式提取:尝试多个正则表达式,按顺序匹配直到成功
109    /// 约定:始终获取第一个捕获组(第一个括号)的内容
110    fn extract_with_fallback(&mut self, rule: &ExtractRule, source_content: &str) -> Result<()> {
111        let mut extracted = false;
112        
113        for (pattern_index, pattern) in rule.patterns.iter().enumerate() {
114            let regex = Regex::new(pattern)
115                .context(format!("Invalid regex pattern {} for rule '{}': {}", pattern_index + 1, rule.name, pattern))?;
116            
117            if let Some(captures) = regex.captures(source_content) {
118                // 约定:始终获取第一个捕获组(第一个括号)的内容
119                if let Some(value) = captures.get(1) {
120                    self.variables.insert(rule.name.clone(), value.as_str().to_string());
121                    extracted = true;
122                    tracing::debug!("Fallback extraction successful for rule '{}' with pattern {}: {}", rule.name, pattern_index + 1, value.as_str());
123                    break; // 找到匹配就停止尝试其他模式
124                } else {
125                    // 如果没有捕获组,记录警告
126                    tracing::warn!("Pattern {} for rule '{}' has no capture groups: {}", 
127                                  pattern_index + 1, rule.name, pattern);
128                }
129            }
130        }
131        
132        // 可选:记录未匹配的规则(用于调试)
133        if !extracted {
134            tracing::debug!("No pattern matched for rule '{}' in source '{}'", rule.name, rule.source);
135        }
136
137        Ok(())
138    }
139
140    /// 获取当前所有变量
141    pub fn get_variables(&self) -> &HashMap<String, String> {
142        &self.variables
143    }
144
145    /// 移除变量
146    pub fn remove_variable(&mut self, key: &str) {
147        self.variables.remove(key);
148    }
149
150    /// 设置变量
151    pub fn set_variable(&mut self, key: String, value: String) {
152        self.variables.insert(key, value);
153    }
154
155    /// 获取变量值
156    pub fn get_variable(&self, key: &str) -> Option<&String> {
157        self.variables.get(key)
158    }
159}