1use std::pin::Pin;
31use std::task::{Context, Poll};
32
33use futures::{Stream, StreamExt};
34use tokio_tungstenite::tungstenite::Message;
35
36#[derive(Debug, Clone)]
43pub enum NotificationStreamEvent {
44 Init { unread_count: u64 },
46 Other(serde_json::Value),
48}
49
50fn parse_event(raw: &str) -> Result<NotificationStreamEvent, serde_json::Error> {
51 let value: serde_json::Value = serde_json::from_str(raw)?;
52 if let Some(ty) = value.get("type").and_then(|v| v.as_str()) {
53 if ty == "init" {
54 if let Some(count) = value.get("unread_count").and_then(|v| v.as_u64()) {
55 return Ok(NotificationStreamEvent::Init {
56 unread_count: count,
57 });
58 }
59 }
60 }
61 Ok(NotificationStreamEvent::Other(value))
62}
63
64#[derive(Debug, thiserror::Error)]
66pub enum WsError {
67 #[error("invalid base URL: {0}")]
68 Url(#[from] url::ParseError),
69 #[error("WebSocket error: {0}")]
70 Ws(#[from] tokio_tungstenite::tungstenite::Error),
71 #[error("failed to parse event JSON: {0}")]
72 Parse(#[from] serde_json::Error),
73}
74
75pub async fn stream_notifications(
83 base_url: &str,
84 token: &str,
85) -> Result<impl Stream<Item = Result<NotificationStreamEvent, WsError>>, WsError> {
86 let ws_url = build_ws_url(base_url, "/api/v1/notifications/stream", token)?;
87 let request = tokio_tungstenite::tungstenite::http::Request::builder()
88 .method("GET")
89 .uri(ws_url.as_str())
90 .header("Authorization", format!("Bearer {token}"))
91 .header("User-Agent", concat!("nucel-sdk-api/", env!("CARGO_PKG_VERSION")))
92 .header("Host", ws_url.host_str().unwrap_or(""))
93 .header("Connection", "Upgrade")
94 .header("Upgrade", "websocket")
95 .header("Sec-WebSocket-Version", "13")
96 .header(
97 "Sec-WebSocket-Key",
98 tokio_tungstenite::tungstenite::handshake::client::generate_key(),
99 )
100 .body(())
101 .expect("valid request");
102 let (ws, _resp) = tokio_tungstenite::connect_async(request).await?;
103 Ok(NotificationStream { inner: ws })
104}
105
106struct NotificationStream<S> {
107 inner: S,
108}
109
110impl<S> Stream for NotificationStream<S>
111where
112 S: Stream<Item = Result<Message, tokio_tungstenite::tungstenite::Error>> + Unpin,
113{
114 type Item = Result<NotificationStreamEvent, WsError>;
115
116 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
117 loop {
118 match futures::ready!(self.inner.poll_next_unpin(cx)) {
119 Some(Ok(Message::Text(text))) => {
120 return Poll::Ready(Some(parse_event(&text).map_err(Into::into)));
121 }
122 Some(Ok(Message::Binary(bin))) => match std::str::from_utf8(&bin) {
123 Ok(text) => {
124 return Poll::Ready(Some(parse_event(text).map_err(Into::into)));
125 }
126 Err(_) => continue,
127 },
128 Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue,
129 Some(Ok(Message::Close(_))) | None => return Poll::Ready(None),
130 Some(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
131 }
132 }
133 }
134}
135
136pub fn build_ws_url(base_url: &str, path: &str, token: &str) -> Result<url::Url, url::ParseError> {
138 let mut url = url::Url::parse(base_url)?.join(path)?;
139 let scheme = if url.scheme() == "https" { "wss" } else { "ws" };
140 let _ = url.set_scheme(scheme);
141 url.query_pairs_mut().append_pair("token", token);
142 Ok(url)
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn build_ws_url_converts_https_to_wss() {
151 let url = build_ws_url("https://nucel.dev", "/api/v1/notifications/stream", "tok").unwrap();
152 assert_eq!(url.scheme(), "wss");
153 assert_eq!(url.host_str(), Some("nucel.dev"));
154 assert_eq!(url.path(), "/api/v1/notifications/stream");
155 assert!(url.query().unwrap().contains("token=tok"));
156 }
157
158 #[test]
159 fn build_ws_url_converts_http_to_ws() {
160 let url = build_ws_url("http://localhost:17321", "/api/v1/notifications/stream", "x").unwrap();
161 assert_eq!(url.scheme(), "ws");
162 assert_eq!(url.port(), Some(17321));
163 }
164
165 #[test]
166 fn parse_event_init() {
167 let raw = r#"{"type":"init","unread_count":7}"#;
168 match parse_event(raw).unwrap() {
169 NotificationStreamEvent::Init { unread_count } => assert_eq!(unread_count, 7),
170 _ => panic!("expected Init"),
171 }
172 }
173
174 #[test]
175 fn parse_event_unknown_falls_back_to_other() {
176 let raw = r#"{"type":"issue_opened","issue_id":"123"}"#;
177 match parse_event(raw).unwrap() {
178 NotificationStreamEvent::Other(v) => assert_eq!(v["issue_id"], "123"),
179 _ => panic!("expected Other"),
180 }
181 }
182}