1pub mod deno_engine;
7pub mod engine_trait;
8pub mod streams;
9
10pub use deno_engine::ScriptFetchConfig;
11
12use crate::deno_engine::DenoScriptEngine;
13use crate::engine_trait::ScriptEngineTrait;
14use async_trait::async_trait;
15use relay_core_api::flow::{Flow, WebSocketMessage};
16use relay_core_lib::interceptor::{
17 BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, ResponseAction,
18 WebSocketMessageAction,
19};
20use std::collections::HashSet;
21use std::collections::hash_map::DefaultHasher;
22use std::hash::{Hash, Hasher};
23use std::sync::atomic::{AtomicU64, Ordering};
24use std::time::Instant;
25use tokio::sync::RwLock;
26
27pub struct ScriptMetrics {
29 pub on_request_headers_duration_us: AtomicU64,
30 pub on_request_headers_invocations: AtomicU64,
31 pub on_request_headers_errors: AtomicU64,
32
33 pub on_request_duration_us: AtomicU64,
34 pub on_request_invocations: AtomicU64,
35 pub on_request_errors: AtomicU64,
36
37 pub on_response_headers_duration_us: AtomicU64,
38 pub on_response_headers_invocations: AtomicU64,
39 pub on_response_headers_errors: AtomicU64,
40
41 pub on_response_duration_us: AtomicU64,
42 pub on_response_invocations: AtomicU64,
43 pub on_response_errors: AtomicU64,
44
45 pub on_websocket_message_duration_us: AtomicU64,
46 pub on_websocket_message_invocations: AtomicU64,
47 pub on_websocket_message_errors: AtomicU64,
48}
49
50impl Default for ScriptMetrics {
51 fn default() -> Self {
52 Self {
53 on_request_headers_duration_us: AtomicU64::new(0),
54 on_request_headers_invocations: AtomicU64::new(0),
55 on_request_headers_errors: AtomicU64::new(0),
56 on_request_duration_us: AtomicU64::new(0),
57 on_request_invocations: AtomicU64::new(0),
58 on_request_errors: AtomicU64::new(0),
59 on_response_headers_duration_us: AtomicU64::new(0),
60 on_response_headers_invocations: AtomicU64::new(0),
61 on_response_headers_errors: AtomicU64::new(0),
62 on_response_duration_us: AtomicU64::new(0),
63 on_response_invocations: AtomicU64::new(0),
64 on_response_errors: AtomicU64::new(0),
65 on_websocket_message_duration_us: AtomicU64::new(0),
66 on_websocket_message_invocations: AtomicU64::new(0),
67 on_websocket_message_errors: AtomicU64::new(0),
68 }
69 }
70}
71
72impl ScriptMetrics {
73 pub fn prometheus_lines(&self) -> String {
74 let mut out = String::new();
75 macro_rules! push_metric {
76 ($name:expr, $dur:expr, $inv:expr, $err:expr) => {
77 out.push_str(&format!(
78 "relay_core_script_hook_duration_us{{hook=\"{}\"}} {}\n",
79 $name,
80 $dur.load(Ordering::Relaxed)
81 ));
82 out.push_str(&format!(
83 "relay_core_script_hook_invocations_total{{hook=\"{}\"}} {}\n",
84 $name,
85 $inv.load(Ordering::Relaxed)
86 ));
87 out.push_str(&format!(
88 "relay_core_script_hook_errors_total{{hook=\"{}\"}} {}\n",
89 $name,
90 $err.load(Ordering::Relaxed)
91 ));
92 };
93 }
94 push_metric!(
95 "onRequestHeaders",
96 &self.on_request_headers_duration_us,
97 &self.on_request_headers_invocations,
98 &self.on_request_headers_errors
99 );
100 push_metric!(
101 "onRequest",
102 &self.on_request_duration_us,
103 &self.on_request_invocations,
104 &self.on_request_errors
105 );
106 push_metric!(
107 "onResponseHeaders",
108 &self.on_response_headers_duration_us,
109 &self.on_response_headers_invocations,
110 &self.on_response_headers_errors
111 );
112 push_metric!(
113 "onResponse",
114 &self.on_response_duration_us,
115 &self.on_response_invocations,
116 &self.on_response_errors
117 );
118 push_metric!(
119 "onWebSocketMessage",
120 &self.on_websocket_message_duration_us,
121 &self.on_websocket_message_invocations,
122 &self.on_websocket_message_errors
123 );
124 out.push_str(&format!(
125 "relay_core_script_env_access_total {}\n",
126 deno_engine::get_script_env_access_total()
127 ));
128 out.push_str(&format!(
129 "relay_core_script_fetch_total {}\n",
130 deno_engine::get_script_fetch_total()
131 ));
132 out.push_str(&format!(
133 "relay_core_script_fetch_rejected_total {}\n",
134 deno_engine::get_script_fetch_rejected_total()
135 ));
136 out
137 }
138}
139
140pub struct ScriptInterceptor {
141 engines: Vec<RwLock<Box<dyn ScriptEngineTrait>>>,
142 pub metrics: ScriptMetrics,
143 env_allow: RwLock<HashSet<String>>,
144 fetch_config: RwLock<deno_engine::ScriptFetchConfig>,
145}
146
147impl ScriptInterceptor {
148 pub async fn new() -> Result<Self, BoxError> {
149 Self::new_with_env(HashSet::new()).await
150 }
151
152 pub async fn new_with_env(env_allow: HashSet<String>) -> Result<Self, BoxError> {
153 Self::new_with_env_and_fetch(env_allow, deno_engine::ScriptFetchConfig::default()).await
154 }
155
156 pub async fn new_with_env_and_fetch(
157 env_allow: HashSet<String>,
158 fetch_config: deno_engine::ScriptFetchConfig,
159 ) -> Result<Self, BoxError> {
160 let pool_size = std::thread::available_parallelism()
161 .map(|n| n.get())
162 .unwrap_or(4);
163 let mut engines = Vec::with_capacity(pool_size);
164
165 for _ in 0..pool_size {
166 let engine: Box<dyn ScriptEngineTrait> = Box::new(DenoScriptEngine::new_with_fetch(
167 env_allow.clone(),
168 fetch_config.clone(),
169 ));
170 engines.push(RwLock::new(engine));
171 }
172
173 Ok(Self {
174 engines,
175 metrics: ScriptMetrics::default(),
176 env_allow: RwLock::new(env_allow),
177 fetch_config: RwLock::new(fetch_config),
178 })
179 }
180
181 pub async fn set_env_allow(&self, env_allow: HashSet<String>) {
182 let mut guard = self.env_allow.write().await;
183 *guard = env_allow;
184 }
185
186 pub async fn set_fetch_config(&self, config: deno_engine::ScriptFetchConfig) {
187 let mut guard = self.fetch_config.write().await;
188 *guard = config;
189 }
190
191 pub async fn load_script(&self, script: &str) -> Result<(), BoxError> {
192 let pool_size = self.engines.len();
194 let mut new_engines = Vec::with_capacity(pool_size);
195
196 let env = self.env_allow.read().await.clone();
197 let fc = self.fetch_config.read().await.clone();
198 for _ in 0..pool_size {
199 let mut engine = DenoScriptEngine::new_with_fetch(env.clone(), fc.clone());
200 engine.load_script(script).await?;
201 new_engines.push(Box::new(engine) as Box<dyn ScriptEngineTrait>);
202 }
203
204 let mut new_engines_iter = new_engines.into_iter();
207 for engine_lock in &self.engines {
208 if let Some(new_engine) = new_engines_iter.next() {
209 let mut guard = engine_lock.write().await;
210 *guard = new_engine;
211 }
212 }
213
214 Ok(())
215 }
216
217 fn get_engine_index(&self) -> usize {
218 if let Ok(index) = relay_core_lib::interceptor::ENGINE_INDEX.try_with(|i| *i) {
220 return index % self.engines.len();
221 }
222
223 let thread_id = std::thread::current().id();
224 let mut hasher = DefaultHasher::new();
225 thread_id.hash(&mut hasher);
226 (hasher.finish() as usize) % self.engines.len()
227 }
228}
229
230#[async_trait]
231impl Interceptor for ScriptInterceptor {
232 async fn on_request_headers(&self, flow: &mut Flow) -> InterceptionResult {
233 let start = Instant::now();
234 let index = self.get_engine_index();
235 let engine_lock = &self.engines[index];
236 let engine = engine_lock.read().await;
237
238 let result = match engine.on_request_headers(flow).await {
239 Ok(Some(modified_flow)) => {
240 *flow = modified_flow;
241 InterceptionResult::ModifiedRequest(match &flow.layer {
242 relay_core_api::flow::Layer::Http(h) => h.request.clone(),
243 relay_core_api::flow::Layer::WebSocket(w) => w.handshake_request.clone(),
244 _ => {
245 self.metrics
246 .on_request_headers_errors
247 .fetch_add(1, Ordering::Relaxed);
248 return InterceptionResult::Continue;
249 }
250 })
251 }
252 Ok(None) => InterceptionResult::Continue,
253 Err(e) => {
254 tracing::error!("Script execution error (on_request_headers): {}", e);
255 flow.tags.push("script-error".to_string());
256 if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
257 http.error = Some(format!("Script Error: {}", e));
258 }
259 self.metrics
260 .on_request_headers_errors
261 .fetch_add(1, Ordering::Relaxed);
262 InterceptionResult::Continue
263 }
264 };
265 let dur_us = start.elapsed().as_micros() as u64;
266 self.metrics
267 .on_request_headers_duration_us
268 .fetch_add(dur_us, Ordering::Relaxed);
269 self.metrics
270 .on_request_headers_invocations
271 .fetch_add(1, Ordering::Relaxed);
272 result
273 }
274
275 async fn on_request(&self, flow: &mut Flow, body: HttpBody) -> Result<RequestAction, BoxError> {
276 let start = Instant::now();
277 let index = self.get_engine_index();
278 let engine_lock = &self.engines[index];
279 let engine = engine_lock.read().await;
280
281 match engine.on_request(flow, body).await {
282 Ok(action) => {
283 let dur_us = start.elapsed().as_micros() as u64;
284 self.metrics
285 .on_request_duration_us
286 .fetch_add(dur_us, Ordering::Relaxed);
287 self.metrics
288 .on_request_invocations
289 .fetch_add(1, Ordering::Relaxed);
290 Ok(action)
291 }
292 Err(e) => {
293 self.metrics
294 .on_request_errors
295 .fetch_add(1, Ordering::Relaxed);
296 Err(e)
297 }
298 }
299 }
300
301 async fn on_response_headers(&self, flow: &mut Flow) -> InterceptionResult {
302 let start = Instant::now();
303 let index = self.get_engine_index();
304 let engine_lock = &self.engines[index];
305 let engine = engine_lock.read().await;
306
307 let result = match engine.on_response_headers(flow).await {
308 Ok(Some(modified_flow)) => {
309 *flow = modified_flow;
310 InterceptionResult::ModifiedResponse(match &flow.layer {
311 relay_core_api::flow::Layer::Http(h) => h.response.clone().unwrap(),
312 relay_core_api::flow::Layer::WebSocket(w) => w.handshake_response.clone(),
313 _ => {
314 self.metrics
315 .on_response_headers_errors
316 .fetch_add(1, Ordering::Relaxed);
317 return InterceptionResult::Continue;
318 }
319 })
320 }
321 Ok(None) => InterceptionResult::Continue,
322 Err(e) => {
323 tracing::error!("Script execution error (on_response_headers): {}", e);
324 flow.tags.push("script-error".to_string());
325 if let relay_core_api::flow::Layer::Http(http) = &mut flow.layer {
326 http.error = Some(format!("Script Error: {}", e));
327 }
328 self.metrics
329 .on_response_headers_errors
330 .fetch_add(1, Ordering::Relaxed);
331 InterceptionResult::Continue
332 }
333 };
334 let dur_us = start.elapsed().as_micros() as u64;
335 self.metrics
336 .on_response_headers_duration_us
337 .fetch_add(dur_us, Ordering::Relaxed);
338 self.metrics
339 .on_response_headers_invocations
340 .fetch_add(1, Ordering::Relaxed);
341 result
342 }
343
344 async fn on_response(
345 &self,
346 flow: &mut Flow,
347 body: HttpBody,
348 ) -> Result<ResponseAction, BoxError> {
349 let start = Instant::now();
350 let index = self.get_engine_index();
351 let engine_lock = &self.engines[index];
352 let engine = engine_lock.read().await;
353
354 match engine.on_response(flow, body).await {
355 Ok(action) => {
356 let dur_us = start.elapsed().as_micros() as u64;
357 self.metrics
358 .on_response_duration_us
359 .fetch_add(dur_us, Ordering::Relaxed);
360 self.metrics
361 .on_response_invocations
362 .fetch_add(1, Ordering::Relaxed);
363 Ok(action)
364 }
365 Err(e) => {
366 self.metrics
367 .on_response_errors
368 .fetch_add(1, Ordering::Relaxed);
369 Err(e)
370 }
371 }
372 }
373
374 async fn on_websocket_message(
375 &self,
376 flow: &mut Flow,
377 mut message: WebSocketMessage,
378 ) -> Result<WebSocketMessageAction, BoxError> {
379 let start = Instant::now();
380 let index = self.get_engine_index();
381 let engine_lock = &self.engines[index];
382 let engine = engine_lock.read().await;
383
384 match engine.on_websocket_message(flow, &mut message).await {
385 Ok(action) => {
386 let dur_us = start.elapsed().as_micros() as u64;
387 self.metrics
388 .on_websocket_message_duration_us
389 .fetch_add(dur_us, Ordering::Relaxed);
390 self.metrics
391 .on_websocket_message_invocations
392 .fetch_add(1, Ordering::Relaxed);
393 Ok(action)
394 }
395 Err(e) => {
396 tracing::error!("Script execution error (on_websocket_message): {}", e);
397 flow.tags.push("script-error".to_string());
398 self.metrics
399 .on_websocket_message_errors
400 .fetch_add(1, Ordering::Relaxed);
401 Ok(WebSocketMessageAction::Continue(message))
402 }
403 }
404 }
405
406 async fn on_connect(
407 &self,
408 conn: &relay_core_lib::interceptor::ConnectionInfo,
409 ) -> relay_core_lib::interceptor::ConnectAction {
410 let index = self.get_engine_index();
411 let engine_lock = &self.engines[index];
412 let engine = engine_lock.read().await;
413
414 match engine.on_connect(conn).await {
415 Ok(action) => action,
416 Err(e) => {
417 tracing::warn!("onConnect script error: {}", e);
418 relay_core_lib::interceptor::ConnectAction::Allow
419 }
420 }
421 }
422
423 async fn on_disconnect(
424 &self,
425 conn: &relay_core_lib::interceptor::ConnectionInfo,
426 stats: &relay_core_lib::interceptor::ConnectionStats,
427 ) {
428 let index = self.get_engine_index();
429 let engine_lock = &self.engines[index];
430 let engine = engine_lock.read().await;
431
432 if let Err(e) = engine.on_disconnect(conn, stats).await {
433 tracing::warn!("onDisconnect script error: {}", e);
434 }
435 }
436
437 async fn on_websocket_start(&self, flow: &mut Flow) {
438 let index = self.get_engine_index();
439 let engine_lock = &self.engines[index];
440 let engine = engine_lock.read().await;
441
442 if let Err(e) = engine.on_websocket_start(flow).await {
443 tracing::warn!("onWebSocketStart script error: {}", e);
444 }
445 }
446
447 async fn on_websocket_end(&self, flow: &mut Flow, close_code: u16, close_reason: &str) {
448 let index = self.get_engine_index();
449 let engine_lock = &self.engines[index];
450 let engine = engine_lock.read().await;
451
452 if let Err(e) = engine
453 .on_websocket_end(flow, close_code, close_reason)
454 .await
455 {
456 tracing::warn!("onWebSocketEnd script error: {}", e);
457 }
458 }
459
460 async fn on_websocket_error(&self, flow: &mut Flow, error: &str) {
461 let index = self.get_engine_index();
462 let engine_lock = &self.engines[index];
463 let engine = engine_lock.read().await;
464
465 if let Err(e) = engine.on_websocket_error(flow, error).await {
466 tracing::warn!("onWebSocketError script error: {}", e);
467 }
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474 use bytes::Bytes;
475 use chrono::Utc;
476 use http_body_util::{BodyExt, Empty};
477 use relay_core_api::flow::{
478 BodyData, Direction, Flow, HttpLayer, HttpRequest, Layer, NetworkInfo, TransportProtocol,
479 WebSocketMessage,
480 };
481 use relay_core_lib::interceptor::{ConnectAction, ConnectionInfo, ConnectionStats};
482 use std::collections::HashMap;
483 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
484 use url::Url;
485 use uuid::Uuid;
486
487 fn create_test_flow() -> Flow {
488 Flow {
489 id: Uuid::new_v4(),
490 start_time: Utc::now(),
491 end_time: None,
492 network: NetworkInfo {
493 client_ip: "127.0.0.1".to_string(),
494 client_port: 12345,
495 server_ip: "1.1.1.1".to_string(),
496 server_port: 80,
497 protocol: TransportProtocol::TCP,
498 tls: false,
499 tls_version: None,
500 sni: None,
501 },
502 layer: Layer::Http(HttpLayer {
503 request: HttpRequest {
504 method: "GET".to_string(),
505 url: Url::parse("http://example.com").unwrap(),
506 version: "HTTP/1.1".to_string(),
507 headers: vec![],
508 cookies: vec![],
509 query: vec![],
510 body: None,
511 },
512 response: None,
513 error: None,
514 }),
515 tags: vec![],
516 meta: HashMap::new(),
517 resilience_trace: None,
518 rule_variables: std::collections::HashMap::new(),
519 matched_rules: vec![],
520 }
521 }
522
523 #[tokio::test]
524 async fn test_script_error_propagation() {
525 let interceptor = ScriptInterceptor::new().await.unwrap();
526
527 let script = r#"
528 globalThis.onRequestHeaders = (flow) => {
529 throw new Error("Test Error 123");
530 };
531 "#;
532 interceptor.load_script(script).await.unwrap();
533
534 let mut flow = create_test_flow();
535
536 let result = interceptor.on_request_headers(&mut flow).await;
537
538 match result {
539 InterceptionResult::Continue => {}
540 _ => panic!("Expected Continue"),
541 }
542
543 assert!(flow.tags.contains(&"script-error".to_string()));
544
545 if let Layer::Http(http) = &flow.layer {
546 assert!(http.error.is_some());
547 let err = http.error.as_ref().unwrap();
548 assert!(err.contains("Test Error 123"));
549 } else {
550 panic!("Expected Http layer");
551 }
552 }
553
554 #[tokio::test]
555 async fn test_script_api_relay_body() {
556 let interceptor = ScriptInterceptor::new().await.unwrap();
557
558 let script = r#"
559 globalThis.onRequest = (body, flow) => {
560 if (!(body instanceof RelayBody)) {
561 throw new Error("First argument is not RelayBody");
562 }
563 // We return nothing (undefined), which means "continue with original body"
564 // But we successfully verified the type.
565 };
566 "#;
567 interceptor.load_script(script).await.unwrap();
568
569 let mut flow = create_test_flow();
570 let body = Empty::<Bytes>::new()
571 .map_err(|_| -> BoxError { unreachable!() })
572 .boxed();
573
574 let result = interceptor.on_request(&mut flow, body).await;
575 assert!(result.is_ok());
576 }
577
578 #[tokio::test]
579 async fn test_load_script_failure_does_not_replace_existing_engines() {
580 let interceptor = ScriptInterceptor::new().await.unwrap();
581
582 let good_script = r#"
583 globalThis.onRequestHeaders = (_context, flow) => {
584 if (flow.layer.type === "Http") {
585 flow.layer.data.request.headers.push(["X-Good-Script", "1"]);
586 }
587 return flow;
588 };
589 "#;
590 interceptor
591 .load_script(good_script)
592 .await
593 .expect("good script should load");
594
595 let bad_script = "globalThis.onRequestHeaders = () => { invalid javascript !!!";
596 let bad = interceptor.load_script(bad_script).await;
597 assert!(bad.is_err(), "bad script should be rejected");
598
599 let mut flow = create_test_flow();
600 let result = interceptor.on_request_headers(&mut flow).await;
601 assert!(
602 matches!(result, InterceptionResult::ModifiedRequest(_)),
603 "existing good script should still be active after failed reload"
604 );
605 if let Layer::Http(http) = &flow.layer {
606 assert!(
607 http.request
608 .headers
609 .iter()
610 .any(|(k, v)| k == "X-Good-Script" && v == "1")
611 );
612 } else {
613 panic!("Expected Http layer");
614 }
615 }
616
617 #[tokio::test]
618 async fn test_websocket_script_error_falls_back_continue_and_tags_flow() {
619 let interceptor = ScriptInterceptor::new().await.unwrap();
620 let script = r#"
621 globalThis.onWebSocketMessage = function(_context, _flow, _message) {
622 throw new Error("ws explode");
623 };
624 "#;
625 interceptor
626 .load_script(script)
627 .await
628 .expect("script should load");
629
630 let mut flow = create_test_flow();
631 let msg = WebSocketMessage {
632 id: Uuid::new_v4(),
633 timestamp: Utc::now(),
634 direction: Direction::ClientToServer,
635 content: BodyData {
636 encoding: "utf-8".to_string(),
637 content: "hello".to_string(),
638 size: 5,
639 },
640 opcode: "Text".to_string(),
641 };
642
643 let result = interceptor
644 .on_websocket_message(&mut flow, msg.clone())
645 .await
646 .expect("websocket interception should not return hard error");
647 match result {
648 WebSocketMessageAction::Continue(forwarded) => {
649 assert_eq!(forwarded.content.content, "hello");
650 }
651 other => panic!("expected Continue fallback, got {:?}", other),
652 }
653 assert!(
654 flow.tags.iter().any(|t| t == "script-error"),
655 "script error should be tagged for observability"
656 );
657 }
658
659 fn sample_connection_info() -> ConnectionInfo {
660 ConnectionInfo {
661 id: Uuid::new_v4(),
662 client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 54321),
663 server_addr: Some(SocketAddr::new(
664 IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)),
665 443,
666 )),
667 tls_sni: Some("example.com".to_string()),
668 }
669 }
670
671 fn sample_connection_stats() -> ConnectionStats {
672 ConnectionStats {
673 duration_ms: 1234,
674 bytes_sent: 5000,
675 bytes_received: 12000,
676 flows_count: 3,
677 }
678 }
679
680 #[tokio::test]
681 async fn test_on_connect_drop_rejects_connection() {
682 let interceptor = ScriptInterceptor::new().await.unwrap();
683
684 let script = r#"
685 globalThis.onConnect = function(context, conn) {
686 if (conn.tls_sni === "example.com") {
687 return { drop: true, reason: "sni blocked" };
688 }
689 return {};
690 }
691 "#;
692 interceptor.load_script(script).await.unwrap();
693
694 let conn = sample_connection_info();
695 let action = interceptor.on_connect(&conn).await;
696 assert!(
697 matches!(action, ConnectAction::Drop { ref reason } if reason == "sni blocked"),
698 "onConnect should drop connection for blocked SNI, got {:?}",
699 action
700 );
701 }
702
703 #[tokio::test]
704 async fn test_on_connect_allow_no_drop() {
705 let interceptor = ScriptInterceptor::new().await.unwrap();
706
707 let script = r#"
708 globalThis.onConnect = function(context, conn) {
709 return {};
710 }
711 "#;
712 interceptor.load_script(script).await.unwrap();
713
714 let conn = sample_connection_info();
715 let action = interceptor.on_connect(&conn).await;
716 assert!(
717 matches!(action, ConnectAction::Allow),
718 "onConnect should allow connection by default"
719 );
720 }
721
722 #[tokio::test]
723 async fn test_on_connect_no_handler_defaults_allow() {
724 let interceptor = ScriptInterceptor::new().await.unwrap();
725
726 let script = r#"
728 globalThis.onRequestHeaders = function(context, flow) { return flow; }
729 "#;
730 interceptor.load_script(script).await.unwrap();
731
732 let conn = sample_connection_info();
733 let action = interceptor.on_connect(&conn).await;
734 assert!(
735 matches!(action, ConnectAction::Allow),
736 "missing onConnect should default to Allow"
737 );
738 }
739
740 #[tokio::test]
741 async fn test_on_disconnect_fires() {
742 let interceptor = ScriptInterceptor::new().await.unwrap();
743
744 let script = r#"
745 globalThis.onDisconnect = function(context, conn, stats) {
746 // Verify stats fields are accessible in script
747 if (typeof stats.duration_ms !== "number") throw new Error("missing duration_ms");
748 if (typeof conn.client_addr !== "string") throw new Error("missing client_addr");
749 }
750 "#;
751 interceptor.load_script(script).await.unwrap();
752
753 let conn = sample_connection_info();
754 let stats = sample_connection_stats();
755 interceptor.on_disconnect(&conn, &stats).await;
757 }
758
759 #[tokio::test]
760 async fn test_on_websocket_start_fires() {
761 let interceptor = ScriptInterceptor::new().await.unwrap();
762
763 let script = r#"
764 globalThis.onWebSocketStart = function(context, flow) {
765 if (flow.layer.type === "Http") {
766 flow.layer.data.request.headers.push(["X-WS-Start", "1"]);
767 }
768 return flow;
769 }
770 "#;
771 interceptor.load_script(script).await.unwrap();
772
773 let mut flow = create_test_flow();
774 interceptor.on_websocket_start(&mut flow).await;
775
776 if let Layer::Http(http) = &flow.layer {
777 assert!(
778 http.request
779 .headers
780 .iter()
781 .any(|(k, v)| k == "X-WS-Start" && v == "1"),
782 "onWebSocketStart should inject header"
783 );
784 }
785 }
786
787 #[tokio::test]
788 async fn test_on_websocket_end_fires() {
789 let interceptor = ScriptInterceptor::new().await.unwrap();
790
791 let script = r#"
792 globalThis.onWebSocketEnd = function(context, flow, closeCode, closeReason) {
793 flow.tags.push("ws-ended:" + closeCode);
794 return flow;
795 }
796 "#;
797 interceptor.load_script(script).await.unwrap();
798
799 let mut flow = create_test_flow();
800 interceptor
801 .on_websocket_end(&mut flow, 1000, "Normal Closure")
802 .await;
803
804 assert!(
805 flow.tags.iter().any(|t| t == "ws-ended:1000"),
806 "onWebSocketEnd should tag flow with close code"
807 );
808 }
809
810 #[tokio::test]
811 async fn test_on_websocket_error_fires() {
812 let interceptor = ScriptInterceptor::new().await.unwrap();
813
814 let script = r#"
815 globalThis.onWebSocketError = function(context, flow, error) {
816 flow.tags.push("ws-error");
817 return flow;
818 }
819 "#;
820 interceptor.load_script(script).await.unwrap();
821
822 let mut flow = create_test_flow();
823 interceptor
824 .on_websocket_error(&mut flow, "connection reset")
825 .await;
826
827 assert!(
828 flow.tags.iter().any(|t| t == "ws-error"),
829 "onWebSocketError should tag flow"
830 );
831 }
832
833 #[tokio::test]
834 async fn test_on_connect_error_falls_back_allow() {
835 let interceptor = ScriptInterceptor::new().await.unwrap();
836
837 let script = r#"
838 globalThis.onConnect = function(context, conn) {
839 throw new Error("connect handler crash");
840 }
841 "#;
842 interceptor.load_script(script).await.unwrap();
843
844 let conn = sample_connection_info();
845 let action = interceptor.on_connect(&conn).await;
846 assert!(
847 matches!(action, ConnectAction::Allow),
848 "onConnect handler crash should fall back to Allow"
849 );
850 }
851
852 #[tokio::test]
853 async fn test_ws_lifecycle_normal_sequence() {
854 let interceptor = ScriptInterceptor::new().await.unwrap();
855
856 let script = r#"
857 globalThis.onWebSocketStart = function(context, flow) {
858 flow.tags.push("ws-life:start");
859 return flow;
860 }
861 globalThis.onWebSocketMessage = function(context, flow, message) {
862 message.content.content += " [mod]";
863 return message;
864 }
865 globalThis.onWebSocketEnd = function(context, flow, closeCode, closeReason) {
866 flow.tags.push("ws-life:end:" + closeCode);
867 return flow;
868 }
869 "#;
870 interceptor.load_script(script).await.unwrap();
871
872 let mut flow = create_test_flow();
873
874 interceptor.on_websocket_start(&mut flow).await;
876 assert!(flow.tags.iter().any(|t| t == "ws-life:start"));
877
878 let msg = WebSocketMessage {
879 id: Uuid::new_v4(),
880 timestamp: Utc::now(),
881 direction: Direction::ClientToServer,
882 content: BodyData {
883 encoding: "utf-8".to_string(),
884 content: "hello".to_string(),
885 size: 5,
886 },
887 opcode: "Text".to_string(),
888 };
889 let result = interceptor
890 .on_websocket_message(&mut flow, msg)
891 .await
892 .expect("ws message interception ok");
893 match result {
894 WebSocketMessageAction::Continue(forwarded) => {
895 assert!(forwarded.content.content.contains("[mod]"));
896 }
897 other => panic!("expected Continue, got {:?}", other),
898 }
899
900 interceptor.on_websocket_end(&mut flow, 1000, "done").await;
901 assert!(flow.tags.iter().any(|t| t == "ws-life:end:1000"));
902 }
903
904 #[tokio::test]
905 async fn test_ws_lifecycle_error_sequence() {
906 let interceptor = ScriptInterceptor::new().await.unwrap();
907
908 let script = r#"
909 globalThis.onWebSocketError = function(context, flow, error) {
910 flow.tags.push("ws-life:error");
911 return flow;
912 }
913 "#;
914 interceptor.load_script(script).await.unwrap();
915
916 let mut flow = create_test_flow();
917
918 interceptor
920 .on_websocket_error(&mut flow, "peer reset")
921 .await;
922 assert!(flow.tags.iter().any(|t| t == "ws-life:error"));
923 }
924}