port_mapping/
mapping_rule.rs

1use std::{collections::HashMap, fmt::Display, ops::RangeInclusive};
2use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader};
3
4#[derive(Debug)]
5pub enum ProtocolRaw {
6    Tcp,
7    Udp,
8    TcpUdp,
9}
10
11#[derive(Debug, PartialEq, Eq, Hash)]
12pub enum Protocol {
13    Tcp,
14    Udp,
15}
16
17#[derive(Debug)]
18pub struct MappingRuleRaw<'a> {
19    pub protocol: ProtocolRaw,
20    pub listen_port: RangeInclusive<u16>,
21    pub upstream_host: &'a str,
22    pub upstream_port: RangeInclusive<u16>,
23}
24
25#[derive(Debug)]
26pub struct MappingRule {
27    pub protocol: Protocol,
28    pub listen: String,
29    pub upstream: String,
30}
31
32impl Display for MappingRule {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        write!(f, "{}->{}", self.listen, self.upstream)
35    }
36}
37
38#[derive(Debug)]
39pub enum MappingRuleParseError<'a> {
40    Empty,
41    MissingListenPort,
42    MissingUpstream,
43    MissingUpstreamPort,
44    InvalidProtocol(String),
45    InvalidListenPort(&'a str),
46    InvalidListenPortRange(&'a str),
47    InvalidUpstream(&'a str),
48    InvalidUpstreamPort(&'a str),
49    InvalidUpstreamPortRange(&'a str),
50    UnmatchedPortRange(RangeInclusive<u16>, RangeInclusive<u16>),
51}
52
53impl<'a> MappingRuleRaw<'a> {
54    pub fn parse(line: &'a str) -> Result<Self, MappingRuleParseError<'a>> {
55        // Skip empty lines and comments
56        let mut parts = line
57            .split('#')
58            .next()
59            .ok_or(MappingRuleParseError::Empty)?
60            .split_whitespace();
61        // Parse protocol
62        let protocol = match parts
63            .next()
64            .ok_or(MappingRuleParseError::Empty)?
65            .to_lowercase()
66            .as_str()
67        {
68            "udp" => ProtocolRaw::Udp,
69            "tcp" => ProtocolRaw::Tcp,
70            "t+u" => ProtocolRaw::TcpUdp,
71            input => {
72                return Err(MappingRuleParseError::InvalidProtocol(input.to_string()));
73            }
74        };
75        // Check listen
76        let listen = parts
77            .next()
78            .ok_or(MappingRuleParseError::MissingListenPort)?;
79        // Check upstream
80        let upstream = parts.next().ok_or(MappingRuleParseError::MissingUpstream)?;
81        // Parse listen port
82        let mut listen_parts = listen.splitn(2, '-');
83        let listen_from: u16 = listen_parts
84            .next()
85            .ok_or(MappingRuleParseError::InvalidListenPort(listen))?
86            .parse()
87            .map_err(|_| MappingRuleParseError::InvalidListenPort(listen))?;
88        let listen_to: u16 = listen_parts
89            .next()
90            .map(|s| s.parse())
91            .unwrap_or(Ok(listen_from))
92            .map_err(|_| MappingRuleParseError::InvalidListenPort(listen))?;
93        if listen_from > listen_to {
94            return Err(MappingRuleParseError::InvalidListenPortRange(listen));
95        }
96        // Parse upstream
97        let mut upstream_parts = upstream.splitn(2, ':');
98        let upstream_host = {
99            let t = upstream_parts
100                .next()
101                .ok_or(MappingRuleParseError::InvalidUpstream(upstream))?;
102            if t.is_empty() { "localhost" } else { t }
103        };
104        let mut upstream_port_parts = upstream_parts
105            .next()
106            .ok_or(MappingRuleParseError::MissingUpstreamPort)?
107            .splitn(2, '-');
108        let upstream_port_from: u16 = upstream_port_parts
109            .next()
110            .ok_or(MappingRuleParseError::InvalidUpstreamPort(upstream))?
111            .parse()
112            .map_err(|_| MappingRuleParseError::InvalidUpstreamPort(upstream))?;
113        let upstream_port_to: u16 = upstream_port_parts
114            .next()
115            .map(|s| s.parse())
116            .unwrap_or(Ok(upstream_port_from))
117            .map_err(|_| MappingRuleParseError::InvalidUpstreamPort(upstream))?;
118        if upstream_port_from > upstream_port_to {
119            return Err(MappingRuleParseError::InvalidUpstreamPortRange(upstream));
120        }
121        let listen_port = listen_from..=listen_to;
122        let upstream_port = upstream_port_from..=upstream_port_to;
123        if upstream_port_to - upstream_port_from != listen_to - listen_from {
124            return Err(MappingRuleParseError::UnmatchedPortRange(
125                listen_port,
126                upstream_port,
127            ));
128        }
129        Ok(Self {
130            protocol,
131            listen_port,
132            upstream_host,
133            upstream_port,
134        })
135    }
136}
137
138pub async fn read_mapping_file<T: Unpin + AsyncRead>(
139    mut reader: BufReader<T>,
140) -> std::io::Result<Vec<MappingRule>> {
141    let mut rules = HashMap::new();
142    let mut line = String::new();
143    while reader.read_line(&mut line).await? != 0 {
144        line = line.trim().to_string();
145        match MappingRuleRaw::parse(&line) {
146            Ok(entry) => {
147                match entry.protocol {
148                    ProtocolRaw::Tcp => {
149                        let upstream_port_from = entry.upstream_port.start();
150                        for (i, port) in entry.listen_port.enumerate() {
151                            if rules.contains_key(&(Protocol::Tcp, port)) {
152                                eprintln!("[warning][tcp] Port {port} will be overwritten")
153                            }
154                            rules.insert(
155                                (Protocol::Tcp, port),
156                                (
157                                    entry.upstream_host.to_string(),
158                                    upstream_port_from + i as u16,
159                                ),
160                            );
161                        }
162                    }
163                    ProtocolRaw::Udp => {
164                        let upstream_port_from = entry.upstream_port.start();
165                        for (i, port) in entry.listen_port.enumerate() {
166                            if rules.contains_key(&(Protocol::Udp, port)) {
167                                eprintln!("[warning][udp] Port {port} will be overwritten")
168                            }
169                            rules.insert(
170                                (Protocol::Udp, port),
171                                (
172                                    entry.upstream_host.to_string(),
173                                    upstream_port_from + i as u16,
174                                ),
175                            );
176                        }
177                    }
178                    ProtocolRaw::TcpUdp => {
179                        let upstream_port_from = entry.upstream_port.start();
180                        for (i, port) in entry.listen_port.enumerate() {
181                            if rules.contains_key(&(Protocol::Tcp, port)) {
182                                eprintln!("[warning][tcp] Port {port} will be overwritten")
183                            }
184                            rules.insert(
185                                (Protocol::Tcp, port),
186                                (
187                                    entry.upstream_host.to_string(),
188                                    upstream_port_from + i as u16,
189                                ),
190                            );
191                            if rules.contains_key(&(Protocol::Udp, port)) {
192                                eprintln!("[warning][udp] Port {port} will be overwritten")
193                            }
194                            rules.insert(
195                                (Protocol::Udp, port),
196                                (
197                                    entry.upstream_host.to_string(),
198                                    upstream_port_from + i as u16,
199                                ),
200                            );
201                        }
202                    }
203                };
204            }
205            Err(e) => match e {
206                MappingRuleParseError::Empty => (),
207                MappingRuleParseError::MissingListenPort => {
208                    eprintln!("[warning][parse] Missing listen port: {line}")
209                }
210                MappingRuleParseError::MissingUpstream => {
211                    eprintln!("[warning][parse] Missing upstream: {line}")
212                }
213                MappingRuleParseError::MissingUpstreamPort => {
214                    eprintln!("[warning][parse] Missing upstream port: {line}")
215                }
216                MappingRuleParseError::InvalidProtocol(protocol) => {
217                    eprintln!("[warning][parse] Invalid protocol: {protocol} in {line}")
218                }
219                MappingRuleParseError::InvalidListenPort(port) => {
220                    eprintln!("[warning][parse] Invalid listen port: {port} in {line}")
221                }
222                MappingRuleParseError::InvalidListenPortRange(range) => {
223                    eprintln!("[warning][parse] Invalid listen port range: {range} in {line}")
224                }
225                MappingRuleParseError::InvalidUpstream(upstream) => {
226                    eprintln!("[warning][parse] Invalid upstream: {upstream} in {line}")
227                }
228                MappingRuleParseError::InvalidUpstreamPort(port) => {
229                    eprintln!("[warning][parse] Invalid upstream port: {port} in {line}")
230                }
231                MappingRuleParseError::InvalidUpstreamPortRange(range) => {
232                    eprintln!("[warning][parse] Invalid upstream port range: {range} in {line}")
233                }
234                MappingRuleParseError::UnmatchedPortRange(
235                    listen_port_range,
236                    upstream_port_range,
237                ) => {
238                    eprintln!(
239                        "[warning][parse] Unmatched port range: {}-{} -> {}-{} in {line}",
240                        listen_port_range.start(),
241                        listen_port_range.end(),
242                        upstream_port_range.start(),
243                        upstream_port_range.end()
244                    )
245                }
246            },
247        }
248        line.clear();
249    }
250    Ok(rules
251        .into_iter()
252        .map(
253            |((protocol, listen), (upstream_host, upstream_port))| MappingRule {
254                protocol,
255                listen: format!("0.0.0.0:{listen}"),
256                upstream: format!("{upstream_host}:{upstream_port}"),
257            },
258        )
259        .collect())
260}