Skip to main content

pdk_websockets_lib/
lib.rs

1// Copyright (c) 2026, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5//! PDK WebSockets Library
6//!
7//! Library for decoding and encoding WebSocket frames in Flex Gateway custom policies.
8//! It wraps [`websocket-sans-io`] to provide ergonomic frame-level access for policies
9//! that operate on WebSocket upgrade connections.
10//!
11//! ## Primary types
12//!
13//! - [`Frame`]: a single WebSocket frame with its payload and metadata
14//! - [`FrameType`]: the kind of frame (Text, Binary, Ping, Pong, etc.)
15//! - [`Decoder`]: incrementally decodes raw bytes into [`Frame`]s
16//! - [`Encoder`]: re-encodes a collection of [`Frame`]s into bytes
17//! - [`SinkResult`]: outcome of feeding bytes to the [`Decoder`]
18//!
19//! ## Example
20//!
21//! ```ignore
22//! use pdk_websockets::{Decoder, Encoder, Frame, FrameType, SinkResult};
23//!
24//! let mut decoder = Decoder::default();
25//! match decoder.sink(raw_bytes) {
26//!     SinkResult::MidFrame => { /* pause and wait for more bytes */ }
27//!     SinkResult::Complete(mut frames) => {
28//!         // inspect frames
29//!         if let Some(frame) = frames.first() {
30//!             if let FrameType::Text = frame.frame_type() {
31//!                 let text = String::from_utf8_lossy(frame.data());
32//!             }
33//!         }
34//!         // modify frames, then re-encode
35//!         frames.push(Frame::ping());
36//!         let bytes = Encoder::default().encode(frames);
37//!     }
38//! }
39//! ```
40
41use websocket_sans_io::{
42    FrameInfo, Opcode, WebsocketFrameDecoder, WebsocketFrameEncoder, WebsocketFrameEvent,
43};
44
45/// A single WebSocket frame.
46#[derive(Clone, Debug)]
47pub struct Frame {
48    info: FrameInfo,
49    data: Vec<u8>,
50}
51
52/// The kind of a WebSocket frame, derived from its opcode.
53#[non_exhaustive]
54#[derive(Debug, Clone, Copy, PartialEq)]
55pub enum FrameType {
56    /// A UTF-8 text data frame.
57    Text,
58    /// A binary data frame.
59    Binary,
60    /// A continuation frame for a fragmented message.
61    Continuation,
62    /// A ping control frame.
63    Ping,
64    /// A pong control frame.
65    Pong,
66    /// A connection-close control frame.
67    ConnectionClose,
68    /// A frame with a reserved or unknown opcode.
69    Reserved,
70}
71
72impl Frame {
73    /// Creates a ping control frame.
74    pub fn ping() -> Self {
75        Self::control_frame(Opcode::Ping)
76    }
77
78    /// Creates a pong control frame.
79    pub fn pong() -> Self {
80        Self::control_frame(Opcode::Pong)
81    }
82
83    /// Creates a connection-close control frame.
84    pub fn connection_close() -> Self {
85        Self::control_frame(Opcode::ConnectionClose)
86    }
87
88    /// Creates a text data frame.
89    ///
90    /// Set `fin` to `true` for an unfragmented message or the final fragment,
91    /// `false` for intermediate fragments.
92    pub fn text<T: Into<Vec<u8>>>(text: T, fin: bool) -> Frame {
93        Self::data_frame(Opcode::Text, text, fin)
94    }
95
96    /// Creates a binary data frame.
97    ///
98    /// Set `fin` to `true` for an unfragmented message or the final fragment,
99    /// `false` for intermediate fragments.
100    pub fn binary<T: Into<Vec<u8>>>(text: T, fin: bool) -> Frame {
101        Self::data_frame(Opcode::Binary, text, fin)
102    }
103
104    /// Creates a continuation data frame for a fragmented message.
105    ///
106    /// Set `fin` to `true` for the final fragment, `false` for intermediate ones.
107    pub fn continuation<T: Into<Vec<u8>>>(text: T, fin: bool) -> Frame {
108        Self::data_frame(Opcode::Continuation, text, fin)
109    }
110}
111
112impl Frame {
113    /// Returns the [`FrameType`] of this frame.
114    pub fn frame_type(&self) -> FrameType {
115        match self.info.opcode {
116            Opcode::Continuation => FrameType::Continuation,
117            Opcode::Text => FrameType::Text,
118            Opcode::Binary => FrameType::Binary,
119            Opcode::ConnectionClose => FrameType::ConnectionClose,
120            Opcode::Ping => FrameType::Ping,
121            Opcode::Pong => FrameType::Pong,
122            _ => FrameType::Reserved,
123        }
124    }
125
126    /// Returns a reference to the frame's payload bytes.
127    pub fn data(&self) -> &[u8] {
128        &self.data
129    }
130
131    /// Consumes the frame and returns its payload bytes.
132    pub fn take(self) -> Vec<u8> {
133        self.data
134    }
135
136    /// Replaces the frame's payload with `data`.
137    pub fn update<U: Into<Vec<u8>>>(&mut self, data: U) {
138        let data = data.into();
139        self.info.payload_length = data.len() as u64;
140        self.data = data;
141    }
142
143    /// Returns `true` if this is the final fragment of a message (FIN bit is set).
144    pub fn fin(&self) -> bool {
145        self.info.fin
146    }
147}
148
149/// Internal functions
150impl Frame {
151    fn control_frame(opcode: Opcode) -> Frame {
152        Frame {
153            info: Self::info(opcode, &[], true),
154            data: Vec::new(),
155        }
156    }
157
158    fn data_frame<D: Into<Vec<u8>>>(opcode: Opcode, data: D, fin: bool) -> Frame {
159        let data = data.into();
160        Frame {
161            info: Self::info(opcode, &data, fin),
162            data,
163        }
164    }
165
166    fn info(opcode: Opcode, data: &[u8], fin: bool) -> FrameInfo {
167        FrameInfo {
168            opcode,
169            payload_length: data.len() as u64,
170            mask: Some(rand::random()),
171            fin,
172            reserved: 0,
173        }
174    }
175
176    fn encode(mut self, encoder: &mut WebsocketFrameEncoder, result: &mut Vec<u8>) {
177        result.extend(encoder.start_frame(&self.info));
178        if self.info.payload_length != 0 {
179            encoder.transform_frame_payload(&mut self.data);
180            result.extend(self.data);
181        }
182    }
183}
184
185/// Incrementally decodes raw bytes into [`Frame`]s.
186///
187/// Feed chunks of bytes via [`Decoder::sink`]. If the last frame in a chunk is
188/// incomplete, `sink` returns [`SinkResult::MidFrame`] and buffers the partial
189/// state internally. Call `sink` again with the next chunk to continue.
190#[derive(Default)]
191pub struct Decoder {
192    decoder: WebsocketFrameDecoder,
193    started: bool,
194    ongoing: Vec<u8>,
195    parsed: Vec<Frame>,
196}
197
198/// The result of feeding bytes to [`Decoder::sink`].
199pub enum SinkResult {
200    /// All bytes were consumed and every frame found is complete.
201    /// Contains the decoded frames.
202    Complete(Vec<Frame>),
203    /// The last frame in the supplied bytes is incomplete.
204    /// The decoder has buffered the partial state; supply more bytes to finish it.
205    MidFrame,
206}
207
208impl Decoder {
209    /// Drains and returns all frames that have been fully decoded so far,
210    /// without consuming partially-decoded state.
211    pub fn take_complete(&mut self) -> Vec<Frame> {
212        self.parsed.split_off(0)
213    }
214
215    /// Feeds a chunk of raw bytes into the decoder.
216    ///
217    /// Returns [`SinkResult::Complete`] with all fully parsed frames when every
218    /// byte in `body` has been consumed, or [`SinkResult::MidFrame`] when the
219    /// final frame in the chunk is still incomplete.
220    pub fn sink(&mut self, mut body: Vec<u8>) -> SinkResult {
221        let mut position = 0;
222        while position < body.len() {
223            // this should never happen as infallible due to feature selection
224            let frame = self.decoder.add_data(&mut body[position..]).unwrap();
225
226            if let Some(event) = frame.event {
227                match event {
228                    WebsocketFrameEvent::Start { .. } => {
229                        self.started = true;
230                    }
231                    WebsocketFrameEvent::PayloadChunk { .. } => {
232                        self.ongoing
233                            .extend_from_slice(&body[position..position + frame.consumed_bytes]);
234                    }
235                    WebsocketFrameEvent::End { frame_info, .. } => {
236                        self.parsed.push(Frame {
237                            info: frame_info,
238                            data: self.ongoing.split_off(0),
239                        });
240                        self.started = false;
241                    }
242                }
243                position += frame.consumed_bytes;
244            }
245        }
246
247        if self.started {
248            // call with empty data to see if end event is there
249            let frame = self.decoder.add_data(&mut []).unwrap();
250            if let Some(WebsocketFrameEvent::End { frame_info, .. }) = frame.event {
251                self.parsed.push(Frame {
252                    info: frame_info,
253                    data: self.ongoing.split_off(0),
254                });
255                self.started = false;
256            }
257        }
258
259        if self.started {
260            SinkResult::MidFrame
261        } else {
262            SinkResult::Complete(self.parsed.split_off(0))
263        }
264    }
265}
266
267/// Encodes a collection of [`Frame`]s back into raw bytes suitable for
268/// writing to a WebSocket connection.
269#[derive(Default)]
270pub struct Encoder {}
271
272impl Encoder {
273    /// Encodes each frame in `frame` and returns the concatenated byte representation for the flow client to server direction.
274    pub fn encode_client(&mut self, frame: Vec<Frame>) -> Vec<u8> {
275        let mut encoder = WebsocketFrameEncoder::new();
276
277        let mut result = Vec::new();
278        for frame in frame {
279            frame.encode(&mut encoder, &mut result);
280        }
281
282        result
283    }
284
285    /// Encodes each frame in `frame` and returns the concatenated byte representation for the flow server to client direction.
286    pub fn encode_server(&mut self, mut frame: Vec<Frame>) -> Vec<u8> {
287        for frame in &mut frame {
288            frame.info.mask = None;
289        }
290
291        self.encode_client(frame)
292    }
293}
294
295#[derive(Default, PartialEq)]
296enum State {
297    #[default]
298    Http,
299    Websocket,
300}
301
302#[derive(Default)]
303pub struct UpgradeTracker {
304    state: State,
305}
306
307impl UpgradeTracker {
308    pub fn ready(&self) -> bool {
309        self.state == State::Websocket
310    }
311
312    pub fn track_upgrade_headers(&mut self, status: Option<&str>, upgrade: Option<&str>) {
313        if self.state == State::Http {
314            let is_switching = status == Some("101");
315            let is_websocket = upgrade == Some("websocket");
316            if is_switching && is_websocket {
317                self.state = State::Websocket;
318            }
319        }
320    }
321}