1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
use crate::errors::RconProtocolError;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::mem::size_of;

#[derive(Debug)]
pub(super) enum RconPacketType {
    Response,
    Login,
    RunCommand,
}

impl From<RconPacketType> for i32 {
    fn from(packet_type: RconPacketType) -> Self {
        match packet_type {
            RconPacketType::Response => 0,
            RconPacketType::RunCommand => 2,
            RconPacketType::Login => 3,
        }
    }
}

impl TryFrom<i32> for RconPacketType {
    type Error = RconProtocolError;

    fn try_from(value: i32) -> Result<Self, Self::Error> {
        match value {
            0 => Ok(RconPacketType::Response),
            2 => Ok(RconPacketType::RunCommand),
            3 => Ok(RconPacketType::Login),
            _ => Err(RconProtocolError::InvalidPacketType),
        }
    }
}

#[derive(Debug)]
pub(super) struct RconPacket {
    pub request_id: i32,
    pub packet_type: RconPacketType,
    pub payload: String,
}

impl RconPacket {
    pub fn new(
        request_id: i32,
        packet_type: RconPacketType,
        payload: String,
    ) -> Result<Self, RconProtocolError> {
        if !payload.is_ascii() {
            return Err(RconProtocolError::NonAsciiPayload);
        }

        if payload.len() > 1446 {}

        Ok(Self {
            request_id,
            packet_type,
            payload,
        })
    }

    pub fn bytes(self) -> Bytes {
        Bytes::from(self)
    }
}

impl TryFrom<Bytes> for RconPacket {
    type Error = RconProtocolError;

    fn try_from(mut bytes: Bytes) -> Result<Self, Self::Error> {
        let len = bytes.get_i32_le(); // length of remaining packet (not including this integer)
        let request_id = bytes.get_i32_le();
        let packet_type = bytes.get_i32_le();

        let mut payload = "".to_string();

        loop {
            let current = bytes.get_u8();
            if current == 0 {
                // null terminated ASCII string, so stop reading here
                break;
            }

            payload.push(current as char);
        }

        // if the payload is already normal ASCII (without 0xa7), no need to
        // check each character to be ASCII or 0xa7
        if !payload.is_ascii() {
            for c in payload.chars() {
                // 0xa7 is an acceptable (though non-ASCII) character
                if !c.is_ascii() && (c as u8) != 0xa7 {
                    return Err(RconProtocolError::NonAsciiPayload);
                }
            }
        }

        let pad = bytes.get_u8(); // there must be a remaining 0 byte as padding
        if pad != 0 {
            return Err(RconProtocolError::InvalidRconResponse);
        }

        // validate if the lengths match
        if get_remaining_length(&payload) != len {
            return Err(RconProtocolError::InvalidRconResponse);
        }

        Self::new(request_id, packet_type.try_into()?, payload)
    }
}

impl From<RconPacket> for Bytes {
    fn from(packet: RconPacket) -> Self {
        let len = get_remaining_length(&packet.payload);
        let packet_type: i32 = packet.packet_type.into();

        let mut bytes = BytesMut::new();

        bytes.put_i32_le(len);
        bytes.put_i32_le(packet.request_id);
        bytes.put_i32_le(packet_type);
        bytes.put(packet.payload.as_bytes());
        bytes.put_u16(0x00_00);

        bytes.freeze()
    }
}

/// Get the *remaining length* of the packet given its payload.
///
/// Remaining length here refers to the length of the packet in bytes excluding
/// the first four bytes which communicate this value. So it refers to the
/// length of the packet *after* the length field.
///
/// As the remainder of the packet is composed of two [i32]s (request ID and type),
/// the payload, and **TWO** 0 bytes (because rust strings are not null-terminated),
/// it is the size of two [i32]s + the length of the payload + 2.
fn get_remaining_length(payload: &str) -> i32 {
    (payload.len() + size_of::<i32>() * 2 + 2) as i32
}