1use base64::Engine;
4use faucet_core::{AuthSpec, DEFAULT_BATCH_SIZE, FaucetError};
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use serde_json::{Value, json};
8use std::collections::BTreeMap;
9use std::time::Duration;
10
11#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
13pub struct WebsocketSourceConfig {
14 pub url: String,
17
18 #[serde(default)]
22 pub auth: AuthSpec<WebsocketAuth>,
23
24 #[serde(default)]
27 pub subscribe_messages: Vec<String>,
28
29 #[serde(default)]
31 pub message_format: WsMessageFormat,
32
33 #[serde(default)]
35 pub on_parse_error: OnParseError,
36
37 #[serde(default)]
40 pub envelope: bool,
41
42 #[serde(
45 default,
46 skip_serializing_if = "Option::is_none",
47 with = "faucet_core::config::duration_secs_option"
48 )]
49 #[schemars(with = "Option<u64>")]
50 pub ping_interval: Option<Duration>,
51
52 #[serde(default, skip_serializing_if = "Option::is_none")]
55 pub max_messages: Option<usize>,
56
57 #[serde(
60 default,
61 skip_serializing_if = "Option::is_none",
62 with = "faucet_core::config::duration_secs_option"
63 )]
64 #[schemars(with = "Option<u64>")]
65 pub idle_timeout: Option<Duration>,
66
67 #[serde(default)]
69 pub reconnect: bool,
70
71 #[serde(
73 default = "default_backoff",
74 with = "faucet_core::config::duration_secs"
75 )]
76 #[schemars(with = "u64")]
77 pub reconnect_backoff: Duration,
78
79 #[serde(default, skip_serializing_if = "Option::is_none")]
82 pub max_reconnect_attempts: Option<usize>,
83
84 #[serde(default, skip_serializing_if = "Option::is_none")]
87 pub max_message_bytes: Option<usize>,
88
89 #[serde(default = "default_batch_size")]
93 pub batch_size: usize,
94}
95
96fn default_backoff() -> Duration {
97 Duration::from_secs(1)
98}
99
100fn default_batch_size() -> usize {
101 DEFAULT_BATCH_SIZE
102}
103
104#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
106#[serde(tag = "type", content = "config", rename_all = "snake_case")]
107pub enum WebsocketAuth {
108 #[default]
110 None,
111 Bearer { token: String },
113 Custom { headers: BTreeMap<String, String> },
115}
116
117#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
119#[serde(rename_all = "snake_case")]
120pub enum WsMessageFormat {
121 #[default]
123 Json,
124 RawString,
126 Binary,
128}
129
130#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
132#[serde(rename_all = "snake_case")]
133pub enum OnParseError {
134 #[default]
136 Fail,
137 Skip,
139}
140
141impl WebsocketSourceConfig {
142 pub fn validate(&self) -> Result<(), FaucetError> {
144 let url = self.url.trim();
145 if url.is_empty() {
146 return Err(FaucetError::Config(
147 "websocket source: url must not be empty".into(),
148 ));
149 }
150 if !(url.starts_with("ws://") || url.starts_with("wss://")) {
151 return Err(FaucetError::Config(format!(
152 "websocket source: url must start with ws:// or wss:// (got {url})"
153 )));
154 }
155 if self.max_messages.is_none() && self.idle_timeout.is_none() {
156 return Err(FaucetError::Config(
157 "websocket source: at least one of max_messages or idle_timeout must be set".into(),
158 ));
159 }
160 faucet_core::validate_batch_size(self.batch_size)?;
161 Ok(())
162 }
163
164 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
168 self.batch_size = batch_size;
169 self
170 }
171}
172
173pub(crate) fn decode_frame(
179 format: WsMessageFormat,
180 on_parse_error: OnParseError,
181 payload: &[u8],
182) -> Result<Option<Value>, FaucetError> {
183 match format {
184 WsMessageFormat::Json => match serde_json::from_slice::<Value>(payload) {
185 Ok(v) => Ok(Some(v)),
186 Err(e) => match on_parse_error {
187 OnParseError::Fail => {
188 Err(FaucetError::Source(format!("websocket json parse: {e}")))
189 }
190 OnParseError::Skip => {
191 tracing::warn!(error = %e, "websocket source: dropping non-JSON frame");
192 Ok(None)
193 }
194 },
195 },
196 WsMessageFormat::RawString => Ok(Some(Value::String(
197 String::from_utf8_lossy(payload).into_owned(),
198 ))),
199 WsMessageFormat::Binary => {
200 let encoded = base64::engine::general_purpose::STANDARD.encode(payload);
201 Ok(Some(Value::String(encoded)))
202 }
203 }
204}
205
206pub(crate) fn shape_record(value: Value, envelope: bool, url: &str, now_ms: u64) -> Value {
211 if envelope {
212 json!({ "data": value, "received_at": now_ms, "url": url })
213 } else {
214 value
215 }
216}
217
218#[cfg(test)]
219mod config_tests {
220 use super::*;
221
222 fn minimal() -> WebsocketSourceConfig {
223 WebsocketSourceConfig {
224 url: "wss://example.com/ws".into(),
225 auth: AuthSpec::Inline(WebsocketAuth::None),
226 subscribe_messages: vec![],
227 message_format: WsMessageFormat::Json,
228 on_parse_error: OnParseError::Fail,
229 envelope: false,
230 ping_interval: None,
231 max_messages: Some(10),
232 idle_timeout: None,
233 reconnect: false,
234 reconnect_backoff: Duration::from_secs(1),
235 max_reconnect_attempts: None,
236 max_message_bytes: None,
237 batch_size: DEFAULT_BATCH_SIZE,
238 }
239 }
240
241 #[test]
242 fn validate_accepts_minimal() {
243 assert!(minimal().validate().is_ok());
244 }
245
246 #[test]
247 fn validate_rejects_empty_url() {
248 let mut c = minimal();
249 c.url = " ".into();
250 assert!(c.validate().is_err());
251 }
252
253 #[test]
254 fn validate_rejects_non_ws_scheme() {
255 let mut c = minimal();
256 c.url = "https://example.com".into();
257 assert!(c.validate().is_err());
258 }
259
260 #[test]
261 fn validate_rejects_no_termination() {
262 let mut c = minimal();
263 c.max_messages = None;
264 c.idle_timeout = None;
265 assert!(c.validate().is_err());
266 }
267
268 #[test]
269 fn validate_accepts_idle_only() {
270 let mut c = minimal();
271 c.max_messages = None;
272 c.idle_timeout = Some(Duration::from_secs(5));
273 assert!(c.validate().is_ok());
274 }
275
276 #[test]
277 fn validate_rejects_oversize_batch() {
278 let mut c = minimal();
279 c.batch_size = faucet_core::MAX_BATCH_SIZE + 1;
280 assert!(c.validate().is_err());
281 }
282
283 #[test]
284 fn auth_bearer_round_trips_as_adjacently_tagged() {
285 let json = serde_json::json!({"type": "bearer", "config": {"token": "abc"}});
287 let auth: WebsocketAuth = serde_json::from_value(json).unwrap();
288 assert_eq!(
289 auth,
290 WebsocketAuth::Bearer {
291 token: "abc".into()
292 }
293 );
294 }
295
296 #[test]
297 fn auth_spec_inline_round_trips() {
298 let json = serde_json::json!({"type": "bearer", "config": {"token": "tok"}});
300 let spec: AuthSpec<WebsocketAuth> = serde_json::from_value(json).unwrap();
301 assert!(matches!(
302 spec,
303 AuthSpec::Inline(WebsocketAuth::Bearer { .. })
304 ));
305 }
306
307 #[test]
308 fn auth_spec_ref_round_trips() {
309 let json = serde_json::json!({"ref": "my-provider"});
310 let spec: AuthSpec<WebsocketAuth> = serde_json::from_value(json).unwrap();
311 assert_eq!(spec.reference_name(), Some("my-provider"));
312 }
313}
314
315#[cfg(test)]
316mod helper_tests {
317 use super::*;
318 use serde_json::json;
319
320 #[test]
321 fn decode_json_object() {
322 let v = decode_frame(WsMessageFormat::Json, OnParseError::Fail, br#"{"a":1}"#)
323 .unwrap()
324 .unwrap();
325 assert_eq!(v, json!({"a": 1}));
326 }
327
328 #[test]
329 fn decode_json_invalid_fails() {
330 let r = decode_frame(WsMessageFormat::Json, OnParseError::Fail, b"not json");
331 assert!(r.is_err());
332 }
333
334 #[test]
335 fn decode_json_invalid_skipped_yields_none() {
336 let r = decode_frame(WsMessageFormat::Json, OnParseError::Skip, b"not json").unwrap();
337 assert!(r.is_none());
338 }
339
340 #[test]
341 fn decode_raw_string() {
342 let v = decode_frame(WsMessageFormat::RawString, OnParseError::Fail, b"hello")
343 .unwrap()
344 .unwrap();
345 assert_eq!(v, json!("hello"));
346 }
347
348 #[test]
349 fn decode_binary_base64() {
350 let v = decode_frame(WsMessageFormat::Binary, OnParseError::Fail, b"hello")
351 .unwrap()
352 .unwrap();
353 assert_eq!(v, json!("aGVsbG8=")); }
355
356 #[test]
357 fn shape_raw_passthrough() {
358 let v = shape_record(json!({"a": 1}), false, "wss://x", 123);
359 assert_eq!(v, json!({"a": 1}));
360 }
361
362 #[test]
363 fn shape_enveloped() {
364 let v = shape_record(json!({"a": 1}), true, "wss://x", 123);
365 assert_eq!(
366 v,
367 json!({"data": {"a": 1}, "received_at": 123, "url": "wss://x"})
368 );
369 }
370}