1use anyhow::{anyhow, bail, Context, Result};
3use bytes::{Bytes, BytesMut};
4use serde::{Deserialize, Serialize};
5use std::fmt::{self, Display, Formatter};
6use std::str::FromStr;
7use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
8use tracing::{debug, trace};
9use urlencoding::encode;
10
11use crate::fair_channel::FairGroup;
12use crate::utils::get_version_number;
13pub use prost::Message as ProstMessage;
14
15pub trait Endpoint {
16 fn credentials(&self) -> String;
17 fn as_url(&self) -> String;
18}
19
20pub trait DefaultPort {
21 fn default_port(&self) -> Option<u16>;
22}
23
24pub fn parse_enum<E: FromStr + Into<i32>>(name: &str) -> Result<i32> {
25 let proto = E::from_str(name).map_err(|_| anyhow!("Invalid enum: {}", name))?;
26 Ok(proto.into())
27}
28
29pub fn str_enum<E: TryFrom<i32> + ToString>(e: i32) -> String {
30 e.try_into()
31 .map(|e: E| e.to_string())
32 .unwrap_or("unknown".to_string())
33}
34
35include!(concat!(env!("OUT_DIR"), "/protocol.rs"));
36
37pub struct Data {
38 pub data: Bytes,
39 pub socket_addr: Option<std::net::SocketAddr>,
40}
41
42impl Display for Protocol {
43 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
44 match self {
45 Protocol::Http => write!(f, "http"),
46 Protocol::Https => write!(f, "https"),
47 Protocol::Tcp => write!(f, "tcp"),
48 Protocol::Udp => write!(f, "udp"),
49 Protocol::OneC => write!(f, "1c"),
50 Protocol::Minecraft => write!(f, "minecraft"),
51 Protocol::Webdav => write!(f, "webdav"),
52 Protocol::Rtsp => write!(f, "rtsp"),
53 Protocol::Rdp => write!(f, "rdp"),
54 Protocol::Vnc => write!(f, "vnc"),
55 Protocol::Ssh => write!(f, "ssh"),
56 }
57 }
58}
59
60impl FromStr for Protocol {
61 type Err = anyhow::Error;
62
63 fn from_str(s: &str) -> Result<Self> {
64 match s {
65 "http" => Ok(Protocol::Http),
66 "https" => Ok(Protocol::Https),
67 "tcp" => Ok(Protocol::Tcp),
68 "udp" => Ok(Protocol::Udp),
69 "1c" => Ok(Protocol::OneC),
70 "minecraft" => Ok(Protocol::Minecraft),
71 "webdav" => Ok(Protocol::Webdav),
72 "rtsp" => Ok(Protocol::Rtsp),
73 "rdp" => Ok(Protocol::Rdp),
74 "vnc" => Ok(Protocol::Vnc),
75 "ssh" => Ok(Protocol::Ssh),
76 _ => bail!("Invalid protocol: {}", s),
77 }
78 }
79}
80
81impl DefaultPort for Protocol {
82 fn default_port(&self) -> Option<u16> {
83 match self {
84 Protocol::Http => Some(80),
85 Protocol::Https => Some(443),
86 Protocol::Tcp => None,
87 Protocol::Udp => None,
88 Protocol::OneC => None,
89 Protocol::Minecraft => Some(25565),
90 Protocol::Webdav => None,
91 Protocol::Rtsp => Some(554),
92 Protocol::Rdp => Some(3389),
93 Protocol::Vnc => Some(5900),
94 Protocol::Ssh => Some(22),
95 }
96 }
97}
98
99impl FromStr for Role {
100 type Err = anyhow::Error;
101
102 fn from_str(s: &str) -> Result<Self> {
103 match s {
104 "none" => Ok(Role::Nobody),
105 "admin" => Ok(Role::Admin),
106 "reader" => Ok(Role::Reader),
107 "writer" => Ok(Role::Writer),
108 _ => bail!("Invalid access: {}", s),
109 }
110 }
111}
112
113impl Display for Role {
114 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
115 match self {
116 Role::Nobody => write!(f, "none"),
117 Role::Admin => write!(f, "admin"),
118 Role::Reader => write!(f, "reader"),
119 Role::Writer => write!(f, "writer"),
120 }
121 }
122}
123
124impl FromStr for Auth {
125 type Err = anyhow::Error;
126
127 fn from_str(s: &str) -> Result<Self> {
128 match s {
129 "none" => Ok(Auth::None),
130 "basic" => Ok(Auth::Basic),
131 "form" => Ok(Auth::Form),
132 _ => bail!("Invalid auth: {}", s),
133 }
134 }
135}
136
137impl Display for Auth {
138 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
139 match self {
140 Auth::None => write!(f, "none"),
141 Auth::Basic => write!(f, "basic"),
142 Auth::Form => write!(f, "form"),
143 }
144 }
145}
146
147impl FromStr for ProxyProtocol {
148 type Err = anyhow::Error;
149
150 fn from_str(s: &str) -> Result<Self> {
151 match s {
152 "none" => Ok(ProxyProtocol::None),
153 "v2" => Ok(ProxyProtocol::V2),
154 _ => bail!("Invalid proxy protocol: {}", s),
155 }
156 }
157}
158
159impl Display for ProxyProtocol {
160 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
161 match self {
162 ProxyProtocol::None => write!(f, "none"),
163 ProxyProtocol::V2 => write!(f, "v2"),
164 }
165 }
166}
167
168impl FromStr for FilterAction {
169 type Err = anyhow::Error;
170
171 fn from_str(s: &str) -> Result<Self> {
172 match s {
173 "allow" => Ok(FilterAction::FilterAllow),
174 "deny" => Ok(FilterAction::FilterDeny),
175 "redirect" => Ok(FilterAction::FilterRedirect),
176 "log" => Ok(FilterAction::FilterLog),
177 _ => bail!("Invalid filter action: {}", s),
178 }
179 }
180}
181
182impl PartialEq for ClientEndpoint {
183 fn eq(&self, other: &Self) -> bool {
184 self.local_proto == other.local_proto
185 && self.local_addr == other.local_addr
186 && self.local_port == other.local_port
187 }
188}
189
190impl Display for ClientEndpoint {
191 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
192 if let Some(name) = self.description.as_ref() {
193 if !name.is_empty() {
194 write!(f, "[{}] ", name)?;
195 }
196 }
197 write!(f, "{}", self.as_url())
198 }
199}
200
201impl Endpoint for ClientEndpoint {
202 fn credentials(&self) -> String {
203 let mut s = String::new();
204 if !self.username.is_empty() {
205 s.push_str(&encode(&self.username));
206 }
207 if !self.password.is_empty() {
208 s.push(':');
209 s.push_str(&encode(&self.password));
210 }
211 if !s.is_empty() {
212 s.push('@');
213 }
214 s
215 }
216
217 fn as_url(&self) -> String {
218 match self.local_proto.try_into().unwrap() {
219 Protocol::OneC | Protocol::Minecraft | Protocol::Webdav => {
220 let credentials = self.credentials();
221 format!(
222 "{}://{}{}",
223 str_enum::<Protocol>(self.local_proto),
224 credentials,
225 &self.local_addr
226 )
227 }
228 Protocol::Http
229 | Protocol::Https
230 | Protocol::Tcp
231 | Protocol::Udp
232 | Protocol::Rtsp
233 | Protocol::Rdp
234 | Protocol::Vnc
235 | Protocol::Ssh => {
236 let credentials = self.credentials();
237 format!(
238 "{}://{}{}:{}{}",
239 str_enum::<Protocol>(self.local_proto),
240 credentials,
241 self.local_addr,
242 self.local_port,
243 self.local_path
244 )
245 }
246 }
247 }
248}
249
250impl Endpoint for ServerEndpoint {
251 fn credentials(&self) -> String {
252 String::new()
253 }
254
255 fn as_url(&self) -> String {
256 format!(
257 "{}://{}:{}",
258 str_enum::<Protocol>(self.remote_proto),
259 self.remote_addr,
260 self.remote_port,
261 )
262 }
263}
264
265impl Display for ServerEndpoint {
266 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
267 let client = self.client.as_ref().unwrap();
268 if self.error.is_empty() {
269 write!(
270 f,
271 "{} -> {}://{}{}:{}{}",
272 client,
273 Protocol::try_from(self.remote_proto).unwrap(),
274 client.credentials(),
275 self.remote_addr,
276 self.remote_port,
277 client.local_path
278 )
279 } else {
280 write!(f, "{} -> {}", client, self.error)
281 }
282 }
283}
284
285impl AgentInfo {
286 pub fn is_support_server_control(&self) -> bool {
287 get_version_number(&self.version) >= get_version_number("2.1.1")
288 }
289
290 pub fn is_support_backpressure(&self) -> bool {
291 get_version_number(&self.version) >= get_version_number("2.2.0")
292 }
293
294 pub fn get_unique_id(&self) -> String {
295 if self.hwid.is_empty() {
296 self.agent_id.clone()
297 } else {
298 self.hwid.clone()
299 }
300 }
301}
302
303impl PartialEq for ServerEndpoint {
304 fn eq(&self, other: &Self) -> bool {
305 self.client == other.client
306 }
307}
308
309pub async fn read_message<T: AsyncRead + Unpin>(conn: &mut T) -> Result<message::Message> {
311 let mut buf = [0u8; std::mem::size_of::<u32>()];
312 conn.read_exact(&mut buf).await?;
313 let len = u32::from_le_bytes(buf) as usize;
314 if !(1..=1024 * 1024 * 10).contains(&len) {
315 bail!("Invalid message length: {}", len);
316 }
317 let mut buf = vec![0u8; len];
318 conn.read_exact(&mut buf).await?;
319
320 let proto_msg =
321 Message::decode(buf.as_slice()).context("Failed to decode Protocol Buffers message")?;
322
323 Ok(proto_msg.message.unwrap())
324}
325
326pub async fn write_message<T: AsyncWrite + Unpin>(
327 conn: &mut T,
328 msg: &message::Message,
329) -> Result<()> {
330 let proto_msg = Message {
331 message: Some(msg.clone()),
332 };
333
334 debug!("Sending proto message: {:?}", msg);
335
336 let mut buf = BytesMut::new();
337 proto_msg
338 .encode(&mut buf)
339 .context("Failed to encode Protocol Buffers message")?;
340
341 let len = buf.len() as u32;
342 let len_bytes = len.to_le_bytes();
343
344 conn.write_all(&len_bytes).await?;
345 conn.write_all(&buf).await?;
346 conn.flush().await?;
347 Ok(())
348}
349
350impl FairGroup for message::Message {
351 fn group_id(&self) -> Option<u32> {
352 match self {
353 message::Message::DataChannelData(lhs) => Some(lhs.channel_id),
354 message::Message::DataChannelDataUdp(lhs) => Some(lhs.channel_id),
355 message::Message::DataChannelEof(lhs) => Some(lhs.channel_id),
356 message::Message::DataChannelAck(lhs) => Some(lhs.channel_id),
357 _ => None,
358 }
359 }
360
361 fn get_size(&self) -> Option<usize> {
362 match self {
363 message::Message::DataChannelData(data) => Some(data.data.len()),
364 message::Message::DataChannelDataUdp(data) => Some(data.data.len()),
365 _ => None, }
367 }
368}
369
370pub type UdpPacketLen = u16; #[derive(Deserialize, Serialize, Debug)]
373pub struct UdpHeader {
374 from: std::net::SocketAddr,
375 len: UdpPacketLen,
376}
377
378#[derive(Debug)]
379pub struct UdpTraffic {
380 pub from: std::net::SocketAddr,
381 pub data: Bytes,
382}
383
384impl UdpTraffic {
385 pub async fn write<T: AsyncWrite + Unpin>(&self, writer: &mut T) -> Result<()> {
386 let hdr = UdpHeader {
387 from: self.from,
388 len: self.data.len() as UdpPacketLen,
389 };
390
391 let v = bincode::serde::encode_to_vec(&hdr, bincode::config::legacy()).unwrap();
392
393 trace!("Write {:?} of length {}", hdr, v.len());
394 writer.write_u8(v.len() as u8).await?;
395 writer.write_all(&v).await?;
396
397 writer.write_all(&self.data).await?;
398
399 Ok(())
400 }
401
402 #[allow(dead_code)]
403 pub async fn write_slice<T: AsyncWrite + Unpin>(
404 writer: &mut T,
405 from: std::net::SocketAddr,
406 data: &[u8],
407 ) -> Result<()> {
408 let hdr = UdpHeader {
409 from,
410 len: data.len() as UdpPacketLen,
411 };
412
413 let v = bincode::serde::encode_to_vec(&hdr, bincode::config::legacy()).unwrap();
414
415 trace!("Write {:?} of length {}", hdr, v.len());
416 writer.write_u8(v.len() as u8).await?;
417 writer.write_all(&v).await?;
418
419 writer.write_all(data).await?;
420
421 Ok(())
422 }
423
424 pub async fn read<T: AsyncRead + Unpin>(reader: &mut T, hdr_len: u8) -> Result<UdpTraffic> {
425 let mut buf = vec![0; hdr_len as usize];
426 reader
427 .read_exact(&mut buf)
428 .await
429 .with_context(|| "Failed to read udp header")?;
430
431 let hdr: UdpHeader = bincode::serde::decode_from_slice(&buf, bincode::config::legacy())
432 .with_context(|| "Failed to deserialize UdpHeader")?
433 .0;
434
435 trace!("hdr {:?}", hdr);
436
437 let mut data = BytesMut::new();
438 data.resize(hdr.len as usize, 0);
439 reader.read_exact(&mut data).await?;
440
441 Ok(UdpTraffic {
442 from: hdr.from,
443 data: data.freeze(),
444 })
445 }
446}