pdk_websockets_lib/lib.rs
1// Copyright (c) 2026, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5//! PDK WebSockets Library
6//!
7//! Library for decoding and encoding WebSocket frames in Flex Gateway custom policies.
8//! It wraps [`websocket-sans-io`] to provide ergonomic frame-level access for policies
9//! that operate on WebSocket upgrade connections.
10//!
11//! ## Primary types
12//!
13//! - [`Frame`]: a single WebSocket frame with its payload and metadata
14//! - [`FrameType`]: the kind of frame (Text, Binary, Ping, Pong, etc.)
15//! - [`Decoder`]: incrementally decodes raw bytes into [`Frame`]s
16//! - [`Encoder`]: re-encodes a collection of [`Frame`]s into bytes
17//! - [`SinkResult`]: outcome of feeding bytes to the [`Decoder`]
18//!
19//! ## Example
20//!
21//! ```ignore
22//! use pdk_websockets::{Decoder, Encoder, Frame, FrameType, SinkResult};
23//!
24//! let mut decoder = Decoder::default();
25//! match decoder.sink(raw_bytes) {
26//! SinkResult::MidFrame => { /* pause and wait for more bytes */ }
27//! SinkResult::Complete(mut frames) => {
28//! // inspect frames
29//! if let Some(frame) = frames.first() {
30//! if let FrameType::Text = frame.frame_type() {
31//! let text = String::from_utf8_lossy(frame.data());
32//! }
33//! }
34//! // modify frames, then re-encode
35//! frames.push(Frame::ping());
36//! let bytes = Encoder::default().encode(frames);
37//! }
38//! }
39//! ```
40
41use websocket_sans_io::{
42 FrameInfo, Opcode, WebsocketFrameDecoder, WebsocketFrameEncoder, WebsocketFrameEvent,
43};
44
45/// A single WebSocket frame.
46#[derive(Clone, Debug)]
47pub struct Frame {
48 info: FrameInfo,
49 data: Vec<u8>,
50}
51
52/// The kind of a WebSocket frame, derived from its opcode.
53#[non_exhaustive]
54#[derive(Debug, Clone, Copy, PartialEq)]
55pub enum FrameType {
56 /// A UTF-8 text data frame.
57 Text,
58 /// A binary data frame.
59 Binary,
60 /// A continuation frame for a fragmented message.
61 Continuation,
62 /// A ping control frame.
63 Ping,
64 /// A pong control frame.
65 Pong,
66 /// A connection-close control frame.
67 ConnectionClose,
68 /// A frame with a reserved or unknown opcode.
69 Reserved,
70}
71
72impl Frame {
73 /// Creates a ping control frame.
74 pub fn ping() -> Self {
75 Self::control_frame(Opcode::Ping)
76 }
77
78 /// Creates a pong control frame.
79 pub fn pong() -> Self {
80 Self::control_frame(Opcode::Pong)
81 }
82
83 /// Creates a connection-close control frame.
84 pub fn connection_close() -> Self {
85 Self::control_frame(Opcode::ConnectionClose)
86 }
87
88 /// Creates a text data frame.
89 ///
90 /// Set `fin` to `true` for an unfragmented message or the final fragment,
91 /// `false` for intermediate fragments.
92 pub fn text<T: Into<Vec<u8>>>(text: T, fin: bool) -> Frame {
93 Self::data_frame(Opcode::Text, text, fin)
94 }
95
96 /// Creates a binary data frame.
97 ///
98 /// Set `fin` to `true` for an unfragmented message or the final fragment,
99 /// `false` for intermediate fragments.
100 pub fn binary<T: Into<Vec<u8>>>(text: T, fin: bool) -> Frame {
101 Self::data_frame(Opcode::Binary, text, fin)
102 }
103
104 /// Creates a continuation data frame for a fragmented message.
105 ///
106 /// Set `fin` to `true` for the final fragment, `false` for intermediate ones.
107 pub fn continuation<T: Into<Vec<u8>>>(text: T, fin: bool) -> Frame {
108 Self::data_frame(Opcode::Continuation, text, fin)
109 }
110}
111
112impl Frame {
113 /// Returns the [`FrameType`] of this frame.
114 pub fn frame_type(&self) -> FrameType {
115 match self.info.opcode {
116 Opcode::Continuation => FrameType::Continuation,
117 Opcode::Text => FrameType::Text,
118 Opcode::Binary => FrameType::Binary,
119 Opcode::ConnectionClose => FrameType::ConnectionClose,
120 Opcode::Ping => FrameType::Ping,
121 Opcode::Pong => FrameType::Pong,
122 _ => FrameType::Reserved,
123 }
124 }
125
126 /// Returns a reference to the frame's payload bytes.
127 pub fn data(&self) -> &[u8] {
128 &self.data
129 }
130
131 /// Consumes the frame and returns its payload bytes.
132 pub fn take(self) -> Vec<u8> {
133 self.data
134 }
135
136 /// Replaces the frame's payload with `data`.
137 pub fn update<U: Into<Vec<u8>>>(&mut self, data: U) {
138 let data = data.into();
139 self.info.payload_length = data.len() as u64;
140 self.data = data;
141 }
142
143 /// Returns `true` if this is the final fragment of a message (FIN bit is set).
144 pub fn fin(&self) -> bool {
145 self.info.fin
146 }
147}
148
149/// Internal functions
150impl Frame {
151 fn control_frame(opcode: Opcode) -> Frame {
152 Frame {
153 info: Self::info(opcode, &[], true),
154 data: Vec::new(),
155 }
156 }
157
158 fn data_frame<D: Into<Vec<u8>>>(opcode: Opcode, data: D, fin: bool) -> Frame {
159 let data = data.into();
160 Frame {
161 info: Self::info(opcode, &data, fin),
162 data,
163 }
164 }
165
166 fn info(opcode: Opcode, data: &[u8], fin: bool) -> FrameInfo {
167 FrameInfo {
168 opcode,
169 payload_length: data.len() as u64,
170 mask: Some(rand::random()),
171 fin,
172 reserved: 0,
173 }
174 }
175
176 fn encode(mut self, encoder: &mut WebsocketFrameEncoder, result: &mut Vec<u8>) {
177 result.extend(encoder.start_frame(&self.info));
178 if self.info.payload_length != 0 {
179 encoder.transform_frame_payload(&mut self.data);
180 result.extend(self.data);
181 }
182 }
183}
184
185/// Incrementally decodes raw bytes into [`Frame`]s.
186///
187/// Feed chunks of bytes via [`Decoder::sink`]. If the last frame in a chunk is
188/// incomplete, `sink` returns [`SinkResult::MidFrame`] and buffers the partial
189/// state internally. Call `sink` again with the next chunk to continue.
190#[derive(Default)]
191pub struct Decoder {
192 decoder: WebsocketFrameDecoder,
193 started: bool,
194 ongoing: Vec<u8>,
195 parsed: Vec<Frame>,
196}
197
198/// The result of feeding bytes to [`Decoder::sink`].
199pub enum SinkResult {
200 /// All bytes were consumed and every frame found is complete.
201 /// Contains the decoded frames.
202 Complete(Vec<Frame>),
203 /// The last frame in the supplied bytes is incomplete.
204 /// The decoder has buffered the partial state; supply more bytes to finish it.
205 MidFrame,
206}
207
208impl Decoder {
209 /// Drains and returns all frames that have been fully decoded so far,
210 /// without consuming partially-decoded state.
211 pub fn take_complete(&mut self) -> Vec<Frame> {
212 self.parsed.split_off(0)
213 }
214
215 /// Feeds a chunk of raw bytes into the decoder.
216 ///
217 /// Returns [`SinkResult::Complete`] with all fully parsed frames when every
218 /// byte in `body` has been consumed, or [`SinkResult::MidFrame`] when the
219 /// final frame in the chunk is still incomplete.
220 pub fn sink(&mut self, mut body: Vec<u8>) -> SinkResult {
221 let mut position = 0;
222 while position < body.len() {
223 // this should never happen as infallible due to feature selection
224 let frame = self.decoder.add_data(&mut body[position..]).unwrap();
225
226 if let Some(event) = frame.event {
227 match event {
228 WebsocketFrameEvent::Start { .. } => {
229 self.started = true;
230 }
231 WebsocketFrameEvent::PayloadChunk { .. } => {
232 self.ongoing
233 .extend_from_slice(&body[position..position + frame.consumed_bytes]);
234 }
235 WebsocketFrameEvent::End { frame_info, .. } => {
236 self.parsed.push(Frame {
237 info: frame_info,
238 data: self.ongoing.split_off(0),
239 });
240 self.started = false;
241 }
242 }
243 position += frame.consumed_bytes;
244 }
245 }
246
247 if self.started {
248 // call with empty data to see if end event is there
249 let frame = self.decoder.add_data(&mut []).unwrap();
250 if let Some(WebsocketFrameEvent::End { frame_info, .. }) = frame.event {
251 self.parsed.push(Frame {
252 info: frame_info,
253 data: self.ongoing.split_off(0),
254 });
255 self.started = false;
256 }
257 }
258
259 if self.started {
260 SinkResult::MidFrame
261 } else {
262 SinkResult::Complete(self.parsed.split_off(0))
263 }
264 }
265}
266
267/// Encodes a collection of [`Frame`]s back into raw bytes suitable for
268/// writing to a WebSocket connection.
269#[derive(Default)]
270pub struct Encoder {}
271
272impl Encoder {
273 /// Encodes each frame in `frame` and returns the concatenated byte representation for the flow client to server direction.
274 pub fn encode_client(&mut self, frame: Vec<Frame>) -> Vec<u8> {
275 let mut encoder = WebsocketFrameEncoder::new();
276
277 let mut result = Vec::new();
278 for frame in frame {
279 frame.encode(&mut encoder, &mut result);
280 }
281
282 result
283 }
284
285 /// Encodes each frame in `frame` and returns the concatenated byte representation for the flow server to client direction.
286 pub fn encode_server(&mut self, mut frame: Vec<Frame>) -> Vec<u8> {
287 for frame in &mut frame {
288 frame.info.mask = None;
289 }
290
291 self.encode_client(frame)
292 }
293}
294
295#[derive(Default, PartialEq)]
296enum State {
297 #[default]
298 Http,
299 Websocket,
300}
301
302#[derive(Default)]
303pub struct UpgradeTracker {
304 state: State,
305}
306
307impl UpgradeTracker {
308 pub fn ready(&self) -> bool {
309 self.state == State::Websocket
310 }
311
312 pub fn track_upgrade_headers(&mut self, status: Option<&str>, upgrade: Option<&str>) {
313 if self.state == State::Http {
314 let is_switching = status == Some("101");
315 let is_websocket = upgrade == Some("websocket");
316 if is_switching && is_websocket {
317 self.state = State::Websocket;
318 }
319 }
320 }
321}