cloudpub_common/
protocol.rs

1// Create a module to contain the Protocol Buffers generated code
2use anyhow::{anyhow, bail, Context, Result};
3use bytes::{Bytes, BytesMut};
4use serde::{Deserialize, Serialize};
5use std::fmt::{self, Display, Formatter};
6use std::str::FromStr;
7use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
8use tracing::{debug, trace};
9use urlencoding::encode;
10
11use crate::fair_channel::FairGroup;
12use crate::utils::get_version_number;
13pub use prost::Message as ProstMessage;
14
15pub trait Endpoint {
16    fn credentials(&self) -> String;
17    fn as_url(&self) -> String;
18}
19
20pub trait DefaultPort {
21    fn default_port(&self) -> Option<u16>;
22}
23
24pub fn parse_enum<E: FromStr + Into<i32>>(name: &str) -> Result<i32> {
25    let proto = E::from_str(name).map_err(|_| anyhow!("Invalid enum: {}", name))?;
26    Ok(proto.into())
27}
28
29pub fn str_enum<E: TryFrom<i32> + ToString>(e: i32) -> String {
30    e.try_into()
31        .map(|e: E| e.to_string())
32        .unwrap_or("unknown".to_string())
33}
34
35include!(concat!(env!("OUT_DIR"), "/protocol.rs"));
36
37pub struct Data {
38    pub data: Bytes,
39    pub socket_addr: Option<std::net::SocketAddr>,
40}
41
42impl Display for Protocol {
43    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
44        match self {
45            Protocol::Http => write!(f, "http"),
46            Protocol::Https => write!(f, "https"),
47            Protocol::Tcp => write!(f, "tcp"),
48            Protocol::Udp => write!(f, "udp"),
49            Protocol::OneC => write!(f, "1c"),
50            Protocol::Minecraft => write!(f, "minecraft"),
51            Protocol::Webdav => write!(f, "webdav"),
52            Protocol::Rtsp => write!(f, "rtsp"),
53            Protocol::Rdp => write!(f, "rdp"),
54            Protocol::Vnc => write!(f, "vnc"),
55            Protocol::Ssh => write!(f, "ssh"),
56        }
57    }
58}
59
60impl FromStr for Protocol {
61    type Err = anyhow::Error;
62
63    fn from_str(s: &str) -> Result<Self> {
64        match s {
65            "http" => Ok(Protocol::Http),
66            "https" => Ok(Protocol::Https),
67            "tcp" => Ok(Protocol::Tcp),
68            "udp" => Ok(Protocol::Udp),
69            "1c" => Ok(Protocol::OneC),
70            "minecraft" => Ok(Protocol::Minecraft),
71            "webdav" => Ok(Protocol::Webdav),
72            "rtsp" => Ok(Protocol::Rtsp),
73            "rdp" => Ok(Protocol::Rdp),
74            "vnc" => Ok(Protocol::Vnc),
75            "ssh" => Ok(Protocol::Ssh),
76            _ => bail!("Invalid protocol: {}", s),
77        }
78    }
79}
80
81impl DefaultPort for Protocol {
82    fn default_port(&self) -> Option<u16> {
83        match self {
84            Protocol::Http => Some(80),
85            Protocol::Https => Some(443),
86            Protocol::Tcp => None,
87            Protocol::Udp => None,
88            Protocol::OneC => None,
89            Protocol::Minecraft => Some(25565),
90            Protocol::Webdav => None,
91            Protocol::Rtsp => Some(554),
92            Protocol::Rdp => Some(3389),
93            Protocol::Vnc => Some(5900),
94            Protocol::Ssh => Some(22),
95        }
96    }
97}
98
99impl FromStr for Role {
100    type Err = anyhow::Error;
101
102    fn from_str(s: &str) -> Result<Self> {
103        match s {
104            "none" => Ok(Role::Nobody),
105            "admin" => Ok(Role::Admin),
106            "reader" => Ok(Role::Reader),
107            "writer" => Ok(Role::Writer),
108            _ => bail!("Invalid access: {}", s),
109        }
110    }
111}
112
113impl Display for Role {
114    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
115        match self {
116            Role::Nobody => write!(f, "none"),
117            Role::Admin => write!(f, "admin"),
118            Role::Reader => write!(f, "reader"),
119            Role::Writer => write!(f, "writer"),
120        }
121    }
122}
123
124impl FromStr for Auth {
125    type Err = anyhow::Error;
126
127    fn from_str(s: &str) -> Result<Self> {
128        match s {
129            "none" => Ok(Auth::None),
130            "basic" => Ok(Auth::Basic),
131            "form" => Ok(Auth::Form),
132            _ => bail!("Invalid auth: {}", s),
133        }
134    }
135}
136
137impl Display for Auth {
138    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
139        match self {
140            Auth::None => write!(f, "none"),
141            Auth::Basic => write!(f, "basic"),
142            Auth::Form => write!(f, "form"),
143        }
144    }
145}
146
147impl FromStr for ProxyProtocol {
148    type Err = anyhow::Error;
149
150    fn from_str(s: &str) -> Result<Self> {
151        match s {
152            "none" => Ok(ProxyProtocol::None),
153            "v2" => Ok(ProxyProtocol::V2),
154            _ => bail!("Invalid proxy protocol: {}", s),
155        }
156    }
157}
158
159impl Display for ProxyProtocol {
160    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
161        match self {
162            ProxyProtocol::None => write!(f, "none"),
163            ProxyProtocol::V2 => write!(f, "v2"),
164        }
165    }
166}
167
168impl FromStr for FilterAction {
169    type Err = anyhow::Error;
170
171    fn from_str(s: &str) -> Result<Self> {
172        match s {
173            "allow" => Ok(FilterAction::FilterAllow),
174            "deny" => Ok(FilterAction::FilterDeny),
175            "redirect" => Ok(FilterAction::FilterRedirect),
176            "log" => Ok(FilterAction::FilterLog),
177            _ => bail!("Invalid filter action: {}", s),
178        }
179    }
180}
181
182impl PartialEq for ClientEndpoint {
183    fn eq(&self, other: &Self) -> bool {
184        self.local_proto == other.local_proto
185            && self.local_addr == other.local_addr
186            && self.local_port == other.local_port
187    }
188}
189
190impl Display for ClientEndpoint {
191    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
192        if let Some(name) = self.description.as_ref() {
193            if !name.is_empty() {
194                write!(f, "[{}] ", name)?;
195            }
196        }
197        write!(f, "{}", self.as_url())
198    }
199}
200
201impl Endpoint for ClientEndpoint {
202    fn credentials(&self) -> String {
203        let mut s = String::new();
204        if !self.username.is_empty() {
205            s.push_str(&encode(&self.username));
206        }
207        if !self.password.is_empty() {
208            s.push(':');
209            s.push_str(&encode(&self.password));
210        }
211        if !s.is_empty() {
212            s.push('@');
213        }
214        s
215    }
216
217    fn as_url(&self) -> String {
218        match self.local_proto.try_into().unwrap() {
219            Protocol::OneC | Protocol::Minecraft | Protocol::Webdav => {
220                let credentials = self.credentials();
221                format!(
222                    "{}://{}{}",
223                    str_enum::<Protocol>(self.local_proto),
224                    credentials,
225                    &self.local_addr
226                )
227            }
228            Protocol::Http
229            | Protocol::Https
230            | Protocol::Tcp
231            | Protocol::Udp
232            | Protocol::Rtsp
233            | Protocol::Rdp
234            | Protocol::Vnc
235            | Protocol::Ssh => {
236                let credentials = self.credentials();
237                format!(
238                    "{}://{}{}:{}{}",
239                    str_enum::<Protocol>(self.local_proto),
240                    credentials,
241                    self.local_addr,
242                    self.local_port,
243                    self.local_path
244                )
245            }
246        }
247    }
248}
249
250impl Endpoint for ServerEndpoint {
251    fn credentials(&self) -> String {
252        String::new()
253    }
254
255    fn as_url(&self) -> String {
256        format!(
257            "{}://{}:{}",
258            str_enum::<Protocol>(self.remote_proto),
259            self.remote_addr,
260            self.remote_port,
261        )
262    }
263}
264
265impl Display for ServerEndpoint {
266    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
267        let client = self.client.as_ref().unwrap();
268        if self.error.is_empty() {
269            write!(
270                f,
271                "{} -> {}://{}{}:{}{}",
272                client,
273                Protocol::try_from(self.remote_proto).unwrap(),
274                client.credentials(),
275                self.remote_addr,
276                self.remote_port,
277                client.local_path
278            )
279        } else {
280            write!(f, "{} -> {}", client, self.error)
281        }
282    }
283}
284
285impl AgentInfo {
286    pub fn is_support_server_control(&self) -> bool {
287        get_version_number(&self.version) >= get_version_number("2.1.1")
288    }
289
290    pub fn is_support_backpressure(&self) -> bool {
291        get_version_number(&self.version) >= get_version_number("2.2.0")
292    }
293
294    pub fn get_unique_id(&self) -> String {
295        if self.hwid.is_empty() {
296            self.agent_id.clone()
297        } else {
298            self.hwid.clone()
299        }
300    }
301}
302
303impl PartialEq for ServerEndpoint {
304    fn eq(&self, other: &Self) -> bool {
305        self.client == other.client
306    }
307}
308
309// New Protocol Buffers message reading and writing functions
310pub async fn read_message<T: AsyncRead + Unpin>(conn: &mut T) -> Result<message::Message> {
311    let mut buf = [0u8; std::mem::size_of::<u32>()];
312    conn.read_exact(&mut buf).await?;
313    let len = u32::from_le_bytes(buf) as usize;
314    if !(1..=1024 * 1024 * 10).contains(&len) {
315        bail!("Invalid message length: {}", len);
316    }
317    let mut buf = vec![0u8; len];
318    conn.read_exact(&mut buf).await?;
319
320    let proto_msg =
321        Message::decode(buf.as_slice()).context("Failed to decode Protocol Buffers message")?;
322
323    Ok(proto_msg.message.unwrap())
324}
325
326pub async fn write_message<T: AsyncWrite + Unpin>(
327    conn: &mut T,
328    msg: &message::Message,
329) -> Result<()> {
330    let proto_msg = Message {
331        message: Some(msg.clone()),
332    };
333
334    debug!("Sending proto message: {:?}", msg);
335
336    let mut buf = BytesMut::new();
337    proto_msg
338        .encode(&mut buf)
339        .context("Failed to encode Protocol Buffers message")?;
340
341    let len = buf.len() as u32;
342    let len_bytes = len.to_le_bytes();
343
344    conn.write_all(&len_bytes).await?;
345    conn.write_all(&buf).await?;
346    conn.flush().await?;
347    Ok(())
348}
349
350impl FairGroup for message::Message {
351    fn group_id(&self) -> Option<u32> {
352        match self {
353            message::Message::DataChannelData(lhs) => Some(lhs.channel_id),
354            message::Message::DataChannelDataUdp(lhs) => Some(lhs.channel_id),
355            message::Message::DataChannelEof(lhs) => Some(lhs.channel_id),
356            message::Message::DataChannelAck(lhs) => Some(lhs.channel_id),
357            _ => None,
358        }
359    }
360
361    fn get_size(&self) -> Option<usize> {
362        match self {
363            message::Message::DataChannelData(data) => Some(data.data.len()),
364            message::Message::DataChannelDataUdp(data) => Some(data.data.len()),
365            _ => None, // Control messages have no size
366        }
367    }
368}
369
370pub type UdpPacketLen = u16; // `u16` should be enough for any practical UDP traffic on the Internet
371                             //
372#[derive(Deserialize, Serialize, Debug)]
373pub struct UdpHeader {
374    from: std::net::SocketAddr,
375    len: UdpPacketLen,
376}
377
378#[derive(Debug)]
379pub struct UdpTraffic {
380    pub from: std::net::SocketAddr,
381    pub data: Bytes,
382}
383
384impl UdpTraffic {
385    pub async fn write<T: AsyncWrite + Unpin>(&self, writer: &mut T) -> Result<()> {
386        let hdr = UdpHeader {
387            from: self.from,
388            len: self.data.len() as UdpPacketLen,
389        };
390
391        let v = bincode::serde::encode_to_vec(&hdr, bincode::config::legacy()).unwrap();
392
393        trace!("Write {:?} of length {}", hdr, v.len());
394        writer.write_u8(v.len() as u8).await?;
395        writer.write_all(&v).await?;
396
397        writer.write_all(&self.data).await?;
398
399        Ok(())
400    }
401
402    #[allow(dead_code)]
403    pub async fn write_slice<T: AsyncWrite + Unpin>(
404        writer: &mut T,
405        from: std::net::SocketAddr,
406        data: &[u8],
407    ) -> Result<()> {
408        let hdr = UdpHeader {
409            from,
410            len: data.len() as UdpPacketLen,
411        };
412
413        let v = bincode::serde::encode_to_vec(&hdr, bincode::config::legacy()).unwrap();
414
415        trace!("Write {:?} of length {}", hdr, v.len());
416        writer.write_u8(v.len() as u8).await?;
417        writer.write_all(&v).await?;
418
419        writer.write_all(data).await?;
420
421        Ok(())
422    }
423
424    pub async fn read<T: AsyncRead + Unpin>(reader: &mut T, hdr_len: u8) -> Result<UdpTraffic> {
425        let mut buf = vec![0; hdr_len as usize];
426        reader
427            .read_exact(&mut buf)
428            .await
429            .with_context(|| "Failed to read udp header")?;
430
431        let hdr: UdpHeader = bincode::serde::decode_from_slice(&buf, bincode::config::legacy())
432            .with_context(|| "Failed to deserialize UdpHeader")?
433            .0;
434
435        trace!("hdr {:?}", hdr);
436
437        let mut data = BytesMut::new();
438        data.resize(hdr.len as usize, 0);
439        reader.read_exact(&mut data).await?;
440
441        Ok(UdpTraffic {
442            from: hdr.from,
443            data: data.freeze(),
444        })
445    }
446}