Skip to main content

relay_core_lib/rule/engine/
executor.rs

1use crate::rule::engine::actions;
2use crate::rule::engine::compiled::CompiledRule;
3use crate::rule::engine::compiler;
4use crate::rule::engine::matcher;
5use crate::rule::engine::state::{InMemoryRuleStateStore, RuleStateStore};
6use crate::rule::engine::validator;
7use crate::rule::model::{
8    Rule, RuleExecutionEvent, RuleGroup, RuleOutcome, RuleStage, RuleTermination, RuleTraceSummary,
9};
10use relay_core_api::flow::Flow;
11use relay_core_api::policy::ProxyPolicy;
12use std::collections::HashMap;
13use std::net::IpAddr;
14use std::sync::Arc;
15
16/// Override applied at the Connect stage to redirect connection targets.
17#[derive(Debug, Clone)]
18pub enum ConnectOverride {
19    ForwardPort { host: String, port: u16 },
20    RedirectIp { ip: IpAddr },
21    SetTtl { ttl: u8 },
22}
23
24impl ConnectOverride {
25    pub fn port(&self) -> Option<u16> {
26        match self {
27            ConnectOverride::ForwardPort { port, .. } => Some(*port),
28            _ => None,
29        }
30    }
31}
32
33pub struct ExecutionContext {
34    pub trace: Vec<RuleExecutionEvent>,
35    pub variables: HashMap<String, String>,
36    pub policy: Option<Arc<ProxyPolicy>>,
37    pub summary: RuleTraceSummary,
38    pub state_store: Arc<dyn RuleStateStore>,
39    // RE2: Throttle bytes/sec set by Throttle action.
40    // Wiring into the body pipeline requires P1 streaming-first changes
41    // (ThrottleBody wrapping in http.rs body forwarding path).
42    pub throttle_bytes_per_sec: Option<u64>,
43    /// RE1: Connect-stage override for L3/L4 transparent proxy actions.
44    pub connect_override: Option<ConnectOverride>,
45}
46
47impl ExecutionContext {
48    pub fn is_terminated(&self) -> bool {
49        matches!(self.summary, RuleTraceSummary::Terminated { .. })
50    }
51}
52
53#[derive(Debug)]
54pub struct RuleEngine {
55    compiled_rules: Vec<CompiledRule>,
56    policy: Option<Arc<ProxyPolicy>>,
57    state_store: Arc<dyn RuleStateStore>,
58}
59
60impl RuleEngine {
61    pub fn new(
62        rules: Vec<Rule>,
63        rule_groups: Vec<RuleGroup>,
64        policy: Option<Arc<ProxyPolicy>>,
65        state_store: Option<Arc<dyn RuleStateStore>>,
66    ) -> Self {
67        let mut all_rules = Vec::new();
68
69        // Flatten rules
70        for rule in rules {
71            all_rules.push(rule);
72        }
73
74        // Flatten rule groups
75        for group in rule_groups {
76            if group.active {
77                for rule in group.rules {
78                    all_rules.push(rule);
79                }
80            }
81        }
82
83        // Sort by priority (descending)
84        all_rules.sort_by_key(|r| std::cmp::Reverse(r.priority));
85
86        // Compile all rules
87        let compiled_rules = all_rules.into_iter().map(compiler::compile_rule).collect();
88
89        Self {
90            compiled_rules,
91            policy,
92            state_store: state_store.unwrap_or_else(|| Arc::new(InMemoryRuleStateStore::new())),
93        }
94    }
95
96    pub fn has_rules_for_stage(&self, stage: RuleStage) -> bool {
97        self.compiled_rules
98            .iter()
99            .any(|r| r.original.active && r.original.stage == stage)
100    }
101
102    pub async fn execute(&self, stage: RuleStage, flow: &mut Flow) -> ExecutionContext {
103        let mut ctx = ExecutionContext {
104            trace: vec![],
105            variables: HashMap::new(),
106            policy: self.policy.clone(),
107            summary: RuleTraceSummary::NoMatch,
108            state_store: self.state_store.clone(),
109            throttle_bytes_per_sec: None,
110            connect_override: None,
111        };
112
113        let mut terminated = false;
114        let mut modified_rules = Vec::new();
115
116        for compiled_rule in &self.compiled_rules {
117            let rule = &compiled_rule.original;
118
119            if !rule.active {
120                continue;
121            }
122
123            if rule.stage != stage {
124                continue;
125            }
126
127            // Validate filter stage (using pre-compiled filter)
128            if !validator::validate_filter_stage(&compiled_rule.filter, &stage) {
129                ctx.trace.push(RuleExecutionEvent {
130                    rule_id: rule.id.clone(),
131                    stage: stage.clone(),
132                    matched: false,
133                    duration_us: 0,
134                    outcome: RuleOutcome::Failed(format!("Filter invalid for stage {:?}", stage)),
135                });
136                continue;
137            }
138
139            // Validate action stage
140            let mut actions_valid = true;
141            for action in &rule.actions {
142                if !validator::validate_action_stage(action, &stage) {
143                    ctx.trace.push(RuleExecutionEvent {
144                        rule_id: rule.id.clone(),
145                        stage: stage.clone(),
146                        matched: false,
147                        duration_us: 0,
148                        outcome: RuleOutcome::Failed(format!(
149                            "Action {:?} not allowed in stage {:?}",
150                            action, stage
151                        )),
152                    });
153                    actions_valid = false;
154                    break;
155                }
156            }
157            if !actions_valid {
158                continue;
159            }
160
161            // Match (using pre-compiled filter)
162            let start = std::time::Instant::now();
163            let matched = matcher::matches(&compiled_rule.filter, flow);
164
165            if matched {
166                let timeout_ms = rule.constraints.as_ref().and_then(|c| c.timeout_ms);
167
168                let action_execution = async {
169                    let mut rule_outcome = RuleOutcome::MatchedAndExecuted;
170                    let mut rule_terminated = false;
171
172                    for action in &rule.actions {
173                        match actions::execute_action(action, flow, &mut ctx).await {
174                            actions::ActionOutcome::Continue => {}
175                            actions::ActionOutcome::Terminated(reason) => {
176                                rule_outcome = RuleOutcome::MatchedAndTerminated;
177                                ctx.summary = RuleTraceSummary::Terminated {
178                                    rule_id: rule.id.clone(),
179                                    reason,
180                                };
181                                rule_terminated = true;
182                                break;
183                            }
184                            actions::ActionOutcome::Failed(err) => {
185                                rule_outcome = RuleOutcome::Failed(err);
186                                break;
187                            }
188                        }
189                    }
190                    (rule_outcome, rule_terminated)
191                };
192
193                let (rule_outcome, rule_terminated) = if let Some(ms) = timeout_ms {
194                    match tokio::time::timeout(
195                        std::time::Duration::from_millis(ms),
196                        action_execution,
197                    )
198                    .await
199                    {
200                        Ok(res) => res,
201                        Err(_) => (
202                            RuleOutcome::Failed(format!("Rule execution timed out after {}ms", ms)),
203                            false,
204                        ),
205                    }
206                } else {
207                    action_execution.await
208                };
209
210                if rule_terminated {
211                    terminated = true;
212                }
213
214                ctx.trace.push(RuleExecutionEvent {
215                    rule_id: rule.id.clone(),
216                    stage: stage.clone(),
217                    matched: true,
218                    duration_us: start.elapsed().as_micros() as u64,
219                    outcome: rule_outcome.clone(),
220                });
221
222                if matches!(rule_outcome, RuleOutcome::MatchedAndExecuted) {
223                    modified_rules.push(rule.id.clone());
224                }
225
226                if terminated {
227                    break;
228                }
229
230                if let RuleTermination::Stop = rule.termination {
231                    break;
232                }
233            }
234        }
235
236        if !terminated && !modified_rules.is_empty() {
237            ctx.summary = RuleTraceSummary::Modified {
238                rule_ids: modified_rules.clone(),
239            };
240        }
241
242        flow.rule_variables = ctx.variables.clone();
243        flow.matched_rules = modified_rules.clone();
244
245        ctx
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use crate::rule::model::{
253        Action, BodySource, Filter, Rule, RuleConstraints, RuleOutcome, RuleStage, RuleTermination,
254        StringMatcher,
255    };
256    use chrono::Utc;
257    use relay_core_api::flow::{Flow, HttpRequest, Layer, NetworkInfo, TransportProtocol};
258    use url::Url;
259    use uuid::Uuid;
260
261    fn create_test_flow() -> Flow {
262        Flow {
263            id: Uuid::new_v4(),
264            start_time: Utc::now(),
265            end_time: None,
266            network: NetworkInfo {
267                client_ip: "127.0.0.1".to_string(),
268                client_port: 12345,
269                server_ip: "1.1.1.1".to_string(),
270                server_port: 80,
271                protocol: TransportProtocol::TCP,
272                tls: false,
273                tls_version: None,
274                sni: None,
275            },
276            layer: Layer::Http(relay_core_api::flow::HttpLayer {
277                request: HttpRequest {
278                    method: "GET".to_string(),
279                    url: Url::parse("http://example.com").unwrap(),
280                    version: "HTTP/1.1".to_string(),
281                    headers: vec![],
282                    cookies: vec![],
283                    query: vec![],
284                    body: None,
285                },
286                response: None,
287                error: None,
288            }),
289            tags: vec![],
290            meta: std::collections::HashMap::new(),
291            resilience_trace: None,
292            rule_variables: std::collections::HashMap::new(),
293            matched_rules: vec![],
294        }
295    }
296
297    #[tokio::test]
298    async fn test_rule_timeout() {
299        let rule = Rule {
300            id: "test-rule-timeout".to_string(),
301            name: "Test Rule Timeout".to_string(),
302            active: true,
303            stage: RuleStage::RequestHeaders,
304            priority: 0,
305            termination: RuleTermination::Continue,
306            filter: Filter::All,
307            actions: vec![
308                Action::Delay { ms: 200 }, // Delay 200ms
309            ],
310            constraints: Some(RuleConstraints {
311                timeout_ms: Some(50), // Timeout 50ms
312            }),
313        };
314
315        let engine = RuleEngine::new(vec![rule], vec![], None, None);
316        let mut flow = create_test_flow();
317
318        // Execute in RequestHeaders stage
319        let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
320
321        // Expect 1 trace entry with Failed outcome due to timeout
322        assert_eq!(ctx.trace.len(), 1);
323        if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
324            assert!(msg.contains("timed out"));
325        } else {
326            panic!("Expected Failed outcome, got {:?}", ctx.trace[0].outcome);
327        }
328    }
329
330    #[tokio::test]
331    async fn test_filter_stage_validation_failure() {
332        let rule = Rule {
333            id: "test-rule-1".to_string(),
334            name: "Test Rule 1".to_string(),
335            active: true,
336            stage: RuleStage::Connect, // Stage is Connect
337            priority: 0,
338            termination: RuleTermination::Continue,
339            filter: Filter::Path(StringMatcher::Exact("/foo".to_string())), // Path filter invalid in Connect
340            actions: vec![],
341            constraints: None,
342        };
343
344        let engine = RuleEngine::new(vec![rule], vec![], None, None);
345        let mut flow = create_test_flow();
346
347        // Execute in Connect stage
348        let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
349
350        // Expect 1 trace entry with Failed outcome
351        assert_eq!(ctx.trace.len(), 1);
352        if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
353            assert!(msg.contains("Filter invalid"));
354        } else {
355            assert!(
356                false,
357                "Expected Failed outcome, got {:?}",
358                ctx.trace[0].outcome
359            );
360        }
361    }
362
363    #[tokio::test]
364    async fn test_action_stage_validation_failure() {
365        let rule = Rule {
366            id: "test-rule-2".to_string(),
367            name: "Test Rule 2".to_string(),
368            active: true,
369            stage: RuleStage::Connect, // Stage is Connect
370            priority: 0,
371            termination: RuleTermination::Continue,
372            filter: Filter::All, // Valid filter
373            actions: vec![
374                Action::SetResponseBody {
375                    body: BodySource::Text("foo".to_string()),
376                }, // Invalid action in Connect
377            ],
378            constraints: None,
379        };
380
381        let engine = RuleEngine::new(vec![rule], vec![], None, None);
382        let mut flow = create_test_flow();
383
384        // Execute in Connect stage
385        let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
386
387        // Expect 1 trace entry with Failed outcome
388        assert_eq!(ctx.trace.len(), 1);
389        if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
390            assert!(msg.contains("Action"));
391            assert!(msg.contains("not allowed"));
392        } else {
393            assert!(
394                false,
395                "Expected Failed outcome, got {:?}",
396                ctx.trace[0].outcome
397            );
398        }
399    }
400
401    #[test]
402    fn test_has_rules_for_stage_respects_active_flag() {
403        let inactive_rule = Rule {
404            id: "inactive-rh".to_string(),
405            name: "Inactive".to_string(),
406            active: false,
407            stage: RuleStage::RequestHeaders,
408            priority: 10,
409            termination: RuleTermination::Continue,
410            filter: Filter::All,
411            actions: vec![Action::Tag {
412                key: "k".to_string(),
413                value: "v".to_string(),
414            }],
415            constraints: None,
416        };
417        let active_connect_rule = Rule {
418            id: "active-connect".to_string(),
419            name: "Active Connect".to_string(),
420            active: true,
421            stage: RuleStage::Connect,
422            priority: 10,
423            termination: RuleTermination::Continue,
424            filter: Filter::All,
425            actions: vec![Action::Drop],
426            constraints: None,
427        };
428
429        let engine = RuleEngine::new(vec![inactive_rule, active_connect_rule], vec![], None, None);
430        assert!(
431            !engine.has_rules_for_stage(RuleStage::RequestHeaders),
432            "inactive rules should not count for stage presence"
433        );
434        assert!(engine.has_rules_for_stage(RuleStage::Connect));
435    }
436
437    #[tokio::test]
438    async fn test_execute_orders_rules_by_priority_descending() {
439        let low = Rule {
440            id: "low-pri".to_string(),
441            name: "low".to_string(),
442            active: true,
443            stage: RuleStage::RequestHeaders,
444            priority: 1,
445            termination: RuleTermination::Continue,
446            filter: Filter::All,
447            actions: vec![Action::Tag {
448                key: "order".to_string(),
449                value: "low".to_string(),
450            }],
451            constraints: None,
452        };
453        let high = Rule {
454            id: "high-pri".to_string(),
455            name: "high".to_string(),
456            active: true,
457            stage: RuleStage::RequestHeaders,
458            priority: 100,
459            termination: RuleTermination::Continue,
460            filter: Filter::All,
461            actions: vec![Action::Tag {
462                key: "order".to_string(),
463                value: "high".to_string(),
464            }],
465            constraints: None,
466        };
467        let engine = RuleEngine::new(vec![low, high], vec![], None, None);
468        let mut flow = create_test_flow();
469
470        let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
471        assert_eq!(ctx.trace.len(), 2);
472        assert_eq!(ctx.trace[0].rule_id, "high-pri");
473        assert_eq!(ctx.trace[1].rule_id, "low-pri");
474        assert_eq!(
475            flow.tags,
476            vec!["order:high".to_string(), "order:low".to_string()]
477        );
478    }
479}