Skip to main content

relay_core_runtime/
rule.rs

1use regex::Regex;
2use relay_core_api::flow::{Flow, Layer};
3use relay_core_lib::rule::{
4    Action, BodySource, Filter, Rule, RuleStage, RuleTermination, StringMatcher,
5};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct InterceptRule {
11    pub id: String,
12    pub active: bool,
13    pub url_pattern: String,
14    pub method: Option<String>,
15    pub phase: String, // "request", "response", "both", "ws_message"
16}
17
18#[derive(Debug, Clone)]
19pub struct InterceptRuleConfig {
20    pub rule_id: String,
21    pub active: bool,
22    pub url_pattern: String,
23    pub method: Option<String>,
24    pub phase: String,
25    pub name: String,
26    pub priority: i32,
27    pub termination: RuleTermination,
28}
29
30#[derive(Debug, Clone)]
31pub struct MockResponseRuleConfig {
32    pub rule_id: String,
33    pub url_pattern: String,
34    pub name: String,
35    pub status: u16,
36    pub content_type: String,
37    pub body: String,
38}
39
40impl InterceptRule {
41    pub fn matches(&self, flow: &Flow, phase: &str) -> bool {
42        if !self.active {
43            return false;
44        }
45
46        // Phase matching logic
47        if self.phase == "both" {
48            if phase == "ws_message" {
49                return false;
50            }
51        } else if self.phase != phase {
52            return false;
53        }
54
55        let url = match &flow.layer {
56            Layer::Http(http) => Some(http.request.url.to_string()),
57            Layer::WebSocket(ws) => Some(ws.handshake_request.url.to_string()),
58            _ => None,
59        };
60
61        let method = match &flow.layer {
62            Layer::Http(http) => Some(http.request.method.to_string()),
63            Layer::WebSocket(ws) => Some(ws.handshake_request.method.to_string()),
64            _ => None,
65        };
66
67        let url_str = url.as_deref().unwrap_or("");
68        let method_str = method.as_deref().unwrap_or("");
69
70        if let Some(m) = &self.method
71            && !m.eq_ignore_ascii_case(method_str)
72        {
73            return false;
74        }
75
76        if let Ok(re) = Regex::new(&self.url_pattern) {
77            if re.is_match(url_str) {
78                return true;
79            }
80        } else if url_str.contains(&self.url_pattern) {
81            return true;
82        }
83
84        false
85    }
86
87    pub fn to_rules(&self) -> Vec<Rule> {
88        build_intercept_rules(InterceptRuleConfig {
89            rule_id: self.id.clone(),
90            active: self.active,
91            url_pattern: self.url_pattern.clone(),
92            method: self.method.clone(),
93            phase: self.phase.clone(),
94            name: format!("Legacy Rule {}", self.id),
95            priority: 0,
96            termination: RuleTermination::Continue,
97        })
98    }
99}
100
101pub fn build_intercept_rules(config: InterceptRuleConfig) -> Vec<Rule> {
102    if !config.active {
103        return vec![];
104    }
105
106    let InterceptRuleConfig {
107        rule_id,
108        url_pattern,
109        method,
110        phase,
111        name,
112        priority,
113        termination,
114        ..
115    } = config;
116
117    let stages = match phase.as_str() {
118        "request" => vec![RuleStage::RequestHeaders],
119        "response" => vec![RuleStage::ResponseHeaders],
120        "ws_message" => vec![RuleStage::WebSocketMessage],
121        "both" => vec![RuleStage::RequestHeaders, RuleStage::ResponseHeaders],
122        _ => return vec![],
123    };
124
125    let url_filter = Filter::Url(build_url_matcher(url_pattern));
126    let filter = if let Some(method) = method {
127        Filter::And(vec![
128            url_filter,
129            Filter::Method(StringMatcher::Exact(method)),
130        ])
131    } else {
132        url_filter
133    };
134
135    let stages_len = stages.len();
136    stages
137        .into_iter()
138        .enumerate()
139        .map(|(i, stage)| Rule {
140            id: if stages_len > 1 {
141                format!("{}-{}", rule_id, i)
142            } else {
143                rule_id.clone()
144            },
145            name: name.clone(),
146            active: true,
147            stage,
148            priority,
149            termination: termination.clone(),
150            filter: filter.clone(),
151            actions: vec![Action::Inspect],
152            constraints: None,
153        })
154        .collect()
155}
156
157pub fn build_mock_response_rule(config: MockResponseRuleConfig) -> Rule {
158    let mut headers = HashMap::new();
159    headers.insert("Content-Type".to_string(), config.content_type);
160
161    Rule {
162        id: config.rule_id,
163        name: config.name,
164        active: true,
165        stage: RuleStage::RequestHeaders,
166        priority: 200,
167        termination: RuleTermination::Stop,
168        filter: Filter::Url(build_url_matcher(config.url_pattern)),
169        actions: vec![Action::MockResponse {
170            status: config.status,
171            headers,
172            body: if config.body.is_empty() {
173                None
174            } else {
175                Some(BodySource::Text(config.body))
176            },
177        }],
178        constraints: None,
179    }
180}
181
182fn build_url_matcher(url_pattern: String) -> StringMatcher {
183    if Regex::new(&url_pattern).is_ok() {
184        StringMatcher::Regex(url_pattern)
185    } else {
186        StringMatcher::Contains(url_pattern)
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::{
193        InterceptRule, InterceptRuleConfig, MockResponseRuleConfig, build_intercept_rules,
194        build_mock_response_rule,
195    };
196    use chrono::Utc;
197    use relay_core_api::flow::{
198        Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, ResponseTiming,
199        TransportProtocol, WebSocketLayer,
200    };
201    use relay_core_lib::rule::{
202        Action, BodySource, Filter, RuleStage, RuleTermination, StringMatcher,
203    };
204    use std::collections::HashMap;
205    use url::Url;
206    use uuid::Uuid;
207
208    fn sample_http_flow(url: &str) -> Flow {
209        Flow {
210            id: Uuid::new_v4(),
211            start_time: Utc::now(),
212            end_time: None,
213            network: NetworkInfo {
214                client_ip: "127.0.0.1".to_string(),
215                client_port: 12345,
216                server_ip: "1.1.1.1".to_string(),
217                server_port: 80,
218                protocol: TransportProtocol::TCP,
219                tls: false,
220                tls_version: None,
221                sni: None,
222            },
223            layer: Layer::Http(HttpLayer {
224                request: HttpRequest {
225                    method: "GET".to_string(),
226                    url: Url::parse(url).expect("url"),
227                    version: "HTTP/1.1".to_string(),
228                    headers: vec![],
229                    body: None,
230                    cookies: vec![],
231                    query: vec![],
232                },
233                response: None,
234                error: None,
235            }),
236            tags: vec![],
237            meta: HashMap::new(),
238        }
239    }
240
241    fn sample_ws_flow(url: &str) -> Flow {
242        Flow {
243            id: Uuid::new_v4(),
244            start_time: Utc::now(),
245            end_time: None,
246            network: NetworkInfo {
247                client_ip: "127.0.0.1".to_string(),
248                client_port: 12345,
249                server_ip: "1.1.1.1".to_string(),
250                server_port: 80,
251                protocol: TransportProtocol::TCP,
252                tls: false,
253                tls_version: None,
254                sni: None,
255            },
256            layer: Layer::WebSocket(WebSocketLayer {
257                handshake_request: HttpRequest {
258                    method: "GET".to_string(),
259                    url: Url::parse(url).expect("url"),
260                    version: "HTTP/1.1".to_string(),
261                    headers: vec![],
262                    body: None,
263                    cookies: vec![],
264                    query: vec![],
265                },
266                handshake_response: HttpResponse {
267                    status: 101,
268                    status_text: "Switching Protocols".to_string(),
269                    version: "HTTP/1.1".to_string(),
270                    headers: vec![],
271                    body: None,
272                    timing: ResponseTiming {
273                        time_to_first_byte: None,
274                        time_to_last_byte: None,
275                        connect_time_ms: None,
276                        ssl_time_ms: None,
277                    },
278                    cookies: vec![],
279                },
280                messages: vec![],
281                closed: false,
282            }),
283            tags: vec![],
284            meta: HashMap::new(),
285        }
286    }
287
288    #[test]
289    fn test_to_rules_inactive_returns_empty() {
290        let r = InterceptRule {
291            id: "legacy-inactive".to_string(),
292            active: false,
293            url_pattern: "example.com".to_string(),
294            method: None,
295            phase: "request".to_string(),
296        };
297        assert!(r.to_rules().is_empty());
298    }
299
300    #[test]
301    fn test_to_rules_invalid_phase_returns_empty() {
302        let r = InterceptRule {
303            id: "legacy-invalid".to_string(),
304            active: true,
305            url_pattern: "example.com".to_string(),
306            method: None,
307            phase: "not-a-phase".to_string(),
308        };
309        assert!(r.to_rules().is_empty());
310    }
311
312    #[test]
313    fn test_to_rules_both_phase_generates_two_stages_with_suffix_ids() {
314        let r = InterceptRule {
315            id: "legacy-both".to_string(),
316            active: true,
317            url_pattern: "example.com".to_string(),
318            method: Some("POST".to_string()),
319            phase: "both".to_string(),
320        };
321        let rules = r.to_rules();
322        assert_eq!(rules.len(), 2);
323        assert_eq!(rules[0].id, "legacy-both-0");
324        assert_eq!(rules[1].id, "legacy-both-1");
325        assert_eq!(rules[0].stage, RuleStage::RequestHeaders);
326        assert_eq!(rules[1].stage, RuleStage::ResponseHeaders);
327
328        for rule in rules {
329            match rule.filter {
330                Filter::And(filters) => assert_eq!(filters.len(), 2),
331                other => panic!("expected And filter for method+url, got {:?}", other),
332            }
333        }
334    }
335
336    #[test]
337    fn test_matches_both_phase_excludes_ws_message_phase() {
338        let r = InterceptRule {
339            id: "legacy-both-match".to_string(),
340            active: true,
341            url_pattern: "example.com".to_string(),
342            method: None,
343            phase: "both".to_string(),
344        };
345        let http = sample_http_flow("http://example.com/path");
346        let ws = sample_ws_flow("ws://example.com/socket");
347        assert!(r.matches(&http, "request"));
348        assert!(r.matches(&http, "response"));
349        assert!(!r.matches(&ws, "ws_message"));
350    }
351
352    #[test]
353    fn test_matches_invalid_regex_falls_back_to_contains() {
354        let r = InterceptRule {
355            id: "legacy-invalid-regex".to_string(),
356            active: true,
357            url_pattern: "[".to_string(),
358            method: None,
359            phase: "request".to_string(),
360        };
361        let flow_hit = sample_http_flow("http://example.com/x[1]");
362        let flow_miss = sample_http_flow("http://example.com/x");
363        assert!(r.matches(&flow_hit, "request"));
364        assert!(!r.matches(&flow_miss, "request"));
365    }
366
367    #[test]
368    fn test_build_intercept_rules_preserves_stop_and_priority() {
369        let rules = build_intercept_rules(InterceptRuleConfig {
370            rule_id: "probe-breakpoint".to_string(),
371            active: true,
372            url_pattern: "example.com".to_string(),
373            method: None,
374            phase: "both".to_string(),
375            name: "probe-intercept:example.com".to_string(),
376            priority: 100,
377            termination: RuleTermination::Stop,
378        });
379
380        assert_eq!(rules.len(), 2);
381        assert_eq!(rules[0].id, "probe-breakpoint-0");
382        assert_eq!(rules[1].id, "probe-breakpoint-1");
383        assert_eq!(rules[0].priority, 100);
384        assert!(matches!(rules[0].termination, RuleTermination::Stop));
385        assert_eq!(rules[0].name, "probe-intercept:example.com");
386    }
387
388    #[test]
389    fn test_build_intercept_rules_invalid_regex_falls_back_to_contains() {
390        let rules = build_intercept_rules(InterceptRuleConfig {
391            rule_id: "api-breakpoint".to_string(),
392            active: true,
393            url_pattern: "[".to_string(),
394            method: Some("POST".to_string()),
395            phase: "request".to_string(),
396            name: "api-intercept:[".to_string(),
397            priority: 100,
398            termination: RuleTermination::Stop,
399        });
400
401        assert_eq!(rules.len(), 1);
402        match &rules[0].filter {
403            Filter::And(filters) => {
404                assert!(matches!(
405                    filters[0],
406                    Filter::Url(StringMatcher::Contains(_))
407                ));
408                assert!(matches!(
409                    filters[1],
410                    Filter::Method(StringMatcher::Exact(_))
411                ));
412            }
413            other => panic!("expected And filter for method+url, got {:?}", other),
414        }
415    }
416
417    #[test]
418    fn test_build_mock_response_rule_sets_mock_action_and_headers() {
419        let rule = build_mock_response_rule(MockResponseRuleConfig {
420            rule_id: "mock-rule".to_string(),
421            url_pattern: "example.com".to_string(),
422            name: "mock".to_string(),
423            status: 201,
424            content_type: "application/json".to_string(),
425            body: "{\"ok\":true}".to_string(),
426        });
427
428        assert_eq!(rule.id, "mock-rule");
429        assert_eq!(rule.stage, RuleStage::RequestHeaders);
430        assert!(matches!(rule.termination, RuleTermination::Stop));
431        match &rule.filter {
432            Filter::Url(StringMatcher::Regex(pattern)) => assert_eq!(pattern, "example.com"),
433            other => panic!("expected regex url filter, got {:?}", other),
434        }
435        match &rule.actions[0] {
436            Action::MockResponse {
437                status,
438                headers,
439                body,
440            } => {
441                assert_eq!(*status, 201);
442                assert_eq!(
443                    headers.get("Content-Type").map(String::as_str),
444                    Some("application/json")
445                );
446                match body {
447                    Some(BodySource::Text(text)) => assert_eq!(text, "{\"ok\":true}"),
448                    other => panic!("expected text body, got {:?}", other),
449                }
450            }
451            other => panic!("expected mock response action, got {:?}", other),
452        }
453    }
454}