Skip to main content

relay_core_runtime/
modification.rs

1use relay_core_api::flow::{BodyData, Flow, HttpResponse, Layer, ResponseTiming, WebSocketMessage};
2use relay_core_api::modification::FlowModification;
3use relay_core_lib::InterceptionResult;
4use url::Url;
5
6// Re-export API modification types at relay_core_runtime::modification::
7pub use relay_core_api::modification::{FlowQuery, FlowSummary};
8
9/// 将 FlowModification 应用到 Flow 的请求或响应上。
10///
11/// `phase` 约定:以 "request" 开头表示修改请求,以 "response" 开头表示修改响应,
12/// 其他值返回 Continue。
13///
14/// 如果 Flow 的 Layer 不支持(Tcp/Unknown 等),同样返回 Continue。
15pub fn apply_flow_modification(flow: &Flow, phase: &str, mods: FlowModification) -> InterceptionResult {
16    if phase.starts_with("request") {
17        let mut req = match &flow.layer {
18            Layer::Http(h) => h.request.clone(),
19            Layer::WebSocket(ws) => ws.handshake_request.clone(),
20            _ => return InterceptionResult::Continue,
21        };
22
23        if let Some(m) = mods.method {
24            req.method = m;
25        }
26        if let Some(u) = mods.url {
27            // 无效 URL 静默忽略,保留原始值
28            if let Ok(parsed) = Url::parse(&u) {
29                req.url = parsed;
30            }
31        }
32        if let Some(h) = mods.request_headers {
33            req.headers = h.into_iter().collect();
34        }
35        if let Some(b) = mods.request_body {
36            req.body = Some(BodyData {
37                encoding: "utf-8".to_string(),
38                size: b.len() as u64,
39                content: b,
40            });
41        }
42
43        InterceptionResult::ModifiedRequest(req)
44    } else if phase.starts_with("response") {
45        let mut res = match &flow.layer {
46            Layer::Http(h) => h.response.clone().unwrap_or_else(|| HttpResponse {
47                status: 200,
48                status_text: "OK".to_string(),
49                version: "HTTP/1.1".to_string(),
50                headers: vec![],
51                body: None,
52                timing: ResponseTiming { time_to_first_byte: None, time_to_last_byte: None,
53                connect_time_ms: None,
54                ssl_time_ms: None,
55            },
56                cookies: vec![],
57            }),
58            Layer::WebSocket(ws) => ws.handshake_response.clone(),
59            _ => return InterceptionResult::Continue,
60        };
61
62        if let Some(s) = mods.status_code {
63            res.status = s;
64        }
65        if let Some(h) = mods.response_headers {
66            res.headers = h.into_iter().collect();
67        }
68        if let Some(b) = mods.response_body {
69            res.body = Some(BodyData {
70                encoding: "utf-8".to_string(),
71                size: b.len() as u64,
72                content: b,
73            });
74        }
75
76        InterceptionResult::ModifiedResponse(res)
77    } else {
78        InterceptionResult::Continue
79    }
80}
81
82/// 将 FlowModification 应用到 WebSocket 消息上。
83///
84/// 仅修改 message_content;其余字段(方向、opcode、时间戳等)保持不变。
85/// 若 modification 不含 message_content,则原样返回消息。
86pub fn apply_ws_modification(message: &WebSocketMessage, mods: FlowModification) -> InterceptionResult {
87    let mut new_msg = message.clone();
88    if let Some(content) = mods.message_content {
89        new_msg.content.size = content.len() as u64;
90        new_msg.content.content = content;
91    }
92    InterceptionResult::ModifiedMessage(new_msg)
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use chrono::Utc;
99    use relay_core_api::flow::{
100        BodyData, Direction, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo,
101        ResponseTiming, TransportProtocol, WebSocketLayer, WebSocketMessage,
102    };
103    use relay_core_api::modification::FlowModification;
104    use std::collections::HashMap;
105    use url::Url;
106    use uuid::Uuid;
107
108    fn make_http_flow(url: &str) -> Flow {
109        Flow {
110            id: Uuid::new_v4(),
111            start_time: Utc::now(),
112            end_time: None,
113            network: NetworkInfo {
114                client_ip: "127.0.0.1".to_string(),
115                client_port: 12345,
116                server_ip: "1.1.1.1".to_string(),
117                server_port: 80,
118                protocol: TransportProtocol::TCP,
119                tls: false,
120                tls_version: None,
121                sni: None,
122            },
123            layer: Layer::Http(HttpLayer {
124                request: HttpRequest {
125                    method: "GET".to_string(),
126                    url: Url::parse(url).unwrap(),
127                    version: "HTTP/1.1".to_string(),
128                    headers: vec![],
129                    body: None,
130                    cookies: vec![],
131                    query: vec![],
132                },
133                response: None,
134                error: None,
135            }),
136            tags: vec![],
137            meta: HashMap::new(),
138        }
139    }
140
141    fn make_http_flow_with_response(url: &str) -> Flow {
142        let mut flow = make_http_flow(url);
143        if let Layer::Http(ref mut h) = flow.layer {
144            h.response = Some(HttpResponse {
145                status: 200,
146                status_text: "OK".to_string(),
147                version: "HTTP/1.1".to_string(),
148                headers: vec![],
149                body: None,
150                timing: ResponseTiming { time_to_first_byte: None, time_to_last_byte: None,
151                connect_time_ms: None,
152                ssl_time_ms: None,
153            },
154                cookies: vec![],
155            });
156        }
157        flow
158    }
159
160    fn make_ws_flow(url: &str) -> Flow {
161        Flow {
162            id: Uuid::new_v4(),
163            start_time: Utc::now(),
164            end_time: None,
165            network: NetworkInfo {
166                client_ip: "127.0.0.1".to_string(),
167                client_port: 12345,
168                server_ip: "1.1.1.1".to_string(),
169                server_port: 80,
170                protocol: TransportProtocol::TCP,
171                tls: false,
172                tls_version: None,
173                sni: None,
174            },
175            layer: Layer::WebSocket(WebSocketLayer {
176                handshake_request: HttpRequest {
177                    method: "GET".to_string(),
178                    url: Url::parse(url).unwrap(),
179                    version: "HTTP/1.1".to_string(),
180                    headers: vec![("Upgrade".to_string(), "websocket".to_string())],
181                    body: None,
182                    cookies: vec![],
183                    query: vec![],
184                },
185                handshake_response: HttpResponse {
186                    status: 101,
187                    status_text: "Switching Protocols".to_string(),
188                    version: "HTTP/1.1".to_string(),
189                    headers: vec![],
190                    body: None,
191                    timing: ResponseTiming { time_to_first_byte: None, time_to_last_byte: None,
192                connect_time_ms: None,
193                ssl_time_ms: None,
194            },
195                    cookies: vec![],
196                },
197                messages: vec![],
198                closed: false,
199            }),
200            tags: vec![],
201            meta: HashMap::new(),
202        }
203    }
204
205    fn make_ws_message(content: &str) -> WebSocketMessage {
206        WebSocketMessage {
207            id: Uuid::new_v4(),
208            timestamp: Utc::now(),
209            direction: Direction::ClientToServer,
210            content: BodyData {
211                encoding: "utf-8".to_string(),
212                content: content.to_string(),
213                size: content.len() as u64,
214            },
215            opcode: "Text".to_string(),
216        }
217    }
218
219    // --- apply_flow_modification: request phase ---
220
221    #[test]
222    fn test_request_modification_applies_all_fields() {
223        let flow = make_http_flow("http://example.com/api");
224        let mods = FlowModification {
225            method: Some("POST".to_string()),
226            url: Some("http://example.com/v2/api".to_string()),
227            request_headers: Some(HashMap::from([("X-Custom".to_string(), "123".to_string())])),
228            request_body: Some("new-body".to_string()),
229            ..Default::default()
230        };
231
232        let result = apply_flow_modification(&flow, "request", mods);
233
234        if let InterceptionResult::ModifiedRequest(req) = result {
235            assert_eq!(req.method, "POST");
236            assert_eq!(req.url.as_str(), "http://example.com/v2/api");
237            assert!(req.headers.iter().any(|(k, v)| k == "X-Custom" && v == "123"));
238            assert_eq!(req.body.unwrap().content, "new-body");
239        } else {
240            panic!("expected ModifiedRequest");
241        }
242    }
243
244    #[test]
245    fn test_request_modification_invalid_url_keeps_original() {
246        let flow = make_http_flow("http://example.com/api");
247        let original_url = match &flow.layer {
248            Layer::Http(h) => h.request.url.clone(),
249            _ => panic!("expected http layer"),
250        };
251        let mods = FlowModification {
252            method: Some("PUT".to_string()),
253            url: Some("://invalid-url".to_string()),
254            ..Default::default()
255        };
256
257        let result = apply_flow_modification(&flow, "request", mods);
258
259        match result {
260            InterceptionResult::ModifiedRequest(req) => {
261                assert_eq!(req.method, "PUT");
262                assert_eq!(req.url, original_url, "invalid URL should keep original");
263            }
264            other => panic!("expected ModifiedRequest, got {:?}", other),
265        }
266    }
267
268    #[test]
269    fn test_request_headers_phase_prefix_routes_to_request() {
270        let flow = make_http_flow("http://example.com/api");
271        let mods = FlowModification {
272            method: Some("PATCH".to_string()),
273            ..Default::default()
274        };
275        let result = apply_flow_modification(&flow, "request_headers", mods);
276        assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
277    }
278
279    #[test]
280    fn test_request_body_phase_prefix_routes_to_request() {
281        let flow = make_http_flow("http://example.com/api");
282        let mods = FlowModification {
283            request_body: Some("hello".to_string()),
284            ..Default::default()
285        };
286        let result = apply_flow_modification(&flow, "request_body", mods);
287        assert!(matches!(result, InterceptionResult::ModifiedRequest(_)));
288    }
289
290    // --- apply_flow_modification: response phase ---
291
292    #[test]
293    fn test_response_modification_applies_all_fields() {
294        let flow = make_http_flow_with_response("http://example.com/api");
295        let mods = FlowModification {
296            status_code: Some(404),
297            response_headers: Some(HashMap::from([("Content-Type".to_string(), "application/json".to_string())])),
298            response_body: Some("{\"error\": \"not found\"}".to_string()),
299            ..Default::default()
300        };
301
302        let result = apply_flow_modification(&flow, "response", mods);
303
304        if let InterceptionResult::ModifiedResponse(res) = result {
305            assert_eq!(res.status, 404);
306            assert!(res.headers.iter().any(|(k, v)| k == "Content-Type" && v == "application/json"));
307            assert_eq!(res.body.unwrap().content, "{\"error\": \"not found\"}");
308        } else {
309            panic!("expected ModifiedResponse");
310        }
311    }
312
313    #[test]
314    fn test_response_modification_no_existing_response_uses_default() {
315        // Flow has no response yet
316        let flow = make_http_flow("http://example.com/api");
317        let mods = FlowModification {
318            status_code: Some(503),
319            ..Default::default()
320        };
321
322        let result = apply_flow_modification(&flow, "response_headers", mods);
323
324        if let InterceptionResult::ModifiedResponse(res) = result {
325            assert_eq!(res.status, 503);
326        } else {
327            panic!("expected ModifiedResponse");
328        }
329    }
330
331    // --- apply_flow_modification: websocket handshake ---
332
333    #[test]
334    fn test_ws_handshake_request_modification() {
335        let flow = make_ws_flow("ws://example.com/socket");
336        let mods = FlowModification {
337            url: Some("ws://example.com/socket-v2".to_string()),
338            ..Default::default()
339        };
340
341        let result = apply_flow_modification(&flow, "request", mods);
342
343        if let InterceptionResult::ModifiedRequest(req) = result {
344            assert_eq!(req.url.as_str(), "ws://example.com/socket-v2");
345        } else {
346            panic!("expected ModifiedRequest for WebSocket handshake");
347        }
348    }
349
350    // --- apply_flow_modification: unknown phase ---
351
352    #[test]
353    fn test_unknown_phase_returns_continue() {
354        let flow = make_http_flow("http://example.com/api");
355        let mods = FlowModification::default();
356        let result = apply_flow_modification(&flow, "pre-request", mods);
357        assert!(matches!(result, InterceptionResult::Continue));
358    }
359
360    // --- apply_ws_modification ---
361
362    #[test]
363    fn test_ws_modification_replaces_content() {
364        let msg = make_ws_message("original");
365        let mods = FlowModification {
366            message_content: Some("modified".to_string()),
367            ..Default::default()
368        };
369
370        let result = apply_ws_modification(&msg, mods);
371
372        if let InterceptionResult::ModifiedMessage(new_msg) = result {
373            assert_eq!(new_msg.content.content, "modified");
374            assert_eq!(new_msg.content.size, 8);
375            assert_eq!(new_msg.direction, Direction::ClientToServer);
376            assert_eq!(new_msg.opcode, "Text");
377        } else {
378            panic!("expected ModifiedMessage");
379        }
380    }
381
382    #[test]
383    fn test_ws_modification_no_content_returns_original_message() {
384        let msg = make_ws_message("origin");
385        let mods = FlowModification::default();
386
387        let result = apply_ws_modification(&msg, mods);
388
389        if let InterceptionResult::ModifiedMessage(new_msg) = result {
390            assert_eq!(new_msg.content.content, "origin");
391            assert_eq!(new_msg.content.size, 6);
392        } else {
393            panic!("expected ModifiedMessage");
394        }
395    }
396}