Skip to main content

relay_core_runtime/
rule.rs

1use regex::Regex;
2use relay_core_api::flow::{Flow, Layer};
3use relay_core_lib::rule::{
4    Action, BodySource, Filter, Rule, RuleStage, RuleTermination, StringMatcher,
5};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct InterceptRule {
11    pub id: String,
12    pub active: bool,
13    pub url_pattern: String,
14    pub method: Option<String>,
15    pub phase: String, // "request", "response", "both", "ws_message"
16}
17
18#[derive(Debug, Clone)]
19pub struct InterceptRuleConfig {
20    pub rule_id: String,
21    pub active: bool,
22    pub url_pattern: String,
23    pub method: Option<String>,
24    pub phase: String,
25    pub name: String,
26    pub priority: i32,
27    pub termination: RuleTermination,
28}
29
30#[derive(Debug, Clone)]
31pub struct MockResponseRuleConfig {
32    pub rule_id: String,
33    pub url_pattern: String,
34    pub name: String,
35    pub status: u16,
36    pub content_type: String,
37    pub body: String,
38}
39
40impl InterceptRule {
41    pub fn matches(&self, flow: &Flow, phase: &str) -> bool {
42        if !self.active {
43            return false;
44        }
45
46        // Phase matching logic
47        if self.phase == "both" {
48            if phase == "ws_message" {
49                return false;
50            }
51        } else if self.phase != phase {
52            return false;
53        }
54
55        let url = match &flow.layer {
56            Layer::Http(http) => Some(http.request.url.to_string()),
57            Layer::WebSocket(ws) => Some(ws.handshake_request.url.to_string()),
58            _ => None,
59        };
60
61        let method = match &flow.layer {
62            Layer::Http(http) => Some(http.request.method.to_string()),
63            Layer::WebSocket(ws) => Some(ws.handshake_request.method.to_string()),
64            _ => None,
65        };
66
67        let url_str = url.as_deref().unwrap_or("");
68        let method_str = method.as_deref().unwrap_or("");
69
70        if let Some(m) = &self.method
71            && !m.eq_ignore_ascii_case(method_str)
72        {
73            return false;
74        }
75
76        if let Ok(re) = Regex::new(&self.url_pattern) {
77            if re.is_match(url_str) {
78                return true;
79            }
80        } else if url_str.contains(&self.url_pattern) {
81            return true;
82        }
83
84        false
85    }
86
87    pub fn to_rules(&self) -> Vec<Rule> {
88        build_intercept_rules(InterceptRuleConfig {
89            rule_id: self.id.clone(),
90            active: self.active,
91            url_pattern: self.url_pattern.clone(),
92            method: self.method.clone(),
93            phase: self.phase.clone(),
94            name: format!("Legacy Rule {}", self.id),
95            priority: 0,
96            termination: RuleTermination::Continue,
97        })
98    }
99}
100
101pub fn build_intercept_rules(config: InterceptRuleConfig) -> Vec<Rule> {
102    if !config.active {
103        return vec![];
104    }
105
106    let InterceptRuleConfig {
107        rule_id,
108        url_pattern,
109        method,
110        phase,
111        name,
112        priority,
113        termination,
114        ..
115    } = config;
116
117    let stages = match phase.as_str() {
118        "request" => vec![RuleStage::RequestHeaders],
119        "response" => vec![RuleStage::ResponseHeaders],
120        "ws_message" => vec![RuleStage::WebSocketMessage],
121        "both" => vec![RuleStage::RequestHeaders, RuleStage::ResponseHeaders],
122        _ => return vec![],
123    };
124
125    let url_filter = Filter::Url(build_url_matcher(url_pattern));
126    let filter = if let Some(method) = method {
127        Filter::And(vec![
128            url_filter,
129            Filter::Method(StringMatcher::Exact(method)),
130        ])
131    } else {
132        url_filter
133    };
134
135    let stages_len = stages.len();
136    stages
137        .into_iter()
138        .enumerate()
139        .map(|(i, stage)| Rule {
140            id: if stages_len > 1 {
141                format!("{}-{}", rule_id, i)
142            } else {
143                rule_id.clone()
144            },
145            name: name.clone(),
146            active: true,
147            stage,
148            priority,
149            termination: termination.clone(),
150            filter: filter.clone(),
151            actions: vec![Action::Inspect],
152            constraints: None,
153        })
154        .collect()
155}
156
157pub fn build_mock_response_rule(config: MockResponseRuleConfig) -> Rule {
158    let mut headers = HashMap::new();
159    headers.insert("Content-Type".to_string(), config.content_type);
160
161    Rule {
162        id: config.rule_id,
163        name: config.name,
164        active: true,
165        stage: RuleStage::RequestHeaders,
166        priority: 200,
167        termination: RuleTermination::Stop,
168        filter: Filter::Url(build_url_matcher(config.url_pattern)),
169        actions: vec![Action::MockResponse {
170            status: config.status,
171            headers,
172            body: if config.body.is_empty() {
173                None
174            } else {
175                Some(BodySource::Text(config.body))
176            },
177        }],
178        constraints: None,
179    }
180}
181
182fn build_url_matcher(url_pattern: String) -> StringMatcher {
183    if Regex::new(&url_pattern).is_ok() {
184        StringMatcher::Regex(url_pattern)
185    } else {
186        StringMatcher::Contains(url_pattern)
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::{
193        InterceptRule, InterceptRuleConfig, MockResponseRuleConfig, build_intercept_rules,
194        build_mock_response_rule,
195    };
196    use chrono::Utc;
197    use relay_core_api::flow::{
198        Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, ResponseTiming,
199        TransportProtocol, WebSocketLayer,
200    };
201    use relay_core_lib::rule::{
202        Action, BodySource, Filter, RuleStage, RuleTermination, StringMatcher,
203    };
204    use std::collections::HashMap;
205    use url::Url;
206    use uuid::Uuid;
207
208    fn sample_http_flow(url: &str) -> Flow {
209        Flow {
210            id: Uuid::new_v4(),
211            start_time: Utc::now(),
212            end_time: None,
213            network: NetworkInfo {
214                client_ip: "127.0.0.1".to_string(),
215                client_port: 12345,
216                server_ip: "1.1.1.1".to_string(),
217                server_port: 80,
218                protocol: TransportProtocol::TCP,
219                tls: false,
220                tls_version: None,
221                sni: None,
222            },
223            layer: Layer::Http(HttpLayer {
224                request: HttpRequest {
225                    method: "GET".to_string(),
226                    url: Url::parse(url).expect("url"),
227                    version: "HTTP/1.1".to_string(),
228                    headers: vec![],
229                    body: None,
230                    cookies: vec![],
231                    query: vec![],
232                },
233                response: None,
234                error: None,
235            }),
236            tags: vec![],
237            meta: HashMap::new(),
238            resilience_trace: None,
239            rule_variables: std::collections::HashMap::new(),
240            matched_rules: vec![],
241        }
242    }
243
244    fn sample_ws_flow(url: &str) -> Flow {
245        Flow {
246            id: Uuid::new_v4(),
247            start_time: Utc::now(),
248            end_time: None,
249            network: NetworkInfo {
250                client_ip: "127.0.0.1".to_string(),
251                client_port: 12345,
252                server_ip: "1.1.1.1".to_string(),
253                server_port: 80,
254                protocol: TransportProtocol::TCP,
255                tls: false,
256                tls_version: None,
257                sni: None,
258            },
259            layer: Layer::WebSocket(WebSocketLayer {
260                handshake_request: HttpRequest {
261                    method: "GET".to_string(),
262                    url: Url::parse(url).expect("url"),
263                    version: "HTTP/1.1".to_string(),
264                    headers: vec![],
265                    body: None,
266                    cookies: vec![],
267                    query: vec![],
268                },
269                handshake_response: HttpResponse {
270                    status: 101,
271                    status_text: "Switching Protocols".to_string(),
272                    version: "HTTP/1.1".to_string(),
273                    headers: vec![],
274                    body: None,
275                    timing: ResponseTiming {
276                        time_to_first_byte: None,
277                        time_to_last_byte: None,
278                        connect_time_ms: None,
279                        ssl_time_ms: None,
280                    },
281                    cookies: vec![],
282                },
283                messages: vec![],
284                closed: false,
285            }),
286            tags: vec![],
287            meta: HashMap::new(),
288            resilience_trace: None,
289            rule_variables: std::collections::HashMap::new(),
290            matched_rules: vec![],
291        }
292    }
293
294    #[test]
295    fn test_to_rules_inactive_returns_empty() {
296        let r = InterceptRule {
297            id: "legacy-inactive".to_string(),
298            active: false,
299            url_pattern: "example.com".to_string(),
300            method: None,
301            phase: "request".to_string(),
302        };
303        assert!(r.to_rules().is_empty());
304    }
305
306    #[test]
307    fn test_to_rules_invalid_phase_returns_empty() {
308        let r = InterceptRule {
309            id: "legacy-invalid".to_string(),
310            active: true,
311            url_pattern: "example.com".to_string(),
312            method: None,
313            phase: "not-a-phase".to_string(),
314        };
315        assert!(r.to_rules().is_empty());
316    }
317
318    #[test]
319    fn test_to_rules_both_phase_generates_two_stages_with_suffix_ids() {
320        let r = InterceptRule {
321            id: "legacy-both".to_string(),
322            active: true,
323            url_pattern: "example.com".to_string(),
324            method: Some("POST".to_string()),
325            phase: "both".to_string(),
326        };
327        let rules = r.to_rules();
328        assert_eq!(rules.len(), 2);
329        assert_eq!(rules[0].id, "legacy-both-0");
330        assert_eq!(rules[1].id, "legacy-both-1");
331        assert_eq!(rules[0].stage, RuleStage::RequestHeaders);
332        assert_eq!(rules[1].stage, RuleStage::ResponseHeaders);
333
334        for rule in rules {
335            match rule.filter {
336                Filter::And(filters) => assert_eq!(filters.len(), 2),
337                other => panic!("expected And filter for method+url, got {:?}", other),
338            }
339        }
340    }
341
342    #[test]
343    fn test_matches_both_phase_excludes_ws_message_phase() {
344        let r = InterceptRule {
345            id: "legacy-both-match".to_string(),
346            active: true,
347            url_pattern: "example.com".to_string(),
348            method: None,
349            phase: "both".to_string(),
350        };
351        let http = sample_http_flow("http://example.com/path");
352        let ws = sample_ws_flow("ws://example.com/socket");
353        assert!(r.matches(&http, "request"));
354        assert!(r.matches(&http, "response"));
355        assert!(!r.matches(&ws, "ws_message"));
356    }
357
358    #[test]
359    fn test_matches_invalid_regex_falls_back_to_contains() {
360        let r = InterceptRule {
361            id: "legacy-invalid-regex".to_string(),
362            active: true,
363            url_pattern: "[".to_string(),
364            method: None,
365            phase: "request".to_string(),
366        };
367        let flow_hit = sample_http_flow("http://example.com/x[1]");
368        let flow_miss = sample_http_flow("http://example.com/x");
369        assert!(r.matches(&flow_hit, "request"));
370        assert!(!r.matches(&flow_miss, "request"));
371    }
372
373    #[test]
374    fn test_build_intercept_rules_preserves_stop_and_priority() {
375        let rules = build_intercept_rules(InterceptRuleConfig {
376            rule_id: "probe-breakpoint".to_string(),
377            active: true,
378            url_pattern: "example.com".to_string(),
379            method: None,
380            phase: "both".to_string(),
381            name: "probe-intercept:example.com".to_string(),
382            priority: 100,
383            termination: RuleTermination::Stop,
384        });
385
386        assert_eq!(rules.len(), 2);
387        assert_eq!(rules[0].id, "probe-breakpoint-0");
388        assert_eq!(rules[1].id, "probe-breakpoint-1");
389        assert_eq!(rules[0].priority, 100);
390        assert!(matches!(rules[0].termination, RuleTermination::Stop));
391        assert_eq!(rules[0].name, "probe-intercept:example.com");
392    }
393
394    #[test]
395    fn test_build_intercept_rules_invalid_regex_falls_back_to_contains() {
396        let rules = build_intercept_rules(InterceptRuleConfig {
397            rule_id: "api-breakpoint".to_string(),
398            active: true,
399            url_pattern: "[".to_string(),
400            method: Some("POST".to_string()),
401            phase: "request".to_string(),
402            name: "api-intercept:[".to_string(),
403            priority: 100,
404            termination: RuleTermination::Stop,
405        });
406
407        assert_eq!(rules.len(), 1);
408        match &rules[0].filter {
409            Filter::And(filters) => {
410                assert!(matches!(
411                    filters[0],
412                    Filter::Url(StringMatcher::Contains(_))
413                ));
414                assert!(matches!(
415                    filters[1],
416                    Filter::Method(StringMatcher::Exact(_))
417                ));
418            }
419            other => panic!("expected And filter for method+url, got {:?}", other),
420        }
421    }
422
423    #[test]
424    fn test_build_mock_response_rule_sets_mock_action_and_headers() {
425        let rule = build_mock_response_rule(MockResponseRuleConfig {
426            rule_id: "mock-rule".to_string(),
427            url_pattern: "example.com".to_string(),
428            name: "mock".to_string(),
429            status: 201,
430            content_type: "application/json".to_string(),
431            body: "{\"ok\":true}".to_string(),
432        });
433
434        assert_eq!(rule.id, "mock-rule");
435        assert_eq!(rule.stage, RuleStage::RequestHeaders);
436        assert!(matches!(rule.termination, RuleTermination::Stop));
437        match &rule.filter {
438            Filter::Url(StringMatcher::Regex(pattern)) => assert_eq!(pattern, "example.com"),
439            other => panic!("expected regex url filter, got {:?}", other),
440        }
441        match &rule.actions[0] {
442            Action::MockResponse {
443                status,
444                headers,
445                body,
446            } => {
447                assert_eq!(*status, 201);
448                assert_eq!(
449                    headers.get("Content-Type").map(String::as_str),
450                    Some("application/json")
451                );
452                match body {
453                    Some(BodySource::Text(text)) => assert_eq!(text, "{\"ok\":true}"),
454                    other => panic!("expected text body, got {:?}", other),
455                }
456            }
457            other => panic!("expected mock response action, got {:?}", other),
458        }
459    }
460}