Skip to main content

faucet_source_websocket/
stream.rs

1//! WebSocket source stream executor.
2
3use crate::config::{WebsocketAuth, WebsocketSourceConfig, decode_frame, shape_record};
4use async_trait::async_trait;
5use base64::Engine;
6use faucet_core::{
7    AuthSpec, Credential, FaucetError, SharedAuthProvider, Source, Stream, StreamPage,
8};
9use futures::{SinkExt, StreamExt};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::pin::Pin;
13use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
14use tokio::net::TcpStream;
15use tokio_tungstenite::tungstenite::client::IntoClientRequest;
16use tokio_tungstenite::tungstenite::handshake::client::Request;
17use tokio_tungstenite::tungstenite::http::{HeaderName, HeaderValue, header};
18use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
19use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig};
20use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async_with_config};
21
22type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
23
24fn now_unix_ms() -> u64 {
25    SystemTime::now()
26        .duration_since(UNIX_EPOCH)
27        .map(|d| d.as_millis() as u64)
28        .unwrap_or(0)
29}
30
31/// Map a [`Credential`] from a shared provider onto [`WebsocketAuth`] so the
32/// existing header-application path can be reused.
33///
34/// This mapping is infallible: every credential kind can be expressed as either
35/// `Bearer` or `Custom` headers.
36fn credential_to_auth(cred: Credential) -> WebsocketAuth {
37    use std::collections::BTreeMap;
38    match cred {
39        Credential::Bearer(token) => WebsocketAuth::Bearer { token },
40        Credential::Token(t) => WebsocketAuth::Custom {
41            headers: BTreeMap::from([("Authorization".to_string(), t)]),
42        },
43        Credential::Header { name, value } => WebsocketAuth::Custom {
44            headers: BTreeMap::from([(name, value)]),
45        },
46        Credential::Basic { username, password } => {
47            let encoded =
48                base64::engine::general_purpose::STANDARD.encode(format!("{username}:{password}"));
49            WebsocketAuth::Custom {
50                headers: BTreeMap::from([(
51                    "Authorization".to_string(),
52                    format!("Basic {encoded}"),
53                )]),
54            }
55        }
56    }
57}
58
59/// Apply `auth` to the HTTP upgrade `request`.
60pub(crate) fn apply_auth(request: &mut Request, auth: &WebsocketAuth) -> Result<(), FaucetError> {
61    let headers = request.headers_mut();
62    match auth {
63        WebsocketAuth::None => {}
64        WebsocketAuth::Bearer { token } => {
65            let value = HeaderValue::from_str(&format!("Bearer {token}"))
66                .map_err(|e| FaucetError::Config(format!("websocket bearer header: {e}")))?;
67            headers.insert(header::AUTHORIZATION, value);
68        }
69        WebsocketAuth::Custom { headers: custom } => {
70            for (k, v) in custom {
71                let name = HeaderName::from_bytes(k.as_bytes())
72                    .map_err(|e| FaucetError::Config(format!("websocket header name {k}: {e}")))?;
73                let value = HeaderValue::from_str(v)
74                    .map_err(|e| FaucetError::Config(format!("websocket header value {k}: {e}")))?;
75                headers.insert(name, value);
76            }
77        }
78    }
79    Ok(())
80}
81
82/// A WebSocket streaming source.
83pub struct WebsocketSource {
84    config: WebsocketSourceConfig,
85    /// Optional shared auth provider. When present, its [`Credential`] is
86    /// resolved on every (re)connect so a freshly-rotated token is used after
87    /// reconnect. Takes precedence over inline auth.
88    auth_provider: Option<SharedAuthProvider>,
89}
90
91impl WebsocketSource {
92    /// Create a new WebSocket source. Validates the config; the connection is
93    /// established lazily inside the stream loop (so reconnect can re-establish
94    /// it mid-run).
95    pub fn new(config: WebsocketSourceConfig) -> Result<Self, FaucetError> {
96        config.validate()?;
97        Ok(Self {
98            config,
99            auth_provider: None,
100        })
101    }
102
103    /// Attach a shared [`AuthProvider`](faucet_core::AuthProvider). When set,
104    /// the provider's credential is resolved on every (re)connect — so a
105    /// refreshed token is used automatically after reconnect. Takes precedence
106    /// over inline auth.
107    pub fn with_auth_provider(mut self, provider: SharedAuthProvider) -> Self {
108        self.auth_provider = Some(provider);
109        self
110    }
111
112    /// Connect, apply auth + size limits, and send the subscribe frames.
113    ///
114    /// Auth is resolved here (not once at construction) so that reconnects
115    /// always pick up a freshly-rotated token from a shared provider.
116    async fn connect(&self, url: &str) -> Result<WsStream, FaucetError> {
117        let mut request = url
118            .into_client_request()
119            .map_err(|e| FaucetError::Config(format!("websocket url {url}: {e}")))?;
120
121        // Resolve effective auth: provider-first, then inline, or error on Reference.
122        let effective_auth = if let Some(p) = &self.auth_provider {
123            credential_to_auth(p.credential().await?)
124        } else {
125            match &self.config.auth {
126                AuthSpec::Inline(a) => a.clone(),
127                AuthSpec::Reference(r) => {
128                    return Err(FaucetError::Auth(format!(
129                        "auth references provider '{}' but no provider was supplied; \
130                         set one via the CLI `auth:` catalog or `with_auth_provider`",
131                        r.name
132                    )));
133                }
134            }
135        };
136        apply_auth(&mut request, &effective_auth)?;
137
138        let ws_config = self.config.max_message_bytes.map(|n| {
139            WebSocketConfig::default()
140                .max_message_size(Some(n))
141                .max_frame_size(Some(n))
142        });
143
144        let (mut ws, _resp) = connect_async_with_config(request, ws_config, false)
145            .await
146            .map_err(|e| FaucetError::Source(format!("websocket connect {url}: {e}")))?;
147
148        for msg in &self.config.subscribe_messages {
149            ws.send(Message::Text(msg.clone().into()))
150                .await
151                .map_err(|e| FaucetError::Source(format!("websocket subscribe: {e}")))?;
152        }
153        Ok(ws)
154    }
155}
156
157#[async_trait]
158impl Source for WebsocketSource {
159    /// Drain the entire run window into memory. This buffers every record the
160    /// run produces (bounded only by `max_messages` / `idle_timeout`); prefer
161    /// [`Source::stream_pages`] for large or long-running feeds so memory stays
162    /// bounded at `batch_size`.
163    async fn fetch_with_context(
164        &self,
165        context: &HashMap<String, Value>,
166    ) -> Result<Vec<Value>, FaucetError> {
167        let mut out = Vec::new();
168        let mut pages = self.stream_pages(context, self.config.batch_size);
169        while let Some(page) = pages.next().await {
170            out.extend(page?.records);
171        }
172        Ok(out)
173    }
174
175    fn stream_pages<'a>(
176        &'a self,
177        context: &'a HashMap<String, Value>,
178        _batch_size: usize,
179    ) -> Pin<Box<dyn Stream<Item = Result<StreamPage, FaucetError>> + Send + 'a>> {
180        let resolved_url = faucet_core::util::substitute_context(&self.config.url, context);
181        let batch_size = self.config.batch_size;
182        let page_chunk = if batch_size == 0 {
183            usize::MAX
184        } else {
185            batch_size
186        };
187        let max_messages = self.config.max_messages.unwrap_or(usize::MAX);
188        let idle_timeout = self.config.idle_timeout;
189        let reconnect = self.config.reconnect;
190        let backoff = self.config.reconnect_backoff;
191        let max_attempts = self.config.max_reconnect_attempts;
192        let ping_interval = self.config.ping_interval;
193        let format = self.config.message_format;
194        let on_parse_error = self.config.on_parse_error;
195        let envelope = self.config.envelope;
196
197        Box::pin(async_stream::try_stream! {
198            let mut buffer: Vec<Value> = Vec::new();
199            let mut total: usize = 0;
200            let mut last_message_at = Instant::now();
201            let mut reconnect_attempts: usize = 0;
202
203            'outer: loop {
204                // Idle cap also bounds connect-failure spins and reconnect gaps.
205                if let Some(t) = idle_timeout
206                    && Instant::now() >= last_message_at + t
207                {
208                    tracing::debug!("websocket source: idle_timeout reached, stopping");
209                    break 'outer;
210                }
211
212                // (Re)connect.
213                let ws = match self.connect(&resolved_url).await {
214                    Ok(ws) => {
215                        reconnect_attempts = 0;
216                        ws
217                    }
218                    Err(e) => {
219                        if reconnect
220                            && max_attempts.is_none_or(|m| reconnect_attempts < m)
221                        {
222                            reconnect_attempts += 1;
223                            tracing::warn!(error = %e, attempt = reconnect_attempts, "websocket source: connect failed, retrying");
224                            tokio::time::sleep(backoff).await;
225                            continue 'outer;
226                        }
227                        Err(e)?;
228                        break 'outer; // unreachable; satisfies the type checker
229                    }
230                };
231
232                let (mut write, mut read) = ws.split();
233                // Start one interval out so the first `tick()` does not fire
234                // immediately (which would send a Ping before any read on every
235                // (re)connect). `tokio::time::interval` ticks at t=0.
236                let mut ping_timer = ping_interval.map(|interval| {
237                    tokio::time::interval_at(tokio::time::Instant::now() + interval, interval)
238                });
239
240                loop {
241                    let idle_deadline = idle_timeout.map(|t| last_message_at + t);
242                    let poll_budget = match idle_deadline {
243                        Some(d) => d.saturating_duration_since(Instant::now()),
244                        None => Duration::from_secs(3600),
245                    };
246
247                    // Flags collected from the select arms; `?` cannot cross
248                    // the select match boundary into the try_stream! body.
249                    let mut stop = false;
250                    let mut fatal: Option<FaucetError> = None;
251                    let mut reconnect_now = false;
252
253                    // Decode a data-frame payload (Text or Binary), shape it,
254                    // push it, and update the run-window counters. The only
255                    // per-arm difference is `t.as_bytes()` vs `&b`, so both
256                    // arms funnel through this single closure.
257                    let mut handle_payload = |payload: &[u8]| {
258                        // A data frame (Text/Binary) arrived, so the server is
259                        // delivering — reset the idle timer here, before decode,
260                        // so a frame dropped by on_parse_error=skip (Ok(None))
261                        // still counts as activity (#146 M9). Control frames
262                        // (Ping/Pong/Close) deliberately do NOT reset it: a
263                        // client keepalive (ping_interval) elicits pongs, and
264                        // resetting on those would make idle_timeout unreachable
265                        // whenever ping_interval < idle_timeout.
266                        last_message_at = Instant::now();
267                        match decode_frame(format, on_parse_error, payload) {
268                            Ok(Some(v)) => {
269                                let now = if envelope { now_unix_ms() } else { 0 };
270                                buffer.push(shape_record(v, envelope, &resolved_url, now));
271                                reconnect_attempts = 0;
272                                total += 1;
273                                if total >= max_messages {
274                                    stop = true;
275                                }
276                            }
277                            Ok(None) => {}
278                            Err(e) => fatal = Some(e),
279                        }
280                    };
281
282                    tokio::select! {
283                        biased;
284                        _ = tokio::signal::ctrl_c() => {
285                            tracing::info!("websocket source: ctrl_c received, stopping cleanly");
286                            stop = true;
287                        }
288                        _ = async { ping_timer.as_mut().unwrap().tick().await }, if ping_timer.is_some() => {
289                            if let Err(e) = write.send(Message::Ping(Vec::new().into())).await {
290                                tracing::warn!(error = %e, "websocket source: ping failed, treating as disconnect");
291                                reconnect_now = true;
292                            }
293                        }
294                        recv = tokio::time::timeout(poll_budget, read.next()) => {
295                            match recv {
296                                Ok(Some(Ok(msg))) => {
297                                    match msg {
298                                        Message::Text(t) => handle_payload(t.as_bytes()),
299                                        Message::Binary(b) => handle_payload(&b),
300                                        Message::Ping(payload) => {
301                                            if let Err(e) = write.send(Message::Pong(payload)).await {
302                                                tracing::warn!(error = %e, "websocket source: pong failed");
303                                                reconnect_now = true;
304                                            }
305                                        }
306                                        Message::Pong(_) | Message::Frame(_) => {}
307                                        Message::Close(frame) => {
308                                            let clean = frame
309                                                .as_ref()
310                                                .map(|f| f.code == CloseCode::Normal)
311                                                .unwrap_or(true);
312                                            if clean && !reconnect {
313                                                tracing::info!("websocket source: server closed (1000), stopping");
314                                                stop = true;
315                                            } else {
316                                                tracing::warn!(?frame, "websocket source: connection closed");
317                                                reconnect_now = true;
318                                            }
319                                        }
320                                    }
321                                }
322                                Ok(Some(Err(e))) => {
323                                    tracing::warn!(error = %e, "websocket source: read error");
324                                    reconnect_now = true;
325                                }
326                                Ok(None) => {
327                                    tracing::warn!("websocket source: stream ended");
328                                    reconnect_now = true;
329                                }
330                                Err(_elapsed) => {
331                                    if let Some(d) = idle_deadline
332                                        && Instant::now() >= d
333                                    {
334                                        tracing::debug!("websocket source: idle_timeout reached, stopping");
335                                        stop = true;
336                                    }
337                                }
338                            }
339                        }
340                    }
341
342                    if let Some(e) = fatal {
343                        Err(e)?;
344                    }
345
346                    if !buffer.is_empty() && buffer.len() >= page_chunk {
347                        let page = std::mem::take(&mut buffer);
348                        yield StreamPage { records: page, bookmark: None };
349                    }
350
351                    if stop {
352                        break 'outer;
353                    }
354
355                    if reconnect_now {
356                        // At-most-once across a reconnect: a live WebSocket has no
357                        // replayable offset, so `continue 'outer` re-connects and
358                        // re-subscribes to the *current* stream — any frames the
359                        // server pushed during the disconnect gap are not replayed
360                        // and are lost. Inherent to live feeds (documented in the
361                        // README under "Not resumable"), not a bug.
362                        if reconnect && max_attempts.is_none_or(|m| reconnect_attempts < m) {
363                            reconnect_attempts += 1;
364                            tracing::warn!(attempt = reconnect_attempts, "websocket source: reconnecting");
365                            tokio::time::sleep(backoff).await;
366                            continue 'outer;
367                        } else if reconnect {
368                            Err(FaucetError::Source(format!(
369                                "websocket source: exceeded max_reconnect_attempts ({})",
370                                max_attempts.unwrap_or(0)
371                            )))?;
372                        } else {
373                            Err(FaucetError::Source(
374                                "websocket source: connection closed and reconnect=false".into(),
375                            ))?;
376                        }
377                    }
378                }
379            }
380
381            if !buffer.is_empty() {
382                yield StreamPage { records: buffer, bookmark: None };
383            }
384
385            tracing::info!(messages = total, "websocket source: stream complete");
386        })
387    }
388
389    fn config_schema(&self) -> Value {
390        let schema = schemars::schema_for!(WebsocketSourceConfig);
391        serde_json::to_value(&schema).unwrap_or(Value::Null)
392    }
393
394    fn connector_name(&self) -> &'static str {
395        "websocket"
396    }
397
398    /// Preflight probe that does **not** open the WebSocket stream.
399    ///
400    /// The default `Source::check` would call `stream_pages`, which connects,
401    /// sends subscribe frames, and then blocks waiting for inbound frames until
402    /// `max_messages` / `idle_timeout` — useless as a fast preflight. Instead we
403    /// only verify TCP reachability of the configured endpoint: parse the
404    /// `ws://`/`wss://` URL, resolve host + port (default 80 for `ws`, 443 for
405    /// `wss`), open a [`tokio::net::TcpStream`] (no WS upgrade handshake, no
406    /// auth, no frames), and immediately drop it.
407    async fn check(
408        &self,
409        ctx: &faucet_core::check::CheckContext,
410    ) -> Result<faucet_core::check::CheckReport, FaucetError> {
411        use faucet_core::check::{CheckReport, Probe};
412
413        let start = std::time::Instant::now();
414
415        // Resolve host + port from the configured URL. The config is validated
416        // at construction (`ws://`/`wss://`), so parse failures here are
417        // probe-level failures, never panics.
418        let (host, port) = match resolve_host_port(&self.config.url) {
419            Ok(hp) => hp,
420            Err(reason) => {
421                return Ok(CheckReport::single(Probe::fail_hint(
422                    "network",
423                    start.elapsed(),
424                    reason,
425                    "url must be ws://host[:port]/... or wss://host[:port]/...",
426                )));
427            }
428        };
429
430        let connect = tokio::net::TcpStream::connect((host.as_str(), port));
431        match tokio::time::timeout(ctx.timeout, connect).await {
432            Ok(Ok(stream)) => {
433                drop(stream);
434                Ok(CheckReport::single(Probe::pass("network", start.elapsed())))
435            }
436            Ok(Err(e)) => Ok(CheckReport::single(Probe::fail_hint(
437                "network",
438                start.elapsed(),
439                e.to_string(),
440                format!("cannot reach {host}:{port} over TCP"),
441            ))),
442            Err(_elapsed) => Ok(CheckReport::single(Probe::fail_hint(
443                "network",
444                start.elapsed(),
445                format!("TCP connect to {host}:{port} timed out"),
446                format!("{host}:{port} did not accept a connection within the check timeout"),
447            ))),
448        }
449    }
450}
451
452/// Parse a `ws://`/`wss://` URL into `(host, port)`, applying the scheme's
453/// default port (80 for `ws`, 443 for `wss`) when none is specified.
454///
455/// Returns the human-readable reason string on failure (never leaks the full
456/// URL, which may carry query-string credentials — only the host is surfaced).
457fn resolve_host_port(url: &str) -> Result<(String, u16), String> {
458    let request = url
459        .into_client_request()
460        .map_err(|e| format!("invalid websocket url: {e}"))?;
461    let uri = request.uri();
462    let host = uri
463        .host()
464        .filter(|h| !h.is_empty())
465        .ok_or_else(|| "websocket url is missing a host".to_string())?
466        .to_string();
467    let default_port = match uri.scheme_str() {
468        Some("wss") => 443,
469        _ => 80,
470    };
471    let port = uri.port_u16().unwrap_or(default_port);
472    Ok((host, port))
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478    use std::collections::BTreeMap;
479
480    #[test]
481    fn credential_bearer_maps_to_bearer() {
482        let auth = credential_to_auth(Credential::Bearer("tok".into()));
483        assert_eq!(
484            auth,
485            WebsocketAuth::Bearer {
486                token: "tok".into()
487            }
488        );
489    }
490
491    #[test]
492    fn credential_token_maps_to_custom_authorization() {
493        let auth = credential_to_auth(Credential::Token("Custom xyz".into()));
494        assert_eq!(
495            auth,
496            WebsocketAuth::Custom {
497                headers: BTreeMap::from([("Authorization".into(), "Custom xyz".into())])
498            }
499        );
500    }
501
502    #[test]
503    fn credential_header_maps_to_custom() {
504        let auth = credential_to_auth(Credential::Header {
505            name: "X-Api-Key".into(),
506            value: "k123".into(),
507        });
508        assert_eq!(
509            auth,
510            WebsocketAuth::Custom {
511                headers: BTreeMap::from([("X-Api-Key".into(), "k123".into())])
512            }
513        );
514    }
515
516    #[test]
517    fn credential_basic_maps_to_base64_authorization() {
518        let auth = credential_to_auth(Credential::Basic {
519            username: "user".into(),
520            password: "pass".into(),
521        });
522        // base64("user:pass") == "dXNlcjpwYXNz"
523        assert_eq!(
524            auth,
525            WebsocketAuth::Custom {
526                headers: BTreeMap::from([("Authorization".into(), "Basic dXNlcjpwYXNz".into())])
527            }
528        );
529    }
530
531    #[test]
532    fn resolve_host_port_applies_scheme_defaults() {
533        assert_eq!(
534            resolve_host_port("ws://example.com/feed").unwrap(),
535            ("example.com".to_string(), 80)
536        );
537        assert_eq!(
538            resolve_host_port("wss://example.com/feed").unwrap(),
539            ("example.com".to_string(), 443)
540        );
541        assert_eq!(
542            resolve_host_port("wss://example.com:9443/feed").unwrap(),
543            ("example.com".to_string(), 9443)
544        );
545    }
546
547    #[tokio::test]
548    async fn check_passes_against_a_live_tcp_listener() {
549        use faucet_core::check::{CheckContext, ProbeStatus};
550
551        // Bind a real TCP listener and point the source's URL at it. The probe
552        // only needs the TCP handshake to succeed — no WS upgrade is performed,
553        // so a plain listener is enough.
554        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
555        let addr = listener.local_addr().unwrap();
556
557        let config = WebsocketSourceConfig {
558            url: format!("ws://{addr}/feed"),
559            auth: AuthSpec::Inline(WebsocketAuth::None),
560            subscribe_messages: vec![],
561            message_format: crate::config::WsMessageFormat::Json,
562            on_parse_error: crate::config::OnParseError::Fail,
563            envelope: false,
564            ping_interval: None,
565            max_messages: Some(1),
566            idle_timeout: None,
567            reconnect: false,
568            reconnect_backoff: Duration::from_secs(1),
569            max_reconnect_attempts: None,
570            max_message_bytes: None,
571            batch_size: faucet_core::DEFAULT_BATCH_SIZE,
572        };
573        let source = WebsocketSource::new(config).unwrap();
574        let report = source.check(&CheckContext::default()).await.unwrap();
575        assert_eq!(report.probes.len(), 1);
576        assert_eq!(report.probes[0].name, "network");
577        assert!(
578            matches!(report.probes[0].status, ProbeStatus::Pass),
579            "expected Pass, got {:?}",
580            report.probes[0].status
581        );
582    }
583
584    #[tokio::test]
585    async fn check_fails_against_a_closed_port() {
586        use faucet_core::check::{CheckContext, ProbeStatus};
587
588        // Bind then drop a listener to obtain a port that is (almost certainly)
589        // closed, so the connect is refused quickly.
590        let addr = {
591            let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
592            listener.local_addr().unwrap()
593        };
594
595        let config = WebsocketSourceConfig {
596            url: format!("ws://{addr}/feed"),
597            auth: AuthSpec::Inline(WebsocketAuth::None),
598            subscribe_messages: vec![],
599            message_format: crate::config::WsMessageFormat::Json,
600            on_parse_error: crate::config::OnParseError::Fail,
601            envelope: false,
602            ping_interval: None,
603            max_messages: Some(1),
604            idle_timeout: None,
605            reconnect: false,
606            reconnect_backoff: Duration::from_secs(1),
607            max_reconnect_attempts: None,
608            max_message_bytes: None,
609            batch_size: faucet_core::DEFAULT_BATCH_SIZE,
610        };
611        let source = WebsocketSource::new(config).unwrap();
612        let report = source
613            .check(&CheckContext {
614                timeout: Duration::from_secs(2),
615            })
616            .await
617            .unwrap();
618        assert_eq!(report.probes.len(), 1);
619        assert_eq!(report.probes[0].name, "network");
620        assert!(
621            matches!(report.probes[0].status, ProbeStatus::Fail { .. }),
622            "expected Fail, got {:?}",
623            report.probes[0].status
624        );
625        assert_eq!(report.failed_count(), 1);
626    }
627}