Skip to main content

proof_engine/networking/
websocket.rs

1//! WebSocket client with auto-reconnect, message queueing, and channel mux.
2//!
3//! Driven by `tick(dt)` each frame. Messages arrive as `WsEvent` values
4//! polled from `drain_events()`. Send messages with `send()`.
5//!
6//! Auto-reconnect: exponential backoff on disconnect.
7//! Message queue: outgoing messages buffered during disconnection.
8//! Channel mux: multiple logical channels on one connection.
9
10use std::collections::{VecDeque, HashMap};
11use std::time::{Duration, Instant};
12
13// ── WsMessage ─────────────────────────────────────────────────────────────────
14
15#[derive(Debug, Clone)]
16pub enum WsMessage {
17    Text(String),
18    Binary(Vec<u8>),
19    /// Close with optional code and reason.
20    Close { code: u16, reason: String },
21    /// Ping with payload.
22    Ping(Vec<u8>),
23    /// Pong response to a Ping.
24    Pong(Vec<u8>),
25}
26
27impl WsMessage {
28    pub fn text(s: impl Into<String>) -> Self { Self::Text(s.into()) }
29    pub fn binary(v: Vec<u8>) -> Self { Self::Binary(v) }
30    pub fn close_normal() -> Self { Self::Close { code: 1000, reason: "Normal closure".into() } }
31
32    pub fn is_data(&self) -> bool {
33        matches!(self, Self::Text(_) | Self::Binary(_))
34    }
35
36    pub fn as_text(&self) -> Option<&str> {
37        if let Self::Text(s) = self { Some(s) } else { None }
38    }
39
40    pub fn len(&self) -> usize {
41        match self {
42            Self::Text(s)   => s.len(),
43            Self::Binary(v) => v.len(),
44            _               => 0,
45        }
46    }
47
48    pub fn is_empty(&self) -> bool { self.len() == 0 }
49}
50
51// ── WsState ───────────────────────────────────────────────────────────────────
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum WsState {
55    /// Not connected. Awaiting connect() or auto-reconnect.
56    Disconnected,
57    /// TCP connection in progress.
58    Connecting,
59    /// WebSocket handshake in progress.
60    Handshaking,
61    /// Fully connected and ready.
62    Connected,
63    /// Waiting before reconnect attempt (backoff).
64    ReconnectBackoff,
65    /// Intentionally closed, will not reconnect.
66    Closed,
67}
68
69impl WsState {
70    pub fn is_connected(self) -> bool { self == Self::Connected }
71    pub fn is_live(self) -> bool { matches!(self, Self::Connected | Self::Handshaking) }
72}
73
74// ── WsEvent ───────────────────────────────────────────────────────────────────
75
76#[derive(Debug, Clone)]
77pub enum WsEvent {
78    /// Successfully connected and handshaked.
79    Connected { url: String },
80    /// Connection closed by server or error.
81    Disconnected { url: String, code: u16, reason: String, will_reconnect: bool },
82    /// Received a message from the server.
83    Message { message: WsMessage, channel: Option<String> },
84    /// Reconnect attempt starting.
85    Reconnecting { url: String, attempt: u32, backoff_ms: u64 },
86    /// Error during connection or message.
87    Error { description: String },
88    /// Ping round-trip time measured.
89    PingRtt { millis: f64 },
90}
91
92// ── ChannelMessage ────────────────────────────────────────────────────────────
93
94#[derive(Debug, Clone)]
95struct OutboundMessage {
96    message: WsMessage,
97    channel: Option<String>,
98    queued_at: Instant,
99}
100
101// ── WsClient ─────────────────────────────────────────────────────────────────
102
103/// WebSocket client driven by `tick()`.
104pub struct WsClient {
105    pub url:             String,
106    pub state:           WsState,
107    /// Auto-reconnect on unexpected disconnect.
108    pub auto_reconnect:  bool,
109    /// Max reconnect attempts before giving up (0 = unlimited).
110    pub max_reconnects:  u32,
111    /// Keepalive ping interval.
112    pub ping_interval:   Duration,
113    /// Close connection if no pong received within this time.
114    pub pong_timeout:    Duration,
115    /// Max size of outbound queue.
116    pub max_queue_size:  usize,
117
118    reconnect_attempt:   u32,
119    reconnect_timer:     f32,
120    reconnect_backoff:   f32,
121    last_ping:           Option<Instant>,
122    last_pong:           Option<Instant>,
123    ping_payload:        Vec<u8>,
124    outbound:            VecDeque<OutboundMessage>,
125    events:              VecDeque<WsEvent>,
126    /// Logical channels: name -> subscription filter
127    channels:            HashMap<String, ChannelConfig>,
128    connect_time:        Option<Instant>,
129    messages_sent:       u64,
130    messages_received:   u64,
131    bytes_sent:          u64,
132    bytes_received:      u64,
133    /// Simulated state (real impl would use TCP stream).
134    sim_connected_at:    Option<Instant>,
135}
136
137#[derive(Debug, Clone)]
138pub struct ChannelConfig {
139    pub name:    String,
140    pub filter:  Option<String>,
141    pub active:  bool,
142}
143
144impl WsClient {
145    pub fn new(url: impl Into<String>) -> Self {
146        Self {
147            url:               url.into(),
148            state:             WsState::Disconnected,
149            auto_reconnect:    true,
150            max_reconnects:    10,
151            ping_interval:     Duration::from_secs(30),
152            pong_timeout:      Duration::from_secs(10),
153            max_queue_size:    1024,
154            reconnect_attempt: 0,
155            reconnect_timer:   0.0,
156            reconnect_backoff: 1.0,
157            last_ping:         None,
158            last_pong:         None,
159            ping_payload:      vec![1, 2, 3, 4],
160            outbound:          VecDeque::new(),
161            events:            VecDeque::new(),
162            channels:          HashMap::new(),
163            connect_time:      None,
164            messages_sent:     0,
165            messages_received: 0,
166            bytes_sent:        0,
167            bytes_received:    0,
168            sim_connected_at:  None,
169        }
170    }
171
172    // ── Control ───────────────────────────────────────────────────────────────
173
174    /// Initiate a connection. No-op if already connected.
175    pub fn connect(&mut self) {
176        if matches!(self.state, WsState::Disconnected | WsState::Closed) {
177            self.state = WsState::Connecting;
178            self.reconnect_attempt = 0;
179        }
180    }
181
182    /// Close the connection intentionally (will not auto-reconnect).
183    pub fn close(&mut self) {
184        if self.state != WsState::Closed {
185            self.send_raw(WsMessage::close_normal(), None);
186            self.state = WsState::Closed;
187        }
188    }
189
190    /// Disconnect but allow auto-reconnect.
191    pub fn disconnect(&mut self) {
192        self.state = WsState::Disconnected;
193        self.sim_connected_at = None;
194    }
195
196    // ── Sending ───────────────────────────────────────────────────────────────
197
198    /// Send a message. Queued if not currently connected.
199    pub fn send(&mut self, message: WsMessage) -> bool {
200        self.send_raw(message, None)
201    }
202
203    /// Send a message on a named channel.
204    pub fn send_on_channel(&mut self, channel: &str, message: WsMessage) -> bool {
205        self.send_raw(message, Some(channel.to_owned()))
206    }
207
208    fn send_raw(&mut self, message: WsMessage, channel: Option<String>) -> bool {
209        if self.outbound.len() >= self.max_queue_size {
210            self.events.push_back(WsEvent::Error {
211                description: "Outbound queue full, message dropped".into(),
212            });
213            return false;
214        }
215        let len = message.len();
216        self.outbound.push_back(OutboundMessage {
217            message,
218            channel,
219            queued_at: Instant::now(),
220        });
221        self.bytes_sent += len as u64;
222        true
223    }
224
225    // ── Channels ──────────────────────────────────────────────────────────────
226
227    /// Subscribe to a named logical channel.
228    pub fn subscribe(&mut self, channel: impl Into<String>) {
229        let name = channel.into();
230        self.channels.insert(name.clone(), ChannelConfig {
231            name,
232            filter: None,
233            active: true,
234        });
235    }
236
237    pub fn unsubscribe(&mut self, channel: &str) {
238        self.channels.remove(channel);
239    }
240
241    // ── Tick ──────────────────────────────────────────────────────────────────
242
243    /// Drive the WebSocket state machine. Call once per frame.
244    pub fn tick(&mut self, dt: f32) {
245        match self.state {
246            WsState::Disconnected => {
247                if self.auto_reconnect
248                    && (self.max_reconnects == 0 || self.reconnect_attempt < self.max_reconnects)
249                {
250                    self.reconnect_timer -= dt;
251                    if self.reconnect_timer <= 0.0 {
252                        self.state = WsState::Connecting;
253                        self.events.push_back(WsEvent::Reconnecting {
254                            url: self.url.clone(),
255                            attempt: self.reconnect_attempt,
256                            backoff_ms: (self.reconnect_backoff * 1000.0) as u64,
257                        });
258                    }
259                }
260            }
261
262            WsState::Connecting => {
263                // Simulate connection establishment
264                self.state = WsState::Handshaking;
265                self.sim_connected_at = Some(Instant::now());
266            }
267
268            WsState::Handshaking => {
269                // Simulate handshake completion after brief delay
270                if let Some(t) = self.sim_connected_at {
271                    if t.elapsed() >= Duration::from_millis(10) {
272                        self.state = WsState::Connected;
273                        self.connect_time = Some(Instant::now());
274                        self.reconnect_attempt = 0;
275                        self.reconnect_backoff = 1.0;
276                        self.events.push_back(WsEvent::Connected { url: self.url.clone() });
277                    }
278                }
279            }
280
281            WsState::Connected => {
282                // Flush outbound queue
283                while let Some(msg) = self.outbound.pop_front() {
284                    self.messages_sent += 1;
285                    // In the real impl, write to the TCP stream here
286                }
287
288                // Keepalive ping
289                let should_ping = match self.last_ping {
290                    None    => true,
291                    Some(t) => t.elapsed() >= self.ping_interval,
292                };
293                if should_ping {
294                    self.send_raw(WsMessage::Ping(self.ping_payload.clone()), None);
295                    self.last_ping = Some(Instant::now());
296                }
297
298                // Pong timeout check
299                if let (Some(ping_t), None) = (self.last_ping, self.last_pong) {
300                    if ping_t.elapsed() > self.pong_timeout {
301                        self.handle_disconnect(1001, "Pong timeout".into());
302                    }
303                }
304            }
305
306            WsState::ReconnectBackoff => {
307                self.reconnect_timer -= dt;
308                if self.reconnect_timer <= 0.0 {
309                    self.state = WsState::Connecting;
310                }
311            }
312
313            WsState::Closed => {}
314        }
315    }
316
317    fn handle_disconnect(&mut self, code: u16, reason: String) {
318        let will_reconnect = self.auto_reconnect
319            && (self.max_reconnects == 0 || self.reconnect_attempt < self.max_reconnects);
320
321        self.events.push_back(WsEvent::Disconnected {
322            url: self.url.clone(),
323            code,
324            reason,
325            will_reconnect,
326        });
327
328        if will_reconnect {
329            self.reconnect_attempt += 1;
330            // Exponential backoff with cap at 60s
331            self.reconnect_backoff = (self.reconnect_backoff * 2.0).min(60.0);
332            self.reconnect_timer = self.reconnect_backoff;
333            self.state = WsState::ReconnectBackoff;
334        } else {
335            self.state = WsState::Closed;
336        }
337        self.sim_connected_at = None;
338    }
339
340    // ── Stats ─────────────────────────────────────────────────────────────────
341
342    pub fn drain_events(&mut self) -> impl Iterator<Item = WsEvent> + '_ {
343        self.events.drain(..)
344    }
345
346    pub fn is_connected(&self) -> bool { self.state.is_connected() }
347    pub fn messages_sent(&self) -> u64 { self.messages_sent }
348    pub fn messages_received(&self) -> u64 { self.messages_received }
349    pub fn bytes_sent(&self) -> u64 { self.bytes_sent }
350    pub fn bytes_received(&self) -> u64 { self.bytes_received }
351    pub fn uptime(&self) -> Option<Duration> { self.connect_time.map(|t| t.elapsed()) }
352    pub fn pending_outbound(&self) -> usize { self.outbound.len() }
353}