Skip to main content

relay_core_script/
lib.rs

1//! Internal Deno/V8 scripting engine for [relay-core](https://crates.io/crates/relay-core).
2//! Provides the `ScriptInterceptor` that implements runtime script modification of traffic.
3//!
4//! **This is a feature backend for `relay-core`.** Enable with `relay-core = { features = ["script"] }`.
5
6pub mod deno_engine;
7pub mod engine_trait;
8pub mod streams;
9
10use crate::deno_engine::DenoScriptEngine;
11use crate::engine_trait::ScriptEngineTrait;
12use async_trait::async_trait;
13use relay_core_api::flow::{Flow, WebSocketMessage};
14use relay_core_lib::interceptor::{
15    BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, ResponseAction,
16    WebSocketMessageAction,
17};
18use std::collections::HashSet;
19use std::collections::hash_map::DefaultHasher;
20use std::hash::{Hash, Hasher};
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::time::Instant;
23use tokio::sync::RwLock;
24
25/// S4: Per-hook script execution metrics exposed via Prometheus/text.
26pub struct ScriptMetrics {
27    pub on_request_headers_duration_us: AtomicU64,
28    pub on_request_headers_invocations: AtomicU64,
29    pub on_request_headers_errors: AtomicU64,
30
31    pub on_request_duration_us: AtomicU64,
32    pub on_request_invocations: AtomicU64,
33    pub on_request_errors: AtomicU64,
34
35    pub on_response_headers_duration_us: AtomicU64,
36    pub on_response_headers_invocations: AtomicU64,
37    pub on_response_headers_errors: AtomicU64,
38
39    pub on_response_duration_us: AtomicU64,
40    pub on_response_invocations: AtomicU64,
41    pub on_response_errors: AtomicU64,
42
43    pub on_websocket_message_duration_us: AtomicU64,
44    pub on_websocket_message_invocations: AtomicU64,
45    pub on_websocket_message_errors: AtomicU64,
46}
47
48impl Default for ScriptMetrics {
49    fn default() -> Self {
50        Self {
51            on_request_headers_duration_us: AtomicU64::new(0),
52            on_request_headers_invocations: AtomicU64::new(0),
53            on_request_headers_errors: AtomicU64::new(0),
54            on_request_duration_us: AtomicU64::new(0),
55            on_request_invocations: AtomicU64::new(0),
56            on_request_errors: AtomicU64::new(0),
57            on_response_headers_duration_us: AtomicU64::new(0),
58            on_response_headers_invocations: AtomicU64::new(0),
59            on_response_headers_errors: AtomicU64::new(0),
60            on_response_duration_us: AtomicU64::new(0),
61            on_response_invocations: AtomicU64::new(0),
62            on_response_errors: AtomicU64::new(0),
63            on_websocket_message_duration_us: AtomicU64::new(0),
64            on_websocket_message_invocations: AtomicU64::new(0),
65            on_websocket_message_errors: AtomicU64::new(0),
66        }
67    }
68}
69
70impl ScriptMetrics {
71    pub fn prometheus_lines(&self) -> String {
72        let mut out = String::new();
73        macro_rules! push_metric {
74            ($name:expr, $dur:expr, $inv:expr, $err:expr) => {
75                out.push_str(&format!(
76                    "relay_core_script_hook_duration_us{{hook=\"{}\"}} {}\n",
77                    $name,
78                    $dur.load(Ordering::Relaxed)
79                ));
80                out.push_str(&format!(
81                    "relay_core_script_hook_invocations_total{{hook=\"{}\"}} {}\n",
82                    $name,
83                    $inv.load(Ordering::Relaxed)
84                ));
85                out.push_str(&format!(
86                    "relay_core_script_hook_errors_total{{hook=\"{}\"}} {}\n",
87                    $name,
88                    $err.load(Ordering::Relaxed)
89                ));
90            };
91        }
92        push_metric!(
93            "onRequestHeaders",
94            &self.on_request_headers_duration_us,
95            &self.on_request_headers_invocations,
96            &self.on_request_headers_errors
97        );
98        push_metric!(
99            "onRequest",
100            &self.on_request_duration_us,
101            &self.on_request_invocations,
102            &self.on_request_errors
103        );
104        push_metric!(
105            "onResponseHeaders",
106            &self.on_response_headers_duration_us,
107            &self.on_response_headers_invocations,
108            &self.on_response_headers_errors
109        );
110        push_metric!(
111            "onResponse",
112            &self.on_response_duration_us,
113            &self.on_response_invocations,
114            &self.on_response_errors
115        );
116        push_metric!(
117            "onWebSocketMessage",
118            &self.on_websocket_message_duration_us,
119            &self.on_websocket_message_invocations,
120            &self.on_websocket_message_errors
121        );
122        out
123    }
124}
125
126pub struct ScriptInterceptor {
127    engines: Vec<RwLock<Box<dyn ScriptEngineTrait>>>,
128    pub metrics: ScriptMetrics,
129    env_allow: RwLock<HashSet<String>>,
130}
131
132impl ScriptInterceptor {
133    pub async fn new() -> Result<Self, BoxError> {
134        Self::new_with_env(HashSet::new()).await
135    }
136
137    pub async fn new_with_env(env_allow: HashSet<String>) -> Result<Self, BoxError> {
138        let pool_size = std::thread::available_parallelism()
139            .map(|n| n.get())
140            .unwrap_or(4);
141        let mut engines = Vec::with_capacity(pool_size);
142
143        for _ in 0..pool_size {
144            let engine: Box<dyn ScriptEngineTrait> =
145                Box::new(DenoScriptEngine::new(env_allow.clone()));
146            engines.push(RwLock::new(engine));
147        }
148
149        Ok(Self {
150            engines,
151            metrics: ScriptMetrics::default(),
152            env_allow: RwLock::new(env_allow),
153        })
154    }
155
156    pub async fn set_env_allow(&self, env_allow: HashSet<String>) {
157        let mut guard = self.env_allow.write().await;
158        *guard = env_allow;
159    }
160
161    pub async fn load_script(&self, script: &str) -> Result<(), BoxError> {
162        // Load script into ALL engines to keep them consistent.
163        let pool_size = self.engines.len();
164        let mut new_engines = Vec::with_capacity(pool_size);
165
166        let env = self.env_allow.read().await.clone();
167        for _ in 0..pool_size {
168            let mut engine = DenoScriptEngine::new(env.clone());
169            engine.load_script(script).await?;
170            new_engines.push(Box::new(engine) as Box<dyn ScriptEngineTrait>);
171        }
172
173        // We acquire write locks only for the swap.
174        // The new engines are already prepared, so the critical section is very short.
175        let mut new_engines_iter = new_engines.into_iter();
176        for engine_lock in &self.engines {
177            if let Some(new_engine) = new_engines_iter.next() {
178                let mut guard = engine_lock.write().await;
179                *guard = new_engine;
180            }
181        }
182
183        Ok(())
184    }
185
186    fn get_engine_index(&self) -> usize {
187        // Optimization: Allow task-local override for engine index (e.g. for testing or specific routing)
188        if let Ok(index) = relay_core_lib::interceptor::ENGINE_INDEX.try_with(|i| *i) {
189            return index % self.engines.len();
190        }
191
192        let thread_id = std::thread::current().id();
193        let mut hasher = DefaultHasher::new();
194        thread_id.hash(&mut hasher);
195        (hasher.finish() as usize) % self.engines.len()
196    }
197}
198
199#[async_trait]
200impl Interceptor for ScriptInterceptor {
201    async fn on_request_headers(&self, flow: &mut Flow) -> InterceptionResult {
202        let start = Instant::now();
203        let index = self.get_engine_index();
204        let engine_lock = &self.engines[index];
205        let engine = engine_lock.read().await;
206
207        let result = match engine.on_request_headers(flow).await {
208            Ok(Some(modified_flow)) => {
209                *flow = modified_flow;
210                InterceptionResult::ModifiedRequest(match &flow.layer {
211                    relay_core_api::flow::Layer::Http(h) => h.request.clone(),
212                    relay_core_api::flow::Layer::WebSocket(w) => w.handshake_request.clone(),
213                    _ => {
214                        self.metrics
215                            .on_request_headers_errors
216                            .fetch_add(1, Ordering::Relaxed);
217                        return InterceptionResult::Continue;
218                    }
219                })
220            }
221            Ok(None) => InterceptionResult::Continue,
222            Err(e) => {
223                tracing::error!("Script execution error (on_request_headers): {}", e);
224                flow.tags.push("script-error".to_string());
225                if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
226                    http.error = Some(format!("Script Error: {}", e));
227                }
228                self.metrics
229                    .on_request_headers_errors
230                    .fetch_add(1, Ordering::Relaxed);
231                InterceptionResult::Continue
232            }
233        };
234        let dur_us = start.elapsed().as_micros() as u64;
235        self.metrics
236            .on_request_headers_duration_us
237            .fetch_add(dur_us, Ordering::Relaxed);
238        self.metrics
239            .on_request_headers_invocations
240            .fetch_add(1, Ordering::Relaxed);
241        result
242    }
243
244    async fn on_request(&self, flow: &mut Flow, body: HttpBody) -> Result<RequestAction, BoxError> {
245        let start = Instant::now();
246        let index = self.get_engine_index();
247        let engine_lock = &self.engines[index];
248        let engine = engine_lock.read().await;
249
250        match engine.on_request(flow, body).await {
251            Ok(action) => {
252                let dur_us = start.elapsed().as_micros() as u64;
253                self.metrics
254                    .on_request_duration_us
255                    .fetch_add(dur_us, Ordering::Relaxed);
256                self.metrics
257                    .on_request_invocations
258                    .fetch_add(1, Ordering::Relaxed);
259                Ok(action)
260            }
261            Err(e) => {
262                self.metrics
263                    .on_request_errors
264                    .fetch_add(1, Ordering::Relaxed);
265                Err(e)
266            }
267        }
268    }
269
270    async fn on_response_headers(&self, flow: &mut Flow) -> InterceptionResult {
271        let start = Instant::now();
272        let index = self.get_engine_index();
273        let engine_lock = &self.engines[index];
274        let engine = engine_lock.read().await;
275
276        let result = match engine.on_response_headers(flow).await {
277            Ok(Some(modified_flow)) => {
278                *flow = modified_flow;
279                InterceptionResult::ModifiedResponse(match &flow.layer {
280                    relay_core_api::flow::Layer::Http(h) => h.response.clone().unwrap(),
281                    relay_core_api::flow::Layer::WebSocket(w) => w.handshake_response.clone(),
282                    _ => {
283                        self.metrics
284                            .on_response_headers_errors
285                            .fetch_add(1, Ordering::Relaxed);
286                        return InterceptionResult::Continue;
287                    }
288                })
289            }
290            Ok(None) => InterceptionResult::Continue,
291            Err(e) => {
292                tracing::error!("Script execution error (on_response_headers): {}", e);
293                flow.tags.push("script-error".to_string());
294                if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
295                    http.error = Some(format!("Script Error: {}", e));
296                }
297                self.metrics
298                    .on_response_headers_errors
299                    .fetch_add(1, Ordering::Relaxed);
300                InterceptionResult::Continue
301            }
302        };
303        let dur_us = start.elapsed().as_micros() as u64;
304        self.metrics
305            .on_response_headers_duration_us
306            .fetch_add(dur_us, Ordering::Relaxed);
307        self.metrics
308            .on_response_headers_invocations
309            .fetch_add(1, Ordering::Relaxed);
310        result
311    }
312
313    async fn on_response(
314        &self,
315        flow: &mut Flow,
316        body: HttpBody,
317    ) -> Result<ResponseAction, BoxError> {
318        let start = Instant::now();
319        let index = self.get_engine_index();
320        let engine_lock = &self.engines[index];
321        let engine = engine_lock.read().await;
322
323        match engine.on_response(flow, body).await {
324            Ok(action) => {
325                let dur_us = start.elapsed().as_micros() as u64;
326                self.metrics
327                    .on_response_duration_us
328                    .fetch_add(dur_us, Ordering::Relaxed);
329                self.metrics
330                    .on_response_invocations
331                    .fetch_add(1, Ordering::Relaxed);
332                Ok(action)
333            }
334            Err(e) => {
335                self.metrics
336                    .on_response_errors
337                    .fetch_add(1, Ordering::Relaxed);
338                Err(e)
339            }
340        }
341    }
342
343    async fn on_websocket_message(
344        &self,
345        flow: &mut Flow,
346        mut message: WebSocketMessage,
347    ) -> Result<WebSocketMessageAction, BoxError> {
348        let start = Instant::now();
349        let index = self.get_engine_index();
350        let engine_lock = &self.engines[index];
351        let engine = engine_lock.read().await;
352
353        match engine.on_websocket_message(flow, &mut message).await {
354            Ok(action) => {
355                let dur_us = start.elapsed().as_micros() as u64;
356                self.metrics
357                    .on_websocket_message_duration_us
358                    .fetch_add(dur_us, Ordering::Relaxed);
359                self.metrics
360                    .on_websocket_message_invocations
361                    .fetch_add(1, Ordering::Relaxed);
362                Ok(action)
363            }
364            Err(e) => {
365                tracing::error!("Script execution error (on_websocket_message): {}", e);
366                flow.tags.push("script-error".to_string());
367                self.metrics
368                    .on_websocket_message_errors
369                    .fetch_add(1, Ordering::Relaxed);
370                Ok(WebSocketMessageAction::Continue(message))
371            }
372        }
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use bytes::Bytes;
380    use chrono::Utc;
381    use http_body_util::{BodyExt, Empty};
382    use relay_core_api::flow::{
383        BodyData, Direction, Flow, HttpLayer, HttpRequest, Layer, NetworkInfo, TransportProtocol,
384        WebSocketMessage,
385    };
386    use std::collections::HashMap;
387    use url::Url;
388    use uuid::Uuid;
389
390    fn create_test_flow() -> Flow {
391        Flow {
392            id: Uuid::new_v4(),
393            start_time: Utc::now(),
394            end_time: None,
395            network: NetworkInfo {
396                client_ip: "127.0.0.1".to_string(),
397                client_port: 12345,
398                server_ip: "1.1.1.1".to_string(),
399                server_port: 80,
400                protocol: TransportProtocol::TCP,
401                tls: false,
402                tls_version: None,
403                sni: None,
404            },
405            layer: Layer::Http(HttpLayer {
406                request: HttpRequest {
407                    method: "GET".to_string(),
408                    url: Url::parse("http://example.com").unwrap(),
409                    version: "HTTP/1.1".to_string(),
410                    headers: vec![],
411                    cookies: vec![],
412                    query: vec![],
413                    body: None,
414                },
415                response: None,
416                error: None,
417            }),
418            tags: vec![],
419            meta: HashMap::new(),
420        }
421    }
422
423    #[tokio::test]
424    async fn test_script_error_propagation() {
425        let interceptor = ScriptInterceptor::new().await.unwrap();
426
427        let script = r#"
428            globalThis.onRequestHeaders = (flow) => {
429                throw new Error("Test Error 123");
430            };
431        "#;
432        interceptor.load_script(script).await.unwrap();
433
434        let mut flow = create_test_flow();
435
436        let result = interceptor.on_request_headers(&mut flow).await;
437
438        match result {
439            InterceptionResult::Continue => {}
440            _ => panic!("Expected Continue"),
441        }
442
443        assert!(flow.tags.contains(&"script-error".to_string()));
444
445        if let Layer::Http(http) = &flow.layer {
446            assert!(http.error.is_some());
447            let err = http.error.as_ref().unwrap();
448            assert!(err.contains("Test Error 123"));
449        } else {
450            panic!("Expected Http layer");
451        }
452    }
453
454    #[tokio::test]
455    async fn test_script_api_relay_body() {
456        let interceptor = ScriptInterceptor::new().await.unwrap();
457
458        let script = r#"
459            globalThis.onRequest = (body, flow) => {
460                if (!(body instanceof RelayBody)) {
461                    throw new Error("First argument is not RelayBody");
462                }
463                // We return nothing (undefined), which means "continue with original body"
464                // But we successfully verified the type.
465            };
466        "#;
467        interceptor.load_script(script).await.unwrap();
468
469        let mut flow = create_test_flow();
470        let body = Empty::<Bytes>::new()
471            .map_err(|_| -> BoxError { unreachable!() })
472            .boxed();
473
474        let result = interceptor.on_request(&mut flow, body).await;
475        assert!(result.is_ok());
476    }
477
478    #[tokio::test]
479    async fn test_load_script_failure_does_not_replace_existing_engines() {
480        let interceptor = ScriptInterceptor::new().await.unwrap();
481
482        let good_script = r#"
483            globalThis.onRequestHeaders = (_context, flow) => {
484                if (flow.layer.type === "Http") {
485                    flow.layer.data.request.headers.push(["X-Good-Script", "1"]);
486                }
487                return flow;
488            };
489        "#;
490        interceptor
491            .load_script(good_script)
492            .await
493            .expect("good script should load");
494
495        let bad_script = "globalThis.onRequestHeaders = () => { invalid javascript !!!";
496        let bad = interceptor.load_script(bad_script).await;
497        assert!(bad.is_err(), "bad script should be rejected");
498
499        let mut flow = create_test_flow();
500        let result = interceptor.on_request_headers(&mut flow).await;
501        assert!(
502            matches!(result, InterceptionResult::ModifiedRequest(_)),
503            "existing good script should still be active after failed reload"
504        );
505        if let Layer::Http(http) = &flow.layer {
506            assert!(
507                http.request
508                    .headers
509                    .iter()
510                    .any(|(k, v)| k == "X-Good-Script" && v == "1")
511            );
512        } else {
513            panic!("Expected Http layer");
514        }
515    }
516
517    #[tokio::test]
518    async fn test_websocket_script_error_falls_back_continue_and_tags_flow() {
519        let interceptor = ScriptInterceptor::new().await.unwrap();
520        let script = r#"
521            globalThis.onWebSocketMessage = function(_context, _flow, _message) {
522                throw new Error("ws explode");
523            };
524        "#;
525        interceptor
526            .load_script(script)
527            .await
528            .expect("script should load");
529
530        let mut flow = create_test_flow();
531        let msg = WebSocketMessage {
532            id: Uuid::new_v4(),
533            timestamp: Utc::now(),
534            direction: Direction::ClientToServer,
535            content: BodyData {
536                encoding: "utf-8".to_string(),
537                content: "hello".to_string(),
538                size: 5,
539            },
540            opcode: "Text".to_string(),
541        };
542
543        let result = interceptor
544            .on_websocket_message(&mut flow, msg.clone())
545            .await
546            .expect("websocket interception should not return hard error");
547        match result {
548            WebSocketMessageAction::Continue(forwarded) => {
549                assert_eq!(forwarded.content.content, "hello");
550            }
551            other => panic!("expected Continue fallback, got {:?}", other),
552        }
553        assert!(
554            flow.tags.iter().any(|t| t == "script-error"),
555            "script error should be tagged for observability"
556        );
557    }
558}