1use std::{
2 collections::HashMap,
3 sync::{
4 Arc,
5 atomic::{AtomicU64, Ordering},
6 },
7 time::Duration,
8};
9
10use futures_util::{SinkExt, StreamExt, stream::SplitSink};
11use prost::Message;
12use tokio::{
13 sync::{Mutex, Notify, RwLock, mpsc, oneshot},
14 task::JoinHandle,
15 time,
16};
17use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
18
19use crate::{
20 emsg::EMsg,
21 error::{Error, Result},
22 message::{NO_JOB_ID, Packet, decode_frame, encode_message},
23 protobuf::{CMsgClientHeartBeat, CMsgProtoBufHeader},
24 transport::websocket::{SteamWebSocket, connect},
25};
26
27type PendingJobs = Arc<Mutex<HashMap<u64, oneshot::Sender<Result<Packet>>>>>;
28type PendingStreams = Arc<Mutex<HashMap<u64, mpsc::UnboundedSender<Result<Packet>>>>>;
29type IncomingEvents = mpsc::UnboundedReceiver<Result<Packet>>;
30type WriteHalf = SplitSink<SteamWebSocket, WebSocketMessage>;
31
32#[derive(Debug, Default, Clone)]
33pub struct ConnectionState {
34 pub steamid: Option<u64>,
35 pub client_session_id: Option<i32>,
36 pub heartbeat_seconds: Option<i32>,
37 pub close_reason: Option<String>,
38 pub license_list_received: bool,
39 pub package_ids: Vec<u32>,
40}
41
42#[derive(Debug)]
43pub struct Connection {
44 sender: Arc<Mutex<WriteHalf>>,
45 pending_jobs: PendingJobs,
46 pending_streams: PendingStreams,
47 incoming: IncomingEvents,
48 next_job_id: AtomicU64,
49 state: Arc<RwLock<ConnectionState>>,
50 license_notify: Arc<Notify>,
51 read_task: JoinHandle<()>,
52 heartbeat_task: Option<JoinHandle<()>>,
53}
54
55impl Connection {
56 pub async fn connect(url: &str) -> Result<Self> {
57 let socket = connect(url).await?;
58 let (writer, mut reader) = socket.split();
59 let sender = Arc::new(Mutex::new(writer));
60 let pending_jobs = Arc::new(Mutex::new(
61 HashMap::<u64, oneshot::Sender<Result<Packet>>>::new(),
62 ));
63 let pending_streams = Arc::new(Mutex::new(HashMap::<
64 u64,
65 mpsc::UnboundedSender<Result<Packet>>,
66 >::new()));
67 let state = Arc::new(RwLock::new(ConnectionState::default()));
68 let license_notify = Arc::new(Notify::new());
69 let (incoming_tx, incoming_rx) = mpsc::unbounded_channel();
70 let pending_jobs_for_read = Arc::clone(&pending_jobs);
71 let pending_streams_for_read = Arc::clone(&pending_streams);
72 let state_for_read = Arc::clone(&state);
73
74 let read_task = tokio::spawn(async move {
75 while let Some(frame) = reader.next().await {
76 let binary = match frame {
77 Ok(WebSocketMessage::Binary(payload)) => payload,
78 Ok(WebSocketMessage::Close(_)) => {
79 mark_closed(&state_for_read, "Steam CM closed the connection").await;
80 fail_pending_jobs(
81 &pending_jobs_for_read,
82 &pending_streams_for_read,
83 Error::Closed,
84 )
85 .await;
86 let _ = incoming_tx.send(Err(Error::Closed));
87 break;
88 }
89 Ok(WebSocketMessage::Ping(_))
90 | Ok(WebSocketMessage::Pong(_))
91 | Ok(WebSocketMessage::Text(_))
92 | Ok(WebSocketMessage::Frame(_)) => {
93 continue;
94 }
95 Err(error) => {
96 let message = error.to_string();
97 let wrapped = Error::from(error);
98 mark_closed(&state_for_read, message.clone()).await;
99 fail_pending_jobs(
100 &pending_jobs_for_read,
101 &pending_streams_for_read,
102 Error::Transport(message),
103 )
104 .await;
105 let _ = incoming_tx.send(Err(wrapped));
106 break;
107 }
108 };
109
110 match decode_frame(&binary) {
111 Ok(packets) => {
112 for packet in packets {
113 let is_server_push = packet.emsg
120 == crate::emsg::EMsg::ServiceMethod.raw()
121 || packet.emsg
122 == crate::emsg::EMsg::ServiceMethodSendToClient.raw();
123
124 if !is_server_push && let Some(job_id) = packet.jobid_target() {
125 let waiter = {
126 let mut pending = pending_jobs_for_read.lock().await;
127 pending.remove(&job_id)
128 };
129 if let Some(waiter) = waiter {
130 let _ = waiter.send(Ok(packet));
131 continue;
132 }
133
134 let stream = {
135 let pending = pending_streams_for_read.lock().await;
136 pending.get(&job_id).cloned()
137 };
138 if let Some(stream) = stream {
139 if stream.send(Ok(packet)).is_err() {
140 let mut pending = pending_streams_for_read.lock().await;
141 pending.remove(&job_id);
142 }
143 continue;
144 }
145 }
146
147 let _ = incoming_tx.send(Ok(packet));
148 }
149 }
150 Err(error) => {
151 let message = error.to_string();
152 mark_closed(&state_for_read, message.clone()).await;
153 fail_pending_jobs(
154 &pending_jobs_for_read,
155 &pending_streams_for_read,
156 Error::Transport(message),
157 )
158 .await;
159 let _ = incoming_tx.send(Err(error));
160 break;
161 }
162 }
163 }
164
165 mark_closed_if_unset(&state_for_read, "Steam CM read loop ended").await;
166 fail_pending_jobs(
167 &pending_jobs_for_read,
168 &pending_streams_for_read,
169 Error::Closed,
170 )
171 .await;
172 });
173
174 Ok(Self {
175 sender,
176 pending_jobs,
177 pending_streams,
178 incoming: incoming_rx,
179 next_job_id: AtomicU64::new(1),
180 state,
181 license_notify,
182 read_task,
183 heartbeat_task: None,
184 })
185 }
186
187 pub async fn send_message<M>(
188 &self,
189 emsg: EMsg,
190 header: &CMsgProtoBufHeader,
191 body: &M,
192 ) -> Result<()>
193 where
194 M: Message,
195 {
196 let payload = encode_message(emsg, header, body)?;
197 self.send_frame(payload).await
198 }
199
200 pub async fn request<M>(
201 &self,
202 emsg: EMsg,
203 header: CMsgProtoBufHeader,
204 body: &M,
205 ) -> Result<Packet>
206 where
207 M: Message,
208 {
209 let rx = self.send_request(emsg, header, body).await?;
210 rx.await
211 .map_err(|_| self.closed_error())
212 .and_then(|result| result)
213 }
214
215 pub async fn send_request<M>(
218 &self,
219 emsg: EMsg,
220 mut header: CMsgProtoBufHeader,
221 body: &M,
222 ) -> Result<oneshot::Receiver<Result<Packet>>>
223 where
224 M: Message,
225 {
226 let job_id = self.next_job_id.fetch_add(1, Ordering::Relaxed);
227 header.jobid_source = Some(job_id);
228 if header.jobid_target.is_none() {
229 header.jobid_target = Some(NO_JOB_ID);
230 }
231
232 let (tx, rx) = oneshot::channel();
233 self.pending_jobs.lock().await.insert(job_id, tx);
234
235 if let Err(error) = self.send_message(emsg, &header, body).await {
236 self.pending_jobs.lock().await.remove(&job_id);
237 return Err(error);
238 }
239
240 Ok(rx)
241 }
242
243 pub async fn send_request_stream<M>(
247 &self,
248 emsg: EMsg,
249 mut header: CMsgProtoBufHeader,
250 body: &M,
251 ) -> Result<(u64, mpsc::UnboundedReceiver<Result<Packet>>)>
252 where
253 M: Message,
254 {
255 let job_id = self.next_job_id.fetch_add(1, Ordering::Relaxed);
256 header.jobid_source = Some(job_id);
257 if header.jobid_target.is_none() {
258 header.jobid_target = Some(NO_JOB_ID);
259 }
260
261 let (tx, rx) = mpsc::unbounded_channel();
262 self.pending_streams.lock().await.insert(job_id, tx);
263
264 if let Err(error) = self.send_message(emsg, &header, body).await {
265 self.pending_streams.lock().await.remove(&job_id);
266 return Err(error);
267 }
268
269 Ok((job_id, rx))
270 }
271
272 pub async fn end_stream(&self, job_id: u64) {
273 self.pending_streams.lock().await.remove(&job_id);
274 }
275
276 pub async fn next_event(&mut self) -> Option<Result<Packet>> {
277 self.incoming.recv().await
278 }
279
280 pub async fn set_logged_on(
281 &mut self,
282 steamid: u64,
283 client_session_id: i32,
284 heartbeat_seconds: i32,
285 ) -> Result<()> {
286 {
287 let mut state = self.state.write().await;
288 state.steamid = Some(steamid);
289 state.client_session_id = Some(client_session_id);
290 state.heartbeat_seconds = Some(heartbeat_seconds);
291 }
292
293 self.start_heartbeat(Duration::from_secs(heartbeat_seconds as u64))
294 .await
295 }
296
297 pub fn take_incoming(&mut self) -> IncomingEvents {
300 let (_dead_tx, dead_rx) = mpsc::unbounded_channel();
301 std::mem::replace(&mut self.incoming, dead_rx)
302 }
303
304 pub async fn state_snapshot(&self) -> ConnectionState {
305 self.state.read().await.clone()
306 }
307
308 pub async fn set_package_ids(&self, package_ids: Vec<u32>) {
309 {
310 let mut state = self.state.write().await;
311 state.license_list_received = true;
312 state.package_ids = package_ids;
313 }
314 self.license_notify.notify_waiters();
315 }
316
317 pub fn license_notify(&self) -> Arc<Notify> {
318 Arc::clone(&self.license_notify)
319 }
320
321 pub async fn is_closed(&self) -> bool {
322 self.state.read().await.close_reason.is_some()
323 }
324
325 async fn send_frame(&self, payload: bytes::Bytes) -> Result<()> {
326 if let Some(reason) = self.state.read().await.close_reason.clone() {
327 return Err(Error::Transport(reason));
328 }
329
330 let mut sender = self.sender.lock().await;
331 if let Err(error) = sender.send(WebSocketMessage::Binary(payload)).await {
332 let message = error.to_string();
333 {
334 let mut state = self.state.write().await;
335 state.close_reason = Some(message.clone());
336 }
337 return Err(Error::Transport(message));
338 }
339 Ok(())
340 }
341
342 async fn start_heartbeat(&mut self, interval: Duration) -> Result<()> {
343 if let Some(task) = self.heartbeat_task.take() {
344 task.abort();
345 }
346
347 let sender = Arc::clone(&self.sender);
348 let state = Arc::clone(&self.state);
349 self.heartbeat_task = Some(tokio::spawn(async move {
350 let mut ticker = time::interval(interval);
351 loop {
352 ticker.tick().await;
353
354 let state_snapshot = state.read().await.clone();
355 let header = CMsgProtoBufHeader {
356 steamid: state_snapshot.steamid,
357 client_sessionid: state_snapshot.client_session_id,
358 ..Default::default()
359 };
360 let payload = match encode_message(
361 EMsg::ClientHeartBeat,
362 &header,
363 &CMsgClientHeartBeat {
364 send_reply: Some(false),
365 },
366 ) {
367 Ok(payload) => payload,
368 Err(_) => break,
369 };
370
371 let mut writer = sender.lock().await;
372 if writer
373 .send(WebSocketMessage::Binary(payload))
374 .await
375 .is_err()
376 {
377 break;
378 }
379 }
380 }));
381
382 Ok(())
383 }
384}
385
386impl Drop for Connection {
387 fn drop(&mut self) {
388 self.read_task.abort();
389 if let Some(task) = self.heartbeat_task.take() {
390 task.abort();
391 }
392 }
393}
394
395impl Connection {
396 fn closed_error(&self) -> Error {
397 match self.state.try_read() {
398 Ok(state) => state
399 .close_reason
400 .clone()
401 .map(Error::Transport)
402 .unwrap_or(Error::Closed),
403 Err(_) => Error::Closed,
404 }
405 }
406}
407
408async fn mark_closed(state: &Arc<RwLock<ConnectionState>>, reason: impl Into<String>) {
409 let mut state = state.write().await;
410 state.close_reason = Some(reason.into());
411}
412
413async fn mark_closed_if_unset(state: &Arc<RwLock<ConnectionState>>, reason: impl Into<String>) {
414 let mut state = state.write().await;
415 if state.close_reason.is_none() {
416 state.close_reason = Some(reason.into());
417 }
418}
419
420async fn fail_pending_jobs(
421 pending_jobs: &PendingJobs,
422 pending_streams: &PendingStreams,
423 error: Error,
424) {
425 let waiters = {
426 let mut pending = pending_jobs.lock().await;
427 pending
428 .drain()
429 .map(|(_, waiter)| waiter)
430 .collect::<Vec<_>>()
431 };
432 let streams = {
433 let mut pending = pending_streams.lock().await;
434 pending
435 .drain()
436 .map(|(_, stream)| stream)
437 .collect::<Vec<_>>()
438 };
439
440 let error_message = error.to_string();
441 for waiter in waiters {
442 let _ = waiter.send(Err(Error::Transport(error_message.clone())));
443 }
444 for stream in streams {
445 let _ = stream.send(Err(Error::Transport(error_message.clone())));
446 }
447}