cloudpub_common/
protocol.rs

1// Create a module to contain the Protocol Buffers generated code
2use anyhow::{anyhow, bail, Context, Result};
3use bytes::{Bytes, BytesMut};
4use serde::{Deserialize, Serialize};
5use std::fmt::{self, Display, Formatter};
6use std::str::FromStr;
7use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
8use tracing::{debug, trace};
9use urlencoding::encode;
10
11use crate::fair_channel::FairGroup;
12use crate::utils::get_version_number;
13pub use prost::Message as ProstMessage;
14
15pub trait Endpoint {
16    fn credentials(&self) -> String;
17    fn as_url(&self) -> String;
18}
19
20pub trait DefaultPort {
21    fn default_port(&self) -> Option<u16>;
22}
23
24pub fn parse_enum<E: FromStr + Into<i32>>(name: &str) -> Result<i32> {
25    let proto = E::from_str(name).map_err(|_| anyhow!("Invalid enum: {}", name))?;
26    Ok(proto.into())
27}
28
29pub fn str_enum<E: TryFrom<i32> + ToString>(e: i32) -> String {
30    e.try_into()
31        .map(|e: E| e.to_string())
32        .unwrap_or("unknown".to_string())
33}
34
35include!(concat!(env!("OUT_DIR"), "/protocol.rs"));
36
37pub struct Data {
38    pub data: Bytes,
39    pub socket_addr: Option<std::net::SocketAddr>,
40}
41
42impl Display for Protocol {
43    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
44        match self {
45            Protocol::Http => write!(f, "http"),
46            Protocol::Https => write!(f, "https"),
47            Protocol::Tcp => write!(f, "tcp"),
48            Protocol::Udp => write!(f, "udp"),
49            Protocol::OneC => write!(f, "1c"),
50            Protocol::Minecraft => write!(f, "minecraft"),
51            Protocol::Webdav => write!(f, "webdav"),
52            Protocol::Rtsp => write!(f, "rtsp"),
53        }
54    }
55}
56
57impl FromStr for Protocol {
58    type Err = anyhow::Error;
59
60    fn from_str(s: &str) -> Result<Self> {
61        match s {
62            "http" => Ok(Protocol::Http),
63            "https" => Ok(Protocol::Https),
64            "tcp" => Ok(Protocol::Tcp),
65            "udp" => Ok(Protocol::Udp),
66            "1c" => Ok(Protocol::OneC),
67            "minecraft" => Ok(Protocol::Minecraft),
68            "webdav" => Ok(Protocol::Webdav),
69            "rtsp" => Ok(Protocol::Rtsp),
70            _ => bail!("Invalid protocol: {}", s),
71        }
72    }
73}
74
75impl DefaultPort for Protocol {
76    fn default_port(&self) -> Option<u16> {
77        match self {
78            Protocol::Http => Some(80),
79            Protocol::Https => Some(443),
80            Protocol::Tcp => None,
81            Protocol::Udp => None,
82            Protocol::OneC => None,
83            Protocol::Minecraft => Some(25565),
84            Protocol::Webdav => None,
85            Protocol::Rtsp => Some(554),
86        }
87    }
88}
89
90impl FromStr for Role {
91    type Err = anyhow::Error;
92
93    fn from_str(s: &str) -> Result<Self> {
94        match s {
95            "none" => Ok(Role::Nobody),
96            "admin" => Ok(Role::Admin),
97            "reader" => Ok(Role::Reader),
98            "writer" => Ok(Role::Writer),
99            _ => bail!("Invalid access: {}", s),
100        }
101    }
102}
103
104impl Display for Role {
105    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
106        match self {
107            Role::Nobody => write!(f, "none"),
108            Role::Admin => write!(f, "admin"),
109            Role::Reader => write!(f, "reader"),
110            Role::Writer => write!(f, "writer"),
111        }
112    }
113}
114
115impl FromStr for Auth {
116    type Err = anyhow::Error;
117
118    fn from_str(s: &str) -> Result<Self> {
119        match s {
120            "none" => Ok(Auth::None),
121            "basic" => Ok(Auth::Basic),
122            "form" => Ok(Auth::Form),
123            _ => bail!("Invalid auth: {}", s),
124        }
125    }
126}
127
128impl Display for Auth {
129    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
130        match self {
131            Auth::None => write!(f, "none"),
132            Auth::Basic => write!(f, "basic"),
133            Auth::Form => write!(f, "form"),
134        }
135    }
136}
137
138impl FromStr for FilterAction {
139    type Err = anyhow::Error;
140
141    fn from_str(s: &str) -> Result<Self> {
142        match s {
143            "allow" => Ok(FilterAction::FilterAllow),
144            "deny" => Ok(FilterAction::FilterDeny),
145            "redirect" => Ok(FilterAction::FilterRedirect),
146            "log" => Ok(FilterAction::FilterLog),
147            _ => bail!("Invalid filter action: {}", s),
148        }
149    }
150}
151
152impl PartialEq for ClientEndpoint {
153    fn eq(&self, other: &Self) -> bool {
154        self.local_proto == other.local_proto
155            && self.local_addr == other.local_addr
156            && self.local_port == other.local_port
157    }
158}
159
160impl Display for ClientEndpoint {
161    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
162        if let Some(name) = self.description.as_ref() {
163            write!(f, "[{}] ", name)?;
164        }
165        write!(f, "{}", self.as_url())
166    }
167}
168
169impl Endpoint for ClientEndpoint {
170    fn credentials(&self) -> String {
171        let mut s = String::new();
172        if !self.username.is_empty() {
173            s.push_str(&encode(&self.username));
174        }
175        if !self.password.is_empty() {
176            s.push(':');
177            s.push_str(&encode(&self.password));
178        }
179        if !s.is_empty() {
180            s.push('@');
181        }
182        s
183    }
184
185    fn as_url(&self) -> String {
186        match self.local_proto.try_into().unwrap() {
187            Protocol::OneC | Protocol::Minecraft | Protocol::Webdav => {
188                let credentials = self.credentials();
189                format!(
190                    "{}://{}{}",
191                    str_enum::<Protocol>(self.local_proto),
192                    credentials,
193                    &self.local_addr
194                )
195            }
196            Protocol::Http | Protocol::Https | Protocol::Tcp | Protocol::Udp | Protocol::Rtsp => {
197                let credentials = self.credentials();
198                format!(
199                    "{}://{}{}:{}{}",
200                    str_enum::<Protocol>(self.local_proto),
201                    credentials,
202                    self.local_addr,
203                    self.local_port,
204                    self.local_path
205                )
206            }
207        }
208    }
209}
210
211impl Endpoint for ServerEndpoint {
212    fn credentials(&self) -> String {
213        String::new()
214    }
215
216    fn as_url(&self) -> String {
217        format!(
218            "{}://{}:{}",
219            str_enum::<Protocol>(self.remote_proto),
220            self.remote_addr,
221            self.remote_port,
222        )
223    }
224}
225
226impl Display for ServerEndpoint {
227    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
228        let client = self.client.as_ref().unwrap();
229        if self.error.is_empty() {
230            write!(
231                f,
232                "{} -> {}://{}{}:{}{}",
233                client,
234                Protocol::try_from(self.remote_proto).unwrap(),
235                client.credentials(),
236                self.remote_addr,
237                self.remote_port,
238                client.local_path
239            )
240        } else {
241            write!(f, "{} -> {}", client, self.error)
242        }
243    }
244}
245
246impl AgentInfo {
247    pub fn is_support_server_control(&self) -> bool {
248        get_version_number(&self.version) >= get_version_number("2.1.1")
249    }
250
251    pub fn is_support_backpressure(&self) -> bool {
252        get_version_number(&self.version) >= get_version_number("2.2.0")
253    }
254
255    pub fn get_unique_id(&self) -> String {
256        if self.hwid.is_empty() {
257            self.agent_id.clone()
258        } else {
259            self.hwid.clone()
260        }
261    }
262}
263
264impl PartialEq for ServerEndpoint {
265    fn eq(&self, other: &Self) -> bool {
266        self.client == other.client
267    }
268}
269
270// New Protocol Buffers message reading and writing functions
271pub async fn read_message<T: AsyncRead + Unpin>(conn: &mut T) -> Result<message::Message> {
272    let mut buf = [0u8; std::mem::size_of::<u32>()];
273    conn.read_exact(&mut buf).await?;
274    let len = u32::from_le_bytes(buf) as usize;
275    if !(1..=1024 * 1024 * 10).contains(&len) {
276        bail!("Invalid message length: {}", len);
277    }
278    let mut buf = vec![0u8; len];
279    conn.read_exact(&mut buf).await?;
280
281    let proto_msg =
282        Message::decode(buf.as_slice()).context("Failed to decode Protocol Buffers message")?;
283
284    Ok(proto_msg.message.unwrap())
285}
286
287pub async fn write_message<T: AsyncWrite + Unpin>(
288    conn: &mut T,
289    msg: &message::Message,
290) -> Result<()> {
291    let proto_msg = Message {
292        message: Some(msg.clone()),
293    };
294
295    debug!("Sending proto message: {:?}", msg);
296
297    let mut buf = BytesMut::new();
298    proto_msg
299        .encode(&mut buf)
300        .context("Failed to encode Protocol Buffers message")?;
301
302    let len = buf.len() as u32;
303    let len_bytes = len.to_le_bytes();
304
305    conn.write_all(&len_bytes).await?;
306    conn.write_all(&buf).await?;
307    conn.flush().await?;
308    Ok(())
309}
310
311impl FairGroup for message::Message {
312    fn group_id(&self) -> Option<u32> {
313        match self {
314            message::Message::DataChannelData(lhs) => Some(lhs.channel_id),
315            message::Message::DataChannelDataUdp(lhs) => Some(lhs.channel_id),
316            message::Message::DataChannelEof(lhs) => Some(lhs.channel_id),
317            message::Message::DataChannelAck(lhs) => Some(lhs.channel_id),
318            _ => None,
319        }
320    }
321
322    fn get_size(&self) -> Option<usize> {
323        match self {
324            message::Message::DataChannelData(data) => Some(data.data.len()),
325            message::Message::DataChannelDataUdp(data) => Some(data.data.len()),
326            _ => None, // Control messages have no size
327        }
328    }
329}
330
331pub type UdpPacketLen = u16; // `u16` should be enough for any practical UDP traffic on the Internet
332                             //
333#[derive(Deserialize, Serialize, Debug)]
334pub struct UdpHeader {
335    from: std::net::SocketAddr,
336    len: UdpPacketLen,
337}
338
339#[derive(Debug)]
340pub struct UdpTraffic {
341    pub from: std::net::SocketAddr,
342    pub data: Bytes,
343}
344
345impl UdpTraffic {
346    pub async fn write<T: AsyncWrite + Unpin>(&self, writer: &mut T) -> Result<()> {
347        let hdr = UdpHeader {
348            from: self.from,
349            len: self.data.len() as UdpPacketLen,
350        };
351
352        let v = bincode::serde::encode_to_vec(&hdr, bincode::config::legacy()).unwrap();
353
354        trace!("Write {:?} of length {}", hdr, v.len());
355        writer.write_u8(v.len() as u8).await?;
356        writer.write_all(&v).await?;
357
358        writer.write_all(&self.data).await?;
359
360        Ok(())
361    }
362
363    #[allow(dead_code)]
364    pub async fn write_slice<T: AsyncWrite + Unpin>(
365        writer: &mut T,
366        from: std::net::SocketAddr,
367        data: &[u8],
368    ) -> Result<()> {
369        let hdr = UdpHeader {
370            from,
371            len: data.len() as UdpPacketLen,
372        };
373
374        let v = bincode::serde::encode_to_vec(&hdr, bincode::config::legacy()).unwrap();
375
376        trace!("Write {:?} of length {}", hdr, v.len());
377        writer.write_u8(v.len() as u8).await?;
378        writer.write_all(&v).await?;
379
380        writer.write_all(data).await?;
381
382        Ok(())
383    }
384
385    pub async fn read<T: AsyncRead + Unpin>(reader: &mut T, hdr_len: u8) -> Result<UdpTraffic> {
386        let mut buf = vec![0; hdr_len as usize];
387        reader
388            .read_exact(&mut buf)
389            .await
390            .with_context(|| "Failed to read udp header")?;
391
392        let hdr: UdpHeader = bincode::serde::decode_from_slice(&buf, bincode::config::legacy())
393            .with_context(|| "Failed to deserialize UdpHeader")?
394            .0;
395
396        trace!("hdr {:?}", hdr);
397
398        let mut data = BytesMut::new();
399        data.resize(hdr.len as usize, 0);
400        reader.read_exact(&mut data).await?;
401
402        Ok(UdpTraffic {
403            from: hdr.from,
404            data: data.freeze(),
405        })
406    }
407}