1use crate::rule::model::{Rule, RuleGroup, RuleStage, RuleOutcome, RuleExecutionEvent, RuleTermination, RuleTraceSummary};
2use crate::rule::engine::compiler;
3use crate::rule::engine::compiled::CompiledRule;
4use crate::rule::engine::matcher;
5use crate::rule::engine::validator;
6use crate::rule::engine::actions;
7use crate::rule::engine::state::{RuleStateStore, InMemoryRuleStateStore};
8use relay_core_api::flow::Flow;
9use std::collections::HashMap;
10use std::sync::Arc;
11use relay_core_api::policy::ProxyPolicy;
12
13pub struct ExecutionContext {
14 pub trace: Vec<RuleExecutionEvent>,
15 pub variables: HashMap<String, String>,
16 pub policy: Option<Arc<ProxyPolicy>>,
17 pub summary: RuleTraceSummary,
18 pub state_store: Arc<dyn RuleStateStore>,
19}
20
21#[derive(Debug)]
22pub struct RuleEngine {
23 compiled_rules: Vec<CompiledRule>,
24 policy: Option<Arc<ProxyPolicy>>,
25 state_store: Arc<dyn RuleStateStore>,
26}
27
28impl RuleEngine {
29 pub fn new(rules: Vec<Rule>, rule_groups: Vec<RuleGroup>, policy: Option<Arc<ProxyPolicy>>, state_store: Option<Arc<dyn RuleStateStore>>) -> Self {
30 let mut all_rules = Vec::new();
31
32 for rule in rules {
34 all_rules.push(rule);
35 }
36
37 for group in rule_groups {
39 if group.active {
40 for rule in group.rules {
41 all_rules.push(rule);
42 }
43 }
44 }
45
46 all_rules.sort_by_key(|r| std::cmp::Reverse(r.priority));
48
49 let compiled_rules = all_rules.into_iter()
51 .map(compiler::compile_rule)
52 .collect();
53
54 Self {
55 compiled_rules,
56 policy,
57 state_store: state_store.unwrap_or_else(|| Arc::new(InMemoryRuleStateStore::new()))
58 }
59 }
60
61 pub fn has_rules_for_stage(&self, stage: RuleStage) -> bool {
62 self.compiled_rules.iter().any(|r| r.original.active && r.original.stage == stage)
63 }
64
65 pub async fn execute(&self, stage: RuleStage, flow: &mut Flow) -> ExecutionContext {
66 let mut ctx = ExecutionContext {
67 trace: vec![],
68 variables: HashMap::new(),
69 policy: self.policy.clone(),
70 summary: RuleTraceSummary::NoMatch,
71 state_store: self.state_store.clone(),
72 };
73
74 let mut terminated = false;
75 let mut modified_rules = Vec::new();
76
77 for compiled_rule in &self.compiled_rules {
78 let rule = &compiled_rule.original;
79
80 if !rule.active {
81 continue;
82 }
83
84 if rule.stage != stage {
85 continue;
86 }
87
88 if !validator::validate_filter_stage(&compiled_rule.filter, &stage) {
90 ctx.trace.push(RuleExecutionEvent {
91 rule_id: rule.id.clone(),
92 stage: stage.clone(),
93 matched: false,
94 duration_us: 0,
95 outcome: RuleOutcome::Failed(format!("Filter invalid for stage {:?}", stage)),
96 });
97 continue;
98 }
99
100 let mut actions_valid = true;
102 for action in &rule.actions {
103 if !validator::validate_action_stage(action, &stage) {
104 ctx.trace.push(RuleExecutionEvent {
105 rule_id: rule.id.clone(),
106 stage: stage.clone(),
107 matched: false,
108 duration_us: 0,
109 outcome: RuleOutcome::Failed(format!("Action {:?} not allowed in stage {:?}", action, stage)),
110 });
111 actions_valid = false;
112 break;
113 }
114 }
115 if !actions_valid {
116 continue;
117 }
118
119 let start = std::time::Instant::now();
121 let matched = matcher::matches(&compiled_rule.filter, flow);
122
123 if matched {
124 let timeout_ms = rule.constraints.as_ref().and_then(|c| c.timeout_ms);
125
126 let action_execution = async {
127 let mut rule_outcome = RuleOutcome::MatchedAndExecuted;
128 let mut rule_terminated = false;
129
130 for action in &rule.actions {
131 match actions::execute_action(action, flow, &mut ctx).await {
132 actions::ActionOutcome::Continue => {},
133 actions::ActionOutcome::Terminated(reason) => {
134 rule_outcome = RuleOutcome::MatchedAndTerminated;
135 ctx.summary = RuleTraceSummary::Terminated {
136 rule_id: rule.id.clone(),
137 reason
138 };
139 rule_terminated = true;
140 break;
141 },
142 actions::ActionOutcome::Failed(err) => {
143 rule_outcome = RuleOutcome::Failed(err);
144 break;
145 }
146 }
147 }
148 (rule_outcome, rule_terminated)
149 };
150
151 let (rule_outcome, rule_terminated) = if let Some(ms) = timeout_ms {
152 match tokio::time::timeout(std::time::Duration::from_millis(ms), action_execution).await {
153 Ok(res) => res,
154 Err(_) => (RuleOutcome::Failed(format!("Rule execution timed out after {}ms", ms)), false)
155 }
156 } else {
157 action_execution.await
158 };
159
160 if rule_terminated {
161 terminated = true;
162 }
163
164 ctx.trace.push(RuleExecutionEvent {
165 rule_id: rule.id.clone(),
166 stage: stage.clone(),
167 matched: true,
168 duration_us: start.elapsed().as_micros() as u64,
169 outcome: rule_outcome.clone(),
170 });
171
172 if matches!(rule_outcome, RuleOutcome::MatchedAndExecuted) {
173 modified_rules.push(rule.id.clone());
174 }
175
176 if terminated {
177 break;
178 }
179
180 if let RuleTermination::Stop = rule.termination {
181 break;
182 }
183 }
184 }
185
186 if !terminated && !modified_rules.is_empty() {
187 ctx.summary = RuleTraceSummary::Modified { rule_ids: modified_rules };
188 }
189
190 ctx
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use crate::rule::model::{Action, Filter, Rule, RuleStage, RuleTermination, RuleOutcome, StringMatcher, BodySource, RuleConstraints};
198 use relay_core_api::flow::{Flow, Layer, NetworkInfo, TransportProtocol, HttpRequest};
199 use uuid::Uuid;
200 use chrono::Utc;
201 use url::Url;
202
203 fn create_test_flow() -> Flow {
204 Flow {
205 id: Uuid::new_v4(),
206 start_time: Utc::now(),
207 end_time: None,
208 network: NetworkInfo {
209 client_ip: "127.0.0.1".to_string(),
210 client_port: 12345,
211 server_ip: "1.1.1.1".to_string(),
212 server_port: 80,
213 protocol: TransportProtocol::TCP,
214 tls: false,
215 tls_version: None,
216 sni: None,
217 },
218 layer: Layer::Http(relay_core_api::flow::HttpLayer {
219 request: HttpRequest {
220 method: "GET".to_string(),
221 url: Url::parse("http://example.com").unwrap(),
222 version: "HTTP/1.1".to_string(),
223 headers: vec![],
224 cookies: vec![],
225 query: vec![],
226 body: None,
227 },
228 response: None,
229 error: None,
230 }),
231 tags: vec![],
232 meta: std::collections::HashMap::new(),
233 }
234 }
235
236 #[tokio::test]
237 async fn test_rule_timeout() {
238 let rule = Rule {
239 id: "test-rule-timeout".to_string(),
240 name: "Test Rule Timeout".to_string(),
241 active: true,
242 stage: RuleStage::RequestHeaders,
243 priority: 0,
244 termination: RuleTermination::Continue,
245 filter: Filter::All,
246 actions: vec![
247 Action::Delay { ms: 200 } ],
249 constraints: Some(RuleConstraints {
250 timeout_ms: Some(50), }),
252 };
253
254 let engine = RuleEngine::new(vec![rule], vec![], None, None);
255 let mut flow = create_test_flow();
256
257 let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
259
260 assert_eq!(ctx.trace.len(), 1);
262 if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
263 assert!(msg.contains("timed out"));
264 } else {
265 panic!("Expected Failed outcome, got {:?}", ctx.trace[0].outcome);
266 }
267 }
268
269 #[tokio::test]
270 async fn test_filter_stage_validation_failure() {
271 let rule = Rule {
272 id: "test-rule-1".to_string(),
273 name: "Test Rule 1".to_string(),
274 active: true,
275 stage: RuleStage::Connect, priority: 0,
277 termination: RuleTermination::Continue,
278 filter: Filter::Path(StringMatcher::Exact("/foo".to_string())), actions: vec![],
280 constraints: None,
281 };
282
283 let engine = RuleEngine::new(vec![rule], vec![], None, None);
284 let mut flow = create_test_flow();
285
286 let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
288
289 assert_eq!(ctx.trace.len(), 1);
291 if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
292 assert!(msg.contains("Filter invalid"));
293 } else {
294 assert!(false, "Expected Failed outcome, got {:?}", ctx.trace[0].outcome);
295 }
296 }
297
298 #[tokio::test]
299 async fn test_action_stage_validation_failure() {
300 let rule = Rule {
301 id: "test-rule-2".to_string(),
302 name: "Test Rule 2".to_string(),
303 active: true,
304 stage: RuleStage::Connect, priority: 0,
306 termination: RuleTermination::Continue,
307 filter: Filter::All, actions: vec![
309 Action::SetResponseBody { body: BodySource::Text("foo".to_string()) } ],
311 constraints: None,
312 };
313
314 let engine = RuleEngine::new(vec![rule], vec![], None, None);
315 let mut flow = create_test_flow();
316
317 let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
319
320 assert_eq!(ctx.trace.len(), 1);
322 if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
323 assert!(msg.contains("Action"));
324 assert!(msg.contains("not allowed"));
325 } else {
326 assert!(false, "Expected Failed outcome, got {:?}", ctx.trace[0].outcome);
327 }
328 }
329
330 #[test]
331 fn test_has_rules_for_stage_respects_active_flag() {
332 let inactive_rule = Rule {
333 id: "inactive-rh".to_string(),
334 name: "Inactive".to_string(),
335 active: false,
336 stage: RuleStage::RequestHeaders,
337 priority: 10,
338 termination: RuleTermination::Continue,
339 filter: Filter::All,
340 actions: vec![Action::Tag {
341 key: "k".to_string(),
342 value: "v".to_string(),
343 }],
344 constraints: None,
345 };
346 let active_connect_rule = Rule {
347 id: "active-connect".to_string(),
348 name: "Active Connect".to_string(),
349 active: true,
350 stage: RuleStage::Connect,
351 priority: 10,
352 termination: RuleTermination::Continue,
353 filter: Filter::All,
354 actions: vec![Action::Drop],
355 constraints: None,
356 };
357
358 let engine = RuleEngine::new(vec![inactive_rule, active_connect_rule], vec![], None, None);
359 assert!(
360 !engine.has_rules_for_stage(RuleStage::RequestHeaders),
361 "inactive rules should not count for stage presence"
362 );
363 assert!(engine.has_rules_for_stage(RuleStage::Connect));
364 }
365
366 #[tokio::test]
367 async fn test_execute_orders_rules_by_priority_descending() {
368 let low = Rule {
369 id: "low-pri".to_string(),
370 name: "low".to_string(),
371 active: true,
372 stage: RuleStage::RequestHeaders,
373 priority: 1,
374 termination: RuleTermination::Continue,
375 filter: Filter::All,
376 actions: vec![Action::Tag {
377 key: "order".to_string(),
378 value: "low".to_string(),
379 }],
380 constraints: None,
381 };
382 let high = Rule {
383 id: "high-pri".to_string(),
384 name: "high".to_string(),
385 active: true,
386 stage: RuleStage::RequestHeaders,
387 priority: 100,
388 termination: RuleTermination::Continue,
389 filter: Filter::All,
390 actions: vec![Action::Tag {
391 key: "order".to_string(),
392 value: "high".to_string(),
393 }],
394 constraints: None,
395 };
396 let engine = RuleEngine::new(vec![low, high], vec![], None, None);
397 let mut flow = create_test_flow();
398
399 let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
400 assert_eq!(ctx.trace.len(), 2);
401 assert_eq!(ctx.trace[0].rule_id, "high-pri");
402 assert_eq!(ctx.trace[1].rule_id, "low-pri");
403 assert_eq!(flow.tags, vec!["order:high".to_string(), "order:low".to_string()]);
404 }
405}