use std::collections::HashMap;
use futures_util::{SinkExt, StreamExt};
use pushwire_core::{ChannelKind, Frame, SystemOp};
use reqwest::Client as HttpClient;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tracing::{debug, warn};
use uuid::Uuid;
use crate::session::ConnectError;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransportPreference {
WsFirst,
SseFirst,
WsOnly,
SseOnly,
}
#[derive(Debug)]
pub(crate) enum OutboundMsg<C: ChannelKind> {
Frame(Frame<C>),
System(SystemOp<C>),
Close,
}
#[derive(Debug)]
pub(crate) enum InboundMsg<C: ChannelKind> {
Frame(Frame<C>),
System(SystemOp<C>),
Closed,
}
pub(crate) enum ActiveTransport<C: ChannelKind> {
WebSocket {
outbound_tx: mpsc::Sender<OutboundMsg<C>>,
reader_handle: JoinHandle<()>,
writer_handle: JoinHandle<()>,
},
Sse {
http: HttpClient,
ack_url: String,
client_id: Uuid,
reader_handle: JoinHandle<()>,
},
}
impl<C: ChannelKind> ActiveTransport<C> {
pub(crate) async fn send_frame(
&self,
frame: Frame<C>,
) -> Result<(), crate::session::SendError> {
match self {
ActiveTransport::WebSocket { outbound_tx, .. } => outbound_tx
.send(OutboundMsg::Frame(frame))
.await
.map_err(|_| crate::session::SendError::ChannelClosed),
ActiveTransport::Sse { .. } => Err(crate::session::SendError::NotConnected),
}
}
pub(crate) async fn send_system(
&self,
op: SystemOp<C>,
) -> Result<(), crate::session::SendError> {
match self {
ActiveTransport::WebSocket { outbound_tx, .. } => outbound_tx
.send(OutboundMsg::System(op))
.await
.map_err(|_| crate::session::SendError::ChannelClosed),
ActiveTransport::Sse {
http,
ack_url,
client_id,
..
} => {
if let SystemOp::Ack { channel, cursor } = &op {
let body = serde_json::json!({
"client_id": client_id,
"channel": channel,
"cursor": cursor,
});
let _ = http.post(ack_url).json(&body).send().await;
Ok(())
} else {
warn!("system op not supported in SSE mode, dropping");
Ok(())
}
}
}
}
pub(crate) async fn close(self) {
match self {
ActiveTransport::WebSocket {
outbound_tx,
reader_handle,
writer_handle,
} => {
let _ = outbound_tx.send(OutboundMsg::Close).await;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
reader_handle.abort();
writer_handle.abort();
}
ActiveTransport::Sse { reader_handle, .. } => {
reader_handle.abort();
}
}
}
}
pub(crate) async fn connect_ws<C: ChannelKind>(
url: &str,
client_id: Uuid,
token: Option<&str>,
capabilities: &[C],
resume_cursors: HashMap<C, u64>,
) -> Result<(ActiveTransport<C>, mpsc::Receiver<InboundMsg<C>>), ConnectError> {
let ws_url = http_to_ws_url(url);
let rps_url = format!("{ws_url}/rps");
let (ws_stream, _response) = tokio_tungstenite::connect_async(&rps_url)
.await
.map_err(|e| ConnectError::Transport(format!("WebSocket connect failed: {e}")))?;
let (mut ws_tx, mut ws_rx) = ws_stream.split();
let global_cursor = resume_cursors.values().copied().max();
let auth = SystemOp::<C>::Auth {
client_id,
token: token.map(String::from),
capabilities: capabilities.to_vec(),
resume_cursor: global_cursor,
resume_cursors: resume_cursors.clone(),
};
let auth_json =
serde_json::to_string(&auth).map_err(|e| ConnectError::Transport(e.to_string()))?;
ws_tx
.send(WsMessage::Text(auth_json))
.await
.map_err(|e| ConnectError::Transport(format!("failed to send auth: {e}")))?;
let auth_reply = ws_rx
.next()
.await
.ok_or(ConnectError::Transport(
"connection closed before auth reply".into(),
))?
.map_err(|e| ConnectError::Transport(format!("auth reply read error: {e}")))?;
let auth_ok: SystemOp<C> = match auth_reply {
WsMessage::Text(text) => serde_json::from_str(&text)
.map_err(|e| ConnectError::AuthRejected(format!("invalid auth reply: {e}")))?,
WsMessage::Close(frame) => {
let reason = frame
.map(|f| f.reason.to_string())
.unwrap_or_else(|| "unknown".into());
return Err(ConnectError::AuthRejected(reason));
}
other => {
return Err(ConnectError::Transport(format!(
"unexpected auth reply type: {other:?}"
)));
}
};
match auth_ok {
SystemOp::AuthOk { .. } => {
debug!(?client_id, "auth handshake complete");
}
SystemOp::Error { message } => return Err(ConnectError::AuthRejected(message)),
other => {
return Err(ConnectError::AuthRejected(format!(
"expected AuthOk, got {other:?}"
)));
}
}
let (inbound_tx, inbound_rx) = mpsc::channel::<InboundMsg<C>>(256);
let (outbound_tx, mut outbound_rx) = mpsc::channel::<OutboundMsg<C>>(64);
let reader_inbound_tx = inbound_tx.clone();
let reader_handle = tokio::spawn(async move {
while let Some(msg) = ws_rx.next().await {
match msg {
Ok(WsMessage::Text(text)) => {
if let Ok(frame) = serde_json::from_str::<Frame<C>>(&text) {
if frame.channel.is_system() {
if let Ok(op) =
serde_json::from_value::<SystemOp<C>>(frame.payload.clone())
{
if reader_inbound_tx
.send(InboundMsg::System(op))
.await
.is_err()
{
break;
}
} else {
if reader_inbound_tx
.send(InboundMsg::Frame(frame))
.await
.is_err()
{
break;
}
}
} else if reader_inbound_tx
.send(InboundMsg::Frame(frame))
.await
.is_err()
{
break;
}
} else if let Ok(op) = serde_json::from_str::<SystemOp<C>>(&text) {
if reader_inbound_tx
.send(InboundMsg::System(op))
.await
.is_err()
{
break;
}
} else {
warn!("failed to parse inbound WS message");
}
}
Ok(WsMessage::Close(_)) => {
let _ = reader_inbound_tx.send(InboundMsg::Closed).await;
break;
}
Ok(WsMessage::Ping(_) | WsMessage::Pong(_)) => {
}
Ok(_) => {}
Err(e) => {
warn!(?e, "WS read error");
let _ = reader_inbound_tx.send(InboundMsg::Closed).await;
break;
}
}
}
});
let writer_handle = tokio::spawn(async move {
while let Some(msg) = outbound_rx.recv().await {
let ws_msg = match msg {
OutboundMsg::Frame(frame) => match serde_json::to_string(&frame) {
Ok(json) => WsMessage::Text(json),
Err(e) => {
warn!(?e, "failed to serialize outbound frame");
continue;
}
},
OutboundMsg::System(op) => match serde_json::to_string(&op) {
Ok(json) => WsMessage::Text(json),
Err(e) => {
warn!(?e, "failed to serialize outbound system op");
continue;
}
},
OutboundMsg::Close => {
let _ = ws_tx.send(WsMessage::Close(None)).await;
break;
}
};
if ws_tx.send(ws_msg).await.is_err() {
break;
}
}
});
let transport = ActiveTransport::WebSocket {
outbound_tx,
reader_handle,
writer_handle,
};
Ok((transport, inbound_rx))
}
pub(crate) async fn connect_sse<C: ChannelKind>(
url: &str,
client_id: Uuid,
token: Option<&str>,
capabilities: &[C],
resume_cursor: Option<u64>,
) -> Result<(ActiveTransport<C>, mpsc::Receiver<InboundMsg<C>>), ConnectError> {
let http = HttpClient::new();
let mut sse_url = format!("{url}/rps/sse?client_id={client_id}");
if let Some(tok) = token {
sse_url.push_str(&format!("&token={tok}"));
}
if !capabilities.is_empty() {
let caps: Vec<&str> = capabilities.iter().map(|c| c.name()).collect();
sse_url.push_str(&format!("&capabilities={}", caps.join(",")));
sse_url.push_str(&format!("&channels={}", caps.join(",")));
}
if let Some(cursor) = resume_cursor {
sse_url.push_str(&format!("&resume_cursor={cursor}"));
}
let response = http
.get(&sse_url)
.send()
.await
.map_err(|e| ConnectError::Transport(format!("SSE connect failed: {e}")))?;
if !response.status().is_success() {
return Err(ConnectError::AuthRejected(format!(
"SSE returned {}",
response.status()
)));
}
let (inbound_tx, inbound_rx) = mpsc::channel::<InboundMsg<C>>(256);
let ack_url = format!("{url}/rps/ack");
let reader_handle = tokio::spawn(async move {
let mut stream = response.bytes_stream();
let mut buffer = String::new();
let mut event_type = String::new();
let mut data_lines = Vec::<String>::new();
while let Some(chunk) = stream.next().await {
let bytes = match chunk {
Ok(b) => b,
Err(e) => {
warn!(?e, "SSE stream error");
let _ = inbound_tx.send(InboundMsg::Closed).await;
break;
}
};
buffer.push_str(&String::from_utf8_lossy(&bytes));
while let Some(newline_pos) = buffer.find('\n') {
let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
buffer = buffer[newline_pos + 1..].to_string();
if line.is_empty() {
if !data_lines.is_empty() && (event_type == "frame" || event_type.is_empty()) {
let data = data_lines.join("\n");
if let Ok(frame) = serde_json::from_str::<Frame<C>>(&data) {
if frame.channel.is_system() {
if let Ok(op) =
serde_json::from_value::<SystemOp<C>>(frame.payload.clone())
{
let _ = inbound_tx.send(InboundMsg::System(op)).await;
} else {
let _ = inbound_tx.send(InboundMsg::Frame(frame)).await;
}
} else {
let _ = inbound_tx.send(InboundMsg::Frame(frame)).await;
}
}
}
event_type.clear();
data_lines.clear();
} else if let Some(value) = line.strip_prefix("event:") {
event_type = value.trim().to_string();
} else if let Some(value) = line.strip_prefix("data:") {
data_lines.push(value.trim_start().to_string());
}
}
}
});
let transport = ActiveTransport::Sse {
http: HttpClient::new(),
ack_url,
client_id,
reader_handle,
};
Ok((transport, inbound_rx))
}
pub(crate) async fn connect_with_preference<C: ChannelKind>(
preference: TransportPreference,
url: &str,
client_id: Uuid,
token: Option<&str>,
capabilities: &[C],
resume_cursors: HashMap<C, u64>,
) -> Result<(ActiveTransport<C>, mpsc::Receiver<InboundMsg<C>>), ConnectError> {
let global_cursor = resume_cursors.values().copied().max();
match preference {
TransportPreference::WsOnly => {
connect_ws(url, client_id, token, capabilities, resume_cursors).await
}
TransportPreference::SseOnly => {
connect_sse(url, client_id, token, capabilities, global_cursor).await
}
TransportPreference::WsFirst => {
match connect_ws(url, client_id, token, capabilities, resume_cursors.clone()).await {
Ok(result) => Ok(result),
Err(ws_err) => {
debug!(?ws_err, "WS failed, falling back to SSE");
connect_sse(url, client_id, token, capabilities, global_cursor).await
}
}
}
TransportPreference::SseFirst => {
match connect_sse(url, client_id, token, capabilities, global_cursor).await {
Ok(result) => Ok(result),
Err(sse_err) => {
debug!(?sse_err, "SSE failed, falling back to WS");
connect_ws(url, client_id, token, capabilities, resume_cursors).await
}
}
}
}
}
fn http_to_ws_url(url: &str) -> String {
if let Some(rest) = url.strip_prefix("http://") {
format!("ws://{rest}")
} else if let Some(rest) = url.strip_prefix("https://") {
format!("wss://{rest}")
} else if url.starts_with("ws://") || url.starts_with("wss://") {
url.to_string()
} else {
format!("ws://{url}")
}
}