1use std::collections::{VecDeque, HashMap};
11use std::time::{Duration, Instant};
12
13#[derive(Debug, Clone)]
16pub enum WsMessage {
17 Text(String),
18 Binary(Vec<u8>),
19 Close { code: u16, reason: String },
21 Ping(Vec<u8>),
23 Pong(Vec<u8>),
25}
26
27impl WsMessage {
28 pub fn text(s: impl Into<String>) -> Self { Self::Text(s.into()) }
29 pub fn binary(v: Vec<u8>) -> Self { Self::Binary(v) }
30 pub fn close_normal() -> Self { Self::Close { code: 1000, reason: "Normal closure".into() } }
31
32 pub fn is_data(&self) -> bool {
33 matches!(self, Self::Text(_) | Self::Binary(_))
34 }
35
36 pub fn as_text(&self) -> Option<&str> {
37 if let Self::Text(s) = self { Some(s) } else { None }
38 }
39
40 pub fn len(&self) -> usize {
41 match self {
42 Self::Text(s) => s.len(),
43 Self::Binary(v) => v.len(),
44 _ => 0,
45 }
46 }
47
48 pub fn is_empty(&self) -> bool { self.len() == 0 }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum WsState {
55 Disconnected,
57 Connecting,
59 Handshaking,
61 Connected,
63 ReconnectBackoff,
65 Closed,
67}
68
69impl WsState {
70 pub fn is_connected(self) -> bool { self == Self::Connected }
71 pub fn is_live(self) -> bool { matches!(self, Self::Connected | Self::Handshaking) }
72}
73
74#[derive(Debug, Clone)]
77pub enum WsEvent {
78 Connected { url: String },
80 Disconnected { url: String, code: u16, reason: String, will_reconnect: bool },
82 Message { message: WsMessage, channel: Option<String> },
84 Reconnecting { url: String, attempt: u32, backoff_ms: u64 },
86 Error { description: String },
88 PingRtt { millis: f64 },
90}
91
92#[derive(Debug, Clone)]
95struct OutboundMessage {
96 message: WsMessage,
97 channel: Option<String>,
98 queued_at: Instant,
99}
100
101pub struct WsClient {
105 pub url: String,
106 pub state: WsState,
107 pub auto_reconnect: bool,
109 pub max_reconnects: u32,
111 pub ping_interval: Duration,
113 pub pong_timeout: Duration,
115 pub max_queue_size: usize,
117
118 reconnect_attempt: u32,
119 reconnect_timer: f32,
120 reconnect_backoff: f32,
121 last_ping: Option<Instant>,
122 last_pong: Option<Instant>,
123 ping_payload: Vec<u8>,
124 outbound: VecDeque<OutboundMessage>,
125 events: VecDeque<WsEvent>,
126 channels: HashMap<String, ChannelConfig>,
128 connect_time: Option<Instant>,
129 messages_sent: u64,
130 messages_received: u64,
131 bytes_sent: u64,
132 bytes_received: u64,
133 sim_connected_at: Option<Instant>,
135}
136
137#[derive(Debug, Clone)]
138pub struct ChannelConfig {
139 pub name: String,
140 pub filter: Option<String>,
141 pub active: bool,
142}
143
144impl WsClient {
145 pub fn new(url: impl Into<String>) -> Self {
146 Self {
147 url: url.into(),
148 state: WsState::Disconnected,
149 auto_reconnect: true,
150 max_reconnects: 10,
151 ping_interval: Duration::from_secs(30),
152 pong_timeout: Duration::from_secs(10),
153 max_queue_size: 1024,
154 reconnect_attempt: 0,
155 reconnect_timer: 0.0,
156 reconnect_backoff: 1.0,
157 last_ping: None,
158 last_pong: None,
159 ping_payload: vec![1, 2, 3, 4],
160 outbound: VecDeque::new(),
161 events: VecDeque::new(),
162 channels: HashMap::new(),
163 connect_time: None,
164 messages_sent: 0,
165 messages_received: 0,
166 bytes_sent: 0,
167 bytes_received: 0,
168 sim_connected_at: None,
169 }
170 }
171
172 pub fn connect(&mut self) {
176 if matches!(self.state, WsState::Disconnected | WsState::Closed) {
177 self.state = WsState::Connecting;
178 self.reconnect_attempt = 0;
179 }
180 }
181
182 pub fn close(&mut self) {
184 if self.state != WsState::Closed {
185 self.send_raw(WsMessage::close_normal(), None);
186 self.state = WsState::Closed;
187 }
188 }
189
190 pub fn disconnect(&mut self) {
192 self.state = WsState::Disconnected;
193 self.sim_connected_at = None;
194 }
195
196 pub fn send(&mut self, message: WsMessage) -> bool {
200 self.send_raw(message, None)
201 }
202
203 pub fn send_on_channel(&mut self, channel: &str, message: WsMessage) -> bool {
205 self.send_raw(message, Some(channel.to_owned()))
206 }
207
208 fn send_raw(&mut self, message: WsMessage, channel: Option<String>) -> bool {
209 if self.outbound.len() >= self.max_queue_size {
210 self.events.push_back(WsEvent::Error {
211 description: "Outbound queue full, message dropped".into(),
212 });
213 return false;
214 }
215 let len = message.len();
216 self.outbound.push_back(OutboundMessage {
217 message,
218 channel,
219 queued_at: Instant::now(),
220 });
221 self.bytes_sent += len as u64;
222 true
223 }
224
225 pub fn subscribe(&mut self, channel: impl Into<String>) {
229 let name = channel.into();
230 self.channels.insert(name.clone(), ChannelConfig {
231 name,
232 filter: None,
233 active: true,
234 });
235 }
236
237 pub fn unsubscribe(&mut self, channel: &str) {
238 self.channels.remove(channel);
239 }
240
241 pub fn tick(&mut self, dt: f32) {
245 match self.state {
246 WsState::Disconnected => {
247 if self.auto_reconnect
248 && (self.max_reconnects == 0 || self.reconnect_attempt < self.max_reconnects)
249 {
250 self.reconnect_timer -= dt;
251 if self.reconnect_timer <= 0.0 {
252 self.state = WsState::Connecting;
253 self.events.push_back(WsEvent::Reconnecting {
254 url: self.url.clone(),
255 attempt: self.reconnect_attempt,
256 backoff_ms: (self.reconnect_backoff * 1000.0) as u64,
257 });
258 }
259 }
260 }
261
262 WsState::Connecting => {
263 self.state = WsState::Handshaking;
265 self.sim_connected_at = Some(Instant::now());
266 }
267
268 WsState::Handshaking => {
269 if let Some(t) = self.sim_connected_at {
271 if t.elapsed() >= Duration::from_millis(10) {
272 self.state = WsState::Connected;
273 self.connect_time = Some(Instant::now());
274 self.reconnect_attempt = 0;
275 self.reconnect_backoff = 1.0;
276 self.events.push_back(WsEvent::Connected { url: self.url.clone() });
277 }
278 }
279 }
280
281 WsState::Connected => {
282 while let Some(msg) = self.outbound.pop_front() {
284 self.messages_sent += 1;
285 }
287
288 let should_ping = match self.last_ping {
290 None => true,
291 Some(t) => t.elapsed() >= self.ping_interval,
292 };
293 if should_ping {
294 self.send_raw(WsMessage::Ping(self.ping_payload.clone()), None);
295 self.last_ping = Some(Instant::now());
296 }
297
298 if let (Some(ping_t), None) = (self.last_ping, self.last_pong) {
300 if ping_t.elapsed() > self.pong_timeout {
301 self.handle_disconnect(1001, "Pong timeout".into());
302 }
303 }
304 }
305
306 WsState::ReconnectBackoff => {
307 self.reconnect_timer -= dt;
308 if self.reconnect_timer <= 0.0 {
309 self.state = WsState::Connecting;
310 }
311 }
312
313 WsState::Closed => {}
314 }
315 }
316
317 fn handle_disconnect(&mut self, code: u16, reason: String) {
318 let will_reconnect = self.auto_reconnect
319 && (self.max_reconnects == 0 || self.reconnect_attempt < self.max_reconnects);
320
321 self.events.push_back(WsEvent::Disconnected {
322 url: self.url.clone(),
323 code,
324 reason,
325 will_reconnect,
326 });
327
328 if will_reconnect {
329 self.reconnect_attempt += 1;
330 self.reconnect_backoff = (self.reconnect_backoff * 2.0).min(60.0);
332 self.reconnect_timer = self.reconnect_backoff;
333 self.state = WsState::ReconnectBackoff;
334 } else {
335 self.state = WsState::Closed;
336 }
337 self.sim_connected_at = None;
338 }
339
340 pub fn drain_events(&mut self) -> impl Iterator<Item = WsEvent> + '_ {
343 self.events.drain(..)
344 }
345
346 pub fn is_connected(&self) -> bool { self.state.is_connected() }
347 pub fn messages_sent(&self) -> u64 { self.messages_sent }
348 pub fn messages_received(&self) -> u64 { self.messages_received }
349 pub fn bytes_sent(&self) -> u64 { self.bytes_sent }
350 pub fn bytes_received(&self) -> u64 { self.bytes_received }
351 pub fn uptime(&self) -> Option<Duration> { self.connect_time.map(|t| t.elapsed()) }
352 pub fn pending_outbound(&self) -> usize { self.outbound.len() }
353}