flowbuilder_yaml/
expression.rs

1use anyhow::{Context, Result};
2use regex::Regex;
3use std::collections::HashMap;
4
5/// 表达式求值器,用于处理工作流中的变量和表达式
6#[derive(Clone)]
7pub struct ExpressionEvaluator {
8    env_vars: HashMap<String, String>,
9    flow_vars: HashMap<String, serde_yaml::Value>,
10    context_vars: HashMap<String, serde_yaml::Value>,
11}
12
13impl ExpressionEvaluator {
14    /// 创建新的表达式求值器
15    pub fn new() -> Self {
16        Self {
17            env_vars: HashMap::new(),
18            flow_vars: HashMap::new(),
19            context_vars: HashMap::new(),
20        }
21    }
22
23    /// 设置环境变量
24    pub fn set_env_vars(&mut self, env_vars: HashMap<String, String>) {
25        self.env_vars = env_vars;
26    }
27
28    /// 设置流程变量
29    pub fn set_flow_vars(
30        &mut self,
31        flow_vars: HashMap<String, serde_yaml::Value>,
32    ) {
33        self.flow_vars = flow_vars;
34    }
35
36    /// 设置上下文变量
37    pub fn set_context_var<S: AsRef<str>>(
38        &mut self,
39        key: S,
40        value: serde_yaml::Value,
41    ) {
42        self.context_vars.insert(key.as_ref().to_string(), value);
43    }
44
45    /// 获取上下文变量
46    pub fn get_context_var<S: AsRef<str>>(
47        &self,
48        key: S,
49    ) -> Option<&serde_yaml::Value> {
50        self.context_vars.get(key.as_ref())
51    }
52
53    /// 求值表达式字符串
54    pub fn evaluate(&self, expression: &str) -> Result<serde_yaml::Value> {
55        // 处理环境变量引用: ${{ env.VAR_NAME }}
56        let env_regex = Regex::new(r"\$\{\{\s*env\.(\w+)\s*\}\}")
57            .context("Failed to compile env regex")?;
58
59        let mut result = expression.to_string();
60
61        for cap in env_regex.captures_iter(expression) {
62            let var_name = &cap[1];
63            if let Some(env_value) = self.env_vars.get(var_name) {
64                result = result.replace(&cap[0], env_value);
65            } else {
66                return Err(anyhow::anyhow!(
67                    "Environment variable not found: {}",
68                    var_name
69                ));
70            }
71        }
72
73        // 处理流程变量引用: ${{ vars.VAR_NAME }}
74        let vars_regex = Regex::new(r"\$\{\{\s*vars\.(\w+)\s*\}\}")
75            .context("Failed to compile vars regex")?;
76
77        for cap in vars_regex.captures_iter(&result.clone()) {
78            let var_name = &cap[1];
79            if let Some(var_value) = self.flow_vars.get(var_name) {
80                let value_str = self.yaml_value_to_string(var_value);
81                result = result.replace(&cap[0], &value_str);
82            } else {
83                return Err(anyhow::anyhow!(
84                    "Flow variable not found: {}",
85                    var_name
86                ));
87            }
88        }
89
90        // 处理任务输出引用: ${task.action.outputs.field}
91        let output_regex = Regex::new(r"\$\{([^}]+)\}")
92            .context("Failed to compile output regex")?;
93
94        for cap in output_regex.captures_iter(&result.clone()) {
95            let path = &cap[1];
96            if let Some(value) = self.resolve_context_path(path)? {
97                let value_str = self.yaml_value_to_string(&value);
98                result = result.replace(&cap[0], &value_str);
99            } else {
100                return Err(anyhow::anyhow!(
101                    "Context path not found: {}",
102                    path
103                ));
104            }
105        }
106
107        // 尝试解析结果为合适的类型
108        self.parse_result(&result)
109    }
110
111    /// 求值条件表达式,返回布尔值
112    pub fn evaluate_condition(&self, condition: &str) -> Result<bool> {
113        let result = self.evaluate(condition)?;
114
115        match result {
116            serde_yaml::Value::Bool(b) => Ok(b),
117            serde_yaml::Value::String(s) => {
118                // 简单的条件解析
119                if s.contains("==") {
120                    self.evaluate_equality(&s)
121                } else if s.contains("!=") {
122                    self.evaluate_inequality(&s)
123                } else if s.contains("&&") {
124                    self.evaluate_and(&s)
125                } else if s.contains("||") {
126                    self.evaluate_or(&s)
127                } else {
128                    // 非空字符串视为 true
129                    Ok(!s.is_empty())
130                }
131            }
132            serde_yaml::Value::Number(n) => {
133                if let Some(i) = n.as_i64() {
134                    Ok(i != 0)
135                } else if let Some(f) = n.as_f64() {
136                    Ok(f != 0.0)
137                } else {
138                    Ok(false)
139                }
140            }
141            serde_yaml::Value::Null => Ok(false),
142            _ => Ok(true),
143        }
144    }
145
146    fn resolve_context_path(
147        &self,
148        path: &str,
149    ) -> Result<Option<serde_yaml::Value>> {
150        let parts: Vec<&str> = path.split('.').collect();
151
152        if parts.len() < 2 {
153            return Ok(None);
154        }
155
156        // 尝试从上下文变量中解析路径
157        let key = parts.join(".");
158        if let Some(value) = self.context_vars.get(&key) {
159            return Ok(Some(value.clone()));
160        }
161
162        // 尝试按路径结构解析
163        let mut current = None;
164        for (ctx_key, ctx_value) in &self.context_vars {
165            if ctx_key.starts_with(parts[0]) {
166                current = Some(ctx_value.clone());
167                break;
168            }
169        }
170
171        Ok(current)
172    }
173
174    fn yaml_value_to_string(&self, value: &serde_yaml::Value) -> String {
175        match value {
176            serde_yaml::Value::String(s) => s.clone(),
177            serde_yaml::Value::Number(n) => n.to_string(),
178            serde_yaml::Value::Bool(b) => b.to_string(),
179            serde_yaml::Value::Null => "null".to_string(),
180            _ => serde_yaml::to_string(value).unwrap_or_default(),
181        }
182    }
183
184    fn parse_result(&self, result: &str) -> Result<serde_yaml::Value> {
185        // 尝试解析为数字
186        if let Ok(i) = result.parse::<i64>() {
187            return Ok(serde_yaml::Value::Number(serde_yaml::Number::from(i)));
188        }
189
190        if let Ok(f) = result.parse::<f64>() {
191            return Ok(serde_yaml::Value::Number(serde_yaml::Number::from(f)));
192        }
193
194        // 尝试解析为布尔值
195        if let Ok(b) = result.parse::<bool>() {
196            return Ok(serde_yaml::Value::Bool(b));
197        }
198
199        // 默认为字符串
200        Ok(serde_yaml::Value::String(result.to_string()))
201    }
202
203    fn evaluate_equality(&self, expr: &str) -> Result<bool> {
204        let parts: Vec<&str> = expr.splitn(2, "==").collect();
205        if parts.len() != 2 {
206            return Ok(false);
207        }
208
209        let left = parts[0].trim();
210        let right = parts[1].trim();
211        Ok(left == right)
212    }
213
214    fn evaluate_inequality(&self, expr: &str) -> Result<bool> {
215        let parts: Vec<&str> = expr.splitn(2, "!=").collect();
216        if parts.len() != 2 {
217            return Ok(false);
218        }
219
220        let left = parts[0].trim();
221        let right = parts[1].trim();
222        Ok(left != right)
223    }
224
225    fn evaluate_and(&self, expr: &str) -> Result<bool> {
226        let parts: Vec<&str> = expr.split("&&").collect();
227        for part in parts {
228            if !self.evaluate_condition(part.trim())? {
229                return Ok(false);
230            }
231        }
232        Ok(true)
233    }
234
235    fn evaluate_or(&self, expr: &str) -> Result<bool> {
236        let parts: Vec<&str> = expr.split("||").collect();
237        for part in parts {
238            if self.evaluate_condition(part.trim())? {
239                return Ok(true);
240            }
241        }
242        Ok(false)
243    }
244}
245
246impl Default for ExpressionEvaluator {
247    fn default() -> Self {
248        Self::new()
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_env_var_evaluation() {
258        let mut evaluator = ExpressionEvaluator::new();
259        let mut env_vars = HashMap::new();
260        env_vars.insert("TEST_VAR".to_string(), "test_value".to_string());
261        evaluator.set_env_vars(env_vars);
262
263        let result = evaluator.evaluate("${{ env.TEST_VAR }}").unwrap();
264        assert_eq!(result, serde_yaml::Value::String("test_value".to_string()));
265    }
266
267    #[test]
268    fn test_flow_var_evaluation() {
269        let mut evaluator = ExpressionEvaluator::new();
270        let mut flow_vars = HashMap::new();
271        flow_vars.insert(
272            "name".to_string(),
273            serde_yaml::Value::String("FlowBuilder".to_string()),
274        );
275        evaluator.set_flow_vars(flow_vars);
276
277        let result = evaluator.evaluate("${{ vars.name }}").unwrap();
278        assert_eq!(
279            result,
280            serde_yaml::Value::String("FlowBuilder".to_string())
281        );
282    }
283
284    #[test]
285    fn test_condition_evaluation() {
286        let evaluator = ExpressionEvaluator::new();
287
288        assert!(evaluator.evaluate_condition("true").unwrap());
289        assert!(!evaluator.evaluate_condition("false").unwrap());
290        assert!(evaluator.evaluate_condition("test == test").unwrap());
291        assert!(!evaluator.evaluate_condition("test != test").unwrap());
292    }
293}