1use std::cell::RefCell;
2use std::collections::{BTreeMap, HashMap, VecDeque};
3use std::rc::Rc;
4use std::time::{Duration, SystemTime};
5
6use crate::value::{VmError, VmValue};
7use crate::vm::Vm;
8
9use futures::{SinkExt, StreamExt};
10use reqwest_eventsource::{Event as SseEvent, EventSource};
11use tokio_tungstenite::tungstenite::Message as WsMessage;
12
13#[derive(Clone)]
16struct MockResponse {
17 status: i64,
18 body: String,
19 headers: BTreeMap<String, VmValue>,
20}
21
22#[derive(Clone, Debug, PartialEq, Eq)]
23pub struct HttpMockResponse {
24 pub status: i64,
25 pub body: String,
26 pub headers: BTreeMap<String, String>,
27}
28
29impl HttpMockResponse {
30 pub fn new(status: i64, body: impl Into<String>) -> Self {
31 Self {
32 status,
33 body: body.into(),
34 headers: BTreeMap::new(),
35 }
36 }
37
38 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
39 self.headers.insert(name.into(), value.into());
40 self
41 }
42}
43
44impl From<HttpMockResponse> for MockResponse {
45 fn from(value: HttpMockResponse) -> Self {
46 Self {
47 status: value.status,
48 body: value.body,
49 headers: value
50 .headers
51 .into_iter()
52 .map(|(key, value)| (key, VmValue::String(Rc::from(value))))
53 .collect(),
54 }
55 }
56}
57
58struct HttpMock {
59 method: String,
60 url_pattern: String,
61 responses: Vec<MockResponse>,
62 next_response: usize,
63}
64
65#[derive(Clone)]
66struct HttpMockCall {
67 method: String,
68 url: String,
69 headers: BTreeMap<String, VmValue>,
70 body: Option<String>,
71}
72
73#[derive(Clone, Debug, PartialEq, Eq)]
74pub struct HttpMockCallSnapshot {
75 pub method: String,
76 pub url: String,
77 pub headers: BTreeMap<String, String>,
78 pub body: Option<String>,
79}
80
81#[derive(Clone)]
82struct RetryConfig {
83 max: u32,
84 backoff_ms: u64,
85 retryable_statuses: Vec<u16>,
86 retryable_methods: Vec<String>,
87 respect_retry_after: bool,
88}
89
90#[derive(Clone)]
91struct HttpRequestConfig {
92 timeout_ms: u64,
93 retry: RetryConfig,
94 follow_redirects: bool,
95 max_redirects: usize,
96}
97
98#[derive(Clone)]
99struct HttpSession {
100 client: reqwest::Client,
101 options: BTreeMap<String, VmValue>,
102}
103
104struct HttpRequestParts {
105 method: reqwest::Method,
106 headers: reqwest::header::HeaderMap,
107 recorded_headers: BTreeMap<String, VmValue>,
108 body: Option<String>,
109}
110
111struct SseMock {
112 url_pattern: String,
113 events: Vec<MockStreamEvent>,
114}
115
116#[derive(Clone)]
117struct MockStreamEvent {
118 event_type: String,
119 data: String,
120 id: Option<String>,
121 retry_ms: Option<i64>,
122}
123
124struct SseHandle {
125 kind: SseHandleKind,
126 url: String,
127 max_events: usize,
128 max_message_bytes: usize,
129 received: usize,
130}
131
132enum SseHandleKind {
133 Real(Rc<tokio::sync::Mutex<EventSource>>),
134 Fake(Rc<tokio::sync::Mutex<FakeSseStream>>),
135}
136
137struct FakeSseStream {
138 events: VecDeque<MockStreamEvent>,
139 opened: bool,
140 closed: bool,
141}
142
143struct WebSocketMock {
144 url_pattern: String,
145 messages: Vec<MockWsMessage>,
146 echo: bool,
147}
148
149#[derive(Clone)]
150struct MockWsMessage {
151 message_type: String,
152 data: Vec<u8>,
153}
154
155type RealWebSocket =
156 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
157
158struct WebSocketHandle {
159 kind: WebSocketHandleKind,
160 url: String,
161 max_messages: usize,
162 max_message_bytes: usize,
163 received: usize,
164}
165
166enum WebSocketHandleKind {
167 Real(Rc<tokio::sync::Mutex<RealWebSocket>>),
168 Fake(Rc<tokio::sync::Mutex<FakeWebSocket>>),
169}
170
171struct FakeWebSocket {
172 messages: VecDeque<MockWsMessage>,
173 echo: bool,
174 closed: bool,
175}
176
177#[derive(Clone)]
178struct TransportMockCall {
179 kind: String,
180 handle: Option<String>,
181 url: String,
182 message_type: Option<String>,
183 data: Option<String>,
184}
185
186const DEFAULT_TIMEOUT_MS: u64 = 30_000;
187const DEFAULT_BACKOFF_MS: u64 = 1_000;
188const MAX_RETRY_DELAY_MS: u64 = 60_000;
189const DEFAULT_RETRYABLE_STATUSES: [u16; 6] = [408, 429, 500, 502, 503, 504];
190const DEFAULT_RETRYABLE_METHODS: [&str; 5] = ["GET", "HEAD", "PUT", "DELETE", "OPTIONS"];
191const DEFAULT_TRANSPORT_RECEIVE_TIMEOUT_MS: u64 = 30_000;
192const DEFAULT_MAX_STREAM_EVENTS: usize = 10_000;
193const DEFAULT_MAX_MESSAGE_BYTES: usize = 1024 * 1024;
194const MAX_HTTP_SESSIONS: usize = 64;
195const MAX_SSE_STREAMS: usize = 64;
196const MAX_WEBSOCKETS: usize = 64;
197
198thread_local! {
199 static HTTP_MOCKS: RefCell<Vec<HttpMock>> = const { RefCell::new(Vec::new()) };
200 static HTTP_MOCK_CALLS: RefCell<Vec<HttpMockCall>> = const { RefCell::new(Vec::new()) };
201 static HTTP_CLIENTS: RefCell<HashMap<String, reqwest::Client>> = RefCell::new(HashMap::new());
202 static HTTP_SESSIONS: RefCell<HashMap<String, HttpSession>> = RefCell::new(HashMap::new());
203 static SSE_MOCKS: RefCell<Vec<SseMock>> = const { RefCell::new(Vec::new()) };
204 static SSE_HANDLES: RefCell<HashMap<String, SseHandle>> = RefCell::new(HashMap::new());
205 static WEBSOCKET_MOCKS: RefCell<Vec<WebSocketMock>> = const { RefCell::new(Vec::new()) };
206 static WEBSOCKET_HANDLES: RefCell<HashMap<String, WebSocketHandle>> = RefCell::new(HashMap::new());
207 static TRANSPORT_MOCK_CALLS: RefCell<Vec<TransportMockCall>> = const { RefCell::new(Vec::new()) };
208 static TRANSPORT_HANDLE_COUNTER: RefCell<u64> = const { RefCell::new(0) };
209}
210
211pub fn reset_http_state() {
213 HTTP_MOCKS.with(|m| m.borrow_mut().clear());
214 HTTP_MOCK_CALLS.with(|c| c.borrow_mut().clear());
215 HTTP_CLIENTS.with(|clients| clients.borrow_mut().clear());
216 HTTP_SESSIONS.with(|sessions| sessions.borrow_mut().clear());
217 SSE_MOCKS.with(|mocks| mocks.borrow_mut().clear());
218 SSE_HANDLES.with(|handles| {
219 for handle in handles.borrow_mut().values_mut() {
220 if let SseHandleKind::Real(stream) = &handle.kind {
221 if let Ok(mut stream) = stream.try_lock() {
222 stream.close();
223 }
224 }
225 }
226 handles.borrow_mut().clear();
227 });
228 WEBSOCKET_MOCKS.with(|mocks| mocks.borrow_mut().clear());
229 WEBSOCKET_HANDLES.with(|handles| handles.borrow_mut().clear());
230 TRANSPORT_MOCK_CALLS.with(|calls| calls.borrow_mut().clear());
231 TRANSPORT_HANDLE_COUNTER.with(|counter| *counter.borrow_mut() = 0);
232}
233
234pub fn push_http_mock(
235 method: impl Into<String>,
236 url_pattern: impl Into<String>,
237 responses: Vec<HttpMockResponse>,
238) {
239 let responses = if responses.is_empty() {
240 vec![MockResponse::from(HttpMockResponse::new(200, ""))]
241 } else {
242 responses.into_iter().map(MockResponse::from).collect()
243 };
244 let method = method.into();
245 let url_pattern = url_pattern.into();
246 HTTP_MOCKS.with(|mocks| {
247 let mut mocks = mocks.borrow_mut();
248 mocks.retain(|mock| !(mock.method == method && mock.url_pattern == url_pattern));
253 mocks.push(HttpMock {
254 method,
255 url_pattern,
256 responses,
257 next_response: 0,
258 });
259 });
260}
261
262pub fn http_mock_calls_snapshot() -> Vec<HttpMockCallSnapshot> {
263 HTTP_MOCK_CALLS.with(|calls| {
264 calls
265 .borrow()
266 .iter()
267 .map(|call| HttpMockCallSnapshot {
268 method: call.method.clone(),
269 url: call.url.clone(),
270 headers: call
271 .headers
272 .iter()
273 .map(|(key, value)| (key.clone(), value.display()))
274 .collect(),
275 body: call.body.clone(),
276 })
277 .collect()
278 })
279}
280
281fn url_matches(pattern: &str, url: &str) -> bool {
283 if pattern == "*" {
284 return true;
285 }
286 if !pattern.contains('*') {
287 return pattern == url;
288 }
289 let parts: Vec<&str> = pattern.split('*').collect();
291 let mut remaining = url;
292 for (i, part) in parts.iter().enumerate() {
293 if part.is_empty() {
294 continue;
295 }
296 if i == 0 {
297 if !remaining.starts_with(part) {
298 return false;
299 }
300 remaining = &remaining[part.len()..];
301 } else if i == parts.len() - 1 {
302 if !remaining.ends_with(part) {
303 return false;
304 }
305 remaining = "";
306 } else {
307 match remaining.find(part) {
308 Some(pos) => remaining = &remaining[pos + part.len()..],
309 None => return false,
310 }
311 }
312 }
313 true
314}
315
316fn build_http_response(status: i64, headers: BTreeMap<String, VmValue>, body: String) -> VmValue {
318 let mut result = BTreeMap::new();
319 result.insert("status".to_string(), VmValue::Int(status));
320 result.insert("headers".to_string(), VmValue::Dict(Rc::new(headers)));
321 result.insert("body".to_string(), VmValue::String(Rc::from(body)));
322 result.insert(
323 "ok".to_string(),
324 VmValue::Bool((200..300).contains(&(status as u16))),
325 );
326 VmValue::Dict(Rc::new(result))
327}
328
329fn vm_error(message: impl Into<String>) -> VmError {
330 VmError::Thrown(VmValue::String(Rc::from(message.into())))
331}
332
333fn next_transport_handle(prefix: &str) -> String {
334 TRANSPORT_HANDLE_COUNTER.with(|counter| {
335 let mut counter = counter.borrow_mut();
336 *counter += 1;
337 format!("{prefix}-{}", *counter)
338 })
339}
340
341fn handle_from_value(value: &VmValue, builtin: &str) -> Result<String, VmError> {
342 match value {
343 VmValue::String(handle) => Ok(handle.to_string()),
344 VmValue::Dict(dict) => dict
345 .get("id")
346 .map(|id| id.display())
347 .filter(|id| !id.is_empty())
348 .ok_or_else(|| vm_error(format!("{builtin}: handle dict must contain id"))),
349 _ => Err(vm_error(format!(
350 "{builtin}: first argument must be a handle string or dict"
351 ))),
352 }
353}
354
355fn get_options_arg(args: &[VmValue], index: usize) -> BTreeMap<String, VmValue> {
356 args.get(index)
357 .and_then(|value| value.as_dict())
358 .cloned()
359 .unwrap_or_default()
360}
361
362fn merge_options(
363 base: &BTreeMap<String, VmValue>,
364 overrides: &BTreeMap<String, VmValue>,
365) -> BTreeMap<String, VmValue> {
366 let mut merged = base.clone();
367 for (key, value) in overrides {
368 merged.insert(key.clone(), value.clone());
369 }
370 merged
371}
372
373fn transport_limit_option(options: &BTreeMap<String, VmValue>, key: &str, default: usize) -> usize {
374 options
375 .get(key)
376 .and_then(|value| value.as_int())
377 .map(|value| value.max(0) as usize)
378 .unwrap_or(default)
379}
380
381fn receive_timeout_arg(args: &[VmValue], index: usize) -> u64 {
382 match args.get(index) {
383 Some(VmValue::Duration(ms)) => *ms,
384 Some(value) => value
385 .as_int()
386 .map(|ms| ms.max(0) as u64)
387 .unwrap_or(DEFAULT_TRANSPORT_RECEIVE_TIMEOUT_MS),
388 None => DEFAULT_TRANSPORT_RECEIVE_TIMEOUT_MS,
389 }
390}
391
392fn timeout_event() -> VmValue {
393 let mut dict = BTreeMap::new();
394 dict.insert("type".to_string(), VmValue::String(Rc::from("timeout")));
395 VmValue::Dict(Rc::new(dict))
396}
397
398fn closed_event() -> VmValue {
399 let mut dict = BTreeMap::new();
400 dict.insert("type".to_string(), VmValue::String(Rc::from("close")));
401 VmValue::Dict(Rc::new(dict))
402}
403
404fn record_transport_call(call: TransportMockCall) {
405 TRANSPORT_MOCK_CALLS.with(|calls| calls.borrow_mut().push(call));
406}
407
408async fn http_verb_handler(
412 method: &str,
413 has_body: bool,
414 args: Vec<VmValue>,
415) -> Result<VmValue, VmError> {
416 let url = args.first().map(|a| a.display()).unwrap_or_default();
417 if url.is_empty() {
418 return Err(VmError::Thrown(VmValue::String(Rc::from(format!(
419 "http_{}: URL is required",
420 method.to_ascii_lowercase()
421 )))));
422 }
423 let mut options = if has_body {
424 match args.get(2) {
425 Some(VmValue::Dict(d)) => (**d).clone(),
426 _ => BTreeMap::new(),
427 }
428 } else {
429 match args.get(1) {
430 Some(VmValue::Dict(d)) => (**d).clone(),
431 _ => BTreeMap::new(),
432 }
433 };
434 if has_body {
435 let body = args.get(1).map(|a| a.display()).unwrap_or_default();
436 options.insert("body".to_string(), VmValue::String(Rc::from(body)));
437 }
438 vm_execute_http_request(method, &url, &options).await
439}
440
441fn parse_mock_response_dict(response: &BTreeMap<String, VmValue>) -> MockResponse {
442 let status = response
443 .get("status")
444 .and_then(|v| v.as_int())
445 .unwrap_or(200);
446 let body = response
447 .get("body")
448 .map(|v| v.display())
449 .unwrap_or_default();
450 let headers = response
451 .get("headers")
452 .and_then(|v| v.as_dict())
453 .cloned()
454 .unwrap_or_default();
455 MockResponse {
456 status,
457 body,
458 headers,
459 }
460}
461
462fn parse_mock_responses(response: &BTreeMap<String, VmValue>) -> Vec<MockResponse> {
463 let scripted = response
464 .get("responses")
465 .and_then(|value| match value {
466 VmValue::List(items) => Some(
467 items
468 .iter()
469 .filter_map(|item| item.as_dict().map(parse_mock_response_dict))
470 .collect::<Vec<_>>(),
471 ),
472 _ => None,
473 })
474 .unwrap_or_default();
475
476 if scripted.is_empty() {
477 vec![parse_mock_response_dict(response)]
478 } else {
479 scripted
480 }
481}
482
483fn consume_http_mock(
484 method: &str,
485 url: &str,
486 headers: BTreeMap<String, VmValue>,
487 body: Option<String>,
488) -> Option<MockResponse> {
489 let response = HTTP_MOCKS.with(|mocks| {
490 let mut mocks = mocks.borrow_mut();
491 for mock in mocks.iter_mut() {
492 if (mock.method == "*" || mock.method.eq_ignore_ascii_case(method))
493 && url_matches(&mock.url_pattern, url)
494 {
495 let Some(last_index) = mock.responses.len().checked_sub(1) else {
496 continue;
497 };
498 let index = mock.next_response.min(last_index);
499 let response = mock.responses[index].clone();
500 if mock.next_response < last_index {
501 mock.next_response += 1;
502 }
503 return Some(response);
504 }
505 }
506 None
507 })?;
508
509 HTTP_MOCK_CALLS.with(|calls| {
510 calls.borrow_mut().push(HttpMockCall {
511 method: method.to_string(),
512 url: url.to_string(),
513 headers,
514 body,
515 });
516 });
517
518 Some(response)
519}
520
521pub fn register_http_builtins(vm: &mut Vm) {
523 vm.register_async_builtin("http_get", |args| async move {
524 http_verb_handler("GET", false, args).await
525 });
526 vm.register_async_builtin("http_post", |args| async move {
527 http_verb_handler("POST", true, args).await
528 });
529 vm.register_async_builtin("http_put", |args| async move {
530 http_verb_handler("PUT", true, args).await
531 });
532 vm.register_async_builtin("http_patch", |args| async move {
533 http_verb_handler("PATCH", true, args).await
534 });
535 vm.register_async_builtin("http_delete", |args| async move {
536 http_verb_handler("DELETE", false, args).await
537 });
538
539 vm.register_builtin("http_mock", |args, _out| {
547 let method = args.first().map(|a| a.display()).unwrap_or_default();
548 let url_pattern = args.get(1).map(|a| a.display()).unwrap_or_default();
549 let response = args
550 .get(2)
551 .and_then(|a| a.as_dict())
552 .cloned()
553 .unwrap_or_default();
554 let responses = parse_mock_responses(&response);
555
556 HTTP_MOCKS.with(|mocks| {
557 let mut mocks = mocks.borrow_mut();
558 mocks.retain(|mock| !(mock.method == method && mock.url_pattern == url_pattern));
559 mocks.push(HttpMock {
560 method,
561 url_pattern,
562 responses,
563 next_response: 0,
564 });
565 });
566 Ok(VmValue::Nil)
567 });
568
569 vm.register_builtin("http_mock_clear", |_args, _out| {
571 HTTP_MOCKS.with(|mocks| mocks.borrow_mut().clear());
572 HTTP_MOCK_CALLS.with(|calls| calls.borrow_mut().clear());
573 Ok(VmValue::Nil)
574 });
575
576 vm.register_builtin("http_mock_calls", |_args, _out| {
578 let calls = HTTP_MOCK_CALLS.with(|calls| calls.borrow().clone());
579 let result: Vec<VmValue> = calls
580 .iter()
581 .map(|c| {
582 let mut dict = BTreeMap::new();
583 dict.insert(
584 "method".to_string(),
585 VmValue::String(Rc::from(c.method.as_str())),
586 );
587 dict.insert("url".to_string(), VmValue::String(Rc::from(c.url.as_str())));
588 dict.insert(
589 "headers".to_string(),
590 VmValue::Dict(Rc::new(c.headers.clone())),
591 );
592 dict.insert(
593 "body".to_string(),
594 match &c.body {
595 Some(b) => VmValue::String(Rc::from(b.as_str())),
596 None => VmValue::Nil,
597 },
598 );
599 VmValue::Dict(Rc::new(dict))
600 })
601 .collect();
602 Ok(VmValue::List(Rc::new(result)))
603 });
604
605 vm.register_async_builtin("http_request", |args| async move {
606 let method = args
607 .first()
608 .map(|a| a.display())
609 .unwrap_or_default()
610 .to_uppercase();
611 if method.is_empty() {
612 return Err(VmError::Thrown(VmValue::String(Rc::from(
613 "http_request: method is required",
614 ))));
615 }
616 let url = args.get(1).map(|a| a.display()).unwrap_or_default();
617 if url.is_empty() {
618 return Err(VmError::Thrown(VmValue::String(Rc::from(
619 "http_request: URL is required",
620 ))));
621 }
622 let options = match args.get(2) {
623 Some(VmValue::Dict(d)) => (**d).clone(),
624 _ => BTreeMap::new(),
625 };
626 vm_execute_http_request(&method, &url, &options).await
627 });
628
629 vm.register_builtin("http_session", |args, _out| {
630 let options = get_options_arg(args, 0);
631 let config = parse_http_options(&options);
632 let client = build_http_client(&config)?;
633 let id = next_transport_handle("http-session");
634 HTTP_SESSIONS.with(|sessions| {
635 let mut sessions = sessions.borrow_mut();
636 if sessions.len() >= MAX_HTTP_SESSIONS {
637 return Err(vm_error(format!(
638 "http_session: maximum open sessions ({MAX_HTTP_SESSIONS}) reached"
639 )));
640 }
641 sessions.insert(id.clone(), HttpSession { client, options });
642 Ok(())
643 })?;
644 Ok(VmValue::String(Rc::from(id)))
645 });
646
647 vm.register_async_builtin("http_session_request", |args| async move {
648 if args.len() < 3 {
649 return Err(vm_error(
650 "http_session_request: requires session, method, and URL",
651 ));
652 }
653 let session_id = handle_from_value(&args[0], "http_session_request")?;
654 let method = args[1].display().to_uppercase();
655 if method.is_empty() {
656 return Err(vm_error("http_session_request: method is required"));
657 }
658 let url = args[2].display();
659 if url.is_empty() {
660 return Err(vm_error("http_session_request: URL is required"));
661 }
662 let options = get_options_arg(&args, 3);
663 vm_execute_http_session_request(&session_id, &method, &url, &options).await
664 });
665
666 vm.register_builtin("http_session_close", |args, _out| {
667 let Some(handle) = args.first() else {
668 return Err(vm_error("http_session_close: requires a session handle"));
669 };
670 let session_id = handle_from_value(handle, "http_session_close")?;
671 let removed = HTTP_SESSIONS.with(|sessions| sessions.borrow_mut().remove(&session_id));
672 Ok(VmValue::Bool(removed.is_some()))
673 });
674
675 vm.register_builtin("sse_mock", |args, _out| {
676 let url_pattern = args.first().map(|arg| arg.display()).unwrap_or_default();
677 if url_pattern.is_empty() {
678 return Err(vm_error("sse_mock: URL pattern is required"));
679 }
680 let events = parse_mock_stream_events(args.get(1));
681 SSE_MOCKS.with(|mocks| {
682 mocks.borrow_mut().push(SseMock {
683 url_pattern,
684 events,
685 });
686 });
687 Ok(VmValue::Nil)
688 });
689
690 vm.register_async_builtin("sse_connect", |args| async move {
691 let method = args
692 .first()
693 .map(|arg| arg.display())
694 .filter(|method| !method.is_empty())
695 .unwrap_or_else(|| "GET".to_string())
696 .to_uppercase();
697 let url = args.get(1).map(|arg| arg.display()).unwrap_or_default();
698 if url.is_empty() {
699 return Err(vm_error("sse_connect: URL is required"));
700 }
701 let options = get_options_arg(&args, 2);
702 vm_sse_connect(&method, &url, &options).await
703 });
704
705 vm.register_async_builtin("sse_receive", |args| async move {
706 let Some(handle) = args.first() else {
707 return Err(vm_error("sse_receive: requires a stream handle"));
708 };
709 let stream_id = handle_from_value(handle, "sse_receive")?;
710 let timeout_ms = receive_timeout_arg(&args, 1);
711 vm_sse_receive(&stream_id, timeout_ms).await
712 });
713
714 vm.register_builtin("sse_close", |args, _out| {
715 let Some(handle) = args.first() else {
716 return Err(vm_error("sse_close: requires a stream handle"));
717 };
718 let stream_id = handle_from_value(handle, "sse_close")?;
719 let removed = SSE_HANDLES.with(|handles| {
720 let mut handles = handles.borrow_mut();
721 let removed = handles.remove(&stream_id);
722 if let Some(handle) = &removed {
723 if let SseHandleKind::Real(stream) = &handle.kind {
724 if let Ok(mut stream) = stream.try_lock() {
725 stream.close();
726 }
727 }
728 }
729 removed
730 });
731 Ok(VmValue::Bool(removed.is_some()))
732 });
733
734 vm.register_builtin("websocket_mock", |args, _out| {
735 let url_pattern = args.first().map(|arg| arg.display()).unwrap_or_default();
736 if url_pattern.is_empty() {
737 return Err(vm_error("websocket_mock: URL pattern is required"));
738 }
739 let (messages, echo) = parse_websocket_mock(args.get(1));
740 WEBSOCKET_MOCKS.with(|mocks| {
741 mocks.borrow_mut().push(WebSocketMock {
742 url_pattern,
743 messages,
744 echo,
745 });
746 });
747 Ok(VmValue::Nil)
748 });
749
750 vm.register_async_builtin("websocket_connect", |args| async move {
751 let url = args.first().map(|arg| arg.display()).unwrap_or_default();
752 if url.is_empty() {
753 return Err(vm_error("websocket_connect: URL is required"));
754 }
755 let options = get_options_arg(&args, 1);
756 vm_websocket_connect(&url, &options).await
757 });
758
759 vm.register_async_builtin("websocket_send", |args| async move {
760 if args.len() < 2 {
761 return Err(vm_error(
762 "websocket_send: requires socket handle and message",
763 ));
764 }
765 let socket_id = handle_from_value(&args[0], "websocket_send")?;
766 let message = args[1].clone();
767 let options = get_options_arg(&args, 2);
768 vm_websocket_send(&socket_id, message, &options).await
769 });
770
771 vm.register_async_builtin("websocket_receive", |args| async move {
772 let Some(handle) = args.first() else {
773 return Err(vm_error("websocket_receive: requires a socket handle"));
774 };
775 let socket_id = handle_from_value(handle, "websocket_receive")?;
776 let timeout_ms = receive_timeout_arg(&args, 1);
777 vm_websocket_receive(&socket_id, timeout_ms).await
778 });
779
780 vm.register_async_builtin("websocket_close", |args| async move {
781 let Some(handle) = args.first() else {
782 return Err(vm_error("websocket_close: requires a socket handle"));
783 };
784 let socket_id = handle_from_value(handle, "websocket_close")?;
785 vm_websocket_close(&socket_id).await
786 });
787
788 vm.register_builtin("transport_mock_clear", |_args, _out| {
789 SSE_MOCKS.with(|mocks| mocks.borrow_mut().clear());
790 SSE_HANDLES.with(|handles| handles.borrow_mut().clear());
791 WEBSOCKET_MOCKS.with(|mocks| mocks.borrow_mut().clear());
792 WEBSOCKET_HANDLES.with(|handles| handles.borrow_mut().clear());
793 TRANSPORT_MOCK_CALLS.with(|calls| calls.borrow_mut().clear());
794 Ok(VmValue::Nil)
795 });
796
797 vm.register_builtin("transport_mock_calls", |_args, _out| {
798 let calls = TRANSPORT_MOCK_CALLS.with(|calls| calls.borrow().clone());
799 let values = calls
800 .iter()
801 .map(transport_mock_call_value)
802 .collect::<Vec<_>>();
803 Ok(VmValue::List(Rc::new(values)))
804 });
805}
806
807fn vm_get_int_option(options: &BTreeMap<String, VmValue>, key: &str, default: i64) -> i64 {
808 options.get(key).and_then(|v| v.as_int()).unwrap_or(default)
809}
810
811fn vm_get_bool_option(options: &BTreeMap<String, VmValue>, key: &str, default: bool) -> bool {
812 match options.get(key) {
813 Some(VmValue::Bool(b)) => *b,
814 _ => default,
815 }
816}
817
818fn vm_get_int_option_prefer(
819 options: &BTreeMap<String, VmValue>,
820 canonical: &str,
821 alias: &str,
822 default: i64,
823) -> i64 {
824 options
825 .get(canonical)
826 .and_then(|value| value.as_int())
827 .or_else(|| options.get(alias).and_then(|value| value.as_int()))
828 .unwrap_or(default)
829}
830
831fn parse_retry_statuses(options: &BTreeMap<String, VmValue>) -> Vec<u16> {
832 match options.get("retry_on") {
833 Some(VmValue::List(values)) => {
834 let statuses: Vec<u16> = values
835 .iter()
836 .filter_map(|value| value.as_int())
837 .filter(|status| (0..=u16::MAX as i64).contains(status))
838 .map(|status| status as u16)
839 .collect();
840 if statuses.is_empty() {
841 DEFAULT_RETRYABLE_STATUSES.to_vec()
842 } else {
843 statuses
844 }
845 }
846 _ => DEFAULT_RETRYABLE_STATUSES.to_vec(),
847 }
848}
849
850fn parse_retry_methods(options: &BTreeMap<String, VmValue>) -> Vec<String> {
851 match options.get("retry_methods") {
852 Some(VmValue::List(values)) => {
853 let methods: Vec<String> = values
854 .iter()
855 .map(|value| value.display().trim().to_ascii_uppercase())
856 .filter(|value| !value.is_empty())
857 .collect();
858 if methods.is_empty() {
859 DEFAULT_RETRYABLE_METHODS
860 .iter()
861 .map(|method| (*method).to_string())
862 .collect()
863 } else {
864 methods
865 }
866 }
867 _ => DEFAULT_RETRYABLE_METHODS
868 .iter()
869 .map(|method| (*method).to_string())
870 .collect(),
871 }
872}
873
874fn parse_http_options(options: &BTreeMap<String, VmValue>) -> HttpRequestConfig {
875 let timeout_ms = vm_get_int_option_prefer(
876 options,
877 "timeout_ms",
878 "timeout",
879 DEFAULT_TIMEOUT_MS as i64,
880 )
881 .max(0) as u64;
882 let retry_options = options.get("retry").and_then(|value| value.as_dict());
883 let retry_max = retry_options
884 .and_then(|retry| retry.get("max"))
885 .and_then(|value| value.as_int())
886 .unwrap_or_else(|| vm_get_int_option(options, "retries", 0))
887 .max(0) as u32;
888 let retry_backoff_ms = retry_options
889 .and_then(|retry| retry.get("backoff_ms"))
890 .and_then(|value| value.as_int())
891 .unwrap_or_else(|| vm_get_int_option(options, "backoff", DEFAULT_BACKOFF_MS as i64))
892 .max(0) as u64;
893 let respect_retry_after = vm_get_bool_option(options, "respect_retry_after", true);
894 let follow_redirects = vm_get_bool_option(options, "follow_redirects", true);
895 let max_redirects = vm_get_int_option(options, "max_redirects", 10).max(0) as usize;
896
897 HttpRequestConfig {
898 timeout_ms,
899 retry: RetryConfig {
900 max: retry_max,
901 backoff_ms: retry_backoff_ms,
902 retryable_statuses: parse_retry_statuses(options),
903 retryable_methods: parse_retry_methods(options),
904 respect_retry_after,
905 },
906 follow_redirects,
907 max_redirects,
908 }
909}
910
911fn http_client_key(config: &HttpRequestConfig) -> String {
912 format!(
913 "follow_redirects={};max_redirects={}",
914 config.follow_redirects, config.max_redirects
915 )
916}
917
918fn build_http_client(config: &HttpRequestConfig) -> Result<reqwest::Client, VmError> {
919 let redirect_policy = if config.follow_redirects {
920 reqwest::redirect::Policy::limited(config.max_redirects)
921 } else {
922 reqwest::redirect::Policy::none()
923 };
924
925 reqwest::Client::builder()
926 .redirect(redirect_policy)
927 .build()
928 .map_err(|e| vm_error(format!("http: failed to build client: {e}")))
929}
930
931fn pooled_http_client(config: &HttpRequestConfig) -> Result<reqwest::Client, VmError> {
932 let key = http_client_key(config);
933 if let Some(client) = HTTP_CLIENTS.with(|clients| clients.borrow().get(&key).cloned()) {
934 return Ok(client);
935 }
936
937 let client = build_http_client(config)?;
938 HTTP_CLIENTS.with(|clients| {
939 clients.borrow_mut().insert(key, client.clone());
940 });
941 Ok(client)
942}
943
944fn parse_http_request_parts(
945 method: &str,
946 options: &BTreeMap<String, VmValue>,
947) -> Result<HttpRequestParts, VmError> {
948 let req_method = method
949 .parse::<reqwest::Method>()
950 .map_err(|e| vm_error(format!("http: invalid method '{method}': {e}")))?;
951
952 let mut header_map = reqwest::header::HeaderMap::new();
953 let mut recorded_headers = BTreeMap::new();
954
955 if let Some(auth_val) = options.get("auth") {
956 match auth_val {
957 VmValue::String(s) => {
958 let hv = reqwest::header::HeaderValue::from_str(s)
959 .map_err(|e| vm_error(format!("http: invalid auth header value: {e}")))?;
960 header_map.insert(reqwest::header::AUTHORIZATION, hv);
961 recorded_headers.insert(
962 "Authorization".to_string(),
963 VmValue::String(Rc::from(s.as_ref())),
964 );
965 }
966 VmValue::Dict(d) => {
967 if let Some(bearer) = d.get("bearer") {
968 let token = bearer.display();
969 let authorization = format!("Bearer {token}");
970 let hv = reqwest::header::HeaderValue::from_str(&authorization)
971 .map_err(|e| vm_error(format!("http: invalid bearer token: {e}")))?;
972 header_map.insert(reqwest::header::AUTHORIZATION, hv);
973 recorded_headers.insert(
974 "Authorization".to_string(),
975 VmValue::String(Rc::from(authorization)),
976 );
977 } else if let Some(VmValue::Dict(basic)) = d.get("basic") {
978 let user = basic.get("user").map(|v| v.display()).unwrap_or_default();
979 let password = basic
980 .get("password")
981 .map(|v| v.display())
982 .unwrap_or_default();
983 use base64::Engine;
984 let encoded = base64::engine::general_purpose::STANDARD
985 .encode(format!("{user}:{password}"));
986 let authorization = format!("Basic {encoded}");
987 let hv = reqwest::header::HeaderValue::from_str(&authorization)
988 .map_err(|e| vm_error(format!("http: invalid basic auth: {e}")))?;
989 header_map.insert(reqwest::header::AUTHORIZATION, hv);
990 recorded_headers.insert(
991 "Authorization".to_string(),
992 VmValue::String(Rc::from(authorization)),
993 );
994 }
995 }
996 _ => {}
997 }
998 }
999
1000 if let Some(VmValue::Dict(hdrs)) = options.get("headers") {
1001 for (k, v) in hdrs.iter() {
1002 let name = reqwest::header::HeaderName::from_bytes(k.as_bytes())
1003 .map_err(|e| vm_error(format!("http: invalid header name '{k}': {e}")))?;
1004 let val = reqwest::header::HeaderValue::from_str(&v.display())
1005 .map_err(|e| vm_error(format!("http: invalid header value for '{k}': {e}")))?;
1006 header_map.insert(name, val);
1007 recorded_headers.insert(k.clone(), VmValue::String(Rc::from(v.display())));
1008 }
1009 }
1010
1011 Ok(HttpRequestParts {
1012 method: req_method,
1013 headers: header_map,
1014 recorded_headers,
1015 body: options.get("body").map(|v| v.display()),
1016 })
1017}
1018
1019fn session_from_options(options: &BTreeMap<String, VmValue>) -> Option<String> {
1020 options
1021 .get("session")
1022 .and_then(|value| handle_from_value(value, "http_request").ok())
1023}
1024
1025fn parse_mock_stream_event(value: &VmValue) -> MockStreamEvent {
1026 match value {
1027 VmValue::Dict(dict) => MockStreamEvent {
1028 event_type: dict
1029 .get("event")
1030 .or_else(|| dict.get("type"))
1031 .map(|value| value.display())
1032 .filter(|value| !value.is_empty())
1033 .unwrap_or_else(|| "message".to_string()),
1034 data: dict
1035 .get("data")
1036 .map(|value| value.display())
1037 .unwrap_or_default(),
1038 id: dict
1039 .get("id")
1040 .map(|value| value.display())
1041 .filter(|value| !value.is_empty()),
1042 retry_ms: dict.get("retry_ms").and_then(|value| value.as_int()),
1043 },
1044 _ => MockStreamEvent {
1045 event_type: "message".to_string(),
1046 data: value.display(),
1047 id: None,
1048 retry_ms: None,
1049 },
1050 }
1051}
1052
1053fn parse_mock_stream_events(value: Option<&VmValue>) -> Vec<MockStreamEvent> {
1054 let Some(value) = value else {
1055 return Vec::new();
1056 };
1057 match value {
1058 VmValue::Dict(dict) => dict
1059 .get("events")
1060 .and_then(|events| match events {
1061 VmValue::List(items) => Some(items.iter().map(parse_mock_stream_event).collect()),
1062 _ => None,
1063 })
1064 .unwrap_or_default(),
1065 VmValue::List(items) => items.iter().map(parse_mock_stream_event).collect(),
1066 other => vec![parse_mock_stream_event(other)],
1067 }
1068}
1069
1070fn sse_event_value(event: &MockStreamEvent) -> VmValue {
1071 let mut dict = BTreeMap::new();
1072 dict.insert("type".to_string(), VmValue::String(Rc::from("event")));
1073 dict.insert(
1074 "event".to_string(),
1075 VmValue::String(Rc::from(event.event_type.as_str())),
1076 );
1077 dict.insert(
1078 "data".to_string(),
1079 VmValue::String(Rc::from(event.data.as_str())),
1080 );
1081 dict.insert(
1082 "id".to_string(),
1083 event
1084 .id
1085 .as_deref()
1086 .map(|id| VmValue::String(Rc::from(id)))
1087 .unwrap_or(VmValue::Nil),
1088 );
1089 dict.insert(
1090 "retry_ms".to_string(),
1091 event.retry_ms.map(VmValue::Int).unwrap_or(VmValue::Nil),
1092 );
1093 VmValue::Dict(Rc::new(dict))
1094}
1095
1096fn real_sse_event_value(event: SseEvent) -> VmValue {
1097 match event {
1098 SseEvent::Open => {
1099 let mut dict = BTreeMap::new();
1100 dict.insert("type".to_string(), VmValue::String(Rc::from("open")));
1101 VmValue::Dict(Rc::new(dict))
1102 }
1103 SseEvent::Message(message) => {
1104 let retry_ms = message.retry.map(|retry| retry.as_millis() as i64);
1105 sse_event_value(&MockStreamEvent {
1106 event_type: if message.event.is_empty() {
1107 "message".to_string()
1108 } else {
1109 message.event
1110 },
1111 data: message.data,
1112 id: if message.id.is_empty() {
1113 None
1114 } else {
1115 Some(message.id)
1116 },
1117 retry_ms,
1118 })
1119 }
1120 }
1121}
1122
1123fn consume_sse_mock(url: &str) -> Option<Vec<MockStreamEvent>> {
1124 SSE_MOCKS.with(|mocks| {
1125 mocks
1126 .borrow()
1127 .iter()
1128 .find(|mock| url_matches(&mock.url_pattern, url))
1129 .map(|mock| mock.events.clone())
1130 })
1131}
1132
1133fn parse_ws_message(value: &VmValue) -> MockWsMessage {
1134 match value {
1135 VmValue::Dict(dict) => {
1136 let message_type = dict
1137 .get("type")
1138 .map(|value| value.display())
1139 .filter(|value| !value.is_empty())
1140 .unwrap_or_else(|| "text".to_string());
1141 let data = if dict
1142 .get("base64")
1143 .and_then(|value| match value {
1144 VmValue::Bool(value) => Some(*value),
1145 _ => None,
1146 })
1147 .unwrap_or(false)
1148 {
1149 use base64::Engine;
1150 dict.get("data")
1151 .map(|value| value.display())
1152 .and_then(|data| base64::engine::general_purpose::STANDARD.decode(data).ok())
1153 .unwrap_or_default()
1154 } else {
1155 dict.get("data")
1156 .map(|value| value.display().into_bytes())
1157 .unwrap_or_default()
1158 };
1159 MockWsMessage { message_type, data }
1160 }
1161 VmValue::Bytes(bytes) => MockWsMessage {
1162 message_type: "binary".to_string(),
1163 data: bytes.as_ref().clone(),
1164 },
1165 other => MockWsMessage {
1166 message_type: "text".to_string(),
1167 data: other.display().into_bytes(),
1168 },
1169 }
1170}
1171
1172fn parse_websocket_mock(value: Option<&VmValue>) -> (Vec<MockWsMessage>, bool) {
1173 let Some(value) = value else {
1174 return (Vec::new(), false);
1175 };
1176 match value {
1177 VmValue::Dict(dict) => {
1178 let echo = dict
1179 .get("echo")
1180 .and_then(|value| match value {
1181 VmValue::Bool(value) => Some(*value),
1182 _ => None,
1183 })
1184 .unwrap_or(false);
1185 let messages = dict
1186 .get("messages")
1187 .and_then(|messages| match messages {
1188 VmValue::List(items) => Some(items.iter().map(parse_ws_message).collect()),
1189 _ => None,
1190 })
1191 .unwrap_or_default();
1192 (messages, echo)
1193 }
1194 VmValue::List(items) => (items.iter().map(parse_ws_message).collect(), false),
1195 other => (vec![parse_ws_message(other)], false),
1196 }
1197}
1198
1199fn consume_websocket_mock(url: &str) -> Option<(Vec<MockWsMessage>, bool)> {
1200 WEBSOCKET_MOCKS.with(|mocks| {
1201 mocks
1202 .borrow()
1203 .iter()
1204 .find(|mock| url_matches(&mock.url_pattern, url))
1205 .map(|mock| (mock.messages.clone(), mock.echo))
1206 })
1207}
1208
1209fn ws_message_data(message: &MockWsMessage) -> String {
1210 match message.message_type.as_str() {
1211 "text" => String::from_utf8_lossy(&message.data).into_owned(),
1212 _ => {
1213 use base64::Engine;
1214 base64::engine::general_purpose::STANDARD.encode(&message.data)
1215 }
1216 }
1217}
1218
1219fn ws_event_value(message: MockWsMessage) -> VmValue {
1220 let mut dict = BTreeMap::new();
1221 dict.insert(
1222 "type".to_string(),
1223 VmValue::String(Rc::from(message.message_type.as_str())),
1224 );
1225 match message.message_type.as_str() {
1226 "text" => {
1227 dict.insert(
1228 "data".to_string(),
1229 VmValue::String(Rc::from(String::from_utf8_lossy(&message.data).as_ref())),
1230 );
1231 }
1232 _ => {
1233 use base64::Engine;
1234 dict.insert(
1235 "data_base64".to_string(),
1236 VmValue::String(Rc::from(
1237 base64::engine::general_purpose::STANDARD
1238 .encode(&message.data)
1239 .as_str(),
1240 )),
1241 );
1242 }
1243 }
1244 VmValue::Dict(Rc::new(dict))
1245}
1246
1247fn real_ws_event_value(message: WsMessage) -> VmValue {
1248 match message {
1249 WsMessage::Text(text) => ws_event_value(MockWsMessage {
1250 message_type: "text".to_string(),
1251 data: text.as_bytes().to_vec(),
1252 }),
1253 WsMessage::Binary(bytes) => ws_event_value(MockWsMessage {
1254 message_type: "binary".to_string(),
1255 data: bytes.to_vec(),
1256 }),
1257 WsMessage::Ping(bytes) => ws_event_value(MockWsMessage {
1258 message_type: "ping".to_string(),
1259 data: bytes.to_vec(),
1260 }),
1261 WsMessage::Pong(bytes) => ws_event_value(MockWsMessage {
1262 message_type: "pong".to_string(),
1263 data: bytes.to_vec(),
1264 }),
1265 WsMessage::Close(_) => closed_event(),
1266 WsMessage::Frame(_) => VmValue::Nil,
1267 }
1268}
1269
1270fn transport_mock_call_value(call: &TransportMockCall) -> VmValue {
1271 let mut dict = BTreeMap::new();
1272 dict.insert(
1273 "kind".to_string(),
1274 VmValue::String(Rc::from(call.kind.as_str())),
1275 );
1276 dict.insert(
1277 "url".to_string(),
1278 VmValue::String(Rc::from(call.url.as_str())),
1279 );
1280 dict.insert(
1281 "handle".to_string(),
1282 call.handle
1283 .as_deref()
1284 .map(|handle| VmValue::String(Rc::from(handle)))
1285 .unwrap_or(VmValue::Nil),
1286 );
1287 dict.insert(
1288 "type".to_string(),
1289 call.message_type
1290 .as_deref()
1291 .map(|message_type| VmValue::String(Rc::from(message_type)))
1292 .unwrap_or(VmValue::Nil),
1293 );
1294 dict.insert(
1295 "data".to_string(),
1296 call.data
1297 .as_deref()
1298 .map(|data| VmValue::String(Rc::from(data)))
1299 .unwrap_or(VmValue::Nil),
1300 );
1301 VmValue::Dict(Rc::new(dict))
1302}
1303
1304fn method_is_retryable(retry: &RetryConfig, method: &reqwest::Method) -> bool {
1305 retry
1306 .retryable_methods
1307 .iter()
1308 .any(|candidate| candidate.eq_ignore_ascii_case(method.as_str()))
1309}
1310
1311fn should_retry_response(
1312 config: &HttpRequestConfig,
1313 method: &reqwest::Method,
1314 status: u16,
1315 attempt: u32,
1316) -> bool {
1317 attempt < config.retry.max
1318 && method_is_retryable(&config.retry, method)
1319 && config.retry.retryable_statuses.contains(&status)
1320}
1321
1322fn should_retry_transport(
1323 config: &HttpRequestConfig,
1324 method: &reqwest::Method,
1325 error: &reqwest::Error,
1326 attempt: u32,
1327) -> bool {
1328 attempt < config.retry.max
1329 && method_is_retryable(&config.retry, method)
1330 && (error.is_timeout() || error.is_connect())
1331}
1332
1333fn parse_retry_after_value(value: &str) -> Option<Duration> {
1334 let value = value.trim();
1335 if value.is_empty() {
1336 return None;
1337 }
1338
1339 if let Ok(secs) = value.parse::<f64>() {
1340 if !secs.is_finite() || secs < 0.0 {
1341 return Some(Duration::from_millis(0));
1342 }
1343 let millis = (secs * 1_000.0) as u64;
1344 return Some(Duration::from_millis(millis.min(MAX_RETRY_DELAY_MS)));
1345 }
1346
1347 if let Ok(target) = httpdate::parse_http_date(value) {
1348 let millis = target
1349 .duration_since(SystemTime::now())
1350 .map(|delta| delta.as_millis() as u64)
1351 .unwrap_or(0);
1352 return Some(Duration::from_millis(millis.min(MAX_RETRY_DELAY_MS)));
1353 }
1354
1355 None
1356}
1357
1358fn parse_retry_after_header(value: &reqwest::header::HeaderValue) -> Option<Duration> {
1359 value.to_str().ok().and_then(parse_retry_after_value)
1360}
1361
1362fn mock_retry_after(status: u16, headers: &BTreeMap<String, VmValue>) -> Option<Duration> {
1363 if !(status == 429 || status == 503) {
1364 return None;
1365 }
1366
1367 headers
1368 .iter()
1369 .find(|(name, _)| name.eq_ignore_ascii_case("retry-after"))
1370 .and_then(|(_, value)| parse_retry_after_value(&value.display()))
1371}
1372
1373fn response_retry_after(
1374 status: u16,
1375 headers: &reqwest::header::HeaderMap,
1376 respect_retry_after: bool,
1377) -> Option<Duration> {
1378 if !respect_retry_after || !(status == 429 || status == 503) {
1379 return None;
1380 }
1381 headers
1382 .get(reqwest::header::RETRY_AFTER)
1383 .and_then(parse_retry_after_header)
1384}
1385
1386fn compute_retry_delay(attempt: u32, base_ms: u64, retry_after: Option<Duration>) -> Duration {
1387 use rand::RngExt;
1388
1389 let base_delay = base_ms.saturating_mul(1u64 << attempt.min(30));
1390 let jitter: f64 = rand::rng().random_range(0.75..=1.25);
1391 let exponential_ms = ((base_delay as f64 * jitter) as u64).min(MAX_RETRY_DELAY_MS);
1392 let retry_after_ms = retry_after
1393 .map(|duration| duration.as_millis() as u64)
1394 .unwrap_or(0)
1395 .min(MAX_RETRY_DELAY_MS);
1396 Duration::from_millis(exponential_ms.max(retry_after_ms))
1397}
1398
1399async fn vm_execute_http_request(
1400 method: &str,
1401 url: &str,
1402 options: &BTreeMap<String, VmValue>,
1403) -> Result<VmValue, VmError> {
1404 if let Some(session_id) = session_from_options(options) {
1405 return vm_execute_http_session_request(&session_id, method, url, options).await;
1406 }
1407
1408 let config = parse_http_options(options);
1409 let client = pooled_http_client(&config)?;
1410 vm_execute_http_request_with_client(client, &config, method, url, options).await
1411}
1412
1413async fn vm_execute_http_session_request(
1414 session_id: &str,
1415 method: &str,
1416 url: &str,
1417 options: &BTreeMap<String, VmValue>,
1418) -> Result<VmValue, VmError> {
1419 let session = HTTP_SESSIONS.with(|sessions| sessions.borrow().get(session_id).cloned());
1420 let Some(session) = session else {
1421 return Err(vm_error(format!(
1422 "http_session_request: unknown HTTP session '{session_id}'"
1423 )));
1424 };
1425 let merged_options = merge_options(&session.options, options);
1426 let config = parse_http_options(&merged_options);
1427 vm_execute_http_request_with_client(session.client, &config, method, url, &merged_options).await
1428}
1429
1430async fn vm_execute_http_request_with_client(
1431 client: reqwest::Client,
1432 config: &HttpRequestConfig,
1433 method: &str,
1434 url: &str,
1435 options: &BTreeMap<String, VmValue>,
1436) -> Result<VmValue, VmError> {
1437 let parts = parse_http_request_parts(method, options)?;
1438
1439 for attempt in 0..=config.retry.max {
1440 if let Some(mock_response) = consume_http_mock(
1441 method,
1442 url,
1443 parts.recorded_headers.clone(),
1444 parts.body.clone(),
1445 ) {
1446 let status = mock_response.status.clamp(0, u16::MAX as i64) as u16;
1447 if should_retry_response(config, &parts.method, status, attempt) {
1448 let retry_after = if config.retry.respect_retry_after {
1449 mock_retry_after(status, &mock_response.headers)
1450 } else {
1451 None
1452 };
1453 tokio::time::sleep(compute_retry_delay(
1454 attempt,
1455 config.retry.backoff_ms,
1456 retry_after,
1457 ))
1458 .await;
1459 continue;
1460 }
1461
1462 return Ok(build_http_response(
1463 mock_response.status,
1464 mock_response.headers,
1465 mock_response.body,
1466 ));
1467 }
1468
1469 if !url.starts_with("http://") && !url.starts_with("https://") {
1470 return Err(vm_error(format!(
1471 "http: URL must start with http:// or https://, got '{url}'"
1472 )));
1473 }
1474
1475 let mut req = client.request(parts.method.clone(), url);
1476 req = req
1477 .headers(parts.headers.clone())
1478 .timeout(Duration::from_millis(config.timeout_ms));
1479 if let Some(ref b) = parts.body {
1480 req = req.body(b.clone());
1481 }
1482
1483 match req.send().await {
1484 Ok(response) => {
1485 let status = response.status().as_u16();
1486 if should_retry_response(config, &parts.method, status, attempt) {
1487 let retry_after = response_retry_after(
1488 status,
1489 response.headers(),
1490 config.retry.respect_retry_after,
1491 );
1492 tokio::time::sleep(compute_retry_delay(
1493 attempt,
1494 config.retry.backoff_ms,
1495 retry_after,
1496 ))
1497 .await;
1498 continue;
1499 }
1500
1501 let mut resp_headers = BTreeMap::new();
1502 for (name, value) in response.headers() {
1503 if let Ok(v) = value.to_str() {
1504 resp_headers
1505 .insert(name.as_str().to_string(), VmValue::String(Rc::from(v)));
1506 }
1507 }
1508
1509 let body_text = response
1510 .text()
1511 .await
1512 .map_err(|e| vm_error(format!("http: failed to read response body: {e}")))?;
1513 return Ok(build_http_response(status as i64, resp_headers, body_text));
1514 }
1515 Err(e) => {
1516 if should_retry_transport(config, &parts.method, &e, attempt) {
1517 tokio::time::sleep(compute_retry_delay(attempt, config.retry.backoff_ms, None))
1518 .await;
1519 continue;
1520 }
1521 return Err(vm_error(format!("http: request failed: {e}")));
1522 }
1523 }
1524 }
1525
1526 Err(vm_error("http: request failed"))
1527}
1528
1529async fn vm_sse_connect(
1530 method: &str,
1531 url: &str,
1532 options: &BTreeMap<String, VmValue>,
1533) -> Result<VmValue, VmError> {
1534 let id = next_transport_handle("sse");
1535 let max_events =
1536 transport_limit_option(options, "max_events", DEFAULT_MAX_STREAM_EVENTS).max(1);
1537 let max_message_bytes =
1538 transport_limit_option(options, "max_message_bytes", DEFAULT_MAX_MESSAGE_BYTES).max(1);
1539
1540 if let Some(events) = consume_sse_mock(url) {
1541 let handle = SseHandle {
1542 kind: SseHandleKind::Fake(Rc::new(tokio::sync::Mutex::new(FakeSseStream {
1543 events: events.into(),
1544 opened: false,
1545 closed: false,
1546 }))),
1547 url: url.to_string(),
1548 max_events,
1549 max_message_bytes,
1550 received: 0,
1551 };
1552 SSE_HANDLES.with(|handles| {
1553 let mut handles = handles.borrow_mut();
1554 if handles.len() >= MAX_SSE_STREAMS {
1555 return Err(vm_error(format!(
1556 "sse_connect: maximum open streams ({MAX_SSE_STREAMS}) reached"
1557 )));
1558 }
1559 handles.insert(id.clone(), handle);
1560 Ok(())
1561 })?;
1562 record_transport_call(TransportMockCall {
1563 kind: "sse_connect".to_string(),
1564 handle: Some(id.clone()),
1565 url: url.to_string(),
1566 message_type: None,
1567 data: None,
1568 });
1569 return Ok(VmValue::String(Rc::from(id)));
1570 }
1571
1572 if !url.starts_with("http://") && !url.starts_with("https://") {
1573 return Err(vm_error(format!(
1574 "sse_connect: URL must start with http:// or https://, got '{url}'"
1575 )));
1576 }
1577
1578 let config = parse_http_options(options);
1579 let client = if let Some(session_id) = session_from_options(options) {
1580 let session = HTTP_SESSIONS.with(|sessions| sessions.borrow().get(&session_id).cloned());
1581 session
1582 .map(|session| session.client)
1583 .ok_or_else(|| vm_error(format!("sse_connect: unknown HTTP session '{session_id}'")))?
1584 } else {
1585 pooled_http_client(&config)?
1586 };
1587 let parts = parse_http_request_parts(method, options)?;
1588 let mut request = client
1589 .request(parts.method, url)
1590 .headers(parts.headers)
1591 .timeout(Duration::from_millis(config.timeout_ms));
1592 if let Some(body) = parts.body {
1593 request = request.body(body);
1594 }
1595 let stream = EventSource::new(request)
1596 .map_err(|error| vm_error(format!("sse_connect: failed to create stream: {error}")))?;
1597 let handle = SseHandle {
1598 kind: SseHandleKind::Real(Rc::new(tokio::sync::Mutex::new(stream))),
1599 url: url.to_string(),
1600 max_events,
1601 max_message_bytes,
1602 received: 0,
1603 };
1604 SSE_HANDLES.with(|handles| {
1605 let mut handles = handles.borrow_mut();
1606 if handles.len() >= MAX_SSE_STREAMS {
1607 return Err(vm_error(format!(
1608 "sse_connect: maximum open streams ({MAX_SSE_STREAMS}) reached"
1609 )));
1610 }
1611 handles.insert(id.clone(), handle);
1612 Ok(())
1613 })?;
1614 Ok(VmValue::String(Rc::from(id)))
1615}
1616
1617async fn vm_sse_receive(stream_id: &str, timeout_ms: u64) -> Result<VmValue, VmError> {
1618 let stream = SSE_HANDLES.with(|handles| {
1619 let mut handles = handles.borrow_mut();
1620 let handle = handles.get_mut(stream_id)?;
1621 if handle.received >= handle.max_events {
1622 return Some(Err(vm_error(format!(
1623 "sse_receive: stream '{stream_id}' exceeded max_events"
1624 ))));
1625 }
1626 handle.received += 1;
1627 let url = handle.url.clone();
1628 let max_message_bytes = handle.max_message_bytes;
1629 let kind = match &handle.kind {
1630 SseHandleKind::Real(stream) => SseHandleKind::Real(stream.clone()),
1631 SseHandleKind::Fake(stream) => SseHandleKind::Fake(stream.clone()),
1632 };
1633 Some(Ok((kind, url, max_message_bytes)))
1634 });
1635 let Some(stream) = stream else {
1636 return Err(vm_error(format!(
1637 "sse_receive: unknown stream '{stream_id}'"
1638 )));
1639 };
1640 let (kind, _url, max_message_bytes) = stream?;
1641
1642 match kind {
1643 SseHandleKind::Fake(stream) => {
1644 let mut stream = stream.lock().await;
1645 if stream.closed {
1646 return Ok(VmValue::Nil);
1647 }
1648 if !stream.opened {
1649 stream.opened = true;
1650 let mut dict = BTreeMap::new();
1651 dict.insert("type".to_string(), VmValue::String(Rc::from("open")));
1652 return Ok(VmValue::Dict(Rc::new(dict)));
1653 }
1654 let Some(event) = stream.events.pop_front() else {
1655 stream.closed = true;
1656 return Ok(VmValue::Nil);
1657 };
1658 if event.data.len() > max_message_bytes {
1659 return Err(vm_error(format!(
1660 "sse_receive: message exceeded max_message_bytes ({max_message_bytes})"
1661 )));
1662 }
1663 Ok(sse_event_value(&event))
1664 }
1665 SseHandleKind::Real(stream) => {
1666 let mut stream = stream.lock().await;
1667 let next = stream.next();
1668 let event = match tokio::time::timeout(Duration::from_millis(timeout_ms), next).await {
1669 Ok(Some(Ok(event))) => event,
1670 Ok(Some(Err(error))) => {
1671 return Err(vm_error(format!("sse_receive: stream error: {error}")));
1672 }
1673 Ok(None) => return Ok(VmValue::Nil),
1674 Err(_) => return Ok(timeout_event()),
1675 };
1676 if let SseEvent::Message(message) = &event {
1677 if message.data.len() > max_message_bytes {
1678 stream.close();
1679 return Err(vm_error(format!(
1680 "sse_receive: message exceeded max_message_bytes ({max_message_bytes})"
1681 )));
1682 }
1683 }
1684 Ok(real_sse_event_value(event))
1685 }
1686 }
1687}
1688
1689async fn vm_websocket_connect(
1690 url: &str,
1691 options: &BTreeMap<String, VmValue>,
1692) -> Result<VmValue, VmError> {
1693 let id = next_transport_handle("websocket");
1694 let max_messages =
1695 transport_limit_option(options, "max_messages", DEFAULT_MAX_STREAM_EVENTS).max(1);
1696 let max_message_bytes =
1697 transport_limit_option(options, "max_message_bytes", DEFAULT_MAX_MESSAGE_BYTES).max(1);
1698
1699 if let Some((messages, echo)) = consume_websocket_mock(url) {
1700 let handle = WebSocketHandle {
1701 kind: WebSocketHandleKind::Fake(Rc::new(tokio::sync::Mutex::new(FakeWebSocket {
1702 messages: messages.into(),
1703 echo,
1704 closed: false,
1705 }))),
1706 url: url.to_string(),
1707 max_messages,
1708 max_message_bytes,
1709 received: 0,
1710 };
1711 WEBSOCKET_HANDLES.with(|handles| {
1712 let mut handles = handles.borrow_mut();
1713 if handles.len() >= MAX_WEBSOCKETS {
1714 return Err(vm_error(format!(
1715 "websocket_connect: maximum open sockets ({MAX_WEBSOCKETS}) reached"
1716 )));
1717 }
1718 handles.insert(id.clone(), handle);
1719 Ok(())
1720 })?;
1721 record_transport_call(TransportMockCall {
1722 kind: "websocket_connect".to_string(),
1723 handle: Some(id.clone()),
1724 url: url.to_string(),
1725 message_type: None,
1726 data: None,
1727 });
1728 return Ok(VmValue::String(Rc::from(id)));
1729 }
1730
1731 if !url.starts_with("ws://") && !url.starts_with("wss://") {
1732 return Err(vm_error(format!(
1733 "websocket_connect: URL must start with ws:// or wss://, got '{url}'"
1734 )));
1735 }
1736 let timeout_ms = vm_get_int_option_prefer(
1737 options,
1738 "timeout_ms",
1739 "timeout",
1740 DEFAULT_TIMEOUT_MS as i64,
1741 )
1742 .max(0) as u64;
1743 let connect = tokio_tungstenite::connect_async(url);
1744 let (socket, _) = tokio::time::timeout(Duration::from_millis(timeout_ms), connect)
1745 .await
1746 .map_err(|_| vm_error(format!("websocket_connect: timed out after {timeout_ms}ms")))?
1747 .map_err(|error| vm_error(format!("websocket_connect: failed: {error}")))?;
1748 let handle = WebSocketHandle {
1749 kind: WebSocketHandleKind::Real(Rc::new(tokio::sync::Mutex::new(socket))),
1750 url: url.to_string(),
1751 max_messages,
1752 max_message_bytes,
1753 received: 0,
1754 };
1755 WEBSOCKET_HANDLES.with(|handles| {
1756 let mut handles = handles.borrow_mut();
1757 if handles.len() >= MAX_WEBSOCKETS {
1758 return Err(vm_error(format!(
1759 "websocket_connect: maximum open sockets ({MAX_WEBSOCKETS}) reached"
1760 )));
1761 }
1762 handles.insert(id.clone(), handle);
1763 Ok(())
1764 })?;
1765 Ok(VmValue::String(Rc::from(id)))
1766}
1767
1768fn websocket_message_from_vm(
1769 value: VmValue,
1770 options: &BTreeMap<String, VmValue>,
1771) -> Result<MockWsMessage, VmError> {
1772 let message_type = options
1773 .get("type")
1774 .map(|value| value.display())
1775 .filter(|value| !value.is_empty())
1776 .unwrap_or_else(|| match value {
1777 VmValue::Bytes(_) => "binary".to_string(),
1778 _ => "text".to_string(),
1779 });
1780 let data = match value {
1781 VmValue::Bytes(bytes) => bytes.as_ref().clone(),
1782 other
1783 if options
1784 .get("base64")
1785 .and_then(|value| match value {
1786 VmValue::Bool(value) => Some(*value),
1787 _ => None,
1788 })
1789 .unwrap_or(false) =>
1790 {
1791 use base64::Engine;
1792 base64::engine::general_purpose::STANDARD
1793 .decode(other.display())
1794 .map_err(|error| vm_error(format!("websocket_send: invalid base64: {error}")))?
1795 }
1796 other => other.display().into_bytes(),
1797 };
1798 Ok(MockWsMessage { message_type, data })
1799}
1800
1801fn real_ws_message(message: &MockWsMessage) -> Result<WsMessage, VmError> {
1802 match message.message_type.as_str() {
1803 "text" => Ok(WsMessage::Text(
1804 String::from_utf8(message.data.clone())
1805 .map_err(|error| vm_error(format!("websocket_send: text is not UTF-8: {error}")))?
1806 .into(),
1807 )),
1808 "binary" => Ok(WsMessage::Binary(message.data.clone().into())),
1809 "ping" => Ok(WsMessage::Ping(message.data.clone().into())),
1810 "pong" => Ok(WsMessage::Pong(message.data.clone().into())),
1811 "close" => Ok(WsMessage::Close(None)),
1812 other => Err(vm_error(format!(
1813 "websocket_send: unsupported message type '{other}'"
1814 ))),
1815 }
1816}
1817
1818async fn vm_websocket_send(
1819 socket_id: &str,
1820 value: VmValue,
1821 options: &BTreeMap<String, VmValue>,
1822) -> Result<VmValue, VmError> {
1823 let message = websocket_message_from_vm(value, options)?;
1824 let socket = WEBSOCKET_HANDLES.with(|handles| {
1825 let handles = handles.borrow();
1826 let handle = handles.get(socket_id)?;
1827 let url = handle.url.clone();
1828 let max_message_bytes = handle.max_message_bytes;
1829 let kind = match &handle.kind {
1830 WebSocketHandleKind::Real(socket) => WebSocketHandleKind::Real(socket.clone()),
1831 WebSocketHandleKind::Fake(socket) => WebSocketHandleKind::Fake(socket.clone()),
1832 };
1833 Some((kind, url, max_message_bytes))
1834 });
1835 let Some((kind, url, max_message_bytes)) = socket else {
1836 return Err(vm_error(format!(
1837 "websocket_send: unknown socket '{socket_id}'"
1838 )));
1839 };
1840 if message.data.len() > max_message_bytes {
1841 return Err(vm_error(format!(
1842 "websocket_send: message exceeded max_message_bytes ({max_message_bytes})"
1843 )));
1844 }
1845 match kind {
1846 WebSocketHandleKind::Fake(socket) => {
1847 let mut socket = socket.lock().await;
1848 if socket.closed {
1849 return Ok(VmValue::Bool(false));
1850 }
1851 if message.message_type == "close" {
1852 socket.closed = true;
1853 } else if socket.echo {
1854 socket.messages.push_back(message.clone());
1855 }
1856 record_transport_call(TransportMockCall {
1857 kind: "websocket_send".to_string(),
1858 handle: Some(socket_id.to_string()),
1859 url,
1860 message_type: Some(message.message_type.clone()),
1861 data: Some(ws_message_data(&message)),
1862 });
1863 Ok(VmValue::Bool(true))
1864 }
1865 WebSocketHandleKind::Real(socket) => {
1866 let mut socket = socket.lock().await;
1867 socket
1868 .send(real_ws_message(&message)?)
1869 .await
1870 .map_err(|error| vm_error(format!("websocket_send: failed: {error}")))?;
1871 Ok(VmValue::Bool(true))
1872 }
1873 }
1874}
1875
1876async fn vm_websocket_receive(socket_id: &str, timeout_ms: u64) -> Result<VmValue, VmError> {
1877 let socket = WEBSOCKET_HANDLES.with(|handles| {
1878 let mut handles = handles.borrow_mut();
1879 let handle = handles.get_mut(socket_id)?;
1880 if handle.received >= handle.max_messages {
1881 return Some(Err(vm_error(format!(
1882 "websocket_receive: socket '{socket_id}' exceeded max_messages"
1883 ))));
1884 }
1885 handle.received += 1;
1886 let max_message_bytes = handle.max_message_bytes;
1887 let kind = match &handle.kind {
1888 WebSocketHandleKind::Real(socket) => WebSocketHandleKind::Real(socket.clone()),
1889 WebSocketHandleKind::Fake(socket) => WebSocketHandleKind::Fake(socket.clone()),
1890 };
1891 Some(Ok((kind, max_message_bytes)))
1892 });
1893 let Some(socket) = socket else {
1894 return Err(vm_error(format!(
1895 "websocket_receive: unknown socket '{socket_id}'"
1896 )));
1897 };
1898 let (kind, max_message_bytes) = socket?;
1899 match kind {
1900 WebSocketHandleKind::Fake(socket) => {
1901 let mut socket = socket.lock().await;
1902 if socket.closed {
1903 return Ok(VmValue::Nil);
1904 }
1905 let Some(message) = socket.messages.pop_front() else {
1906 return Ok(timeout_event());
1907 };
1908 if message.data.len() > max_message_bytes {
1909 socket.closed = true;
1910 return Err(vm_error(format!(
1911 "websocket_receive: message exceeded max_message_bytes ({max_message_bytes})"
1912 )));
1913 }
1914 if message.message_type == "close" {
1915 socket.closed = true;
1916 }
1917 Ok(ws_event_value(message))
1918 }
1919 WebSocketHandleKind::Real(socket) => {
1920 let mut socket = socket.lock().await;
1921 let next = socket.next();
1922 let message = match tokio::time::timeout(Duration::from_millis(timeout_ms), next).await
1923 {
1924 Ok(Some(Ok(message))) => message,
1925 Ok(Some(Err(error))) => {
1926 return Err(vm_error(format!("websocket_receive: failed: {error}")));
1927 }
1928 Ok(None) => return Ok(VmValue::Nil),
1929 Err(_) => return Ok(timeout_event()),
1930 };
1931 match &message {
1932 WsMessage::Text(text) if text.len() > max_message_bytes => {
1933 return Err(vm_error(format!(
1934 "websocket_receive: message exceeded max_message_bytes ({max_message_bytes})"
1935 )));
1936 }
1937 WsMessage::Binary(bytes) | WsMessage::Ping(bytes) | WsMessage::Pong(bytes)
1938 if bytes.len() > max_message_bytes =>
1939 {
1940 return Err(vm_error(format!(
1941 "websocket_receive: message exceeded max_message_bytes ({max_message_bytes})"
1942 )));
1943 }
1944 _ => {}
1945 }
1946 Ok(real_ws_event_value(message))
1947 }
1948 }
1949}
1950
1951async fn vm_websocket_close(socket_id: &str) -> Result<VmValue, VmError> {
1952 let removed = WEBSOCKET_HANDLES.with(|handles| handles.borrow_mut().remove(socket_id));
1953 let Some(handle) = removed else {
1954 return Ok(VmValue::Bool(false));
1955 };
1956 match handle.kind {
1957 WebSocketHandleKind::Fake(socket) => {
1958 socket.lock().await.closed = true;
1959 record_transport_call(TransportMockCall {
1960 kind: "websocket_close".to_string(),
1961 handle: Some(socket_id.to_string()),
1962 url: handle.url,
1963 message_type: None,
1964 data: None,
1965 });
1966 Ok(VmValue::Bool(true))
1967 }
1968 WebSocketHandleKind::Real(socket) => {
1969 let mut socket = socket.lock().await;
1970 socket
1971 .close(None)
1972 .await
1973 .map_err(|error| vm_error(format!("websocket_close: failed: {error}")))?;
1974 Ok(VmValue::Bool(true))
1975 }
1976 }
1977}
1978
1979#[cfg(test)]
1980mod tests {
1981 use super::{
1982 compute_retry_delay, http_mock_calls_snapshot, parse_retry_after_value, push_http_mock,
1983 reset_http_state, vm_execute_http_request, HttpMockResponse,
1984 };
1985 use crate::value::VmValue;
1986 use std::collections::BTreeMap;
1987 use std::time::{Duration, SystemTime};
1988
1989 #[test]
1990 fn parses_retry_after_delta_seconds() {
1991 assert_eq!(parse_retry_after_value("5"), Some(Duration::from_secs(5)));
1992 }
1993
1994 #[test]
1995 fn parses_retry_after_http_date() {
1996 let header = httpdate::fmt_http_date(SystemTime::now() + Duration::from_secs(2));
1997 let parsed = parse_retry_after_value(&header).expect("http-date should parse");
1998 assert!(parsed <= Duration::from_secs(2));
1999 assert!(parsed <= Duration::from_secs(60));
2000 }
2001
2002 #[test]
2003 fn malformed_retry_after_returns_none() {
2004 assert_eq!(parse_retry_after_value("soon-ish"), None);
2005 }
2006
2007 #[test]
2008 fn retry_delay_honors_retry_after_floor() {
2009 let delay = compute_retry_delay(0, 1, Some(Duration::from_millis(250)));
2010 assert!(delay >= Duration::from_millis(250));
2011 assert!(delay <= Duration::from_secs(60));
2012 }
2013
2014 #[tokio::test]
2015 async fn typed_mock_api_drives_http_request_retries() {
2016 reset_http_state();
2017 push_http_mock(
2018 "GET",
2019 "https://api.example.com/retry",
2020 vec![
2021 HttpMockResponse::new(503, "busy").with_header("retry-after", "0"),
2022 HttpMockResponse::new(200, "ok"),
2023 ],
2024 );
2025 let result = vm_execute_http_request(
2026 "GET",
2027 "https://api.example.com/retry",
2028 &BTreeMap::from([
2029 ("retries".to_string(), VmValue::Int(1)),
2030 ("backoff".to_string(), VmValue::Int(0)),
2031 ]),
2032 )
2033 .await
2034 .expect("mocked request should succeed after retry");
2035
2036 let dict = result.as_dict().expect("response dict");
2037 assert_eq!(dict["status"].as_int(), Some(200));
2038 let calls = http_mock_calls_snapshot();
2039 assert_eq!(calls.len(), 2);
2040 assert_eq!(calls[0].url, "https://api.example.com/retry");
2041 reset_http_state();
2042 }
2043}