Skip to main content

relay_core_lib/rule/engine/
executor.rs

1use crate::rule::engine::actions;
2use crate::rule::engine::compiled::CompiledRule;
3use crate::rule::engine::compiler;
4use crate::rule::engine::matcher;
5use crate::rule::engine::state::{InMemoryRuleStateStore, RuleStateStore};
6use crate::rule::engine::validator;
7use crate::rule::model::{
8    Rule, RuleExecutionEvent, RuleGroup, RuleOutcome, RuleStage, RuleTermination, RuleTraceSummary,
9};
10use relay_core_api::flow::Flow;
11use relay_core_api::policy::ProxyPolicy;
12use std::collections::HashMap;
13use std::sync::Arc;
14
15pub struct ExecutionContext {
16    pub trace: Vec<RuleExecutionEvent>,
17    pub variables: HashMap<String, String>,
18    pub policy: Option<Arc<ProxyPolicy>>,
19    pub summary: RuleTraceSummary,
20    pub state_store: Arc<dyn RuleStateStore>,
21}
22
23#[derive(Debug)]
24pub struct RuleEngine {
25    compiled_rules: Vec<CompiledRule>,
26    policy: Option<Arc<ProxyPolicy>>,
27    state_store: Arc<dyn RuleStateStore>,
28}
29
30impl RuleEngine {
31    pub fn new(
32        rules: Vec<Rule>,
33        rule_groups: Vec<RuleGroup>,
34        policy: Option<Arc<ProxyPolicy>>,
35        state_store: Option<Arc<dyn RuleStateStore>>,
36    ) -> Self {
37        let mut all_rules = Vec::new();
38
39        // Flatten rules
40        for rule in rules {
41            all_rules.push(rule);
42        }
43
44        // Flatten rule groups
45        for group in rule_groups {
46            if group.active {
47                for rule in group.rules {
48                    all_rules.push(rule);
49                }
50            }
51        }
52
53        // Sort by priority (descending)
54        all_rules.sort_by_key(|r| std::cmp::Reverse(r.priority));
55
56        // Compile all rules
57        let compiled_rules = all_rules.into_iter().map(compiler::compile_rule).collect();
58
59        Self {
60            compiled_rules,
61            policy,
62            state_store: state_store.unwrap_or_else(|| Arc::new(InMemoryRuleStateStore::new())),
63        }
64    }
65
66    pub fn has_rules_for_stage(&self, stage: RuleStage) -> bool {
67        self.compiled_rules
68            .iter()
69            .any(|r| r.original.active && r.original.stage == stage)
70    }
71
72    pub async fn execute(&self, stage: RuleStage, flow: &mut Flow) -> ExecutionContext {
73        let mut ctx = ExecutionContext {
74            trace: vec![],
75            variables: HashMap::new(),
76            policy: self.policy.clone(),
77            summary: RuleTraceSummary::NoMatch,
78            state_store: self.state_store.clone(),
79        };
80
81        let mut terminated = false;
82        let mut modified_rules = Vec::new();
83
84        for compiled_rule in &self.compiled_rules {
85            let rule = &compiled_rule.original;
86
87            if !rule.active {
88                continue;
89            }
90
91            if rule.stage != stage {
92                continue;
93            }
94
95            // Validate filter stage (using pre-compiled filter)
96            if !validator::validate_filter_stage(&compiled_rule.filter, &stage) {
97                ctx.trace.push(RuleExecutionEvent {
98                    rule_id: rule.id.clone(),
99                    stage: stage.clone(),
100                    matched: false,
101                    duration_us: 0,
102                    outcome: RuleOutcome::Failed(format!("Filter invalid for stage {:?}", stage)),
103                });
104                continue;
105            }
106
107            // Validate action stage
108            let mut actions_valid = true;
109            for action in &rule.actions {
110                if !validator::validate_action_stage(action, &stage) {
111                    ctx.trace.push(RuleExecutionEvent {
112                        rule_id: rule.id.clone(),
113                        stage: stage.clone(),
114                        matched: false,
115                        duration_us: 0,
116                        outcome: RuleOutcome::Failed(format!(
117                            "Action {:?} not allowed in stage {:?}",
118                            action, stage
119                        )),
120                    });
121                    actions_valid = false;
122                    break;
123                }
124            }
125            if !actions_valid {
126                continue;
127            }
128
129            // Match (using pre-compiled filter)
130            let start = std::time::Instant::now();
131            let matched = matcher::matches(&compiled_rule.filter, flow);
132
133            if matched {
134                let timeout_ms = rule.constraints.as_ref().and_then(|c| c.timeout_ms);
135
136                let action_execution = async {
137                    let mut rule_outcome = RuleOutcome::MatchedAndExecuted;
138                    let mut rule_terminated = false;
139
140                    for action in &rule.actions {
141                        match actions::execute_action(action, flow, &mut ctx).await {
142                            actions::ActionOutcome::Continue => {}
143                            actions::ActionOutcome::Terminated(reason) => {
144                                rule_outcome = RuleOutcome::MatchedAndTerminated;
145                                ctx.summary = RuleTraceSummary::Terminated {
146                                    rule_id: rule.id.clone(),
147                                    reason,
148                                };
149                                rule_terminated = true;
150                                break;
151                            }
152                            actions::ActionOutcome::Failed(err) => {
153                                rule_outcome = RuleOutcome::Failed(err);
154                                break;
155                            }
156                        }
157                    }
158                    (rule_outcome, rule_terminated)
159                };
160
161                let (rule_outcome, rule_terminated) = if let Some(ms) = timeout_ms {
162                    match tokio::time::timeout(
163                        std::time::Duration::from_millis(ms),
164                        action_execution,
165                    )
166                    .await
167                    {
168                        Ok(res) => res,
169                        Err(_) => (
170                            RuleOutcome::Failed(format!("Rule execution timed out after {}ms", ms)),
171                            false,
172                        ),
173                    }
174                } else {
175                    action_execution.await
176                };
177
178                if rule_terminated {
179                    terminated = true;
180                }
181
182                ctx.trace.push(RuleExecutionEvent {
183                    rule_id: rule.id.clone(),
184                    stage: stage.clone(),
185                    matched: true,
186                    duration_us: start.elapsed().as_micros() as u64,
187                    outcome: rule_outcome.clone(),
188                });
189
190                if matches!(rule_outcome, RuleOutcome::MatchedAndExecuted) {
191                    modified_rules.push(rule.id.clone());
192                }
193
194                if terminated {
195                    break;
196                }
197
198                if let RuleTermination::Stop = rule.termination {
199                    break;
200                }
201            }
202        }
203
204        if !terminated && !modified_rules.is_empty() {
205            ctx.summary = RuleTraceSummary::Modified {
206                rule_ids: modified_rules,
207            };
208        }
209
210        ctx
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use crate::rule::model::{
218        Action, BodySource, Filter, Rule, RuleConstraints, RuleOutcome, RuleStage, RuleTermination,
219        StringMatcher,
220    };
221    use chrono::Utc;
222    use relay_core_api::flow::{Flow, HttpRequest, Layer, NetworkInfo, TransportProtocol};
223    use url::Url;
224    use uuid::Uuid;
225
226    fn create_test_flow() -> Flow {
227        Flow {
228            id: Uuid::new_v4(),
229            start_time: Utc::now(),
230            end_time: None,
231            network: NetworkInfo {
232                client_ip: "127.0.0.1".to_string(),
233                client_port: 12345,
234                server_ip: "1.1.1.1".to_string(),
235                server_port: 80,
236                protocol: TransportProtocol::TCP,
237                tls: false,
238                tls_version: None,
239                sni: None,
240            },
241            layer: Layer::Http(relay_core_api::flow::HttpLayer {
242                request: HttpRequest {
243                    method: "GET".to_string(),
244                    url: Url::parse("http://example.com").unwrap(),
245                    version: "HTTP/1.1".to_string(),
246                    headers: vec![],
247                    cookies: vec![],
248                    query: vec![],
249                    body: None,
250                },
251                response: None,
252                error: None,
253            }),
254            tags: vec![],
255            meta: std::collections::HashMap::new(),
256        }
257    }
258
259    #[tokio::test]
260    async fn test_rule_timeout() {
261        let rule = Rule {
262            id: "test-rule-timeout".to_string(),
263            name: "Test Rule Timeout".to_string(),
264            active: true,
265            stage: RuleStage::RequestHeaders,
266            priority: 0,
267            termination: RuleTermination::Continue,
268            filter: Filter::All,
269            actions: vec![
270                Action::Delay { ms: 200 }, // Delay 200ms
271            ],
272            constraints: Some(RuleConstraints {
273                timeout_ms: Some(50), // Timeout 50ms
274            }),
275        };
276
277        let engine = RuleEngine::new(vec![rule], vec![], None, None);
278        let mut flow = create_test_flow();
279
280        // Execute in RequestHeaders stage
281        let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
282
283        // Expect 1 trace entry with Failed outcome due to timeout
284        assert_eq!(ctx.trace.len(), 1);
285        if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
286            assert!(msg.contains("timed out"));
287        } else {
288            panic!("Expected Failed outcome, got {:?}", ctx.trace[0].outcome);
289        }
290    }
291
292    #[tokio::test]
293    async fn test_filter_stage_validation_failure() {
294        let rule = Rule {
295            id: "test-rule-1".to_string(),
296            name: "Test Rule 1".to_string(),
297            active: true,
298            stage: RuleStage::Connect, // Stage is Connect
299            priority: 0,
300            termination: RuleTermination::Continue,
301            filter: Filter::Path(StringMatcher::Exact("/foo".to_string())), // Path filter invalid in Connect
302            actions: vec![],
303            constraints: None,
304        };
305
306        let engine = RuleEngine::new(vec![rule], vec![], None, None);
307        let mut flow = create_test_flow();
308
309        // Execute in Connect stage
310        let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
311
312        // Expect 1 trace entry with Failed outcome
313        assert_eq!(ctx.trace.len(), 1);
314        if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
315            assert!(msg.contains("Filter invalid"));
316        } else {
317            assert!(
318                false,
319                "Expected Failed outcome, got {:?}",
320                ctx.trace[0].outcome
321            );
322        }
323    }
324
325    #[tokio::test]
326    async fn test_action_stage_validation_failure() {
327        let rule = Rule {
328            id: "test-rule-2".to_string(),
329            name: "Test Rule 2".to_string(),
330            active: true,
331            stage: RuleStage::Connect, // Stage is Connect
332            priority: 0,
333            termination: RuleTermination::Continue,
334            filter: Filter::All, // Valid filter
335            actions: vec![
336                Action::SetResponseBody {
337                    body: BodySource::Text("foo".to_string()),
338                }, // Invalid action in Connect
339            ],
340            constraints: None,
341        };
342
343        let engine = RuleEngine::new(vec![rule], vec![], None, None);
344        let mut flow = create_test_flow();
345
346        // Execute in Connect stage
347        let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
348
349        // Expect 1 trace entry with Failed outcome
350        assert_eq!(ctx.trace.len(), 1);
351        if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
352            assert!(msg.contains("Action"));
353            assert!(msg.contains("not allowed"));
354        } else {
355            assert!(
356                false,
357                "Expected Failed outcome, got {:?}",
358                ctx.trace[0].outcome
359            );
360        }
361    }
362
363    #[test]
364    fn test_has_rules_for_stage_respects_active_flag() {
365        let inactive_rule = Rule {
366            id: "inactive-rh".to_string(),
367            name: "Inactive".to_string(),
368            active: false,
369            stage: RuleStage::RequestHeaders,
370            priority: 10,
371            termination: RuleTermination::Continue,
372            filter: Filter::All,
373            actions: vec![Action::Tag {
374                key: "k".to_string(),
375                value: "v".to_string(),
376            }],
377            constraints: None,
378        };
379        let active_connect_rule = Rule {
380            id: "active-connect".to_string(),
381            name: "Active Connect".to_string(),
382            active: true,
383            stage: RuleStage::Connect,
384            priority: 10,
385            termination: RuleTermination::Continue,
386            filter: Filter::All,
387            actions: vec![Action::Drop],
388            constraints: None,
389        };
390
391        let engine = RuleEngine::new(vec![inactive_rule, active_connect_rule], vec![], None, None);
392        assert!(
393            !engine.has_rules_for_stage(RuleStage::RequestHeaders),
394            "inactive rules should not count for stage presence"
395        );
396        assert!(engine.has_rules_for_stage(RuleStage::Connect));
397    }
398
399    #[tokio::test]
400    async fn test_execute_orders_rules_by_priority_descending() {
401        let low = Rule {
402            id: "low-pri".to_string(),
403            name: "low".to_string(),
404            active: true,
405            stage: RuleStage::RequestHeaders,
406            priority: 1,
407            termination: RuleTermination::Continue,
408            filter: Filter::All,
409            actions: vec![Action::Tag {
410                key: "order".to_string(),
411                value: "low".to_string(),
412            }],
413            constraints: None,
414        };
415        let high = Rule {
416            id: "high-pri".to_string(),
417            name: "high".to_string(),
418            active: true,
419            stage: RuleStage::RequestHeaders,
420            priority: 100,
421            termination: RuleTermination::Continue,
422            filter: Filter::All,
423            actions: vec![Action::Tag {
424                key: "order".to_string(),
425                value: "high".to_string(),
426            }],
427            constraints: None,
428        };
429        let engine = RuleEngine::new(vec![low, high], vec![], None, None);
430        let mut flow = create_test_flow();
431
432        let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
433        assert_eq!(ctx.trace.len(), 2);
434        assert_eq!(ctx.trace[0].rule_id, "high-pri");
435        assert_eq!(ctx.trace[1].rule_id, "low-pri");
436        assert_eq!(
437            flow.tags,
438            vec!["order:high".to_string(), "order:low".to_string()]
439        );
440    }
441}