1use super::{
2 Command, Config, DisconnectReason, Event, Handle,
3 state::{FrameResult, HeartbeatState, Sink, State},
4};
5use crate::{
6 serde::{deserialize_json, serialize_json},
7 ws::ReceivedMessage,
8};
9use futures_util::{SinkExt, StreamExt};
10use serde::{Serialize, de::DeserializeOwned};
11use std::fmt::Debug;
12use std::time::Duration;
13use tokio::{
14 sync::mpsc,
15 time::{Instant, sleep, timeout},
16};
17use tokio_tungstenite::{connect_async, tungstenite::Message};
18use tracing::{debug, error, info, warn};
19
20pub struct Stream<C, M>
28where
29 C: Serialize + Send + Debug + 'static,
30 M: ReceivedMessage + DeserializeOwned + Send + Debug + 'static,
31{
32 config: Config,
33 cmd_rx: mpsc::Receiver<Command<C>>,
34 evt_tx: mpsc::Sender<Event<M>>,
35}
36
37impl<C, M> Stream<C, M>
38where
39 C: Serialize + Send + Debug + 'static,
40 M: ReceivedMessage + DeserializeOwned + Send + Debug + 'static,
41{
42 #[allow(clippy::new_ret_no_self)]
43 pub fn new(config: Config) -> (Handle<C>, mpsc::Receiver<Event<M>>) {
44 let (cmd_tx, cmd_rx) = mpsc::channel::<Command<C>>(config.command_queue_size);
45 let (evt_tx, evt_rx) = mpsc::channel::<Event<M>>(config.event_queue_size);
46
47 let stream = Self {
48 config,
49 cmd_rx,
50 evt_tx,
51 };
52
53 tokio::spawn(stream.run());
54
55 (Handle::<C>::new(cmd_tx), evt_rx)
56 }
57
58 async fn run(mut self) {
59 info!("stream started");
60 let mut state = State::Idle;
61
62 loop {
63 state = match state {
64 State::Idle => self.step_idle().await,
65 State::Connecting { attempt } => self.step_connecting(attempt).await,
66 State::Connected {
67 frame_rx,
68 read_task,
69 sink,
70 } => self.step_connected(frame_rx, read_task, sink).await,
71 State::Reconnecting { attempt, delay_ms } => {
72 self.step_reconnecting(attempt, delay_ms).await
73 }
74 State::Closing { sink } => self.step_closing(sink).await,
75 State::Done => break,
76 };
77 }
78
79 info!("stream shut down");
80 }
81
82 fn emit(&self, event: Event<M>) {
83 if let Err(e) = self.evt_tx.try_send(event) {
84 match e {
85 mpsc::error::TrySendError::Full(dropped) => {
86 warn!("event queue full, dropping event: {:?}", dropped);
87 }
88 mpsc::error::TrySendError::Closed(_) => {
89 debug!("event receiver dropped");
90 }
91 }
92 }
93 }
94
95 async fn step_idle(&mut self) -> State {
96 loop {
97 match self.cmd_rx.recv().await {
98 Some(Command::Connect) => {
99 return State::Connecting { attempt: 1 };
100 }
101 None => return State::Done,
102 Some(Command::Disconnect) => {
103 warn!("Command disconnect ignored - not connected");
104 }
105 Some(Command::Send(_)) => {
106 warn!("Send ignored - not connected");
107 }
108 }
109 }
110 }
111
112 async fn step_connecting(&mut self, attempt: u32) -> State {
113 debug!(attempt, "connecting…");
114
115 match connect_async(&self.config.url).await {
116 Ok((ws_stream, _)) => {
117 info!("websocket connected");
118 self.emit(Event::Connected);
119
120 let (sink, stream) = ws_stream.split();
121 let (frame_tx, frame_rx) =
122 mpsc::channel::<FrameResult>(self.config.event_queue_size);
123
124 let read_task = tokio::spawn(async move {
125 let mut stream = stream;
126 while let Some(msg) = stream.next().await {
127 if frame_tx.send(msg).await.is_err() {
128 break;
129 }
130 }
131 });
132
133 State::Connected {
134 frame_rx,
135 read_task,
136 sink: Box::new(sink),
137 }
138 }
139 Err(e) => {
140 error!(error = %e, attempt, "connection failed");
141 self.next_reconnect_state(attempt + 1, e.to_string())
142 }
143 }
144 }
145
146 async fn step_connected(
147 &mut self,
148 mut frame_rx: mpsc::Receiver<FrameResult>,
149 read_task: tokio::task::JoinHandle<()>,
150 mut sink: Sink,
151 ) -> State {
152 let ping_interval = self.config.ping_interval;
153 let pong_timeout_dur = self.config.pong_timeout;
154 let ttl_dur = self.config.connection_ttl;
155
156 let mut ping_timer = Box::pin(sleep(ping_interval));
162 let mut pong_timeout = Box::pin(sleep(FAR_FUTURE));
163 let mut ttl_timer = Box::pin(sleep(ttl_dur));
164 let mut hb = HeartbeatState::Idle;
165
166 loop {
167 tokio::select! {
168 biased;
169
170 frame = frame_rx.recv() => match frame {
171 None => {
172 info!("remote closed the connection");
173 read_task.abort();
174 return self.next_reconnect_state(1, "remote closed".into());
175 }
176 Some(Err(e)) => {
177 error!(error = %e, "websocket read error");
178 read_task.abort();
179 return self.next_reconnect_state(1, e.to_string());
180 }
181 Some(Ok(msg)) => match msg {
182 Message::Ping(bytes) => {
183 debug!("protocol ping received ({}B)", bytes.len());
184 if let Err(e) = sink.send(Message::Pong(bytes)).await {
185 error!(error = %e, "send protocol pong failed");
186 read_task.abort();
187 return self.next_reconnect_state(1, e.to_string());
188 }
189 hb = HeartbeatState::PongSent;
190 ping_timer.as_mut().reset(far_future_instant());
191 pong_timeout.as_mut().reset(Instant::now() + pong_timeout_dur);
192 }
193 Message::Text(json) => {
194 match deserialize_json::<M>(&json) {
195 Ok(msg) => {
196 if msg.server_shutdown_event_time().is_some() {
197 info!("server shutdown notice received, initiating reconnect");
198 self.emit(Event::Message(msg));
199 read_task.abort();
200 return self.next_reconnect_state(1, "server shutdown".into());
201 } else {
202 self.emit(Event::Message(msg))
203 }
204 }
205 Err(e) => {
206 warn!(error = %e, "parsing IncomingMessage failed");
207 self.emit(Event::ParseError(e.to_string()));
208 }
209 }
210 }
211 Message::Pong(bytes) => debug!("pong received ({}B)", bytes.len()),
212 Message::Binary(bytes) => debug!("binary message received ({}B)", bytes.len()),
213 Message::Close(close_frame) => {
214 debug!(?close_frame, "close frame received");
215 read_task.abort();
216 return self.next_reconnect_state(1, "remote close frame".into());
217 }
218 Message::Frame(frame) => debug!("frame received ({}B)", frame.len()),
219 },
220 },
221
222 cmd = self.cmd_rx.recv() => match cmd {
223 None | Some(Command::Disconnect) => {
224 info!("disconnect requested");
225 read_task.abort();
226 return State::Closing { sink };
227 }
228 Some(Command::Send(msg)) => {
229 let json = serialize_json(&msg).expect("serialize outgoing message failed");
230 let msg = Message::Text(json.into());
231 if let Err(e) = sink.send(msg).await {
232 error!(error = %e, "send error");
233 read_task.abort();
234 return self.next_reconnect_state(1, e.to_string());
235 }
236 }
237 Some(Command::Connect) => warn!("Connect ignored - already connected")
238 },
239
240 _ = ping_timer.as_mut(), if matches!(hb, HeartbeatState::Idle) => {
241 warn!("no ping received within ping_interval - connection assumed dead");
242 self.emit(Event::Disconnected { reason: DisconnectReason::PongTimeout });
243 read_task.abort();
244 return self.next_reconnect_state(1, "ping interval exceeded".into());
245 }
246
247 _ = pong_timeout.as_mut(), if matches!(hb, HeartbeatState::PongSent) => {
248 warn!("no ping received within pong_timeout after last pong - connection assumed dead");
249 self.emit(Event::Disconnected { reason: DisconnectReason::PongTimeout });
250 read_task.abort();
251 return self.next_reconnect_state(1, "pong timeout".into());
252 }
253
254 _ = ttl_timer.as_mut() => {
255 info!("connection TTL reached, reconnecting proactively");
256 read_task.abort();
257 return self.next_reconnect_state(1, "connection TTL reached".into());
258 }
259 }
260 }
261 }
262
263 async fn step_reconnecting(&mut self, attempt: u32, delay_ms: u64) -> State {
264 warn!(attempt, delay_ms, "waiting before reconnect");
265 self.emit(Event::Reconnecting { attempt, delay_ms });
266
267 let cancelled = tokio::select! {
268 _ = sleep(Duration::from_millis(delay_ms)) => false,
269 cmd = self.cmd_rx.recv() => matches!(cmd, None | Some(Command::Disconnect)),
270 };
271
272 if cancelled {
273 self.emit(Event::Disconnected {
274 reason: DisconnectReason::Requested,
275 });
276 State::Idle
277 } else {
278 State::Connecting { attempt }
279 }
280 }
281
282 async fn step_closing(&mut self, mut sink: Sink) -> State {
283 if let Err(e) = sink.send(Message::Close(None)).await {
284 error!(error = %e, "send close message failed");
285 }
286 if let Err(e) = timeout(self.config.close_timeout, self.cmd_rx.recv()).await {
287 error!(error = %e, "waiting for a clean close handshake failed");
288 }
289
290 self.emit(Event::Disconnected {
291 reason: DisconnectReason::Requested,
292 });
293 State::Idle
294 }
295
296 fn next_reconnect_state(&self, next_attempt: u32, reason: String) -> State {
297 if self.config.max_reconnect_attempts == 0
298 || next_attempt > self.config.max_reconnect_attempts
299 {
300 self.emit(Event::Disconnected {
301 reason: DisconnectReason::Error(String::from(
302 "all reconnection attempts have failed",
303 )),
304 });
305 return State::Idle;
306 }
307
308 let base_ms = self.config.reconnect_base_delay.as_millis() as u64;
309 let max_ms = self.config.reconnect_max_delay.as_millis() as u64;
310 let delay_ms = (base_ms.saturating_mul(1u64 << (next_attempt - 1).min(10))).min(max_ms);
311
312 debug!(next_attempt, delay_ms, reason, "scheduling reconnect");
313 State::Reconnecting {
314 attempt: next_attempt,
315 delay_ms,
316 }
317 }
318}
319
320const FAR_FUTURE: Duration = Duration::from_secs(u64::MAX / 4);
321
322#[inline]
323fn far_future_instant() -> Instant {
324 Instant::now() + FAR_FUTURE
325}