Skip to main content

relay_core_script/
lib.rs

1//! Internal Deno/V8 scripting engine for [relay-core](https://crates.io/crates/relay-core).
2//! Provides the `ScriptInterceptor` that implements runtime script modification of traffic.
3//!
4//! **This is a feature backend for `relay-core`.** Enable with `relay-core = { features = ["script"] }`.
5
6pub mod deno_engine;
7pub mod engine_trait;
8pub mod streams;
9
10use crate::deno_engine::DenoScriptEngine;
11use crate::engine_trait::ScriptEngineTrait;
12use async_trait::async_trait;
13use relay_core_api::flow::{Flow, WebSocketMessage};
14use relay_core_lib::interceptor::{
15    BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, ResponseAction,
16    WebSocketMessageAction,
17};
18use std::collections::hash_map::DefaultHasher;
19use std::hash::{Hash, Hasher};
20use tokio::sync::RwLock;
21
22pub struct ScriptInterceptor {
23    engines: Vec<RwLock<Box<dyn ScriptEngineTrait>>>,
24}
25
26impl ScriptInterceptor {
27    pub async fn new() -> Result<Self, BoxError> {
28        let pool_size = std::thread::available_parallelism()
29            .map(|n| n.get())
30            .unwrap_or(4);
31        let mut engines = Vec::with_capacity(pool_size);
32
33        for _ in 0..pool_size {
34            let engine: Box<dyn ScriptEngineTrait> = Box::new(DenoScriptEngine::new());
35            engines.push(RwLock::new(engine));
36        }
37
38        Ok(Self { engines })
39    }
40
41    pub async fn load_script(&self, script: &str) -> Result<(), BoxError> {
42        // Load script into ALL engines to keep them consistent.
43        // Optimization: Load into new engines first to avoid blocking request processing,
44        // then swap them in quickly.
45        let pool_size = self.engines.len();
46        let mut new_engines = Vec::with_capacity(pool_size);
47
48        for _ in 0..pool_size {
49            let mut engine = DenoScriptEngine::new();
50            engine.load_script(script).await?;
51            new_engines.push(Box::new(engine) as Box<dyn ScriptEngineTrait>);
52        }
53
54        // We acquire write locks only for the swap.
55        // The new engines are already prepared, so the critical section is very short.
56        let mut new_engines_iter = new_engines.into_iter();
57        for engine_lock in &self.engines {
58            if let Some(new_engine) = new_engines_iter.next() {
59                let mut guard = engine_lock.write().await;
60                *guard = new_engine;
61            }
62        }
63
64        Ok(())
65    }
66
67    fn get_engine_index(&self) -> usize {
68        // Optimization: Allow task-local override for engine index (e.g. for testing or specific routing)
69        if let Ok(index) = relay_core_lib::interceptor::ENGINE_INDEX.try_with(|i| *i) {
70            return index % self.engines.len();
71        }
72
73        let thread_id = std::thread::current().id();
74        let mut hasher = DefaultHasher::new();
75        thread_id.hash(&mut hasher);
76        (hasher.finish() as usize) % self.engines.len()
77    }
78}
79
80#[async_trait]
81impl Interceptor for ScriptInterceptor {
82    async fn on_request_headers(&self, flow: &mut Flow) -> InterceptionResult {
83        let index = self.get_engine_index();
84        let engine_lock = &self.engines[index];
85        let engine = engine_lock.read().await;
86
87        match engine.on_request_headers(flow).await {
88            Ok(Some(modified_flow)) => {
89                *flow = modified_flow;
90                InterceptionResult::ModifiedRequest(match &flow.layer {
91                    relay_core_api::flow::Layer::Http(h) => h.request.clone(),
92                    relay_core_api::flow::Layer::WebSocket(w) => w.handshake_request.clone(),
93                    _ => return InterceptionResult::Continue,
94                })
95            }
96            Ok(None) => InterceptionResult::Continue,
97            Err(e) => {
98                tracing::error!("Script execution error (on_request_headers): {}", e);
99                flow.tags.push("script-error".to_string());
100                if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
101                    http.error = Some(format!("Script Error: {}", e));
102                }
103                InterceptionResult::Continue
104            }
105        }
106    }
107
108    async fn on_request(&self, flow: &mut Flow, body: HttpBody) -> Result<RequestAction, BoxError> {
109        let index = self.get_engine_index();
110        let engine_lock = &self.engines[index];
111        let engine = engine_lock.read().await;
112
113        engine.on_request(flow, body).await
114    }
115
116    async fn on_response_headers(&self, flow: &mut Flow) -> InterceptionResult {
117        let index = self.get_engine_index();
118        let engine_lock = &self.engines[index];
119        let engine = engine_lock.read().await;
120
121        match engine.on_response_headers(flow).await {
122            Ok(Some(modified_flow)) => {
123                *flow = modified_flow;
124                InterceptionResult::ModifiedResponse(match &flow.layer {
125                    relay_core_api::flow::Layer::Http(h) => h.response.clone().unwrap(),
126                    relay_core_api::flow::Layer::WebSocket(w) => w.handshake_response.clone(),
127                    _ => return InterceptionResult::Continue,
128                })
129            }
130            Ok(None) => InterceptionResult::Continue,
131            Err(e) => {
132                tracing::error!("Script execution error (on_response_headers): {}", e);
133                flow.tags.push("script-error".to_string());
134                if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
135                    http.error = Some(format!("Script Error: {}", e));
136                }
137                InterceptionResult::Continue
138            }
139        }
140    }
141
142    async fn on_response(
143        &self,
144        flow: &mut Flow,
145        body: HttpBody,
146    ) -> Result<ResponseAction, BoxError> {
147        let index = self.get_engine_index();
148        let engine_lock = &self.engines[index];
149        let engine = engine_lock.read().await;
150
151        engine.on_response(flow, body).await
152    }
153
154    async fn on_websocket_message(
155        &self,
156        flow: &mut Flow,
157        mut message: WebSocketMessage,
158    ) -> Result<WebSocketMessageAction, BoxError> {
159        let index = self.get_engine_index();
160        let engine_lock = &self.engines[index];
161        let engine = engine_lock.read().await;
162
163        match engine.on_websocket_message(flow, &mut message).await {
164            Ok(action) => Ok(action),
165            Err(e) => {
166                tracing::error!("Script execution error (on_websocket_message): {}", e);
167                flow.tags.push("script-error".to_string());
168                Ok(WebSocketMessageAction::Continue(message))
169            }
170        }
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use bytes::Bytes;
178    use chrono::Utc;
179    use http_body_util::{BodyExt, Empty};
180    use relay_core_api::flow::{
181        BodyData, Direction, Flow, HttpLayer, HttpRequest, Layer, NetworkInfo, TransportProtocol,
182        WebSocketMessage,
183    };
184    use std::collections::HashMap;
185    use url::Url;
186    use uuid::Uuid;
187
188    fn create_test_flow() -> Flow {
189        Flow {
190            id: Uuid::new_v4(),
191            start_time: Utc::now(),
192            end_time: None,
193            network: NetworkInfo {
194                client_ip: "127.0.0.1".to_string(),
195                client_port: 12345,
196                server_ip: "1.1.1.1".to_string(),
197                server_port: 80,
198                protocol: TransportProtocol::TCP,
199                tls: false,
200                tls_version: None,
201                sni: None,
202            },
203            layer: Layer::Http(HttpLayer {
204                request: HttpRequest {
205                    method: "GET".to_string(),
206                    url: Url::parse("http://example.com").unwrap(),
207                    version: "HTTP/1.1".to_string(),
208                    headers: vec![],
209                    cookies: vec![],
210                    query: vec![],
211                    body: None,
212                },
213                response: None,
214                error: None,
215            }),
216            tags: vec![],
217            meta: HashMap::new(),
218        }
219    }
220
221    #[tokio::test]
222    async fn test_script_error_propagation() {
223        let interceptor = ScriptInterceptor::new().await.unwrap();
224
225        let script = r#"
226            globalThis.onRequestHeaders = (flow) => {
227                throw new Error("Test Error 123");
228            };
229        "#;
230        interceptor.load_script(script).await.unwrap();
231
232        let mut flow = create_test_flow();
233
234        let result = interceptor.on_request_headers(&mut flow).await;
235
236        match result {
237            InterceptionResult::Continue => {}
238            _ => panic!("Expected Continue"),
239        }
240
241        assert!(flow.tags.contains(&"script-error".to_string()));
242
243        if let Layer::Http(http) = &flow.layer {
244            assert!(http.error.is_some());
245            let err = http.error.as_ref().unwrap();
246            assert!(err.contains("Test Error 123"));
247        } else {
248            panic!("Expected Http layer");
249        }
250    }
251
252    #[tokio::test]
253    async fn test_script_api_relay_body() {
254        let interceptor = ScriptInterceptor::new().await.unwrap();
255
256        let script = r#"
257            globalThis.onRequest = (body, flow) => {
258                if (!(body instanceof RelayBody)) {
259                    throw new Error("First argument is not RelayBody");
260                }
261                // We return nothing (undefined), which means "continue with original body"
262                // But we successfully verified the type.
263            };
264        "#;
265        interceptor.load_script(script).await.unwrap();
266
267        let mut flow = create_test_flow();
268        let body = Empty::<Bytes>::new()
269            .map_err(|_| -> BoxError { unreachable!() })
270            .boxed();
271
272        let result = interceptor.on_request(&mut flow, body).await;
273        assert!(result.is_ok());
274    }
275
276    #[tokio::test]
277    async fn test_load_script_failure_does_not_replace_existing_engines() {
278        let interceptor = ScriptInterceptor::new().await.unwrap();
279
280        let good_script = r#"
281            globalThis.onRequestHeaders = (_context, flow) => {
282                if (flow.layer.type === "Http") {
283                    flow.layer.data.request.headers.push(["X-Good-Script", "1"]);
284                }
285                return flow;
286            };
287        "#;
288        interceptor
289            .load_script(good_script)
290            .await
291            .expect("good script should load");
292
293        let bad_script = "globalThis.onRequestHeaders = () => { invalid javascript !!!";
294        let bad = interceptor.load_script(bad_script).await;
295        assert!(bad.is_err(), "bad script should be rejected");
296
297        let mut flow = create_test_flow();
298        let result = interceptor.on_request_headers(&mut flow).await;
299        assert!(
300            matches!(result, InterceptionResult::ModifiedRequest(_)),
301            "existing good script should still be active after failed reload"
302        );
303        if let Layer::Http(http) = &flow.layer {
304            assert!(
305                http.request
306                    .headers
307                    .iter()
308                    .any(|(k, v)| k == "X-Good-Script" && v == "1")
309            );
310        } else {
311            panic!("Expected Http layer");
312        }
313    }
314
315    #[tokio::test]
316    async fn test_websocket_script_error_falls_back_continue_and_tags_flow() {
317        let interceptor = ScriptInterceptor::new().await.unwrap();
318        let script = r#"
319            globalThis.onWebSocketMessage = function(_context, _flow, _message) {
320                throw new Error("ws explode");
321            };
322        "#;
323        interceptor
324            .load_script(script)
325            .await
326            .expect("script should load");
327
328        let mut flow = create_test_flow();
329        let msg = WebSocketMessage {
330            id: Uuid::new_v4(),
331            timestamp: Utc::now(),
332            direction: Direction::ClientToServer,
333            content: BodyData {
334                encoding: "utf-8".to_string(),
335                content: "hello".to_string(),
336                size: 5,
337            },
338            opcode: "Text".to_string(),
339        };
340
341        let result = interceptor
342            .on_websocket_message(&mut flow, msg.clone())
343            .await
344            .expect("websocket interception should not return hard error");
345        match result {
346            WebSocketMessageAction::Continue(forwarded) => {
347                assert_eq!(forwarded.content.content, "hello");
348            }
349            other => panic!("expected Continue fallback, got {:?}", other),
350        }
351        assert!(
352            flow.tags.iter().any(|t| t == "script-error"),
353            "script error should be tagged for observability"
354        );
355    }
356}