Skip to main content

relay_core_script/
lib.rs

1//! Internal Deno/V8 scripting engine for [relay-core](https://crates.io/crates/relay-core).
2//! Provides the `ScriptInterceptor` that implements runtime script modification of traffic.
3//!
4//! **This is a feature backend for `relay-core`.** Enable with `relay-core = { features = ["script"] }`.
5
6pub mod deno_engine;
7pub mod engine_trait;
8pub mod streams;
9
10use crate::deno_engine::DenoScriptEngine;
11use crate::engine_trait::ScriptEngineTrait;
12use async_trait::async_trait;
13use relay_core_api::flow::{Flow, WebSocketMessage};
14use relay_core_lib::interceptor::{
15    BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, ResponseAction,
16    WebSocketMessageAction,
17};
18use std::collections::hash_map::DefaultHasher;
19use std::hash::{Hash, Hasher};
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::time::Instant;
22use tokio::sync::RwLock;
23
24/// S4: Per-hook script execution metrics exposed via Prometheus/text.
25pub struct ScriptMetrics {
26    pub on_request_headers_duration_us: AtomicU64,
27    pub on_request_headers_invocations: AtomicU64,
28    pub on_request_headers_errors: AtomicU64,
29
30    pub on_request_duration_us: AtomicU64,
31    pub on_request_invocations: AtomicU64,
32    pub on_request_errors: AtomicU64,
33
34    pub on_response_headers_duration_us: AtomicU64,
35    pub on_response_headers_invocations: AtomicU64,
36    pub on_response_headers_errors: AtomicU64,
37
38    pub on_response_duration_us: AtomicU64,
39    pub on_response_invocations: AtomicU64,
40    pub on_response_errors: AtomicU64,
41
42    pub on_websocket_message_duration_us: AtomicU64,
43    pub on_websocket_message_invocations: AtomicU64,
44    pub on_websocket_message_errors: AtomicU64,
45}
46
47impl Default for ScriptMetrics {
48    fn default() -> Self {
49        Self {
50            on_request_headers_duration_us: AtomicU64::new(0),
51            on_request_headers_invocations: AtomicU64::new(0),
52            on_request_headers_errors: AtomicU64::new(0),
53            on_request_duration_us: AtomicU64::new(0),
54            on_request_invocations: AtomicU64::new(0),
55            on_request_errors: AtomicU64::new(0),
56            on_response_headers_duration_us: AtomicU64::new(0),
57            on_response_headers_invocations: AtomicU64::new(0),
58            on_response_headers_errors: AtomicU64::new(0),
59            on_response_duration_us: AtomicU64::new(0),
60            on_response_invocations: AtomicU64::new(0),
61            on_response_errors: AtomicU64::new(0),
62            on_websocket_message_duration_us: AtomicU64::new(0),
63            on_websocket_message_invocations: AtomicU64::new(0),
64            on_websocket_message_errors: AtomicU64::new(0),
65        }
66    }
67}
68
69impl ScriptMetrics {
70    pub fn prometheus_lines(&self) -> String {
71        let mut out = String::new();
72        macro_rules! push_metric {
73            ($name:expr, $dur:expr, $inv:expr, $err:expr) => {
74                out.push_str(&format!(
75                    "relay_core_script_hook_duration_us{{hook=\"{}\"}} {}\n",
76                    $name,
77                    $dur.load(Ordering::Relaxed)
78                ));
79                out.push_str(&format!(
80                    "relay_core_script_hook_invocations_total{{hook=\"{}\"}} {}\n",
81                    $name,
82                    $inv.load(Ordering::Relaxed)
83                ));
84                out.push_str(&format!(
85                    "relay_core_script_hook_errors_total{{hook=\"{}\"}} {}\n",
86                    $name,
87                    $err.load(Ordering::Relaxed)
88                ));
89            };
90        }
91        push_metric!(
92            "onRequestHeaders",
93            &self.on_request_headers_duration_us,
94            &self.on_request_headers_invocations,
95            &self.on_request_headers_errors
96        );
97        push_metric!(
98            "onRequest",
99            &self.on_request_duration_us,
100            &self.on_request_invocations,
101            &self.on_request_errors
102        );
103        push_metric!(
104            "onResponseHeaders",
105            &self.on_response_headers_duration_us,
106            &self.on_response_headers_invocations,
107            &self.on_response_headers_errors
108        );
109        push_metric!(
110            "onResponse",
111            &self.on_response_duration_us,
112            &self.on_response_invocations,
113            &self.on_response_errors
114        );
115        push_metric!(
116            "onWebSocketMessage",
117            &self.on_websocket_message_duration_us,
118            &self.on_websocket_message_invocations,
119            &self.on_websocket_message_errors
120        );
121        out
122    }
123}
124
125pub struct ScriptInterceptor {
126    engines: Vec<RwLock<Box<dyn ScriptEngineTrait>>>,
127    pub metrics: ScriptMetrics,
128}
129
130impl ScriptInterceptor {
131    pub async fn new() -> Result<Self, BoxError> {
132        let pool_size = std::thread::available_parallelism()
133            .map(|n| n.get())
134            .unwrap_or(4);
135        let mut engines = Vec::with_capacity(pool_size);
136
137        for _ in 0..pool_size {
138            let engine: Box<dyn ScriptEngineTrait> = Box::new(DenoScriptEngine::new());
139            engines.push(RwLock::new(engine));
140        }
141
142        Ok(Self {
143            engines,
144            metrics: ScriptMetrics::default(),
145        })
146    }
147
148    pub async fn load_script(&self, script: &str) -> Result<(), BoxError> {
149        // Load script into ALL engines to keep them consistent.
150        // Optimization: Load into new engines first to avoid blocking request processing,
151        // then swap them in quickly.
152        let pool_size = self.engines.len();
153        let mut new_engines = Vec::with_capacity(pool_size);
154
155        for _ in 0..pool_size {
156            let mut engine = DenoScriptEngine::new();
157            engine.load_script(script).await?;
158            new_engines.push(Box::new(engine) as Box<dyn ScriptEngineTrait>);
159        }
160
161        // We acquire write locks only for the swap.
162        // The new engines are already prepared, so the critical section is very short.
163        let mut new_engines_iter = new_engines.into_iter();
164        for engine_lock in &self.engines {
165            if let Some(new_engine) = new_engines_iter.next() {
166                let mut guard = engine_lock.write().await;
167                *guard = new_engine;
168            }
169        }
170
171        Ok(())
172    }
173
174    fn get_engine_index(&self) -> usize {
175        // Optimization: Allow task-local override for engine index (e.g. for testing or specific routing)
176        if let Ok(index) = relay_core_lib::interceptor::ENGINE_INDEX.try_with(|i| *i) {
177            return index % self.engines.len();
178        }
179
180        let thread_id = std::thread::current().id();
181        let mut hasher = DefaultHasher::new();
182        thread_id.hash(&mut hasher);
183        (hasher.finish() as usize) % self.engines.len()
184    }
185}
186
187#[async_trait]
188impl Interceptor for ScriptInterceptor {
189    async fn on_request_headers(&self, flow: &mut Flow) -> InterceptionResult {
190        let start = Instant::now();
191        let index = self.get_engine_index();
192        let engine_lock = &self.engines[index];
193        let engine = engine_lock.read().await;
194
195        let result = match engine.on_request_headers(flow).await {
196            Ok(Some(modified_flow)) => {
197                *flow = modified_flow;
198                InterceptionResult::ModifiedRequest(match &flow.layer {
199                    relay_core_api::flow::Layer::Http(h) => h.request.clone(),
200                    relay_core_api::flow::Layer::WebSocket(w) => w.handshake_request.clone(),
201                    _ => {
202                        self.metrics
203                            .on_request_headers_errors
204                            .fetch_add(1, Ordering::Relaxed);
205                        return InterceptionResult::Continue;
206                    }
207                })
208            }
209            Ok(None) => InterceptionResult::Continue,
210            Err(e) => {
211                tracing::error!("Script execution error (on_request_headers): {}", e);
212                flow.tags.push("script-error".to_string());
213                if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
214                    http.error = Some(format!("Script Error: {}", e));
215                }
216                self.metrics
217                    .on_request_headers_errors
218                    .fetch_add(1, Ordering::Relaxed);
219                InterceptionResult::Continue
220            }
221        };
222        let dur_us = start.elapsed().as_micros() as u64;
223        self.metrics
224            .on_request_headers_duration_us
225            .fetch_add(dur_us, Ordering::Relaxed);
226        self.metrics
227            .on_request_headers_invocations
228            .fetch_add(1, Ordering::Relaxed);
229        result
230    }
231
232    async fn on_request(&self, flow: &mut Flow, body: HttpBody) -> Result<RequestAction, BoxError> {
233        let start = Instant::now();
234        let index = self.get_engine_index();
235        let engine_lock = &self.engines[index];
236        let engine = engine_lock.read().await;
237
238        match engine.on_request(flow, body).await {
239            Ok(action) => {
240                let dur_us = start.elapsed().as_micros() as u64;
241                self.metrics
242                    .on_request_duration_us
243                    .fetch_add(dur_us, Ordering::Relaxed);
244                self.metrics
245                    .on_request_invocations
246                    .fetch_add(1, Ordering::Relaxed);
247                Ok(action)
248            }
249            Err(e) => {
250                self.metrics
251                    .on_request_errors
252                    .fetch_add(1, Ordering::Relaxed);
253                Err(e)
254            }
255        }
256    }
257
258    async fn on_response_headers(&self, flow: &mut Flow) -> InterceptionResult {
259        let start = Instant::now();
260        let index = self.get_engine_index();
261        let engine_lock = &self.engines[index];
262        let engine = engine_lock.read().await;
263
264        let result = match engine.on_response_headers(flow).await {
265            Ok(Some(modified_flow)) => {
266                *flow = modified_flow;
267                InterceptionResult::ModifiedResponse(match &flow.layer {
268                    relay_core_api::flow::Layer::Http(h) => h.response.clone().unwrap(),
269                    relay_core_api::flow::Layer::WebSocket(w) => w.handshake_response.clone(),
270                    _ => {
271                        self.metrics
272                            .on_response_headers_errors
273                            .fetch_add(1, Ordering::Relaxed);
274                        return InterceptionResult::Continue;
275                    }
276                })
277            }
278            Ok(None) => InterceptionResult::Continue,
279            Err(e) => {
280                tracing::error!("Script execution error (on_response_headers): {}", e);
281                flow.tags.push("script-error".to_string());
282                if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
283                    http.error = Some(format!("Script Error: {}", e));
284                }
285                self.metrics
286                    .on_response_headers_errors
287                    .fetch_add(1, Ordering::Relaxed);
288                InterceptionResult::Continue
289            }
290        };
291        let dur_us = start.elapsed().as_micros() as u64;
292        self.metrics
293            .on_response_headers_duration_us
294            .fetch_add(dur_us, Ordering::Relaxed);
295        self.metrics
296            .on_response_headers_invocations
297            .fetch_add(1, Ordering::Relaxed);
298        result
299    }
300
301    async fn on_response(
302        &self,
303        flow: &mut Flow,
304        body: HttpBody,
305    ) -> Result<ResponseAction, BoxError> {
306        let start = Instant::now();
307        let index = self.get_engine_index();
308        let engine_lock = &self.engines[index];
309        let engine = engine_lock.read().await;
310
311        match engine.on_response(flow, body).await {
312            Ok(action) => {
313                let dur_us = start.elapsed().as_micros() as u64;
314                self.metrics
315                    .on_response_duration_us
316                    .fetch_add(dur_us, Ordering::Relaxed);
317                self.metrics
318                    .on_response_invocations
319                    .fetch_add(1, Ordering::Relaxed);
320                Ok(action)
321            }
322            Err(e) => {
323                self.metrics
324                    .on_response_errors
325                    .fetch_add(1, Ordering::Relaxed);
326                Err(e)
327            }
328        }
329    }
330
331    async fn on_websocket_message(
332        &self,
333        flow: &mut Flow,
334        mut message: WebSocketMessage,
335    ) -> Result<WebSocketMessageAction, BoxError> {
336        let start = Instant::now();
337        let index = self.get_engine_index();
338        let engine_lock = &self.engines[index];
339        let engine = engine_lock.read().await;
340
341        match engine.on_websocket_message(flow, &mut message).await {
342            Ok(action) => {
343                let dur_us = start.elapsed().as_micros() as u64;
344                self.metrics
345                    .on_websocket_message_duration_us
346                    .fetch_add(dur_us, Ordering::Relaxed);
347                self.metrics
348                    .on_websocket_message_invocations
349                    .fetch_add(1, Ordering::Relaxed);
350                Ok(action)
351            }
352            Err(e) => {
353                tracing::error!("Script execution error (on_websocket_message): {}", e);
354                flow.tags.push("script-error".to_string());
355                self.metrics
356                    .on_websocket_message_errors
357                    .fetch_add(1, Ordering::Relaxed);
358                Ok(WebSocketMessageAction::Continue(message))
359            }
360        }
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use bytes::Bytes;
368    use chrono::Utc;
369    use http_body_util::{BodyExt, Empty};
370    use relay_core_api::flow::{
371        BodyData, Direction, Flow, HttpLayer, HttpRequest, Layer, NetworkInfo, TransportProtocol,
372        WebSocketMessage,
373    };
374    use std::collections::HashMap;
375    use url::Url;
376    use uuid::Uuid;
377
378    fn create_test_flow() -> Flow {
379        Flow {
380            id: Uuid::new_v4(),
381            start_time: Utc::now(),
382            end_time: None,
383            network: NetworkInfo {
384                client_ip: "127.0.0.1".to_string(),
385                client_port: 12345,
386                server_ip: "1.1.1.1".to_string(),
387                server_port: 80,
388                protocol: TransportProtocol::TCP,
389                tls: false,
390                tls_version: None,
391                sni: None,
392            },
393            layer: Layer::Http(HttpLayer {
394                request: HttpRequest {
395                    method: "GET".to_string(),
396                    url: Url::parse("http://example.com").unwrap(),
397                    version: "HTTP/1.1".to_string(),
398                    headers: vec![],
399                    cookies: vec![],
400                    query: vec![],
401                    body: None,
402                },
403                response: None,
404                error: None,
405            }),
406            tags: vec![],
407            meta: HashMap::new(),
408        }
409    }
410
411    #[tokio::test]
412    async fn test_script_error_propagation() {
413        let interceptor = ScriptInterceptor::new().await.unwrap();
414
415        let script = r#"
416            globalThis.onRequestHeaders = (flow) => {
417                throw new Error("Test Error 123");
418            };
419        "#;
420        interceptor.load_script(script).await.unwrap();
421
422        let mut flow = create_test_flow();
423
424        let result = interceptor.on_request_headers(&mut flow).await;
425
426        match result {
427            InterceptionResult::Continue => {}
428            _ => panic!("Expected Continue"),
429        }
430
431        assert!(flow.tags.contains(&"script-error".to_string()));
432
433        if let Layer::Http(http) = &flow.layer {
434            assert!(http.error.is_some());
435            let err = http.error.as_ref().unwrap();
436            assert!(err.contains("Test Error 123"));
437        } else {
438            panic!("Expected Http layer");
439        }
440    }
441
442    #[tokio::test]
443    async fn test_script_api_relay_body() {
444        let interceptor = ScriptInterceptor::new().await.unwrap();
445
446        let script = r#"
447            globalThis.onRequest = (body, flow) => {
448                if (!(body instanceof RelayBody)) {
449                    throw new Error("First argument is not RelayBody");
450                }
451                // We return nothing (undefined), which means "continue with original body"
452                // But we successfully verified the type.
453            };
454        "#;
455        interceptor.load_script(script).await.unwrap();
456
457        let mut flow = create_test_flow();
458        let body = Empty::<Bytes>::new()
459            .map_err(|_| -> BoxError { unreachable!() })
460            .boxed();
461
462        let result = interceptor.on_request(&mut flow, body).await;
463        assert!(result.is_ok());
464    }
465
466    #[tokio::test]
467    async fn test_load_script_failure_does_not_replace_existing_engines() {
468        let interceptor = ScriptInterceptor::new().await.unwrap();
469
470        let good_script = r#"
471            globalThis.onRequestHeaders = (_context, flow) => {
472                if (flow.layer.type === "Http") {
473                    flow.layer.data.request.headers.push(["X-Good-Script", "1"]);
474                }
475                return flow;
476            };
477        "#;
478        interceptor
479            .load_script(good_script)
480            .await
481            .expect("good script should load");
482
483        let bad_script = "globalThis.onRequestHeaders = () => { invalid javascript !!!";
484        let bad = interceptor.load_script(bad_script).await;
485        assert!(bad.is_err(), "bad script should be rejected");
486
487        let mut flow = create_test_flow();
488        let result = interceptor.on_request_headers(&mut flow).await;
489        assert!(
490            matches!(result, InterceptionResult::ModifiedRequest(_)),
491            "existing good script should still be active after failed reload"
492        );
493        if let Layer::Http(http) = &flow.layer {
494            assert!(
495                http.request
496                    .headers
497                    .iter()
498                    .any(|(k, v)| k == "X-Good-Script" && v == "1")
499            );
500        } else {
501            panic!("Expected Http layer");
502        }
503    }
504
505    #[tokio::test]
506    async fn test_websocket_script_error_falls_back_continue_and_tags_flow() {
507        let interceptor = ScriptInterceptor::new().await.unwrap();
508        let script = r#"
509            globalThis.onWebSocketMessage = function(_context, _flow, _message) {
510                throw new Error("ws explode");
511            };
512        "#;
513        interceptor
514            .load_script(script)
515            .await
516            .expect("script should load");
517
518        let mut flow = create_test_flow();
519        let msg = WebSocketMessage {
520            id: Uuid::new_v4(),
521            timestamp: Utc::now(),
522            direction: Direction::ClientToServer,
523            content: BodyData {
524                encoding: "utf-8".to_string(),
525                content: "hello".to_string(),
526                size: 5,
527            },
528            opcode: "Text".to_string(),
529        };
530
531        let result = interceptor
532            .on_websocket_message(&mut flow, msg.clone())
533            .await
534            .expect("websocket interception should not return hard error");
535        match result {
536            WebSocketMessageAction::Continue(forwarded) => {
537                assert_eq!(forwarded.content.content, "hello");
538            }
539            other => panic!("expected Continue fallback, got {:?}", other),
540        }
541        assert!(
542            flow.tags.iter().any(|t| t == "script-error"),
543            "script error should be tagged for observability"
544        );
545    }
546}