1use std::fmt;
7use std::net::SocketAddr;
8
9use crate::cursor::{Cursor, CursorMut};
10use crate::filter::Filter;
11use crate::server::Region;
12use crate::Error;
13
14#[derive(Clone, Debug, PartialEq)]
16pub struct QueryServers<T> {
17 pub region: Region,
19 pub last: SocketAddr,
21 pub filter: T,
23}
24
25impl QueryServers<()> {
26 pub const HEADER: &'static [u8] = b"1";
28}
29
30impl<'a, T: 'a> QueryServers<T>
31where
32 T: TryFrom<&'a [u8], Error = Error>,
33{
34 pub fn decode(src: &'a [u8]) -> Result<Self, Error> {
36 let mut cur = Cursor::new(src);
37 cur.expect(QueryServers::HEADER)?;
38 let region = cur.get_u8()?.try_into().map_err(|_| Error::InvalidRegion)?;
39 let last = cur.get_cstr_as_str()?;
40 let filter = match cur.get_bytes(cur.remaining())? {
41 [x @ .., 0] => x,
43 x => x,
44 };
45 Ok(Self {
46 region,
47 last: last.parse().map_err(|_| Error::InvalidQueryServersLast)?,
48 filter: T::try_from(filter)?,
49 })
50 }
51}
52
53impl<'a, T: 'a> QueryServers<T>
54where
55 for<'b> &'b T: fmt::Display,
56{
57 pub fn encode(&self, buf: &mut [u8]) -> Result<usize, Error> {
59 Ok(CursorMut::new(buf)
60 .put_bytes(QueryServers::HEADER)?
61 .put_u8(self.region as u8)?
62 .put_as_str(self.last)?
63 .put_u8(0)?
64 .put_as_str(&self.filter)?
65 .put_u8(0)?
66 .pos())
67 }
68}
69
70#[derive(Clone, Debug, PartialEq)]
72pub struct GetServerInfo {
73 pub protocol: u8,
75}
76
77impl GetServerInfo {
78 pub const HEADER: &'static [u8] = b"\xff\xff\xff\xffinfo ";
80
81 pub fn new(protocol: u8) -> Self {
83 Self { protocol }
84 }
85
86 pub fn decode(src: &[u8]) -> Result<Self, Error> {
88 let mut cur = Cursor::new(src);
89 cur.expect(Self::HEADER)?;
90 let protocol = cur
91 .get_str(cur.remaining())?
92 .parse()
93 .map_err(|_| Error::InvalidPacket)?;
94 Ok(Self { protocol })
95 }
96
97 pub fn encode(&self, buf: &mut [u8]) -> Result<usize, Error> {
99 Ok(CursorMut::new(buf)
100 .put_bytes(Self::HEADER)?
101 .put_as_str(self.protocol)?
102 .pos())
103 }
104}
105
106#[derive(Clone, Debug, PartialEq)]
108pub enum Packet<'a> {
109 QueryServers(QueryServers<Filter<'a>>),
111 GetServerInfo(GetServerInfo),
113}
114
115impl<'a> Packet<'a> {
116 pub fn decode(src: &'a [u8]) -> Result<Option<Self>, Error> {
118 if src.starts_with(QueryServers::HEADER) {
119 QueryServers::decode(src).map(Self::QueryServers)
120 } else if src.starts_with(GetServerInfo::HEADER) {
121 GetServerInfo::decode(src).map(Self::GetServerInfo)
122 } else {
123 return Ok(None);
124 }
125 .map(Some)
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::filter::{FilterFlags, Version};
133 use crate::wrappers::Str;
134 use std::net::{IpAddr, Ipv4Addr};
135
136 #[test]
137 fn query_servers() {
138 let p = QueryServers {
139 region: Region::RestOfTheWorld,
140 last: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
141 filter: Filter {
142 gamedir: Some(Str(&b"valve"[..])),
143 map: Some(Str(&b"crossfire"[..])),
144 key: Some(0xdeadbeef),
145 protocol: Some(49),
146 clver: Some(Version::new(0, 20)),
147 flags: FilterFlags::all(),
148 flags_mask: FilterFlags::all(),
149 },
150 };
151 let mut buf = [0; 512];
152 let n = p.encode(&mut buf).unwrap();
153 assert_eq!(Packet::decode(&buf[..n]), Ok(Some(Packet::QueryServers(p))));
154 }
155
156 #[test]
157 fn query_servers_filter_bug() {
158 let p = QueryServers {
159 region: Region::RestOfTheWorld,
160 last: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
161 filter: Filter {
162 gamedir: None,
163 protocol: Some(48),
164 map: None,
165 key: None,
166 clver: Some(Version::new(0, 20)),
167 flags: FilterFlags::empty(),
168 flags_mask: FilterFlags::NAT,
169 },
170 };
171
172 let s = b"1\xff0.0.0.0:0\x00\\protocol\\48\\clver\\0.20\\nat\\0\0";
173 assert_eq!(Packet::decode(s), Ok(Some(Packet::QueryServers(p.clone()))));
174
175 let s = b"1\xff0.0.0.0:0\x00\\protocol\\48\\clver\\0.20\\nat\\0";
176 assert_eq!(Packet::decode(s), Ok(Some(Packet::QueryServers(p))));
177 }
178
179 #[test]
180 fn get_server_info() {
181 let p = GetServerInfo::new(49);
182 let mut buf = [0; 512];
183 let n = p.encode(&mut buf).unwrap();
184 assert_eq!(
185 Packet::decode(&buf[..n]),
186 Ok(Some(Packet::GetServerInfo(p)))
187 );
188 }
189}