Skip to main content

pdk_unit/tester/
websocket.rs

1// Copyright (c) 2026, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5use crate::host::implementation::{FlowType, ProxyWasmStub};
6use crate::tester::io::{RequestResponse, UnitFrame, UnitHttpResponse};
7use crate::tester::unit_test::{add_request_properties, respond_call, Backends};
8use pdk_websockets_lib::{Decoder, Encoder, Frame, SinkResult};
9use proxy_wasm_stub::stub::Host;
10use proxy_wasm_stub::traits::HttpContext;
11use proxy_wasm_stub::types::{Action, BufferType, MapType};
12use std::cell::RefCell;
13use std::collections::VecDeque;
14use std::rc::{Rc, Weak};
15use std::task::Poll;
16// ── Upgrade ───────────────────────────────────────────────────────────────────
17
18/// A handle to an in-progress WebSocket upgrade being processed by the policy.
19///
20/// Use [`poll`](Self::poll) to drive the upgrade handshake through
21/// `on_http_request_headers` and `on_http_response_headers`.
22#[derive(Clone)]
23pub struct UnitTestUpgrade {
24    pub(crate) inner: Rc<RefCell<InnerUnitUpgrade>>,
25}
26
27impl UnitTestUpgrade {
28    pub(crate) fn new(inner: InnerUnitUpgrade) -> Self {
29        Self {
30            inner: Rc::new(RefCell::new(inner)),
31        }
32    }
33
34    /// Advances the upgrade handshake.
35    ///
36    /// Returns `Poll::Ready(Ok(conn))` when the upgrade completes,
37    /// `Poll::Ready(Err(response))` if the policy rejected, or
38    /// `Poll::Pending` if waiting for an async call response.
39    pub fn poll(&mut self) -> Poll<Result<UpgradeConnection, UnitHttpResponse>> {
40        self.inner.borrow_mut().poll()
41    }
42}
43
44#[derive(PartialOrd, PartialEq, Copy, Clone, Debug)]
45enum UpgradeState {
46    RequestHeaders,
47    RequestHeadersPaused,
48    ResponseHeaders,
49    ResponseHeadersPaused,
50    Done,
51    Rejected,
52}
53
54pub(crate) struct InnerUnitUpgrade {
55    state: UpgradeState,
56    context_id: u32,
57    chunk_size: usize,
58    request: RequestResponse,
59    http_context: Box<dyn HttpContext>,
60    backends: Rc<RefCell<Backends>>,
61    host: Rc<RefCell<ProxyWasmStub>>,
62    cached_connection: Option<Rc<RefCell<ConnectionInner>>>,
63    cached_response: Option<UnitHttpResponse>,
64}
65
66impl InnerUnitUpgrade {
67    pub(crate) fn new(
68        context_id: u32,
69        request: RequestResponse,
70        http_context: Box<dyn HttpContext>,
71        backends: Rc<RefCell<Backends>>,
72        host: Rc<RefCell<ProxyWasmStub>>,
73        chunk_size: usize,
74    ) -> Self {
75        Self {
76            state: UpgradeState::RequestHeaders,
77            context_id,
78            chunk_size,
79            request,
80            http_context,
81            backends,
82            host,
83            cached_connection: None,
84            cached_response: None,
85        }
86    }
87
88    pub(crate) fn poll(&mut self) -> Poll<Result<UpgradeConnection, UnitHttpResponse>> {
89        if let Poll::Ready(result) = self.build_result() {
90            return Poll::Ready(result);
91        }
92
93        let context_id = self.context_id;
94        self.host.borrow_mut().set_context(context_id);
95
96        self.host.borrow_mut().set_flow_mode(FlowType::Upstream);
97        self.resume();
98
99        if self.state == UpgradeState::RequestHeaders {
100            self.host.borrow_mut().create_map(
101                context_id,
102                MapType::HttpRequestHeaders,
103                self.request
104                    .headers()
105                    .iter()
106                    .map(|(k, v)| (k.clone(), v.as_bytes().to_vec()))
107                    .collect(),
108            );
109
110            let action = self
111                .http_context
112                .on_http_request_headers(self.request.headers().len(), false);
113
114            if action == Action::Pause && self.clear_send_response() {
115                self.state = UpgradeState::Rejected;
116            } else if action == Action::Pause {
117                self.state = UpgradeState::RequestHeadersPaused;
118            } else {
119                self.state = UpgradeState::ResponseHeaders;
120            }
121
122            self.respond_calls();
123        }
124
125        self.host.borrow_mut().set_flow_mode(FlowType::Downstream);
126
127        if self.state == UpgradeState::ResponseHeaders {
128            let backend_response = self
129                .backends
130                .borrow()
131                .backend
132                .call(self.request.clone().into())
133                .inner;
134            let backend_response = add_request_properties(backend_response, context_id);
135
136            self.host.borrow_mut().create_map(
137                context_id,
138                MapType::HttpResponseHeaders,
139                backend_response
140                    .headers()
141                    .iter()
142                    .map(|(k, v)| (k.clone(), v.as_bytes().to_vec()))
143                    .collect(),
144            );
145
146            let num_headers = backend_response.headers().len();
147            self.cached_response = Some(UnitHttpResponse::from(backend_response));
148
149            let action = self
150                .http_context
151                .on_http_response_headers(num_headers, false);
152
153            if action == Action::Pause && self.clear_send_response() {
154                self.state = UpgradeState::Rejected;
155            } else if action == Action::Pause {
156                self.state = UpgradeState::ResponseHeadersPaused;
157            } else {
158                self.state = UpgradeState::Done;
159            }
160
161            self.respond_calls();
162        }
163
164        self.build_result()
165    }
166
167    fn build_result(&mut self) -> Poll<Result<UpgradeConnection, UnitHttpResponse>> {
168        match self.state {
169            UpgradeState::Done => Poll::Ready(Ok(self.build_connection())),
170            UpgradeState::Rejected => Poll::Ready(Err(self.read_response().into())),
171            _ => Poll::Pending,
172        }
173    }
174    fn build_connection(&mut self) -> UpgradeConnection {
175        if self.cached_connection.is_none() {
176            let context_id = self.context_id;
177            self.cached_connection = Some(Rc::new(RefCell::new(ConnectionInner {
178                context_id,
179                chunk_size: self.chunk_size,
180                http_context: self.take_http_context(),
181                backends: Rc::clone(&self.backends),
182                host: Rc::clone(&self.host),
183                connection_state: Default::default(),
184            })));
185        }
186        let response = self
187            .cached_response
188            .clone()
189            .unwrap_or_else(UnitHttpResponse::upgrade);
190        UpgradeConnection::from_inner(
191            Rc::clone(self.cached_connection.as_ref().unwrap()),
192            response,
193        )
194    }
195
196    fn resume(&mut self) -> bool {
197        let resume = self.host.borrow_mut().clear_resume(self.context_id);
198        let send_response = self.clear_send_response();
199        assert!(!(resume && send_response));
200
201        if resume {
202            self.state = match self.state {
203                UpgradeState::RequestHeadersPaused => UpgradeState::ResponseHeaders,
204                UpgradeState::ResponseHeadersPaused => UpgradeState::Done,
205                state => panic!("Called resume on non-paused upgrade state {state:?}"),
206            };
207            true
208        } else if send_response {
209            self.state = UpgradeState::Rejected;
210            true
211        } else {
212            false
213        }
214    }
215
216    fn clear_send_response(&mut self) -> bool {
217        self.host.borrow_mut().clear_send_response(self.context_id)
218    }
219
220    fn respond_calls(&mut self) {
221        let prev_flow = self.host.borrow_mut().set_flow_mode(FlowType::Async);
222        let mut pending = self.host.borrow_mut().pending_calls(self.context_id);
223        while !pending.is_empty() {
224            for (id, upstream, call) in pending {
225                respond_call(
226                    self.http_context.as_mut(),
227                    &self.host,
228                    &self.backends,
229                    self.context_id,
230                    id,
231                    upstream,
232                    call,
233                );
234                if self.resume() {
235                    self.host.borrow_mut().set_flow_mode(prev_flow);
236                    return;
237                }
238            }
239            pending = self.host.borrow_mut().pending_calls(self.context_id);
240        }
241        self.host.borrow_mut().set_flow_mode(prev_flow);
242    }
243
244    fn take_http_context(&mut self) -> Box<dyn HttpContext> {
245        struct NoopHttp;
246        impl proxy_wasm_stub::traits::Context for NoopHttp {}
247        impl HttpContext for NoopHttp {}
248
249        let mut placeholder: Box<dyn HttpContext> = Box::new(NoopHttp);
250        std::mem::swap(&mut self.http_context, &mut placeholder);
251        placeholder
252    }
253
254    fn read_response(&self) -> RequestResponse {
255        let headers = self
256            .host
257            .borrow()
258            .read_map(self.context_id, MapType::HttpResponseHeaders)
259            .into_iter()
260            .map(|(k, v)| (k, String::from_utf8(v).unwrap()))
261            .collect();
262        let body = self
263            .host
264            .borrow()
265            .read_buffer(self.context_id, BufferType::HttpResponseBody);
266        RequestResponse::create(headers, body, Default::default())
267    }
268}
269
270// ── Connection ────────────────────────────────────────────────────────────────
271
272/// A live upgraded WebSocket connection managed by the policy under test.
273///
274/// Obtained from [`UnitTest::upgrade`] or [`UnitTest::upgrade_partial`]. Use
275/// [`client`](Self::client) and [`server`](Self::server) to exchange frames.
276/// Pending outgoing calls triggered during body processing are resolved inline;
277/// call [`UnitTest::tick`] for calls that require simulated time to elapse.
278pub struct UpgradeConnection {
279    inner: Rc<RefCell<ConnectionInner>>,
280    response: UnitHttpResponse,
281}
282
283impl Drop for UpgradeConnection {
284    fn drop(&mut self) {
285        self.inner.borrow_mut().http_context.on_log();
286        self.inner.borrow_mut().http_context.on_done();
287    }
288}
289
290impl UpgradeConnection {
291    fn from_inner(inner: Rc<RefCell<ConnectionInner>>, response: UnitHttpResponse) -> Self {
292        Self { inner, response }
293    }
294
295    /// Returns the HTTP 101 response that completed the upgrade handshake.
296    pub fn response(&self) -> &UnitHttpResponse {
297        &self.response
298    }
299
300    /// Returns a client-side handle for sending frames to the server and reading frames back.
301    pub fn client(&self) -> ClientHandle {
302        ClientHandle {
303            inner: Rc::clone(&self.inner),
304        }
305    }
306
307    /// Returns a server-side handle for sending frames to the client and reading frames forwarded toward the server.
308    pub fn server(&self) -> ServerHandle {
309        ServerHandle {
310            inner: Rc::clone(&self.inner),
311        }
312    }
313
314    pub(crate) fn weak_inner(&self) -> Weak<RefCell<ConnectionInner>> {
315        Rc::downgrade(&self.inner)
316    }
317}
318
319// ── Handles ───────────────────────────────────────────────────────────────────
320
321/// A handle for the client side of an upgraded WebSocket connection.
322pub struct ClientHandle {
323    inner: Rc<RefCell<ConnectionInner>>,
324}
325
326impl ClientHandle {
327    /// Encode `frame` and deliver it to the policy via `on_http_request_body`.
328    ///
329    /// Outgoing calls triggered by the policy are resolved inline.
330    pub fn send_to_server(&self, frame: UnitFrame) {
331        self.inner.borrow_mut().send_upstream(vec![frame.frame])
332    }
333
334    #[cfg(feature = "experimental_websocket_bytes")]
335    /// Deliver the bytes to the policy via `on_http_request_body`.
336    pub fn send_bytes_to_server(&self, bytes: Vec<u8>) {
337        self.inner.borrow_mut().send_upstream_bytes(bytes)
338    }
339
340    /// Dequeue the next frame forwarded back to the client by the policy, or `None`.
341    pub fn next(&self) -> Option<UnitFrame> {
342        self.inner
343            .borrow_mut()
344            .connection_state
345            .client_ready_frames
346            .pop_front()
347            .map(|frame| UnitFrame { frame })
348    }
349
350    #[cfg(feature = "experimental_websocket_bytes")]
351    /// Takes the bytes that reached the client.
352    pub fn bytes(&self) -> Vec<u8> {
353        self.inner
354            .borrow_mut()
355            .connection_state
356            .client_ready_bytes
357            .split_off(0)
358    }
359}
360
361/// A handle for the server side of an upgraded WebSocket connection.
362pub struct ServerHandle {
363    inner: Rc<RefCell<ConnectionInner>>,
364}
365
366impl ServerHandle {
367    /// Encode `frame` and deliver it to the policy via `on_http_response_body`.
368    ///
369    /// Outgoing calls triggered by the policy are resolved inline.
370    pub fn send_to_client(&self, frame: UnitFrame) {
371        self.inner.borrow_mut().send_downstream(vec![frame.frame])
372    }
373
374    #[cfg(feature = "experimental_websocket_bytes")]
375    /// Deliver the bytes to the policy via `on_http_response_body`.
376    pub fn send_bytes_to_client(&self, bytes: Vec<u8>) {
377        self.inner.borrow_mut().send_downstream_bytes(bytes)
378    }
379
380    /// Dequeue the next frame forwarded toward the server by the policy, or `None`.
381    pub fn next(&self) -> Option<UnitFrame> {
382        self.inner
383            .borrow_mut()
384            .connection_state
385            .server_ready_frames
386            .pop_front()
387            .map(|frame| UnitFrame { frame })
388    }
389
390    #[cfg(feature = "experimental_websocket_bytes")]
391    /// Takes the bytes that reached the server.
392    pub fn bytes(&self) -> Vec<u8> {
393        self.inner
394            .borrow_mut()
395            .connection_state
396            .server_ready_bytes
397            .split_off(0)
398    }
399}
400
401// ── Shared connection state ───────────────────────────────────────────────────
402
403pub(crate) struct ConnectionInner {
404    context_id: u32,
405    chunk_size: usize,
406    http_context: Box<dyn HttpContext>,
407    backends: Rc<RefCell<Backends>>,
408    host: Rc<RefCell<ProxyWasmStub>>,
409    connection_state: ConnectionState,
410}
411
412#[derive(Copy, Clone)]
413pub(crate) enum Direction {
414    Upstream,
415    Downstream,
416}
417
418#[derive(Default)]
419struct ConnectionState {
420    #[cfg(feature = "experimental_websocket_bytes")]
421    server_ready_bytes: Vec<u8>,
422    server_ready_decoder: Decoder,
423    pub(crate) server_ready_frames: VecDeque<Frame>,
424    pub(crate) upstream_paused: bool,
425    #[cfg(feature = "experimental_websocket_bytes")]
426    client_ready_bytes: Vec<u8>,
427    client_ready_decoder: Decoder,
428    pub(crate) client_ready_frames: VecDeque<Frame>,
429    pub(crate) downstream_paused: bool,
430}
431impl ConnectionState {
432    fn set_paused(&mut self, direction: Direction, value: bool) {
433        match direction {
434            Direction::Upstream => {
435                self.upstream_paused = value;
436            }
437            Direction::Downstream => {
438                self.downstream_paused = value;
439            }
440        }
441    }
442
443    fn on_body(
444        &mut self,
445        context_id: u32,
446        context: &mut dyn HttpContext,
447        host: Rc<RefCell<ProxyWasmStub>>,
448        direction: Direction,
449        bytes: Vec<u8>,
450    ) -> Action {
451        match direction {
452            Direction::Upstream => {
453                let mut buffer = host
454                    .borrow()
455                    .get_buffer(BufferType::HttpRequestBody, 0, usize::MAX)
456                    .unwrap_or_default()
457                    .unwrap_or_default();
458
459                buffer.extend_from_slice(&bytes);
460
461                let len = buffer.len();
462                host.borrow_mut()
463                    .create_buffer(context_id, BufferType::HttpRequestBody, buffer);
464
465                context.on_http_request_body(len, false)
466            }
467            Direction::Downstream => {
468                let mut buffer = host
469                    .borrow()
470                    .get_buffer(BufferType::HttpResponseBody, 0, usize::MAX)
471                    .unwrap_or_default()
472                    .unwrap_or_default();
473
474                buffer.extend_from_slice(&bytes);
475
476                let len = buffer.len();
477                host.borrow_mut()
478                    .create_buffer(context_id, BufferType::HttpResponseBody, buffer);
479                context.on_http_response_body(len, false)
480            }
481        }
482    }
483
484    fn clean_buffer(&mut self, context_id: u32, host: &mut ProxyWasmStub, direction: Direction) {
485        match direction {
486            Direction::Upstream => {
487                let bytes = host.read_buffer(context_id, BufferType::HttpRequestBody);
488                #[cfg(feature = "experimental_websocket_bytes")]
489                self.server_ready_bytes.extend_from_slice(&bytes);
490                if let SinkResult::Complete(frames) = self.server_ready_decoder.sink(bytes) {
491                    frames
492                        .into_iter()
493                        .for_each(|frame| self.server_ready_frames.push_back(frame))
494                }
495                host.create_buffer(context_id, BufferType::HttpRequestBody, vec![]);
496            }
497            Direction::Downstream => {
498                let bytes = host.read_buffer(context_id, BufferType::HttpResponseBody);
499                #[cfg(feature = "experimental_websocket_bytes")]
500                self.client_ready_bytes.extend_from_slice(&bytes);
501                if let SinkResult::Complete(frames) = self.client_ready_decoder.sink(bytes) {
502                    frames
503                        .into_iter()
504                        .for_each(|frame| self.client_ready_frames.push_back(frame))
505                }
506                host.create_buffer(context_id, BufferType::HttpResponseBody, vec![]);
507            }
508        }
509    }
510
511    fn resume(&mut self, context_id: u32, host: &mut ProxyWasmStub, direction: Direction) {
512        match direction {
513            Direction::Upstream => {
514                let resume = host.clear_resume_request(context_id);
515                if resume && !self.upstream_paused {
516                    panic!("Called resume on non-paused request state")
517                }
518                if resume {
519                    self.clean_buffer(context_id, host, direction)
520                }
521            }
522            Direction::Downstream => {
523                let resume = host.clear_resume_response(context_id);
524                if resume && !self.downstream_paused {
525                    panic!("Called resume on non-paused response state")
526                }
527                if resume {
528                    self.clean_buffer(context_id, host, direction)
529                }
530            }
531        }
532    }
533}
534
535impl ConnectionInner {
536    pub(crate) fn send_upstream(&mut self, frames: Vec<Frame>) {
537        let encoded = Encoder::default().encode_client(frames);
538        self.send_upstream_bytes(encoded);
539    }
540
541    pub(crate) fn send_upstream_bytes(&mut self, bytes: Vec<u8>) {
542        self.host.borrow_mut().set_flow_mode(FlowType::Upstream);
543        self.drive(Direction::Upstream, bytes);
544    }
545
546    pub(crate) fn send_downstream(&mut self, frames: Vec<Frame>) {
547        let encoded = Encoder::default().encode_server(frames);
548        self.send_downstream_bytes(encoded);
549    }
550
551    pub(crate) fn send_downstream_bytes(&mut self, bytes: Vec<u8>) {
552        self.host.borrow_mut().set_flow_mode(FlowType::Downstream);
553        self.drive(Direction::Downstream, bytes);
554    }
555
556    fn drive(&mut self, direction: Direction, bytes: Vec<u8>) {
557        self.host.borrow_mut().set_context(self.context_id);
558
559        let mut start = 0;
560        let mut end = 0;
561
562        while end < bytes.len() {
563            end += self.chunk_size;
564            if end > bytes.len() {
565                end = bytes.len();
566            }
567
568            let buffer = bytes[start..end].to_vec();
569            let ctx = self.http_context.as_mut();
570            let action = self.connection_state.on_body(
571                self.context_id,
572                ctx,
573                Rc::clone(&self.host),
574                direction,
575                buffer,
576            );
577            match action {
578                Action::Continue => {
579                    self.connection_state.clean_buffer(
580                        self.context_id,
581                        &mut self.host.borrow_mut(),
582                        direction,
583                    );
584                    self.connection_state.set_paused(direction, false);
585                }
586                Action::Pause => {
587                    self.connection_state.set_paused(direction, true);
588                }
589                _ => {
590                    panic!("unexpected action: {action:?}");
591                }
592            }
593            start = end
594        }
595
596        self.respond_calls();
597        self.resume(direction);
598    }
599
600    pub(crate) fn resume(&mut self, direction: Direction) {
601        if self.host.borrow_mut().clear_send_response(self.context_id) {
602            panic!("Called send response on websocket flow.")
603        }
604
605        self.connection_state
606            .resume(self.context_id, &mut self.host.borrow_mut(), direction);
607    }
608
609    fn respond_calls(&mut self) {
610        self.host.borrow_mut().set_context(self.context_id);
611        let prev_flow = self.host.borrow_mut().set_flow_mode(FlowType::Async);
612        let mut pending = self.host.borrow_mut().pending_calls(self.context_id);
613        while !pending.is_empty() {
614            for (id, upstream, call) in pending {
615                respond_call(
616                    self.http_context.as_mut(),
617                    &self.host,
618                    &self.backends,
619                    self.context_id,
620                    id,
621                    upstream,
622                    call,
623                );
624            }
625            pending = self.host.borrow_mut().pending_calls(self.context_id);
626        }
627        self.host.borrow_mut().set_flow_mode(prev_flow);
628    }
629}
630
631impl Drop for ConnectionInner {
632    fn drop(&mut self) {
633        self.host.borrow_mut().set_context(self.context_id);
634        self.http_context.on_done();
635    }
636}