Skip to main content

relay_core_runtime/
modification.rs

1use relay_core_api::flow::{BodyData, Flow, HttpResponse, Layer, ResponseTiming, WebSocketMessage};
2use relay_core_api::modification::FlowModification;
3use relay_core_lib::InterceptionResult;
4use url::Url;
5
6// Re-export API modification types at relay_core_runtime::modification::
7pub use relay_core_api::modification::{FlowQuery, FlowSummary};
8
9/// 将 FlowModification 应用到 Flow 的请求或响应上。
10///
11/// `phase` 约定:以 "request" 开头表示修改请求,以 "response" 开头表示修改响应,
12/// 其他值返回 Continue。
13///
14/// 如果 Flow 的 Layer 不支持(Tcp/Unknown 等),同样返回 Continue。
15pub fn apply_flow_modification(
16    flow: &Flow,
17    phase: &str,
18    mods: FlowModification,
19) -> InterceptionResult {
20    if phase.starts_with("request") {
21        let mut req = match &flow.layer {
22            Layer::Http(h) => h.request.clone(),
23            Layer::WebSocket(ws) => ws.handshake_request.clone(),
24            _ => return InterceptionResult::Continue,
25        };
26
27        if let Some(m) = mods.method {
28            req.method = m;
29        }
30        if let Some(u) = mods.url {
31            // 无效 URL 静默忽略,保留原始值
32            if let Ok(parsed) = Url::parse(&u) {
33                req.url = parsed;
34            }
35        }
36        if let Some(h) = mods.request_headers {
37            req.headers = h.into_iter().collect();
38        }
39        if let Some(b) = mods.request_body {
40            req.body = Some(BodyData {
41                encoding: "utf-8".to_string(),
42                size: b.len() as u64,
43                content: b,
44            });
45        }
46
47        InterceptionResult::ModifiedRequest(req)
48    } else if phase.starts_with("response") {
49        let mut res = match &flow.layer {
50            Layer::Http(h) => h.response.clone().unwrap_or_else(|| HttpResponse {
51                status: 200,
52                status_text: "OK".to_string(),
53                version: "HTTP/1.1".to_string(),
54                headers: vec![],
55                body: None,
56                timing: ResponseTiming {
57                    time_to_first_byte: None,
58                    time_to_last_byte: None,
59                    connect_time_ms: None,
60                    ssl_time_ms: None,
61                },
62                cookies: vec![],
63            }),
64            Layer::WebSocket(ws) => ws.handshake_response.clone(),
65            _ => return InterceptionResult::Continue,
66        };
67
68        if let Some(s) = mods.status_code {
69            res.status = s;
70        }
71        if let Some(h) = mods.response_headers {
72            res.headers = h.into_iter().collect();
73        }
74        if let Some(b) = mods.response_body {
75            res.body = Some(BodyData {
76                encoding: "utf-8".to_string(),
77                size: b.len() as u64,
78                content: b,
79            });
80        }
81
82        InterceptionResult::ModifiedResponse(res)
83    } else {
84        InterceptionResult::Continue
85    }
86}
87
88/// 将 FlowModification 应用到 WebSocket 消息上。
89///
90/// 仅修改 message_content;其余字段(方向、opcode、时间戳等)保持不变。
91/// 若 modification 不含 message_content,则原样返回消息。
92pub fn apply_ws_modification(
93    message: &WebSocketMessage,
94    mods: FlowModification,
95) -> InterceptionResult {
96    let mut new_msg = message.clone();
97    if let Some(content) = mods.message_content {
98        new_msg.content.size = content.len() as u64;
99        new_msg.content.content = content;
100    }
101    InterceptionResult::ModifiedMessage(new_msg)
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use chrono::Utc;
108    use relay_core_api::flow::{
109        BodyData, Direction, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo,
110        ResponseTiming, TransportProtocol, WebSocketLayer, WebSocketMessage,
111    };
112    use relay_core_api::modification::FlowModification;
113    use std::collections::HashMap;
114    use url::Url;
115    use uuid::Uuid;
116
117    fn make_http_flow(url: &str) -> Flow {
118        Flow {
119            id: Uuid::new_v4(),
120            start_time: Utc::now(),
121            end_time: None,
122            network: NetworkInfo {
123                client_ip: "127.0.0.1".to_string(),
124                client_port: 12345,
125                server_ip: "1.1.1.1".to_string(),
126                server_port: 80,
127                protocol: TransportProtocol::TCP,
128                tls: false,
129                tls_version: None,
130                sni: None,
131            },
132            layer: Layer::Http(HttpLayer {
133                request: HttpRequest {
134                    method: "GET".to_string(),
135                    url: Url::parse(url).unwrap(),
136                    version: "HTTP/1.1".to_string(),
137                    headers: vec![],
138                    body: None,
139                    cookies: vec![],
140                    query: vec![],
141                },
142                response: None,
143                error: None,
144            }),
145            tags: vec![],
146            meta: HashMap::new(),
147        }
148    }
149
150    fn make_http_flow_with_response(url: &str) -> Flow {
151        let mut flow = make_http_flow(url);
152        if let Layer::Http(ref mut h) = flow.layer {
153            h.response = Some(HttpResponse {
154                status: 200,
155                status_text: "OK".to_string(),
156                version: "HTTP/1.1".to_string(),
157                headers: vec![],
158                body: None,
159                timing: ResponseTiming {
160                    time_to_first_byte: None,
161                    time_to_last_byte: None,
162                    connect_time_ms: None,
163                    ssl_time_ms: None,
164                },
165                cookies: vec![],
166            });
167        }
168        flow
169    }
170
171    fn make_ws_flow(url: &str) -> Flow {
172        Flow {
173            id: Uuid::new_v4(),
174            start_time: Utc::now(),
175            end_time: None,
176            network: NetworkInfo {
177                client_ip: "127.0.0.1".to_string(),
178                client_port: 12345,
179                server_ip: "1.1.1.1".to_string(),
180                server_port: 80,
181                protocol: TransportProtocol::TCP,
182                tls: false,
183                tls_version: None,
184                sni: None,
185            },
186            layer: Layer::WebSocket(WebSocketLayer {
187                handshake_request: HttpRequest {
188                    method: "GET".to_string(),
189                    url: Url::parse(url).unwrap(),
190                    version: "HTTP/1.1".to_string(),
191                    headers: vec![("Upgrade".to_string(), "websocket".to_string())],
192                    body: None,
193                    cookies: vec![],
194                    query: vec![],
195                },
196                handshake_response: HttpResponse {
197                    status: 101,
198                    status_text: "Switching Protocols".to_string(),
199                    version: "HTTP/1.1".to_string(),
200                    headers: vec![],
201                    body: None,
202                    timing: ResponseTiming {
203                        time_to_first_byte: None,
204                        time_to_last_byte: None,
205                        connect_time_ms: None,
206                        ssl_time_ms: None,
207                    },
208                    cookies: vec![],
209                },
210                messages: vec![],
211                closed: false,
212            }),
213            tags: vec![],
214            meta: HashMap::new(),
215        }
216    }
217
218    fn make_ws_message(content: &str) -> WebSocketMessage {
219        WebSocketMessage {
220            id: Uuid::new_v4(),
221            timestamp: Utc::now(),
222            direction: Direction::ClientToServer,
223            content: BodyData {
224                encoding: "utf-8".to_string(),
225                content: content.to_string(),
226                size: content.len() as u64,
227            },
228            opcode: "Text".to_string(),
229        }
230    }
231
232    // --- apply_flow_modification: request phase ---
233
234    #[test]
235    fn test_request_modification_applies_all_fields() {
236        let flow = make_http_flow("http://example.com/api");
237        let mods = FlowModification {
238            method: Some("POST".to_string()),
239            url: Some("http://example.com/v2/api".to_string()),
240            request_headers: Some(HashMap::from([("X-Custom".to_string(), "123".to_string())])),
241            request_body: Some("new-body".to_string()),
242            ..Default::default()
243        };
244
245        let result = apply_flow_modification(&flow, "request", mods);
246
247        if let InterceptionResult::ModifiedRequest(req) = result {
248            assert_eq!(req.method, "POST");
249            assert_eq!(req.url.as_str(), "http://example.com/v2/api");
250            assert!(
251                req.headers
252                    .iter()
253                    .any(|(k, v)| k == "X-Custom" && v == "123")
254            );
255            assert_eq!(req.body.unwrap().content, "new-body");
256        } else {
257            panic!("expected ModifiedRequest");
258        }
259    }
260
261    #[test]
262    fn test_request_modification_invalid_url_keeps_original() {
263        let flow = make_http_flow("http://example.com/api");
264        let original_url = match &flow.layer {
265            Layer::Http(h) => h.request.url.clone(),
266            _ => panic!("expected http layer"),
267        };
268        let mods = FlowModification {
269            method: Some("PUT".to_string()),
270            url: Some("://invalid-url".to_string()),
271            ..Default::default()
272        };
273
274        let result = apply_flow_modification(&flow, "request", mods);
275
276        match result {
277            InterceptionResult::ModifiedRequest(req) => {
278                assert_eq!(req.method, "PUT");
279                assert_eq!(req.url, original_url, "invalid URL should keep original");
280            }
281            other => panic!("expected ModifiedRequest, got {:?}", other),
282        }
283    }
284
285    #[test]
286    fn test_request_headers_phase_prefix_routes_to_request() {
287        let flow = make_http_flow("http://example.com/api");
288        let mods = FlowModification {
289            method: Some("PATCH".to_string()),
290            ..Default::default()
291        };
292        let result = apply_flow_modification(&flow, "request_headers", mods);
293        assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
294    }
295
296    #[test]
297    fn test_request_body_phase_prefix_routes_to_request() {
298        let flow = make_http_flow("http://example.com/api");
299        let mods = FlowModification {
300            request_body: Some("hello".to_string()),
301            ..Default::default()
302        };
303        let result = apply_flow_modification(&flow, "request_body", mods);
304        assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
305    }
306
307    // --- apply_flow_modification: response phase ---
308
309    #[test]
310    fn test_response_modification_applies_all_fields() {
311        let flow = make_http_flow_with_response("http://example.com/api");
312        let mods = FlowModification {
313            status_code: Some(404),
314            response_headers: Some(HashMap::from([(
315                "Content-Type".to_string(),
316                "application/json".to_string(),
317            )])),
318            response_body: Some("{\"error\": \"not found\"}".to_string()),
319            ..Default::default()
320        };
321
322        let result = apply_flow_modification(&flow, "response", mods);
323
324        if let InterceptionResult::ModifiedResponse(res) = result {
325            assert_eq!(res.status, 404);
326            assert!(
327                res.headers
328                    .iter()
329                    .any(|(k, v)| k == "Content-Type" && v == "application/json")
330            );
331            assert_eq!(res.body.unwrap().content, "{\"error\": \"not found\"}");
332        } else {
333            panic!("expected ModifiedResponse");
334        }
335    }
336
337    #[test]
338    fn test_response_modification_no_existing_response_uses_default() {
339        // Flow has no response yet
340        let flow = make_http_flow("http://example.com/api");
341        let mods = FlowModification {
342            status_code: Some(503),
343            ..Default::default()
344        };
345
346        let result = apply_flow_modification(&flow, "response_headers", mods);
347
348        if let InterceptionResult::ModifiedResponse(res) = result {
349            assert_eq!(res.status, 503);
350        } else {
351            panic!("expected ModifiedResponse");
352        }
353    }
354
355    // --- apply_flow_modification: websocket handshake ---
356
357    #[test]
358    fn test_ws_handshake_request_modification() {
359        let flow = make_ws_flow("ws://example.com/socket");
360        let mods = FlowModification {
361            url: Some("ws://example.com/socket-v2".to_string()),
362            ..Default::default()
363        };
364
365        let result = apply_flow_modification(&flow, "request", mods);
366
367        if let InterceptionResult::ModifiedRequest(req) = result {
368            assert_eq!(req.url.as_str(), "ws://example.com/socket-v2");
369        } else {
370            panic!("expected ModifiedRequest for WebSocket handshake");
371        }
372    }
373
374    // --- apply_flow_modification: unknown phase ---
375
376    #[test]
377    fn test_unknown_phase_returns_continue() {
378        let flow = make_http_flow("http://example.com/api");
379        let mods = FlowModification::default();
380        let result = apply_flow_modification(&flow, "pre-request", mods);
381        assert!(matches!(result, InterceptionResult::Continue));
382    }
383
384    // --- apply_ws_modification ---
385
386    #[test]
387    fn test_ws_modification_replaces_content() {
388        let msg = make_ws_message("original");
389        let mods = FlowModification {
390            message_content: Some("modified".to_string()),
391            ..Default::default()
392        };
393
394        let result = apply_ws_modification(&msg, mods);
395
396        if let InterceptionResult::ModifiedMessage(new_msg) = result {
397            assert_eq!(new_msg.content.content, "modified");
398            assert_eq!(new_msg.content.size, 8);
399            assert_eq!(new_msg.direction, Direction::ClientToServer);
400            assert_eq!(new_msg.opcode, "Text");
401        } else {
402            panic!("expected ModifiedMessage");
403        }
404    }
405
406    #[test]
407    fn test_ws_modification_no_content_returns_original_message() {
408        let msg = make_ws_message("origin");
409        let mods = FlowModification::default();
410
411        let result = apply_ws_modification(&msg, mods);
412
413        if let InterceptionResult::ModifiedMessage(new_msg) = result {
414            assert_eq!(new_msg.content.content, "origin");
415            assert_eq!(new_msg.content.size, 6);
416        } else {
417            panic!("expected ModifiedMessage");
418        }
419    }
420}