1use anyhow::{anyhow, bail, Context, Result};
3use bytes::{Bytes, BytesMut};
4use serde::{Deserialize, Serialize};
5use std::fmt::{self, Display, Formatter};
6use std::str::FromStr;
7use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
8use tracing::{debug, trace};
9use urlencoding::encode;
10
11use crate::fair_channel::FairGroup;
12use crate::utils::get_version_number;
13pub use prost::Message as ProstMessage;
14
15pub trait Endpoint {
16 fn credentials(&self) -> String;
17 fn as_url(&self) -> String;
18}
19
20pub trait DefaultPort {
21 fn default_port(&self) -> Option<u16>;
22}
23
24pub fn parse_enum<E: FromStr + Into<i32>>(name: &str) -> Result<i32> {
25 let proto = E::from_str(name).map_err(|_| anyhow!("Invalid enum: {}", name))?;
26 Ok(proto.into())
27}
28
29pub fn str_enum<E: TryFrom<i32> + ToString>(e: i32) -> String {
30 e.try_into()
31 .map(|e: E| e.to_string())
32 .unwrap_or("unknown".to_string())
33}
34
35include!(concat!(env!("OUT_DIR"), "/protocol.rs"));
36
37pub struct Data {
38 pub data: Bytes,
39 pub socket_addr: Option<std::net::SocketAddr>,
40}
41
42impl Display for Protocol {
43 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
44 match self {
45 Protocol::Http => write!(f, "http"),
46 Protocol::Https => write!(f, "https"),
47 Protocol::Tcp => write!(f, "tcp"),
48 Protocol::Udp => write!(f, "udp"),
49 Protocol::OneC => write!(f, "1c"),
50 Protocol::Minecraft => write!(f, "minecraft"),
51 Protocol::Webdav => write!(f, "webdav"),
52 Protocol::Rtsp => write!(f, "rtsp"),
53 }
54 }
55}
56
57impl FromStr for Protocol {
58 type Err = anyhow::Error;
59
60 fn from_str(s: &str) -> Result<Self> {
61 match s {
62 "http" => Ok(Protocol::Http),
63 "https" => Ok(Protocol::Https),
64 "tcp" => Ok(Protocol::Tcp),
65 "udp" => Ok(Protocol::Udp),
66 "1c" => Ok(Protocol::OneC),
67 "minecraft" => Ok(Protocol::Minecraft),
68 "webdav" => Ok(Protocol::Webdav),
69 "rtsp" => Ok(Protocol::Rtsp),
70 _ => bail!("Invalid protocol: {}", s),
71 }
72 }
73}
74
75impl DefaultPort for Protocol {
76 fn default_port(&self) -> Option<u16> {
77 match self {
78 Protocol::Http => Some(80),
79 Protocol::Https => Some(443),
80 Protocol::Tcp => None,
81 Protocol::Udp => None,
82 Protocol::OneC => None,
83 Protocol::Minecraft => Some(25565),
84 Protocol::Webdav => None,
85 Protocol::Rtsp => Some(554),
86 }
87 }
88}
89
90impl FromStr for Role {
91 type Err = anyhow::Error;
92
93 fn from_str(s: &str) -> Result<Self> {
94 match s {
95 "none" => Ok(Role::Nobody),
96 "admin" => Ok(Role::Admin),
97 "reader" => Ok(Role::Reader),
98 "writer" => Ok(Role::Writer),
99 _ => bail!("Invalid access: {}", s),
100 }
101 }
102}
103
104impl Display for Role {
105 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
106 match self {
107 Role::Nobody => write!(f, "none"),
108 Role::Admin => write!(f, "admin"),
109 Role::Reader => write!(f, "reader"),
110 Role::Writer => write!(f, "writer"),
111 }
112 }
113}
114
115impl FromStr for Auth {
116 type Err = anyhow::Error;
117
118 fn from_str(s: &str) -> Result<Self> {
119 match s {
120 "none" => Ok(Auth::None),
121 "basic" => Ok(Auth::Basic),
122 "form" => Ok(Auth::Form),
123 _ => bail!("Invalid auth: {}", s),
124 }
125 }
126}
127
128impl Display for Auth {
129 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
130 match self {
131 Auth::None => write!(f, "none"),
132 Auth::Basic => write!(f, "basic"),
133 Auth::Form => write!(f, "form"),
134 }
135 }
136}
137
138impl FromStr for FilterAction {
139 type Err = anyhow::Error;
140
141 fn from_str(s: &str) -> Result<Self> {
142 match s {
143 "allow" => Ok(FilterAction::FilterAllow),
144 "deny" => Ok(FilterAction::FilterDeny),
145 "redirect" => Ok(FilterAction::FilterRedirect),
146 "log" => Ok(FilterAction::FilterLog),
147 _ => bail!("Invalid filter action: {}", s),
148 }
149 }
150}
151
152impl PartialEq for ClientEndpoint {
153 fn eq(&self, other: &Self) -> bool {
154 self.local_proto == other.local_proto
155 && self.local_addr == other.local_addr
156 && self.local_port == other.local_port
157 }
158}
159
160impl Display for ClientEndpoint {
161 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
162 if let Some(name) = self.description.as_ref() {
163 write!(f, "[{}] ", name)?;
164 }
165 write!(f, "{}", self.as_url())
166 }
167}
168
169impl Endpoint for ClientEndpoint {
170 fn credentials(&self) -> String {
171 let mut s = String::new();
172 if !self.username.is_empty() {
173 s.push_str(&encode(&self.username));
174 }
175 if !self.password.is_empty() {
176 s.push(':');
177 s.push_str(&encode(&self.password));
178 }
179 if !s.is_empty() {
180 s.push('@');
181 }
182 s
183 }
184
185 fn as_url(&self) -> String {
186 match self.local_proto.try_into().unwrap() {
187 Protocol::OneC | Protocol::Minecraft | Protocol::Webdav => {
188 let credentials = self.credentials();
189 format!(
190 "{}://{}{}",
191 str_enum::<Protocol>(self.local_proto),
192 credentials,
193 &self.local_addr
194 )
195 }
196 Protocol::Http | Protocol::Https | Protocol::Tcp | Protocol::Udp | Protocol::Rtsp => {
197 let credentials = self.credentials();
198 format!(
199 "{}://{}{}:{}{}",
200 str_enum::<Protocol>(self.local_proto),
201 credentials,
202 self.local_addr,
203 self.local_port,
204 self.local_path
205 )
206 }
207 }
208 }
209}
210
211impl Endpoint for ServerEndpoint {
212 fn credentials(&self) -> String {
213 String::new()
214 }
215
216 fn as_url(&self) -> String {
217 format!(
218 "{}://{}:{}",
219 str_enum::<Protocol>(self.remote_proto),
220 self.remote_addr,
221 self.remote_port,
222 )
223 }
224}
225
226impl Display for ServerEndpoint {
227 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
228 let client = self.client.as_ref().unwrap();
229 if self.error.is_empty() {
230 write!(
231 f,
232 "{} -> {}://{}{}:{}{}",
233 client,
234 Protocol::try_from(self.remote_proto).unwrap(),
235 client.credentials(),
236 self.remote_addr,
237 self.remote_port,
238 client.local_path
239 )
240 } else {
241 write!(f, "{} -> {}", client, self.error)
242 }
243 }
244}
245
246impl AgentInfo {
247 pub fn is_support_server_control(&self) -> bool {
248 get_version_number(&self.version) >= get_version_number("2.1.1")
249 }
250
251 pub fn is_support_backpressure(&self) -> bool {
252 get_version_number(&self.version) >= get_version_number("2.2.0")
253 }
254
255 pub fn get_unique_id(&self) -> String {
256 if self.hwid.is_empty() {
257 self.agent_id.clone()
258 } else {
259 self.hwid.clone()
260 }
261 }
262}
263
264impl PartialEq for ServerEndpoint {
265 fn eq(&self, other: &Self) -> bool {
266 self.client == other.client
267 }
268}
269
270pub async fn read_message<T: AsyncRead + Unpin>(conn: &mut T) -> Result<message::Message> {
272 let mut buf = [0u8; std::mem::size_of::<u32>()];
273 conn.read_exact(&mut buf).await?;
274 let len = u32::from_le_bytes(buf) as usize;
275 if !(1..=1024 * 1024 * 10).contains(&len) {
276 bail!("Invalid message length: {}", len);
277 }
278 let mut buf = vec![0u8; len];
279 conn.read_exact(&mut buf).await?;
280
281 let proto_msg =
282 Message::decode(buf.as_slice()).context("Failed to decode Protocol Buffers message")?;
283
284 Ok(proto_msg.message.unwrap())
285}
286
287pub async fn write_message<T: AsyncWrite + Unpin>(
288 conn: &mut T,
289 msg: &message::Message,
290) -> Result<()> {
291 let proto_msg = Message {
292 message: Some(msg.clone()),
293 };
294
295 debug!("Sending proto message: {:?}", msg);
296
297 let mut buf = BytesMut::new();
298 proto_msg
299 .encode(&mut buf)
300 .context("Failed to encode Protocol Buffers message")?;
301
302 let len = buf.len() as u32;
303 let len_bytes = len.to_le_bytes();
304
305 conn.write_all(&len_bytes).await?;
306 conn.write_all(&buf).await?;
307 conn.flush().await?;
308 Ok(())
309}
310
311impl FairGroup for message::Message {
312 fn group_id(&self) -> Option<u32> {
313 match self {
314 message::Message::DataChannelData(lhs) => Some(lhs.channel_id),
315 message::Message::DataChannelDataUdp(lhs) => Some(lhs.channel_id),
316 message::Message::DataChannelEof(lhs) => Some(lhs.channel_id),
317 message::Message::DataChannelAck(lhs) => Some(lhs.channel_id),
318 _ => None,
319 }
320 }
321
322 fn get_size(&self) -> Option<usize> {
323 match self {
324 message::Message::DataChannelData(data) => Some(data.data.len()),
325 message::Message::DataChannelDataUdp(data) => Some(data.data.len()),
326 _ => None, }
328 }
329}
330
331pub type UdpPacketLen = u16; #[derive(Deserialize, Serialize, Debug)]
334pub struct UdpHeader {
335 from: std::net::SocketAddr,
336 len: UdpPacketLen,
337}
338
339#[derive(Debug)]
340pub struct UdpTraffic {
341 pub from: std::net::SocketAddr,
342 pub data: Bytes,
343}
344
345impl UdpTraffic {
346 pub async fn write<T: AsyncWrite + Unpin>(&self, writer: &mut T) -> Result<()> {
347 let hdr = UdpHeader {
348 from: self.from,
349 len: self.data.len() as UdpPacketLen,
350 };
351
352 let v = bincode::serde::encode_to_vec(&hdr, bincode::config::legacy()).unwrap();
353
354 trace!("Write {:?} of length {}", hdr, v.len());
355 writer.write_u8(v.len() as u8).await?;
356 writer.write_all(&v).await?;
357
358 writer.write_all(&self.data).await?;
359
360 Ok(())
361 }
362
363 #[allow(dead_code)]
364 pub async fn write_slice<T: AsyncWrite + Unpin>(
365 writer: &mut T,
366 from: std::net::SocketAddr,
367 data: &[u8],
368 ) -> Result<()> {
369 let hdr = UdpHeader {
370 from,
371 len: data.len() as UdpPacketLen,
372 };
373
374 let v = bincode::serde::encode_to_vec(&hdr, bincode::config::legacy()).unwrap();
375
376 trace!("Write {:?} of length {}", hdr, v.len());
377 writer.write_u8(v.len() as u8).await?;
378 writer.write_all(&v).await?;
379
380 writer.write_all(data).await?;
381
382 Ok(())
383 }
384
385 pub async fn read<T: AsyncRead + Unpin>(reader: &mut T, hdr_len: u8) -> Result<UdpTraffic> {
386 let mut buf = vec![0; hdr_len as usize];
387 reader
388 .read_exact(&mut buf)
389 .await
390 .with_context(|| "Failed to read udp header")?;
391
392 let hdr: UdpHeader = bincode::serde::decode_from_slice(&buf, bincode::config::legacy())
393 .with_context(|| "Failed to deserialize UdpHeader")?
394 .0;
395
396 trace!("hdr {:?}", hdr);
397
398 let mut data = BytesMut::new();
399 data.resize(hdr.len as usize, 0);
400 reader.read_exact(&mut data).await?;
401
402 Ok(UdpTraffic {
403 from: hdr.from,
404 data: data.freeze(),
405 })
406 }
407}