1use anyhow::bail;
14use anyhow::Result;
15use rand::rngs::StdRng;
16use rand::Rng;
17use rand::SeedableRng;
18use tokio::sync::mpsc::unbounded_channel;
19use tokio::sync::mpsc::UnboundedReceiver;
20use tokio::sync::mpsc::UnboundedSender;
21
22use crate::wire::command::Command;
23use crate::wire::command::CommandProperties;
24use crate::wire::command::ToClientCommand;
25use crate::wire::deser::Deserialize;
26use crate::wire::deser::Deserializer;
27use crate::wire::packet::AckBody;
28use crate::wire::packet::ControlBody;
29use crate::wire::packet::InnerBody;
30use crate::wire::packet::Packet;
31use crate::wire::packet::PacketBody;
32use crate::wire::packet::PeerId;
33use crate::wire::packet::ReliableBody;
34use crate::wire::packet::SetPeerIdBody;
35use crate::wire::ser::Serialize;
36use crate::wire::ser::VecSerializer;
37use crate::wire::types::ProtocolContext;
38
39use super::reliable_receiver::ReliableReceiver;
40use super::reliable_sender::ReliableSender;
41use super::split_receiver::SplitReceiver;
42use super::split_sender::SplitSender;
43
44use std::collections::VecDeque;
45use std::net::SocketAddr;
46use std::time::Duration;
47use std::time::Instant;
48
49const INEXISTENT_PEER_ID_GRACE: Duration = Duration::from_secs(20);
51
52#[derive(thiserror::Error, Debug)]
53pub enum PeerError {
54 #[error("Peer sent disconnect packet")]
55 PeerSentDisconnect,
56 #[error("Socket Closed")]
57 SocketClosed,
58 #[error("Controller Closed")]
59 ControllerClosed,
60 #[error("Internal Peer error")]
61 InternalPeerError,
62}
63
64pub type ChannelNum = u8;
65pub type FullSeqNum = u64;
66
67pub struct Peer {
69 remote_addr: SocketAddr,
70 remote_is_server: bool,
71 send: UnboundedSender<Command>,
73 recv: UnboundedReceiver<Result<Command>>,
74}
75
76impl Peer {
77 pub fn remote_addr(&self) -> SocketAddr {
78 self.remote_addr
79 }
80
81 pub fn is_server(&self) -> bool {
82 self.remote_is_server
83 }
84
85 pub async fn send(&self, command: Command) -> Result<()> {
88 self.send.send(command)?;
89 Ok(())
90 }
91
92 pub async fn recv(&mut self) -> anyhow::Result<Command> {
96 match self.recv.recv().await {
97 Some(result) => result,
98 None => bail!(PeerError::InternalPeerError),
99 }
100 }
101}
102
103pub struct PeerIO {
105 relay: UnboundedSender<SocketToPeer>,
106}
107
108pub fn new_peer(
109 remote_addr: SocketAddr,
110 remote_is_server: bool,
111 peer_to_socket: UnboundedSender<PeerToSocket>,
112) -> (Peer, PeerIO) {
113 let (peer_send_tx, peer_send_rx) = unbounded_channel();
114 let (peer_recv_tx, peer_recv_rx) = unbounded_channel();
115 let (relay_tx, relay_rx) = unbounded_channel();
116
117 let socket_peer = Peer {
118 remote_addr,
119 remote_is_server,
120 send: peer_send_tx,
121 recv: peer_recv_rx,
122 };
123 let socket_peer_io = PeerIO { relay: relay_tx };
124 let socket_peer_runner = PeerRunner {
125 remote_addr,
126 remote_is_server,
127 recv_context: ProtocolContext::latest_for_receive(remote_is_server),
128 send_context: ProtocolContext::latest_for_send(remote_is_server),
129 connect_time: Instant::now(),
130 remote_peer_id: 0,
131 local_peer_id: 0,
132 from_socket: relay_rx,
133 from_controller: peer_send_rx,
134 to_controller: peer_recv_tx.clone(),
135 to_socket: peer_to_socket,
136 channels: vec![
137 Channel::new(remote_is_server, peer_recv_tx.clone()),
138 Channel::new(remote_is_server, peer_recv_tx.clone()),
139 Channel::new(remote_is_server, peer_recv_tx.clone()),
140 ],
141 rng: StdRng::from_entropy(),
142 now: Instant::now(),
143 last_received: Instant::now(),
144 };
145 tokio::spawn(async move { socket_peer_runner.run().await });
146 (socket_peer, socket_peer_io)
147}
148
149impl PeerIO {
150 pub fn send(&mut self, data: &[u8]) {
154 let _ = self.relay.send(SocketToPeer::Received(data.to_vec()));
156 }
157}
158
159struct Channel {
160 unreliable_out: VecDeque<InnerBody>,
161
162 reliable_in: ReliableReceiver,
163 reliable_out: ReliableSender,
164
165 split_in: SplitReceiver,
166 split_out: SplitSender,
167
168 to_controller: UnboundedSender<Result<Command>>,
169 now: Instant,
170 recv_context: ProtocolContext,
171 send_context: ProtocolContext,
172}
173
174impl Channel {
175 pub fn new(remote_is_server: bool, to_controller: UnboundedSender<Result<Command>>) -> Self {
176 Self {
177 unreliable_out: VecDeque::new(),
178 reliable_in: ReliableReceiver::new(),
179 reliable_out: ReliableSender::new(),
180 split_in: SplitReceiver::new(),
181 split_out: SplitSender::new(),
182 to_controller,
183 now: Instant::now(),
184 recv_context: ProtocolContext::latest_for_receive(remote_is_server),
185 send_context: ProtocolContext::latest_for_send(remote_is_server),
186 }
187 }
188
189 pub fn update_now(&mut self, now: &Instant) {
190 self.now = *now;
191 }
192
193 pub fn update_context(
194 &mut self,
195 recv_context: &ProtocolContext,
196 send_context: &ProtocolContext,
197 ) {
198 self.recv_context = *recv_context;
199 self.send_context = *send_context;
200 }
201
202 pub async fn process(&mut self, body: PacketBody) -> anyhow::Result<()> {
205 match body {
206 PacketBody::Reliable(rb) => self.process_reliable(rb).await?,
207 PacketBody::Inner(ib) => self.process_inner(ib).await?,
208 }
209 Ok(())
210 }
211
212 pub async fn process_reliable(&mut self, body: ReliableBody) -> anyhow::Result<()> {
213 self.reliable_in.push(body);
214 while let Some(inner) = self.reliable_in.pop() {
215 self.process_inner(inner).await?;
216 }
217 Ok(())
218 }
219
220 pub async fn process_inner(&mut self, body: InnerBody) -> anyhow::Result<()> {
221 match body {
222 InnerBody::Control(body) => self.process_control(body),
223 InnerBody::Original(body) => self.process_command(body.command).await,
224 InnerBody::Split(body) => {
225 if let Some(payload) = self.split_in.push(self.now, body)? {
226 let mut buf = Deserializer::new(self.recv_context, &payload);
227 let command = Command::deserialize(&mut buf)?;
228 self.process_command(command).await;
229 }
230 }
231 }
232 Ok(())
233 }
234
235 pub fn process_control(&mut self, body: ControlBody) {
236 match body {
237 ControlBody::Ack(ack) => {
238 self.reliable_out.process_ack(ack);
239 }
240 _ => (),
242 }
243 }
244
245 pub async fn process_command(&mut self, command: Command) {
246 match self.to_controller.send(Ok(command)) {
247 Ok(_) => (),
248 Err(e) => panic!("Unexpected command channel shutdown: {:?}", e),
249 }
250 }
251
252 pub fn send(&mut self, reliable: bool, command: Command) -> anyhow::Result<()> {
254 let bodies = self.split_out.push(self.send_context, command)?;
255 for body in bodies.into_iter() {
256 self.send_inner(reliable, body);
257 }
258 Ok(())
259 }
260
261 pub fn send_inner(&mut self, reliable: bool, body: InnerBody) {
262 if reliable {
263 self.reliable_out.push(body);
264 } else {
265 self.unreliable_out.push_back(body);
266 }
267 }
268
269 pub fn next_send(&mut self, now: Instant) -> Option<PacketBody> {
271 match self.unreliable_out.pop_front() {
272 Some(body) => return Some(PacketBody::Inner(body)),
273 None => (),
274 };
275 match self.reliable_out.pop(now) {
276 Some(body) => return Some(body),
277 None => (),
278 }
279 None
280 }
281
282 pub fn next_timeout(&mut self) -> Option<Instant> {
284 self.reliable_out.next_timeout()
285 }
286}
287
288#[derive(Debug)]
289pub enum SocketToPeer {
290 Received(Vec<u8>),
292}
293
294#[derive(Debug)]
295pub enum PeerToSocket {
296 SendImmediate(SocketAddr, Vec<u8>),
298 Send(SocketAddr, Vec<u8>),
299 PeerIsDisconnected(SocketAddr),
300}
301
302pub struct PeerRunner {
303 remote_addr: SocketAddr,
304 remote_is_server: bool,
305 connect_time: Instant,
306 recv_context: ProtocolContext,
307 send_context: ProtocolContext,
308
309 from_socket: UnboundedReceiver<SocketToPeer>,
311 to_socket: UnboundedSender<PeerToSocket>,
312
313 from_controller: UnboundedReceiver<Command>,
315 to_controller: UnboundedSender<Result<Command>>,
316
317 remote_peer_id: PeerId,
322 local_peer_id: PeerId,
323 rng: StdRng,
324
325 channels: Vec<Channel>,
326
327 now: Instant,
329
330 last_received: Instant,
332}
333
334impl PeerRunner {
335 pub fn update_now(&mut self) {
336 self.now = Instant::now();
337 for num in 0..=2 {
338 self.channels[num].update_now(&self.now);
339 }
340 }
341
342 pub fn serialize_for_send(&mut self, channel: u8, body: PacketBody) -> Result<Vec<u8>> {
343 let pkt = Packet::new(self.local_peer_id, channel, body);
344 let mut serializer = VecSerializer::new(self.send_context, 512);
345 Packet::serialize(&pkt, &mut serializer)?;
346 Ok(serializer.take())
347 }
348
349 pub async fn send_raw(&mut self, channel: u8, body: PacketBody) -> Result<()> {
350 let raw = self.serialize_for_send(channel, body)?;
351 self.to_socket
352 .send(PeerToSocket::Send(self.remote_addr, raw))?;
353 Ok(())
354 }
355
356 pub async fn send_raw_priority(&mut self, channel: u8, body: PacketBody) -> Result<()> {
357 let raw = self.serialize_for_send(channel, body)?;
358 self.to_socket
359 .send(PeerToSocket::SendImmediate(self.remote_addr, raw))?;
360 Ok(())
361 }
362
363 pub async fn run(mut self) {
364 if let Err(err) = self.run_inner().await {
365 let disconnected_cleanly: bool = if let Some(e) = err.downcast_ref::<PeerError>() {
370 matches!(e, PeerError::PeerSentDisconnect)
371 } else {
372 false
373 };
374 if !disconnected_cleanly {
375 let _ = self
377 .send_raw(0, (ControlBody::Disconnect).into_inner().into_unreliable())
378 .await;
379 }
380 let _ = self
381 .to_socket
382 .send(PeerToSocket::PeerIsDisconnected(self.remote_addr));
383
384 let _ = self.to_controller.send(Err(err));
386 }
387 }
388
389 pub async fn run_inner(&mut self) -> anyhow::Result<()> {
390 self.update_now();
391
392 let never = self.now + Duration::from_secs(315576000);
394
395 loop {
396 let mut next_wakeup = never;
399 for num in 0..=2 {
400 loop {
401 let pkt = self.channels[num].next_send(self.now);
402 match pkt {
403 Some(body) => self.send_raw(num as u8, body).await?,
404 None => break,
405 }
406 }
407 if let Some(timeout) = self.channels[num].next_timeout() {
408 next_wakeup = std::cmp::min(next_wakeup, timeout);
409 }
410 }
411
412 tokio::select! {
414 msg = self.from_socket.recv() => self.handle_from_socket(msg).await?,
415 command = self.from_controller.recv() => self.handle_from_controller(command).await?,
416 _ = tokio::time::sleep_until(next_wakeup.into()) => self.handle_timeout().await?,
417 }
418 }
419 }
420
421 async fn handle_from_socket(&mut self, msg: Option<SocketToPeer>) -> anyhow::Result<()> {
422 self.update_now();
423 let msg = match msg {
424 Some(msg) => msg,
425 None => bail!(PeerError::SocketClosed),
426 };
427 match msg {
428 SocketToPeer::Received(buf) => {
429 let mut deser = Deserializer::new(self.recv_context, &buf);
430 let pkt = Packet::deserialize(&mut deser)?;
431 self.last_received = self.now;
432 self.process_packet(pkt).await?;
433 }
434 };
435 Ok(())
436 }
437
438 async fn handle_from_controller(&mut self, command: Option<Command>) -> anyhow::Result<()> {
439 self.update_now();
440 let command = match command {
441 Some(command) => command,
442 None => bail!(PeerError::ControllerClosed),
443 };
444 self.sniff_hello(&command);
445
446 self.send_command(command).await?;
447 Ok(())
448 }
449
450 async fn handle_timeout(&mut self) -> anyhow::Result<()> {
451 self.update_now();
452 self.process_timeouts().await?;
453 Ok(())
454 }
455
456 async fn process_packet(&mut self, pkt: Packet) -> anyhow::Result<()> {
458 if !self.remote_is_server {
459 if self.remote_peer_id == 0 {
461 self.local_peer_id = 1;
463 self.remote_peer_id = self.rng.gen_range(2..65535);
464
465 let set_peer_id = SetPeerIdBody::new(self.remote_peer_id).into_inner();
467 self.channels[0].send_inner(true, set_peer_id);
468 }
469 if pkt.sender_peer_id == 0 {
470 if self.now > self.connect_time + INEXISTENT_PEER_ID_GRACE {
471 println!("Ignoring peer_id 0 packet");
473 return Ok(());
474 }
475 } else if pkt.sender_peer_id != self.remote_peer_id {
476 println!("Invalid peer_id on packet");
478 return Ok(());
479 }
480 } else {
481 if pkt.sender_peer_id != 1 {
482 println!("Server sending from wrong peer id");
483 return Ok(());
484 }
485 }
486
487 if let Some(rb) = pkt.as_reliable() {
489 self.send_ack(pkt.channel, rb).await?;
490 }
491
492 if let Some(control) = pkt.as_control() {
498 match control {
499 ControlBody::Ack(_) => {
500 }
502 ControlBody::SetPeerId(set_peer_id) => {
503 if !self.remote_is_server {
504 bail!("Invalid set_peer_id received from client");
505 } else {
506 if self.local_peer_id == 0 {
507 self.local_peer_id = set_peer_id.peer_id;
508 } else if self.local_peer_id != set_peer_id.peer_id {
509 bail!("Peer id mismatch in duplicate SetPeerId");
510 }
511 }
512 }
513 ControlBody::Ping => {
514 }
516 ControlBody::Disconnect => bail!(PeerError::PeerSentDisconnect),
517 }
518 }
519 if let Some(command) = pkt.body.command_ref() {
521 self.sniff_hello(command);
522 }
523
524 self.channels[pkt.channel as usize].process(pkt.body).await
525 }
526
527 fn sniff_hello(&mut self, command: &Command) {
528 match command {
529 Command::ToClient(ToClientCommand::Hello(spec)) => {
530 self.update_context(spec.serialization_ver, spec.proto_ver);
531 }
532 _ => (),
533 }
534 }
535
536 fn update_context(&mut self, ser_fmt: u8, protocol_version: u16) {
537 self.recv_context.protocol_version = protocol_version;
538 self.recv_context.ser_fmt = ser_fmt;
539 self.send_context.protocol_version = protocol_version;
540 self.send_context.ser_fmt = ser_fmt;
541 for num in 0..=2 {
542 self.channels[num].update_context(&self.recv_context, &self.send_context);
543 }
544 }
545
546 async fn send_ack(&mut self, channel: u8, rb: &ReliableBody) -> anyhow::Result<()> {
549 let ack = AckBody::new(rb.seqnum).into_inner().into_unreliable();
550 self.send_raw_priority(channel, ack).await?;
551 Ok(())
552 }
553
554 async fn send_command(&mut self, command: Command) -> anyhow::Result<()> {
556 let channel = command.default_channel();
557 let reliable = command.default_reliability();
558 assert!((0..=2).contains(&channel));
559 self.channels[channel as usize].send(reliable, command)
560 }
561
562 async fn process_timeouts(&mut self) -> anyhow::Result<()> {
563 Ok(())
564 }
565}