use base64::Engine;
use faucet_core::{AuthSpec, DEFAULT_BATCH_SIZE, FaucetError};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::collections::BTreeMap;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct WebsocketSourceConfig {
pub url: String,
#[serde(default)]
pub auth: AuthSpec<WebsocketAuth>,
#[serde(default)]
pub subscribe_messages: Vec<String>,
#[serde(default)]
pub message_format: WsMessageFormat,
#[serde(default)]
pub on_parse_error: OnParseError,
#[serde(default)]
pub envelope: bool,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "faucet_core::config::duration_secs_option"
)]
#[schemars(with = "Option<u64>")]
pub ping_interval: Option<Duration>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_messages: Option<usize>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "faucet_core::config::duration_secs_option"
)]
#[schemars(with = "Option<u64>")]
pub idle_timeout: Option<Duration>,
#[serde(default)]
pub reconnect: bool,
#[serde(
default = "default_backoff",
with = "faucet_core::config::duration_secs"
)]
#[schemars(with = "u64")]
pub reconnect_backoff: Duration,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_reconnect_attempts: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_message_bytes: Option<usize>,
#[serde(default = "default_batch_size")]
pub batch_size: usize,
}
fn default_backoff() -> Duration {
Duration::from_secs(1)
}
fn default_batch_size() -> usize {
DEFAULT_BATCH_SIZE
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(tag = "type", content = "config", rename_all = "snake_case")]
pub enum WebsocketAuth {
#[default]
None,
Bearer { token: String },
Custom { headers: BTreeMap<String, String> },
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum WsMessageFormat {
#[default]
Json,
RawString,
Binary,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum OnParseError {
#[default]
Fail,
Skip,
}
impl WebsocketSourceConfig {
pub fn validate(&self) -> Result<(), FaucetError> {
let url = self.url.trim();
if url.is_empty() {
return Err(FaucetError::Config(
"websocket source: url must not be empty".into(),
));
}
if !(url.starts_with("ws://") || url.starts_with("wss://")) {
return Err(FaucetError::Config(format!(
"websocket source: url must start with ws:// or wss:// (got {url})"
)));
}
if self.max_messages.is_none() && self.idle_timeout.is_none() {
return Err(FaucetError::Config(
"websocket source: at least one of max_messages or idle_timeout must be set".into(),
));
}
faucet_core::validate_batch_size(self.batch_size)?;
Ok(())
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
}
pub(crate) fn decode_frame(
format: WsMessageFormat,
on_parse_error: OnParseError,
payload: &[u8],
) -> Result<Option<Value>, FaucetError> {
match format {
WsMessageFormat::Json => match serde_json::from_slice::<Value>(payload) {
Ok(v) => Ok(Some(v)),
Err(e) => match on_parse_error {
OnParseError::Fail => {
Err(FaucetError::Source(format!("websocket json parse: {e}")))
}
OnParseError::Skip => {
tracing::warn!(error = %e, "websocket source: dropping non-JSON frame");
Ok(None)
}
},
},
WsMessageFormat::RawString => Ok(Some(Value::String(
String::from_utf8_lossy(payload).into_owned(),
))),
WsMessageFormat::Binary => {
let encoded = base64::engine::general_purpose::STANDARD.encode(payload);
Ok(Some(Value::String(encoded)))
}
}
}
pub(crate) fn shape_record(value: Value, envelope: bool, url: &str, now_ms: u64) -> Value {
if envelope {
json!({ "data": value, "received_at": now_ms, "url": url })
} else {
value
}
}
#[cfg(test)]
mod config_tests {
use super::*;
fn minimal() -> WebsocketSourceConfig {
WebsocketSourceConfig {
url: "wss://example.com/ws".into(),
auth: AuthSpec::Inline(WebsocketAuth::None),
subscribe_messages: vec![],
message_format: WsMessageFormat::Json,
on_parse_error: OnParseError::Fail,
envelope: false,
ping_interval: None,
max_messages: Some(10),
idle_timeout: None,
reconnect: false,
reconnect_backoff: Duration::from_secs(1),
max_reconnect_attempts: None,
max_message_bytes: None,
batch_size: DEFAULT_BATCH_SIZE,
}
}
#[test]
fn validate_accepts_minimal() {
assert!(minimal().validate().is_ok());
}
#[test]
fn validate_rejects_empty_url() {
let mut c = minimal();
c.url = " ".into();
assert!(c.validate().is_err());
}
#[test]
fn validate_rejects_non_ws_scheme() {
let mut c = minimal();
c.url = "https://example.com".into();
assert!(c.validate().is_err());
}
#[test]
fn validate_rejects_no_termination() {
let mut c = minimal();
c.max_messages = None;
c.idle_timeout = None;
assert!(c.validate().is_err());
}
#[test]
fn validate_accepts_idle_only() {
let mut c = minimal();
c.max_messages = None;
c.idle_timeout = Some(Duration::from_secs(5));
assert!(c.validate().is_ok());
}
#[test]
fn validate_rejects_oversize_batch() {
let mut c = minimal();
c.batch_size = faucet_core::MAX_BATCH_SIZE + 1;
assert!(c.validate().is_err());
}
#[test]
fn auth_bearer_round_trips_as_adjacently_tagged() {
let json = serde_json::json!({"type": "bearer", "config": {"token": "abc"}});
let auth: WebsocketAuth = serde_json::from_value(json).unwrap();
assert_eq!(
auth,
WebsocketAuth::Bearer {
token: "abc".into()
}
);
}
#[test]
fn auth_spec_inline_round_trips() {
let json = serde_json::json!({"type": "bearer", "config": {"token": "tok"}});
let spec: AuthSpec<WebsocketAuth> = serde_json::from_value(json).unwrap();
assert!(matches!(
spec,
AuthSpec::Inline(WebsocketAuth::Bearer { .. })
));
}
#[test]
fn auth_spec_ref_round_trips() {
let json = serde_json::json!({"ref": "my-provider"});
let spec: AuthSpec<WebsocketAuth> = serde_json::from_value(json).unwrap();
assert_eq!(spec.reference_name(), Some("my-provider"));
}
}
#[cfg(test)]
mod helper_tests {
use super::*;
use serde_json::json;
#[test]
fn decode_json_object() {
let v = decode_frame(WsMessageFormat::Json, OnParseError::Fail, br#"{"a":1}"#)
.unwrap()
.unwrap();
assert_eq!(v, json!({"a": 1}));
}
#[test]
fn decode_json_invalid_fails() {
let r = decode_frame(WsMessageFormat::Json, OnParseError::Fail, b"not json");
assert!(r.is_err());
}
#[test]
fn decode_json_invalid_skipped_yields_none() {
let r = decode_frame(WsMessageFormat::Json, OnParseError::Skip, b"not json").unwrap();
assert!(r.is_none());
}
#[test]
fn decode_raw_string() {
let v = decode_frame(WsMessageFormat::RawString, OnParseError::Fail, b"hello")
.unwrap()
.unwrap();
assert_eq!(v, json!("hello"));
}
#[test]
fn decode_binary_base64() {
let v = decode_frame(WsMessageFormat::Binary, OnParseError::Fail, b"hello")
.unwrap()
.unwrap();
assert_eq!(v, json!("aGVsbG8=")); }
#[test]
fn shape_raw_passthrough() {
let v = shape_record(json!({"a": 1}), false, "wss://x", 123);
assert_eq!(v, json!({"a": 1}));
}
#[test]
fn shape_enveloped() {
let v = shape_record(json!({"a": 1}), true, "wss://x", 123);
assert_eq!(
v,
json!({"data": {"a": 1}, "received_at": 123, "url": "wss://x"})
);
}
}