use atomr_streams::{RestartSettings, RestartSource, Source};
use bytes::Bytes;
use futures_util::StreamExt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WsFrame {
Text(String),
Binary(Bytes),
Ping(Bytes),
Pong(Bytes),
Close,
}
#[derive(Debug, thiserror::Error)]
pub enum WsError {
#[error("websocket connect error: {0}")]
Connect(String),
#[error("websocket protocol error: {0}")]
Protocol(String),
}
pub struct WsSource;
impl WsSource {
pub fn connect(url: url::Url, restart: RestartSettings) -> Source<Result<WsFrame, WsError>> {
RestartSource::with_backoff(restart, move || {
let url = url.clone();
connect_once(url)
})
}
}
fn connect_once(url: url::Url) -> Source<Result<WsFrame, WsError>> {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<Result<WsFrame, WsError>>();
tokio::spawn(async move {
match tokio_tungstenite::connect_async(url.as_str()).await {
Err(e) => {
let _ = tx.send(Err(WsError::Connect(e.to_string())));
}
Ok((ws, _resp)) => {
let (_write, mut read) = ws.split();
while let Some(msg) = read.next().await {
match msg {
Ok(m) => {
if let Some(frame) = map_message(m) {
if tx.send(Ok(frame)).is_err() {
return; }
}
}
Err(e) => {
let _ = tx.send(Err(WsError::Protocol(e.to_string())));
return; }
}
}
}
}
});
Source::from_receiver(rx)
}
fn map_message(m: tokio_tungstenite::tungstenite::Message) -> Option<WsFrame> {
use tokio_tungstenite::tungstenite::Message as M;
match m {
M::Text(s) => Some(WsFrame::Text(s)),
M::Binary(b) => Some(WsFrame::Binary(Bytes::from(b))),
M::Ping(b) => Some(WsFrame::Ping(Bytes::from(b))),
M::Pong(b) => Some(WsFrame::Pong(Bytes::from(b))),
M::Close(_) => Some(WsFrame::Close),
M::Frame(_) => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use atomr_streams::Sink;
use std::time::Duration;
use tokio::net::TcpListener;
fn fast_restart(max: usize) -> RestartSettings {
RestartSettings {
min_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(5),
random_factor: 0.0,
max_restarts: Some(max),
}
}
#[tokio::test]
async fn first_frame_is_text_against_local_server() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
if let Ok((stream, _)) = listener.accept().await {
if let Ok(mut ws) = tokio_tungstenite::accept_async(stream).await {
use futures_util::SinkExt;
use tokio_tungstenite::tungstenite::Message;
let _ = ws.send(Message::Text("hello".to_string())).await;
let _ = ws.close(None).await;
}
}
});
let url = url::Url::parse(&format!("ws://{addr}/")).unwrap();
let src = WsSource::connect(url, fast_restart(1));
let first = Sink::first(src).await.expect("expected one frame");
match first {
Ok(WsFrame::Text(s)) => assert_eq!(s, "hello"),
other => panic!("expected Ok(Text(\"hello\")), got {other:?}"),
}
}
#[tokio::test]
async fn refused_port_surfaces_connect_err() {
let url = url::Url::parse("ws://127.0.0.1:1/").unwrap();
let src = WsSource::connect(url, fast_restart(1));
let first = Sink::first(src).await.expect("expected one emission");
match first {
Err(WsError::Connect(_)) => {}
other => panic!("expected Err(Connect), got {other:?}"),
}
}
}