Skip to main content

voltlane_net/
lib.rs

1//! This module provides utilities for working with TCP sockets in an asynchronous context using Tokio.
2//! It handles sending and receiving size-prefixed messages and packets with client IDs. The module
3//! ensures cancellation safety where applicable, allowing for robust handling of asynchronous operations.
4//!
5//! If you call ANYTHING from this module, make SURE to read the notes about cancellation safety in each
6//! doc comment. If you don't, you will be cursed and your code will be haunted by the ghosts of the braincells
7//! who had to die in the tens of hours of work debugging and testing this code to make sure it's cancellation safe
8//! where needed.
9//!
10//! If you want to touch this, read up on cancellation safety in Rust and the Tokio documentation, first.
11//! Then, grab a {drink_of_choice}, block a time slot of a few hours, and enjoy yourself.
12//! Finally, make sure you can explain why the code is the way it is before you change it. And please, don't
13//! assume that it was ever correct; so if you find a bug, it's probably real.
14//!
15use tokio::{
16    io::{AsyncRead, AsyncWriteExt},
17    net::TcpStream,
18};
19use tokio_stream::StreamExt;
20use tokio_util::{
21    bytes::BytesMut,
22    codec::{FramedRead, LengthDelimitedCodec},
23};
24
25pub const PROTOCOL_VERSION: u32 = 2;
26
27pub type FramedReader<T> = FramedRead<T, LengthDelimitedCodec>;
28
29#[derive(Clone, bincode::Encode, bincode::Decode)]
30pub enum ClientServerPacket {
31    /// Bidirectional
32    /// used for checking the connection after reconnect
33    Ping,
34    /// Unidirectional Client -> Server
35    ProtocolVersion(u32),
36    /// Bidirectional
37    PubKey(Vec<u8>),
38    /// Unidirectional Server -> Client
39    ClientId(u64),
40    /// Unidirectional Server -> Client
41    Challenge(Vec<u8>),
42    /// Unidirectional Client -> Server
43    ChallengeResponse(Vec<u8>),
44}
45
46impl ClientServerPacket {
47    pub fn into_vec(self) -> Result<Vec<u8>, bincode::error::EncodeError> {
48        bincode::encode_to_vec(self, bincode::config::standard())
49    }
50
51    pub fn from_slice(data: &[u8]) -> Result<Self, bincode::error::DecodeError> {
52        bincode::decode_from_slice(&data, bincode::config::standard()).map(|(packet, _)| packet)
53    }
54}
55
56#[derive(Clone, Debug)]
57pub enum TaggedPacket {
58    Data { client_id: u64, data: Vec<u8> },
59    Failure { client_id: u64, error: String },
60    Kick { client_id: u64 },
61    Reconnection { client_id: u64 },
62}
63
64impl TaggedPacket {
65    pub fn client_id(&self) -> u64 {
66        match self {
67            TaggedPacket::Data { client_id, .. } => *client_id,
68            TaggedPacket::Failure { client_id, .. } => *client_id,
69            TaggedPacket::Kick { client_id } => *client_id,
70            TaggedPacket::Reconnection { client_id } => *client_id,
71        }
72    }
73
74    pub fn into_vec(self) -> Vec<u8> {
75        let mut buf = Vec::new();
76        match self {
77            TaggedPacket::Data { data, client_id } => {
78                buf.extend_from_slice(&client_id.to_le_bytes());
79                buf.push(0x00);
80                buf.extend_from_slice(&data);
81            }
82            TaggedPacket::Failure { error, client_id } => {
83                buf.extend_from_slice(&client_id.to_le_bytes());
84                buf.push(0x01);
85                buf.extend_from_slice(error.as_bytes());
86            }
87            TaggedPacket::Kick { client_id } => {
88                buf.extend_from_slice(&client_id.to_le_bytes());
89                buf.push(0x02);
90            }
91            TaggedPacket::Reconnection { client_id } => {
92                buf.extend_from_slice(&client_id.to_le_bytes());
93                buf.push(0x03);
94            }
95        }
96        buf
97    }
98}
99
100/// Configures a TCP socket for performance by setting relevant socket options.
101pub fn configure_performance_tcp_socket(stream: &mut TcpStream) -> std::io::Result<()> {
102    stream.set_nodelay(true)?;
103    stream.set_linger(Some(std::time::Duration::from_secs(5)))?;
104    Ok(())
105}
106
107pub fn new_framed_reader<T: AsyncRead + Unpin>(stream: T) -> FramedReader<T> {
108    LengthDelimitedCodec::builder()
109        // NOTE(lion): do we want .max_frame_length(1024 * 1024 * 10) or something here?
110        .length_field_type::<u32>()
111        .little_endian()
112        .new_read(stream)
113}
114
115/// Receives a size-prefixed message.
116///
117/// Create a `FramedReader` with `new_framed_reader` and pass it to this function.
118pub async fn recv_size_prefixed<T: AsyncRead + Unpin>(
119    read: &mut FramedReader<T>,
120) -> anyhow::Result<BytesMut> {
121    Ok(read
122        .next()
123        .await
124        .ok_or_else(|| anyhow::format_err!("Connection closed or Eof"))??)
125}
126
127pub async fn send_size_prefixed<T: AsyncWriteExt + Unpin>(
128    stream: &mut T,
129    message: &[u8],
130) -> anyhow::Result<()> {
131    let size = message.len() as u32;
132    let size_bytes = size.to_le_bytes();
133    let mut combined_message = Vec::with_capacity(4 + message.len());
134    combined_message.extend_from_slice(&size_bytes);
135    combined_message.extend_from_slice(message);
136    stream.write_all(&combined_message).await?;
137    Ok(())
138}
139
140/// Receives a packet with a client id.
141///
142/// # Cancellation Safety
143///
144/// This method IS cancellation safe. It peeks the data until it knows that enough data is available,
145/// and then reads it in a cancellation safe way.
146pub async fn recv_tagged_packet<T: AsyncRead + Unpin>(
147    read: &mut FramedReader<T>,
148) -> anyhow::Result<TaggedPacket> {
149    let buffer = recv_size_prefixed(read).await?;
150    if buffer.len() < 8 {
151        return Err(anyhow::format_err!("Packet too small"));
152    }
153    let client_id = u64::from_le_bytes(buffer[0..8].try_into().unwrap());
154    let buf: &[u8] = buffer[8..].into();
155
156    match buf[0] {
157        0x00 => {
158            // Data
159            Ok(TaggedPacket::Data {
160                client_id,
161                data: buf[1..].into(),
162            })
163        }
164        0x01 => {
165            // Failure
166            let error = String::from_utf8_lossy(&buf[1..]).to_string();
167            Ok(TaggedPacket::Failure { client_id, error })
168        }
169        0x02 => {
170            // Kick
171            Ok(TaggedPacket::Kick { client_id })
172        }
173        0x03 => {
174            // Reconnection
175            Ok(TaggedPacket::Reconnection { client_id })
176        }
177        _ => {
178            return Err(anyhow::format_err!("Unknown packet type"));
179        }
180    }
181}
182
183/// Sends a packet with a client id.
184///
185/// # Cancellation Safety
186///
187/// This method is NOT cancellation safe.
188pub async fn send_tagged_packet<T: AsyncWriteExt + Unpin>(
189    stream: &mut T,
190    packet: TaggedPacket,
191) -> anyhow::Result<()> {
192    let data = packet.into_vec();
193    let size = data.len() as u32;
194    let size_bytes = size.to_le_bytes();
195
196    let mut combined_message = Vec::with_capacity(size as usize + 4);
197    combined_message.extend_from_slice(&size_bytes);
198    combined_message.extend_from_slice(&data);
199
200    stream.write_all(&combined_message).await?;
201    Ok(())
202}