1use crate::rule::engine::actions;
2use crate::rule::engine::compiled::CompiledRule;
3use crate::rule::engine::compiler;
4use crate::rule::engine::matcher;
5use crate::rule::engine::state::{InMemoryRuleStateStore, RuleStateStore};
6use crate::rule::engine::validator;
7use crate::rule::model::{
8 Rule, RuleExecutionEvent, RuleGroup, RuleOutcome, RuleStage, RuleTermination, RuleTraceSummary,
9};
10use relay_core_api::flow::Flow;
11use relay_core_api::policy::ProxyPolicy;
12use std::collections::HashMap;
13use std::sync::Arc;
14
15pub struct ExecutionContext {
16 pub trace: Vec<RuleExecutionEvent>,
17 pub variables: HashMap<String, String>,
18 pub policy: Option<Arc<ProxyPolicy>>,
19 pub summary: RuleTraceSummary,
20 pub state_store: Arc<dyn RuleStateStore>,
21}
22
23#[derive(Debug)]
24pub struct RuleEngine {
25 compiled_rules: Vec<CompiledRule>,
26 policy: Option<Arc<ProxyPolicy>>,
27 state_store: Arc<dyn RuleStateStore>,
28}
29
30impl RuleEngine {
31 pub fn new(
32 rules: Vec<Rule>,
33 rule_groups: Vec<RuleGroup>,
34 policy: Option<Arc<ProxyPolicy>>,
35 state_store: Option<Arc<dyn RuleStateStore>>,
36 ) -> Self {
37 let mut all_rules = Vec::new();
38
39 for rule in rules {
41 all_rules.push(rule);
42 }
43
44 for group in rule_groups {
46 if group.active {
47 for rule in group.rules {
48 all_rules.push(rule);
49 }
50 }
51 }
52
53 all_rules.sort_by_key(|r| std::cmp::Reverse(r.priority));
55
56 let compiled_rules = all_rules.into_iter().map(compiler::compile_rule).collect();
58
59 Self {
60 compiled_rules,
61 policy,
62 state_store: state_store.unwrap_or_else(|| Arc::new(InMemoryRuleStateStore::new())),
63 }
64 }
65
66 pub fn has_rules_for_stage(&self, stage: RuleStage) -> bool {
67 self.compiled_rules
68 .iter()
69 .any(|r| r.original.active && r.original.stage == stage)
70 }
71
72 pub async fn execute(&self, stage: RuleStage, flow: &mut Flow) -> ExecutionContext {
73 let mut ctx = ExecutionContext {
74 trace: vec![],
75 variables: HashMap::new(),
76 policy: self.policy.clone(),
77 summary: RuleTraceSummary::NoMatch,
78 state_store: self.state_store.clone(),
79 };
80
81 let mut terminated = false;
82 let mut modified_rules = Vec::new();
83
84 for compiled_rule in &self.compiled_rules {
85 let rule = &compiled_rule.original;
86
87 if !rule.active {
88 continue;
89 }
90
91 if rule.stage != stage {
92 continue;
93 }
94
95 if !validator::validate_filter_stage(&compiled_rule.filter, &stage) {
97 ctx.trace.push(RuleExecutionEvent {
98 rule_id: rule.id.clone(),
99 stage: stage.clone(),
100 matched: false,
101 duration_us: 0,
102 outcome: RuleOutcome::Failed(format!("Filter invalid for stage {:?}", stage)),
103 });
104 continue;
105 }
106
107 let mut actions_valid = true;
109 for action in &rule.actions {
110 if !validator::validate_action_stage(action, &stage) {
111 ctx.trace.push(RuleExecutionEvent {
112 rule_id: rule.id.clone(),
113 stage: stage.clone(),
114 matched: false,
115 duration_us: 0,
116 outcome: RuleOutcome::Failed(format!(
117 "Action {:?} not allowed in stage {:?}",
118 action, stage
119 )),
120 });
121 actions_valid = false;
122 break;
123 }
124 }
125 if !actions_valid {
126 continue;
127 }
128
129 let start = std::time::Instant::now();
131 let matched = matcher::matches(&compiled_rule.filter, flow);
132
133 if matched {
134 let timeout_ms = rule.constraints.as_ref().and_then(|c| c.timeout_ms);
135
136 let action_execution = async {
137 let mut rule_outcome = RuleOutcome::MatchedAndExecuted;
138 let mut rule_terminated = false;
139
140 for action in &rule.actions {
141 match actions::execute_action(action, flow, &mut ctx).await {
142 actions::ActionOutcome::Continue => {}
143 actions::ActionOutcome::Terminated(reason) => {
144 rule_outcome = RuleOutcome::MatchedAndTerminated;
145 ctx.summary = RuleTraceSummary::Terminated {
146 rule_id: rule.id.clone(),
147 reason,
148 };
149 rule_terminated = true;
150 break;
151 }
152 actions::ActionOutcome::Failed(err) => {
153 rule_outcome = RuleOutcome::Failed(err);
154 break;
155 }
156 }
157 }
158 (rule_outcome, rule_terminated)
159 };
160
161 let (rule_outcome, rule_terminated) = if let Some(ms) = timeout_ms {
162 match tokio::time::timeout(
163 std::time::Duration::from_millis(ms),
164 action_execution,
165 )
166 .await
167 {
168 Ok(res) => res,
169 Err(_) => (
170 RuleOutcome::Failed(format!("Rule execution timed out after {}ms", ms)),
171 false,
172 ),
173 }
174 } else {
175 action_execution.await
176 };
177
178 if rule_terminated {
179 terminated = true;
180 }
181
182 ctx.trace.push(RuleExecutionEvent {
183 rule_id: rule.id.clone(),
184 stage: stage.clone(),
185 matched: true,
186 duration_us: start.elapsed().as_micros() as u64,
187 outcome: rule_outcome.clone(),
188 });
189
190 if matches!(rule_outcome, RuleOutcome::MatchedAndExecuted) {
191 modified_rules.push(rule.id.clone());
192 }
193
194 if terminated {
195 break;
196 }
197
198 if let RuleTermination::Stop = rule.termination {
199 break;
200 }
201 }
202 }
203
204 if !terminated && !modified_rules.is_empty() {
205 ctx.summary = RuleTraceSummary::Modified {
206 rule_ids: modified_rules,
207 };
208 }
209
210 ctx
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use crate::rule::model::{
218 Action, BodySource, Filter, Rule, RuleConstraints, RuleOutcome, RuleStage, RuleTermination,
219 StringMatcher,
220 };
221 use chrono::Utc;
222 use relay_core_api::flow::{Flow, HttpRequest, Layer, NetworkInfo, TransportProtocol};
223 use url::Url;
224 use uuid::Uuid;
225
226 fn create_test_flow() -> Flow {
227 Flow {
228 id: Uuid::new_v4(),
229 start_time: Utc::now(),
230 end_time: None,
231 network: NetworkInfo {
232 client_ip: "127.0.0.1".to_string(),
233 client_port: 12345,
234 server_ip: "1.1.1.1".to_string(),
235 server_port: 80,
236 protocol: TransportProtocol::TCP,
237 tls: false,
238 tls_version: None,
239 sni: None,
240 },
241 layer: Layer::Http(relay_core_api::flow::HttpLayer {
242 request: HttpRequest {
243 method: "GET".to_string(),
244 url: Url::parse("http://example.com").unwrap(),
245 version: "HTTP/1.1".to_string(),
246 headers: vec![],
247 cookies: vec![],
248 query: vec![],
249 body: None,
250 },
251 response: None,
252 error: None,
253 }),
254 tags: vec![],
255 meta: std::collections::HashMap::new(),
256 }
257 }
258
259 #[tokio::test]
260 async fn test_rule_timeout() {
261 let rule = Rule {
262 id: "test-rule-timeout".to_string(),
263 name: "Test Rule Timeout".to_string(),
264 active: true,
265 stage: RuleStage::RequestHeaders,
266 priority: 0,
267 termination: RuleTermination::Continue,
268 filter: Filter::All,
269 actions: vec![
270 Action::Delay { ms: 200 }, ],
272 constraints: Some(RuleConstraints {
273 timeout_ms: Some(50), }),
275 };
276
277 let engine = RuleEngine::new(vec![rule], vec![], None, None);
278 let mut flow = create_test_flow();
279
280 let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
282
283 assert_eq!(ctx.trace.len(), 1);
285 if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
286 assert!(msg.contains("timed out"));
287 } else {
288 panic!("Expected Failed outcome, got {:?}", ctx.trace[0].outcome);
289 }
290 }
291
292 #[tokio::test]
293 async fn test_filter_stage_validation_failure() {
294 let rule = Rule {
295 id: "test-rule-1".to_string(),
296 name: "Test Rule 1".to_string(),
297 active: true,
298 stage: RuleStage::Connect, priority: 0,
300 termination: RuleTermination::Continue,
301 filter: Filter::Path(StringMatcher::Exact("/foo".to_string())), actions: vec![],
303 constraints: None,
304 };
305
306 let engine = RuleEngine::new(vec![rule], vec![], None, None);
307 let mut flow = create_test_flow();
308
309 let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
311
312 assert_eq!(ctx.trace.len(), 1);
314 if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
315 assert!(msg.contains("Filter invalid"));
316 } else {
317 assert!(
318 false,
319 "Expected Failed outcome, got {:?}",
320 ctx.trace[0].outcome
321 );
322 }
323 }
324
325 #[tokio::test]
326 async fn test_action_stage_validation_failure() {
327 let rule = Rule {
328 id: "test-rule-2".to_string(),
329 name: "Test Rule 2".to_string(),
330 active: true,
331 stage: RuleStage::Connect, priority: 0,
333 termination: RuleTermination::Continue,
334 filter: Filter::All, actions: vec![
336 Action::SetResponseBody {
337 body: BodySource::Text("foo".to_string()),
338 }, ],
340 constraints: None,
341 };
342
343 let engine = RuleEngine::new(vec![rule], vec![], None, None);
344 let mut flow = create_test_flow();
345
346 let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
348
349 assert_eq!(ctx.trace.len(), 1);
351 if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
352 assert!(msg.contains("Action"));
353 assert!(msg.contains("not allowed"));
354 } else {
355 assert!(
356 false,
357 "Expected Failed outcome, got {:?}",
358 ctx.trace[0].outcome
359 );
360 }
361 }
362
363 #[test]
364 fn test_has_rules_for_stage_respects_active_flag() {
365 let inactive_rule = Rule {
366 id: "inactive-rh".to_string(),
367 name: "Inactive".to_string(),
368 active: false,
369 stage: RuleStage::RequestHeaders,
370 priority: 10,
371 termination: RuleTermination::Continue,
372 filter: Filter::All,
373 actions: vec![Action::Tag {
374 key: "k".to_string(),
375 value: "v".to_string(),
376 }],
377 constraints: None,
378 };
379 let active_connect_rule = Rule {
380 id: "active-connect".to_string(),
381 name: "Active Connect".to_string(),
382 active: true,
383 stage: RuleStage::Connect,
384 priority: 10,
385 termination: RuleTermination::Continue,
386 filter: Filter::All,
387 actions: vec![Action::Drop],
388 constraints: None,
389 };
390
391 let engine = RuleEngine::new(vec![inactive_rule, active_connect_rule], vec![], None, None);
392 assert!(
393 !engine.has_rules_for_stage(RuleStage::RequestHeaders),
394 "inactive rules should not count for stage presence"
395 );
396 assert!(engine.has_rules_for_stage(RuleStage::Connect));
397 }
398
399 #[tokio::test]
400 async fn test_execute_orders_rules_by_priority_descending() {
401 let low = Rule {
402 id: "low-pri".to_string(),
403 name: "low".to_string(),
404 active: true,
405 stage: RuleStage::RequestHeaders,
406 priority: 1,
407 termination: RuleTermination::Continue,
408 filter: Filter::All,
409 actions: vec![Action::Tag {
410 key: "order".to_string(),
411 value: "low".to_string(),
412 }],
413 constraints: None,
414 };
415 let high = Rule {
416 id: "high-pri".to_string(),
417 name: "high".to_string(),
418 active: true,
419 stage: RuleStage::RequestHeaders,
420 priority: 100,
421 termination: RuleTermination::Continue,
422 filter: Filter::All,
423 actions: vec![Action::Tag {
424 key: "order".to_string(),
425 value: "high".to_string(),
426 }],
427 constraints: None,
428 };
429 let engine = RuleEngine::new(vec![low, high], vec![], None, None);
430 let mut flow = create_test_flow();
431
432 let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
433 assert_eq!(ctx.trace.len(), 2);
434 assert_eq!(ctx.trace[0].rule_id, "high-pri");
435 assert_eq!(ctx.trace[1].rule_id, "low-pri");
436 assert_eq!(
437 flow.tags,
438 vec!["order:high".to_string(), "order:low".to_string()]
439 );
440 }
441}