Skip to main content

relay_core_runtime/
modification.rs

1use relay_core_api::flow::{BodyData, Flow, HttpResponse, Layer, ResponseTiming, WebSocketMessage};
2use relay_core_api::modification::FlowModification;
3use relay_core_lib::InterceptionResult;
4use url::Url;
5
6// Re-export API modification types at relay_core_runtime::modification::
7pub use relay_core_api::modification::{FlowQuery, FlowSummary};
8
9/// 将 FlowModification 应用到 Flow 的请求或响应上。
10///
11/// `phase` 约定:以 "request" 开头表示修改请求,以 "response" 开头表示修改响应,
12/// 其他值返回 Continue。
13///
14/// 如果 Flow 的 Layer 不支持(Tcp/Unknown 等),同样返回 Continue。
15pub fn apply_flow_modification(flow: &Flow, phase: &str, mods: FlowModification) -> InterceptionResult {
16    if phase.starts_with("request") {
17        let mut req = match &flow.layer {
18            Layer::Http(h) => h.request.clone(),
19            Layer::WebSocket(ws) => ws.handshake_request.clone(),
20            _ => return InterceptionResult::Continue,
21        };
22
23        if let Some(m) = mods.method {
24            req.method = m;
25        }
26        if let Some(u) = mods.url {
27            // 无效 URL 静默忽略,保留原始值
28            if let Ok(parsed) = Url::parse(&u) {
29                req.url = parsed;
30            }
31        }
32        if let Some(h) = mods.request_headers {
33            req.headers = h.into_iter().collect();
34        }
35        if let Some(b) = mods.request_body {
36            req.body = Some(BodyData {
37                encoding: "utf-8".to_string(),
38                size: b.len() as u64,
39                content: b,
40            });
41        }
42
43        InterceptionResult::ModifiedRequest(req)
44    } else if phase.starts_with("response") {
45        let mut res = match &flow.layer {
46            Layer::Http(h) => h.response.clone().unwrap_or_else(|| HttpResponse {
47                status: 200,
48                status_text: "OK".to_string(),
49                version: "HTTP/1.1".to_string(),
50                headers: vec![],
51                body: None,
52                timing: ResponseTiming { time_to_first_byte: None, time_to_last_byte: None },
53                cookies: vec![],
54            }),
55            Layer::WebSocket(ws) => ws.handshake_response.clone(),
56            _ => return InterceptionResult::Continue,
57        };
58
59        if let Some(s) = mods.status_code {
60            res.status = s;
61        }
62        if let Some(h) = mods.response_headers {
63            res.headers = h.into_iter().collect();
64        }
65        if let Some(b) = mods.response_body {
66            res.body = Some(BodyData {
67                encoding: "utf-8".to_string(),
68                size: b.len() as u64,
69                content: b,
70            });
71        }
72
73        InterceptionResult::ModifiedResponse(res)
74    } else {
75        InterceptionResult::Continue
76    }
77}
78
79/// 将 FlowModification 应用到 WebSocket 消息上。
80///
81/// 仅修改 message_content;其余字段(方向、opcode、时间戳等)保持不变。
82/// 若 modification 不含 message_content,则原样返回消息。
83pub fn apply_ws_modification(message: &WebSocketMessage, mods: FlowModification) -> InterceptionResult {
84    let mut new_msg = message.clone();
85    if let Some(content) = mods.message_content {
86        new_msg.content.size = content.len() as u64;
87        new_msg.content.content = content;
88    }
89    InterceptionResult::ModifiedMessage(new_msg)
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95    use chrono::Utc;
96    use relay_core_api::flow::{
97        BodyData, Direction, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo,
98        ResponseTiming, TransportProtocol, WebSocketLayer, WebSocketMessage,
99    };
100    use relay_core_api::modification::FlowModification;
101    use std::collections::HashMap;
102    use url::Url;
103    use uuid::Uuid;
104
105    fn make_http_flow(url: &str) -> Flow {
106        Flow {
107            id: Uuid::new_v4(),
108            start_time: Utc::now(),
109            end_time: None,
110            network: NetworkInfo {
111                client_ip: "127.0.0.1".to_string(),
112                client_port: 12345,
113                server_ip: "1.1.1.1".to_string(),
114                server_port: 80,
115                protocol: TransportProtocol::TCP,
116                tls: false,
117                tls_version: None,
118                sni: None,
119            },
120            layer: Layer::Http(HttpLayer {
121                request: HttpRequest {
122                    method: "GET".to_string(),
123                    url: Url::parse(url).unwrap(),
124                    version: "HTTP/1.1".to_string(),
125                    headers: vec![],
126                    body: None,
127                    cookies: vec![],
128                    query: vec![],
129                },
130                response: None,
131                error: None,
132            }),
133            tags: vec![],
134            meta: HashMap::new(),
135        }
136    }
137
138    fn make_http_flow_with_response(url: &str) -> Flow {
139        let mut flow = make_http_flow(url);
140        if let Layer::Http(ref mut h) = flow.layer {
141            h.response = Some(HttpResponse {
142                status: 200,
143                status_text: "OK".to_string(),
144                version: "HTTP/1.1".to_string(),
145                headers: vec![],
146                body: None,
147                timing: ResponseTiming { time_to_first_byte: None, time_to_last_byte: None },
148                cookies: vec![],
149            });
150        }
151        flow
152    }
153
154    fn make_ws_flow(url: &str) -> Flow {
155        Flow {
156            id: Uuid::new_v4(),
157            start_time: Utc::now(),
158            end_time: None,
159            network: NetworkInfo {
160                client_ip: "127.0.0.1".to_string(),
161                client_port: 12345,
162                server_ip: "1.1.1.1".to_string(),
163                server_port: 80,
164                protocol: TransportProtocol::TCP,
165                tls: false,
166                tls_version: None,
167                sni: None,
168            },
169            layer: Layer::WebSocket(WebSocketLayer {
170                handshake_request: HttpRequest {
171                    method: "GET".to_string(),
172                    url: Url::parse(url).unwrap(),
173                    version: "HTTP/1.1".to_string(),
174                    headers: vec![("Upgrade".to_string(), "websocket".to_string())],
175                    body: None,
176                    cookies: vec![],
177                    query: vec![],
178                },
179                handshake_response: HttpResponse {
180                    status: 101,
181                    status_text: "Switching Protocols".to_string(),
182                    version: "HTTP/1.1".to_string(),
183                    headers: vec![],
184                    body: None,
185                    timing: ResponseTiming { time_to_first_byte: None, time_to_last_byte: None },
186                    cookies: vec![],
187                },
188                messages: vec![],
189                closed: false,
190            }),
191            tags: vec![],
192            meta: HashMap::new(),
193        }
194    }
195
196    fn make_ws_message(content: &str) -> WebSocketMessage {
197        WebSocketMessage {
198            id: Uuid::new_v4(),
199            timestamp: Utc::now(),
200            direction: Direction::ClientToServer,
201            content: BodyData {
202                encoding: "utf-8".to_string(),
203                content: content.to_string(),
204                size: content.len() as u64,
205            },
206            opcode: "Text".to_string(),
207        }
208    }
209
210    // --- apply_flow_modification: request phase ---
211
212    #[test]
213    fn test_request_modification_applies_all_fields() {
214        let flow = make_http_flow("http://example.com/api");
215        let mods = FlowModification {
216            method: Some("POST".to_string()),
217            url: Some("http://example.com/v2/api".to_string()),
218            request_headers: Some(HashMap::from([("X-Custom".to_string(), "123".to_string())])),
219            request_body: Some("new-body".to_string()),
220            ..Default::default()
221        };
222
223        let result = apply_flow_modification(&flow, "request", mods);
224
225        if let InterceptionResult::ModifiedRequest(req) = result {
226            assert_eq!(req.method, "POST");
227            assert_eq!(req.url.as_str(), "http://example.com/v2/api");
228            assert!(req.headers.iter().any(|(k, v)| k == "X-Custom" && v == "123"));
229            assert_eq!(req.body.unwrap().content, "new-body");
230        } else {
231            panic!("expected ModifiedRequest");
232        }
233    }
234
235    #[test]
236    fn test_request_modification_invalid_url_keeps_original() {
237        let flow = make_http_flow("http://example.com/api");
238        let original_url = match &flow.layer {
239            Layer::Http(h) => h.request.url.clone(),
240            _ => panic!("expected http layer"),
241        };
242        let mods = FlowModification {
243            method: Some("PUT".to_string()),
244            url: Some("://invalid-url".to_string()),
245            ..Default::default()
246        };
247
248        let result = apply_flow_modification(&flow, "request", mods);
249
250        match result {
251            InterceptionResult::ModifiedRequest(req) => {
252                assert_eq!(req.method, "PUT");
253                assert_eq!(req.url, original_url, "invalid URL should keep original");
254            }
255            other => panic!("expected ModifiedRequest, got {:?}", other),
256        }
257    }
258
259    #[test]
260    fn test_request_headers_phase_prefix_routes_to_request() {
261        let flow = make_http_flow("http://example.com/api");
262        let mods = FlowModification {
263            method: Some("PATCH".to_string()),
264            ..Default::default()
265        };
266        let result = apply_flow_modification(&flow, "request_headers", mods);
267        assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
268    }
269
270    #[test]
271    fn test_request_body_phase_prefix_routes_to_request() {
272        let flow = make_http_flow("http://example.com/api");
273        let mods = FlowModification {
274            request_body: Some("hello".to_string()),
275            ..Default::default()
276        };
277        let result = apply_flow_modification(&flow, "request_body", mods);
278        assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
279    }
280
281    // --- apply_flow_modification: response phase ---
282
283    #[test]
284    fn test_response_modification_applies_all_fields() {
285        let flow = make_http_flow_with_response("http://example.com/api");
286        let mods = FlowModification {
287            status_code: Some(404),
288            response_headers: Some(HashMap::from([("Content-Type".to_string(), "application/json".to_string())])),
289            response_body: Some("{\"error\": \"not found\"}".to_string()),
290            ..Default::default()
291        };
292
293        let result = apply_flow_modification(&flow, "response", mods);
294
295        if let InterceptionResult::ModifiedResponse(res) = result {
296            assert_eq!(res.status, 404);
297            assert!(res.headers.iter().any(|(k, v)| k == "Content-Type" && v == "application/json"));
298            assert_eq!(res.body.unwrap().content, "{\"error\": \"not found\"}");
299        } else {
300            panic!("expected ModifiedResponse");
301        }
302    }
303
304    #[test]
305    fn test_response_modification_no_existing_response_uses_default() {
306        // Flow has no response yet
307        let flow = make_http_flow("http://example.com/api");
308        let mods = FlowModification {
309            status_code: Some(503),
310            ..Default::default()
311        };
312
313        let result = apply_flow_modification(&flow, "response_headers", mods);
314
315        if let InterceptionResult::ModifiedResponse(res) = result {
316            assert_eq!(res.status, 503);
317        } else {
318            panic!("expected ModifiedResponse");
319        }
320    }
321
322    // --- apply_flow_modification: websocket handshake ---
323
324    #[test]
325    fn test_ws_handshake_request_modification() {
326        let flow = make_ws_flow("ws://example.com/socket");
327        let mods = FlowModification {
328            url: Some("ws://example.com/socket-v2".to_string()),
329            ..Default::default()
330        };
331
332        let result = apply_flow_modification(&flow, "request", mods);
333
334        if let InterceptionResult::ModifiedRequest(req) = result {
335            assert_eq!(req.url.as_str(), "ws://example.com/socket-v2");
336        } else {
337            panic!("expected ModifiedRequest for WebSocket handshake");
338        }
339    }
340
341    // --- apply_flow_modification: unknown phase ---
342
343    #[test]
344    fn test_unknown_phase_returns_continue() {
345        let flow = make_http_flow("http://example.com/api");
346        let mods = FlowModification::default();
347        let result = apply_flow_modification(&flow, "pre-request", mods);
348        assert!(matches!(result, InterceptionResult::Continue));
349    }
350
351    // --- apply_ws_modification ---
352
353    #[test]
354    fn test_ws_modification_replaces_content() {
355        let msg = make_ws_message("original");
356        let mods = FlowModification {
357            message_content: Some("modified".to_string()),
358            ..Default::default()
359        };
360
361        let result = apply_ws_modification(&msg, mods);
362
363        if let InterceptionResult::ModifiedMessage(new_msg) = result {
364            assert_eq!(new_msg.content.content, "modified");
365            assert_eq!(new_msg.content.size, 8);
366            assert_eq!(new_msg.direction, Direction::ClientToServer);
367            assert_eq!(new_msg.opcode, "Text");
368        } else {
369            panic!("expected ModifiedMessage");
370        }
371    }
372
373    #[test]
374    fn test_ws_modification_no_content_returns_original_message() {
375        let msg = make_ws_message("origin");
376        let mods = FlowModification::default();
377
378        let result = apply_ws_modification(&msg, mods);
379
380        if let InterceptionResult::ModifiedMessage(new_msg) = result {
381            assert_eq!(new_msg.content.content, "origin");
382            assert_eq!(new_msg.content.size, 6);
383        } else {
384            panic!("expected ModifiedMessage");
385        }
386    }
387}