Skip to main content

relay_core_script/
lib.rs

1//! Internal Deno/V8 scripting engine for [relay-core](https://crates.io/crates/relay-core).
2//! Provides the `ScriptInterceptor` that implements runtime script modification of traffic.
3//!
4//! **This is a feature backend for `relay-core`.** Enable with `relay-core = { features = ["script"] }`.
5
6pub mod deno_engine;
7pub mod engine_trait;
8pub mod streams;
9
10pub use deno_engine::ScriptFetchConfig;
11
12use crate::deno_engine::DenoScriptEngine;
13use crate::engine_trait::ScriptEngineTrait;
14use async_trait::async_trait;
15use relay_core_api::flow::{Flow, WebSocketMessage};
16use relay_core_lib::interceptor::{
17    BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, ResponseAction,
18    WebSocketMessageAction,
19};
20use std::collections::HashSet;
21use std::collections::hash_map::DefaultHasher;
22use std::hash::{Hash, Hasher};
23use std::sync::atomic::{AtomicU64, Ordering};
24use std::time::Instant;
25use tokio::sync::RwLock;
26
27/// S4: Per-hook script execution metrics exposed via Prometheus/text.
28pub struct ScriptMetrics {
29    pub on_request_headers_duration_us: AtomicU64,
30    pub on_request_headers_invocations: AtomicU64,
31    pub on_request_headers_errors: AtomicU64,
32
33    pub on_request_duration_us: AtomicU64,
34    pub on_request_invocations: AtomicU64,
35    pub on_request_errors: AtomicU64,
36
37    pub on_response_headers_duration_us: AtomicU64,
38    pub on_response_headers_invocations: AtomicU64,
39    pub on_response_headers_errors: AtomicU64,
40
41    pub on_response_duration_us: AtomicU64,
42    pub on_response_invocations: AtomicU64,
43    pub on_response_errors: AtomicU64,
44
45    pub on_websocket_message_duration_us: AtomicU64,
46    pub on_websocket_message_invocations: AtomicU64,
47    pub on_websocket_message_errors: AtomicU64,
48}
49
50impl Default for ScriptMetrics {
51    fn default() -> Self {
52        Self {
53            on_request_headers_duration_us: AtomicU64::new(0),
54            on_request_headers_invocations: AtomicU64::new(0),
55            on_request_headers_errors: AtomicU64::new(0),
56            on_request_duration_us: AtomicU64::new(0),
57            on_request_invocations: AtomicU64::new(0),
58            on_request_errors: AtomicU64::new(0),
59            on_response_headers_duration_us: AtomicU64::new(0),
60            on_response_headers_invocations: AtomicU64::new(0),
61            on_response_headers_errors: AtomicU64::new(0),
62            on_response_duration_us: AtomicU64::new(0),
63            on_response_invocations: AtomicU64::new(0),
64            on_response_errors: AtomicU64::new(0),
65            on_websocket_message_duration_us: AtomicU64::new(0),
66            on_websocket_message_invocations: AtomicU64::new(0),
67            on_websocket_message_errors: AtomicU64::new(0),
68        }
69    }
70}
71
72impl ScriptMetrics {
73    pub fn prometheus_lines(&self) -> String {
74        let mut out = String::new();
75        macro_rules! push_metric {
76            ($name:expr, $dur:expr, $inv:expr, $err:expr) => {
77                out.push_str(&format!(
78                    "relay_core_script_hook_duration_us{{hook=\"{}\"}} {}\n",
79                    $name,
80                    $dur.load(Ordering::Relaxed)
81                ));
82                out.push_str(&format!(
83                    "relay_core_script_hook_invocations_total{{hook=\"{}\"}} {}\n",
84                    $name,
85                    $inv.load(Ordering::Relaxed)
86                ));
87                out.push_str(&format!(
88                    "relay_core_script_hook_errors_total{{hook=\"{}\"}} {}\n",
89                    $name,
90                    $err.load(Ordering::Relaxed)
91                ));
92            };
93        }
94        push_metric!(
95            "onRequestHeaders",
96            &self.on_request_headers_duration_us,
97            &self.on_request_headers_invocations,
98            &self.on_request_headers_errors
99        );
100        push_metric!(
101            "onRequest",
102            &self.on_request_duration_us,
103            &self.on_request_invocations,
104            &self.on_request_errors
105        );
106        push_metric!(
107            "onResponseHeaders",
108            &self.on_response_headers_duration_us,
109            &self.on_response_headers_invocations,
110            &self.on_response_headers_errors
111        );
112        push_metric!(
113            "onResponse",
114            &self.on_response_duration_us,
115            &self.on_response_invocations,
116            &self.on_response_errors
117        );
118        push_metric!(
119            "onWebSocketMessage",
120            &self.on_websocket_message_duration_us,
121            &self.on_websocket_message_invocations,
122            &self.on_websocket_message_errors
123        );
124        out.push_str(&format!(
125            "relay_core_script_env_access_total {}\n",
126            deno_engine::get_script_env_access_total()
127        ));
128        out.push_str(&format!(
129            "relay_core_script_fetch_total {}\n",
130            deno_engine::get_script_fetch_total()
131        ));
132        out.push_str(&format!(
133            "relay_core_script_fetch_rejected_total {}\n",
134            deno_engine::get_script_fetch_rejected_total()
135        ));
136        out
137    }
138}
139
140pub struct ScriptInterceptor {
141    engines: Vec<RwLock<Box<dyn ScriptEngineTrait>>>,
142    pub metrics: ScriptMetrics,
143    env_allow: RwLock<HashSet<String>>,
144    fetch_config: RwLock<deno_engine::ScriptFetchConfig>,
145}
146
147impl ScriptInterceptor {
148    pub async fn new() -> Result<Self, BoxError> {
149        Self::new_with_env(HashSet::new()).await
150    }
151
152    pub async fn new_with_env(env_allow: HashSet<String>) -> Result<Self, BoxError> {
153        Self::new_with_env_and_fetch(env_allow, deno_engine::ScriptFetchConfig::default()).await
154    }
155
156    pub async fn new_with_env_and_fetch(
157        env_allow: HashSet<String>,
158        fetch_config: deno_engine::ScriptFetchConfig,
159    ) -> Result<Self, BoxError> {
160        let pool_size = std::thread::available_parallelism()
161            .map(|n| n.get())
162            .unwrap_or(4);
163        let mut engines = Vec::with_capacity(pool_size);
164
165        for _ in 0..pool_size {
166            let engine: Box<dyn ScriptEngineTrait> = Box::new(DenoScriptEngine::new_with_fetch(
167                env_allow.clone(),
168                fetch_config.clone(),
169            ));
170            engines.push(RwLock::new(engine));
171        }
172
173        Ok(Self {
174            engines,
175            metrics: ScriptMetrics::default(),
176            env_allow: RwLock::new(env_allow),
177            fetch_config: RwLock::new(fetch_config),
178        })
179    }
180
181    pub async fn set_env_allow(&self, env_allow: HashSet<String>) {
182        let mut guard = self.env_allow.write().await;
183        *guard = env_allow;
184    }
185
186    pub async fn set_fetch_config(&self, config: deno_engine::ScriptFetchConfig) {
187        let mut guard = self.fetch_config.write().await;
188        *guard = config;
189    }
190
191    pub async fn load_script(&self, script: &str) -> Result<(), BoxError> {
192        // Load script into ALL engines to keep them consistent.
193        let pool_size = self.engines.len();
194        let mut new_engines = Vec::with_capacity(pool_size);
195
196        let env = self.env_allow.read().await.clone();
197        let fc = self.fetch_config.read().await.clone();
198        for _ in 0..pool_size {
199            let mut engine = DenoScriptEngine::new_with_fetch(env.clone(), fc.clone());
200            engine.load_script(script).await?;
201            new_engines.push(Box::new(engine) as Box<dyn ScriptEngineTrait>);
202        }
203
204        // We acquire write locks only for the swap.
205        // The new engines are already prepared, so the critical section is very short.
206        let mut new_engines_iter = new_engines.into_iter();
207        for engine_lock in &self.engines {
208            if let Some(new_engine) = new_engines_iter.next() {
209                let mut guard = engine_lock.write().await;
210                *guard = new_engine;
211            }
212        }
213
214        Ok(())
215    }
216
217    fn get_engine_index(&self) -> usize {
218        // Optimization: Allow task-local override for engine index (e.g. for testing or specific routing)
219        if let Ok(index) = relay_core_lib::interceptor::ENGINE_INDEX.try_with(|i| *i) {
220            return index % self.engines.len();
221        }
222
223        let thread_id = std::thread::current().id();
224        let mut hasher = DefaultHasher::new();
225        thread_id.hash(&mut hasher);
226        (hasher.finish() as usize) % self.engines.len()
227    }
228}
229
230#[async_trait]
231impl Interceptor for ScriptInterceptor {
232    async fn on_request_headers(&self, flow: &mut Flow) -> InterceptionResult {
233        let start = Instant::now();
234        let index = self.get_engine_index();
235        let engine_lock = &self.engines[index];
236        let engine = engine_lock.read().await;
237
238        let result = match engine.on_request_headers(flow).await {
239            Ok(Some(modified_flow)) => {
240                *flow = modified_flow;
241                InterceptionResult::ModifiedRequest(match &flow.layer {
242                    relay_core_api::flow::Layer::Http(h) => h.request.clone(),
243                    relay_core_api::flow::Layer::WebSocket(w) => w.handshake_request.clone(),
244                    _ => {
245                        self.metrics
246                            .on_request_headers_errors
247                            .fetch_add(1, Ordering::Relaxed);
248                        return InterceptionResult::Continue;
249                    }
250                })
251            }
252            Ok(None) => InterceptionResult::Continue,
253            Err(e) => {
254                tracing::error!("Script execution error (on_request_headers): {}", e);
255                flow.tags.push("script-error".to_string());
256                if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
257                    http.error = Some(format!("Script Error: {}", e));
258                }
259                self.metrics
260                    .on_request_headers_errors
261                    .fetch_add(1, Ordering::Relaxed);
262                InterceptionResult::Continue
263            }
264        };
265        let dur_us = start.elapsed().as_micros() as u64;
266        self.metrics
267            .on_request_headers_duration_us
268            .fetch_add(dur_us, Ordering::Relaxed);
269        self.metrics
270            .on_request_headers_invocations
271            .fetch_add(1, Ordering::Relaxed);
272        result
273    }
274
275    async fn on_request(&self, flow: &mut Flow, body: HttpBody) -> Result<RequestAction, BoxError> {
276        let start = Instant::now();
277        let index = self.get_engine_index();
278        let engine_lock = &self.engines[index];
279        let engine = engine_lock.read().await;
280
281        match engine.on_request(flow, body).await {
282            Ok(action) => {
283                let dur_us = start.elapsed().as_micros() as u64;
284                self.metrics
285                    .on_request_duration_us
286                    .fetch_add(dur_us, Ordering::Relaxed);
287                self.metrics
288                    .on_request_invocations
289                    .fetch_add(1, Ordering::Relaxed);
290                Ok(action)
291            }
292            Err(e) => {
293                self.metrics
294                    .on_request_errors
295                    .fetch_add(1, Ordering::Relaxed);
296                Err(e)
297            }
298        }
299    }
300
301    async fn on_response_headers(&self, flow: &mut Flow) -> InterceptionResult {
302        let start = Instant::now();
303        let index = self.get_engine_index();
304        let engine_lock = &self.engines[index];
305        let engine = engine_lock.read().await;
306
307        let result = match engine.on_response_headers(flow).await {
308            Ok(Some(modified_flow)) => {
309                *flow = modified_flow;
310                InterceptionResult::ModifiedResponse(match &flow.layer {
311                    relay_core_api::flow::Layer::Http(h) => h.response.clone().unwrap(),
312                    relay_core_api::flow::Layer::WebSocket(w) => w.handshake_response.clone(),
313                    _ => {
314                        self.metrics
315                            .on_response_headers_errors
316                            .fetch_add(1, Ordering::Relaxed);
317                        return InterceptionResult::Continue;
318                    }
319                })
320            }
321            Ok(None) => InterceptionResult::Continue,
322            Err(e) => {
323                tracing::error!("Script execution error (on_response_headers): {}", e);
324                flow.tags.push("script-error".to_string());
325                if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
326                    http.error = Some(format!("Script Error: {}", e));
327                }
328                self.metrics
329                    .on_response_headers_errors
330                    .fetch_add(1, Ordering::Relaxed);
331                InterceptionResult::Continue
332            }
333        };
334        let dur_us = start.elapsed().as_micros() as u64;
335        self.metrics
336            .on_response_headers_duration_us
337            .fetch_add(dur_us, Ordering::Relaxed);
338        self.metrics
339            .on_response_headers_invocations
340            .fetch_add(1, Ordering::Relaxed);
341        result
342    }
343
344    async fn on_response(
345        &self,
346        flow: &mut Flow,
347        body: HttpBody,
348    ) -> Result<ResponseAction, BoxError> {
349        let start = Instant::now();
350        let index = self.get_engine_index();
351        let engine_lock = &self.engines[index];
352        let engine = engine_lock.read().await;
353
354        match engine.on_response(flow, body).await {
355            Ok(action) => {
356                let dur_us = start.elapsed().as_micros() as u64;
357                self.metrics
358                    .on_response_duration_us
359                    .fetch_add(dur_us, Ordering::Relaxed);
360                self.metrics
361                    .on_response_invocations
362                    .fetch_add(1, Ordering::Relaxed);
363                Ok(action)
364            }
365            Err(e) => {
366                self.metrics
367                    .on_response_errors
368                    .fetch_add(1, Ordering::Relaxed);
369                Err(e)
370            }
371        }
372    }
373
374    async fn on_websocket_message(
375        &self,
376        flow: &mut Flow,
377        mut message: WebSocketMessage,
378    ) -> Result<WebSocketMessageAction, BoxError> {
379        let start = Instant::now();
380        let index = self.get_engine_index();
381        let engine_lock = &self.engines[index];
382        let engine = engine_lock.read().await;
383
384        match engine.on_websocket_message(flow, &mut message).await {
385            Ok(action) => {
386                let dur_us = start.elapsed().as_micros() as u64;
387                self.metrics
388                    .on_websocket_message_duration_us
389                    .fetch_add(dur_us, Ordering::Relaxed);
390                self.metrics
391                    .on_websocket_message_invocations
392                    .fetch_add(1, Ordering::Relaxed);
393                Ok(action)
394            }
395            Err(e) => {
396                tracing::error!("Script execution error (on_websocket_message): {}", e);
397                flow.tags.push("script-error".to_string());
398                self.metrics
399                    .on_websocket_message_errors
400                    .fetch_add(1, Ordering::Relaxed);
401                Ok(WebSocketMessageAction::Continue(message))
402            }
403        }
404    }
405
406    async fn on_connect(
407        &self,
408        conn: &relay_core_lib::interceptor::ConnectionInfo,
409    ) -> relay_core_lib::interceptor::ConnectAction {
410        let index = self.get_engine_index();
411        let engine_lock = &self.engines[index];
412        let engine = engine_lock.read().await;
413
414        match engine.on_connect(conn).await {
415            Ok(action) => action,
416            Err(e) => {
417                tracing::warn!("onConnect script error: {}", e);
418                relay_core_lib::interceptor::ConnectAction::Allow
419            }
420        }
421    }
422
423    async fn on_disconnect(
424        &self,
425        conn: &relay_core_lib::interceptor::ConnectionInfo,
426        stats: &relay_core_lib::interceptor::ConnectionStats,
427    ) {
428        let index = self.get_engine_index();
429        let engine_lock = &self.engines[index];
430        let engine = engine_lock.read().await;
431
432        if let Err(e) = engine.on_disconnect(conn, stats).await {
433            tracing::warn!("onDisconnect script error: {}", e);
434        }
435    }
436
437    async fn on_websocket_start(&self, flow: &mut Flow) {
438        let index = self.get_engine_index();
439        let engine_lock = &self.engines[index];
440        let engine = engine_lock.read().await;
441
442        if let Err(e) = engine.on_websocket_start(flow).await {
443            tracing::warn!("onWebSocketStart script error: {}", e);
444        }
445    }
446
447    async fn on_websocket_end(&self, flow: &mut Flow, close_code: u16, close_reason: &str) {
448        let index = self.get_engine_index();
449        let engine_lock = &self.engines[index];
450        let engine = engine_lock.read().await;
451
452        if let Err(e) = engine
453            .on_websocket_end(flow, close_code, close_reason)
454            .await
455        {
456            tracing::warn!("onWebSocketEnd script error: {}", e);
457        }
458    }
459
460    async fn on_websocket_error(&self, flow: &mut Flow, error: &str) {
461        let index = self.get_engine_index();
462        let engine_lock = &self.engines[index];
463        let engine = engine_lock.read().await;
464
465        if let Err(e) = engine.on_websocket_error(flow, error).await {
466            tracing::warn!("onWebSocketError script error: {}", e);
467        }
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use bytes::Bytes;
475    use chrono::Utc;
476    use http_body_util::{BodyExt, Empty};
477    use relay_core_api::flow::{
478        BodyData, Direction, Flow, HttpLayer, HttpRequest, Layer, NetworkInfo, TransportProtocol,
479        WebSocketMessage,
480    };
481    use relay_core_lib::interceptor::{ConnectAction, ConnectionInfo, ConnectionStats};
482    use std::collections::HashMap;
483    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
484    use url::Url;
485    use uuid::Uuid;
486
487    fn create_test_flow() -> Flow {
488        Flow {
489            id: Uuid::new_v4(),
490            start_time: Utc::now(),
491            end_time: None,
492            network: NetworkInfo {
493                client_ip: "127.0.0.1".to_string(),
494                client_port: 12345,
495                server_ip: "1.1.1.1".to_string(),
496                server_port: 80,
497                protocol: TransportProtocol::TCP,
498                tls: false,
499                tls_version: None,
500                sni: None,
501            },
502            layer: Layer::Http(HttpLayer {
503                request: HttpRequest {
504                    method: "GET".to_string(),
505                    url: Url::parse("http://example.com").unwrap(),
506                    version: "HTTP/1.1".to_string(),
507                    headers: vec![],
508                    cookies: vec![],
509                    query: vec![],
510                    body: None,
511                },
512                response: None,
513                error: None,
514            }),
515            tags: vec![],
516            meta: HashMap::new(),
517            resilience_trace: None,
518            rule_variables: std::collections::HashMap::new(),
519            matched_rules: vec![],
520        }
521    }
522
523    #[tokio::test]
524    async fn test_script_error_propagation() {
525        let interceptor = ScriptInterceptor::new().await.unwrap();
526
527        let script = r#"
528            globalThis.onRequestHeaders = (flow) => {
529                throw new Error("Test Error 123");
530            };
531        "#;
532        interceptor.load_script(script).await.unwrap();
533
534        let mut flow = create_test_flow();
535
536        let result = interceptor.on_request_headers(&mut flow).await;
537
538        match result {
539            InterceptionResult::Continue => {}
540            _ => panic!("Expected Continue"),
541        }
542
543        assert!(flow.tags.contains(&"script-error".to_string()));
544
545        if let Layer::Http(http) = &flow.layer {
546            assert!(http.error.is_some());
547            let err = http.error.as_ref().unwrap();
548            assert!(err.contains("Test Error 123"));
549        } else {
550            panic!("Expected Http layer");
551        }
552    }
553
554    #[tokio::test]
555    async fn test_script_api_relay_body() {
556        let interceptor = ScriptInterceptor::new().await.unwrap();
557
558        let script = r#"
559            globalThis.onRequest = (body, flow) => {
560                if (!(body instanceof RelayBody)) {
561                    throw new Error("First argument is not RelayBody");
562                }
563                // We return nothing (undefined), which means "continue with original body"
564                // But we successfully verified the type.
565            };
566        "#;
567        interceptor.load_script(script).await.unwrap();
568
569        let mut flow = create_test_flow();
570        let body = Empty::<Bytes>::new()
571            .map_err(|_| -> BoxError { unreachable!() })
572            .boxed();
573
574        let result = interceptor.on_request(&mut flow, body).await;
575        assert!(result.is_ok());
576    }
577
578    #[tokio::test]
579    async fn test_load_script_failure_does_not_replace_existing_engines() {
580        let interceptor = ScriptInterceptor::new().await.unwrap();
581
582        let good_script = r#"
583            globalThis.onRequestHeaders = (_context, flow) => {
584                if (flow.layer.type === "Http") {
585                    flow.layer.data.request.headers.push(["X-Good-Script", "1"]);
586                }
587                return flow;
588            };
589        "#;
590        interceptor
591            .load_script(good_script)
592            .await
593            .expect("good script should load");
594
595        let bad_script = "globalThis.onRequestHeaders = () => { invalid javascript !!!";
596        let bad = interceptor.load_script(bad_script).await;
597        assert!(bad.is_err(), "bad script should be rejected");
598
599        let mut flow = create_test_flow();
600        let result = interceptor.on_request_headers(&mut flow).await;
601        assert!(
602            matches!(result, InterceptionResult::ModifiedRequest(_)),
603            "existing good script should still be active after failed reload"
604        );
605        if let Layer::Http(http) = &flow.layer {
606            assert!(
607                http.request
608                    .headers
609                    .iter()
610                    .any(|(k, v)| k == "X-Good-Script" && v == "1")
611            );
612        } else {
613            panic!("Expected Http layer");
614        }
615    }
616
617    #[tokio::test]
618    async fn test_websocket_script_error_falls_back_continue_and_tags_flow() {
619        let interceptor = ScriptInterceptor::new().await.unwrap();
620        let script = r#"
621            globalThis.onWebSocketMessage = function(_context, _flow, _message) {
622                throw new Error("ws explode");
623            };
624        "#;
625        interceptor
626            .load_script(script)
627            .await
628            .expect("script should load");
629
630        let mut flow = create_test_flow();
631        let msg = WebSocketMessage {
632            id: Uuid::new_v4(),
633            timestamp: Utc::now(),
634            direction: Direction::ClientToServer,
635            content: BodyData {
636                encoding: "utf-8".to_string(),
637                content: "hello".to_string(),
638                size: 5,
639            },
640            opcode: "Text".to_string(),
641        };
642
643        let result = interceptor
644            .on_websocket_message(&mut flow, msg.clone())
645            .await
646            .expect("websocket interception should not return hard error");
647        match result {
648            WebSocketMessageAction::Continue(forwarded) => {
649                assert_eq!(forwarded.content.content, "hello");
650            }
651            other => panic!("expected Continue fallback, got {:?}", other),
652        }
653        assert!(
654            flow.tags.iter().any(|t| t == "script-error"),
655            "script error should be tagged for observability"
656        );
657    }
658
659    fn sample_connection_info() -> ConnectionInfo {
660        ConnectionInfo {
661            id: Uuid::new_v4(),
662            client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 54321),
663            server_addr: Some(SocketAddr::new(
664                IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)),
665                443,
666            )),
667            tls_sni: Some("example.com".to_string()),
668        }
669    }
670
671    fn sample_connection_stats() -> ConnectionStats {
672        ConnectionStats {
673            duration_ms: 1234,
674            bytes_sent: 5000,
675            bytes_received: 12000,
676            flows_count: 3,
677        }
678    }
679
680    #[tokio::test]
681    async fn test_on_connect_drop_rejects_connection() {
682        let interceptor = ScriptInterceptor::new().await.unwrap();
683
684        let script = r#"
685            globalThis.onConnect = function(context, conn) {
686                if (conn.tls_sni === "example.com") {
687                    return { drop: true, reason: "sni blocked" };
688                }
689                return {};
690            }
691        "#;
692        interceptor.load_script(script).await.unwrap();
693
694        let conn = sample_connection_info();
695        let action = interceptor.on_connect(&conn).await;
696        assert!(
697            matches!(action, ConnectAction::Drop { ref reason } if reason == "sni blocked"),
698            "onConnect should drop connection for blocked SNI, got {:?}",
699            action
700        );
701    }
702
703    #[tokio::test]
704    async fn test_on_connect_allow_no_drop() {
705        let interceptor = ScriptInterceptor::new().await.unwrap();
706
707        let script = r#"
708            globalThis.onConnect = function(context, conn) {
709                return {};
710            }
711        "#;
712        interceptor.load_script(script).await.unwrap();
713
714        let conn = sample_connection_info();
715        let action = interceptor.on_connect(&conn).await;
716        assert!(
717            matches!(action, ConnectAction::Allow),
718            "onConnect should allow connection by default"
719        );
720    }
721
722    #[tokio::test]
723    async fn test_on_connect_no_handler_defaults_allow() {
724        let interceptor = ScriptInterceptor::new().await.unwrap();
725
726        // No onConnect defined in script
727        let script = r#"
728            globalThis.onRequestHeaders = function(context, flow) { return flow; }
729        "#;
730        interceptor.load_script(script).await.unwrap();
731
732        let conn = sample_connection_info();
733        let action = interceptor.on_connect(&conn).await;
734        assert!(
735            matches!(action, ConnectAction::Allow),
736            "missing onConnect should default to Allow"
737        );
738    }
739
740    #[tokio::test]
741    async fn test_on_disconnect_fires() {
742        let interceptor = ScriptInterceptor::new().await.unwrap();
743
744        let script = r#"
745            globalThis.onDisconnect = function(context, conn, stats) {
746                // Verify stats fields are accessible in script
747                if (typeof stats.duration_ms !== "number") throw new Error("missing duration_ms");
748                if (typeof conn.client_addr !== "string") throw new Error("missing client_addr");
749            }
750        "#;
751        interceptor.load_script(script).await.unwrap();
752
753        let conn = sample_connection_info();
754        let stats = sample_connection_stats();
755        // on_disconnect is fire-and-forget; verify it does not panic
756        interceptor.on_disconnect(&conn, &stats).await;
757    }
758
759    #[tokio::test]
760    async fn test_on_websocket_start_fires() {
761        let interceptor = ScriptInterceptor::new().await.unwrap();
762
763        let script = r#"
764            globalThis.onWebSocketStart = function(context, flow) {
765                if (flow.layer.type === "Http") {
766                    flow.layer.data.request.headers.push(["X-WS-Start", "1"]);
767                }
768                return flow;
769            }
770        "#;
771        interceptor.load_script(script).await.unwrap();
772
773        let mut flow = create_test_flow();
774        interceptor.on_websocket_start(&mut flow).await;
775
776        if let Layer::Http(http) = &flow.layer {
777            assert!(
778                http.request
779                    .headers
780                    .iter()
781                    .any(|(k, v)| k == "X-WS-Start" && v == "1"),
782                "onWebSocketStart should inject header"
783            );
784        }
785    }
786
787    #[tokio::test]
788    async fn test_on_websocket_end_fires() {
789        let interceptor = ScriptInterceptor::new().await.unwrap();
790
791        let script = r#"
792            globalThis.onWebSocketEnd = function(context, flow, closeCode, closeReason) {
793                flow.tags.push("ws-ended:" + closeCode);
794                return flow;
795            }
796        "#;
797        interceptor.load_script(script).await.unwrap();
798
799        let mut flow = create_test_flow();
800        interceptor
801            .on_websocket_end(&mut flow, 1000, "Normal Closure")
802            .await;
803
804        assert!(
805            flow.tags.iter().any(|t| t == "ws-ended:1000"),
806            "onWebSocketEnd should tag flow with close code"
807        );
808    }
809
810    #[tokio::test]
811    async fn test_on_websocket_error_fires() {
812        let interceptor = ScriptInterceptor::new().await.unwrap();
813
814        let script = r#"
815            globalThis.onWebSocketError = function(context, flow, error) {
816                flow.tags.push("ws-error");
817                return flow;
818            }
819        "#;
820        interceptor.load_script(script).await.unwrap();
821
822        let mut flow = create_test_flow();
823        interceptor
824            .on_websocket_error(&mut flow, "connection reset")
825            .await;
826
827        assert!(
828            flow.tags.iter().any(|t| t == "ws-error"),
829            "onWebSocketError should tag flow"
830        );
831    }
832
833    #[tokio::test]
834    async fn test_on_connect_error_falls_back_allow() {
835        let interceptor = ScriptInterceptor::new().await.unwrap();
836
837        let script = r#"
838            globalThis.onConnect = function(context, conn) {
839                throw new Error("connect handler crash");
840            }
841        "#;
842        interceptor.load_script(script).await.unwrap();
843
844        let conn = sample_connection_info();
845        let action = interceptor.on_connect(&conn).await;
846        assert!(
847            matches!(action, ConnectAction::Allow),
848            "onConnect handler crash should fall back to Allow"
849        );
850    }
851
852    #[tokio::test]
853    async fn test_ws_lifecycle_normal_sequence() {
854        let interceptor = ScriptInterceptor::new().await.unwrap();
855
856        let script = r#"
857            globalThis.onWebSocketStart = function(context, flow) {
858                flow.tags.push("ws-life:start");
859                return flow;
860            }
861            globalThis.onWebSocketMessage = function(context, flow, message) {
862                message.content.content += " [mod]";
863                return message;
864            }
865            globalThis.onWebSocketEnd = function(context, flow, closeCode, closeReason) {
866                flow.tags.push("ws-life:end:" + closeCode);
867                return flow;
868            }
869        "#;
870        interceptor.load_script(script).await.unwrap();
871
872        let mut flow = create_test_flow();
873
874        // Normal sequence: start → message → end
875        interceptor.on_websocket_start(&mut flow).await;
876        assert!(flow.tags.iter().any(|t| t == "ws-life:start"));
877
878        let msg = WebSocketMessage {
879            id: Uuid::new_v4(),
880            timestamp: Utc::now(),
881            direction: Direction::ClientToServer,
882            content: BodyData {
883                encoding: "utf-8".to_string(),
884                content: "hello".to_string(),
885                size: 5,
886            },
887            opcode: "Text".to_string(),
888        };
889        let result = interceptor
890            .on_websocket_message(&mut flow, msg)
891            .await
892            .expect("ws message interception ok");
893        match result {
894            WebSocketMessageAction::Continue(forwarded) => {
895                assert!(forwarded.content.content.contains("[mod]"));
896            }
897            other => panic!("expected Continue, got {:?}", other),
898        }
899
900        interceptor.on_websocket_end(&mut flow, 1000, "done").await;
901        assert!(flow.tags.iter().any(|t| t == "ws-life:end:1000"));
902    }
903
904    #[tokio::test]
905    async fn test_ws_lifecycle_error_sequence() {
906        let interceptor = ScriptInterceptor::new().await.unwrap();
907
908        let script = r#"
909            globalThis.onWebSocketError = function(context, flow, error) {
910                flow.tags.push("ws-life:error");
911                return flow;
912            }
913        "#;
914        interceptor.load_script(script).await.unwrap();
915
916        let mut flow = create_test_flow();
917
918        // Error sequence: onWebSocketError fires
919        interceptor
920            .on_websocket_error(&mut flow, "peer reset")
921            .await;
922        assert!(flow.tags.iter().any(|t| t == "ws-life:error"));
923    }
924}