Skip to main content

relay_core_lib/rule/engine/
executor.rs

1use crate::rule::model::{Rule, RuleGroup, RuleStage, RuleOutcome, RuleExecutionEvent, RuleTermination, RuleTraceSummary};
2use crate::rule::engine::compiler;
3use crate::rule::engine::compiled::CompiledRule;
4use crate::rule::engine::matcher;
5use crate::rule::engine::validator;
6use crate::rule::engine::actions;
7use crate::rule::engine::state::{RuleStateStore, InMemoryRuleStateStore};
8use relay_core_api::flow::Flow;
9use std::collections::HashMap;
10use std::sync::Arc;
11use relay_core_api::policy::ProxyPolicy;
12
13pub struct ExecutionContext {
14    pub trace: Vec<RuleExecutionEvent>,
15    pub variables: HashMap<String, String>,
16    pub policy: Option<Arc<ProxyPolicy>>,
17    pub summary: RuleTraceSummary,
18    pub state_store: Arc<dyn RuleStateStore>,
19}
20
21#[derive(Debug)]
22pub struct RuleEngine {
23    compiled_rules: Vec<CompiledRule>,
24    policy: Option<Arc<ProxyPolicy>>,
25    state_store: Arc<dyn RuleStateStore>,
26}
27
28impl RuleEngine {
29    pub fn new(rules: Vec<Rule>, rule_groups: Vec<RuleGroup>, policy: Option<Arc<ProxyPolicy>>, state_store: Option<Arc<dyn RuleStateStore>>) -> Self {
30        let mut all_rules = Vec::new();
31        
32        // Flatten rules
33        for rule in rules {
34            all_rules.push(rule);
35        }
36        
37        // Flatten rule groups
38        for group in rule_groups {
39            if group.active {
40                for rule in group.rules {
41                    all_rules.push(rule);
42                }
43            }
44        }
45        
46        // Sort by priority (descending)
47        all_rules.sort_by_key(|r| std::cmp::Reverse(r.priority));
48        
49        // Compile all rules
50        let compiled_rules = all_rules.into_iter()
51            .map(compiler::compile_rule)
52            .collect();
53            
54        Self { 
55            compiled_rules, 
56            policy,
57            state_store: state_store.unwrap_or_else(|| Arc::new(InMemoryRuleStateStore::new()))
58        }
59    }
60
61    pub fn has_rules_for_stage(&self, stage: RuleStage) -> bool {
62        self.compiled_rules.iter().any(|r| r.original.active && r.original.stage == stage)
63    }
64
65    pub async fn execute(&self, stage: RuleStage, flow: &mut Flow) -> ExecutionContext {
66        let mut ctx = ExecutionContext { 
67            trace: vec![],
68            variables: HashMap::new(),
69            policy: self.policy.clone(),
70            summary: RuleTraceSummary::NoMatch,
71            state_store: self.state_store.clone(),
72        };
73        
74        let mut terminated = false;
75        let mut modified_rules = Vec::new();
76
77        for compiled_rule in &self.compiled_rules {
78            let rule = &compiled_rule.original;
79
80            if !rule.active {
81                continue;
82            }
83
84            if rule.stage != stage {
85                continue;
86            }
87            
88            // Validate filter stage (using pre-compiled filter)
89            if !validator::validate_filter_stage(&compiled_rule.filter, &stage) {
90                 ctx.trace.push(RuleExecutionEvent {
91                    rule_id: rule.id.clone(),
92                    stage: stage.clone(),
93                    matched: false,
94                    duration_us: 0,
95                    outcome: RuleOutcome::Failed(format!("Filter invalid for stage {:?}", stage)),
96                });
97                continue;
98            }
99            
100            // Validate action stage
101            let mut actions_valid = true;
102            for action in &rule.actions {
103                if !validator::validate_action_stage(action, &stage) {
104                     ctx.trace.push(RuleExecutionEvent {
105                        rule_id: rule.id.clone(),
106                        stage: stage.clone(),
107                        matched: false,
108                        duration_us: 0,
109                        outcome: RuleOutcome::Failed(format!("Action {:?} not allowed in stage {:?}", action, stage)),
110                    });
111                    actions_valid = false;
112                    break;
113                }
114            }
115            if !actions_valid {
116                continue;
117            }
118
119            // Match (using pre-compiled filter)
120            let start = std::time::Instant::now();
121            let matched = matcher::matches(&compiled_rule.filter, flow);
122            
123            if matched {
124                 let timeout_ms = rule.constraints.as_ref().and_then(|c| c.timeout_ms);
125                 
126                 let action_execution = async {
127                     let mut rule_outcome = RuleOutcome::MatchedAndExecuted;
128                     let mut rule_terminated = false;
129                     
130                     for action in &rule.actions {
131                        match actions::execute_action(action, flow, &mut ctx).await {
132                            actions::ActionOutcome::Continue => {},
133                            actions::ActionOutcome::Terminated(reason) => {
134                                 rule_outcome = RuleOutcome::MatchedAndTerminated;
135                                 ctx.summary = RuleTraceSummary::Terminated { 
136                                     rule_id: rule.id.clone(), 
137                                     reason 
138                                 };
139                                 rule_terminated = true;
140                                 break; 
141                            },
142                            actions::ActionOutcome::Failed(err) => {
143                                 rule_outcome = RuleOutcome::Failed(err);
144                                 break;
145                            }
146                        }
147                     }
148                     (rule_outcome, rule_terminated)
149                 };
150
151                 let (rule_outcome, rule_terminated) = if let Some(ms) = timeout_ms {
152                     match tokio::time::timeout(std::time::Duration::from_millis(ms), action_execution).await {
153                         Ok(res) => res,
154                         Err(_) => (RuleOutcome::Failed(format!("Rule execution timed out after {}ms", ms)), false)
155                     }
156                 } else {
157                     action_execution.await
158                 };
159                 
160                 if rule_terminated {
161                     terminated = true;
162                 }
163                 
164                 ctx.trace.push(RuleExecutionEvent {
165                    rule_id: rule.id.clone(),
166                    stage: stage.clone(),
167                    matched: true,
168                    duration_us: start.elapsed().as_micros() as u64,
169                    outcome: rule_outcome.clone(),
170                });
171                
172                if matches!(rule_outcome, RuleOutcome::MatchedAndExecuted) {
173                     modified_rules.push(rule.id.clone());
174                }
175
176                if terminated {
177                    break;
178                }
179                
180                if let RuleTermination::Stop = rule.termination {
181                    break;
182                }
183            }
184        }
185        
186        if !terminated && !modified_rules.is_empty() {
187            ctx.summary = RuleTraceSummary::Modified { rule_ids: modified_rules };
188        }
189        
190        ctx
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use crate::rule::model::{Action, Filter, Rule, RuleStage, RuleTermination, RuleOutcome, StringMatcher, BodySource, RuleConstraints};
198    use relay_core_api::flow::{Flow, Layer, NetworkInfo, TransportProtocol, HttpRequest};
199    use uuid::Uuid;
200    use chrono::Utc;
201    use url::Url;
202
203    fn create_test_flow() -> Flow {
204        Flow {
205            id: Uuid::new_v4(),
206            start_time: Utc::now(),
207            end_time: None,
208            network: NetworkInfo {
209                client_ip: "127.0.0.1".to_string(),
210                client_port: 12345,
211                server_ip: "1.1.1.1".to_string(),
212                server_port: 80,
213                protocol: TransportProtocol::TCP,
214                tls: false,
215                tls_version: None,
216                sni: None,
217            },
218            layer: Layer::Http(relay_core_api::flow::HttpLayer {
219                request: HttpRequest {
220                    method: "GET".to_string(),
221                    url: Url::parse("http://example.com").unwrap(),
222                    version: "HTTP/1.1".to_string(),
223                    headers: vec![],
224                    cookies: vec![],
225                    query: vec![],
226                    body: None,
227                },
228                response: None,
229                error: None,
230            }),
231            tags: vec![],
232            meta: std::collections::HashMap::new(),
233        }
234    }
235
236    #[tokio::test]
237    async fn test_rule_timeout() {
238        let rule = Rule {
239            id: "test-rule-timeout".to_string(),
240            name: "Test Rule Timeout".to_string(),
241            active: true,
242            stage: RuleStage::RequestHeaders,
243            priority: 0,
244            termination: RuleTermination::Continue,
245            filter: Filter::All,
246            actions: vec![
247                Action::Delay { ms: 200 } // Delay 200ms
248            ],
249            constraints: Some(RuleConstraints {
250                timeout_ms: Some(50), // Timeout 50ms
251            }),
252        };
253
254        let engine = RuleEngine::new(vec![rule], vec![], None, None);
255        let mut flow = create_test_flow();
256        
257        // Execute in RequestHeaders stage
258        let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
259        
260        // Expect 1 trace entry with Failed outcome due to timeout
261        assert_eq!(ctx.trace.len(), 1);
262        if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
263            assert!(msg.contains("timed out"));
264        } else {
265            panic!("Expected Failed outcome, got {:?}", ctx.trace[0].outcome);
266        }
267    }
268
269    #[tokio::test]
270    async fn test_filter_stage_validation_failure() {
271        let rule = Rule {
272            id: "test-rule-1".to_string(),
273            name: "Test Rule 1".to_string(),
274            active: true,
275            stage: RuleStage::Connect, // Stage is Connect
276            priority: 0,
277            termination: RuleTermination::Continue,
278            filter: Filter::Path(StringMatcher::Exact("/foo".to_string())), // Path filter invalid in Connect
279            actions: vec![],
280            constraints: None,
281        };
282
283        let engine = RuleEngine::new(vec![rule], vec![], None, None);
284        let mut flow = create_test_flow();
285        
286        // Execute in Connect stage
287        let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
288        
289        // Expect 1 trace entry with Failed outcome
290        assert_eq!(ctx.trace.len(), 1);
291        if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
292            assert!(msg.contains("Filter invalid"));
293        } else {
294            assert!(false, "Expected Failed outcome, got {:?}", ctx.trace[0].outcome);
295        }
296    }
297
298    #[tokio::test]
299    async fn test_action_stage_validation_failure() {
300        let rule = Rule {
301            id: "test-rule-2".to_string(),
302            name: "Test Rule 2".to_string(),
303            active: true,
304            stage: RuleStage::Connect, // Stage is Connect
305            priority: 0,
306            termination: RuleTermination::Continue,
307            filter: Filter::All, // Valid filter
308            actions: vec![
309                Action::SetResponseBody { body: BodySource::Text("foo".to_string()) } // Invalid action in Connect
310            ],
311            constraints: None,
312        };
313
314        let engine = RuleEngine::new(vec![rule], vec![], None, None);
315        let mut flow = create_test_flow();
316        
317        // Execute in Connect stage
318        let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
319        
320        // Expect 1 trace entry with Failed outcome
321        assert_eq!(ctx.trace.len(), 1);
322        if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
323            assert!(msg.contains("Action"));
324            assert!(msg.contains("not allowed"));
325        } else {
326            assert!(false, "Expected Failed outcome, got {:?}", ctx.trace[0].outcome);
327        }
328    }
329
330    #[test]
331    fn test_has_rules_for_stage_respects_active_flag() {
332        let inactive_rule = Rule {
333            id: "inactive-rh".to_string(),
334            name: "Inactive".to_string(),
335            active: false,
336            stage: RuleStage::RequestHeaders,
337            priority: 10,
338            termination: RuleTermination::Continue,
339            filter: Filter::All,
340            actions: vec![Action::Tag {
341                key: "k".to_string(),
342                value: "v".to_string(),
343            }],
344            constraints: None,
345        };
346        let active_connect_rule = Rule {
347            id: "active-connect".to_string(),
348            name: "Active Connect".to_string(),
349            active: true,
350            stage: RuleStage::Connect,
351            priority: 10,
352            termination: RuleTermination::Continue,
353            filter: Filter::All,
354            actions: vec![Action::Drop],
355            constraints: None,
356        };
357
358        let engine = RuleEngine::new(vec![inactive_rule, active_connect_rule], vec![], None, None);
359        assert!(
360            !engine.has_rules_for_stage(RuleStage::RequestHeaders),
361            "inactive rules should not count for stage presence"
362        );
363        assert!(engine.has_rules_for_stage(RuleStage::Connect));
364    }
365
366    #[tokio::test]
367    async fn test_execute_orders_rules_by_priority_descending() {
368        let low = Rule {
369            id: "low-pri".to_string(),
370            name: "low".to_string(),
371            active: true,
372            stage: RuleStage::RequestHeaders,
373            priority: 1,
374            termination: RuleTermination::Continue,
375            filter: Filter::All,
376            actions: vec![Action::Tag {
377                key: "order".to_string(),
378                value: "low".to_string(),
379            }],
380            constraints: None,
381        };
382        let high = Rule {
383            id: "high-pri".to_string(),
384            name: "high".to_string(),
385            active: true,
386            stage: RuleStage::RequestHeaders,
387            priority: 100,
388            termination: RuleTermination::Continue,
389            filter: Filter::All,
390            actions: vec![Action::Tag {
391                key: "order".to_string(),
392                value: "high".to_string(),
393            }],
394            constraints: None,
395        };
396        let engine = RuleEngine::new(vec![low, high], vec![], None, None);
397        let mut flow = create_test_flow();
398
399        let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
400        assert_eq!(ctx.trace.len(), 2);
401        assert_eq!(ctx.trace[0].rule_id, "high-pri");
402        assert_eq!(ctx.trace[1].rule_id, "low-pri");
403        assert_eq!(flow.tags, vec!["order:high".to_string(), "order:low".to_string()]);
404    }
405}