1use futures_util::{SinkExt, StreamExt};
2use std::time::Duration;
3use tokio::{
4 sync::mpsc,
5 time::{Instant, sleep, timeout},
6};
7use tokio_tungstenite::{connect_async, tungstenite::Message};
8use tracing::{debug, error, info, warn};
9
10use crate::{
11 serde::{deserialize_json, serialize_json},
12 ws::{IncomingMessage, OutgoingMessage},
13};
14
15use super::{
16 Command, Config, DisconnectReason, Event, Handle,
17 state::{FrameResult, HeartbeatState, Sink, State},
18};
19
20pub struct Stream {
21 config: Config,
22 cmd_rx: mpsc::Receiver<Command>,
23 evt_tx: mpsc::Sender<Event>,
24}
25
26impl Stream {
27 #[allow(clippy::new_ret_no_self)]
28 pub fn new(config: Config) -> (Handle, mpsc::Receiver<Event>) {
29 let (cmd_tx, cmd_rx) = mpsc::channel::<Command>(config.command_queue_size);
30 let (evt_tx, evt_rx) = mpsc::channel::<Event>(config.event_queue_size);
31
32 let stream = Self {
33 config,
34 cmd_rx,
35 evt_tx,
36 };
37
38 tokio::spawn(stream.run());
39
40 (Handle::new(cmd_tx), evt_rx)
41 }
42
43 async fn run(mut self) {
44 info!("stream started");
45 let mut state = State::Idle;
46
47 loop {
48 state = match state {
49 State::Idle => self.step_idle().await,
50 State::Connecting { attempt } => self.step_connecting(attempt).await,
51 State::Connected {
52 frame_rx,
53 read_task,
54 sink,
55 } => self.step_connected(frame_rx, read_task, sink).await,
56 State::Reconnecting { attempt, delay_ms } => {
57 self.step_reconnecting(attempt, delay_ms).await
58 }
59 State::Closing { sink } => self.step_closing(sink).await,
60 State::Done => break,
61 };
62 }
63
64 info!("stream shut down");
65 }
66
67 fn emit(&self, event: Event) {
68 if let Err(e) = self.evt_tx.try_send(event) {
69 match e {
70 mpsc::error::TrySendError::Full(dropped) => {
71 warn!("event queue full, dropping event: {:?}", dropped);
72 }
73 mpsc::error::TrySendError::Closed(_) => {
74 debug!("event receiver dropped");
75 }
76 }
77 }
78 }
79
80 async fn step_idle(&mut self) -> State {
81 loop {
82 match self.cmd_rx.recv().await {
83 Some(Command::Connect) => {
84 return State::Connecting { attempt: 1 };
85 }
86 None => return State::Done,
87 Some(Command::Disconnect) => {
88 warn!("Command disconnect ignored - not connected");
89 }
90 Some(Command::Send(_)) => {
91 warn!("Send ignored - not connected");
92 }
93 }
94 }
95 }
96
97 async fn step_connecting(&mut self, attempt: u32) -> State {
98 debug!(attempt, "connecting…");
99
100 match connect_async(&self.config.url).await {
101 Ok((ws_stream, _)) => {
102 info!("websocket connected");
103 self.emit(Event::Connected);
104
105 let (sink, stream) = ws_stream.split();
106 let (frame_tx, frame_rx) =
107 mpsc::channel::<FrameResult>(self.config.event_queue_size);
108
109 let read_task = tokio::spawn(async move {
110 let mut stream = stream;
111 while let Some(msg) = stream.next().await {
112 if frame_tx.send(msg).await.is_err() {
113 break;
114 }
115 }
116 });
117
118 State::Connected {
119 frame_rx,
120 read_task,
121 sink: Box::new(sink),
122 }
123 }
124 Err(e) => {
125 error!(error = %e, attempt, "connection failed");
126 self.next_reconnect_state(attempt + 1, e.to_string())
127 }
128 }
129 }
130
131 async fn step_connected(
132 &mut self,
133 mut frame_rx: mpsc::Receiver<FrameResult>,
134 read_task: tokio::task::JoinHandle<()>,
135 mut sink: Sink,
136 ) -> State {
137 let ping_interval = self.config.ping_interval;
138 let pong_timeout = self.config.pong_timeout;
139 let mut ping_timer = Box::pin(match ping_interval {
140 Some(d) => sleep(d),
141 None => sleep(FAR_FUTURE),
142 });
143 let mut pong_timer = Box::pin(sleep(FAR_FUTURE));
144 let mut hb = HeartbeatState::Idle;
145
146 loop {
147 tokio::select! {
148 biased;
149
150 frame = frame_rx.recv() => match frame {
151 None => {
152 info!("remote closed the connection");
153 read_task.abort();
154 return self.next_reconnect_state(1, "remote closed".into());
155 }
156 Some(Ok(msg)) => {
157 match msg {
158 Message::Text(json) => {
159 match deserialize_json::<IncomingMessage>(&json){
160 Ok(msg) => {
161 if matches!(hb, HeartbeatState::PingSent) && (msg.is_ping() || msg.is_pong()) {
166 debug!("receiving heartbeat pong");
167 hb = HeartbeatState::Idle;
168 pong_timer.as_mut().reset(far_future_instant());
169 if let Some(d) = ping_interval {
170 ping_timer.as_mut().reset(Instant::now() + d);
171 }
172 }
173
174 self.emit(Event::Message(msg));
175 }
176 Err(e) => {
177 warn!(error = %e, "parsing IncomingMessage failed");
178 self.emit(Event::ParseError(e.to_string()));
179 }
180 }
181 }
182 Message::Binary(bytes) => debug!("binary message received ({}B)", bytes.len()),
183 Message::Ping(bytes) => debug!("ping received ({}B)", bytes.len()),
184 Message::Pong(bytes) => debug!("pong received ({}B)", bytes.len()),
185 Message::Close(close_frame) => debug!("close frame received [{:?}]", close_frame),
186 Message::Frame(frame) =>debug!("frame received ({}B)", frame.len()),
187 }
188 }
189 Some(Err(e)) => {
190 error!(error = %e, "websocket read error");
191 read_task.abort();
192 return self.next_reconnect_state(1, e.to_string());
193 }
194 },
195
196 cmd = self.cmd_rx.recv() => match cmd {
197 None | Some(Command::Disconnect) => {
198 info!("disconnect requested");
199 read_task.abort();
200 return State::Closing { sink };
201 }
202 Some(Command::Send(msg)) => {
203 let json = serialize_json(&msg).expect("serialize outgoing message failed");
204 let msg = Message::Text(json.into());
205 if let Err(e) = sink.send(msg).await {
206 error!(error = %e, "send error");
207 read_task.abort();
208 return self.next_reconnect_state(1, e.to_string());
209 }
210 }
211 Some(Command::Connect) => warn!("Connect ignored - already connected")
212 },
213
214 () = &mut ping_timer, if ping_interval.is_some() => {
215 debug!("sending heartbeat ping");
216 if let Err(e) = sink.send(ping()).await {
217 error!(error = %e, "ping send error");
218 read_task.abort();
219 return self.next_reconnect_state(1, e.to_string());
220 }
221
222 hb = HeartbeatState::PingSent;
223 pong_timer.as_mut().reset(Instant::now() + pong_timeout);
224 ping_timer.as_mut().reset(far_future_instant());
225 },
226
227 () = &mut pong_timer, if matches!(hb, HeartbeatState::PingSent) => {
228 warn!("pong timeout - connection appears dead");
229 read_task.abort();
230 return self.next_reconnect_state( 1, "pong timeout".into());
231 },
232 }
233 }
234 }
235
236 async fn step_reconnecting(&mut self, attempt: u32, delay_ms: u64) -> State {
237 warn!(attempt, delay_ms, "waiting before reconnect");
238 self.emit(Event::Reconnecting { attempt, delay_ms });
239
240 let cancelled = tokio::select! {
241 _ = sleep(Duration::from_millis(delay_ms)) => false,
242 cmd = self.cmd_rx.recv() => matches!(cmd, None | Some(Command::Disconnect)),
243 };
244
245 if cancelled {
246 self.emit(Event::Disconnected {
247 reason: DisconnectReason::Requested,
248 });
249 State::Idle
250 } else {
251 State::Connecting { attempt }
252 }
253 }
254
255 async fn step_closing(&mut self, mut sink: Sink) -> State {
256 if let Err(e) = sink.send(Message::Close(None)).await {
257 error!(error = %e, "send close message failed");
258 }
259 if let Err(e) = timeout(self.config.close_timeout, self.cmd_rx.recv()).await {
260 error!(error = %e, "waiting for a clean close handshake failed");
261 }
262
263 self.emit(Event::Disconnected {
264 reason: DisconnectReason::Requested,
265 });
266 State::Idle
267 }
268
269 fn next_reconnect_state(&self, next_attempt: u32, reason: String) -> State {
270 if self.config.max_reconnect_attempts == 0
271 || next_attempt > self.config.max_reconnect_attempts
272 {
273 self.emit(Event::Disconnected {
274 reason: DisconnectReason::Error(String::from(
275 "all reconnection attempts have failed",
276 )),
277 });
278 return State::Idle;
279 }
280
281 let base_ms = self.config.reconnect_base_delay.as_millis() as u64;
282 let max_ms = self.config.reconnect_max_delay.as_millis() as u64;
283 let delay_ms = (base_ms.saturating_mul(1u64 << (next_attempt - 1).min(10))).min(max_ms);
284
285 debug!(next_attempt, delay_ms, reason, "scheduling reconnect");
286 State::Reconnecting {
287 attempt: next_attempt,
288 delay_ms,
289 }
290 }
291}
292
293const FAR_FUTURE: Duration = Duration::from_secs(u64::MAX / 4);
294
295#[inline]
296fn far_future_instant() -> Instant {
297 Instant::now() + FAR_FUTURE
298}
299
300#[inline]
301fn ping() -> Message {
302 let msg = OutgoingMessage::Ping { req_id: None };
303 let json = serialize_json(&msg).expect("serialize ping outgoing message failed");
304 Message::Text(json.into())
305}