1use crate::config::{WebsocketAuth, WebsocketSourceConfig, decode_frame, shape_record};
4use async_trait::async_trait;
5use base64::Engine;
6use faucet_core::{
7 AuthSpec, Credential, FaucetError, SharedAuthProvider, Source, Stream, StreamPage,
8};
9use futures::{SinkExt, StreamExt};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::pin::Pin;
13use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
14use tokio::net::TcpStream;
15use tokio_tungstenite::tungstenite::client::IntoClientRequest;
16use tokio_tungstenite::tungstenite::handshake::client::Request;
17use tokio_tungstenite::tungstenite::http::{HeaderName, HeaderValue, header};
18use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
19use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig};
20use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async_with_config};
21
22type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
23
24fn now_unix_ms() -> u64 {
25 SystemTime::now()
26 .duration_since(UNIX_EPOCH)
27 .map(|d| d.as_millis() as u64)
28 .unwrap_or(0)
29}
30
31fn credential_to_auth(cred: Credential) -> WebsocketAuth {
37 use std::collections::BTreeMap;
38 match cred {
39 Credential::Bearer(token) => WebsocketAuth::Bearer { token },
40 Credential::Token(t) => WebsocketAuth::Custom {
41 headers: BTreeMap::from([("Authorization".to_string(), t)]),
42 },
43 Credential::Header { name, value } => WebsocketAuth::Custom {
44 headers: BTreeMap::from([(name, value)]),
45 },
46 Credential::Basic { username, password } => {
47 let encoded =
48 base64::engine::general_purpose::STANDARD.encode(format!("{username}:{password}"));
49 WebsocketAuth::Custom {
50 headers: BTreeMap::from([(
51 "Authorization".to_string(),
52 format!("Basic {encoded}"),
53 )]),
54 }
55 }
56 }
57}
58
59pub(crate) fn apply_auth(request: &mut Request, auth: &WebsocketAuth) -> Result<(), FaucetError> {
61 let headers = request.headers_mut();
62 match auth {
63 WebsocketAuth::None => {}
64 WebsocketAuth::Bearer { token } => {
65 let value = HeaderValue::from_str(&format!("Bearer {token}"))
66 .map_err(|e| FaucetError::Config(format!("websocket bearer header: {e}")))?;
67 headers.insert(header::AUTHORIZATION, value);
68 }
69 WebsocketAuth::Custom { headers: custom } => {
70 for (k, v) in custom {
71 let name = HeaderName::from_bytes(k.as_bytes())
72 .map_err(|e| FaucetError::Config(format!("websocket header name {k}: {e}")))?;
73 let value = HeaderValue::from_str(v)
74 .map_err(|e| FaucetError::Config(format!("websocket header value {k}: {e}")))?;
75 headers.insert(name, value);
76 }
77 }
78 }
79 Ok(())
80}
81
82pub struct WebsocketSource {
84 config: WebsocketSourceConfig,
85 auth_provider: Option<SharedAuthProvider>,
89}
90
91impl WebsocketSource {
92 pub fn new(config: WebsocketSourceConfig) -> Result<Self, FaucetError> {
96 config.validate()?;
97 Ok(Self {
98 config,
99 auth_provider: None,
100 })
101 }
102
103 pub fn with_auth_provider(mut self, provider: SharedAuthProvider) -> Self {
108 self.auth_provider = Some(provider);
109 self
110 }
111
112 async fn connect(&self, url: &str) -> Result<WsStream, FaucetError> {
117 let mut request = url
118 .into_client_request()
119 .map_err(|e| FaucetError::Config(format!("websocket url {url}: {e}")))?;
120
121 let effective_auth = if let Some(p) = &self.auth_provider {
123 credential_to_auth(p.credential().await?)
124 } else {
125 match &self.config.auth {
126 AuthSpec::Inline(a) => a.clone(),
127 AuthSpec::Reference(r) => {
128 return Err(FaucetError::Auth(format!(
129 "auth references provider '{}' but no provider was supplied; \
130 set one via the CLI `auth:` catalog or `with_auth_provider`",
131 r.name
132 )));
133 }
134 }
135 };
136 apply_auth(&mut request, &effective_auth)?;
137
138 let ws_config = self.config.max_message_bytes.map(|n| {
139 WebSocketConfig::default()
140 .max_message_size(Some(n))
141 .max_frame_size(Some(n))
142 });
143
144 let (mut ws, _resp) = connect_async_with_config(request, ws_config, false)
145 .await
146 .map_err(|e| FaucetError::Source(format!("websocket connect {url}: {e}")))?;
147
148 for msg in &self.config.subscribe_messages {
149 ws.send(Message::Text(msg.clone().into()))
150 .await
151 .map_err(|e| FaucetError::Source(format!("websocket subscribe: {e}")))?;
152 }
153 Ok(ws)
154 }
155}
156
157#[async_trait]
158impl Source for WebsocketSource {
159 async fn fetch_with_context(
164 &self,
165 context: &HashMap<String, Value>,
166 ) -> Result<Vec<Value>, FaucetError> {
167 let mut out = Vec::new();
168 let mut pages = self.stream_pages(context, self.config.batch_size);
169 while let Some(page) = pages.next().await {
170 out.extend(page?.records);
171 }
172 Ok(out)
173 }
174
175 fn stream_pages<'a>(
176 &'a self,
177 context: &'a HashMap<String, Value>,
178 _batch_size: usize,
179 ) -> Pin<Box<dyn Stream<Item = Result<StreamPage, FaucetError>> + Send + 'a>> {
180 let resolved_url = faucet_core::util::substitute_context(&self.config.url, context);
181 let batch_size = self.config.batch_size;
182 let page_chunk = if batch_size == 0 {
183 usize::MAX
184 } else {
185 batch_size
186 };
187 let max_messages = self.config.max_messages.unwrap_or(usize::MAX);
188 let idle_timeout = self.config.idle_timeout;
189 let reconnect = self.config.reconnect;
190 let backoff = self.config.reconnect_backoff;
191 let max_attempts = self.config.max_reconnect_attempts;
192 let ping_interval = self.config.ping_interval;
193 let format = self.config.message_format;
194 let on_parse_error = self.config.on_parse_error;
195 let envelope = self.config.envelope;
196
197 Box::pin(async_stream::try_stream! {
198 let mut buffer: Vec<Value> = Vec::new();
199 let mut total: usize = 0;
200 let mut last_message_at = Instant::now();
201 let mut reconnect_attempts: usize = 0;
202
203 'outer: loop {
204 if let Some(t) = idle_timeout
206 && Instant::now() >= last_message_at + t
207 {
208 tracing::debug!("websocket source: idle_timeout reached, stopping");
209 break 'outer;
210 }
211
212 let ws = match self.connect(&resolved_url).await {
214 Ok(ws) => {
215 reconnect_attempts = 0;
216 ws
217 }
218 Err(e) => {
219 if reconnect
220 && max_attempts.is_none_or(|m| reconnect_attempts < m)
221 {
222 reconnect_attempts += 1;
223 tracing::warn!(error = %e, attempt = reconnect_attempts, "websocket source: connect failed, retrying");
224 tokio::time::sleep(backoff).await;
225 continue 'outer;
226 }
227 Err(e)?;
228 break 'outer; }
230 };
231
232 let (mut write, mut read) = ws.split();
233 let mut ping_timer = ping_interval.map(|interval| {
237 tokio::time::interval_at(tokio::time::Instant::now() + interval, interval)
238 });
239
240 loop {
241 let idle_deadline = idle_timeout.map(|t| last_message_at + t);
242 let poll_budget = match idle_deadline {
243 Some(d) => d.saturating_duration_since(Instant::now()),
244 None => Duration::from_secs(3600),
245 };
246
247 let mut stop = false;
250 let mut fatal: Option<FaucetError> = None;
251 let mut reconnect_now = false;
252
253 let mut handle_payload = |payload: &[u8]| {
258 last_message_at = Instant::now();
267 match decode_frame(format, on_parse_error, payload) {
268 Ok(Some(v)) => {
269 let now = if envelope { now_unix_ms() } else { 0 };
270 buffer.push(shape_record(v, envelope, &resolved_url, now));
271 reconnect_attempts = 0;
272 total += 1;
273 if total >= max_messages {
274 stop = true;
275 }
276 }
277 Ok(None) => {}
278 Err(e) => fatal = Some(e),
279 }
280 };
281
282 tokio::select! {
283 biased;
284 _ = tokio::signal::ctrl_c() => {
285 tracing::info!("websocket source: ctrl_c received, stopping cleanly");
286 stop = true;
287 }
288 _ = async { ping_timer.as_mut().unwrap().tick().await }, if ping_timer.is_some() => {
289 if let Err(e) = write.send(Message::Ping(Vec::new().into())).await {
290 tracing::warn!(error = %e, "websocket source: ping failed, treating as disconnect");
291 reconnect_now = true;
292 }
293 }
294 recv = tokio::time::timeout(poll_budget, read.next()) => {
295 match recv {
296 Ok(Some(Ok(msg))) => {
297 match msg {
298 Message::Text(t) => handle_payload(t.as_bytes()),
299 Message::Binary(b) => handle_payload(&b),
300 Message::Ping(payload) => {
301 if let Err(e) = write.send(Message::Pong(payload)).await {
302 tracing::warn!(error = %e, "websocket source: pong failed");
303 reconnect_now = true;
304 }
305 }
306 Message::Pong(_) | Message::Frame(_) => {}
307 Message::Close(frame) => {
308 let clean = frame
309 .as_ref()
310 .map(|f| f.code == CloseCode::Normal)
311 .unwrap_or(true);
312 if clean && !reconnect {
313 tracing::info!("websocket source: server closed (1000), stopping");
314 stop = true;
315 } else {
316 tracing::warn!(?frame, "websocket source: connection closed");
317 reconnect_now = true;
318 }
319 }
320 }
321 }
322 Ok(Some(Err(e))) => {
323 tracing::warn!(error = %e, "websocket source: read error");
324 reconnect_now = true;
325 }
326 Ok(None) => {
327 tracing::warn!("websocket source: stream ended");
328 reconnect_now = true;
329 }
330 Err(_elapsed) => {
331 if let Some(d) = idle_deadline
332 && Instant::now() >= d
333 {
334 tracing::debug!("websocket source: idle_timeout reached, stopping");
335 stop = true;
336 }
337 }
338 }
339 }
340 }
341
342 if let Some(e) = fatal {
343 Err(e)?;
344 }
345
346 if !buffer.is_empty() && buffer.len() >= page_chunk {
347 let page = std::mem::take(&mut buffer);
348 yield StreamPage { records: page, bookmark: None };
349 }
350
351 if stop {
352 break 'outer;
353 }
354
355 if reconnect_now {
356 if reconnect && max_attempts.is_none_or(|m| reconnect_attempts < m) {
363 reconnect_attempts += 1;
364 tracing::warn!(attempt = reconnect_attempts, "websocket source: reconnecting");
365 tokio::time::sleep(backoff).await;
366 continue 'outer;
367 } else if reconnect {
368 Err(FaucetError::Source(format!(
369 "websocket source: exceeded max_reconnect_attempts ({})",
370 max_attempts.unwrap_or(0)
371 )))?;
372 } else {
373 Err(FaucetError::Source(
374 "websocket source: connection closed and reconnect=false".into(),
375 ))?;
376 }
377 }
378 }
379 }
380
381 if !buffer.is_empty() {
382 yield StreamPage { records: buffer, bookmark: None };
383 }
384
385 tracing::info!(messages = total, "websocket source: stream complete");
386 })
387 }
388
389 fn config_schema(&self) -> Value {
390 let schema = schemars::schema_for!(WebsocketSourceConfig);
391 serde_json::to_value(&schema).unwrap_or(Value::Null)
392 }
393
394 fn connector_name(&self) -> &'static str {
395 "websocket"
396 }
397
398 async fn check(
408 &self,
409 ctx: &faucet_core::check::CheckContext,
410 ) -> Result<faucet_core::check::CheckReport, FaucetError> {
411 use faucet_core::check::{CheckReport, Probe};
412
413 let start = std::time::Instant::now();
414
415 let (host, port) = match resolve_host_port(&self.config.url) {
419 Ok(hp) => hp,
420 Err(reason) => {
421 return Ok(CheckReport::single(Probe::fail_hint(
422 "network",
423 start.elapsed(),
424 reason,
425 "url must be ws://host[:port]/... or wss://host[:port]/...",
426 )));
427 }
428 };
429
430 let connect = tokio::net::TcpStream::connect((host.as_str(), port));
431 match tokio::time::timeout(ctx.timeout, connect).await {
432 Ok(Ok(stream)) => {
433 drop(stream);
434 Ok(CheckReport::single(Probe::pass("network", start.elapsed())))
435 }
436 Ok(Err(e)) => Ok(CheckReport::single(Probe::fail_hint(
437 "network",
438 start.elapsed(),
439 e.to_string(),
440 format!("cannot reach {host}:{port} over TCP"),
441 ))),
442 Err(_elapsed) => Ok(CheckReport::single(Probe::fail_hint(
443 "network",
444 start.elapsed(),
445 format!("TCP connect to {host}:{port} timed out"),
446 format!("{host}:{port} did not accept a connection within the check timeout"),
447 ))),
448 }
449 }
450}
451
452fn resolve_host_port(url: &str) -> Result<(String, u16), String> {
458 let request = url
459 .into_client_request()
460 .map_err(|e| format!("invalid websocket url: {e}"))?;
461 let uri = request.uri();
462 let host = uri
463 .host()
464 .filter(|h| !h.is_empty())
465 .ok_or_else(|| "websocket url is missing a host".to_string())?
466 .to_string();
467 let default_port = match uri.scheme_str() {
468 Some("wss") => 443,
469 _ => 80,
470 };
471 let port = uri.port_u16().unwrap_or(default_port);
472 Ok((host, port))
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478 use std::collections::BTreeMap;
479
480 #[test]
481 fn credential_bearer_maps_to_bearer() {
482 let auth = credential_to_auth(Credential::Bearer("tok".into()));
483 assert_eq!(
484 auth,
485 WebsocketAuth::Bearer {
486 token: "tok".into()
487 }
488 );
489 }
490
491 #[test]
492 fn credential_token_maps_to_custom_authorization() {
493 let auth = credential_to_auth(Credential::Token("Custom xyz".into()));
494 assert_eq!(
495 auth,
496 WebsocketAuth::Custom {
497 headers: BTreeMap::from([("Authorization".into(), "Custom xyz".into())])
498 }
499 );
500 }
501
502 #[test]
503 fn credential_header_maps_to_custom() {
504 let auth = credential_to_auth(Credential::Header {
505 name: "X-Api-Key".into(),
506 value: "k123".into(),
507 });
508 assert_eq!(
509 auth,
510 WebsocketAuth::Custom {
511 headers: BTreeMap::from([("X-Api-Key".into(), "k123".into())])
512 }
513 );
514 }
515
516 #[test]
517 fn credential_basic_maps_to_base64_authorization() {
518 let auth = credential_to_auth(Credential::Basic {
519 username: "user".into(),
520 password: "pass".into(),
521 });
522 assert_eq!(
524 auth,
525 WebsocketAuth::Custom {
526 headers: BTreeMap::from([("Authorization".into(), "Basic dXNlcjpwYXNz".into())])
527 }
528 );
529 }
530
531 #[test]
532 fn resolve_host_port_applies_scheme_defaults() {
533 assert_eq!(
534 resolve_host_port("ws://example.com/feed").unwrap(),
535 ("example.com".to_string(), 80)
536 );
537 assert_eq!(
538 resolve_host_port("wss://example.com/feed").unwrap(),
539 ("example.com".to_string(), 443)
540 );
541 assert_eq!(
542 resolve_host_port("wss://example.com:9443/feed").unwrap(),
543 ("example.com".to_string(), 9443)
544 );
545 }
546
547 #[tokio::test]
548 async fn check_passes_against_a_live_tcp_listener() {
549 use faucet_core::check::{CheckContext, ProbeStatus};
550
551 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
555 let addr = listener.local_addr().unwrap();
556
557 let config = WebsocketSourceConfig {
558 url: format!("ws://{addr}/feed"),
559 auth: AuthSpec::Inline(WebsocketAuth::None),
560 subscribe_messages: vec![],
561 message_format: crate::config::WsMessageFormat::Json,
562 on_parse_error: crate::config::OnParseError::Fail,
563 envelope: false,
564 ping_interval: None,
565 max_messages: Some(1),
566 idle_timeout: None,
567 reconnect: false,
568 reconnect_backoff: Duration::from_secs(1),
569 max_reconnect_attempts: None,
570 max_message_bytes: None,
571 batch_size: faucet_core::DEFAULT_BATCH_SIZE,
572 };
573 let source = WebsocketSource::new(config).unwrap();
574 let report = source.check(&CheckContext::default()).await.unwrap();
575 assert_eq!(report.probes.len(), 1);
576 assert_eq!(report.probes[0].name, "network");
577 assert!(
578 matches!(report.probes[0].status, ProbeStatus::Pass),
579 "expected Pass, got {:?}",
580 report.probes[0].status
581 );
582 }
583
584 #[tokio::test]
585 async fn check_fails_against_a_closed_port() {
586 use faucet_core::check::{CheckContext, ProbeStatus};
587
588 let addr = {
591 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
592 listener.local_addr().unwrap()
593 };
594
595 let config = WebsocketSourceConfig {
596 url: format!("ws://{addr}/feed"),
597 auth: AuthSpec::Inline(WebsocketAuth::None),
598 subscribe_messages: vec![],
599 message_format: crate::config::WsMessageFormat::Json,
600 on_parse_error: crate::config::OnParseError::Fail,
601 envelope: false,
602 ping_interval: None,
603 max_messages: Some(1),
604 idle_timeout: None,
605 reconnect: false,
606 reconnect_backoff: Duration::from_secs(1),
607 max_reconnect_attempts: None,
608 max_message_bytes: None,
609 batch_size: faucet_core::DEFAULT_BATCH_SIZE,
610 };
611 let source = WebsocketSource::new(config).unwrap();
612 let report = source
613 .check(&CheckContext {
614 timeout: Duration::from_secs(2),
615 })
616 .await
617 .unwrap();
618 assert_eq!(report.probes.len(), 1);
619 assert_eq!(report.probes[0].name, "network");
620 assert!(
621 matches!(report.probes[0].status, ProbeStatus::Fail { .. }),
622 "expected Fail, got {:?}",
623 report.probes[0].status
624 );
625 assert_eq!(report.failed_count(), 1);
626 }
627}