1use crate::rule::engine::actions;
2use crate::rule::engine::compiled::CompiledRule;
3use crate::rule::engine::compiler;
4use crate::rule::engine::matcher;
5use crate::rule::engine::state::{InMemoryRuleStateStore, RuleStateStore};
6use crate::rule::engine::validator;
7use crate::rule::model::{
8 Rule, RuleExecutionEvent, RuleGroup, RuleOutcome, RuleStage, RuleTermination, RuleTraceSummary,
9};
10use relay_core_api::flow::Flow;
11use relay_core_api::policy::ProxyPolicy;
12use std::collections::HashMap;
13use std::sync::Arc;
14
15pub struct ExecutionContext {
16 pub trace: Vec<RuleExecutionEvent>,
17 pub variables: HashMap<String, String>,
18 pub policy: Option<Arc<ProxyPolicy>>,
19 pub summary: RuleTraceSummary,
20 pub state_store: Arc<dyn RuleStateStore>,
21 pub throttle_bytes_per_sec: Option<u64>,
25}
26
27#[derive(Debug)]
28pub struct RuleEngine {
29 compiled_rules: Vec<CompiledRule>,
30 policy: Option<Arc<ProxyPolicy>>,
31 state_store: Arc<dyn RuleStateStore>,
32}
33
34impl RuleEngine {
35 pub fn new(
36 rules: Vec<Rule>,
37 rule_groups: Vec<RuleGroup>,
38 policy: Option<Arc<ProxyPolicy>>,
39 state_store: Option<Arc<dyn RuleStateStore>>,
40 ) -> Self {
41 let mut all_rules = Vec::new();
42
43 for rule in rules {
45 all_rules.push(rule);
46 }
47
48 for group in rule_groups {
50 if group.active {
51 for rule in group.rules {
52 all_rules.push(rule);
53 }
54 }
55 }
56
57 all_rules.sort_by_key(|r| std::cmp::Reverse(r.priority));
59
60 let compiled_rules = all_rules.into_iter().map(compiler::compile_rule).collect();
62
63 Self {
64 compiled_rules,
65 policy,
66 state_store: state_store.unwrap_or_else(|| Arc::new(InMemoryRuleStateStore::new())),
67 }
68 }
69
70 pub fn has_rules_for_stage(&self, stage: RuleStage) -> bool {
71 self.compiled_rules
72 .iter()
73 .any(|r| r.original.active && r.original.stage == stage)
74 }
75
76 pub async fn execute(&self, stage: RuleStage, flow: &mut Flow) -> ExecutionContext {
77 let mut ctx = ExecutionContext {
78 trace: vec![],
79 variables: HashMap::new(),
80 policy: self.policy.clone(),
81 summary: RuleTraceSummary::NoMatch,
82 state_store: self.state_store.clone(),
83 throttle_bytes_per_sec: None,
84 };
85
86 let mut terminated = false;
87 let mut modified_rules = Vec::new();
88
89 for compiled_rule in &self.compiled_rules {
90 let rule = &compiled_rule.original;
91
92 if !rule.active {
93 continue;
94 }
95
96 if rule.stage != stage {
97 continue;
98 }
99
100 if !validator::validate_filter_stage(&compiled_rule.filter, &stage) {
102 ctx.trace.push(RuleExecutionEvent {
103 rule_id: rule.id.clone(),
104 stage: stage.clone(),
105 matched: false,
106 duration_us: 0,
107 outcome: RuleOutcome::Failed(format!("Filter invalid for stage {:?}", stage)),
108 });
109 continue;
110 }
111
112 let mut actions_valid = true;
114 for action in &rule.actions {
115 if !validator::validate_action_stage(action, &stage) {
116 ctx.trace.push(RuleExecutionEvent {
117 rule_id: rule.id.clone(),
118 stage: stage.clone(),
119 matched: false,
120 duration_us: 0,
121 outcome: RuleOutcome::Failed(format!(
122 "Action {:?} not allowed in stage {:?}",
123 action, stage
124 )),
125 });
126 actions_valid = false;
127 break;
128 }
129 }
130 if !actions_valid {
131 continue;
132 }
133
134 let start = std::time::Instant::now();
136 let matched = matcher::matches(&compiled_rule.filter, flow);
137
138 if matched {
139 let timeout_ms = rule.constraints.as_ref().and_then(|c| c.timeout_ms);
140
141 let action_execution = async {
142 let mut rule_outcome = RuleOutcome::MatchedAndExecuted;
143 let mut rule_terminated = false;
144
145 for action in &rule.actions {
146 match actions::execute_action(action, flow, &mut ctx).await {
147 actions::ActionOutcome::Continue => {}
148 actions::ActionOutcome::Terminated(reason) => {
149 rule_outcome = RuleOutcome::MatchedAndTerminated;
150 ctx.summary = RuleTraceSummary::Terminated {
151 rule_id: rule.id.clone(),
152 reason,
153 };
154 rule_terminated = true;
155 break;
156 }
157 actions::ActionOutcome::Failed(err) => {
158 rule_outcome = RuleOutcome::Failed(err);
159 break;
160 }
161 }
162 }
163 (rule_outcome, rule_terminated)
164 };
165
166 let (rule_outcome, rule_terminated) = if let Some(ms) = timeout_ms {
167 match tokio::time::timeout(
168 std::time::Duration::from_millis(ms),
169 action_execution,
170 )
171 .await
172 {
173 Ok(res) => res,
174 Err(_) => (
175 RuleOutcome::Failed(format!("Rule execution timed out after {}ms", ms)),
176 false,
177 ),
178 }
179 } else {
180 action_execution.await
181 };
182
183 if rule_terminated {
184 terminated = true;
185 }
186
187 ctx.trace.push(RuleExecutionEvent {
188 rule_id: rule.id.clone(),
189 stage: stage.clone(),
190 matched: true,
191 duration_us: start.elapsed().as_micros() as u64,
192 outcome: rule_outcome.clone(),
193 });
194
195 if matches!(rule_outcome, RuleOutcome::MatchedAndExecuted) {
196 modified_rules.push(rule.id.clone());
197 }
198
199 if terminated {
200 break;
201 }
202
203 if let RuleTermination::Stop = rule.termination {
204 break;
205 }
206 }
207 }
208
209 if !terminated && !modified_rules.is_empty() {
210 ctx.summary = RuleTraceSummary::Modified {
211 rule_ids: modified_rules,
212 };
213 }
214
215 ctx
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use crate::rule::model::{
223 Action, BodySource, Filter, Rule, RuleConstraints, RuleOutcome, RuleStage, RuleTermination,
224 StringMatcher,
225 };
226 use chrono::Utc;
227 use relay_core_api::flow::{Flow, HttpRequest, Layer, NetworkInfo, TransportProtocol};
228 use url::Url;
229 use uuid::Uuid;
230
231 fn create_test_flow() -> Flow {
232 Flow {
233 id: Uuid::new_v4(),
234 start_time: Utc::now(),
235 end_time: None,
236 network: NetworkInfo {
237 client_ip: "127.0.0.1".to_string(),
238 client_port: 12345,
239 server_ip: "1.1.1.1".to_string(),
240 server_port: 80,
241 protocol: TransportProtocol::TCP,
242 tls: false,
243 tls_version: None,
244 sni: None,
245 },
246 layer: Layer::Http(relay_core_api::flow::HttpLayer {
247 request: HttpRequest {
248 method: "GET".to_string(),
249 url: Url::parse("http://example.com").unwrap(),
250 version: "HTTP/1.1".to_string(),
251 headers: vec![],
252 cookies: vec![],
253 query: vec![],
254 body: None,
255 },
256 response: None,
257 error: None,
258 }),
259 tags: vec![],
260 meta: std::collections::HashMap::new(),
261 }
262 }
263
264 #[tokio::test]
265 async fn test_rule_timeout() {
266 let rule = Rule {
267 id: "test-rule-timeout".to_string(),
268 name: "Test Rule Timeout".to_string(),
269 active: true,
270 stage: RuleStage::RequestHeaders,
271 priority: 0,
272 termination: RuleTermination::Continue,
273 filter: Filter::All,
274 actions: vec![
275 Action::Delay { ms: 200 }, ],
277 constraints: Some(RuleConstraints {
278 timeout_ms: Some(50), }),
280 };
281
282 let engine = RuleEngine::new(vec![rule], vec![], None, None);
283 let mut flow = create_test_flow();
284
285 let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
287
288 assert_eq!(ctx.trace.len(), 1);
290 if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
291 assert!(msg.contains("timed out"));
292 } else {
293 panic!("Expected Failed outcome, got {:?}", ctx.trace[0].outcome);
294 }
295 }
296
297 #[tokio::test]
298 async fn test_filter_stage_validation_failure() {
299 let rule = Rule {
300 id: "test-rule-1".to_string(),
301 name: "Test Rule 1".to_string(),
302 active: true,
303 stage: RuleStage::Connect, priority: 0,
305 termination: RuleTermination::Continue,
306 filter: Filter::Path(StringMatcher::Exact("/foo".to_string())), actions: vec![],
308 constraints: None,
309 };
310
311 let engine = RuleEngine::new(vec![rule], vec![], None, None);
312 let mut flow = create_test_flow();
313
314 let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
316
317 assert_eq!(ctx.trace.len(), 1);
319 if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
320 assert!(msg.contains("Filter invalid"));
321 } else {
322 assert!(
323 false,
324 "Expected Failed outcome, got {:?}",
325 ctx.trace[0].outcome
326 );
327 }
328 }
329
330 #[tokio::test]
331 async fn test_action_stage_validation_failure() {
332 let rule = Rule {
333 id: "test-rule-2".to_string(),
334 name: "Test Rule 2".to_string(),
335 active: true,
336 stage: RuleStage::Connect, priority: 0,
338 termination: RuleTermination::Continue,
339 filter: Filter::All, actions: vec![
341 Action::SetResponseBody {
342 body: BodySource::Text("foo".to_string()),
343 }, ],
345 constraints: None,
346 };
347
348 let engine = RuleEngine::new(vec![rule], vec![], None, None);
349 let mut flow = create_test_flow();
350
351 let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
353
354 assert_eq!(ctx.trace.len(), 1);
356 if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
357 assert!(msg.contains("Action"));
358 assert!(msg.contains("not allowed"));
359 } else {
360 assert!(
361 false,
362 "Expected Failed outcome, got {:?}",
363 ctx.trace[0].outcome
364 );
365 }
366 }
367
368 #[test]
369 fn test_has_rules_for_stage_respects_active_flag() {
370 let inactive_rule = Rule {
371 id: "inactive-rh".to_string(),
372 name: "Inactive".to_string(),
373 active: false,
374 stage: RuleStage::RequestHeaders,
375 priority: 10,
376 termination: RuleTermination::Continue,
377 filter: Filter::All,
378 actions: vec![Action::Tag {
379 key: "k".to_string(),
380 value: "v".to_string(),
381 }],
382 constraints: None,
383 };
384 let active_connect_rule = Rule {
385 id: "active-connect".to_string(),
386 name: "Active Connect".to_string(),
387 active: true,
388 stage: RuleStage::Connect,
389 priority: 10,
390 termination: RuleTermination::Continue,
391 filter: Filter::All,
392 actions: vec![Action::Drop],
393 constraints: None,
394 };
395
396 let engine = RuleEngine::new(vec![inactive_rule, active_connect_rule], vec![], None, None);
397 assert!(
398 !engine.has_rules_for_stage(RuleStage::RequestHeaders),
399 "inactive rules should not count for stage presence"
400 );
401 assert!(engine.has_rules_for_stage(RuleStage::Connect));
402 }
403
404 #[tokio::test]
405 async fn test_execute_orders_rules_by_priority_descending() {
406 let low = Rule {
407 id: "low-pri".to_string(),
408 name: "low".to_string(),
409 active: true,
410 stage: RuleStage::RequestHeaders,
411 priority: 1,
412 termination: RuleTermination::Continue,
413 filter: Filter::All,
414 actions: vec![Action::Tag {
415 key: "order".to_string(),
416 value: "low".to_string(),
417 }],
418 constraints: None,
419 };
420 let high = Rule {
421 id: "high-pri".to_string(),
422 name: "high".to_string(),
423 active: true,
424 stage: RuleStage::RequestHeaders,
425 priority: 100,
426 termination: RuleTermination::Continue,
427 filter: Filter::All,
428 actions: vec![Action::Tag {
429 key: "order".to_string(),
430 value: "high".to_string(),
431 }],
432 constraints: None,
433 };
434 let engine = RuleEngine::new(vec![low, high], vec![], None, None);
435 let mut flow = create_test_flow();
436
437 let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
438 assert_eq!(ctx.trace.len(), 2);
439 assert_eq!(ctx.trace[0].rule_id, "high-pri");
440 assert_eq!(ctx.trace[1].rule_id, "low-pri");
441 assert_eq!(
442 flow.tags,
443 vec!["order:high".to_string(), "order:low".to_string()]
444 );
445 }
446}