1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU8, AtomicU32, Ordering};
4use std::time::Duration;
5
6use dashmap::DashMap;
7use pushwire_core::{ChannelKind, Frame, SystemOp};
8use tokio::sync::{Notify, mpsc};
9use tracing::{debug, info, warn};
10use uuid::Uuid;
11
12use crate::connection::{ActiveTransport, InboundMsg, connect_with_preference};
13use crate::cursor::{CursorResult, CursorTracker};
14use crate::dispatch::ChannelReceiver;
15use crate::reconnect::ReconnectPolicy;
16use crate::subscription::SubscriptionTracker;
17
18pub use crate::connection::TransportPreference;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26#[repr(u8)]
27pub enum ConnectionState {
28 Disconnected = 0,
29 Connecting = 1,
30 Connected = 2,
31 Resuming = 3,
32}
33
34#[non_exhaustive]
40pub struct ClientConfig {
41 pub url: String,
42 pub client_id: Uuid,
43 pub token: Option<String>,
44 pub reconnect: ReconnectPolicy,
45 pub transport_preference: TransportPreference,
46 pub binary_mode: bool,
47}
48
49impl ClientConfig {
50 pub fn new(url: impl Into<String>) -> Self {
51 Self {
52 url: url.into(),
53 client_id: Uuid::new_v4(),
54 token: None,
55 reconnect: ReconnectPolicy::default(),
56 transport_preference: TransportPreference::WsFirst,
57 binary_mode: false,
58 }
59 }
60}
61
62#[derive(Debug, thiserror::Error)]
68pub enum ConnectError {
69 #[error("transport error: {0}")]
70 Transport(String),
71 #[error("auth rejected: {0}")]
72 AuthRejected(String),
73 #[error("timeout")]
74 Timeout,
75}
76
77#[derive(Debug, thiserror::Error)]
78pub enum SendError {
79 #[error("not connected")]
80 NotConnected,
81 #[error("channel closed")]
82 ChannelClosed,
83 #[error("serialization error: {0}")]
84 Serialize(#[from] serde_json::Error),
85}
86
87pub struct PushClient<C: ChannelKind> {
97 config: ClientConfig,
98 cursors: Arc<CursorTracker<C>>,
99 receivers: Arc<DashMap<C, Arc<dyn ChannelReceiver<C>>>>,
100 subscriptions: Arc<SubscriptionTracker<C>>,
101 state: Arc<AtomicU8>,
102 transport: Option<ActiveTransport<C>>,
103 shutdown: Arc<Notify>,
104 processor_handle: Option<tokio::task::JoinHandle<()>>,
105}
106
107impl<C: ChannelKind> PushClient<C> {
108 pub fn new(config: ClientConfig) -> Self {
109 Self {
110 config,
111 cursors: Arc::new(CursorTracker::new()),
112 receivers: Arc::new(DashMap::new()),
113 subscriptions: Arc::new(SubscriptionTracker::new()),
114 state: Arc::new(AtomicU8::new(ConnectionState::Disconnected as u8)),
115 transport: None,
116 shutdown: Arc::new(Notify::new()),
117 processor_handle: None,
118 }
119 }
120
121 pub fn on(&mut self, channel: C, receiver: impl ChannelReceiver<C>) {
123 self.subscriptions.subscribe(&[channel]);
124 self.receivers.insert(channel, Arc::new(receiver));
125 }
126
127 pub async fn connect(&mut self) -> Result<(), ConnectError> {
130 self.set_state(ConnectionState::Connecting);
131
132 let capabilities = self.subscriptions.active();
133 let resume_cursors = self.cursors.export();
134
135 let (transport, inbound_rx) = connect_with_preference(
136 self.config.transport_preference,
137 &self.config.url,
138 self.config.client_id,
139 self.config.token.as_deref(),
140 &capabilities,
141 resume_cursors,
142 )
143 .await?;
144
145 self.transport = Some(transport);
146 self.set_state(ConnectionState::Connected);
147
148 self.spawn_processor(inbound_rx);
150
151 info!(client_id = ?self.config.client_id, "connected");
152 Ok(())
153 }
154
155 pub async fn send(&self, frame: Frame<C>) -> Result<(), SendError> {
157 if self.state() != ConnectionState::Connected {
158 return Err(SendError::NotConnected);
159 }
160 match &self.transport {
161 Some(t) => t.send_frame(frame).await,
162 None => Err(SendError::NotConnected),
163 }
164 }
165
166 pub async fn subscribe(&self, channels: &[C]) -> Result<(), SendError> {
168 if let Some(op) = self.subscriptions.subscribe(channels)
169 && let Some(t) = &self.transport
170 {
171 t.send_system(op).await?;
172 }
173 Ok(())
174 }
175
176 pub async fn unsubscribe(&self, channels: &[C]) -> Result<(), SendError> {
178 if let Some(op) = self.subscriptions.unsubscribe(channels)
179 && let Some(t) = &self.transport
180 {
181 t.send_system(op).await?;
182 }
183 Ok(())
184 }
185
186 pub async fn disconnect(&mut self) -> Result<(), SendError> {
188 self.shutdown.notify_waiters();
189
190 if let Some(t) = &self.transport {
191 let _ = t.send_system(SystemOp::Goodbye { reason: None }).await;
192 }
193
194 if let Some(transport) = self.transport.take() {
195 transport.close().await;
196 }
197
198 if let Some(handle) = self.processor_handle.take() {
199 handle.abort();
200 }
201
202 self.set_state(ConnectionState::Disconnected);
203 info!(client_id = ?self.config.client_id, "disconnected");
204 Ok(())
205 }
206
207 pub fn state(&self) -> ConnectionState {
209 match self.state.load(Ordering::SeqCst) {
210 0 => ConnectionState::Disconnected,
211 1 => ConnectionState::Connecting,
212 2 => ConnectionState::Connected,
213 3 => ConnectionState::Resuming,
214 _ => ConnectionState::Disconnected,
215 }
216 }
217
218 pub fn cursors(&self) -> HashMap<C, u64> {
220 self.cursors.export()
221 }
222
223 fn set_state(&self, state: ConnectionState) {
228 self.state.store(state as u8, Ordering::SeqCst);
229 }
230
231 fn spawn_processor(&mut self, mut inbound_rx: mpsc::Receiver<InboundMsg<C>>) {
232 let cursors = self.cursors.clone();
233 let receivers = self.receivers.clone();
234 let state = self.state.clone();
235 let shutdown = self.shutdown.clone();
236
237 let reconnect_policy = self.config.reconnect.clone();
240 let url = self.config.url.clone();
241 let client_id = self.config.client_id;
242 let token = self.config.token.clone();
243 let transport_pref = self.config.transport_preference;
244 let subscriptions = self.subscriptions.clone();
245 let attempt_count = Arc::new(AtomicU32::new(0));
246
247 self.processor_handle = Some(tokio::spawn(async move {
248 loop {
249 tokio::select! {
250 _ = shutdown.notified() => {
251 debug!("processor: shutdown signal received");
252 break;
253 }
254 msg = inbound_rx.recv() => {
255 match msg {
256 Some(InboundMsg::Frame(frame)) => {
257 if let Some(cursor) = frame.cursor {
259 let result = cursors.advance(frame.channel, cursor);
260 if let CursorResult::GapDetected { expected, got } = result {
261 warn!(
262 channel = frame.channel.name(),
263 expected, got,
264 "cursor gap detected"
265 );
266 }
267 }
268
269 if let Some(receiver) = receivers.get(&frame.channel) {
271 receiver.on_frame(frame);
272 } else {
273 debug!(
274 channel = frame.channel.name(),
275 "no receiver for channel, dropping"
276 );
277 }
278
279 attempt_count.store(0, Ordering::SeqCst);
281 }
282 Some(InboundMsg::System(op)) => {
283 handle_system_op(&op);
284 attempt_count.store(0, Ordering::SeqCst);
285 }
286 Some(InboundMsg::Closed) | None => {
287 info!("transport closed");
288 state.store(
289 ConnectionState::Disconnected as u8,
290 Ordering::SeqCst,
291 );
292
293 let attempts = attempt_count.load(Ordering::SeqCst);
295 if !reconnect_policy.should_retry(attempts) {
296 info!("reconnect exhausted, staying disconnected");
297 break;
298 }
299
300 state.store(
301 ConnectionState::Resuming as u8,
302 Ordering::SeqCst,
303 );
304
305 let delay = reconnect_policy.delay_for_attempt(attempts);
306 let jittered = if reconnect_policy.jitter {
307 add_jitter(delay)
308 } else {
309 delay
310 };
311 info!(
312 attempt = attempts + 1,
313 delay_ms = jittered.as_millis(),
314 "reconnecting"
315 );
316 tokio::time::sleep(jittered).await;
317
318 let capabilities = subscriptions.active();
319 let resume = cursors.export();
320
321 match connect_with_preference(
322 transport_pref,
323 &url,
324 client_id,
325 token.as_deref(),
326 &capabilities,
327 resume,
328 )
329 .await
330 {
331 Ok((_transport, new_rx)) => {
332 inbound_rx = new_rx;
337 attempt_count.store(0, Ordering::SeqCst);
338 state.store(
339 ConnectionState::Connected as u8,
340 Ordering::SeqCst,
341 );
342 info!("reconnected successfully");
343 }
344 Err(e) => {
345 warn!(?e, "reconnect failed");
346 attempt_count.fetch_add(1, Ordering::SeqCst);
347 inbound_rx.close();
350 continue;
351 }
352 }
353 }
354 }
355 }
356 }
357 }
358 }));
359 }
360}
361
362fn handle_system_op<C: ChannelKind>(op: &SystemOp<C>) {
363 match op {
364 SystemOp::Ping => {
365 debug!("received application-level Ping");
368 }
369 SystemOp::Pong => {
370 debug!("received Pong");
371 }
372 SystemOp::Error { message } => {
373 warn!(message, "server error");
374 }
375 SystemOp::ResumeRequired {
376 channel,
377 from_cursor,
378 } => {
379 warn!(
380 channel = channel.name(),
381 from_cursor, "server requires full resync from cursor"
382 );
383 }
384 SystemOp::Goodbye { reason } => {
385 info!(?reason, "server goodbye");
386 }
387 SystemOp::Health { status, detail } => {
388 debug!(?status, ?detail, "server health");
389 }
390 other => {
391 debug!(?other, "unhandled system op");
392 }
393 }
394}
395
396fn add_jitter(delay: Duration) -> Duration {
397 use rand::Rng;
398 let jitter_range = delay.as_millis() as f64 * 0.25;
399 let jitter = rand::thread_rng().gen_range(-jitter_range..jitter_range);
400 let ms = (delay.as_millis() as f64 + jitter).max(0.0);
401 Duration::from_millis(ms as u64)
402}