totp_gateway/
utils.rs

1use regex::Regex;
2use std::fmt;
3use std::net::IpAddr;
4use std::str::FromStr;
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub struct SessionId(String);
8
9impl SessionId {
10    pub fn new(id: String) -> Self {
11        Self(id)
12    }
13
14    pub fn as_str(&self) -> &str {
15        &self.0
16    }
17}
18
19impl fmt::Display for SessionId {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        write!(f, "{}", self.0)
22    }
23}
24
25impl From<String> for SessionId {
26    fn from(s: String) -> Self {
27        Self(s)
28    }
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct UpstreamAddr {
33    pub host: String,
34    pub port: u16,
35}
36
37impl UpstreamAddr {
38    pub fn new(host: String, port: u16) -> Self {
39        Self { host, port }
40    }
41}
42
43impl FromStr for UpstreamAddr {
44    type Err = ParseError;
45
46    fn from_str(s: &str) -> Result<Self, Self::Err> {
47        let parts: Vec<&str> = s.split(':').collect();
48        match parts.as_slice() {
49            [host, port_str] => {
50                let port = port_str
51                    .parse::<u16>()
52                    .map_err(|_| ParseError::InvalidPort(port_str.to_string()))?;
53                Ok(Self {
54                    host: host.to_string(),
55                    port,
56                })
57            }
58            [host] => Ok(Self {
59                host: host.to_string(),
60                port: 80,
61            }),
62            _ => Err(ParseError::InvalidFormat(s.to_string())),
63        }
64    }
65}
66
67impl fmt::Display for UpstreamAddr {
68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69        write!(f, "{}:{}", self.host, self.port)
70    }
71}
72
73#[derive(Debug, Clone)]
74pub enum ParseError {
75    InvalidFormat(String),
76    InvalidPort(String),
77    InvalidRegex(String),
78}
79
80impl fmt::Display for ParseError {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        match self {
83            ParseError::InvalidFormat(s) => write!(f, "Invalid format: {}", s),
84            ParseError::InvalidPort(s) => write!(f, "Invalid port: {}", s),
85            ParseError::InvalidRegex(s) => write!(f, "Invalid regex: {}", s),
86        }
87    }
88}
89
90impl std::error::Error for ParseError {}
91
92#[derive(Debug, Clone)]
93pub enum ProxyError {
94    MissingClientIp,
95    InvalidUpstream(String),
96    TotpSecretInvalid,
97    TotpCreationFailed,
98    SessionTableFull,
99    IpLimitTableFull,
100}
101
102impl fmt::Display for ProxyError {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        match self {
105            ProxyError::MissingClientIp => write!(f, "Client IP address not found"),
106            ProxyError::InvalidUpstream(s) => write!(f, "Invalid upstream address: {}", s),
107            ProxyError::TotpSecretInvalid => write!(f, "TOTP secret is invalid"),
108            ProxyError::TotpCreationFailed => write!(f, "Failed to create TOTP instance"),
109            ProxyError::SessionTableFull => write!(f, "Session table is full"),
110            ProxyError::IpLimitTableFull => write!(f, "IP limit table is full"),
111        }
112    }
113}
114
115impl std::error::Error for ProxyError {}
116
117pub struct ClientIp(IpAddr);
118
119impl ClientIp {
120    pub fn new(ip: IpAddr) -> Self {
121        Self(ip)
122    }
123
124    pub fn inner(&self) -> IpAddr {
125        self.0
126    }
127}
128
129impl From<IpAddr> for ClientIp {
130    fn from(ip: IpAddr) -> Self {
131        Self(ip)
132    }
133}
134
135pub fn glob_to_regex(pattern: &str) -> Result<Regex, ParseError> {
136    let mut regex_str = String::from("^");
137    for c in pattern.chars() {
138        match c {
139            '*' => regex_str.push_str(".*"),
140            '?' => regex_str.push('.'),
141            '.' | '+' | '^' | '$' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '\\' => {
142                regex_str.push('\\');
143                regex_str.push(c);
144            }
145            _ => regex_str.push(c),
146        }
147    }
148    regex_str.push('$');
149    Regex::new(&regex_str).map_err(|e| ParseError::InvalidRegex(e.to_string()))
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_upstream_addr_parse() {
158        let addr: UpstreamAddr = "127.0.0.1:8080".parse().unwrap();
159        assert_eq!(addr.host, "127.0.0.1");
160        assert_eq!(addr.port, 8080);
161
162        let addr: UpstreamAddr = "example.com".parse().unwrap();
163        assert_eq!(addr.host, "example.com");
164        assert_eq!(addr.port, 80);
165
166        assert!("invalid:port:extra".parse::<UpstreamAddr>().is_err());
167        assert!("host:abc".parse::<UpstreamAddr>().is_err());
168    }
169
170    #[test]
171    fn test_glob_to_regex() {
172        let re = glob_to_regex("*.example.com").unwrap();
173        assert!(re.is_match("sub.example.com"));
174        assert!(re.is_match("a.example.com"));
175        assert!(!re.is_match("example.com"));
176
177        let re = glob_to_regex("/api/*").unwrap();
178        assert!(re.is_match("/api/users"));
179        assert!(re.is_match("/api/"));
180        assert!(!re.is_match("/api"));
181
182        let re = glob_to_regex("test?.com").unwrap();
183        assert!(re.is_match("test1.com"));
184        assert!(re.is_match("testa.com"));
185        assert!(!re.is_match("test.com"));
186        assert!(!re.is_match("test12.com"));
187    }
188}