1use std::{net::SocketAddr, sync::{Arc, atomic::{AtomicUsize, Ordering}}};
2
3use bevy::{prelude::*, utils::tracing::Instrument};
4use dashmap::DashMap;
5use derive_more::Display;
6use tokio::{
7    io::{AsyncReadExt, AsyncWriteExt, BufWriter},
8    net::{TcpStream, ToSocketAddrs},
9    runtime::Runtime,
10    sync::mpsc::{unbounded_channel, UnboundedSender},
11    task::JoinHandle,
12};
13
14use crate::{
15    error::NetworkError,
16    network_message::{ClientMessage, NetworkMessage, ServerMessage},
17    ClientNetworkEvent, ConnectionId, NetworkData, NetworkPacket, NetworkSettings, SyncChannel,
18};
19
20#[derive(Display)]
21#[display(fmt = "Server connection to {}", peer_addr)]
22struct ServerConnection {
23    peer_addr: SocketAddr,
24    receive_task: JoinHandle<()>,
25    send_task: JoinHandle<()>,
26    send_message: UnboundedSender<NetworkPacket>,
27}
28
29impl ServerConnection {
30    fn stop(self) {
31        self.receive_task.abort();
32        self.send_task.abort();
33    }
34}
35
36pub struct NetworkClient {
39    runtime: Runtime,
40    server_connection: Option<ServerConnection>,
41    recv_message_map: Arc<DashMap<&'static str, Vec<Box<dyn NetworkMessage>>>>,
42    network_events: SyncChannel<ClientNetworkEvent>,
43    connection_events: SyncChannel<(TcpStream, SocketAddr, NetworkSettings)>,
44}
45
46impl std::fmt::Debug for NetworkClient {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        if let Some(conn) = self.server_connection.as_ref() {
49            write!(f, "NetworkClient [Connected to {}]", conn.peer_addr)?;
50        } else {
51            write!(f, "NetworkClient [Not Connected]")?;
52        }
53
54        Ok(())
55    }
56}
57
58impl NetworkClient {
59    pub(crate) fn new() -> NetworkClient {
60        NetworkClient {
61            runtime: tokio::runtime::Builder::new_multi_thread()
62                .enable_all()
63                .thread_name_fn(|| {
64                    static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
65                    let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
66                    format!("bevy-spicy-networking-client-worker-{}", id)
67                })
68                .worker_threads(2)
69                .build()
70                .expect("Could not build tokio runtime"),
71            server_connection: None,
72            recv_message_map: Arc::new(DashMap::new()),
73            network_events: SyncChannel::new(),
74            connection_events: SyncChannel::new(),
75        }
76    }
77
78    pub fn connect(
83        &mut self,
84        addr: impl ToSocketAddrs + Send + 'static,
85        network_settings: NetworkSettings,
86    ) {
87        debug!("Starting connection");
88
89        self.disconnect();
90
91        let network_error_sender = self.network_events.sender.clone();
92        let connection_event_sender = self.connection_events.sender.clone();
93
94        let debug_span = debug_span!("network_start_connection");
95
96        self.runtime.spawn(
97            async move {
98                let debug_span = debug_span!("network_connect");
99                let stream = match TcpStream::connect(addr).instrument(debug_span).await {
100                    Ok(stream) => stream,
101                    Err(error) => {
102                        match network_error_sender
103                            .send(ClientNetworkEvent::Error(NetworkError::Connection(error)))
104                        {
105                            Ok(_) => (),
106                            Err(err) => {
107                                error!("Could not send error event: {}", err);
108                            }
109                        }
110
111                        return;
112                    }
113                };
114
115                let addr = stream
116                    .peer_addr()
117                    .expect("Could not fetch peer_addr of existing stream");
118
119                match connection_event_sender.send((stream, addr, network_settings)) {
120                    Ok(_) => (),
121                    Err(err) => {
122                        error!("Could not initiate connection: {}", err);
123                    }
124                }
125
126                debug!("Connected to: {:?}", addr);
127            }
128            .instrument(debug_span),
129        );
130    }
131
132    pub fn disconnect(&mut self) {
137        if let Some(conn) = self.server_connection.take() {
138            conn.stop();
139
140            let _ = self
141                .network_events
142                .sender
143                .send(ClientNetworkEvent::Disconnected);
144        }
145    }
146
147    pub fn send_message<T: ServerMessage>(&self, message: T) -> Result<(), NetworkError> {
150        debug!("Sending message to server");
151        let server_connection = match self.server_connection.as_ref() {
152            Some(server) => server,
153            None => return Err(NetworkError::NotConnected),
154        };
155
156        let packet = NetworkPacket {
157            kind: String::from(T::NAME),
158            data: Box::new(message),
159        };
160
161        match server_connection.send_message.send(packet) {
162            Ok(_) => (),
163            Err(err) => {
164                error!("Server disconnected: {}", err);
165                return Err(NetworkError::NotConnected);
166            }
167        }
168
169        Ok(())
170    }
171
172    pub fn is_connected(&self) -> bool {
177        self.server_connection.is_some()
178    }
179}
180
181pub trait AppNetworkClientMessage {
183    fn listen_for_client_message<T: ClientMessage>(&mut self);
191}
192
193impl AppNetworkClientMessage for AppBuilder {
194    fn listen_for_client_message<T: ClientMessage>(&mut self) {
195        let client = self.world().get_resource::<NetworkClient>().expect("Could not find `NetworkClient`. Be sure to include the `ClientPlugin` before listening for client messages.");
196
197        debug!("Registered a new ClientMessage: {}", T::NAME);
198
199        assert!(
200            !client.recv_message_map.contains_key(T::NAME),
201            "Duplicate registration of ClientMessage: {}",
202            T::NAME
203        );
204        client.recv_message_map.insert(T::NAME, Vec::new());
205
206        self.add_event::<NetworkData<T>>();
207        self.add_system_to_stage(CoreStage::PreUpdate, register_client_message::<T>.system());
208    }
209}
210
211fn register_client_message<T>(
212    net_res: ResMut<NetworkClient>,
213    mut events: EventWriter<NetworkData<T>>,
214) where
215    T: ClientMessage,
216{
217    let mut messages = match net_res.recv_message_map.get_mut(T::NAME) {
218        Some(messages) => messages,
219        None => return,
220    };
221
222    events.send_batch(
223        messages
224            .drain(..)
225            .flat_map(|msg| msg.downcast())
226            .map(|msg| {
227                NetworkData::new(
228                    ConnectionId::server(
229                        net_res
230                            .server_connection
231                            .as_ref()
232                            .map(|conn| conn.peer_addr),
233                    ),
234                    *msg,
235                )
236            }),
237    );
238}
239
240pub fn handle_connection_event(
241    mut net_res: ResMut<NetworkClient>,
242    mut events: EventWriter<ClientNetworkEvent>,
243) {
244    let (connection, peer_addr, network_settings) =
245        match net_res.connection_events.receiver.try_recv() {
246            Ok(event) => event,
247            Err(_err) => {
248                return;
249            }
250        };
251
252    let (read_socket, send_socket) = connection.into_split();
253    let recv_message_map = net_res.recv_message_map.clone();
254    let (send_message, recv_message) = unbounded_channel();
255    let network_event_sender = net_res.network_events.sender.clone();
256    let network_event_sender_two = net_res.network_events.sender.clone();
257
258    let send_span = debug_span!(
259        "network_message_send_task",
260        peer_addr = &peer_addr.to_string()[..]
261    );
262    let receive_span = debug_span!(
263        "network_message_receive_task",
264        peer_addr = &peer_addr.to_string()[..]
265    );
266
267    net_res.server_connection = Some(ServerConnection {
268        peer_addr,
269        send_task: net_res.runtime.spawn(
270            async move {
271                let mut recv_message = recv_message;
272                let mut send_socket = BufWriter::new(send_socket);
273
274                debug!("Starting new server connection, sending task");
275
276                while let Some(message) = recv_message.recv().await {
277                    let debug_span = debug_span!("network_message_encode");
278
279                    let encoded = match debug_span.in_scope(|| serde_cbor::to_vec(&message)) {
280                        Ok(encoded) => encoded,
281                        Err(err) => {
282                            error!("Could not encode packet {:?}: {}", message, err);
283                            continue;
284                        }
285                    };
286
287                    let compression_level = 3;
288
289                    let debug_span = debug_span!(
290                        "network_message_compress",
291                        compression_level = compression_level
292                    );
293
294                    let compressed = debug_span.in_scope(|| {
295                        miniz_oxide::deflate::compress_to_vec(&encoded, compression_level)
296                    });
297
298                    let len = compressed.len();
299
300                    let debug_span = debug_span!("network_message_sending_length", len = len);
301
302                    match send_socket
303                        .write_u32(len as u32)
304                        .instrument(debug_span)
305                        .await
306                    {
307                        Ok(_) => (),
308                        Err(err) => {
309                            error!("Could not send packet length: {:?}: {}", len, err);
310                            break;
311                        }
312                    }
313
314                    let debug_span = debug_span!("network_message_send_data");
315
316                    match send_socket
317                        .write_all(&compressed)
318                        .instrument(debug_span)
319                        .await
320                    {
321                        Ok(_) => (),
322                        Err(err) => {
323                            error!("Could not send packet: {:?}: {}", message, err);
324                            break;
325                        }
326                    }
327
328                    match send_socket.flush().await {
329                        Ok(_) => (),
330                        Err(err) => {
331                            error!("Could not flush packet: {:?}: {}", message, err);
332                            break;
333                        }
334                    }
335
336                    trace!("Succesfully written all!");
337                }
338
339                let _ = network_event_sender_two.send(ClientNetworkEvent::Disconnected);
340            }
341            .instrument(send_span),
342        ),
343        receive_task: net_res.runtime.spawn(
344            async move {
345                let mut read_socket = read_socket;
346                let network_settings = network_settings;
347                let recv_message_map = recv_message_map;
348
349                let mut buffer: Vec<u8> = vec![0; network_settings.max_packet_length];
350                loop {
351                    let debug_span = debug_span!("network_message_receive_length");
352
353                    let length = match read_socket.read_u32().instrument(debug_span).await {
354                        Ok(len) => len as usize,
355                        Err(err) => {
356                            error!(
357                                "Encountered error while fetching length [{}]: {}",
358                                peer_addr, err
359                            );
360                            break;
361                        }
362                    };
363
364                    if length > network_settings.max_packet_length {
365                        error!(
366                            "Received too large packet from [{}]: {} > {}",
367                            peer_addr, length, network_settings.max_packet_length
368                        );
369                        break;
370                    }
371
372                    let debug_span = debug_span!("network_message_receive_data", length = length);
373
374                    match read_socket
375                        .read_exact(&mut buffer[..length])
376                        .instrument(debug_span)
377                        .await
378                    {
379                        Ok(_) => (),
380                        Err(err) => {
381                            error!(
382                                "Encountered error while fetching stream of length {} [{}]: {}",
383                                length, peer_addr, err
384                            );
385                            break;
386                        }
387                    }
388
389                    let debug_span = debug_span!("network_message_decompression");
390
391                    let decompressed_packet = match debug_span
392                        .in_scope(|| miniz_oxide::inflate::decompress_to_vec(&buffer[..length]))
393                    {
394                        Ok(decom) => decom,
395                        Err(err) => {
396                            error!("Encountered error while decompressing: {:?}", err);
397                            break;
398                        }
399                    };
400
401                    let debug_span = debug_span!("network_message_parsing");
402
403                    let packet: NetworkPacket = match debug_span
404                        .in_scope(|| serde_cbor::from_slice(&decompressed_packet))
405                    {
406                        Ok(packet) => packet,
407                        Err(err) => {
408                            error!(
409                                "Failed to decode network packet from [{}]: {}",
410                                peer_addr, err
411                            );
412                            break;
413                        }
414                    };
415
416                    match recv_message_map.get_mut(&packet.kind[..]) {
417                        Some(mut packets) => packets.push(packet.data),
418                        None => {
419                            error!(
420                                "Could not find existing entries for message kinds: {:?}",
421                                packet
422                            );
423                        }
424                    }
425                    debug!("Finished receiving message from: {}", peer_addr);
426                }
427
428                let _ = network_event_sender.send(ClientNetworkEvent::Disconnected);
429            }
430            .instrument(receive_span),
431        ),
432        send_message,
433    });
434
435    events.send(ClientNetworkEvent::Connected);
436}
437
438pub fn send_client_network_events(
439    client_server: ResMut<NetworkClient>,
440    mut client_network_events: EventWriter<ClientNetworkEvent>,
441) {
442    client_network_events.send_batch(client_server.network_events.receiver.try_iter());
443}