use std::pin::Pin;
use std::task::{Context, Poll};
use syncable_ag_ui_core::{Event, JsonValue};
use futures::{SinkExt, Stream};
use tokio_tungstenite::{
connect_async,
tungstenite::{self, Message},
MaybeTlsStream, WebSocketStream,
};
use crate::error::{ClientError, Result};
#[derive(Debug, Clone)]
pub struct WsConfig {
pub headers: Vec<(String, String)>,
pub auto_pong: bool,
}
impl Default for WsConfig {
fn default() -> Self {
Self {
headers: Vec::new(),
auto_pong: true,
}
}
}
impl WsConfig {
pub fn new() -> Self {
Self::default()
}
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.push((name.into(), value.into()));
self
}
pub fn bearer_token(self, token: impl Into<String>) -> Self {
self.header("Authorization", format!("Bearer {}", token.into()))
}
pub fn disable_auto_pong(mut self) -> Self {
self.auto_pong = false;
self
}
}
pub struct WsClient {
socket: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
auto_pong: bool,
}
impl WsClient {
pub async fn connect(url: &str) -> Result<Self> {
Self::connect_with_config(url, WsConfig::default()).await
}
pub async fn connect_with_config(url: &str, config: WsConfig) -> Result<Self> {
let mut request = tungstenite::http::Request::builder()
.uri(url)
.header("Host", extract_host(url)?)
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header(
"Sec-WebSocket-Key",
tungstenite::handshake::client::generate_key(),
);
for (name, value) in config.headers {
request = request.header(name, value);
}
let request = request
.body(())
.map_err(|e| ClientError::connection(e.to_string()))?;
let (socket, _response) = connect_async(request)
.await
.map_err(|e| ClientError::connection(e.to_string()))?;
Ok(Self {
socket,
auto_pong: config.auto_pong,
})
}
pub fn into_stream(self) -> WsEventStream {
WsEventStream {
socket: self.socket,
auto_pong: self.auto_pong,
}
}
pub async fn close(mut self) -> Result<()> {
self.socket
.close(None)
.await
.map_err(|e| ClientError::connection(e.to_string()))
}
}
pub struct WsEventStream {
socket: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
auto_pong: bool,
}
impl Stream for WsEventStream {
type Item = Result<Event<JsonValue>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match Pin::new(&mut self.socket).poll_next(cx) {
Poll::Ready(Some(Ok(msg))) => {
match msg {
Message::Text(text) => {
match serde_json::from_str::<Event<JsonValue>>(&text) {
Ok(event) => return Poll::Ready(Some(Ok(event))),
Err(e) => {
return Poll::Ready(Some(Err(ClientError::parse(format!(
"failed to parse event: {}",
e
)))))
}
}
}
Message::Ping(data) => {
if self.auto_pong {
let mut socket = Pin::new(&mut self.socket);
let _ = socket.start_send_unpin(Message::Pong(data));
}
continue;
}
Message::Pong(_) => {
continue;
}
Message::Close(_) => {
return Poll::Ready(None);
}
Message::Binary(_) | Message::Frame(_) => {
continue;
}
}
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(ClientError::WebSocket(e))))
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
}
}
}
}
fn extract_host(url: &str) -> Result<String> {
let url = url::Url::parse(url).map_err(|e| ClientError::InvalidUrl(e.to_string()))?;
let host = url
.host_str()
.ok_or_else(|| ClientError::InvalidUrl("missing host".to_string()))?;
match url.port() {
Some(port) => Ok(format!("{}:{}", host, port)),
None => Ok(host.to_string()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ws_config_default() {
let config = WsConfig::default();
assert!(config.headers.is_empty());
assert!(config.auto_pong);
}
#[test]
fn test_ws_config_builder() {
let config = WsConfig::new()
.header("X-Custom", "value")
.bearer_token("token123")
.disable_auto_pong();
assert_eq!(config.headers.len(), 2);
assert_eq!(config.headers[0], ("X-Custom".to_string(), "value".to_string()));
assert_eq!(
config.headers[1],
("Authorization".to_string(), "Bearer token123".to_string())
);
assert!(!config.auto_pong);
}
#[test]
fn test_extract_host() {
assert_eq!(extract_host("ws://localhost:3000/ws").unwrap(), "localhost:3000");
assert_eq!(extract_host("wss://example.com/events").unwrap(), "example.com");
assert_eq!(
extract_host("ws://api.example.com:8080/stream").unwrap(),
"api.example.com:8080"
);
}
#[test]
fn test_extract_host_invalid() {
assert!(extract_host("not a url").is_err());
}
}