Skip to main content

relay_core_runtime/
modification.rs

1use relay_core_api::flow::{BodyData, Flow, HttpResponse, Layer, ResponseTiming, WebSocketMessage};
2use relay_core_api::modification::FlowModification;
3use relay_core_lib::InterceptionResult;
4use url::Url;
5
6// Re-export API modification types at relay_core_runtime::modification::
7pub use relay_core_api::modification::{FlowQuery, FlowSummary};
8
9/// 将 FlowModification 应用到 Flow 的请求或响应上。
10///
11/// `phase` 约定:以 "request" 开头表示修改请求,以 "response" 开头表示修改响应,
12/// 其他值返回 Continue。
13///
14/// 如果 Flow 的 Layer 不支持(Tcp/Unknown 等),同样返回 Continue。
15pub fn apply_flow_modification(
16    flow: &Flow,
17    phase: &str,
18    mods: FlowModification,
19) -> InterceptionResult {
20    if phase.starts_with("request") {
21        let mut req = match &flow.layer {
22            Layer::Http(h) => h.request.clone(),
23            Layer::WebSocket(ws) => ws.handshake_request.clone(),
24            _ => return InterceptionResult::Continue,
25        };
26
27        if let Some(m) = mods.method {
28            req.method = m;
29        }
30        if let Some(u) = mods.url {
31            // 无效 URL 静默忽略,保留原始值
32            if let Ok(parsed) = Url::parse(&u) {
33                req.url = parsed;
34            }
35        }
36        if let Some(h) = mods.request_headers {
37            req.headers = h.into_iter().collect();
38        }
39        if let Some(b) = mods.request_body {
40            req.body = Some(BodyData {
41                encoding: "utf-8".to_string(),
42                size: b.len() as u64,
43                content: b,
44            });
45        }
46
47        InterceptionResult::ModifiedRequest(req)
48    } else if phase.starts_with("response") {
49        let mut res = match &flow.layer {
50            Layer::Http(h) => h.response.clone().unwrap_or_else(|| HttpResponse {
51                status: 200,
52                status_text: "OK".to_string(),
53                version: "HTTP/1.1".to_string(),
54                headers: vec![],
55                body: None,
56                timing: ResponseTiming {
57                    time_to_first_byte: None,
58                    time_to_last_byte: None,
59                    connect_time_ms: None,
60                    ssl_time_ms: None,
61                },
62                cookies: vec![],
63            }),
64            Layer::WebSocket(ws) => ws.handshake_response.clone(),
65            _ => return InterceptionResult::Continue,
66        };
67
68        if let Some(s) = mods.status_code {
69            res.status = s;
70        }
71        if let Some(h) = mods.response_headers {
72            res.headers = h.into_iter().collect();
73        }
74        if let Some(b) = mods.response_body {
75            res.body = Some(BodyData {
76                encoding: "utf-8".to_string(),
77                size: b.len() as u64,
78                content: b,
79            });
80        }
81
82        InterceptionResult::ModifiedResponse(res)
83    } else {
84        InterceptionResult::Continue
85    }
86}
87
88/// 将 FlowModification 应用到 WebSocket 消息上。
89///
90/// 仅修改 message_content;其余字段(方向、opcode、时间戳等)保持不变。
91/// 若 modification 不含 message_content,则原样返回消息。
92pub fn apply_ws_modification(
93    message: &WebSocketMessage,
94    mods: FlowModification,
95) -> InterceptionResult {
96    let mut new_msg = message.clone();
97    if let Some(content) = mods.message_content {
98        new_msg.content.size = content.len() as u64;
99        new_msg.content.content = content;
100    }
101    InterceptionResult::ModifiedMessage(new_msg)
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use chrono::Utc;
108    use relay_core_api::flow::{
109        BodyData, Direction, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo,
110        ResponseTiming, TransportProtocol, WebSocketLayer, WebSocketMessage,
111    };
112    use relay_core_api::modification::FlowModification;
113    use std::collections::HashMap;
114    use url::Url;
115    use uuid::Uuid;
116
117    fn make_http_flow(url: &str) -> Flow {
118        Flow {
119            id: Uuid::new_v4(),
120            start_time: Utc::now(),
121            end_time: None,
122            network: NetworkInfo {
123                client_ip: "127.0.0.1".to_string(),
124                client_port: 12345,
125                server_ip: "1.1.1.1".to_string(),
126                server_port: 80,
127                protocol: TransportProtocol::TCP,
128                tls: false,
129                tls_version: None,
130                sni: None,
131            },
132            layer: Layer::Http(HttpLayer {
133                request: HttpRequest {
134                    method: "GET".to_string(),
135                    url: Url::parse(url).unwrap(),
136                    version: "HTTP/1.1".to_string(),
137                    headers: vec![],
138                    body: None,
139                    cookies: vec![],
140                    query: vec![],
141                },
142                response: None,
143                error: None,
144            }),
145            tags: vec![],
146            meta: HashMap::new(),
147            resilience_trace: None,
148            rule_variables: std::collections::HashMap::new(),
149            matched_rules: vec![],
150        }
151    }
152
153    fn make_http_flow_with_response(url: &str) -> Flow {
154        let mut flow = make_http_flow(url);
155        if let Layer::Http(ref mut h) = flow.layer {
156            h.response = Some(HttpResponse {
157                status: 200,
158                status_text: "OK".to_string(),
159                version: "HTTP/1.1".to_string(),
160                headers: vec![],
161                body: None,
162                timing: ResponseTiming {
163                    time_to_first_byte: None,
164                    time_to_last_byte: None,
165                    connect_time_ms: None,
166                    ssl_time_ms: None,
167                },
168                cookies: vec![],
169            });
170        }
171        flow
172    }
173
174    fn make_ws_flow(url: &str) -> Flow {
175        Flow {
176            id: Uuid::new_v4(),
177            start_time: Utc::now(),
178            end_time: None,
179            network: NetworkInfo {
180                client_ip: "127.0.0.1".to_string(),
181                client_port: 12345,
182                server_ip: "1.1.1.1".to_string(),
183                server_port: 80,
184                protocol: TransportProtocol::TCP,
185                tls: false,
186                tls_version: None,
187                sni: None,
188            },
189            layer: Layer::WebSocket(WebSocketLayer {
190                handshake_request: HttpRequest {
191                    method: "GET".to_string(),
192                    url: Url::parse(url).unwrap(),
193                    version: "HTTP/1.1".to_string(),
194                    headers: vec![("Upgrade".to_string(), "websocket".to_string())],
195                    body: None,
196                    cookies: vec![],
197                    query: vec![],
198                },
199                handshake_response: HttpResponse {
200                    status: 101,
201                    status_text: "Switching Protocols".to_string(),
202                    version: "HTTP/1.1".to_string(),
203                    headers: vec![],
204                    body: None,
205                    timing: ResponseTiming {
206                        time_to_first_byte: None,
207                        time_to_last_byte: None,
208                        connect_time_ms: None,
209                        ssl_time_ms: None,
210                    },
211                    cookies: vec![],
212                },
213                messages: vec![],
214                closed: false,
215            }),
216            tags: vec![],
217            meta: HashMap::new(),
218            resilience_trace: None,
219            rule_variables: std::collections::HashMap::new(),
220            matched_rules: vec![],
221        }
222    }
223
224    fn make_ws_message(content: &str) -> WebSocketMessage {
225        WebSocketMessage {
226            id: Uuid::new_v4(),
227            timestamp: Utc::now(),
228            direction: Direction::ClientToServer,
229            content: BodyData {
230                encoding: "utf-8".to_string(),
231                content: content.to_string(),
232                size: content.len() as u64,
233            },
234            opcode: "Text".to_string(),
235        }
236    }
237
238    // --- apply_flow_modification: request phase ---
239
240    #[test]
241    fn test_request_modification_applies_all_fields() {
242        let flow = make_http_flow("http://example.com/api");
243        let mods = FlowModification {
244            method: Some("POST".to_string()),
245            url: Some("http://example.com/v2/api".to_string()),
246            request_headers: Some(HashMap::from([("X-Custom".to_string(), "123".to_string())])),
247            request_body: Some("new-body".to_string()),
248            ..Default::default()
249        };
250
251        let result = apply_flow_modification(&flow, "request", mods);
252
253        if let InterceptionResult::ModifiedRequest(req) = result {
254            assert_eq!(req.method, "POST");
255            assert_eq!(req.url.as_str(), "http://example.com/v2/api");
256            assert!(
257                req.headers
258                    .iter()
259                    .any(|(k, v)| k == "X-Custom" && v == "123")
260            );
261            assert_eq!(req.body.unwrap().content, "new-body");
262        } else {
263            panic!("expected ModifiedRequest");
264        }
265    }
266
267    #[test]
268    fn test_request_modification_invalid_url_keeps_original() {
269        let flow = make_http_flow("http://example.com/api");
270        let original_url = match &flow.layer {
271            Layer::Http(h) => h.request.url.clone(),
272            _ => panic!("expected http layer"),
273        };
274        let mods = FlowModification {
275            method: Some("PUT".to_string()),
276            url: Some("://invalid-url".to_string()),
277            ..Default::default()
278        };
279
280        let result = apply_flow_modification(&flow, "request", mods);
281
282        match result {
283            InterceptionResult::ModifiedRequest(req) => {
284                assert_eq!(req.method, "PUT");
285                assert_eq!(req.url, original_url, "invalid URL should keep original");
286            }
287            other => panic!("expected ModifiedRequest, got {:?}", other),
288        }
289    }
290
291    #[test]
292    fn test_request_headers_phase_prefix_routes_to_request() {
293        let flow = make_http_flow("http://example.com/api");
294        let mods = FlowModification {
295            method: Some("PATCH".to_string()),
296            ..Default::default()
297        };
298        let result = apply_flow_modification(&flow, "request_headers", mods);
299        assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
300    }
301
302    #[test]
303    fn test_request_body_phase_prefix_routes_to_request() {
304        let flow = make_http_flow("http://example.com/api");
305        let mods = FlowModification {
306            request_body: Some("hello".to_string()),
307            ..Default::default()
308        };
309        let result = apply_flow_modification(&flow, "request_body", mods);
310        assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
311    }
312
313    // --- apply_flow_modification: response phase ---
314
315    #[test]
316    fn test_response_modification_applies_all_fields() {
317        let flow = make_http_flow_with_response("http://example.com/api");
318        let mods = FlowModification {
319            status_code: Some(404),
320            response_headers: Some(HashMap::from([(
321                "Content-Type".to_string(),
322                "application/json".to_string(),
323            )])),
324            response_body: Some("{\"error\": \"not found\"}".to_string()),
325            ..Default::default()
326        };
327
328        let result = apply_flow_modification(&flow, "response", mods);
329
330        if let InterceptionResult::ModifiedResponse(res) = result {
331            assert_eq!(res.status, 404);
332            assert!(
333                res.headers
334                    .iter()
335                    .any(|(k, v)| k == "Content-Type" && v == "application/json")
336            );
337            assert_eq!(res.body.unwrap().content, "{\"error\": \"not found\"}");
338        } else {
339            panic!("expected ModifiedResponse");
340        }
341    }
342
343    #[test]
344    fn test_response_modification_no_existing_response_uses_default() {
345        // Flow has no response yet
346        let flow = make_http_flow("http://example.com/api");
347        let mods = FlowModification {
348            status_code: Some(503),
349            ..Default::default()
350        };
351
352        let result = apply_flow_modification(&flow, "response_headers", mods);
353
354        if let InterceptionResult::ModifiedResponse(res) = result {
355            assert_eq!(res.status, 503);
356        } else {
357            panic!("expected ModifiedResponse");
358        }
359    }
360
361    // --- apply_flow_modification: websocket handshake ---
362
363    #[test]
364    fn test_ws_handshake_request_modification() {
365        let flow = make_ws_flow("ws://example.com/socket");
366        let mods = FlowModification {
367            url: Some("ws://example.com/socket-v2".to_string()),
368            ..Default::default()
369        };
370
371        let result = apply_flow_modification(&flow, "request", mods);
372
373        if let InterceptionResult::ModifiedRequest(req) = result {
374            assert_eq!(req.url.as_str(), "ws://example.com/socket-v2");
375        } else {
376            panic!("expected ModifiedRequest for WebSocket handshake");
377        }
378    }
379
380    // --- apply_flow_modification: unknown phase ---
381
382    #[test]
383    fn test_unknown_phase_returns_continue() {
384        let flow = make_http_flow("http://example.com/api");
385        let mods = FlowModification::default();
386        let result = apply_flow_modification(&flow, "pre-request", mods);
387        assert!(matches!(result, InterceptionResult::Continue));
388    }
389
390    // --- apply_ws_modification ---
391
392    #[test]
393    fn test_ws_modification_replaces_content() {
394        let msg = make_ws_message("original");
395        let mods = FlowModification {
396            message_content: Some("modified".to_string()),
397            ..Default::default()
398        };
399
400        let result = apply_ws_modification(&msg, mods);
401
402        if let InterceptionResult::ModifiedMessage(new_msg) = result {
403            assert_eq!(new_msg.content.content, "modified");
404            assert_eq!(new_msg.content.size, 8);
405            assert_eq!(new_msg.direction, Direction::ClientToServer);
406            assert_eq!(new_msg.opcode, "Text");
407        } else {
408            panic!("expected ModifiedMessage");
409        }
410    }
411
412    #[test]
413    fn test_ws_modification_no_content_returns_original_message() {
414        let msg = make_ws_message("origin");
415        let mods = FlowModification::default();
416
417        let result = apply_ws_modification(&msg, mods);
418
419        if let InterceptionResult::ModifiedMessage(new_msg) = result {
420            assert_eq!(new_msg.content.content, "origin");
421            assert_eq!(new_msg.content.size, 6);
422        } else {
423            panic!("expected ModifiedMessage");
424        }
425    }
426}