1use crate::rule::engine::actions;
2use crate::rule::engine::compiled::CompiledRule;
3use crate::rule::engine::compiler;
4use crate::rule::engine::matcher;
5use crate::rule::engine::state::{InMemoryRuleStateStore, RuleStateStore};
6use crate::rule::engine::validator;
7use crate::rule::model::{
8 Rule, RuleExecutionEvent, RuleGroup, RuleOutcome, RuleStage, RuleTermination, RuleTraceSummary,
9};
10use relay_core_api::flow::Flow;
11use relay_core_api::policy::ProxyPolicy;
12use std::collections::HashMap;
13use std::net::IpAddr;
14use std::sync::Arc;
15
16#[derive(Debug, Clone)]
18pub enum ConnectOverride {
19 ForwardPort { host: String, port: u16 },
20 RedirectIp { ip: IpAddr },
21 SetTtl { ttl: u8 },
22}
23
24impl ConnectOverride {
25 pub fn port(&self) -> Option<u16> {
26 match self {
27 ConnectOverride::ForwardPort { port, .. } => Some(*port),
28 _ => None,
29 }
30 }
31}
32
33pub struct ExecutionContext {
34 pub trace: Vec<RuleExecutionEvent>,
35 pub variables: HashMap<String, String>,
36 pub policy: Option<Arc<ProxyPolicy>>,
37 pub summary: RuleTraceSummary,
38 pub state_store: Arc<dyn RuleStateStore>,
39 pub throttle_bytes_per_sec: Option<u64>,
43 pub connect_override: Option<ConnectOverride>,
45}
46
47impl ExecutionContext {
48 pub fn is_terminated(&self) -> bool {
49 matches!(self.summary, RuleTraceSummary::Terminated { .. })
50 }
51}
52
53#[derive(Debug)]
54pub struct RuleEngine {
55 compiled_rules: Vec<CompiledRule>,
56 policy: Option<Arc<ProxyPolicy>>,
57 state_store: Arc<dyn RuleStateStore>,
58}
59
60impl RuleEngine {
61 pub fn new(
62 rules: Vec<Rule>,
63 rule_groups: Vec<RuleGroup>,
64 policy: Option<Arc<ProxyPolicy>>,
65 state_store: Option<Arc<dyn RuleStateStore>>,
66 ) -> Self {
67 let mut all_rules = Vec::new();
68
69 for rule in rules {
71 all_rules.push(rule);
72 }
73
74 for group in rule_groups {
76 if group.active {
77 for rule in group.rules {
78 all_rules.push(rule);
79 }
80 }
81 }
82
83 all_rules.sort_by_key(|r| std::cmp::Reverse(r.priority));
85
86 let compiled_rules = all_rules.into_iter().map(compiler::compile_rule).collect();
88
89 Self {
90 compiled_rules,
91 policy,
92 state_store: state_store.unwrap_or_else(|| Arc::new(InMemoryRuleStateStore::new())),
93 }
94 }
95
96 pub fn has_rules_for_stage(&self, stage: RuleStage) -> bool {
97 self.compiled_rules
98 .iter()
99 .any(|r| r.original.active && r.original.stage == stage)
100 }
101
102 pub async fn execute(&self, stage: RuleStage, flow: &mut Flow) -> ExecutionContext {
103 let mut ctx = ExecutionContext {
104 trace: vec![],
105 variables: HashMap::new(),
106 policy: self.policy.clone(),
107 summary: RuleTraceSummary::NoMatch,
108 state_store: self.state_store.clone(),
109 throttle_bytes_per_sec: None,
110 connect_override: None,
111 };
112
113 let mut terminated = false;
114 let mut modified_rules = Vec::new();
115
116 for compiled_rule in &self.compiled_rules {
117 let rule = &compiled_rule.original;
118
119 if !rule.active {
120 continue;
121 }
122
123 if rule.stage != stage {
124 continue;
125 }
126
127 if !validator::validate_filter_stage(&compiled_rule.filter, &stage) {
129 ctx.trace.push(RuleExecutionEvent {
130 rule_id: rule.id.clone(),
131 stage: stage.clone(),
132 matched: false,
133 duration_us: 0,
134 outcome: RuleOutcome::Failed(format!("Filter invalid for stage {:?}", stage)),
135 });
136 continue;
137 }
138
139 let mut actions_valid = true;
141 for action in &rule.actions {
142 if !validator::validate_action_stage(action, &stage) {
143 ctx.trace.push(RuleExecutionEvent {
144 rule_id: rule.id.clone(),
145 stage: stage.clone(),
146 matched: false,
147 duration_us: 0,
148 outcome: RuleOutcome::Failed(format!(
149 "Action {:?} not allowed in stage {:?}",
150 action, stage
151 )),
152 });
153 actions_valid = false;
154 break;
155 }
156 }
157 if !actions_valid {
158 continue;
159 }
160
161 let start = std::time::Instant::now();
163 let matched = matcher::matches(&compiled_rule.filter, flow);
164
165 if matched {
166 let timeout_ms = rule.constraints.as_ref().and_then(|c| c.timeout_ms);
167
168 let action_execution = async {
169 let mut rule_outcome = RuleOutcome::MatchedAndExecuted;
170 let mut rule_terminated = false;
171
172 for action in &rule.actions {
173 match actions::execute_action(action, flow, &mut ctx).await {
174 actions::ActionOutcome::Continue => {}
175 actions::ActionOutcome::Terminated(reason) => {
176 rule_outcome = RuleOutcome::MatchedAndTerminated;
177 ctx.summary = RuleTraceSummary::Terminated {
178 rule_id: rule.id.clone(),
179 reason,
180 };
181 rule_terminated = true;
182 break;
183 }
184 actions::ActionOutcome::Failed(err) => {
185 rule_outcome = RuleOutcome::Failed(err);
186 break;
187 }
188 }
189 }
190 (rule_outcome, rule_terminated)
191 };
192
193 let (rule_outcome, rule_terminated) = if let Some(ms) = timeout_ms {
194 match tokio::time::timeout(
195 std::time::Duration::from_millis(ms),
196 action_execution,
197 )
198 .await
199 {
200 Ok(res) => res,
201 Err(_) => (
202 RuleOutcome::Failed(format!("Rule execution timed out after {}ms", ms)),
203 false,
204 ),
205 }
206 } else {
207 action_execution.await
208 };
209
210 if rule_terminated {
211 terminated = true;
212 }
213
214 ctx.trace.push(RuleExecutionEvent {
215 rule_id: rule.id.clone(),
216 stage: stage.clone(),
217 matched: true,
218 duration_us: start.elapsed().as_micros() as u64,
219 outcome: rule_outcome.clone(),
220 });
221
222 if matches!(rule_outcome, RuleOutcome::MatchedAndExecuted) {
223 modified_rules.push(rule.id.clone());
224 }
225
226 if terminated {
227 break;
228 }
229
230 if let RuleTermination::Stop = rule.termination {
231 break;
232 }
233 }
234 }
235
236 if !terminated && !modified_rules.is_empty() {
237 ctx.summary = RuleTraceSummary::Modified {
238 rule_ids: modified_rules.clone(),
239 };
240 }
241
242 flow.rule_variables = ctx.variables.clone();
243 flow.matched_rules = modified_rules.clone();
244
245 ctx
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use crate::rule::model::{
253 Action, BodySource, Filter, Rule, RuleConstraints, RuleOutcome, RuleStage, RuleTermination,
254 StringMatcher,
255 };
256 use chrono::Utc;
257 use relay_core_api::flow::{Flow, HttpRequest, Layer, NetworkInfo, TransportProtocol};
258 use url::Url;
259 use uuid::Uuid;
260
261 fn create_test_flow() -> Flow {
262 Flow {
263 id: Uuid::new_v4(),
264 start_time: Utc::now(),
265 end_time: None,
266 network: NetworkInfo {
267 client_ip: "127.0.0.1".to_string(),
268 client_port: 12345,
269 server_ip: "1.1.1.1".to_string(),
270 server_port: 80,
271 protocol: TransportProtocol::TCP,
272 tls: false,
273 tls_version: None,
274 sni: None,
275 },
276 layer: Layer::Http(relay_core_api::flow::HttpLayer {
277 request: HttpRequest {
278 method: "GET".to_string(),
279 url: Url::parse("http://example.com").unwrap(),
280 version: "HTTP/1.1".to_string(),
281 headers: vec![],
282 cookies: vec![],
283 query: vec![],
284 body: None,
285 },
286 response: None,
287 error: None,
288 }),
289 tags: vec![],
290 meta: std::collections::HashMap::new(),
291 resilience_trace: None,
292 rule_variables: std::collections::HashMap::new(),
293 matched_rules: vec![],
294 }
295 }
296
297 #[tokio::test]
298 async fn test_rule_timeout() {
299 let rule = Rule {
300 id: "test-rule-timeout".to_string(),
301 name: "Test Rule Timeout".to_string(),
302 active: true,
303 stage: RuleStage::RequestHeaders,
304 priority: 0,
305 termination: RuleTermination::Continue,
306 filter: Filter::All,
307 actions: vec![
308 Action::Delay { ms: 200 }, ],
310 constraints: Some(RuleConstraints {
311 timeout_ms: Some(50), }),
313 };
314
315 let engine = RuleEngine::new(vec![rule], vec![], None, None);
316 let mut flow = create_test_flow();
317
318 let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
320
321 assert_eq!(ctx.trace.len(), 1);
323 if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
324 assert!(msg.contains("timed out"));
325 } else {
326 panic!("Expected Failed outcome, got {:?}", ctx.trace[0].outcome);
327 }
328 }
329
330 #[tokio::test]
331 async fn test_filter_stage_validation_failure() {
332 let rule = Rule {
333 id: "test-rule-1".to_string(),
334 name: "Test Rule 1".to_string(),
335 active: true,
336 stage: RuleStage::Connect, priority: 0,
338 termination: RuleTermination::Continue,
339 filter: Filter::Path(StringMatcher::Exact("/foo".to_string())), actions: vec![],
341 constraints: None,
342 };
343
344 let engine = RuleEngine::new(vec![rule], vec![], None, None);
345 let mut flow = create_test_flow();
346
347 let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
349
350 assert_eq!(ctx.trace.len(), 1);
352 if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
353 assert!(msg.contains("Filter invalid"));
354 } else {
355 assert!(
356 false,
357 "Expected Failed outcome, got {:?}",
358 ctx.trace[0].outcome
359 );
360 }
361 }
362
363 #[tokio::test]
364 async fn test_action_stage_validation_failure() {
365 let rule = Rule {
366 id: "test-rule-2".to_string(),
367 name: "Test Rule 2".to_string(),
368 active: true,
369 stage: RuleStage::Connect, priority: 0,
371 termination: RuleTermination::Continue,
372 filter: Filter::All, actions: vec![
374 Action::SetResponseBody {
375 body: BodySource::Text("foo".to_string()),
376 }, ],
378 constraints: None,
379 };
380
381 let engine = RuleEngine::new(vec![rule], vec![], None, None);
382 let mut flow = create_test_flow();
383
384 let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
386
387 assert_eq!(ctx.trace.len(), 1);
389 if let RuleOutcome::Failed(msg) = &ctx.trace[0].outcome {
390 assert!(msg.contains("Action"));
391 assert!(msg.contains("not allowed"));
392 } else {
393 assert!(
394 false,
395 "Expected Failed outcome, got {:?}",
396 ctx.trace[0].outcome
397 );
398 }
399 }
400
401 #[test]
402 fn test_has_rules_for_stage_respects_active_flag() {
403 let inactive_rule = Rule {
404 id: "inactive-rh".to_string(),
405 name: "Inactive".to_string(),
406 active: false,
407 stage: RuleStage::RequestHeaders,
408 priority: 10,
409 termination: RuleTermination::Continue,
410 filter: Filter::All,
411 actions: vec![Action::Tag {
412 key: "k".to_string(),
413 value: "v".to_string(),
414 }],
415 constraints: None,
416 };
417 let active_connect_rule = Rule {
418 id: "active-connect".to_string(),
419 name: "Active Connect".to_string(),
420 active: true,
421 stage: RuleStage::Connect,
422 priority: 10,
423 termination: RuleTermination::Continue,
424 filter: Filter::All,
425 actions: vec![Action::Drop],
426 constraints: None,
427 };
428
429 let engine = RuleEngine::new(vec![inactive_rule, active_connect_rule], vec![], None, None);
430 assert!(
431 !engine.has_rules_for_stage(RuleStage::RequestHeaders),
432 "inactive rules should not count for stage presence"
433 );
434 assert!(engine.has_rules_for_stage(RuleStage::Connect));
435 }
436
437 #[tokio::test]
438 async fn test_execute_orders_rules_by_priority_descending() {
439 let low = Rule {
440 id: "low-pri".to_string(),
441 name: "low".to_string(),
442 active: true,
443 stage: RuleStage::RequestHeaders,
444 priority: 1,
445 termination: RuleTermination::Continue,
446 filter: Filter::All,
447 actions: vec![Action::Tag {
448 key: "order".to_string(),
449 value: "low".to_string(),
450 }],
451 constraints: None,
452 };
453 let high = Rule {
454 id: "high-pri".to_string(),
455 name: "high".to_string(),
456 active: true,
457 stage: RuleStage::RequestHeaders,
458 priority: 100,
459 termination: RuleTermination::Continue,
460 filter: Filter::All,
461 actions: vec![Action::Tag {
462 key: "order".to_string(),
463 value: "high".to_string(),
464 }],
465 constraints: None,
466 };
467 let engine = RuleEngine::new(vec![low, high], vec![], None, None);
468 let mut flow = create_test_flow();
469
470 let ctx = engine.execute(RuleStage::RequestHeaders, &mut flow).await;
471 assert_eq!(ctx.trace.len(), 2);
472 assert_eq!(ctx.trace[0].rule_id, "high-pri");
473 assert_eq!(ctx.trace[1].rule_id, "low-pri");
474 assert_eq!(
475 flow.tags,
476 vec!["order:high".to_string(), "order:low".to_string()]
477 );
478 }
479}