Skip to main content

faucet_source_websocket/
config.rs

1//! Configuration types for the WebSocket source.
2
3use base64::Engine;
4use faucet_core::{AuthSpec, DEFAULT_BATCH_SIZE, FaucetError};
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use serde_json::{Value, json};
8use std::collections::BTreeMap;
9use std::time::Duration;
10
11/// Configuration for the WebSocket source.
12#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
13pub struct WebsocketSourceConfig {
14    /// WebSocket endpoint, `ws://` or `wss://`. Supports `{placeholder}`
15    /// parent-matrix context substitution.
16    pub url: String,
17
18    /// Authentication applied to the HTTP upgrade request. Either inline
19    /// (`{ type, config }`) or a `{ ref: <name> }` pointer to a shared
20    /// provider in the CLI's top-level `auth:` catalog.
21    #[serde(default)]
22    pub auth: AuthSpec<WebsocketAuth>,
23
24    /// Subscription frames sent (in order) immediately after every
25    /// (re)connect. Empty = send nothing.
26    #[serde(default)]
27    pub subscribe_messages: Vec<String>,
28
29    /// How to interpret each incoming frame.
30    #[serde(default)]
31    pub message_format: WsMessageFormat,
32
33    /// In `Json` mode, what to do when a frame is not valid JSON.
34    #[serde(default)]
35    pub on_parse_error: OnParseError,
36
37    /// `false` (default) emits the record raw; `true` wraps it as
38    /// `{ data, received_at, url }`.
39    #[serde(default)]
40    pub envelope: bool,
41
42    /// If set, send a WebSocket Ping frame on this interval (seconds) to keep
43    /// the connection alive through proxies/load balancers.
44    #[serde(
45        default,
46        skip_serializing_if = "Option::is_none",
47        with = "faucet_core::config::duration_secs_option"
48    )]
49    #[schemars(with = "Option<u64>")]
50    pub ping_interval: Option<Duration>,
51
52    /// Stop after this many messages. At least one of `max_messages` /
53    /// `idle_timeout` must be set.
54    #[serde(default, skip_serializing_if = "Option::is_none")]
55    pub max_messages: Option<usize>,
56
57    /// Stop after this many seconds with no message. The idle clock keeps
58    /// ticking across reconnect gaps, so it also caps a connection outage.
59    #[serde(
60        default,
61        skip_serializing_if = "Option::is_none",
62        with = "faucet_core::config::duration_secs_option"
63    )]
64    #[schemars(with = "Option<u64>")]
65    pub idle_timeout: Option<Duration>,
66
67    /// Reconnect on transport error / non-1000 close.
68    #[serde(default)]
69    pub reconnect: bool,
70
71    /// Fixed wait (seconds) between reconnect attempts. Default 1s.
72    #[serde(
73        default = "default_backoff",
74        with = "faucet_core::config::duration_secs"
75    )]
76    #[schemars(with = "u64")]
77    pub reconnect_backoff: Duration,
78
79    /// Cap on *consecutive* failed reconnects (resets on any received
80    /// message). `None` = unlimited (then `idle_timeout` is the natural cap).
81    #[serde(default, skip_serializing_if = "Option::is_none")]
82    pub max_reconnect_attempts: Option<usize>,
83
84    /// Bound the max WebSocket message/frame size (bytes) to prevent runaway
85    /// memory. `None` = tungstenite default (64 MiB message / 16 MiB frame).
86    #[serde(default, skip_serializing_if = "Option::is_none")]
87    pub max_message_bytes: Option<usize>,
88
89    /// Records per emitted [`StreamPage`](faucet_core::StreamPage). Default
90    /// [`DEFAULT_BATCH_SIZE`]. `0` drains the entire run window into a single
91    /// page (same sentinel as the Kafka source).
92    #[serde(default = "default_batch_size")]
93    pub batch_size: usize,
94}
95
96fn default_backoff() -> Duration {
97    Duration::from_secs(1)
98}
99
100fn default_batch_size() -> usize {
101    DEFAULT_BATCH_SIZE
102}
103
104/// Authentication for the WebSocket upgrade request.
105#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
106#[serde(tag = "type", content = "config", rename_all = "snake_case")]
107pub enum WebsocketAuth {
108    /// No authentication (default).
109    #[default]
110    None,
111    /// `Authorization: Bearer <token>`.
112    Bearer { token: String },
113    /// Arbitrary request headers.
114    Custom { headers: BTreeMap<String, String> },
115}
116
117/// How each incoming frame is converted into a record.
118#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
119#[serde(rename_all = "snake_case")]
120pub enum WsMessageFormat {
121    /// Parse the frame payload as JSON (default).
122    #[default]
123    Json,
124    /// Emit the frame payload as a UTF-8 string (lossy for invalid UTF-8).
125    RawString,
126    /// Base64-encode the frame payload as a string.
127    Binary,
128}
129
130/// What to do when a `Json`-mode frame is not valid JSON.
131#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
132#[serde(rename_all = "snake_case")]
133pub enum OnParseError {
134    /// Abort the run with a [`FaucetError::Source`] (default).
135    #[default]
136    Fail,
137    /// Log a warning and drop the frame.
138    Skip,
139}
140
141impl WebsocketSourceConfig {
142    /// Validate the config at construction time. Called by `WebsocketSource::new`.
143    pub fn validate(&self) -> Result<(), FaucetError> {
144        let url = self.url.trim();
145        if url.is_empty() {
146            return Err(FaucetError::Config(
147                "websocket source: url must not be empty".into(),
148            ));
149        }
150        if !(url.starts_with("ws://") || url.starts_with("wss://")) {
151            return Err(FaucetError::Config(format!(
152                "websocket source: url must start with ws:// or wss:// (got {url})"
153            )));
154        }
155        if self.max_messages.is_none() && self.idle_timeout.is_none() {
156            return Err(FaucetError::Config(
157                "websocket source: at least one of max_messages or idle_timeout must be set".into(),
158            ));
159        }
160        faucet_core::validate_batch_size(self.batch_size)?;
161        Ok(())
162    }
163
164    /// Set the per-page record count for
165    /// [`Source::stream_pages`](faucet_core::Source::stream_pages). Pass `0` to
166    /// drain the entire run window into a single page.
167    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
168        self.batch_size = batch_size;
169        self
170    }
171}
172
173/// Convert a data-frame payload into a record value per `format`.
174///
175/// Returns `Ok(None)` only when `on_parse_error == Skip` swallows an invalid
176/// `Json` frame; the caller drops it. Used for both Text and Binary frames —
177/// `payload` is the raw frame bytes (for Text frames, the UTF-8 bytes).
178pub(crate) fn decode_frame(
179    format: WsMessageFormat,
180    on_parse_error: OnParseError,
181    payload: &[u8],
182) -> Result<Option<Value>, FaucetError> {
183    match format {
184        WsMessageFormat::Json => match serde_json::from_slice::<Value>(payload) {
185            Ok(v) => Ok(Some(v)),
186            Err(e) => match on_parse_error {
187                OnParseError::Fail => {
188                    Err(FaucetError::Source(format!("websocket json parse: {e}")))
189                }
190                OnParseError::Skip => {
191                    tracing::warn!(error = %e, "websocket source: dropping non-JSON frame");
192                    Ok(None)
193                }
194            },
195        },
196        WsMessageFormat::RawString => Ok(Some(Value::String(
197            String::from_utf8_lossy(payload).into_owned(),
198        ))),
199        WsMessageFormat::Binary => {
200            let encoded = base64::engine::general_purpose::STANDARD.encode(payload);
201            Ok(Some(Value::String(encoded)))
202        }
203    }
204}
205
206/// Wrap (or not) the decoded value into the emitted record shape.
207///
208/// `now_ms` is injected so the function stays pure and testable; the stream
209/// loop passes `now_unix_ms()` only when `envelope` is true.
210pub(crate) fn shape_record(value: Value, envelope: bool, url: &str, now_ms: u64) -> Value {
211    if envelope {
212        json!({ "data": value, "received_at": now_ms, "url": url })
213    } else {
214        value
215    }
216}
217
218#[cfg(test)]
219mod config_tests {
220    use super::*;
221
222    fn minimal() -> WebsocketSourceConfig {
223        WebsocketSourceConfig {
224            url: "wss://example.com/ws".into(),
225            auth: AuthSpec::Inline(WebsocketAuth::None),
226            subscribe_messages: vec![],
227            message_format: WsMessageFormat::Json,
228            on_parse_error: OnParseError::Fail,
229            envelope: false,
230            ping_interval: None,
231            max_messages: Some(10),
232            idle_timeout: None,
233            reconnect: false,
234            reconnect_backoff: Duration::from_secs(1),
235            max_reconnect_attempts: None,
236            max_message_bytes: None,
237            batch_size: DEFAULT_BATCH_SIZE,
238        }
239    }
240
241    #[test]
242    fn validate_accepts_minimal() {
243        assert!(minimal().validate().is_ok());
244    }
245
246    #[test]
247    fn validate_rejects_empty_url() {
248        let mut c = minimal();
249        c.url = "  ".into();
250        assert!(c.validate().is_err());
251    }
252
253    #[test]
254    fn validate_rejects_non_ws_scheme() {
255        let mut c = minimal();
256        c.url = "https://example.com".into();
257        assert!(c.validate().is_err());
258    }
259
260    #[test]
261    fn validate_rejects_no_termination() {
262        let mut c = minimal();
263        c.max_messages = None;
264        c.idle_timeout = None;
265        assert!(c.validate().is_err());
266    }
267
268    #[test]
269    fn validate_accepts_idle_only() {
270        let mut c = minimal();
271        c.max_messages = None;
272        c.idle_timeout = Some(Duration::from_secs(5));
273        assert!(c.validate().is_ok());
274    }
275
276    #[test]
277    fn validate_rejects_oversize_batch() {
278        let mut c = minimal();
279        c.batch_size = faucet_core::MAX_BATCH_SIZE + 1;
280        assert!(c.validate().is_err());
281    }
282
283    #[test]
284    fn auth_bearer_round_trips_as_adjacently_tagged() {
285        // WebsocketAuth uses tag="type", content="config" (adjacent tagging).
286        let json = serde_json::json!({"type": "bearer", "config": {"token": "abc"}});
287        let auth: WebsocketAuth = serde_json::from_value(json).unwrap();
288        assert_eq!(
289            auth,
290            WebsocketAuth::Bearer {
291                token: "abc".into()
292            }
293        );
294    }
295
296    #[test]
297    fn auth_spec_inline_round_trips() {
298        // AuthSpec wraps WebsocketAuth; the inline shape uses the adjacent-tagged format.
299        let json = serde_json::json!({"type": "bearer", "config": {"token": "tok"}});
300        let spec: AuthSpec<WebsocketAuth> = serde_json::from_value(json).unwrap();
301        assert!(matches!(
302            spec,
303            AuthSpec::Inline(WebsocketAuth::Bearer { .. })
304        ));
305    }
306
307    #[test]
308    fn auth_spec_ref_round_trips() {
309        let json = serde_json::json!({"ref": "my-provider"});
310        let spec: AuthSpec<WebsocketAuth> = serde_json::from_value(json).unwrap();
311        assert_eq!(spec.reference_name(), Some("my-provider"));
312    }
313}
314
315#[cfg(test)]
316mod helper_tests {
317    use super::*;
318    use serde_json::json;
319
320    #[test]
321    fn decode_json_object() {
322        let v = decode_frame(WsMessageFormat::Json, OnParseError::Fail, br#"{"a":1}"#)
323            .unwrap()
324            .unwrap();
325        assert_eq!(v, json!({"a": 1}));
326    }
327
328    #[test]
329    fn decode_json_invalid_fails() {
330        let r = decode_frame(WsMessageFormat::Json, OnParseError::Fail, b"not json");
331        assert!(r.is_err());
332    }
333
334    #[test]
335    fn decode_json_invalid_skipped_yields_none() {
336        let r = decode_frame(WsMessageFormat::Json, OnParseError::Skip, b"not json").unwrap();
337        assert!(r.is_none());
338    }
339
340    #[test]
341    fn decode_raw_string() {
342        let v = decode_frame(WsMessageFormat::RawString, OnParseError::Fail, b"hello")
343            .unwrap()
344            .unwrap();
345        assert_eq!(v, json!("hello"));
346    }
347
348    #[test]
349    fn decode_binary_base64() {
350        let v = decode_frame(WsMessageFormat::Binary, OnParseError::Fail, b"hello")
351            .unwrap()
352            .unwrap();
353        assert_eq!(v, json!("aGVsbG8=")); // base64("hello")
354    }
355
356    #[test]
357    fn shape_raw_passthrough() {
358        let v = shape_record(json!({"a": 1}), false, "wss://x", 123);
359        assert_eq!(v, json!({"a": 1}));
360    }
361
362    #[test]
363    fn shape_enveloped() {
364        let v = shape_record(json!({"a": 1}), true, "wss://x", 123);
365        assert_eq!(
366            v,
367            json!({"data": {"a": 1}, "received_at": 123, "url": "wss://x"})
368        );
369    }
370}