Skip to main content

relay_core_lib/rule/engine/
executor.rs

1use crate::rule::engine::actions;
2use crate::rule::engine::compiled::CompiledRule;
3use crate::rule::engine::compiler;
4use crate::rule::engine::matcher;
5use crate::rule::engine::state::{InMemoryRuleStateStore, RuleStateStore};
6use crate::rule::engine::validator;
7use crate::rule::model::{
8    Rule, RuleExecutionEvent, RuleGroup, RuleOutcome, RuleStage, RuleTermination, RuleTraceSummary,
9};
10use relay_core_api::flow::Flow;
11use relay_core_api::policy::ProxyPolicy;
12use std::collections::HashMap;
13use std::sync::Arc;
14
15pub struct ExecutionContext {
16    pub trace: Vec<RuleExecutionEvent>,
17    pub variables: HashMap<String, String>,
18    pub policy: Option<Arc<ProxyPolicy>>,
19    pub summary: RuleTraceSummary,
20    pub state_store: Arc<dyn RuleStateStore>,
21    // RE2: Throttle bytes/sec set by Throttle action.
22    // Wiring into the body pipeline requires P1 streaming-first changes
23    // (ThrottleBody wrapping in http.rs body forwarding path).
24    pub throttle_bytes_per_sec: Option<u64>,
25}
26
27#[derive(Debug)]
28pub struct RuleEngine {
29    compiled_rules: Vec<CompiledRule>,
30    policy: Option<Arc<ProxyPolicy>>,
31    state_store: Arc<dyn RuleStateStore>,
32}
33
34impl RuleEngine {
35    pub fn new(
36        rules: Vec<Rule>,
37        rule_groups: Vec<RuleGroup>,
38        policy: Option<Arc<ProxyPolicy>>,
39        state_store: Option<Arc<dyn RuleStateStore>>,
40    ) -> Self {
41        let mut all_rules = Vec::new();
42
43        // Flatten rules
44        for rule in rules {
45            all_rules.push(rule);
46        }
47
48        // Flatten rule groups
49        for group in rule_groups {
50            if group.active {
51                for rule in group.rules {
52                    all_rules.push(rule);
53                }
54            }
55        }
56
57        // Sort by priority (descending)
58        all_rules.sort_by_key(|r| std::cmp::Reverse(r.priority));
59
60        // Compile all rules
61        let compiled_rules = all_rules.into_iter().map(compiler::compile_rule).collect();
62
63        Self {
64            compiled_rules,
65            policy,
66            state_store: state_store.unwrap_or_else(|| Arc::new(InMemoryRuleStateStore::new())),
67        }
68    }
69
70    pub fn has_rules_for_stage(&self, stage: RuleStage) -> bool {
71        self.compiled_rules
72            .iter()
73            .any(|r| r.original.active && r.original.stage == stage)
74    }
75
76    pub async fn execute(&self, stage: RuleStage, flow: &mut Flow) -> ExecutionContext {
77        let mut ctx = ExecutionContext {
78            trace: vec![],
79            variables: HashMap::new(),
80            policy: self.policy.clone(),
81            summary: RuleTraceSummary::NoMatch,
82            state_store: self.state_store.clone(),
83            throttle_bytes_per_sec: None,
84        };
85
86        let mut terminated = false;
87        let mut modified_rules = Vec::new();
88
89        for compiled_rule in &self.compiled_rules {
90            let rule = &compiled_rule.original;
91
92            if !rule.active {
93                continue;
94            }
95
96            if rule.stage != stage {
97                continue;
98            }
99
100            // Validate filter stage (using pre-compiled filter)
101            if !validator::validate_filter_stage(&compiled_rule.filter, &stage) {
102                ctx.trace.push(RuleExecutionEvent {
103                    rule_id: rule.id.clone(),
104                    stage: stage.clone(),
105                    matched: false,
106                    duration_us: 0,
107                    outcome: RuleOutcome::Failed(format!("Filter invalid for stage {:?}", stage)),
108                });
109                continue;
110            }
111
112            // Validate action stage
113            let mut actions_valid = true;
114            for action in &rule.actions {
115                if !validator::validate_action_stage(action, &stage) {
116                    ctx.trace.push(RuleExecutionEvent {
117                        rule_id: rule.id.clone(),
118                        stage: stage.clone(),
119                        matched: false,
120                        duration_us: 0,
121                        outcome: RuleOutcome::Failed(format!(
122                            "Action {:?} not allowed in stage {:?}",
123                            action, stage
124                        )),
125                    });
126                    actions_valid = false;
127                    break;
128                }
129            }
130            if !actions_valid {
131                continue;
132            }
133
134            // Match (using pre-compiled filter)
135            let start = std::time::Instant::now();
136            let matched = matcher::matches(&compiled_rule.filter, flow);
137
138            if matched {
139                let timeout_ms = rule.constraints.as_ref().and_then(|c| c.timeout_ms);
140
141                let action_execution = async {
142                    let mut rule_outcome = RuleOutcome::MatchedAndExecuted;
143                    let mut rule_terminated = false;
144
145                    for action in &rule.actions {
146                        match actions::execute_action(action, flow, &mut ctx).await {
147                            actions::ActionOutcome::Continue => {}
148                            actions::ActionOutcome::Terminated(reason) => {
149                                rule_outcome = RuleOutcome::MatchedAndTerminated;
150                                ctx.summary = RuleTraceSummary::Terminated {
151                                    rule_id: rule.id.clone(),
152                                    reason,
153                                };
154                                rule_terminated = true;
155                                break;
156                            }
157                            actions::ActionOutcome::Failed(err) => {
158                                rule_outcome = RuleOutcome::Failed(err);
159                                break;
160                            }
161                        }
162                    }
163                    (rule_outcome, rule_terminated)
164                };
165
166                let (rule_outcome, rule_terminated) = if let Some(ms) = timeout_ms {
167                    match tokio::time::timeout(
168                        std::time::Duration::from_millis(ms),
169                        action_execution,
170                    )
171                    .await
172                    {
173                        Ok(res) => res,
174                        Err(_) => (
175                            RuleOutcome::Failed(format!("Rule execution timed out after {}ms", ms)),
176                            false,
177                        ),
178                    }
179                } else {
180                    action_execution.await
181                };
182
183                if rule_terminated {
184                    terminated = true;
185                }
186
187                ctx.trace.push(RuleExecutionEvent {
188                    rule_id: rule.id.clone(),
189                    stage: stage.clone(),
190                    matched: true,
191                    duration_us: start.elapsed().as_micros() as u64,
192                    outcome: rule_outcome.clone(),
193                });
194
195                if matches!(rule_outcome, RuleOutcome::MatchedAndExecuted) {
196                    modified_rules.push(rule.id.clone());
197                }
198
199                if terminated {
200                    break;
201                }
202
203                if let RuleTermination::Stop = rule.termination {
204                    break;
205                }
206            }
207        }
208
209        if !terminated && !modified_rules.is_empty() {
210            ctx.summary = RuleTraceSummary::Modified {
211                rule_ids: modified_rules,
212            };
213        }
214
215        ctx
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::rule::model::{
223        Action, BodySource, Filter, Rule, RuleConstraints, RuleOutcome, RuleStage, RuleTermination,
224        StringMatcher,
225    };
226    use chrono::Utc;
227    use relay_core_api::flow::{Flow, HttpRequest, Layer, NetworkInfo, TransportProtocol};
228    use url::Url;
229    use uuid::Uuid;
230
231    fn create_test_flow() -> Flow {
232        Flow {
233            id: Uuid::new_v4(),
234            start_time: Utc::now(),
235            end_time: None,
236            network: NetworkInfo {
237                client_ip: "127.0.0.1".to_string(),
238                client_port: 12345,
239                server_ip: "1.1.1.1".to_string(),
240                server_port: 80,
241                protocol: TransportProtocol::TCP,
242                tls: false,
243                tls_version: None,
244                sni: None,
245            },
246            layer: Layer::Http(relay_core_api::flow::HttpLayer {
247                request: HttpRequest {
248                    method: "GET".to_string(),
249                    url: Url::parse("http://example.com").unwrap(),
250                    version: "HTTP/1.1".to_string(),
251                    headers: vec![],
252                    cookies: vec![],
253                    query: vec![],
254                    body: None,
255                },
256                response: None,
257                error: None,
258            }),
259            tags: vec![],
260            meta: std::collections::HashMap::new(),
261        }
262    }
263
264    #[tokio::test]
265    async fn test_rule_timeout() {
266        let rule = Rule {
267            id: "test-rule-timeout".to_string(),
268            name: "Test Rule Timeout".to_string(),
269            active: true,
270            stage: RuleStage::RequestHeaders,
271            priority: 0,
272            termination: RuleTermination::Continue,
273            filter: Filter::All,
274            actions: vec![
275                Action::Delay { ms: 200 }, // Delay 200ms
276            ],
277            constraints: Some(RuleConstraints {
278                timeout_ms: Some(50), // Timeout 50ms
279            }),
280        };
281
282        let engine = RuleEngine::new(vec![rule], vec![], None, None);
283        let mut flow = create_test_flow();
284
285        // Execute in RequestHeaders stage
286        let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
287
288        // Expect 1 trace entry with Failed outcome due to timeout
289        assert_eq!(ctx.trace.len(), 1);
290        if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
291            assert!(msg.contains("timed out"));
292        } else {
293            panic!("Expected Failed outcome, got {:?}", ctx.trace[0].outcome);
294        }
295    }
296
297    #[tokio::test]
298    async fn test_filter_stage_validation_failure() {
299        let rule = Rule {
300            id: "test-rule-1".to_string(),
301            name: "Test Rule 1".to_string(),
302            active: true,
303            stage: RuleStage::Connect, // Stage is Connect
304            priority: 0,
305            termination: RuleTermination::Continue,
306            filter: Filter::Path(StringMatcher::Exact("/foo".to_string())), // Path filter invalid in Connect
307            actions: vec![],
308            constraints: None,
309        };
310
311        let engine = RuleEngine::new(vec![rule], vec![], None, None);
312        let mut flow = create_test_flow();
313
314        // Execute in Connect stage
315        let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
316
317        // Expect 1 trace entry with Failed outcome
318        assert_eq!(ctx.trace.len(), 1);
319        if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
320            assert!(msg.contains("Filter invalid"));
321        } else {
322            assert!(
323                false,
324                "Expected Failed outcome, got {:?}",
325                ctx.trace[0].outcome
326            );
327        }
328    }
329
330    #[tokio::test]
331    async fn test_action_stage_validation_failure() {
332        let rule = Rule {
333            id: "test-rule-2".to_string(),
334            name: "Test Rule 2".to_string(),
335            active: true,
336            stage: RuleStage::Connect, // Stage is Connect
337            priority: 0,
338            termination: RuleTermination::Continue,
339            filter: Filter::All, // Valid filter
340            actions: vec![
341                Action::SetResponseBody {
342                    body: BodySource::Text("foo".to_string()),
343                }, // Invalid action in Connect
344            ],
345            constraints: None,
346        };
347
348        let engine = RuleEngine::new(vec![rule], vec![], None, None);
349        let mut flow = create_test_flow();
350
351        // Execute in Connect stage
352        let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
353
354        // Expect 1 trace entry with Failed outcome
355        assert_eq!(ctx.trace.len(), 1);
356        if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
357            assert!(msg.contains("Action"));
358            assert!(msg.contains("not allowed"));
359        } else {
360            assert!(
361                false,
362                "Expected Failed outcome, got {:?}",
363                ctx.trace[0].outcome
364            );
365        }
366    }
367
368    #[test]
369    fn test_has_rules_for_stage_respects_active_flag() {
370        let inactive_rule = Rule {
371            id: "inactive-rh".to_string(),
372            name: "Inactive".to_string(),
373            active: false,
374            stage: RuleStage::RequestHeaders,
375            priority: 10,
376            termination: RuleTermination::Continue,
377            filter: Filter::All,
378            actions: vec![Action::Tag {
379                key: "k".to_string(),
380                value: "v".to_string(),
381            }],
382            constraints: None,
383        };
384        let active_connect_rule = Rule {
385            id: "active-connect".to_string(),
386            name: "Active Connect".to_string(),
387            active: true,
388            stage: RuleStage::Connect,
389            priority: 10,
390            termination: RuleTermination::Continue,
391            filter: Filter::All,
392            actions: vec![Action::Drop],
393            constraints: None,
394        };
395
396        let engine = RuleEngine::new(vec![inactive_rule, active_connect_rule], vec![], None, None);
397        assert!(
398            !engine.has_rules_for_stage(RuleStage::RequestHeaders),
399            "inactive rules should not count for stage presence"
400        );
401        assert!(engine.has_rules_for_stage(RuleStage::Connect));
402    }
403
404    #[tokio::test]
405    async fn test_execute_orders_rules_by_priority_descending() {
406        let low = Rule {
407            id: "low-pri".to_string(),
408            name: "low".to_string(),
409            active: true,
410            stage: RuleStage::RequestHeaders,
411            priority: 1,
412            termination: RuleTermination::Continue,
413            filter: Filter::All,
414            actions: vec![Action::Tag {
415                key: "order".to_string(),
416                value: "low".to_string(),
417            }],
418            constraints: None,
419        };
420        let high = Rule {
421            id: "high-pri".to_string(),
422            name: "high".to_string(),
423            active: true,
424            stage: RuleStage::RequestHeaders,
425            priority: 100,
426            termination: RuleTermination::Continue,
427            filter: Filter::All,
428            actions: vec![Action::Tag {
429                key: "order".to_string(),
430                value: "high".to_string(),
431            }],
432            constraints: None,
433        };
434        let engine = RuleEngine::new(vec![low, high], vec![], None, None);
435        let mut flow = create_test_flow();
436
437        let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
438        assert_eq!(ctx.trace.len(), 2);
439        assert_eq!(ctx.trace[0].rule_id, "high-pri");
440        assert_eq!(ctx.trace[1].rule_id, "low-pri");
441        assert_eq!(
442            flow.tags,
443            vec!["order:high".to_string(), "order:low".to_string()]
444        );
445    }
446}