Skip to main content

flaron_sdk/
ws.rs

1//! WebSocket primitives for flares that handle WebSocket upgrades.
2//!
3//! A WebSocket-enabled flare exports `ws_open`, `ws_message`, and `ws_close`
4//! instead of (or in addition to) `handle_request`. Inside those exports the
5//! event accessors return data the host wired up before the call:
6//!
7//! * [`event_type`] - `"open"`, `"message"`, or `"close"`.
8//! * [`event_data`] - the message payload (binary or UTF-8 text bytes).
9//! * [`conn_id`]    - the per-connection identifier the host issues.
10//! * [`close_code`] - close code from the peer (only meaningful in
11//!   `ws_close`).
12//!
13//! Use [`send`] to push frames to the peer and [`close`] to terminate the
14//! connection from the flare side.
15
16use crate::{ffi, mem};
17
18/// Errors returned by [`send`] when the host refuses the frame.
19#[derive(Debug, thiserror::Error)]
20pub enum WsSendError {
21    /// Per-invocation send-rate limit reached. Drop the frame and back off.
22    #[error("ws send: per-invocation rate limit reached")]
23    SendLimitReached,
24
25    /// Frame exceeds the host's per-message size cap.
26    #[error("ws send: message too large")]
27    MessageTooLarge,
28
29    /// Generic send failure (connection closed, write error, etc.).
30    #[error("ws send: write failed")]
31    SendError,
32
33    /// Unknown error code returned by the host.
34    #[error("ws send: unknown error code {0}")]
35    Unknown(i32),
36}
37
38impl WsSendError {
39    fn from_code(code: i32) -> Self {
40        match code {
41            1 => Self::SendLimitReached,
42            2 => Self::MessageTooLarge,
43            3 => Self::SendError,
44            _ => Self::Unknown(code),
45        }
46    }
47}
48
49/// Send a frame to the connected peer.
50///
51/// The host treats the bytes as opaque - pass UTF-8 for a text frame or
52/// arbitrary bytes for a binary frame.
53pub fn send(data: &[u8]) -> Result<(), WsSendError> {
54    let (data_ptr, data_len) = mem::host_arg_bytes(data);
55    let code = unsafe { ffi::ws_send(data_ptr, data_len) };
56    if code == 0 {
57        Ok(())
58    } else {
59        Err(WsSendError::from_code(code))
60    }
61}
62
63/// Convenience: send a text frame.
64pub fn send_text(text: &str) -> Result<(), WsSendError> {
65    send(text.as_bytes())
66}
67
68/// Close the connection with the given WebSocket status code.
69///
70/// Common codes:
71/// * `1000` - Normal Closure
72/// * `1001` - Going Away
73/// * `1002` - Protocol Error
74/// * `1008` - Policy Violation
75/// * `1011` - Internal Error
76pub fn close(code: u16) {
77    unsafe { ffi::ws_close_conn(code as i32) }
78}
79
80/// Per-connection identifier issued by the host. Stable for the lifetime of
81/// a single WebSocket connection.
82pub fn conn_id() -> String {
83    // SAFETY: host writes a valid UTF-8 connection ID into the bump arena.
84    unsafe { mem::read_packed_string(ffi::ws_conn_id()) }.unwrap_or_default()
85}
86
87/// Type of the current WebSocket event: `"open"`, `"message"`, or `"close"`.
88pub fn event_type() -> String {
89    // SAFETY: host writes a valid UTF-8 event type into the bump arena.
90    unsafe { mem::read_packed_string(ffi::ws_event_type()) }.unwrap_or_default()
91}
92
93/// Payload bytes for a `"message"` event. Returns an empty `Vec<u8>` for
94/// `"open"` and `"close"` events.
95pub fn event_data() -> Vec<u8> {
96    // SAFETY: host writes the event payload bytes into the bump arena.
97    unsafe { mem::read_packed_bytes(ffi::ws_event_data()) }.unwrap_or_default()
98}
99
100/// Convenience: event payload interpreted as a UTF-8 string. Invalid UTF-8
101/// is replaced with the Unicode replacement character.
102pub fn event_text() -> String {
103    String::from_utf8_lossy(&event_data()).into_owned()
104}
105
106/// Close code provided by the remote peer (only meaningful inside a
107/// `ws_close` handler - `0` otherwise).
108pub fn close_code() -> u16 {
109    unsafe { ffi::ws_close_code() as u16 }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use crate::ffi::test_host;
116
117    #[test]
118    fn send_records_payload() {
119        test_host::reset();
120        send(b"hello").unwrap();
121        let sends = test_host::read_mock(|m| m.ws_sends.clone());
122        assert_eq!(sends, vec![b"hello".to_vec()]);
123    }
124
125    #[test]
126    fn send_text_records_utf8_bytes() {
127        test_host::reset();
128        send_text("héllo").unwrap();
129        let sends = test_host::read_mock(|m| m.ws_sends.clone());
130        assert_eq!(sends, vec!["héllo".as_bytes().to_vec()]);
131    }
132
133    #[test]
134    fn send_maps_error_codes() {
135        for (code, expected) in [
136            (1, WsSendError::SendLimitReached),
137            (2, WsSendError::MessageTooLarge),
138            (3, WsSendError::SendError),
139        ] {
140            test_host::reset();
141            test_host::with_mock(|m| m.ws_send_error = code);
142            let err = send(b"x").unwrap_err();
143            assert!(
144                std::mem::discriminant(&err) == std::mem::discriminant(&expected),
145                "code {} mismatch",
146                code,
147            );
148        }
149    }
150
151    #[test]
152    fn send_unknown_error_code() {
153        test_host::reset();
154        test_host::with_mock(|m| m.ws_send_error = 99);
155        match send(b"x").unwrap_err() {
156            WsSendError::Unknown(99) => {}
157            other => panic!("expected Unknown(99), got {:?}", other),
158        }
159    }
160
161    #[test]
162    fn close_records_code() {
163        test_host::reset();
164        close(1000);
165        close(1011);
166        assert_eq!(
167            test_host::read_mock(|m| m.ws_closes.clone()),
168            vec![1000, 1011]
169        );
170    }
171
172    #[test]
173    fn conn_id_returns_host_value() {
174        test_host::reset();
175        test_host::with_mock(|m| m.ws_conn_id = Some("conn-abc-123".into()));
176        assert_eq!(conn_id(), "conn-abc-123");
177    }
178
179    #[test]
180    fn conn_id_empty_when_unset() {
181        test_host::reset();
182        assert_eq!(conn_id(), "");
183    }
184
185    #[test]
186    fn event_type_open_message_close() {
187        for ty in ["open", "message", "close"] {
188            test_host::reset();
189            test_host::with_mock(|m| m.ws_event_type = Some(ty.into()));
190            assert_eq!(event_type(), ty);
191        }
192    }
193
194    #[test]
195    fn event_data_returns_payload_bytes() {
196        test_host::reset();
197        test_host::with_mock(|m| m.ws_event_data = Some(vec![1, 2, 3, 4]));
198        assert_eq!(event_data(), vec![1, 2, 3, 4]);
199    }
200
201    #[test]
202    fn event_data_empty_for_open_close() {
203        test_host::reset();
204        assert!(event_data().is_empty());
205    }
206
207    #[test]
208    fn event_text_decodes_utf8() {
209        test_host::reset();
210        test_host::with_mock(|m| m.ws_event_data = Some("héllo".as_bytes().to_vec()));
211        assert_eq!(event_text(), "héllo");
212    }
213
214    #[test]
215    fn close_code_returns_host_value() {
216        test_host::reset();
217        test_host::with_mock(|m| m.ws_close_code = 1006);
218        assert_eq!(close_code(), 1006);
219    }
220}