use std::pin::Pin;
use std::task::{Context, Poll};
use futures::{Stream, StreamExt};
use tokio_tungstenite::tungstenite::Message;
#[derive(Debug, Clone)]
pub enum NotificationStreamEvent {
Init { unread_count: u64 },
Other(serde_json::Value),
}
fn parse_event(raw: &str) -> Result<NotificationStreamEvent, serde_json::Error> {
let value: serde_json::Value = serde_json::from_str(raw)?;
if let Some(ty) = value.get("type").and_then(|v| v.as_str()) {
if ty == "init" {
if let Some(count) = value.get("unread_count").and_then(|v| v.as_u64()) {
return Ok(NotificationStreamEvent::Init {
unread_count: count,
});
}
}
}
Ok(NotificationStreamEvent::Other(value))
}
#[derive(Debug, thiserror::Error)]
pub enum WsError {
#[error("invalid base URL: {0}")]
Url(#[from] url::ParseError),
#[error("WebSocket error: {0}")]
Ws(#[from] tokio_tungstenite::tungstenite::Error),
#[error("failed to parse event JSON: {0}")]
Parse(#[from] serde_json::Error),
}
pub async fn stream_notifications(
base_url: &str,
token: &str,
) -> Result<impl Stream<Item = Result<NotificationStreamEvent, WsError>>, WsError> {
let ws_url = build_ws_url(base_url, "/api/v1/notifications/stream", token)?;
let request = tokio_tungstenite::tungstenite::http::Request::builder()
.method("GET")
.uri(ws_url.as_str())
.header("Authorization", format!("Bearer {token}"))
.header("User-Agent", concat!("nucel-sdk-api/", env!("CARGO_PKG_VERSION")))
.header("Host", ws_url.host_str().unwrap_or(""))
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header(
"Sec-WebSocket-Key",
tokio_tungstenite::tungstenite::handshake::client::generate_key(),
)
.body(())
.expect("valid request");
let (ws, _resp) = tokio_tungstenite::connect_async(request).await?;
Ok(NotificationStream { inner: ws })
}
struct NotificationStream<S> {
inner: S,
}
impl<S> Stream for NotificationStream<S>
where
S: Stream<Item = Result<Message, tokio_tungstenite::tungstenite::Error>> + Unpin,
{
type Item = Result<NotificationStreamEvent, WsError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match futures::ready!(self.inner.poll_next_unpin(cx)) {
Some(Ok(Message::Text(text))) => {
return Poll::Ready(Some(parse_event(&text).map_err(Into::into)));
}
Some(Ok(Message::Binary(bin))) => match std::str::from_utf8(&bin) {
Ok(text) => {
return Poll::Ready(Some(parse_event(text).map_err(Into::into)));
}
Err(_) => continue,
},
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue,
Some(Ok(Message::Close(_))) | None => return Poll::Ready(None),
Some(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
}
}
}
}
pub fn build_ws_url(base_url: &str, path: &str, token: &str) -> Result<url::Url, url::ParseError> {
let mut url = url::Url::parse(base_url)?.join(path)?;
let scheme = if url.scheme() == "https" { "wss" } else { "ws" };
let _ = url.set_scheme(scheme);
url.query_pairs_mut().append_pair("token", token);
Ok(url)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_ws_url_converts_https_to_wss() {
let url = build_ws_url("https://nucel.dev", "/api/v1/notifications/stream", "tok").unwrap();
assert_eq!(url.scheme(), "wss");
assert_eq!(url.host_str(), Some("nucel.dev"));
assert_eq!(url.path(), "/api/v1/notifications/stream");
assert!(url.query().unwrap().contains("token=tok"));
}
#[test]
fn build_ws_url_converts_http_to_ws() {
let url = build_ws_url("http://localhost:17321", "/api/v1/notifications/stream", "x").unwrap();
assert_eq!(url.scheme(), "ws");
assert_eq!(url.port(), Some(17321));
}
#[test]
fn parse_event_init() {
let raw = r#"{"type":"init","unread_count":7}"#;
match parse_event(raw).unwrap() {
NotificationStreamEvent::Init { unread_count } => assert_eq!(unread_count, 7),
_ => panic!("expected Init"),
}
}
#[test]
fn parse_event_unknown_falls_back_to_other() {
let raw = r#"{"type":"issue_opened","issue_id":"123"}"#;
match parse_event(raw).unwrap() {
NotificationStreamEvent::Other(v) => assert_eq!(v["issue_id"], "123"),
_ => panic!("expected Other"),
}
}
}