1use std::{collections::HashMap, fmt::Display, ops::RangeInclusive};
2use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader};
3
4#[derive(Debug)]
5pub enum ProtocolRaw {
6 Tcp,
7 Udp,
8 TcpUdp,
9}
10
11#[derive(Debug, PartialEq, Eq, Hash)]
12pub enum Protocol {
13 Tcp,
14 Udp,
15}
16
17#[derive(Debug)]
18pub struct MappingRuleRaw<'a> {
19 pub protocol: ProtocolRaw,
20 pub listen_port: RangeInclusive<u16>,
21 pub upstream_host: &'a str,
22 pub upstream_port: RangeInclusive<u16>,
23}
24
25#[derive(Debug)]
26pub struct MappingRule {
27 pub protocol: Protocol,
28 pub listen: String,
29 pub upstream: String,
30}
31
32impl Display for MappingRule {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 write!(f, "{}->{}", self.listen, self.upstream)
35 }
36}
37
38#[derive(Debug)]
39pub enum MappingRuleParseError<'a> {
40 Empty,
41 MissingListenPort,
42 MissingUpstream,
43 MissingUpstreamPort,
44 InvalidProtocol(String),
45 InvalidListenPort(&'a str),
46 InvalidListenPortRange(&'a str),
47 InvalidUpstream(&'a str),
48 InvalidUpstreamPort(&'a str),
49 InvalidUpstreamPortRange(&'a str),
50 UnmatchedPortRange(RangeInclusive<u16>, RangeInclusive<u16>),
51}
52
53impl<'a> MappingRuleRaw<'a> {
54 pub fn parse(line: &'a str) -> Result<Self, MappingRuleParseError<'a>> {
55 let mut parts = line
57 .split('#')
58 .next()
59 .ok_or(MappingRuleParseError::Empty)?
60 .split_whitespace();
61 let protocol = match parts
63 .next()
64 .ok_or(MappingRuleParseError::Empty)?
65 .to_lowercase()
66 .as_str()
67 {
68 "udp" => ProtocolRaw::Udp,
69 "tcp" => ProtocolRaw::Tcp,
70 "t+u" => ProtocolRaw::TcpUdp,
71 input => {
72 return Err(MappingRuleParseError::InvalidProtocol(input.to_string()));
73 }
74 };
75 let listen = parts
77 .next()
78 .ok_or(MappingRuleParseError::MissingListenPort)?;
79 let upstream = parts.next().ok_or(MappingRuleParseError::MissingUpstream)?;
81 let mut listen_parts = listen.splitn(2, '-');
83 let listen_from: u16 = listen_parts
84 .next()
85 .ok_or(MappingRuleParseError::InvalidListenPort(listen))?
86 .parse()
87 .map_err(|_| MappingRuleParseError::InvalidListenPort(listen))?;
88 let listen_to: u16 = listen_parts
89 .next()
90 .map(|s| s.parse())
91 .unwrap_or(Ok(listen_from))
92 .map_err(|_| MappingRuleParseError::InvalidListenPort(listen))?;
93 if listen_from > listen_to {
94 return Err(MappingRuleParseError::InvalidListenPortRange(listen));
95 }
96 let mut upstream_parts = upstream.splitn(2, ':');
98 let upstream_host = {
99 let t = upstream_parts
100 .next()
101 .ok_or(MappingRuleParseError::InvalidUpstream(upstream))?;
102 if t.is_empty() { "localhost" } else { t }
103 };
104 let mut upstream_port_parts = upstream_parts
105 .next()
106 .ok_or(MappingRuleParseError::MissingUpstreamPort)?
107 .splitn(2, '-');
108 let upstream_port_from: u16 = upstream_port_parts
109 .next()
110 .ok_or(MappingRuleParseError::InvalidUpstreamPort(upstream))?
111 .parse()
112 .map_err(|_| MappingRuleParseError::InvalidUpstreamPort(upstream))?;
113 let upstream_port_to: u16 = upstream_port_parts
114 .next()
115 .map(|s| s.parse())
116 .unwrap_or(Ok(upstream_port_from))
117 .map_err(|_| MappingRuleParseError::InvalidUpstreamPort(upstream))?;
118 if upstream_port_from > upstream_port_to {
119 return Err(MappingRuleParseError::InvalidUpstreamPortRange(upstream));
120 }
121 let listen_port = listen_from..=listen_to;
122 let upstream_port = upstream_port_from..=upstream_port_to;
123 if upstream_port_to - upstream_port_from != listen_to - listen_from {
124 return Err(MappingRuleParseError::UnmatchedPortRange(
125 listen_port,
126 upstream_port,
127 ));
128 }
129 Ok(Self {
130 protocol,
131 listen_port,
132 upstream_host,
133 upstream_port,
134 })
135 }
136}
137
138pub async fn read_mapping_file<T: Unpin + AsyncRead>(
139 mut reader: BufReader<T>,
140) -> std::io::Result<Vec<MappingRule>> {
141 let mut rules = HashMap::new();
142 let mut line = String::new();
143 while reader.read_line(&mut line).await? != 0 {
144 line = line.trim().to_string();
145 match MappingRuleRaw::parse(&line) {
146 Ok(entry) => {
147 match entry.protocol {
148 ProtocolRaw::Tcp => {
149 let upstream_port_from = entry.upstream_port.start();
150 for (i, port) in entry.listen_port.enumerate() {
151 if rules.contains_key(&(Protocol::Tcp, port)) {
152 eprintln!("[warning][tcp] Port {port} will be overwritten")
153 }
154 rules.insert(
155 (Protocol::Tcp, port),
156 (
157 entry.upstream_host.to_string(),
158 upstream_port_from + i as u16,
159 ),
160 );
161 }
162 }
163 ProtocolRaw::Udp => {
164 let upstream_port_from = entry.upstream_port.start();
165 for (i, port) in entry.listen_port.enumerate() {
166 if rules.contains_key(&(Protocol::Udp, port)) {
167 eprintln!("[warning][udp] Port {port} will be overwritten")
168 }
169 rules.insert(
170 (Protocol::Udp, port),
171 (
172 entry.upstream_host.to_string(),
173 upstream_port_from + i as u16,
174 ),
175 );
176 }
177 }
178 ProtocolRaw::TcpUdp => {
179 let upstream_port_from = entry.upstream_port.start();
180 for (i, port) in entry.listen_port.enumerate() {
181 if rules.contains_key(&(Protocol::Tcp, port)) {
182 eprintln!("[warning][tcp] Port {port} will be overwritten")
183 }
184 rules.insert(
185 (Protocol::Tcp, port),
186 (
187 entry.upstream_host.to_string(),
188 upstream_port_from + i as u16,
189 ),
190 );
191 if rules.contains_key(&(Protocol::Udp, port)) {
192 eprintln!("[warning][udp] Port {port} will be overwritten")
193 }
194 rules.insert(
195 (Protocol::Udp, port),
196 (
197 entry.upstream_host.to_string(),
198 upstream_port_from + i as u16,
199 ),
200 );
201 }
202 }
203 };
204 }
205 Err(e) => match e {
206 MappingRuleParseError::Empty => (),
207 MappingRuleParseError::MissingListenPort => {
208 eprintln!("[warning][parse] Missing listen port: {line}")
209 }
210 MappingRuleParseError::MissingUpstream => {
211 eprintln!("[warning][parse] Missing upstream: {line}")
212 }
213 MappingRuleParseError::MissingUpstreamPort => {
214 eprintln!("[warning][parse] Missing upstream port: {line}")
215 }
216 MappingRuleParseError::InvalidProtocol(protocol) => {
217 eprintln!("[warning][parse] Invalid protocol: {protocol} in {line}")
218 }
219 MappingRuleParseError::InvalidListenPort(port) => {
220 eprintln!("[warning][parse] Invalid listen port: {port} in {line}")
221 }
222 MappingRuleParseError::InvalidListenPortRange(range) => {
223 eprintln!("[warning][parse] Invalid listen port range: {range} in {line}")
224 }
225 MappingRuleParseError::InvalidUpstream(upstream) => {
226 eprintln!("[warning][parse] Invalid upstream: {upstream} in {line}")
227 }
228 MappingRuleParseError::InvalidUpstreamPort(port) => {
229 eprintln!("[warning][parse] Invalid upstream port: {port} in {line}")
230 }
231 MappingRuleParseError::InvalidUpstreamPortRange(range) => {
232 eprintln!("[warning][parse] Invalid upstream port range: {range} in {line}")
233 }
234 MappingRuleParseError::UnmatchedPortRange(
235 listen_port_range,
236 upstream_port_range,
237 ) => {
238 eprintln!(
239 "[warning][parse] Unmatched port range: {}-{} -> {}-{} in {line}",
240 listen_port_range.start(),
241 listen_port_range.end(),
242 upstream_port_range.start(),
243 upstream_port_range.end()
244 )
245 }
246 },
247 }
248 line.clear();
249 }
250 Ok(rules
251 .into_iter()
252 .map(
253 |((protocol, listen), (upstream_host, upstream_port))| MappingRule {
254 protocol,
255 listen: format!("0.0.0.0:{listen}"),
256 upstream: format!("{upstream_host}:{upstream_port}"),
257 },
258 )
259 .collect())
260}