1mod channel;
15mod reliable_receiver;
16mod reliable_sender;
17mod sequence_number;
18mod split_receiver;
19mod split_sender;
20
21use anyhow::Result;
22use anyhow::bail;
23use channel::Channel;
24use log::debug;
25use log::error;
26use log::info;
27use log::trace;
28use log::warn;
29use tokio::sync::mpsc::UnboundedReceiver;
30use tokio::sync::mpsc::UnboundedSender;
31use tokio::sync::mpsc::unbounded_channel;
32
33use crate::commands::Command;
34use crate::commands::CommandProperties;
35use crate::commands::server_to_client::ToClientCommand;
36use crate::types::ProtocolContext;
37use crate::wire::channel_id::ChannelId;
38use crate::wire::deser::Deserialize;
39use crate::wire::deser::Deserializer;
40use crate::wire::packet::AckBody;
41use crate::wire::packet::ControlBody;
42use crate::wire::packet::Packet;
43use crate::wire::packet::PacketBody;
44use crate::wire::packet::ReliableBody;
45use crate::wire::packet::SetPeerIdBody;
46use crate::wire::peer_id::PeerId;
47use crate::wire::ser::Serialize;
48use crate::wire::ser::VecSerializer;
49
50use reliable_receiver::ReliableReceiver;
51use reliable_sender::ReliableSender;
52use split_receiver::SplitReceiver;
53use split_sender::SplitSender;
54
55use std::net::SocketAddr;
56use std::time::Duration;
57use std::time::Instant;
58
59const INEXISTENT_PEER_ID_GRACE: Duration = Duration::from_secs(20);
61
62#[derive(thiserror::Error, Debug)]
63pub enum PeerError {
64 #[error("Peer sent disconnect packet")]
65 PeerSentDisconnect,
66 #[error("Socket Closed")]
67 SocketClosed,
68 #[error("Controller Closed")]
69 ControllerClosed,
70 #[error("Internal Peer error")]
71 InternalPeerError,
72}
73
74pub type FullSeqNum = u64;
75
76pub struct Peer {
78 remote_addr: SocketAddr,
79 remote_is_server: bool,
80 send: UnboundedSender<Command>,
82 recv: UnboundedReceiver<Result<Command>>,
83}
84
85impl Peer {
86 #[must_use]
87 pub fn remote_addr(&self) -> SocketAddr {
88 self.remote_addr
89 }
90
91 #[must_use]
93 pub fn is_server(&self) -> bool {
94 self.remote_is_server
95 }
96
97 pub fn send(&self, command: Command) -> Result<()> {
100 self.send.send(command)?;
101 Ok(())
102 }
103
104 pub async fn recv(&mut self) -> Result<Command> {
108 match self.recv.recv().await {
109 Some(result) => result,
110 None => bail!(PeerError::InternalPeerError),
111 }
112 }
113}
114
115pub struct PeerIO {
117 relay: UnboundedSender<SocketToPeer>,
118}
119
120#[must_use]
121pub fn new_peer(
122 remote_addr: SocketAddr,
123 remote_is_server: bool,
124 peer_to_socket: UnboundedSender<PeerToSocket>,
125) -> (Peer, PeerIO) {
126 let (peer_send_tx, peer_send_rx) = unbounded_channel();
127 let (peer_recv_tx, peer_recv_rx) = unbounded_channel();
128 let (relay_tx, relay_rx) = unbounded_channel();
129
130 let socket_peer = Peer {
131 remote_addr,
132 remote_is_server,
133 send: peer_send_tx,
134 recv: peer_recv_rx,
135 };
136 let socket_peer_io = PeerIO { relay: relay_tx };
137 let socket_peer_runner = PeerRunner {
138 remote_addr,
139 remote_is_server,
140 recv_context: ProtocolContext::latest_for_receive(remote_is_server),
141 send_context: ProtocolContext::latest_for_send(remote_is_server),
142 connect_time: Instant::now(),
143 remote_peer_id: PeerId::NONE,
144 local_peer_id: PeerId::NONE,
145 from_socket: relay_rx,
146 from_controller: peer_send_rx,
147 to_controller: peer_recv_tx.clone(),
148 to_socket: peer_to_socket,
149 channels: vec![
150 Channel::new(remote_is_server, peer_recv_tx.clone()),
151 Channel::new(remote_is_server, peer_recv_tx.clone()),
152 Channel::new(remote_is_server, peer_recv_tx.clone()),
153 ],
154 now: Instant::now(),
155 last_received: Instant::now(),
156 };
157 tokio::spawn(async move { socket_peer_runner.run().await });
158 (socket_peer, socket_peer_io)
159}
160
161impl PeerIO {
162 pub fn send(&mut self, data: &[u8]) {
166 self.relay
168 .send(SocketToPeer::Received(data.to_vec()))
169 .unwrap_or_else(|error| {
170 error!("failed to relay packet: {error}");
172 });
173 }
174}
175
176#[derive(Debug)]
177pub enum SocketToPeer {
178 Received(Vec<u8>),
180}
181
182#[derive(Debug)]
183pub enum PeerToSocket {
184 SendImmediate(SocketAddr, Vec<u8>),
186 Send(SocketAddr, Vec<u8>),
187 PeerIsDisconnected(SocketAddr),
188}
189
190pub struct PeerRunner {
191 remote_addr: SocketAddr,
192 remote_is_server: bool,
193 connect_time: Instant,
194 recv_context: ProtocolContext,
195 send_context: ProtocolContext,
196
197 from_socket: UnboundedReceiver<SocketToPeer>,
199 to_socket: UnboundedSender<PeerToSocket>,
200
201 from_controller: UnboundedReceiver<Command>,
203 to_controller: UnboundedSender<Result<Command>>,
204
205 remote_peer_id: PeerId,
210 local_peer_id: PeerId,
211
212 channels: Vec<Channel>,
213
214 now: Instant,
216
217 last_received: Instant,
219}
220
221impl PeerRunner {
222 pub fn update_now(&mut self) {
223 self.now = Instant::now();
224 self.channels
225 .iter_mut()
226 .for_each(|channel| channel.update_now(&self.now));
227 }
228
229 pub fn serialize_for_send(&mut self, channel: ChannelId, body: PacketBody) -> Result<Vec<u8>> {
230 let pkt = Packet::new(self.local_peer_id, channel, body);
231 let mut serializer = VecSerializer::new(self.send_context, 512);
232 Packet::serialize(&pkt, &mut serializer)?;
233 Ok(serializer.take())
234 }
235
236 pub fn send_raw(&mut self, channel: ChannelId, body: PacketBody) -> Result<()> {
237 let raw = self.serialize_for_send(channel, body)?;
238 self.to_socket
239 .send(PeerToSocket::Send(self.remote_addr, raw))?;
240 Ok(())
241 }
242
243 pub fn send_raw_priority(&mut self, channel: ChannelId, body: PacketBody) -> Result<()> {
244 let raw = self.serialize_for_send(channel, body)?;
245 self.to_socket
246 .send(PeerToSocket::SendImmediate(self.remote_addr, raw))?;
247 Ok(())
248 }
249
250 pub async fn run(mut self) {
251 if let Err(err) = self.run_inner().await {
252 let disconnected_cleanly = if let Some(error) = err.downcast_ref::<PeerError>() {
257 matches!(error, PeerError::PeerSentDisconnect)
258 } else {
259 false
260 };
261 if !disconnected_cleanly {
262 #[expect(
264 clippy::unwrap_used,
265 reason = "// TODO clarify error condition and handling"
266 )]
267 self.send_raw(
268 ChannelId::Default,
269 (ControlBody::Disconnect).into_inner().into_unreliable(),
270 )
271 .unwrap();
272 }
273 #[expect(
274 clippy::unwrap_used,
275 reason = "// TODO clarify error condition and handling"
276 )]
277 self.to_socket
278 .send(PeerToSocket::PeerIsDisconnected(self.remote_addr))
279 .unwrap();
280
281 self.to_controller.send(Err(err)).unwrap_or_else(|err| {
283 debug!("controller is no longer available: {err}");
286 });
287 }
288 }
289
290 pub async fn run_inner(&mut self) -> Result<()> {
291 self.update_now();
292
293 let never = self.now + Duration::from_secs(315_576_000);
295
296 loop {
297 let mut next_wakeup = never;
300 for channel_id in ChannelId::all() {
301 loop {
302 let pkt = self.channels[usize::from(channel_id)].next_send(self.now);
303 match pkt {
304 Some(body) => self.send_raw(channel_id, body)?,
305 None => break,
306 }
307 }
308 if let Some(timeout) = self.channels[usize::from(channel_id)].next_timeout() {
309 next_wakeup = std::cmp::min(next_wakeup, timeout);
310 }
311 }
312
313 tokio::select! {
315 msg = self.from_socket.recv() => self.handle_from_socket(msg)?,
316 command = self.from_controller.recv() => self.handle_from_controller(command)?,
317 () = tokio::time::sleep_until(next_wakeup.into()) => self.handle_timeout()?,
318 }
319 }
320 }
321
322 fn handle_from_socket(&mut self, msg: Option<SocketToPeer>) -> Result<()> {
323 self.update_now();
324 let Some(msg) = msg else {
325 bail!(PeerError::SocketClosed);
326 };
327 match msg {
328 SocketToPeer::Received(buf) => {
329 trace!(
330 "received {} bytes from socket: {:?}",
331 buf.len(),
332 &buf[0..buf.len().min(64)]
333 );
334 let mut deser = Deserializer::new(self.recv_context, &buf);
335 let pkt = Packet::deserialize(&mut deser)?;
336 self.last_received = self.now;
337 self.process_packet(pkt)?;
338 }
339 };
340 Ok(())
341 }
342
343 fn handle_from_controller(&mut self, command: Option<Command>) -> Result<()> {
344 trace!("received command from controller: {command:?}",);
345
346 self.update_now();
347 let Some(command) = command else {
348 bail!(PeerError::ControllerClosed);
349 };
350 self.sniff_hello(&command);
351
352 self.send_command(command)?;
353 Ok(())
354 }
355
356 fn handle_timeout(&mut self) -> Result<()> {
357 self.update_now();
358 self.process_timeouts()?;
359 Ok(())
360 }
361
362 fn process_packet(&mut self, pkt: Packet) -> Result<()> {
364 if self.remote_is_server {
365 if !pkt.sender_peer_id.is_server() {
366 warn!("Server sending from wrong peer id");
367 return Ok(());
368 }
369 } else {
370 if self.remote_peer_id.is_none() {
372 self.local_peer_id = PeerId::SERVER;
374 self.remote_peer_id = PeerId::random();
376
377 let set_peer_id = SetPeerIdBody::new(self.remote_peer_id).into_inner();
379 self.channels[0].send_inner(true, set_peer_id);
380 }
381 if pkt.sender_peer_id.is_none() {
382 if self.now > self.connect_time + INEXISTENT_PEER_ID_GRACE {
383 warn!("Ignoring peer_id 0 packet");
385 return Ok(());
386 }
387 } else if pkt.sender_peer_id != self.remote_peer_id {
388 warn!("Invalid peer_id on packet");
390 return Ok(());
391 }
392 }
393
394 if let Some(rb) = pkt.as_reliable() {
396 self.send_ack(pkt.channel, rb)?;
397 }
398
399 if let Some(control) = pkt.as_control() {
405 #[expect(clippy::match_same_arms, reason = "for better documentation")]
406 match control {
407 ControlBody::Ack(_) => {
408 }
410 ControlBody::SetPeerId(set_peer_id) => {
411 if self.remote_is_server {
412 if self.local_peer_id.is_none() {
413 self.local_peer_id = set_peer_id.peer_id;
414 } else if self.local_peer_id != set_peer_id.peer_id {
415 bail!("Peer id mismatch in duplicate SetPeerId");
416 }
417 } else {
418 bail!("Invalid set_peer_id received from client");
419 }
420 }
421 ControlBody::Ping => {
422 }
424 ControlBody::Disconnect => bail!(PeerError::PeerSentDisconnect),
425 }
426 }
427 if let Some(command) = pkt.body.command() {
429 self.sniff_hello(command);
430 }
431
432 self.channels[usize::from(pkt.channel)].process(pkt.body)
433 }
434
435 fn sniff_hello(&mut self, command: &Command) {
436 if let Command::ToClient(ToClientCommand::Hello(spec)) = command {
437 info!(
438 "Server protocol version {} / serialization version {}",
439 spec.proto_ver, spec.serialization_ver
440 );
441 self.update_context(spec.serialization_ver, spec.proto_ver);
442 }
443 }
444
445 fn update_context(&mut self, ser_fmt: u8, protocol_version: u16) {
446 self.recv_context.protocol_version = protocol_version;
447 self.recv_context.ser_fmt = ser_fmt;
448 self.send_context.protocol_version = protocol_version;
449 self.send_context.ser_fmt = ser_fmt;
450 self.channels
451 .iter_mut()
452 .for_each(|channel| channel.update_context(self.recv_context, self.send_context));
453 }
454
455 fn send_ack(&mut self, channel: ChannelId, rb: &ReliableBody) -> Result<()> {
458 let ack = AckBody::new(rb.seqnum).into_inner().into_unreliable();
459 self.send_raw_priority(channel, ack)?;
460 Ok(())
461 }
462
463 fn send_command(&mut self, command: Command) -> Result<()> {
465 let channel = command.default_channel();
466 let reliable = command.default_reliability();
467 self.channels[usize::from(channel)].send(reliable, command)
468 }
469
470 #[expect(
471 clippy::unused_self,
472 clippy::unnecessary_wraps,
473 reason = "// TODO this implementation looks incomplete"
474 )]
475 fn process_timeouts(&mut self) -> Result<()> {
476 Ok(())
477 }
478}