relay_core_lib/rule/engine/
validator.rs1use crate::rule::model::{Action, RuleStage};
2use crate::rule::engine::compiled::CompiledFilter;
3
4pub fn validate_filter_stage(filter: &CompiledFilter, stage: &RuleStage) -> bool {
5 match filter {
6 CompiledFilter::All => true,
7 CompiledFilter::SrcIp(_) | CompiledFilter::DstPort(_) | CompiledFilter::Protocol(_) | CompiledFilter::TransparentMode(_) => {
8 true }
10 CompiledFilter::Url(_) | CompiledFilter::Host(_) | CompiledFilter::Path(_) | CompiledFilter::Method(_) | CompiledFilter::RequestHeader { .. } => {
11 !matches!(stage, RuleStage::Connect)
12 }
13 CompiledFilter::ResponseHeader { .. } | CompiledFilter::StatusCode(_) => {
14 !matches!(stage, RuleStage::Connect | RuleStage::RequestHeaders | RuleStage::RequestBody)
15 }
16 CompiledFilter::ResponseBody(_) => {
17 matches!(stage, RuleStage::ResponseBody)
18 }
19 CompiledFilter::WebSocketMessage(_) => {
20 matches!(stage, RuleStage::WebSocketMessage)
21 }
22 CompiledFilter::And(filters) | CompiledFilter::Or(filters) => {
23 filters.iter().all(|f| validate_filter_stage(f, stage))
24 }
25 CompiledFilter::Not(f) => validate_filter_stage(f, stage),
26 CompiledFilter::Invalid => false,
27 }
28}
29
30pub fn validate_action_stage(action: &Action, stage: &RuleStage) -> bool {
31 match stage {
32 RuleStage::Connect => matches!(action, Action::Drop | Action::Abort | Action::Delay { .. } | Action::Throttle { .. } |
33 Action::Tag { .. } | Action::SetVariable { .. } | Action::Inspect | Action::RateLimit { .. } |
34 Action::RedirectIp { .. } | Action::SetTtl { .. } | Action::ForwardPort { .. }),
35 RuleStage::RequestHeaders => matches!(action, Action::Drop | Action::Abort | Action::Delay { .. } | Action::Throttle { .. } |
36 Action::Tag { .. } | Action::SetVariable { .. } | Action::Inspect | Action::RateLimit { .. } |
37 Action::MockResponse { .. } | Action::MapLocal { .. } | Action::MapRemote { .. } |
38 Action::Redirect { .. } |
39 Action::AddRequestHeader { .. } | Action::UpdateRequestHeader { .. } | Action::DeleteRequestHeader { .. } |
40 Action::SetRequestMethod { .. } | Action::SetRequestUrl { .. } |
41 Action::SetRequestBody { .. }),
42 RuleStage::RequestBody => matches!(action, Action::Drop | Action::Abort | Action::Delay { .. } | Action::Throttle { .. } |
43 Action::Tag { .. } | Action::SetVariable { .. } | Action::Inspect | Action::RateLimit { .. } |
44 Action::SetRequestBody { .. } | Action::TransformRequestBody { .. } |
45 Action::MockResponse { .. }),
47 RuleStage::ResponseHeaders => matches!(action, Action::Drop | Action::Abort | Action::Delay { .. } | Action::Throttle { .. } |
48 Action::Tag { .. } | Action::SetVariable { .. } | Action::Inspect | Action::RateLimit { .. } |
49 Action::AddResponseHeader { .. } | Action::UpdateResponseHeader { .. } | Action::DeleteResponseHeader { .. } |
50 Action::SetResponseStatus { .. } | Action::SetResponseBody { .. }),
51 RuleStage::ResponseBody => matches!(action, Action::Drop | Action::Abort | Action::Delay { .. } | Action::Throttle { .. } |
52 Action::Tag { .. } | Action::SetVariable { .. } | Action::Inspect | Action::RateLimit { .. } |
53 Action::SetResponseBody { .. } | Action::TransformResponseBody { .. }),
54 RuleStage::WebSocketMessage => matches!(action, Action::Drop | Action::Abort | Action::Delay { .. } | Action::Throttle { .. } |
55 Action::Tag { .. } | Action::SetVariable { .. } | Action::Inspect | Action::RateLimit { .. } |
56 Action::MockWebSocketMessage { .. } | Action::DropWebSocketMessage),
57 }
58}
59
60#[cfg(test)]
61mod tests {
62 use super::{validate_action_stage, validate_filter_stage};
63 use crate::rule::engine::compiled::{CompiledFilter, CompiledStringMatcher};
64 use crate::rule::model::{Action, BodySource, RuleStage, WebSocketDirection};
65
66 #[test]
67 fn test_validate_filter_stage_response_filters() {
68 let f = CompiledFilter::ResponseHeader {
69 name: "Content-Type".to_string(),
70 value: Some(CompiledStringMatcher::Contains("json".to_string())),
71 };
72 assert!(!validate_filter_stage(&f, &RuleStage::RequestHeaders));
73 assert!(!validate_filter_stage(&f, &RuleStage::RequestBody));
74 assert!(validate_filter_stage(&f, &RuleStage::ResponseHeaders));
75 assert!(validate_filter_stage(&f, &RuleStage::ResponseBody));
76 }
77
78 #[test]
79 fn test_validate_filter_stage_response_body_only() {
80 let f = CompiledFilter::ResponseBody(CompiledStringMatcher::Contains("err".to_string()));
81 assert!(!validate_filter_stage(&f, &RuleStage::ResponseHeaders));
82 assert!(validate_filter_stage(&f, &RuleStage::ResponseBody));
83 assert!(!validate_filter_stage(&f, &RuleStage::WebSocketMessage));
84 }
85
86 #[test]
87 fn test_validate_filter_stage_websocket_only() {
88 let f = CompiledFilter::WebSocketMessage(CompiledStringMatcher::Contains("ping".to_string()));
89 assert!(validate_filter_stage(&f, &RuleStage::WebSocketMessage));
90 assert!(!validate_filter_stage(&f, &RuleStage::RequestHeaders));
91 assert!(!validate_filter_stage(&f, &RuleStage::ResponseBody));
92 }
93
94 #[test]
95 fn test_validate_filter_stage_composite_requires_all_members_valid() {
96 let valid_ws = CompiledFilter::WebSocketMessage(CompiledStringMatcher::Contains("x".to_string()));
97 let invalid_in_ws = CompiledFilter::ResponseBody(CompiledStringMatcher::Contains("y".to_string()));
98 let and_filter = CompiledFilter::And(vec![valid_ws.clone(), invalid_in_ws.clone()]);
99 let or_filter = CompiledFilter::Or(vec![valid_ws.clone(), invalid_in_ws]);
100 let not_filter = CompiledFilter::Not(Box::new(valid_ws));
101
102 assert!(
103 !validate_filter_stage(&and_filter, &RuleStage::WebSocketMessage),
104 "AND should fail when one child invalid for stage"
105 );
106 assert!(
107 !validate_filter_stage(&or_filter, &RuleStage::WebSocketMessage),
108 "OR currently enforces all children stage-valid"
109 );
110 assert!(validate_filter_stage(¬_filter, &RuleStage::WebSocketMessage));
111 }
112
113 #[test]
114 fn test_validate_action_stage_representative_matrix() {
115 let set_body = Action::SetRequestBody {
116 body: BodySource::Text("x".to_string()),
117 };
118 let set_status = Action::SetResponseStatus { status: 418 };
119 let mock_ws = Action::MockWebSocketMessage {
120 direction: WebSocketDirection::Incoming,
121 message: "pong".to_string(),
122 };
123
124 assert!(validate_action_stage(&set_body, &RuleStage::RequestHeaders));
125 assert!(validate_action_stage(&set_body, &RuleStage::RequestBody));
126 assert!(!validate_action_stage(&set_body, &RuleStage::Connect));
127 assert!(!validate_action_stage(&set_body, &RuleStage::ResponseHeaders));
128
129 assert!(validate_action_stage(&set_status, &RuleStage::ResponseHeaders));
130 assert!(!validate_action_stage(&set_status, &RuleStage::RequestHeaders));
131
132 assert!(validate_action_stage(&mock_ws, &RuleStage::WebSocketMessage));
133 assert!(!validate_action_stage(&mock_ws, &RuleStage::RequestHeaders));
134 }
135}