rak_rs/client/
handshake.rs

1use crate::client::discovery;
2use crate::client::discovery::DiscoveryStatus;
3use crate::client::discovery::MtuDiscovery;
4use crate::client::util::send_packet;
5use crate::connection::queue::send::SendQueue;
6use crate::connection::queue::RecvQueue;
7use crate::protocol::frame::FramePacket;
8use crate::protocol::packet::offline::{SessionInfoReply, SessionInfoRequest};
9use crate::protocol::packet::online::ConnectedPong;
10use crate::protocol::packet::online::{ConnectionRequest, NewConnection, OnlinePacket};
11use crate::protocol::reliability::Reliability;
12use crate::protocol::Magic;
13use crate::rakrs_debug;
14use crate::server::current_epoch;
15#[cfg(feature = "async_std")]
16use async_std::{
17    future::timeout,
18    future::Future,
19    net::UdpSocket,
20    task::{self, Context, Poll, Waker},
21};
22use binary_util::interfaces::Reader;
23use binary_util::io::ByteReader;
24#[cfg(feature = "async_tokio")]
25use std::future::Future;
26use std::sync::Arc;
27use std::sync::Mutex;
28#[cfg(feature = "async_tokio")]
29use std::task::{Context, Poll, Waker};
30use std::time::Duration;
31#[cfg(feature = "async_tokio")]
32use tokio::{
33    net::UdpSocket,
34    task::{self},
35    time::timeout,
36};
37
38#[macro_export]
39macro_rules! match_ids {
40    ($socket: expr, $timeout: expr, $($ids: expr),*) => {
41        {
42            let mut recv_buf: [u8; 2048] = [0; 2048];
43            let mut tries: u8 = 0;
44            let ids = vec![$($ids),*];
45            let mut pk: Option<Vec<u8>> = None;
46
47            'try_conn: loop {
48                if (tries >= 5) {
49                    break;
50                }
51
52                let len: usize;
53                let send_result = timeout(
54                    Duration::from_secs($timeout),
55                    $socket.recv(&mut recv_buf)
56                ).await;
57
58                if (send_result.is_err()) {
59                    rakrs_debug!(true, "[CLIENT] Failed to receive packet from server! Is it offline?");
60                    break 'try_conn;
61                }
62
63                match send_result.unwrap() {
64                    Err(e) => {
65                        tries += 1;
66                        rakrs_debug!(true, "[CLIENT] Failed to receive packet from server! {}", e);
67                        continue;
68                    },
69                    Ok(l) => len = l
70                };
71
72                crate::rakrs_debug_buffers!(true, "[annon]\n {:?}", &recv_buf[..len]);
73
74                // rakrs_debug!(true, "[CLIENT] Received packet from server: {:x?}", &recv_buf[..len]);
75
76                if ids.contains(&recv_buf[0]) {
77                    pk = Some(recv_buf[..len].to_vec());
78                    break 'try_conn;
79                }
80            }
81
82            pk
83        }
84    };
85}
86
87macro_rules! expect_reply {
88    ($socket: expr, $reply: ty, $timeout: expr) => {{
89        let mut recv_buf: [u8; 2048] = [0; 2048];
90        let mut tries: u8 = 0;
91        let mut pk: Option<$reply> = None;
92
93        loop {
94            if (tries >= 5) {
95                break;
96            }
97
98            let len: usize;
99            let send_result =
100                timeout(Duration::from_secs($timeout), $socket.recv(&mut recv_buf)).await;
101
102            if (send_result.is_err()) {
103                rakrs_debug!(
104                    true,
105                    "[CLIENT] Failed to receive packet from server! Is it offline?"
106                );
107                break;
108            }
109
110            match send_result.unwrap() {
111                Err(_) => {
112                    tries += 1;
113                    continue;
114                }
115                Ok(l) => len = l,
116            };
117
118            // rakrs_debug!(true, "[CLIENT] Received packet from server: {:x?}", &recv_buf[..len]);
119            crate::rakrs_debug_buffers!(true, "[annon]\n {:?}", &recv_buf[..len]);
120
121            let mut reader = ByteReader::from(&recv_buf[1..len]);
122            if let Ok(packet) = <$reply>::read(&mut reader) {
123                pk = Some(packet);
124                break;
125            } else {
126                rakrs_debug!(true, "[CLIENT] Failed to parse packet!");
127            }
128        }
129
130        pk
131    }};
132}
133
134macro_rules! update_state {
135    ($done: expr, $shared_state: expr, $state: expr) => {{
136        let mut state = $shared_state.lock().unwrap();
137        state.status = $state;
138        state.done = true;
139        if let Some(waker) = state.waker.take() {
140            waker.wake();
141        }
142        return;
143    }};
144    ($shared_state: expr, $state: expr) => {{
145        let mut state = $shared_state.lock().unwrap();
146        state.status = $state;
147        state.done = false;
148        if let Some(waker) = state.waker.take() {
149            waker.wake();
150        }
151    }};
152}
153
154#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
155pub enum HandshakeStatus {
156    Created,
157    Opening,
158    SessionOpen,
159    Failed,
160    FailedMtuDiscovery,
161    FailedNoSessionReply,
162    IncompatibleVersion,
163    Completed,
164}
165
166impl std::fmt::Display for HandshakeStatus {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        write!(
169            f,
170            "{}",
171            match self {
172                HandshakeStatus::Created => "Handshake created",
173                HandshakeStatus::Opening => "Opening handshake",
174                HandshakeStatus::SessionOpen => "Session open",
175                HandshakeStatus::Failed => "Handshake failed",
176                HandshakeStatus::FailedMtuDiscovery => "MTU discovery failed",
177                HandshakeStatus::FailedNoSessionReply => "No session reply",
178                HandshakeStatus::IncompatibleVersion => "Incompatible version",
179                HandshakeStatus::Completed => "Handshake completed",
180            }
181        )
182    }
183}
184
185pub(crate) struct HandshakeState {
186    status: HandshakeStatus,
187    done: bool,
188    waker: Option<Waker>,
189}
190
191pub struct ClientHandshake {
192    status: Arc<Mutex<HandshakeState>>,
193}
194
195impl ClientHandshake {
196    pub fn new(
197        socket: Arc<UdpSocket>,
198        id: i64,
199        version: u8,
200        mut mtu: u16,
201        attempts: u8,
202        timeout: u16,
203    ) -> Self {
204        let state = Arc::new(Mutex::new(HandshakeState {
205            done: false,
206            status: HandshakeStatus::Created,
207            waker: None,
208        }));
209
210        let shared_state = state.clone();
211
212        task::spawn(async move {
213            update_state!(shared_state, HandshakeStatus::Opening);
214
215            rakrs_debug!(true, "[CLIENT] Sending OpenConnectRequest to server...");
216
217            match MtuDiscovery::new(
218                socket.clone(),
219                discovery::MtuDiscoveryMeta {
220                    id,
221                    version,
222                    mtu,
223                    timeout,
224                },
225            )
226            .await
227            {
228                DiscoveryStatus::Discovered(m) => {
229                    rakrs_debug!(true, "[CLIENT] Discovered MTU size: {}", m);
230                    mtu = m;
231                }
232                DiscoveryStatus::IncompatibleVersion => {
233                    rakrs_debug!(
234                        true,
235                        "[CLIENT] Client is using incompatible protocol version."
236                    );
237                    update_state!(true, shared_state, HandshakeStatus::IncompatibleVersion);
238                }
239                _ => {
240                    update_state!(true, shared_state, HandshakeStatus::FailedMtuDiscovery);
241                }
242            }
243
244            let session_info = SessionInfoRequest {
245                magic: Magic::new(),
246                address: socket.peer_addr().unwrap(),
247                mtu_size: mtu,
248                client_id: id,
249            };
250
251            rakrs_debug!(true, "[CLIENT] Sending SessionInfoRequest to server...");
252
253            update_state!(shared_state, HandshakeStatus::SessionOpen);
254
255            if !send_packet(&socket, session_info.into()).await {
256                rakrs_debug!(
257                    true,
258                    "[CLIENT] Failed to send SessionInfoRequest to server."
259                );
260                update_state!(true, shared_state, HandshakeStatus::Failed);
261            }
262
263            let session_reply = expect_reply!(socket, SessionInfoReply, timeout.into());
264
265            if session_reply.is_none() {
266                rakrs_debug!(true, "[CLIENT] Server did not reply with SessionInfoReply!");
267                update_state!(true, shared_state, HandshakeStatus::FailedNoSessionReply);
268            }
269
270            let session_reply = session_reply.unwrap();
271
272            if session_reply.mtu_size != mtu {
273                rakrs_debug!(
274                    true,
275                    "[CLIENT] Server replied with incompatible MTU size! ({} != {})",
276                    session_reply.mtu_size,
277                    mtu
278                );
279                update_state!(true, shared_state, HandshakeStatus::Failed);
280            }
281
282            rakrs_debug!(true, "[CLIENT] Received SessionInfoReply from server!");
283
284            // create a temporary sendq
285            let mut send_q = SendQueue::new(
286                mtu,
287                timeout,
288                attempts.clone().into(),
289                socket.clone(),
290                socket.peer_addr().unwrap(),
291            );
292            let mut recv_q = RecvQueue::new();
293
294            if let Err(_) = Self::send_connection_request(&mut send_q, id).await {
295                update_state!(true, shared_state, HandshakeStatus::Failed);
296            }
297
298            rakrs_debug!(true, "[CLIENT] Sent ConnectionRequest to server!");
299
300            let mut send_time = current_epoch() as i64;
301            let mut tries = 0_u8;
302
303            let mut buf: [u8; 2048] = [0; 2048];
304
305            loop {
306                let len: usize;
307                let rec = socket.recv_from(&mut buf).await;
308
309                if (send_time + 2) <= current_epoch() as i64 {
310                    send_time = current_epoch() as i64;
311
312                    rakrs_debug!(
313                        true,
314                        "[CLIENT] Server did not reply with ConnectAccept, sending another..."
315                    );
316
317                    if let Err(_) = Self::send_connection_request(&mut send_q, id).await {
318                        update_state!(true, shared_state, HandshakeStatus::Failed);
319                    }
320
321                    tries += 1;
322                    if tries >= 5 {
323                        update_state!(true, shared_state, HandshakeStatus::Failed);
324                    }
325                }
326
327                match rec {
328                    Err(_) => {
329                        continue;
330                    }
331                    Ok((l, _)) => len = l,
332                };
333
334                let mut reader = ByteReader::from(&buf[..len]);
335
336                // proccess frame packet
337                match buf[0] {
338                    0x80..=0x8d => {
339                        if let Ok(pk) = FramePacket::read(&mut reader) {
340                            if let Err(_) = recv_q.insert(pk) {
341                                continue;
342                            }
343
344                            let raw_packets = recv_q.flush();
345
346                            for raw_pk in raw_packets {
347                                let mut pk = ByteReader::from(&raw_pk[..]);
348
349                                if let Ok(pk) = OnlinePacket::read(&mut pk) {
350                                    match pk {
351                                        OnlinePacket::ConnectedPing(pk) => {
352                                            rakrs_debug!(
353                                                true,
354                                                "[CLIENT] Received ConnectedPing from server!"
355                                            );
356                                            let response = ConnectedPong {
357                                                ping_time: pk.time,
358                                                pong_time: current_epoch() as i64,
359                                            };
360
361                                            if let Err(_) = send_q
362                                                .send_packet(
363                                                    response.into(),
364                                                    Reliability::Reliable,
365                                                    true,
366                                                )
367                                                .await
368                                            {
369                                                rakrs_debug!(
370                                                    true,
371                                                    "[CLIENT] Failed to send pong packet!"
372                                                );
373                                            }
374
375                                            continue;
376                                        }
377                                        OnlinePacket::ConnectionAccept(pk) => {
378                                            // send new incoming connection
379                                            let new_incoming = NewConnection {
380                                                server_address: socket.peer_addr().unwrap(),
381                                                system_address: vec![
382                                                    socket.peer_addr().unwrap(),
383                                                    socket.peer_addr().unwrap(),
384                                                    socket.peer_addr().unwrap(),
385                                                    socket.peer_addr().unwrap(),
386                                                    socket.peer_addr().unwrap(),
387                                                    socket.peer_addr().unwrap(),
388                                                    socket.peer_addr().unwrap(),
389                                                    socket.peer_addr().unwrap(),
390                                                    socket.peer_addr().unwrap(),
391                                                    socket.peer_addr().unwrap(),
392                                                ],
393                                                request_time: pk.request_time,
394                                                timestamp: pk.timestamp,
395                                            };
396                                            if let Err(_) = send_q
397                                                .send_packet(
398                                                    new_incoming.into(),
399                                                    Reliability::Reliable,
400                                                    true,
401                                                )
402                                                .await
403                                            {
404                                                update_state!(
405                                                    true,
406                                                    shared_state,
407                                                    HandshakeStatus::Failed
408                                                );
409                                            } else {
410                                                update_state!(
411                                                    true,
412                                                    shared_state,
413                                                    HandshakeStatus::Completed
414                                                );
415                                            }
416                                        }
417                                        _ => {
418                                            rakrs_debug!(
419                                                true,
420                                                "[CLIENT] Received unknown packet from server!"
421                                            );
422                                        }
423                                    }
424                                }
425                            }
426                        }
427                    }
428                    _ => {}
429                }
430            }
431        });
432
433        Self { status: state }
434    }
435
436    pub(crate) async fn send_connection_request(
437        send_q: &mut SendQueue,
438        id: i64,
439    ) -> std::io::Result<()> {
440        let connect_request = ConnectionRequest {
441            time: current_epoch() as i64,
442            client_id: id,
443            security: false,
444        };
445
446        if let Err(_) = send_q
447            .send_packet(connect_request.into(), Reliability::Reliable, true)
448            .await
449        {
450            return Err(std::io::Error::new(
451                std::io::ErrorKind::Other,
452                "Failed to send ConnectionRequest!",
453            ));
454        }
455        return Ok(());
456    }
457}
458
459impl Future for ClientHandshake {
460    type Output = HandshakeStatus;
461
462    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
463        // see if we can finish
464        let mut state = self.status.lock().unwrap();
465
466        if state.done {
467            return Poll::Ready(state.status);
468        } else {
469            state.waker = Some(cx.waker().clone());
470            return Poll::Pending;
471        }
472    }
473}