Skip to main content

nucel_sdk_api/
ws.rs

1//! WebSocket helpers for endpoints that stream events over a persistent
2//! connection (e.g. `/api/v1/notifications/stream`).
3//!
4//! Built on [`tokio_tungstenite`] with rustls + webpki roots so it works out
5//! of the box against TLS deployments. Exposes a [`futures::Stream`] of
6//! parsed JSON events — callers decide how to handle each one.
7//!
8//! # Example
9//!
10//! ```no_run
11//! # async fn run() -> Result<(), Box<dyn std::error::Error>> {
12//! use futures::StreamExt;
13//! use nucel_sdk_api::ws::{NotificationStreamEvent, stream_notifications};
14//!
15//! let mut events = stream_notifications("https://nucel.dev", "ghp_token").await?;
16//! while let Some(Ok(ev)) = events.next().await {
17//!     match ev {
18//!         NotificationStreamEvent::Init { unread_count } => {
19//!             println!("connected, {unread_count} unread");
20//!         }
21//!         NotificationStreamEvent::Other(v) => {
22//!             println!("event: {v}");
23//!         }
24//!     }
25//! }
26//! # Ok(())
27//! # }
28//! ```
29
30use std::pin::Pin;
31use std::task::{Context, Poll};
32
33use futures::{Stream, StreamExt};
34use tokio_tungstenite::tungstenite::Message;
35
36/// An event on the notification stream.
37///
38/// The server sends `{"type":"init", ...}` as the first frame and arbitrary
39/// event shapes afterwards. This enum handles the one well-known frame and
40/// preserves everything else as raw JSON so new server-side event types
41/// don't break older clients.
42#[derive(Debug, Clone)]
43pub enum NotificationStreamEvent {
44    /// First frame on connect. Contains the user's unread notification count.
45    Init { unread_count: u64 },
46    /// Any other event kind. The raw JSON is preserved.
47    Other(serde_json::Value),
48}
49
50fn parse_event(raw: &str) -> Result<NotificationStreamEvent, serde_json::Error> {
51    let value: serde_json::Value = serde_json::from_str(raw)?;
52    if let Some(ty) = value.get("type").and_then(|v| v.as_str()) {
53        if ty == "init" {
54            if let Some(count) = value.get("unread_count").and_then(|v| v.as_u64()) {
55                return Ok(NotificationStreamEvent::Init {
56                    unread_count: count,
57                });
58            }
59        }
60    }
61    Ok(NotificationStreamEvent::Other(value))
62}
63
64/// Errors the WebSocket stream can produce.
65#[derive(Debug, thiserror::Error)]
66pub enum WsError {
67    #[error("invalid base URL: {0}")]
68    Url(#[from] url::ParseError),
69    #[error("WebSocket error: {0}")]
70    Ws(#[from] tokio_tungstenite::tungstenite::Error),
71    #[error("failed to parse event JSON: {0}")]
72    Parse(#[from] serde_json::Error),
73}
74
75// `thiserror` is a dev-only addition; if the caller doesn't want the extra
76// dep we still provide a manual `From` for convenience via the crate root.
77
78/// Open a WebSocket to `/api/v1/notifications/stream` and return a stream
79/// of parsed events. The `token` is sent as both an `Authorization: Bearer`
80/// header AND a `?token=` query param so the call works against any server
81/// config.
82pub async fn stream_notifications(
83    base_url: &str,
84    token: &str,
85) -> Result<impl Stream<Item = Result<NotificationStreamEvent, WsError>>, WsError> {
86    let ws_url = build_ws_url(base_url, "/api/v1/notifications/stream", token)?;
87    let request = tokio_tungstenite::tungstenite::http::Request::builder()
88        .method("GET")
89        .uri(ws_url.as_str())
90        .header("Authorization", format!("Bearer {token}"))
91        .header("User-Agent", concat!("nucel-sdk-api/", env!("CARGO_PKG_VERSION")))
92        .header("Host", ws_url.host_str().unwrap_or(""))
93        .header("Connection", "Upgrade")
94        .header("Upgrade", "websocket")
95        .header("Sec-WebSocket-Version", "13")
96        .header(
97            "Sec-WebSocket-Key",
98            tokio_tungstenite::tungstenite::handshake::client::generate_key(),
99        )
100        .body(())
101        .expect("valid request");
102    let (ws, _resp) = tokio_tungstenite::connect_async(request).await?;
103    Ok(NotificationStream { inner: ws })
104}
105
106struct NotificationStream<S> {
107    inner: S,
108}
109
110impl<S> Stream for NotificationStream<S>
111where
112    S: Stream<Item = Result<Message, tokio_tungstenite::tungstenite::Error>> + Unpin,
113{
114    type Item = Result<NotificationStreamEvent, WsError>;
115
116    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
117        loop {
118            match futures::ready!(self.inner.poll_next_unpin(cx)) {
119                Some(Ok(Message::Text(text))) => {
120                    return Poll::Ready(Some(parse_event(&text).map_err(Into::into)));
121                }
122                Some(Ok(Message::Binary(bin))) => match std::str::from_utf8(&bin) {
123                    Ok(text) => {
124                        return Poll::Ready(Some(parse_event(text).map_err(Into::into)));
125                    }
126                    Err(_) => continue,
127                },
128                Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue,
129                Some(Ok(Message::Close(_))) | None => return Poll::Ready(None),
130                Some(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
131            }
132        }
133    }
134}
135
136/// Turn `http(s)://host[:port]` + a path into a `ws(s)://host[:port]/path?token=...`.
137pub fn build_ws_url(base_url: &str, path: &str, token: &str) -> Result<url::Url, url::ParseError> {
138    let mut url = url::Url::parse(base_url)?.join(path)?;
139    let scheme = if url.scheme() == "https" { "wss" } else { "ws" };
140    let _ = url.set_scheme(scheme);
141    url.query_pairs_mut().append_pair("token", token);
142    Ok(url)
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn build_ws_url_converts_https_to_wss() {
151        let url = build_ws_url("https://nucel.dev", "/api/v1/notifications/stream", "tok").unwrap();
152        assert_eq!(url.scheme(), "wss");
153        assert_eq!(url.host_str(), Some("nucel.dev"));
154        assert_eq!(url.path(), "/api/v1/notifications/stream");
155        assert!(url.query().unwrap().contains("token=tok"));
156    }
157
158    #[test]
159    fn build_ws_url_converts_http_to_ws() {
160        let url = build_ws_url("http://localhost:17321", "/api/v1/notifications/stream", "x").unwrap();
161        assert_eq!(url.scheme(), "ws");
162        assert_eq!(url.port(), Some(17321));
163    }
164
165    #[test]
166    fn parse_event_init() {
167        let raw = r#"{"type":"init","unread_count":7}"#;
168        match parse_event(raw).unwrap() {
169            NotificationStreamEvent::Init { unread_count } => assert_eq!(unread_count, 7),
170            _ => panic!("expected Init"),
171        }
172    }
173
174    #[test]
175    fn parse_event_unknown_falls_back_to_other() {
176        let raw = r#"{"type":"issue_opened","issue_id":"123"}"#;
177        match parse_event(raw).unwrap() {
178            NotificationStreamEvent::Other(v) => assert_eq!(v["issue_id"], "123"),
179            _ => panic!("expected Other"),
180        }
181    }
182}