1use regex::Regex;
2use std::fmt;
3use std::net::IpAddr;
4use std::str::FromStr;
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub struct SessionId(String);
8
9impl SessionId {
10 pub fn new(id: String) -> Self {
11 Self(id)
12 }
13
14 pub fn as_str(&self) -> &str {
15 &self.0
16 }
17}
18
19impl fmt::Display for SessionId {
20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 write!(f, "{}", self.0)
22 }
23}
24
25impl From<String> for SessionId {
26 fn from(s: String) -> Self {
27 Self(s)
28 }
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct UpstreamAddr {
33 pub host: String,
34 pub port: u16,
35}
36
37impl UpstreamAddr {
38 pub fn new(host: String, port: u16) -> Self {
39 Self { host, port }
40 }
41}
42
43impl FromStr for UpstreamAddr {
44 type Err = ParseError;
45
46 fn from_str(s: &str) -> Result<Self, Self::Err> {
47 let parts: Vec<&str> = s.split(':').collect();
48 match parts.as_slice() {
49 [host, port_str] => {
50 let port = port_str
51 .parse::<u16>()
52 .map_err(|_| ParseError::InvalidPort(port_str.to_string()))?;
53 Ok(Self {
54 host: host.to_string(),
55 port,
56 })
57 }
58 [host] => Ok(Self {
59 host: host.to_string(),
60 port: 80,
61 }),
62 _ => Err(ParseError::InvalidFormat(s.to_string())),
63 }
64 }
65}
66
67impl fmt::Display for UpstreamAddr {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 write!(f, "{}:{}", self.host, self.port)
70 }
71}
72
73#[derive(Debug, Clone)]
74pub enum ParseError {
75 InvalidFormat(String),
76 InvalidPort(String),
77 InvalidRegex(String),
78}
79
80impl fmt::Display for ParseError {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 match self {
83 ParseError::InvalidFormat(s) => write!(f, "Invalid format: {}", s),
84 ParseError::InvalidPort(s) => write!(f, "Invalid port: {}", s),
85 ParseError::InvalidRegex(s) => write!(f, "Invalid regex: {}", s),
86 }
87 }
88}
89
90impl std::error::Error for ParseError {}
91
92#[derive(Debug, Clone)]
93pub enum ProxyError {
94 MissingClientIp,
95 InvalidUpstream(String),
96 TotpSecretInvalid,
97 TotpCreationFailed,
98 SessionTableFull,
99 IpLimitTableFull,
100}
101
102impl fmt::Display for ProxyError {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 match self {
105 ProxyError::MissingClientIp => write!(f, "Client IP address not found"),
106 ProxyError::InvalidUpstream(s) => write!(f, "Invalid upstream address: {}", s),
107 ProxyError::TotpSecretInvalid => write!(f, "TOTP secret is invalid"),
108 ProxyError::TotpCreationFailed => write!(f, "Failed to create TOTP instance"),
109 ProxyError::SessionTableFull => write!(f, "Session table is full"),
110 ProxyError::IpLimitTableFull => write!(f, "IP limit table is full"),
111 }
112 }
113}
114
115impl std::error::Error for ProxyError {}
116
117pub struct ClientIp(IpAddr);
118
119impl ClientIp {
120 pub fn new(ip: IpAddr) -> Self {
121 Self(ip)
122 }
123
124 pub fn inner(&self) -> IpAddr {
125 self.0
126 }
127}
128
129impl From<IpAddr> for ClientIp {
130 fn from(ip: IpAddr) -> Self {
131 Self(ip)
132 }
133}
134
135pub fn glob_to_regex(pattern: &str) -> Result<Regex, ParseError> {
136 let mut regex_str = String::from("^");
137 for c in pattern.chars() {
138 match c {
139 '*' => regex_str.push_str(".*"),
140 '?' => regex_str.push('.'),
141 '.' | '+' | '^' | '$' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '\\' => {
142 regex_str.push('\\');
143 regex_str.push(c);
144 }
145 _ => regex_str.push(c),
146 }
147 }
148 regex_str.push('$');
149 Regex::new(®ex_str).map_err(|e| ParseError::InvalidRegex(e.to_string()))
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_upstream_addr_parse() {
158 let addr: UpstreamAddr = "127.0.0.1:8080".parse().unwrap();
159 assert_eq!(addr.host, "127.0.0.1");
160 assert_eq!(addr.port, 8080);
161
162 let addr: UpstreamAddr = "example.com".parse().unwrap();
163 assert_eq!(addr.host, "example.com");
164 assert_eq!(addr.port, 80);
165
166 assert!("invalid:port:extra".parse::<UpstreamAddr>().is_err());
167 assert!("host:abc".parse::<UpstreamAddr>().is_err());
168 }
169
170 #[test]
171 fn test_glob_to_regex() {
172 let re = glob_to_regex("*.example.com").unwrap();
173 assert!(re.is_match("sub.example.com"));
174 assert!(re.is_match("a.example.com"));
175 assert!(!re.is_match("example.com"));
176
177 let re = glob_to_regex("/api/*").unwrap();
178 assert!(re.is_match("/api/users"));
179 assert!(re.is_match("/api/"));
180 assert!(!re.is_match("/api"));
181
182 let re = glob_to_regex("test?.com").unwrap();
183 assert!(re.is_match("test1.com"));
184 assert!(re.is_match("testa.com"));
185 assert!(!re.is_match("test.com"));
186 assert!(!re.is_match("test12.com"));
187 }
188}