1use std::cell::RefCell;
2use std::collections::{BTreeMap, HashMap, VecDeque};
3use std::rc::Rc;
4use std::time::{Duration, SystemTime};
5
6use crate::value::{VmError, VmValue};
7use crate::vm::Vm;
8
9use futures::{SinkExt, StreamExt};
10use reqwest_eventsource::{Event as SseEvent, EventSource};
11use tokio_tungstenite::tungstenite::Message as WsMessage;
12
13#[derive(Clone)]
16struct MockResponse {
17 status: i64,
18 body: String,
19 headers: BTreeMap<String, VmValue>,
20}
21
22#[derive(Clone, Debug, PartialEq, Eq)]
23pub struct HttpMockResponse {
24 pub status: i64,
25 pub body: String,
26 pub headers: BTreeMap<String, String>,
27}
28
29impl HttpMockResponse {
30 pub fn new(status: i64, body: impl Into<String>) -> Self {
31 Self {
32 status,
33 body: body.into(),
34 headers: BTreeMap::new(),
35 }
36 }
37
38 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
39 self.headers.insert(name.into(), value.into());
40 self
41 }
42}
43
44impl From<HttpMockResponse> for MockResponse {
45 fn from(value: HttpMockResponse) -> Self {
46 Self {
47 status: value.status,
48 body: value.body,
49 headers: value
50 .headers
51 .into_iter()
52 .map(|(key, value)| (key, VmValue::String(Rc::from(value))))
53 .collect(),
54 }
55 }
56}
57
58struct HttpMock {
59 method: String,
60 url_pattern: String,
61 responses: Vec<MockResponse>,
62 next_response: usize,
63}
64
65#[derive(Clone)]
66struct HttpMockCall {
67 method: String,
68 url: String,
69 headers: BTreeMap<String, VmValue>,
70 body: Option<String>,
71}
72
73#[derive(Clone, Debug, PartialEq, Eq)]
74pub struct HttpMockCallSnapshot {
75 pub method: String,
76 pub url: String,
77 pub headers: BTreeMap<String, String>,
78 pub body: Option<String>,
79}
80
81#[derive(Clone)]
82struct RetryConfig {
83 max: u32,
84 backoff_ms: u64,
85 retryable_statuses: Vec<u16>,
86 retryable_methods: Vec<String>,
87 respect_retry_after: bool,
88}
89
90#[derive(Clone)]
91struct HttpRequestConfig {
92 timeout_ms: u64,
93 retry: RetryConfig,
94 follow_redirects: bool,
95 max_redirects: usize,
96}
97
98#[derive(Clone)]
99struct HttpSession {
100 client: reqwest::Client,
101 options: BTreeMap<String, VmValue>,
102}
103
104struct HttpRequestParts {
105 method: reqwest::Method,
106 headers: reqwest::header::HeaderMap,
107 recorded_headers: BTreeMap<String, VmValue>,
108 body: Option<String>,
109}
110
111struct SseMock {
112 url_pattern: String,
113 events: Vec<MockStreamEvent>,
114}
115
116#[derive(Clone)]
117struct MockStreamEvent {
118 event_type: String,
119 data: String,
120 id: Option<String>,
121 retry_ms: Option<i64>,
122}
123
124struct SseHandle {
125 kind: SseHandleKind,
126 url: String,
127 max_events: usize,
128 max_message_bytes: usize,
129 received: usize,
130}
131
132enum SseHandleKind {
133 Real(Rc<tokio::sync::Mutex<EventSource>>),
134 Fake(Rc<tokio::sync::Mutex<FakeSseStream>>),
135}
136
137struct FakeSseStream {
138 events: VecDeque<MockStreamEvent>,
139 opened: bool,
140 closed: bool,
141}
142
143struct WebSocketMock {
144 url_pattern: String,
145 messages: Vec<MockWsMessage>,
146 echo: bool,
147}
148
149#[derive(Clone)]
150struct MockWsMessage {
151 message_type: String,
152 data: Vec<u8>,
153}
154
155type RealWebSocket =
156 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
157
158struct WebSocketHandle {
159 kind: WebSocketHandleKind,
160 url: String,
161 max_messages: usize,
162 max_message_bytes: usize,
163 received: usize,
164}
165
166enum WebSocketHandleKind {
167 Real(Rc<tokio::sync::Mutex<RealWebSocket>>),
168 Fake(Rc<tokio::sync::Mutex<FakeWebSocket>>),
169}
170
171struct FakeWebSocket {
172 messages: VecDeque<MockWsMessage>,
173 echo: bool,
174 closed: bool,
175}
176
177#[derive(Clone)]
178struct TransportMockCall {
179 kind: String,
180 handle: Option<String>,
181 url: String,
182 message_type: Option<String>,
183 data: Option<String>,
184}
185
186const DEFAULT_TIMEOUT_MS: u64 = 30_000;
187const DEFAULT_BACKOFF_MS: u64 = 1_000;
188const MAX_RETRY_DELAY_MS: u64 = 60_000;
189const DEFAULT_RETRYABLE_STATUSES: [u16; 6] = [408, 429, 500, 502, 503, 504];
190const DEFAULT_RETRYABLE_METHODS: [&str; 5] = ["GET", "HEAD", "PUT", "DELETE", "OPTIONS"];
191const DEFAULT_TRANSPORT_RECEIVE_TIMEOUT_MS: u64 = 30_000;
192const DEFAULT_MAX_STREAM_EVENTS: usize = 10_000;
193const DEFAULT_MAX_MESSAGE_BYTES: usize = 1024 * 1024;
194const MAX_HTTP_SESSIONS: usize = 64;
195const MAX_SSE_STREAMS: usize = 64;
196const MAX_WEBSOCKETS: usize = 64;
197
198thread_local! {
199 static HTTP_MOCKS: RefCell<Vec<HttpMock>> = const { RefCell::new(Vec::new()) };
200 static HTTP_MOCK_CALLS: RefCell<Vec<HttpMockCall>> = const { RefCell::new(Vec::new()) };
201 static HTTP_CLIENTS: RefCell<HashMap<String, reqwest::Client>> = RefCell::new(HashMap::new());
202 static HTTP_SESSIONS: RefCell<HashMap<String, HttpSession>> = RefCell::new(HashMap::new());
203 static SSE_MOCKS: RefCell<Vec<SseMock>> = const { RefCell::new(Vec::new()) };
204 static SSE_HANDLES: RefCell<HashMap<String, SseHandle>> = RefCell::new(HashMap::new());
205 static WEBSOCKET_MOCKS: RefCell<Vec<WebSocketMock>> = const { RefCell::new(Vec::new()) };
206 static WEBSOCKET_HANDLES: RefCell<HashMap<String, WebSocketHandle>> = RefCell::new(HashMap::new());
207 static TRANSPORT_MOCK_CALLS: RefCell<Vec<TransportMockCall>> = const { RefCell::new(Vec::new()) };
208 static TRANSPORT_HANDLE_COUNTER: RefCell<u64> = const { RefCell::new(0) };
209}
210
211pub fn reset_http_state() {
213 HTTP_MOCKS.with(|m| m.borrow_mut().clear());
214 HTTP_MOCK_CALLS.with(|c| c.borrow_mut().clear());
215 HTTP_CLIENTS.with(|clients| clients.borrow_mut().clear());
216 HTTP_SESSIONS.with(|sessions| sessions.borrow_mut().clear());
217 SSE_MOCKS.with(|mocks| mocks.borrow_mut().clear());
218 SSE_HANDLES.with(|handles| {
219 for handle in handles.borrow_mut().values_mut() {
220 if let SseHandleKind::Real(stream) = &handle.kind {
221 if let Ok(mut stream) = stream.try_lock() {
222 stream.close();
223 }
224 }
225 }
226 handles.borrow_mut().clear();
227 });
228 WEBSOCKET_MOCKS.with(|mocks| mocks.borrow_mut().clear());
229 WEBSOCKET_HANDLES.with(|handles| handles.borrow_mut().clear());
230 TRANSPORT_MOCK_CALLS.with(|calls| calls.borrow_mut().clear());
231 TRANSPORT_HANDLE_COUNTER.with(|counter| *counter.borrow_mut() = 0);
232}
233
234pub fn push_http_mock(
235 method: impl Into<String>,
236 url_pattern: impl Into<String>,
237 responses: Vec<HttpMockResponse>,
238) {
239 let responses = if responses.is_empty() {
240 vec![MockResponse::from(HttpMockResponse::new(200, ""))]
241 } else {
242 responses.into_iter().map(MockResponse::from).collect()
243 };
244 HTTP_MOCKS.with(|mocks| {
245 mocks.borrow_mut().push(HttpMock {
246 method: method.into(),
247 url_pattern: url_pattern.into(),
248 responses,
249 next_response: 0,
250 });
251 });
252}
253
254pub fn http_mock_calls_snapshot() -> Vec<HttpMockCallSnapshot> {
255 HTTP_MOCK_CALLS.with(|calls| {
256 calls
257 .borrow()
258 .iter()
259 .map(|call| HttpMockCallSnapshot {
260 method: call.method.clone(),
261 url: call.url.clone(),
262 headers: call
263 .headers
264 .iter()
265 .map(|(key, value)| (key.clone(), value.display()))
266 .collect(),
267 body: call.body.clone(),
268 })
269 .collect()
270 })
271}
272
273fn url_matches(pattern: &str, url: &str) -> bool {
275 if pattern == "*" {
276 return true;
277 }
278 if !pattern.contains('*') {
279 return pattern == url;
280 }
281 let parts: Vec<&str> = pattern.split('*').collect();
283 let mut remaining = url;
284 for (i, part) in parts.iter().enumerate() {
285 if part.is_empty() {
286 continue;
287 }
288 if i == 0 {
289 if !remaining.starts_with(part) {
290 return false;
291 }
292 remaining = &remaining[part.len()..];
293 } else if i == parts.len() - 1 {
294 if !remaining.ends_with(part) {
295 return false;
296 }
297 remaining = "";
298 } else {
299 match remaining.find(part) {
300 Some(pos) => remaining = &remaining[pos + part.len()..],
301 None => return false,
302 }
303 }
304 }
305 true
306}
307
308fn build_http_response(status: i64, headers: BTreeMap<String, VmValue>, body: String) -> VmValue {
310 let mut result = BTreeMap::new();
311 result.insert("status".to_string(), VmValue::Int(status));
312 result.insert("headers".to_string(), VmValue::Dict(Rc::new(headers)));
313 result.insert("body".to_string(), VmValue::String(Rc::from(body)));
314 result.insert(
315 "ok".to_string(),
316 VmValue::Bool((200..300).contains(&(status as u16))),
317 );
318 VmValue::Dict(Rc::new(result))
319}
320
321fn vm_error(message: impl Into<String>) -> VmError {
322 VmError::Thrown(VmValue::String(Rc::from(message.into())))
323}
324
325fn next_transport_handle(prefix: &str) -> String {
326 TRANSPORT_HANDLE_COUNTER.with(|counter| {
327 let mut counter = counter.borrow_mut();
328 *counter += 1;
329 format!("{prefix}-{}", *counter)
330 })
331}
332
333fn handle_from_value(value: &VmValue, builtin: &str) -> Result<String, VmError> {
334 match value {
335 VmValue::String(handle) => Ok(handle.to_string()),
336 VmValue::Dict(dict) => dict
337 .get("id")
338 .map(|id| id.display())
339 .filter(|id| !id.is_empty())
340 .ok_or_else(|| vm_error(format!("{builtin}: handle dict must contain id"))),
341 _ => Err(vm_error(format!(
342 "{builtin}: first argument must be a handle string or dict"
343 ))),
344 }
345}
346
347fn get_options_arg(args: &[VmValue], index: usize) -> BTreeMap<String, VmValue> {
348 args.get(index)
349 .and_then(|value| value.as_dict())
350 .cloned()
351 .unwrap_or_default()
352}
353
354fn merge_options(
355 base: &BTreeMap<String, VmValue>,
356 overrides: &BTreeMap<String, VmValue>,
357) -> BTreeMap<String, VmValue> {
358 let mut merged = base.clone();
359 for (key, value) in overrides {
360 merged.insert(key.clone(), value.clone());
361 }
362 merged
363}
364
365fn transport_limit_option(options: &BTreeMap<String, VmValue>, key: &str, default: usize) -> usize {
366 options
367 .get(key)
368 .and_then(|value| value.as_int())
369 .map(|value| value.max(0) as usize)
370 .unwrap_or(default)
371}
372
373fn receive_timeout_arg(args: &[VmValue], index: usize) -> u64 {
374 match args.get(index) {
375 Some(VmValue::Duration(ms)) => *ms,
376 Some(value) => value
377 .as_int()
378 .map(|ms| ms.max(0) as u64)
379 .unwrap_or(DEFAULT_TRANSPORT_RECEIVE_TIMEOUT_MS),
380 None => DEFAULT_TRANSPORT_RECEIVE_TIMEOUT_MS,
381 }
382}
383
384fn timeout_event() -> VmValue {
385 let mut dict = BTreeMap::new();
386 dict.insert("type".to_string(), VmValue::String(Rc::from("timeout")));
387 VmValue::Dict(Rc::new(dict))
388}
389
390fn closed_event() -> VmValue {
391 let mut dict = BTreeMap::new();
392 dict.insert("type".to_string(), VmValue::String(Rc::from("close")));
393 VmValue::Dict(Rc::new(dict))
394}
395
396fn record_transport_call(call: TransportMockCall) {
397 TRANSPORT_MOCK_CALLS.with(|calls| calls.borrow_mut().push(call));
398}
399
400async fn http_verb_handler(
404 method: &str,
405 has_body: bool,
406 args: Vec<VmValue>,
407) -> Result<VmValue, VmError> {
408 let url = args.first().map(|a| a.display()).unwrap_or_default();
409 if url.is_empty() {
410 return Err(VmError::Thrown(VmValue::String(Rc::from(format!(
411 "http_{}: URL is required",
412 method.to_ascii_lowercase()
413 )))));
414 }
415 let mut options = if has_body {
416 match args.get(2) {
417 Some(VmValue::Dict(d)) => (**d).clone(),
418 _ => BTreeMap::new(),
419 }
420 } else {
421 match args.get(1) {
422 Some(VmValue::Dict(d)) => (**d).clone(),
423 _ => BTreeMap::new(),
424 }
425 };
426 if has_body {
427 let body = args.get(1).map(|a| a.display()).unwrap_or_default();
428 options.insert("body".to_string(), VmValue::String(Rc::from(body)));
429 }
430 vm_execute_http_request(method, &url, &options).await
431}
432
433fn parse_mock_response_dict(response: &BTreeMap<String, VmValue>) -> MockResponse {
434 let status = response
435 .get("status")
436 .and_then(|v| v.as_int())
437 .unwrap_or(200);
438 let body = response
439 .get("body")
440 .map(|v| v.display())
441 .unwrap_or_default();
442 let headers = response
443 .get("headers")
444 .and_then(|v| v.as_dict())
445 .cloned()
446 .unwrap_or_default();
447 MockResponse {
448 status,
449 body,
450 headers,
451 }
452}
453
454fn parse_mock_responses(response: &BTreeMap<String, VmValue>) -> Vec<MockResponse> {
455 let scripted = response
456 .get("responses")
457 .and_then(|value| match value {
458 VmValue::List(items) => Some(
459 items
460 .iter()
461 .filter_map(|item| item.as_dict().map(parse_mock_response_dict))
462 .collect::<Vec<_>>(),
463 ),
464 _ => None,
465 })
466 .unwrap_or_default();
467
468 if scripted.is_empty() {
469 vec![parse_mock_response_dict(response)]
470 } else {
471 scripted
472 }
473}
474
475fn consume_http_mock(
476 method: &str,
477 url: &str,
478 headers: BTreeMap<String, VmValue>,
479 body: Option<String>,
480) -> Option<MockResponse> {
481 let response = HTTP_MOCKS.with(|mocks| {
482 let mut mocks = mocks.borrow_mut();
483 for mock in mocks.iter_mut() {
484 if (mock.method == "*" || mock.method.eq_ignore_ascii_case(method))
485 && url_matches(&mock.url_pattern, url)
486 {
487 let Some(last_index) = mock.responses.len().checked_sub(1) else {
488 continue;
489 };
490 let index = mock.next_response.min(last_index);
491 let response = mock.responses[index].clone();
492 if mock.next_response < last_index {
493 mock.next_response += 1;
494 }
495 return Some(response);
496 }
497 }
498 None
499 })?;
500
501 HTTP_MOCK_CALLS.with(|calls| {
502 calls.borrow_mut().push(HttpMockCall {
503 method: method.to_string(),
504 url: url.to_string(),
505 headers,
506 body,
507 });
508 });
509
510 Some(response)
511}
512
513pub fn register_http_builtins(vm: &mut Vm) {
515 vm.register_async_builtin("http_get", |args| async move {
516 http_verb_handler("GET", false, args).await
517 });
518 vm.register_async_builtin("http_post", |args| async move {
519 http_verb_handler("POST", true, args).await
520 });
521 vm.register_async_builtin("http_put", |args| async move {
522 http_verb_handler("PUT", true, args).await
523 });
524 vm.register_async_builtin("http_patch", |args| async move {
525 http_verb_handler("PATCH", true, args).await
526 });
527 vm.register_async_builtin("http_delete", |args| async move {
528 http_verb_handler("DELETE", false, args).await
529 });
530
531 vm.register_builtin("http_mock", |args, _out| {
535 let method = args.first().map(|a| a.display()).unwrap_or_default();
536 let url_pattern = args.get(1).map(|a| a.display()).unwrap_or_default();
537 let response = args
538 .get(2)
539 .and_then(|a| a.as_dict())
540 .cloned()
541 .unwrap_or_default();
542 let responses = parse_mock_responses(&response);
543
544 HTTP_MOCKS.with(|mocks| {
545 mocks.borrow_mut().push(HttpMock {
546 method,
547 url_pattern,
548 responses,
549 next_response: 0,
550 });
551 });
552 Ok(VmValue::Nil)
553 });
554
555 vm.register_builtin("http_mock_clear", |_args, _out| {
557 HTTP_MOCKS.with(|mocks| mocks.borrow_mut().clear());
558 HTTP_MOCK_CALLS.with(|calls| calls.borrow_mut().clear());
559 Ok(VmValue::Nil)
560 });
561
562 vm.register_builtin("http_mock_calls", |_args, _out| {
564 let calls = HTTP_MOCK_CALLS.with(|calls| calls.borrow().clone());
565 let result: Vec<VmValue> = calls
566 .iter()
567 .map(|c| {
568 let mut dict = BTreeMap::new();
569 dict.insert(
570 "method".to_string(),
571 VmValue::String(Rc::from(c.method.as_str())),
572 );
573 dict.insert("url".to_string(), VmValue::String(Rc::from(c.url.as_str())));
574 dict.insert(
575 "headers".to_string(),
576 VmValue::Dict(Rc::new(c.headers.clone())),
577 );
578 dict.insert(
579 "body".to_string(),
580 match &c.body {
581 Some(b) => VmValue::String(Rc::from(b.as_str())),
582 None => VmValue::Nil,
583 },
584 );
585 VmValue::Dict(Rc::new(dict))
586 })
587 .collect();
588 Ok(VmValue::List(Rc::new(result)))
589 });
590
591 vm.register_async_builtin("http_request", |args| async move {
592 let method = args
593 .first()
594 .map(|a| a.display())
595 .unwrap_or_default()
596 .to_uppercase();
597 if method.is_empty() {
598 return Err(VmError::Thrown(VmValue::String(Rc::from(
599 "http_request: method is required",
600 ))));
601 }
602 let url = args.get(1).map(|a| a.display()).unwrap_or_default();
603 if url.is_empty() {
604 return Err(VmError::Thrown(VmValue::String(Rc::from(
605 "http_request: URL is required",
606 ))));
607 }
608 let options = match args.get(2) {
609 Some(VmValue::Dict(d)) => (**d).clone(),
610 _ => BTreeMap::new(),
611 };
612 vm_execute_http_request(&method, &url, &options).await
613 });
614
615 vm.register_builtin("http_session", |args, _out| {
616 let options = get_options_arg(args, 0);
617 let config = parse_http_options(&options);
618 let client = build_http_client(&config)?;
619 let id = next_transport_handle("http-session");
620 HTTP_SESSIONS.with(|sessions| {
621 let mut sessions = sessions.borrow_mut();
622 if sessions.len() >= MAX_HTTP_SESSIONS {
623 return Err(vm_error(format!(
624 "http_session: maximum open sessions ({MAX_HTTP_SESSIONS}) reached"
625 )));
626 }
627 sessions.insert(id.clone(), HttpSession { client, options });
628 Ok(())
629 })?;
630 Ok(VmValue::String(Rc::from(id)))
631 });
632
633 vm.register_async_builtin("http_session_request", |args| async move {
634 if args.len() < 3 {
635 return Err(vm_error(
636 "http_session_request: requires session, method, and URL",
637 ));
638 }
639 let session_id = handle_from_value(&args[0], "http_session_request")?;
640 let method = args[1].display().to_uppercase();
641 if method.is_empty() {
642 return Err(vm_error("http_session_request: method is required"));
643 }
644 let url = args[2].display();
645 if url.is_empty() {
646 return Err(vm_error("http_session_request: URL is required"));
647 }
648 let options = get_options_arg(&args, 3);
649 vm_execute_http_session_request(&session_id, &method, &url, &options).await
650 });
651
652 vm.register_builtin("http_session_close", |args, _out| {
653 let Some(handle) = args.first() else {
654 return Err(vm_error("http_session_close: requires a session handle"));
655 };
656 let session_id = handle_from_value(handle, "http_session_close")?;
657 let removed = HTTP_SESSIONS.with(|sessions| sessions.borrow_mut().remove(&session_id));
658 Ok(VmValue::Bool(removed.is_some()))
659 });
660
661 vm.register_builtin("sse_mock", |args, _out| {
662 let url_pattern = args.first().map(|arg| arg.display()).unwrap_or_default();
663 if url_pattern.is_empty() {
664 return Err(vm_error("sse_mock: URL pattern is required"));
665 }
666 let events = parse_mock_stream_events(args.get(1));
667 SSE_MOCKS.with(|mocks| {
668 mocks.borrow_mut().push(SseMock {
669 url_pattern,
670 events,
671 });
672 });
673 Ok(VmValue::Nil)
674 });
675
676 vm.register_async_builtin("sse_connect", |args| async move {
677 let method = args
678 .first()
679 .map(|arg| arg.display())
680 .filter(|method| !method.is_empty())
681 .unwrap_or_else(|| "GET".to_string())
682 .to_uppercase();
683 let url = args.get(1).map(|arg| arg.display()).unwrap_or_default();
684 if url.is_empty() {
685 return Err(vm_error("sse_connect: URL is required"));
686 }
687 let options = get_options_arg(&args, 2);
688 vm_sse_connect(&method, &url, &options).await
689 });
690
691 vm.register_async_builtin("sse_receive", |args| async move {
692 let Some(handle) = args.first() else {
693 return Err(vm_error("sse_receive: requires a stream handle"));
694 };
695 let stream_id = handle_from_value(handle, "sse_receive")?;
696 let timeout_ms = receive_timeout_arg(&args, 1);
697 vm_sse_receive(&stream_id, timeout_ms).await
698 });
699
700 vm.register_builtin("sse_close", |args, _out| {
701 let Some(handle) = args.first() else {
702 return Err(vm_error("sse_close: requires a stream handle"));
703 };
704 let stream_id = handle_from_value(handle, "sse_close")?;
705 let removed = SSE_HANDLES.with(|handles| {
706 let mut handles = handles.borrow_mut();
707 let removed = handles.remove(&stream_id);
708 if let Some(handle) = &removed {
709 if let SseHandleKind::Real(stream) = &handle.kind {
710 if let Ok(mut stream) = stream.try_lock() {
711 stream.close();
712 }
713 }
714 }
715 removed
716 });
717 Ok(VmValue::Bool(removed.is_some()))
718 });
719
720 vm.register_builtin("websocket_mock", |args, _out| {
721 let url_pattern = args.first().map(|arg| arg.display()).unwrap_or_default();
722 if url_pattern.is_empty() {
723 return Err(vm_error("websocket_mock: URL pattern is required"));
724 }
725 let (messages, echo) = parse_websocket_mock(args.get(1));
726 WEBSOCKET_MOCKS.with(|mocks| {
727 mocks.borrow_mut().push(WebSocketMock {
728 url_pattern,
729 messages,
730 echo,
731 });
732 });
733 Ok(VmValue::Nil)
734 });
735
736 vm.register_async_builtin("websocket_connect", |args| async move {
737 let url = args.first().map(|arg| arg.display()).unwrap_or_default();
738 if url.is_empty() {
739 return Err(vm_error("websocket_connect: URL is required"));
740 }
741 let options = get_options_arg(&args, 1);
742 vm_websocket_connect(&url, &options).await
743 });
744
745 vm.register_async_builtin("websocket_send", |args| async move {
746 if args.len() < 2 {
747 return Err(vm_error(
748 "websocket_send: requires socket handle and message",
749 ));
750 }
751 let socket_id = handle_from_value(&args[0], "websocket_send")?;
752 let message = args[1].clone();
753 let options = get_options_arg(&args, 2);
754 vm_websocket_send(&socket_id, message, &options).await
755 });
756
757 vm.register_async_builtin("websocket_receive", |args| async move {
758 let Some(handle) = args.first() else {
759 return Err(vm_error("websocket_receive: requires a socket handle"));
760 };
761 let socket_id = handle_from_value(handle, "websocket_receive")?;
762 let timeout_ms = receive_timeout_arg(&args, 1);
763 vm_websocket_receive(&socket_id, timeout_ms).await
764 });
765
766 vm.register_async_builtin("websocket_close", |args| async move {
767 let Some(handle) = args.first() else {
768 return Err(vm_error("websocket_close: requires a socket handle"));
769 };
770 let socket_id = handle_from_value(handle, "websocket_close")?;
771 vm_websocket_close(&socket_id).await
772 });
773
774 vm.register_builtin("transport_mock_clear", |_args, _out| {
775 SSE_MOCKS.with(|mocks| mocks.borrow_mut().clear());
776 SSE_HANDLES.with(|handles| handles.borrow_mut().clear());
777 WEBSOCKET_MOCKS.with(|mocks| mocks.borrow_mut().clear());
778 WEBSOCKET_HANDLES.with(|handles| handles.borrow_mut().clear());
779 TRANSPORT_MOCK_CALLS.with(|calls| calls.borrow_mut().clear());
780 Ok(VmValue::Nil)
781 });
782
783 vm.register_builtin("transport_mock_calls", |_args, _out| {
784 let calls = TRANSPORT_MOCK_CALLS.with(|calls| calls.borrow().clone());
785 let values = calls
786 .iter()
787 .map(transport_mock_call_value)
788 .collect::<Vec<_>>();
789 Ok(VmValue::List(Rc::new(values)))
790 });
791}
792
793fn vm_get_int_option(options: &BTreeMap<String, VmValue>, key: &str, default: i64) -> i64 {
794 options.get(key).and_then(|v| v.as_int()).unwrap_or(default)
795}
796
797fn vm_get_bool_option(options: &BTreeMap<String, VmValue>, key: &str, default: bool) -> bool {
798 match options.get(key) {
799 Some(VmValue::Bool(b)) => *b,
800 _ => default,
801 }
802}
803
804fn vm_get_int_option_prefer(
805 options: &BTreeMap<String, VmValue>,
806 canonical: &str,
807 alias: &str,
808 default: i64,
809) -> i64 {
810 options
811 .get(canonical)
812 .and_then(|value| value.as_int())
813 .or_else(|| options.get(alias).and_then(|value| value.as_int()))
814 .unwrap_or(default)
815}
816
817fn parse_retry_statuses(options: &BTreeMap<String, VmValue>) -> Vec<u16> {
818 match options.get("retry_on") {
819 Some(VmValue::List(values)) => {
820 let statuses: Vec<u16> = values
821 .iter()
822 .filter_map(|value| value.as_int())
823 .filter(|status| (0..=u16::MAX as i64).contains(status))
824 .map(|status| status as u16)
825 .collect();
826 if statuses.is_empty() {
827 DEFAULT_RETRYABLE_STATUSES.to_vec()
828 } else {
829 statuses
830 }
831 }
832 _ => DEFAULT_RETRYABLE_STATUSES.to_vec(),
833 }
834}
835
836fn parse_retry_methods(options: &BTreeMap<String, VmValue>) -> Vec<String> {
837 match options.get("retry_methods") {
838 Some(VmValue::List(values)) => {
839 let methods: Vec<String> = values
840 .iter()
841 .map(|value| value.display().trim().to_ascii_uppercase())
842 .filter(|value| !value.is_empty())
843 .collect();
844 if methods.is_empty() {
845 DEFAULT_RETRYABLE_METHODS
846 .iter()
847 .map(|method| (*method).to_string())
848 .collect()
849 } else {
850 methods
851 }
852 }
853 _ => DEFAULT_RETRYABLE_METHODS
854 .iter()
855 .map(|method| (*method).to_string())
856 .collect(),
857 }
858}
859
860fn parse_http_options(options: &BTreeMap<String, VmValue>) -> HttpRequestConfig {
861 let timeout_ms = vm_get_int_option_prefer(
862 options,
863 "timeout_ms",
864 "timeout",
865 DEFAULT_TIMEOUT_MS as i64,
866 )
867 .max(0) as u64;
868 let retry_options = options.get("retry").and_then(|value| value.as_dict());
869 let retry_max = retry_options
870 .and_then(|retry| retry.get("max"))
871 .and_then(|value| value.as_int())
872 .unwrap_or_else(|| vm_get_int_option(options, "retries", 0))
873 .max(0) as u32;
874 let retry_backoff_ms = retry_options
875 .and_then(|retry| retry.get("backoff_ms"))
876 .and_then(|value| value.as_int())
877 .unwrap_or_else(|| vm_get_int_option(options, "backoff", DEFAULT_BACKOFF_MS as i64))
878 .max(0) as u64;
879 let respect_retry_after = vm_get_bool_option(options, "respect_retry_after", true);
880 let follow_redirects = vm_get_bool_option(options, "follow_redirects", true);
881 let max_redirects = vm_get_int_option(options, "max_redirects", 10).max(0) as usize;
882
883 HttpRequestConfig {
884 timeout_ms,
885 retry: RetryConfig {
886 max: retry_max,
887 backoff_ms: retry_backoff_ms,
888 retryable_statuses: parse_retry_statuses(options),
889 retryable_methods: parse_retry_methods(options),
890 respect_retry_after,
891 },
892 follow_redirects,
893 max_redirects,
894 }
895}
896
897fn http_client_key(config: &HttpRequestConfig) -> String {
898 format!(
899 "follow_redirects={};max_redirects={}",
900 config.follow_redirects, config.max_redirects
901 )
902}
903
904fn build_http_client(config: &HttpRequestConfig) -> Result<reqwest::Client, VmError> {
905 let redirect_policy = if config.follow_redirects {
906 reqwest::redirect::Policy::limited(config.max_redirects)
907 } else {
908 reqwest::redirect::Policy::none()
909 };
910
911 reqwest::Client::builder()
912 .redirect(redirect_policy)
913 .build()
914 .map_err(|e| vm_error(format!("http: failed to build client: {e}")))
915}
916
917fn pooled_http_client(config: &HttpRequestConfig) -> Result<reqwest::Client, VmError> {
918 let key = http_client_key(config);
919 if let Some(client) = HTTP_CLIENTS.with(|clients| clients.borrow().get(&key).cloned()) {
920 return Ok(client);
921 }
922
923 let client = build_http_client(config)?;
924 HTTP_CLIENTS.with(|clients| {
925 clients.borrow_mut().insert(key, client.clone());
926 });
927 Ok(client)
928}
929
930fn parse_http_request_parts(
931 method: &str,
932 options: &BTreeMap<String, VmValue>,
933) -> Result<HttpRequestParts, VmError> {
934 let req_method = method
935 .parse::<reqwest::Method>()
936 .map_err(|e| vm_error(format!("http: invalid method '{method}': {e}")))?;
937
938 let mut header_map = reqwest::header::HeaderMap::new();
939 let mut recorded_headers = BTreeMap::new();
940
941 if let Some(auth_val) = options.get("auth") {
942 match auth_val {
943 VmValue::String(s) => {
944 let hv = reqwest::header::HeaderValue::from_str(s)
945 .map_err(|e| vm_error(format!("http: invalid auth header value: {e}")))?;
946 header_map.insert(reqwest::header::AUTHORIZATION, hv);
947 recorded_headers.insert(
948 "Authorization".to_string(),
949 VmValue::String(Rc::from(s.as_ref())),
950 );
951 }
952 VmValue::Dict(d) => {
953 if let Some(bearer) = d.get("bearer") {
954 let token = bearer.display();
955 let authorization = format!("Bearer {token}");
956 let hv = reqwest::header::HeaderValue::from_str(&authorization)
957 .map_err(|e| vm_error(format!("http: invalid bearer token: {e}")))?;
958 header_map.insert(reqwest::header::AUTHORIZATION, hv);
959 recorded_headers.insert(
960 "Authorization".to_string(),
961 VmValue::String(Rc::from(authorization)),
962 );
963 } else if let Some(VmValue::Dict(basic)) = d.get("basic") {
964 let user = basic.get("user").map(|v| v.display()).unwrap_or_default();
965 let password = basic
966 .get("password")
967 .map(|v| v.display())
968 .unwrap_or_default();
969 use base64::Engine;
970 let encoded = base64::engine::general_purpose::STANDARD
971 .encode(format!("{user}:{password}"));
972 let authorization = format!("Basic {encoded}");
973 let hv = reqwest::header::HeaderValue::from_str(&authorization)
974 .map_err(|e| vm_error(format!("http: invalid basic auth: {e}")))?;
975 header_map.insert(reqwest::header::AUTHORIZATION, hv);
976 recorded_headers.insert(
977 "Authorization".to_string(),
978 VmValue::String(Rc::from(authorization)),
979 );
980 }
981 }
982 _ => {}
983 }
984 }
985
986 if let Some(VmValue::Dict(hdrs)) = options.get("headers") {
987 for (k, v) in hdrs.iter() {
988 let name = reqwest::header::HeaderName::from_bytes(k.as_bytes())
989 .map_err(|e| vm_error(format!("http: invalid header name '{k}': {e}")))?;
990 let val = reqwest::header::HeaderValue::from_str(&v.display())
991 .map_err(|e| vm_error(format!("http: invalid header value for '{k}': {e}")))?;
992 header_map.insert(name, val);
993 recorded_headers.insert(k.clone(), VmValue::String(Rc::from(v.display())));
994 }
995 }
996
997 Ok(HttpRequestParts {
998 method: req_method,
999 headers: header_map,
1000 recorded_headers,
1001 body: options.get("body").map(|v| v.display()),
1002 })
1003}
1004
1005fn session_from_options(options: &BTreeMap<String, VmValue>) -> Option<String> {
1006 options
1007 .get("session")
1008 .and_then(|value| handle_from_value(value, "http_request").ok())
1009}
1010
1011fn parse_mock_stream_event(value: &VmValue) -> MockStreamEvent {
1012 match value {
1013 VmValue::Dict(dict) => MockStreamEvent {
1014 event_type: dict
1015 .get("event")
1016 .or_else(|| dict.get("type"))
1017 .map(|value| value.display())
1018 .filter(|value| !value.is_empty())
1019 .unwrap_or_else(|| "message".to_string()),
1020 data: dict
1021 .get("data")
1022 .map(|value| value.display())
1023 .unwrap_or_default(),
1024 id: dict
1025 .get("id")
1026 .map(|value| value.display())
1027 .filter(|value| !value.is_empty()),
1028 retry_ms: dict.get("retry_ms").and_then(|value| value.as_int()),
1029 },
1030 _ => MockStreamEvent {
1031 event_type: "message".to_string(),
1032 data: value.display(),
1033 id: None,
1034 retry_ms: None,
1035 },
1036 }
1037}
1038
1039fn parse_mock_stream_events(value: Option<&VmValue>) -> Vec<MockStreamEvent> {
1040 let Some(value) = value else {
1041 return Vec::new();
1042 };
1043 match value {
1044 VmValue::Dict(dict) => dict
1045 .get("events")
1046 .and_then(|events| match events {
1047 VmValue::List(items) => Some(items.iter().map(parse_mock_stream_event).collect()),
1048 _ => None,
1049 })
1050 .unwrap_or_default(),
1051 VmValue::List(items) => items.iter().map(parse_mock_stream_event).collect(),
1052 other => vec![parse_mock_stream_event(other)],
1053 }
1054}
1055
1056fn sse_event_value(event: &MockStreamEvent) -> VmValue {
1057 let mut dict = BTreeMap::new();
1058 dict.insert("type".to_string(), VmValue::String(Rc::from("event")));
1059 dict.insert(
1060 "event".to_string(),
1061 VmValue::String(Rc::from(event.event_type.as_str())),
1062 );
1063 dict.insert(
1064 "data".to_string(),
1065 VmValue::String(Rc::from(event.data.as_str())),
1066 );
1067 dict.insert(
1068 "id".to_string(),
1069 event
1070 .id
1071 .as_deref()
1072 .map(|id| VmValue::String(Rc::from(id)))
1073 .unwrap_or(VmValue::Nil),
1074 );
1075 dict.insert(
1076 "retry_ms".to_string(),
1077 event.retry_ms.map(VmValue::Int).unwrap_or(VmValue::Nil),
1078 );
1079 VmValue::Dict(Rc::new(dict))
1080}
1081
1082fn real_sse_event_value(event: SseEvent) -> VmValue {
1083 match event {
1084 SseEvent::Open => {
1085 let mut dict = BTreeMap::new();
1086 dict.insert("type".to_string(), VmValue::String(Rc::from("open")));
1087 VmValue::Dict(Rc::new(dict))
1088 }
1089 SseEvent::Message(message) => {
1090 let retry_ms = message.retry.map(|retry| retry.as_millis() as i64);
1091 sse_event_value(&MockStreamEvent {
1092 event_type: if message.event.is_empty() {
1093 "message".to_string()
1094 } else {
1095 message.event
1096 },
1097 data: message.data,
1098 id: if message.id.is_empty() {
1099 None
1100 } else {
1101 Some(message.id)
1102 },
1103 retry_ms,
1104 })
1105 }
1106 }
1107}
1108
1109fn consume_sse_mock(url: &str) -> Option<Vec<MockStreamEvent>> {
1110 SSE_MOCKS.with(|mocks| {
1111 mocks
1112 .borrow()
1113 .iter()
1114 .find(|mock| url_matches(&mock.url_pattern, url))
1115 .map(|mock| mock.events.clone())
1116 })
1117}
1118
1119fn parse_ws_message(value: &VmValue) -> MockWsMessage {
1120 match value {
1121 VmValue::Dict(dict) => {
1122 let message_type = dict
1123 .get("type")
1124 .map(|value| value.display())
1125 .filter(|value| !value.is_empty())
1126 .unwrap_or_else(|| "text".to_string());
1127 let data = if dict
1128 .get("base64")
1129 .and_then(|value| match value {
1130 VmValue::Bool(value) => Some(*value),
1131 _ => None,
1132 })
1133 .unwrap_or(false)
1134 {
1135 use base64::Engine;
1136 dict.get("data")
1137 .map(|value| value.display())
1138 .and_then(|data| base64::engine::general_purpose::STANDARD.decode(data).ok())
1139 .unwrap_or_default()
1140 } else {
1141 dict.get("data")
1142 .map(|value| value.display().into_bytes())
1143 .unwrap_or_default()
1144 };
1145 MockWsMessage { message_type, data }
1146 }
1147 VmValue::Bytes(bytes) => MockWsMessage {
1148 message_type: "binary".to_string(),
1149 data: bytes.as_ref().clone(),
1150 },
1151 other => MockWsMessage {
1152 message_type: "text".to_string(),
1153 data: other.display().into_bytes(),
1154 },
1155 }
1156}
1157
1158fn parse_websocket_mock(value: Option<&VmValue>) -> (Vec<MockWsMessage>, bool) {
1159 let Some(value) = value else {
1160 return (Vec::new(), false);
1161 };
1162 match value {
1163 VmValue::Dict(dict) => {
1164 let echo = dict
1165 .get("echo")
1166 .and_then(|value| match value {
1167 VmValue::Bool(value) => Some(*value),
1168 _ => None,
1169 })
1170 .unwrap_or(false);
1171 let messages = dict
1172 .get("messages")
1173 .and_then(|messages| match messages {
1174 VmValue::List(items) => Some(items.iter().map(parse_ws_message).collect()),
1175 _ => None,
1176 })
1177 .unwrap_or_default();
1178 (messages, echo)
1179 }
1180 VmValue::List(items) => (items.iter().map(parse_ws_message).collect(), false),
1181 other => (vec![parse_ws_message(other)], false),
1182 }
1183}
1184
1185fn consume_websocket_mock(url: &str) -> Option<(Vec<MockWsMessage>, bool)> {
1186 WEBSOCKET_MOCKS.with(|mocks| {
1187 mocks
1188 .borrow()
1189 .iter()
1190 .find(|mock| url_matches(&mock.url_pattern, url))
1191 .map(|mock| (mock.messages.clone(), mock.echo))
1192 })
1193}
1194
1195fn ws_message_data(message: &MockWsMessage) -> String {
1196 match message.message_type.as_str() {
1197 "text" => String::from_utf8_lossy(&message.data).into_owned(),
1198 _ => {
1199 use base64::Engine;
1200 base64::engine::general_purpose::STANDARD.encode(&message.data)
1201 }
1202 }
1203}
1204
1205fn ws_event_value(message: MockWsMessage) -> VmValue {
1206 let mut dict = BTreeMap::new();
1207 dict.insert(
1208 "type".to_string(),
1209 VmValue::String(Rc::from(message.message_type.as_str())),
1210 );
1211 match message.message_type.as_str() {
1212 "text" => {
1213 dict.insert(
1214 "data".to_string(),
1215 VmValue::String(Rc::from(String::from_utf8_lossy(&message.data).as_ref())),
1216 );
1217 }
1218 _ => {
1219 use base64::Engine;
1220 dict.insert(
1221 "data_base64".to_string(),
1222 VmValue::String(Rc::from(
1223 base64::engine::general_purpose::STANDARD
1224 .encode(&message.data)
1225 .as_str(),
1226 )),
1227 );
1228 }
1229 }
1230 VmValue::Dict(Rc::new(dict))
1231}
1232
1233fn real_ws_event_value(message: WsMessage) -> VmValue {
1234 match message {
1235 WsMessage::Text(text) => ws_event_value(MockWsMessage {
1236 message_type: "text".to_string(),
1237 data: text.as_bytes().to_vec(),
1238 }),
1239 WsMessage::Binary(bytes) => ws_event_value(MockWsMessage {
1240 message_type: "binary".to_string(),
1241 data: bytes.to_vec(),
1242 }),
1243 WsMessage::Ping(bytes) => ws_event_value(MockWsMessage {
1244 message_type: "ping".to_string(),
1245 data: bytes.to_vec(),
1246 }),
1247 WsMessage::Pong(bytes) => ws_event_value(MockWsMessage {
1248 message_type: "pong".to_string(),
1249 data: bytes.to_vec(),
1250 }),
1251 WsMessage::Close(_) => closed_event(),
1252 WsMessage::Frame(_) => VmValue::Nil,
1253 }
1254}
1255
1256fn transport_mock_call_value(call: &TransportMockCall) -> VmValue {
1257 let mut dict = BTreeMap::new();
1258 dict.insert(
1259 "kind".to_string(),
1260 VmValue::String(Rc::from(call.kind.as_str())),
1261 );
1262 dict.insert(
1263 "url".to_string(),
1264 VmValue::String(Rc::from(call.url.as_str())),
1265 );
1266 dict.insert(
1267 "handle".to_string(),
1268 call.handle
1269 .as_deref()
1270 .map(|handle| VmValue::String(Rc::from(handle)))
1271 .unwrap_or(VmValue::Nil),
1272 );
1273 dict.insert(
1274 "type".to_string(),
1275 call.message_type
1276 .as_deref()
1277 .map(|message_type| VmValue::String(Rc::from(message_type)))
1278 .unwrap_or(VmValue::Nil),
1279 );
1280 dict.insert(
1281 "data".to_string(),
1282 call.data
1283 .as_deref()
1284 .map(|data| VmValue::String(Rc::from(data)))
1285 .unwrap_or(VmValue::Nil),
1286 );
1287 VmValue::Dict(Rc::new(dict))
1288}
1289
1290fn method_is_retryable(retry: &RetryConfig, method: &reqwest::Method) -> bool {
1291 retry
1292 .retryable_methods
1293 .iter()
1294 .any(|candidate| candidate.eq_ignore_ascii_case(method.as_str()))
1295}
1296
1297fn should_retry_response(
1298 config: &HttpRequestConfig,
1299 method: &reqwest::Method,
1300 status: u16,
1301 attempt: u32,
1302) -> bool {
1303 attempt < config.retry.max
1304 && method_is_retryable(&config.retry, method)
1305 && config.retry.retryable_statuses.contains(&status)
1306}
1307
1308fn should_retry_transport(
1309 config: &HttpRequestConfig,
1310 method: &reqwest::Method,
1311 error: &reqwest::Error,
1312 attempt: u32,
1313) -> bool {
1314 attempt < config.retry.max
1315 && method_is_retryable(&config.retry, method)
1316 && (error.is_timeout() || error.is_connect())
1317}
1318
1319fn parse_retry_after_value(value: &str) -> Option<Duration> {
1320 let value = value.trim();
1321 if value.is_empty() {
1322 return None;
1323 }
1324
1325 if let Ok(secs) = value.parse::<f64>() {
1326 if !secs.is_finite() || secs < 0.0 {
1327 return Some(Duration::from_millis(0));
1328 }
1329 let millis = (secs * 1_000.0) as u64;
1330 return Some(Duration::from_millis(millis.min(MAX_RETRY_DELAY_MS)));
1331 }
1332
1333 if let Ok(target) = httpdate::parse_http_date(value) {
1334 let millis = target
1335 .duration_since(SystemTime::now())
1336 .map(|delta| delta.as_millis() as u64)
1337 .unwrap_or(0);
1338 return Some(Duration::from_millis(millis.min(MAX_RETRY_DELAY_MS)));
1339 }
1340
1341 None
1342}
1343
1344fn parse_retry_after_header(value: &reqwest::header::HeaderValue) -> Option<Duration> {
1345 value.to_str().ok().and_then(parse_retry_after_value)
1346}
1347
1348fn mock_retry_after(status: u16, headers: &BTreeMap<String, VmValue>) -> Option<Duration> {
1349 if !(status == 429 || status == 503) {
1350 return None;
1351 }
1352
1353 headers
1354 .iter()
1355 .find(|(name, _)| name.eq_ignore_ascii_case("retry-after"))
1356 .and_then(|(_, value)| parse_retry_after_value(&value.display()))
1357}
1358
1359fn response_retry_after(
1360 status: u16,
1361 headers: &reqwest::header::HeaderMap,
1362 respect_retry_after: bool,
1363) -> Option<Duration> {
1364 if !respect_retry_after || !(status == 429 || status == 503) {
1365 return None;
1366 }
1367 headers
1368 .get(reqwest::header::RETRY_AFTER)
1369 .and_then(parse_retry_after_header)
1370}
1371
1372fn compute_retry_delay(attempt: u32, base_ms: u64, retry_after: Option<Duration>) -> Duration {
1373 use rand::RngExt;
1374
1375 let base_delay = base_ms.saturating_mul(1u64 << attempt.min(30));
1376 let jitter: f64 = rand::rng().random_range(0.75..=1.25);
1377 let exponential_ms = ((base_delay as f64 * jitter) as u64).min(MAX_RETRY_DELAY_MS);
1378 let retry_after_ms = retry_after
1379 .map(|duration| duration.as_millis() as u64)
1380 .unwrap_or(0)
1381 .min(MAX_RETRY_DELAY_MS);
1382 Duration::from_millis(exponential_ms.max(retry_after_ms))
1383}
1384
1385async fn vm_execute_http_request(
1386 method: &str,
1387 url: &str,
1388 options: &BTreeMap<String, VmValue>,
1389) -> Result<VmValue, VmError> {
1390 if let Some(session_id) = session_from_options(options) {
1391 return vm_execute_http_session_request(&session_id, method, url, options).await;
1392 }
1393
1394 let config = parse_http_options(options);
1395 let client = pooled_http_client(&config)?;
1396 vm_execute_http_request_with_client(client, &config, method, url, options).await
1397}
1398
1399async fn vm_execute_http_session_request(
1400 session_id: &str,
1401 method: &str,
1402 url: &str,
1403 options: &BTreeMap<String, VmValue>,
1404) -> Result<VmValue, VmError> {
1405 let session = HTTP_SESSIONS.with(|sessions| sessions.borrow().get(session_id).cloned());
1406 let Some(session) = session else {
1407 return Err(vm_error(format!(
1408 "http_session_request: unknown HTTP session '{session_id}'"
1409 )));
1410 };
1411 let merged_options = merge_options(&session.options, options);
1412 let config = parse_http_options(&merged_options);
1413 vm_execute_http_request_with_client(session.client, &config, method, url, &merged_options).await
1414}
1415
1416async fn vm_execute_http_request_with_client(
1417 client: reqwest::Client,
1418 config: &HttpRequestConfig,
1419 method: &str,
1420 url: &str,
1421 options: &BTreeMap<String, VmValue>,
1422) -> Result<VmValue, VmError> {
1423 let parts = parse_http_request_parts(method, options)?;
1424
1425 for attempt in 0..=config.retry.max {
1426 if let Some(mock_response) = consume_http_mock(
1427 method,
1428 url,
1429 parts.recorded_headers.clone(),
1430 parts.body.clone(),
1431 ) {
1432 let status = mock_response.status.clamp(0, u16::MAX as i64) as u16;
1433 if should_retry_response(config, &parts.method, status, attempt) {
1434 let retry_after = if config.retry.respect_retry_after {
1435 mock_retry_after(status, &mock_response.headers)
1436 } else {
1437 None
1438 };
1439 tokio::time::sleep(compute_retry_delay(
1440 attempt,
1441 config.retry.backoff_ms,
1442 retry_after,
1443 ))
1444 .await;
1445 continue;
1446 }
1447
1448 return Ok(build_http_response(
1449 mock_response.status,
1450 mock_response.headers,
1451 mock_response.body,
1452 ));
1453 }
1454
1455 if !url.starts_with("http://") && !url.starts_with("https://") {
1456 return Err(vm_error(format!(
1457 "http: URL must start with http:// or https://, got '{url}'"
1458 )));
1459 }
1460
1461 let mut req = client.request(parts.method.clone(), url);
1462 req = req
1463 .headers(parts.headers.clone())
1464 .timeout(Duration::from_millis(config.timeout_ms));
1465 if let Some(ref b) = parts.body {
1466 req = req.body(b.clone());
1467 }
1468
1469 match req.send().await {
1470 Ok(response) => {
1471 let status = response.status().as_u16();
1472 if should_retry_response(config, &parts.method, status, attempt) {
1473 let retry_after = response_retry_after(
1474 status,
1475 response.headers(),
1476 config.retry.respect_retry_after,
1477 );
1478 tokio::time::sleep(compute_retry_delay(
1479 attempt,
1480 config.retry.backoff_ms,
1481 retry_after,
1482 ))
1483 .await;
1484 continue;
1485 }
1486
1487 let mut resp_headers = BTreeMap::new();
1488 for (name, value) in response.headers() {
1489 if let Ok(v) = value.to_str() {
1490 resp_headers
1491 .insert(name.as_str().to_string(), VmValue::String(Rc::from(v)));
1492 }
1493 }
1494
1495 let body_text = response
1496 .text()
1497 .await
1498 .map_err(|e| vm_error(format!("http: failed to read response body: {e}")))?;
1499 return Ok(build_http_response(status as i64, resp_headers, body_text));
1500 }
1501 Err(e) => {
1502 if should_retry_transport(config, &parts.method, &e, attempt) {
1503 tokio::time::sleep(compute_retry_delay(attempt, config.retry.backoff_ms, None))
1504 .await;
1505 continue;
1506 }
1507 return Err(vm_error(format!("http: request failed: {e}")));
1508 }
1509 }
1510 }
1511
1512 Err(vm_error("http: request failed"))
1513}
1514
1515async fn vm_sse_connect(
1516 method: &str,
1517 url: &str,
1518 options: &BTreeMap<String, VmValue>,
1519) -> Result<VmValue, VmError> {
1520 let id = next_transport_handle("sse");
1521 let max_events =
1522 transport_limit_option(options, "max_events", DEFAULT_MAX_STREAM_EVENTS).max(1);
1523 let max_message_bytes =
1524 transport_limit_option(options, "max_message_bytes", DEFAULT_MAX_MESSAGE_BYTES).max(1);
1525
1526 if let Some(events) = consume_sse_mock(url) {
1527 let handle = SseHandle {
1528 kind: SseHandleKind::Fake(Rc::new(tokio::sync::Mutex::new(FakeSseStream {
1529 events: events.into(),
1530 opened: false,
1531 closed: false,
1532 }))),
1533 url: url.to_string(),
1534 max_events,
1535 max_message_bytes,
1536 received: 0,
1537 };
1538 SSE_HANDLES.with(|handles| {
1539 let mut handles = handles.borrow_mut();
1540 if handles.len() >= MAX_SSE_STREAMS {
1541 return Err(vm_error(format!(
1542 "sse_connect: maximum open streams ({MAX_SSE_STREAMS}) reached"
1543 )));
1544 }
1545 handles.insert(id.clone(), handle);
1546 Ok(())
1547 })?;
1548 record_transport_call(TransportMockCall {
1549 kind: "sse_connect".to_string(),
1550 handle: Some(id.clone()),
1551 url: url.to_string(),
1552 message_type: None,
1553 data: None,
1554 });
1555 return Ok(VmValue::String(Rc::from(id)));
1556 }
1557
1558 if !url.starts_with("http://") && !url.starts_with("https://") {
1559 return Err(vm_error(format!(
1560 "sse_connect: URL must start with http:// or https://, got '{url}'"
1561 )));
1562 }
1563
1564 let config = parse_http_options(options);
1565 let client = if let Some(session_id) = session_from_options(options) {
1566 let session = HTTP_SESSIONS.with(|sessions| sessions.borrow().get(&session_id).cloned());
1567 session
1568 .map(|session| session.client)
1569 .ok_or_else(|| vm_error(format!("sse_connect: unknown HTTP session '{session_id}'")))?
1570 } else {
1571 pooled_http_client(&config)?
1572 };
1573 let parts = parse_http_request_parts(method, options)?;
1574 let mut request = client
1575 .request(parts.method, url)
1576 .headers(parts.headers)
1577 .timeout(Duration::from_millis(config.timeout_ms));
1578 if let Some(body) = parts.body {
1579 request = request.body(body);
1580 }
1581 let stream = EventSource::new(request)
1582 .map_err(|error| vm_error(format!("sse_connect: failed to create stream: {error}")))?;
1583 let handle = SseHandle {
1584 kind: SseHandleKind::Real(Rc::new(tokio::sync::Mutex::new(stream))),
1585 url: url.to_string(),
1586 max_events,
1587 max_message_bytes,
1588 received: 0,
1589 };
1590 SSE_HANDLES.with(|handles| {
1591 let mut handles = handles.borrow_mut();
1592 if handles.len() >= MAX_SSE_STREAMS {
1593 return Err(vm_error(format!(
1594 "sse_connect: maximum open streams ({MAX_SSE_STREAMS}) reached"
1595 )));
1596 }
1597 handles.insert(id.clone(), handle);
1598 Ok(())
1599 })?;
1600 Ok(VmValue::String(Rc::from(id)))
1601}
1602
1603async fn vm_sse_receive(stream_id: &str, timeout_ms: u64) -> Result<VmValue, VmError> {
1604 let stream = SSE_HANDLES.with(|handles| {
1605 let mut handles = handles.borrow_mut();
1606 let handle = handles.get_mut(stream_id)?;
1607 if handle.received >= handle.max_events {
1608 return Some(Err(vm_error(format!(
1609 "sse_receive: stream '{stream_id}' exceeded max_events"
1610 ))));
1611 }
1612 handle.received += 1;
1613 let url = handle.url.clone();
1614 let max_message_bytes = handle.max_message_bytes;
1615 let kind = match &handle.kind {
1616 SseHandleKind::Real(stream) => SseHandleKind::Real(stream.clone()),
1617 SseHandleKind::Fake(stream) => SseHandleKind::Fake(stream.clone()),
1618 };
1619 Some(Ok((kind, url, max_message_bytes)))
1620 });
1621 let Some(stream) = stream else {
1622 return Err(vm_error(format!(
1623 "sse_receive: unknown stream '{stream_id}'"
1624 )));
1625 };
1626 let (kind, _url, max_message_bytes) = stream?;
1627
1628 match kind {
1629 SseHandleKind::Fake(stream) => {
1630 let mut stream = stream.lock().await;
1631 if stream.closed {
1632 return Ok(VmValue::Nil);
1633 }
1634 if !stream.opened {
1635 stream.opened = true;
1636 let mut dict = BTreeMap::new();
1637 dict.insert("type".to_string(), VmValue::String(Rc::from("open")));
1638 return Ok(VmValue::Dict(Rc::new(dict)));
1639 }
1640 let Some(event) = stream.events.pop_front() else {
1641 stream.closed = true;
1642 return Ok(VmValue::Nil);
1643 };
1644 if event.data.len() > max_message_bytes {
1645 return Err(vm_error(format!(
1646 "sse_receive: message exceeded max_message_bytes ({max_message_bytes})"
1647 )));
1648 }
1649 Ok(sse_event_value(&event))
1650 }
1651 SseHandleKind::Real(stream) => {
1652 let mut stream = stream.lock().await;
1653 let next = stream.next();
1654 let event = match tokio::time::timeout(Duration::from_millis(timeout_ms), next).await {
1655 Ok(Some(Ok(event))) => event,
1656 Ok(Some(Err(error))) => {
1657 return Err(vm_error(format!("sse_receive: stream error: {error}")));
1658 }
1659 Ok(None) => return Ok(VmValue::Nil),
1660 Err(_) => return Ok(timeout_event()),
1661 };
1662 if let SseEvent::Message(message) = &event {
1663 if message.data.len() > max_message_bytes {
1664 stream.close();
1665 return Err(vm_error(format!(
1666 "sse_receive: message exceeded max_message_bytes ({max_message_bytes})"
1667 )));
1668 }
1669 }
1670 Ok(real_sse_event_value(event))
1671 }
1672 }
1673}
1674
1675async fn vm_websocket_connect(
1676 url: &str,
1677 options: &BTreeMap<String, VmValue>,
1678) -> Result<VmValue, VmError> {
1679 let id = next_transport_handle("websocket");
1680 let max_messages =
1681 transport_limit_option(options, "max_messages", DEFAULT_MAX_STREAM_EVENTS).max(1);
1682 let max_message_bytes =
1683 transport_limit_option(options, "max_message_bytes", DEFAULT_MAX_MESSAGE_BYTES).max(1);
1684
1685 if let Some((messages, echo)) = consume_websocket_mock(url) {
1686 let handle = WebSocketHandle {
1687 kind: WebSocketHandleKind::Fake(Rc::new(tokio::sync::Mutex::new(FakeWebSocket {
1688 messages: messages.into(),
1689 echo,
1690 closed: false,
1691 }))),
1692 url: url.to_string(),
1693 max_messages,
1694 max_message_bytes,
1695 received: 0,
1696 };
1697 WEBSOCKET_HANDLES.with(|handles| {
1698 let mut handles = handles.borrow_mut();
1699 if handles.len() >= MAX_WEBSOCKETS {
1700 return Err(vm_error(format!(
1701 "websocket_connect: maximum open sockets ({MAX_WEBSOCKETS}) reached"
1702 )));
1703 }
1704 handles.insert(id.clone(), handle);
1705 Ok(())
1706 })?;
1707 record_transport_call(TransportMockCall {
1708 kind: "websocket_connect".to_string(),
1709 handle: Some(id.clone()),
1710 url: url.to_string(),
1711 message_type: None,
1712 data: None,
1713 });
1714 return Ok(VmValue::String(Rc::from(id)));
1715 }
1716
1717 if !url.starts_with("ws://") && !url.starts_with("wss://") {
1718 return Err(vm_error(format!(
1719 "websocket_connect: URL must start with ws:// or wss://, got '{url}'"
1720 )));
1721 }
1722 let timeout_ms = vm_get_int_option_prefer(
1723 options,
1724 "timeout_ms",
1725 "timeout",
1726 DEFAULT_TIMEOUT_MS as i64,
1727 )
1728 .max(0) as u64;
1729 let connect = tokio_tungstenite::connect_async(url);
1730 let (socket, _) = tokio::time::timeout(Duration::from_millis(timeout_ms), connect)
1731 .await
1732 .map_err(|_| vm_error(format!("websocket_connect: timed out after {timeout_ms}ms")))?
1733 .map_err(|error| vm_error(format!("websocket_connect: failed: {error}")))?;
1734 let handle = WebSocketHandle {
1735 kind: WebSocketHandleKind::Real(Rc::new(tokio::sync::Mutex::new(socket))),
1736 url: url.to_string(),
1737 max_messages,
1738 max_message_bytes,
1739 received: 0,
1740 };
1741 WEBSOCKET_HANDLES.with(|handles| {
1742 let mut handles = handles.borrow_mut();
1743 if handles.len() >= MAX_WEBSOCKETS {
1744 return Err(vm_error(format!(
1745 "websocket_connect: maximum open sockets ({MAX_WEBSOCKETS}) reached"
1746 )));
1747 }
1748 handles.insert(id.clone(), handle);
1749 Ok(())
1750 })?;
1751 Ok(VmValue::String(Rc::from(id)))
1752}
1753
1754fn websocket_message_from_vm(
1755 value: VmValue,
1756 options: &BTreeMap<String, VmValue>,
1757) -> Result<MockWsMessage, VmError> {
1758 let message_type = options
1759 .get("type")
1760 .map(|value| value.display())
1761 .filter(|value| !value.is_empty())
1762 .unwrap_or_else(|| match value {
1763 VmValue::Bytes(_) => "binary".to_string(),
1764 _ => "text".to_string(),
1765 });
1766 let data = match value {
1767 VmValue::Bytes(bytes) => bytes.as_ref().clone(),
1768 other
1769 if options
1770 .get("base64")
1771 .and_then(|value| match value {
1772 VmValue::Bool(value) => Some(*value),
1773 _ => None,
1774 })
1775 .unwrap_or(false) =>
1776 {
1777 use base64::Engine;
1778 base64::engine::general_purpose::STANDARD
1779 .decode(other.display())
1780 .map_err(|error| vm_error(format!("websocket_send: invalid base64: {error}")))?
1781 }
1782 other => other.display().into_bytes(),
1783 };
1784 Ok(MockWsMessage { message_type, data })
1785}
1786
1787fn real_ws_message(message: &MockWsMessage) -> Result<WsMessage, VmError> {
1788 match message.message_type.as_str() {
1789 "text" => Ok(WsMessage::Text(
1790 String::from_utf8(message.data.clone())
1791 .map_err(|error| vm_error(format!("websocket_send: text is not UTF-8: {error}")))?
1792 .into(),
1793 )),
1794 "binary" => Ok(WsMessage::Binary(message.data.clone().into())),
1795 "ping" => Ok(WsMessage::Ping(message.data.clone().into())),
1796 "pong" => Ok(WsMessage::Pong(message.data.clone().into())),
1797 "close" => Ok(WsMessage::Close(None)),
1798 other => Err(vm_error(format!(
1799 "websocket_send: unsupported message type '{other}'"
1800 ))),
1801 }
1802}
1803
1804async fn vm_websocket_send(
1805 socket_id: &str,
1806 value: VmValue,
1807 options: &BTreeMap<String, VmValue>,
1808) -> Result<VmValue, VmError> {
1809 let message = websocket_message_from_vm(value, options)?;
1810 let socket = WEBSOCKET_HANDLES.with(|handles| {
1811 let handles = handles.borrow();
1812 let handle = handles.get(socket_id)?;
1813 let url = handle.url.clone();
1814 let max_message_bytes = handle.max_message_bytes;
1815 let kind = match &handle.kind {
1816 WebSocketHandleKind::Real(socket) => WebSocketHandleKind::Real(socket.clone()),
1817 WebSocketHandleKind::Fake(socket) => WebSocketHandleKind::Fake(socket.clone()),
1818 };
1819 Some((kind, url, max_message_bytes))
1820 });
1821 let Some((kind, url, max_message_bytes)) = socket else {
1822 return Err(vm_error(format!(
1823 "websocket_send: unknown socket '{socket_id}'"
1824 )));
1825 };
1826 if message.data.len() > max_message_bytes {
1827 return Err(vm_error(format!(
1828 "websocket_send: message exceeded max_message_bytes ({max_message_bytes})"
1829 )));
1830 }
1831 match kind {
1832 WebSocketHandleKind::Fake(socket) => {
1833 let mut socket = socket.lock().await;
1834 if socket.closed {
1835 return Ok(VmValue::Bool(false));
1836 }
1837 if message.message_type == "close" {
1838 socket.closed = true;
1839 } else if socket.echo {
1840 socket.messages.push_back(message.clone());
1841 }
1842 record_transport_call(TransportMockCall {
1843 kind: "websocket_send".to_string(),
1844 handle: Some(socket_id.to_string()),
1845 url,
1846 message_type: Some(message.message_type.clone()),
1847 data: Some(ws_message_data(&message)),
1848 });
1849 Ok(VmValue::Bool(true))
1850 }
1851 WebSocketHandleKind::Real(socket) => {
1852 let mut socket = socket.lock().await;
1853 socket
1854 .send(real_ws_message(&message)?)
1855 .await
1856 .map_err(|error| vm_error(format!("websocket_send: failed: {error}")))?;
1857 Ok(VmValue::Bool(true))
1858 }
1859 }
1860}
1861
1862async fn vm_websocket_receive(socket_id: &str, timeout_ms: u64) -> Result<VmValue, VmError> {
1863 let socket = WEBSOCKET_HANDLES.with(|handles| {
1864 let mut handles = handles.borrow_mut();
1865 let handle = handles.get_mut(socket_id)?;
1866 if handle.received >= handle.max_messages {
1867 return Some(Err(vm_error(format!(
1868 "websocket_receive: socket '{socket_id}' exceeded max_messages"
1869 ))));
1870 }
1871 handle.received += 1;
1872 let max_message_bytes = handle.max_message_bytes;
1873 let kind = match &handle.kind {
1874 WebSocketHandleKind::Real(socket) => WebSocketHandleKind::Real(socket.clone()),
1875 WebSocketHandleKind::Fake(socket) => WebSocketHandleKind::Fake(socket.clone()),
1876 };
1877 Some(Ok((kind, max_message_bytes)))
1878 });
1879 let Some(socket) = socket else {
1880 return Err(vm_error(format!(
1881 "websocket_receive: unknown socket '{socket_id}'"
1882 )));
1883 };
1884 let (kind, max_message_bytes) = socket?;
1885 match kind {
1886 WebSocketHandleKind::Fake(socket) => {
1887 let mut socket = socket.lock().await;
1888 if socket.closed {
1889 return Ok(VmValue::Nil);
1890 }
1891 let Some(message) = socket.messages.pop_front() else {
1892 return Ok(timeout_event());
1893 };
1894 if message.data.len() > max_message_bytes {
1895 socket.closed = true;
1896 return Err(vm_error(format!(
1897 "websocket_receive: message exceeded max_message_bytes ({max_message_bytes})"
1898 )));
1899 }
1900 if message.message_type == "close" {
1901 socket.closed = true;
1902 }
1903 Ok(ws_event_value(message))
1904 }
1905 WebSocketHandleKind::Real(socket) => {
1906 let mut socket = socket.lock().await;
1907 let next = socket.next();
1908 let message = match tokio::time::timeout(Duration::from_millis(timeout_ms), next).await
1909 {
1910 Ok(Some(Ok(message))) => message,
1911 Ok(Some(Err(error))) => {
1912 return Err(vm_error(format!("websocket_receive: failed: {error}")));
1913 }
1914 Ok(None) => return Ok(VmValue::Nil),
1915 Err(_) => return Ok(timeout_event()),
1916 };
1917 match &message {
1918 WsMessage::Text(text) if text.len() > max_message_bytes => {
1919 return Err(vm_error(format!(
1920 "websocket_receive: message exceeded max_message_bytes ({max_message_bytes})"
1921 )));
1922 }
1923 WsMessage::Binary(bytes) | WsMessage::Ping(bytes) | WsMessage::Pong(bytes)
1924 if bytes.len() > max_message_bytes =>
1925 {
1926 return Err(vm_error(format!(
1927 "websocket_receive: message exceeded max_message_bytes ({max_message_bytes})"
1928 )));
1929 }
1930 _ => {}
1931 }
1932 Ok(real_ws_event_value(message))
1933 }
1934 }
1935}
1936
1937async fn vm_websocket_close(socket_id: &str) -> Result<VmValue, VmError> {
1938 let removed = WEBSOCKET_HANDLES.with(|handles| handles.borrow_mut().remove(socket_id));
1939 let Some(handle) = removed else {
1940 return Ok(VmValue::Bool(false));
1941 };
1942 match handle.kind {
1943 WebSocketHandleKind::Fake(socket) => {
1944 socket.lock().await.closed = true;
1945 record_transport_call(TransportMockCall {
1946 kind: "websocket_close".to_string(),
1947 handle: Some(socket_id.to_string()),
1948 url: handle.url,
1949 message_type: None,
1950 data: None,
1951 });
1952 Ok(VmValue::Bool(true))
1953 }
1954 WebSocketHandleKind::Real(socket) => {
1955 let mut socket = socket.lock().await;
1956 socket
1957 .close(None)
1958 .await
1959 .map_err(|error| vm_error(format!("websocket_close: failed: {error}")))?;
1960 Ok(VmValue::Bool(true))
1961 }
1962 }
1963}
1964
1965#[cfg(test)]
1966mod tests {
1967 use super::{
1968 compute_retry_delay, http_mock_calls_snapshot, parse_retry_after_value, push_http_mock,
1969 reset_http_state, vm_execute_http_request, HttpMockResponse,
1970 };
1971 use crate::value::VmValue;
1972 use std::collections::BTreeMap;
1973 use std::time::{Duration, SystemTime};
1974
1975 #[test]
1976 fn parses_retry_after_delta_seconds() {
1977 assert_eq!(parse_retry_after_value("5"), Some(Duration::from_secs(5)));
1978 }
1979
1980 #[test]
1981 fn parses_retry_after_http_date() {
1982 let header = httpdate::fmt_http_date(SystemTime::now() + Duration::from_secs(2));
1983 let parsed = parse_retry_after_value(&header).expect("http-date should parse");
1984 assert!(parsed <= Duration::from_secs(2));
1985 assert!(parsed <= Duration::from_secs(60));
1986 }
1987
1988 #[test]
1989 fn malformed_retry_after_returns_none() {
1990 assert_eq!(parse_retry_after_value("soon-ish"), None);
1991 }
1992
1993 #[test]
1994 fn retry_delay_honors_retry_after_floor() {
1995 let delay = compute_retry_delay(0, 1, Some(Duration::from_millis(250)));
1996 assert!(delay >= Duration::from_millis(250));
1997 assert!(delay <= Duration::from_secs(60));
1998 }
1999
2000 #[tokio::test]
2001 async fn typed_mock_api_drives_http_request_retries() {
2002 reset_http_state();
2003 push_http_mock(
2004 "GET",
2005 "https://api.example.com/retry",
2006 vec![
2007 HttpMockResponse::new(503, "busy").with_header("retry-after", "0"),
2008 HttpMockResponse::new(200, "ok"),
2009 ],
2010 );
2011 let result = vm_execute_http_request(
2012 "GET",
2013 "https://api.example.com/retry",
2014 &BTreeMap::from([
2015 ("retries".to_string(), VmValue::Int(1)),
2016 ("backoff".to_string(), VmValue::Int(0)),
2017 ]),
2018 )
2019 .await
2020 .expect("mocked request should succeed after retry");
2021
2022 let dict = result.as_dict().expect("response dict");
2023 assert_eq!(dict["status"].as_int(), Some(200));
2024 let calls = http_mock_calls_snapshot();
2025 assert_eq!(calls.len(), 2);
2026 assert_eq!(calls[0].url, "https://api.example.com/retry");
2027 reset_http_state();
2028 }
2029}