use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU32, Ordering};
use std::time::Duration;
use dashmap::DashMap;
use pushwire_core::{ChannelKind, Frame, SystemOp};
use tokio::sync::{Notify, mpsc};
use tracing::{debug, info, warn};
use uuid::Uuid;
use crate::connection::{ActiveTransport, InboundMsg, connect_with_preference};
use crate::cursor::{CursorResult, CursorTracker};
use crate::dispatch::ChannelReceiver;
use crate::reconnect::ReconnectPolicy;
use crate::subscription::SubscriptionTracker;
pub use crate::connection::TransportPreference;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum ConnectionState {
Disconnected = 0,
Connecting = 1,
Connected = 2,
Resuming = 3,
}
#[non_exhaustive]
pub struct ClientConfig {
pub url: String,
pub client_id: Uuid,
pub token: Option<String>,
pub reconnect: ReconnectPolicy,
pub transport_preference: TransportPreference,
pub binary_mode: bool,
}
impl ClientConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
client_id: Uuid::new_v4(),
token: None,
reconnect: ReconnectPolicy::default(),
transport_preference: TransportPreference::WsFirst,
binary_mode: false,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ConnectError {
#[error("transport error: {0}")]
Transport(String),
#[error("auth rejected: {0}")]
AuthRejected(String),
#[error("timeout")]
Timeout,
}
#[derive(Debug, thiserror::Error)]
pub enum SendError {
#[error("not connected")]
NotConnected,
#[error("channel closed")]
ChannelClosed,
#[error("serialization error: {0}")]
Serialize(#[from] serde_json::Error),
}
pub struct PushClient<C: ChannelKind> {
config: ClientConfig,
cursors: Arc<CursorTracker<C>>,
receivers: Arc<DashMap<C, Arc<dyn ChannelReceiver<C>>>>,
subscriptions: Arc<SubscriptionTracker<C>>,
state: Arc<AtomicU8>,
transport: Option<ActiveTransport<C>>,
shutdown: Arc<Notify>,
processor_handle: Option<tokio::task::JoinHandle<()>>,
}
impl<C: ChannelKind> PushClient<C> {
pub fn new(config: ClientConfig) -> Self {
Self {
config,
cursors: Arc::new(CursorTracker::new()),
receivers: Arc::new(DashMap::new()),
subscriptions: Arc::new(SubscriptionTracker::new()),
state: Arc::new(AtomicU8::new(ConnectionState::Disconnected as u8)),
transport: None,
shutdown: Arc::new(Notify::new()),
processor_handle: None,
}
}
pub fn on(&mut self, channel: C, receiver: impl ChannelReceiver<C>) {
self.subscriptions.subscribe(&[channel]);
self.receivers.insert(channel, Arc::new(receiver));
}
pub async fn connect(&mut self) -> Result<(), ConnectError> {
self.set_state(ConnectionState::Connecting);
let capabilities = self.subscriptions.active();
let resume_cursors = self.cursors.export();
let (transport, inbound_rx) = connect_with_preference(
self.config.transport_preference,
&self.config.url,
self.config.client_id,
self.config.token.as_deref(),
&capabilities,
resume_cursors,
)
.await?;
self.transport = Some(transport);
self.set_state(ConnectionState::Connected);
self.spawn_processor(inbound_rx);
info!(client_id = ?self.config.client_id, "connected");
Ok(())
}
pub async fn send(&self, frame: Frame<C>) -> Result<(), SendError> {
if self.state() != ConnectionState::Connected {
return Err(SendError::NotConnected);
}
match &self.transport {
Some(t) => t.send_frame(frame).await,
None => Err(SendError::NotConnected),
}
}
pub async fn subscribe(&self, channels: &[C]) -> Result<(), SendError> {
if let Some(op) = self.subscriptions.subscribe(channels)
&& let Some(t) = &self.transport
{
t.send_system(op).await?;
}
Ok(())
}
pub async fn unsubscribe(&self, channels: &[C]) -> Result<(), SendError> {
if let Some(op) = self.subscriptions.unsubscribe(channels)
&& let Some(t) = &self.transport
{
t.send_system(op).await?;
}
Ok(())
}
pub async fn disconnect(&mut self) -> Result<(), SendError> {
self.shutdown.notify_waiters();
if let Some(t) = &self.transport {
let _ = t.send_system(SystemOp::Goodbye { reason: None }).await;
}
if let Some(transport) = self.transport.take() {
transport.close().await;
}
if let Some(handle) = self.processor_handle.take() {
handle.abort();
}
self.set_state(ConnectionState::Disconnected);
info!(client_id = ?self.config.client_id, "disconnected");
Ok(())
}
pub fn state(&self) -> ConnectionState {
match self.state.load(Ordering::SeqCst) {
0 => ConnectionState::Disconnected,
1 => ConnectionState::Connecting,
2 => ConnectionState::Connected,
3 => ConnectionState::Resuming,
_ => ConnectionState::Disconnected,
}
}
pub fn cursors(&self) -> HashMap<C, u64> {
self.cursors.export()
}
fn set_state(&self, state: ConnectionState) {
self.state.store(state as u8, Ordering::SeqCst);
}
fn spawn_processor(&mut self, mut inbound_rx: mpsc::Receiver<InboundMsg<C>>) {
let cursors = self.cursors.clone();
let receivers = self.receivers.clone();
let state = self.state.clone();
let shutdown = self.shutdown.clone();
let reconnect_policy = self.config.reconnect.clone();
let url = self.config.url.clone();
let client_id = self.config.client_id;
let token = self.config.token.clone();
let transport_pref = self.config.transport_preference;
let subscriptions = self.subscriptions.clone();
let attempt_count = Arc::new(AtomicU32::new(0));
self.processor_handle = Some(tokio::spawn(async move {
loop {
tokio::select! {
_ = shutdown.notified() => {
debug!("processor: shutdown signal received");
break;
}
msg = inbound_rx.recv() => {
match msg {
Some(InboundMsg::Frame(frame)) => {
if let Some(cursor) = frame.cursor {
let result = cursors.advance(frame.channel, cursor);
if let CursorResult::GapDetected { expected, got } = result {
warn!(
channel = frame.channel.name(),
expected, got,
"cursor gap detected"
);
}
}
if let Some(receiver) = receivers.get(&frame.channel) {
receiver.on_frame(frame);
} else {
debug!(
channel = frame.channel.name(),
"no receiver for channel, dropping"
);
}
attempt_count.store(0, Ordering::SeqCst);
}
Some(InboundMsg::System(op)) => {
handle_system_op(&op);
attempt_count.store(0, Ordering::SeqCst);
}
Some(InboundMsg::Closed) | None => {
info!("transport closed");
state.store(
ConnectionState::Disconnected as u8,
Ordering::SeqCst,
);
let attempts = attempt_count.load(Ordering::SeqCst);
if !reconnect_policy.should_retry(attempts) {
info!("reconnect exhausted, staying disconnected");
break;
}
state.store(
ConnectionState::Resuming as u8,
Ordering::SeqCst,
);
let delay = reconnect_policy.delay_for_attempt(attempts);
let jittered = if reconnect_policy.jitter {
add_jitter(delay)
} else {
delay
};
info!(
attempt = attempts + 1,
delay_ms = jittered.as_millis(),
"reconnecting"
);
tokio::time::sleep(jittered).await;
let capabilities = subscriptions.active();
let resume = cursors.export();
match connect_with_preference(
transport_pref,
&url,
client_id,
token.as_deref(),
&capabilities,
resume,
)
.await
{
Ok((_transport, new_rx)) => {
inbound_rx = new_rx;
attempt_count.store(0, Ordering::SeqCst);
state.store(
ConnectionState::Connected as u8,
Ordering::SeqCst,
);
info!("reconnected successfully");
}
Err(e) => {
warn!(?e, "reconnect failed");
attempt_count.fetch_add(1, Ordering::SeqCst);
inbound_rx.close();
continue;
}
}
}
}
}
}
}
}));
}
}
fn handle_system_op<C: ChannelKind>(op: &SystemOp<C>) {
match op {
SystemOp::Ping => {
debug!("received application-level Ping");
}
SystemOp::Pong => {
debug!("received Pong");
}
SystemOp::Error { message } => {
warn!(message, "server error");
}
SystemOp::ResumeRequired {
channel,
from_cursor,
} => {
warn!(
channel = channel.name(),
from_cursor, "server requires full resync from cursor"
);
}
SystemOp::Goodbye { reason } => {
info!(?reason, "server goodbye");
}
SystemOp::Health { status, detail } => {
debug!(?status, ?detail, "server health");
}
other => {
debug!(?other, "unhandled system op");
}
}
}
fn add_jitter(delay: Duration) -> Duration {
use rand::Rng;
let jitter_range = delay.as_millis() as f64 * 0.25;
let jitter = rand::thread_rng().gen_range(-jitter_range..jitter_range);
let ms = (delay.as_millis() as f64 + jitter).max(0.0);
Duration::from_millis(ms as u64)
}