Skip to main content

relay_core_runtime/
rule.rs

1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use relay_core_api::flow::{Flow, Layer};
5use relay_core_lib::rule::{Action, BodySource, Filter, Rule, RuleStage, RuleTermination, StringMatcher};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct InterceptRule {
9    pub id: String,
10    pub active: bool,
11    pub url_pattern: String,
12    pub method: Option<String>,
13    pub phase: String, // "request", "response", "both", "ws_message"
14}
15
16#[derive(Debug, Clone)]
17pub struct InterceptRuleConfig {
18    pub rule_id: String,
19    pub active: bool,
20    pub url_pattern: String,
21    pub method: Option<String>,
22    pub phase: String,
23    pub name: String,
24    pub priority: i32,
25    pub termination: RuleTermination,
26}
27
28#[derive(Debug, Clone)]
29pub struct MockResponseRuleConfig {
30    pub rule_id: String,
31    pub url_pattern: String,
32    pub name: String,
33    pub status: u16,
34    pub content_type: String,
35    pub body: String,
36}
37
38impl InterceptRule {
39    pub fn matches(&self, flow: &Flow, phase: &str) -> bool {
40        if !self.active {
41            return false;
42        }
43
44        // Phase matching logic
45        if self.phase == "both" {
46            if phase == "ws_message" {
47                return false;
48            }
49        } else if self.phase != phase {
50            return false;
51        }
52
53        let url = match &flow.layer {
54            Layer::Http(http) => Some(http.request.url.to_string()),
55            Layer::WebSocket(ws) => Some(ws.handshake_request.url.to_string()),
56            _ => None,
57        };
58
59        let method = match &flow.layer {
60            Layer::Http(http) => Some(http.request.method.to_string()),
61            Layer::WebSocket(ws) => Some(ws.handshake_request.method.to_string()),
62            _ => None,
63        };
64
65        let url_str = url.as_deref().unwrap_or("");
66        let method_str = method.as_deref().unwrap_or("");
67
68        if let Some(m) = &self.method
69            && !m.eq_ignore_ascii_case(method_str) {
70                return false;
71            }
72
73        if let Ok(re) = Regex::new(&self.url_pattern) {
74            if re.is_match(url_str) {
75                return true;
76            }
77        } else if url_str.contains(&self.url_pattern) {
78             return true;
79        }
80
81        false
82    }
83
84    pub fn to_rules(&self) -> Vec<Rule> {
85        build_intercept_rules(InterceptRuleConfig {
86            rule_id: self.id.clone(),
87            active: self.active,
88            url_pattern: self.url_pattern.clone(),
89            method: self.method.clone(),
90            phase: self.phase.clone(),
91            name: format!("Legacy Rule {}", self.id),
92            priority: 0,
93            termination: RuleTermination::Continue,
94        })
95    }
96}
97
98pub fn build_intercept_rules(config: InterceptRuleConfig) -> Vec<Rule> {
99    if !config.active {
100        return vec![];
101    }
102
103    let InterceptRuleConfig {
104        rule_id,
105        url_pattern,
106        method,
107        phase,
108        name,
109        priority,
110        termination,
111        ..
112    } = config;
113
114    let stages = match phase.as_str() {
115        "request" => vec![RuleStage::RequestHeaders],
116        "response" => vec![RuleStage::ResponseHeaders],
117        "ws_message" => vec![RuleStage::WebSocketMessage],
118        "both" => vec![RuleStage::RequestHeaders, RuleStage::ResponseHeaders],
119        _ => return vec![],
120    };
121
122    let url_filter = Filter::Url(build_url_matcher(url_pattern));
123    let filter = if let Some(method) = method {
124        Filter::And(vec![url_filter, Filter::Method(StringMatcher::Exact(method))])
125    } else {
126        url_filter
127    };
128
129    let stages_len = stages.len();
130    stages
131        .into_iter()
132        .enumerate()
133        .map(|(i, stage)| Rule {
134            id: if stages_len > 1 {
135                format!("{}-{}", rule_id, i)
136            } else {
137                rule_id.clone()
138            },
139            name: name.clone(),
140            active: true,
141            stage,
142            priority,
143            termination: termination.clone(),
144            filter: filter.clone(),
145            actions: vec![Action::Inspect],
146            constraints: None,
147        })
148        .collect()
149}
150
151pub fn build_mock_response_rule(config: MockResponseRuleConfig) -> Rule {
152    let mut headers = HashMap::new();
153    headers.insert("Content-Type".to_string(), config.content_type);
154
155    Rule {
156        id: config.rule_id,
157        name: config.name,
158        active: true,
159        stage: RuleStage::RequestHeaders,
160        priority: 200,
161        termination: RuleTermination::Stop,
162        filter: Filter::Url(build_url_matcher(config.url_pattern)),
163        actions: vec![Action::MockResponse {
164            status: config.status,
165            headers,
166            body: if config.body.is_empty() {
167                None
168            } else {
169                Some(BodySource::Text(config.body))
170            },
171        }],
172        constraints: None,
173    }
174}
175
176fn build_url_matcher(url_pattern: String) -> StringMatcher {
177    if Regex::new(&url_pattern).is_ok() {
178        StringMatcher::Regex(url_pattern)
179    } else {
180        StringMatcher::Contains(url_pattern)
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::{
187        InterceptRule, InterceptRuleConfig, MockResponseRuleConfig, build_intercept_rules,
188        build_mock_response_rule,
189    };
190    use chrono::Utc;
191    use relay_core_api::flow::{
192        Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, ResponseTiming,
193        TransportProtocol, WebSocketLayer,
194    };
195    use relay_core_lib::rule::{Action, BodySource, Filter, RuleStage, RuleTermination, StringMatcher};
196    use std::collections::HashMap;
197    use url::Url;
198    use uuid::Uuid;
199
200    fn sample_http_flow(url: &str) -> Flow {
201        Flow {
202            id: Uuid::new_v4(),
203            start_time: Utc::now(),
204            end_time: None,
205            network: NetworkInfo {
206                client_ip: "127.0.0.1".to_string(),
207                client_port: 12345,
208                server_ip: "1.1.1.1".to_string(),
209                server_port: 80,
210                protocol: TransportProtocol::TCP,
211                tls: false,
212                tls_version: None,
213                sni: None,
214            },
215            layer: Layer::Http(HttpLayer {
216                request: HttpRequest {
217                    method: "GET".to_string(),
218                    url: Url::parse(url).expect("url"),
219                    version: "HTTP/1.1".to_string(),
220                    headers: vec![],
221                    body: None,
222                    cookies: vec![],
223                    query: vec![],
224                },
225                response: None,
226                error: None,
227            }),
228            tags: vec![],
229            meta: HashMap::new(),
230        }
231    }
232
233    fn sample_ws_flow(url: &str) -> Flow {
234        Flow {
235            id: Uuid::new_v4(),
236            start_time: Utc::now(),
237            end_time: None,
238            network: NetworkInfo {
239                client_ip: "127.0.0.1".to_string(),
240                client_port: 12345,
241                server_ip: "1.1.1.1".to_string(),
242                server_port: 80,
243                protocol: TransportProtocol::TCP,
244                tls: false,
245                tls_version: None,
246                sni: None,
247            },
248            layer: Layer::WebSocket(WebSocketLayer {
249                handshake_request: HttpRequest {
250                    method: "GET".to_string(),
251                    url: Url::parse(url).expect("url"),
252                    version: "HTTP/1.1".to_string(),
253                    headers: vec![],
254                    body: None,
255                    cookies: vec![],
256                    query: vec![],
257                },
258                handshake_response: HttpResponse {
259                    status: 101,
260                    status_text: "Switching Protocols".to_string(),
261                    version: "HTTP/1.1".to_string(),
262                    headers: vec![],
263                    body: None,
264                    timing: ResponseTiming {
265                        time_to_first_byte: None,
266                        time_to_last_byte: None,
267                    },
268                    cookies: vec![],
269                },
270                messages: vec![],
271                closed: false,
272            }),
273            tags: vec![],
274            meta: HashMap::new(),
275        }
276    }
277
278    #[test]
279    fn test_to_rules_inactive_returns_empty() {
280        let r = InterceptRule {
281            id: "legacy-inactive".to_string(),
282            active: false,
283            url_pattern: "example.com".to_string(),
284            method: None,
285            phase: "request".to_string(),
286        };
287        assert!(r.to_rules().is_empty());
288    }
289
290    #[test]
291    fn test_to_rules_invalid_phase_returns_empty() {
292        let r = InterceptRule {
293            id: "legacy-invalid".to_string(),
294            active: true,
295            url_pattern: "example.com".to_string(),
296            method: None,
297            phase: "not-a-phase".to_string(),
298        };
299        assert!(r.to_rules().is_empty());
300    }
301
302    #[test]
303    fn test_to_rules_both_phase_generates_two_stages_with_suffix_ids() {
304        let r = InterceptRule {
305            id: "legacy-both".to_string(),
306            active: true,
307            url_pattern: "example.com".to_string(),
308            method: Some("POST".to_string()),
309            phase: "both".to_string(),
310        };
311        let rules = r.to_rules();
312        assert_eq!(rules.len(), 2);
313        assert_eq!(rules[0].id, "legacy-both-0");
314        assert_eq!(rules[1].id, "legacy-both-1");
315        assert_eq!(rules[0].stage, RuleStage::RequestHeaders);
316        assert_eq!(rules[1].stage, RuleStage::ResponseHeaders);
317
318        for rule in rules {
319            match rule.filter {
320                Filter::And(filters) => assert_eq!(filters.len(), 2),
321                other => panic!("expected And filter for method+url, got {:?}", other),
322            }
323        }
324    }
325
326    #[test]
327    fn test_matches_both_phase_excludes_ws_message_phase() {
328        let r = InterceptRule {
329            id: "legacy-both-match".to_string(),
330            active: true,
331            url_pattern: "example.com".to_string(),
332            method: None,
333            phase: "both".to_string(),
334        };
335        let http = sample_http_flow("http://example.com/path");
336        let ws = sample_ws_flow("ws://example.com/socket");
337        assert!(r.matches(&http, "request"));
338        assert!(r.matches(&http, "response"));
339        assert!(!r.matches(&ws, "ws_message"));
340    }
341
342    #[test]
343    fn test_matches_invalid_regex_falls_back_to_contains() {
344        let r = InterceptRule {
345            id: "legacy-invalid-regex".to_string(),
346            active: true,
347            url_pattern: "[".to_string(),
348            method: None,
349            phase: "request".to_string(),
350        };
351        let flow_hit = sample_http_flow("http://example.com/x[1]");
352        let flow_miss = sample_http_flow("http://example.com/x");
353        assert!(r.matches(&flow_hit, "request"));
354        assert!(!r.matches(&flow_miss, "request"));
355    }
356
357    #[test]
358    fn test_build_intercept_rules_preserves_stop_and_priority() {
359        let rules = build_intercept_rules(InterceptRuleConfig {
360            rule_id: "probe-breakpoint".to_string(),
361            active: true,
362            url_pattern: "example.com".to_string(),
363            method: None,
364            phase: "both".to_string(),
365            name: "probe-intercept:example.com".to_string(),
366            priority: 100,
367            termination: RuleTermination::Stop,
368        });
369
370        assert_eq!(rules.len(), 2);
371        assert_eq!(rules[0].id, "probe-breakpoint-0");
372        assert_eq!(rules[1].id, "probe-breakpoint-1");
373        assert_eq!(rules[0].priority, 100);
374        assert!(matches!(rules[0].termination, RuleTermination::Stop));
375        assert_eq!(rules[0].name, "probe-intercept:example.com");
376    }
377
378    #[test]
379    fn test_build_intercept_rules_invalid_regex_falls_back_to_contains() {
380        let rules = build_intercept_rules(InterceptRuleConfig {
381            rule_id: "api-breakpoint".to_string(),
382            active: true,
383            url_pattern: "[".to_string(),
384            method: Some("POST".to_string()),
385            phase: "request".to_string(),
386            name: "api-intercept:[".to_string(),
387            priority: 100,
388            termination: RuleTermination::Stop,
389        });
390
391        assert_eq!(rules.len(), 1);
392        match &rules[0].filter {
393            Filter::And(filters) => {
394                assert!(matches!(filters[0], Filter::Url(StringMatcher::Contains(_))));
395                assert!(matches!(filters[1], Filter::Method(StringMatcher::Exact(_))));
396            }
397            other => panic!("expected And filter for method+url, got {:?}", other),
398        }
399    }
400
401    #[test]
402    fn test_build_mock_response_rule_sets_mock_action_and_headers() {
403        let rule = build_mock_response_rule(MockResponseRuleConfig {
404            rule_id: "mock-rule".to_string(),
405            url_pattern: "example.com".to_string(),
406            name: "mock".to_string(),
407            status: 201,
408            content_type: "application/json".to_string(),
409            body: "{\"ok\":true}".to_string(),
410        });
411
412        assert_eq!(rule.id, "mock-rule");
413        assert_eq!(rule.stage, RuleStage::RequestHeaders);
414        assert!(matches!(rule.termination, RuleTermination::Stop));
415        match &rule.filter {
416            Filter::Url(StringMatcher::Regex(pattern)) => assert_eq!(pattern, "example.com"),
417            other => panic!("expected regex url filter, got {:?}", other),
418        }
419        match &rule.actions[0] {
420            Action::MockResponse { status, headers, body } => {
421                assert_eq!(*status, 201);
422                assert_eq!(headers.get("Content-Type").map(String::as_str), Some("application/json"));
423                match body {
424                    Some(BodySource::Text(text)) => assert_eq!(text, "{\"ok\":true}"),
425                    other => panic!("expected text body, got {:?}", other),
426                }
427            }
428            other => panic!("expected mock response action, got {:?}", other),
429        }
430    }
431}