1use crate::error::{ProxyError, Result};
9use subtle::ConstantTimeEq;
10use tracing::warn;
11use zeroize::Zeroizing;
12
13const TOKEN_BYTES: usize = 32;
15
16pub fn generate_session_token() -> Result<Zeroizing<String>> {
22 let mut bytes = [0u8; TOKEN_BYTES];
23 getrandom::fill(&mut bytes).map_err(|e| ProxyError::Config(format!("RNG failure: {}", e)))?;
24 let hex = hex_encode(&bytes);
25 bytes.fill(0);
27 Ok(Zeroizing::new(hex))
28}
29
30#[must_use]
36pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
37 if a.len() != b.len() {
38 return false;
39 }
40 a.ct_eq(b).into()
41}
42
43fn hex_encode(bytes: &[u8]) -> String {
45 let mut hex = String::with_capacity(bytes.len().saturating_mul(2));
46 for byte in bytes {
47 hex.push(HEX_CHARS[(byte >> 4) as usize]);
48 hex.push(HEX_CHARS[(byte & 0x0f) as usize]);
49 }
50 hex
51}
52
53const HEX_CHARS: [char; 16] = [
54 '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f',
55];
56
57pub fn validate_proxy_auth(header_bytes: &[u8], session_token: &Zeroizing<String>) -> Result<()> {
65 let header_str = std::str::from_utf8(header_bytes).map_err(|_| ProxyError::InvalidToken)?;
66
67 const BEARER_PREFIX: &str = "proxy-authorization: bearer ";
68 const BASIC_PREFIX: &str = "proxy-authorization: basic ";
69
70 for line in header_str.lines() {
71 let lower = line.to_lowercase();
72 if lower.starts_with(BEARER_PREFIX) {
73 let value = line[BEARER_PREFIX.len()..].trim();
74 if constant_time_eq(value.as_bytes(), session_token.as_bytes()) {
75 return Ok(());
76 }
77 warn!("Invalid proxy authorization token (Bearer)");
78 return Err(ProxyError::InvalidToken);
79 }
80 if lower.starts_with(BASIC_PREFIX) {
81 let encoded = line[BASIC_PREFIX.len()..].trim();
82 return validate_basic_auth(encoded, session_token);
83 }
84 }
85
86 warn!("Missing Proxy-Authorization header");
87 Err(ProxyError::InvalidToken)
88}
89
90fn validate_basic_auth(encoded: &str, session_token: &Zeroizing<String>) -> Result<()> {
95 use base64::engine::general_purpose::STANDARD;
96 use base64::Engine;
97
98 let decoded = STANDARD
99 .decode(encoded)
100 .map_err(|_| ProxyError::InvalidToken)?;
101 let decoded_str = std::str::from_utf8(&decoded).map_err(|_| ProxyError::InvalidToken)?;
102
103 let password = match decoded_str.split_once(':') {
104 Some((_, pw)) => pw,
105 None => {
106 warn!("Malformed Basic auth (no colon separator)");
107 return Err(ProxyError::InvalidToken);
108 }
109 };
110
111 if constant_time_eq(password.as_bytes(), session_token.as_bytes()) {
112 Ok(())
113 } else {
114 warn!("Invalid proxy authorization token (Basic)");
115 Err(ProxyError::InvalidToken)
116 }
117}
118
119#[cfg(test)]
120#[allow(clippy::unwrap_used)]
121mod tests {
122 use super::*;
123
124 #[test]
125 fn test_generate_token_length() {
126 let token = generate_session_token().unwrap();
127 assert_eq!(token.len(), 64); }
129
130 #[test]
131 fn test_generate_token_is_hex() {
132 let token = generate_session_token().unwrap();
133 assert!(token.chars().all(|c| c.is_ascii_hexdigit()));
134 }
135
136 #[test]
137 fn test_generate_token_unique() {
138 let t1 = generate_session_token().unwrap();
139 let t2 = generate_session_token().unwrap();
140 assert_ne!(*t1, *t2);
141 }
142
143 #[test]
144 fn test_constant_time_eq_same() {
145 let a = b"hello";
146 let b = b"hello";
147 assert!(constant_time_eq(a, b));
148 }
149
150 #[test]
151 fn test_constant_time_eq_different() {
152 let a = b"hello";
153 let b = b"world";
154 assert!(!constant_time_eq(a, b));
155 }
156
157 #[test]
158 fn test_constant_time_eq_different_length() {
159 let a = b"hello";
160 let b = b"hi";
161 assert!(!constant_time_eq(a, b));
162 }
163
164 #[test]
165 fn test_constant_time_eq_empty() {
166 assert!(constant_time_eq(b"", b""));
167 }
168
169 #[test]
170 fn test_validate_proxy_auth_bearer() {
171 let token = Zeroizing::new("abc123".to_string());
172 let header = b"Proxy-Authorization: Bearer abc123\r\n\r\n";
173 assert!(validate_proxy_auth(header, &token).is_ok());
174 }
175
176 #[test]
177 fn test_validate_proxy_auth_bearer_case_insensitive() {
178 let token = Zeroizing::new("abc123".to_string());
179 let header = b"proxy-authorization: BEARER abc123\r\n\r\n";
180 assert!(validate_proxy_auth(header, &token).is_ok());
181 }
182
183 #[test]
184 fn test_validate_proxy_auth_bearer_invalid() {
185 let token = Zeroizing::new("abc123".to_string());
186 let header = b"Proxy-Authorization: Bearer wrong\r\n\r\n";
187 assert!(validate_proxy_auth(header, &token).is_err());
188 }
189
190 #[test]
191 fn test_validate_proxy_auth_basic() {
192 use base64::engine::general_purpose::STANDARD;
193 use base64::Engine;
194 let token = Zeroizing::new("abc123".to_string());
195 let encoded = STANDARD.encode("nono:abc123");
196 let header = format!("Proxy-Authorization: Basic {}\r\n\r\n", encoded);
197 assert!(validate_proxy_auth(header.as_bytes(), &token).is_ok());
198 }
199
200 #[test]
201 fn test_validate_proxy_auth_basic_wrong_password() {
202 use base64::engine::general_purpose::STANDARD;
203 use base64::Engine;
204 let token = Zeroizing::new("abc123".to_string());
205 let encoded = STANDARD.encode("nono:wrong");
206 let header = format!("Proxy-Authorization: Basic {}\r\n\r\n", encoded);
207 assert!(validate_proxy_auth(header.as_bytes(), &token).is_err());
208 }
209
210 #[test]
211 fn test_validate_proxy_auth_basic_any_username() {
212 use base64::engine::general_purpose::STANDARD;
213 use base64::Engine;
214 let token = Zeroizing::new("abc123".to_string());
215 let encoded = STANDARD.encode("whatever:abc123");
217 let header = format!("Proxy-Authorization: Basic {}\r\n\r\n", encoded);
218 assert!(validate_proxy_auth(header.as_bytes(), &token).is_ok());
219 }
220
221 #[test]
222 fn test_validate_proxy_auth_missing() {
223 let token = Zeroizing::new("abc123".to_string());
224 let header = b"Host: example.com\r\n\r\n";
225 assert!(validate_proxy_auth(header, &token).is_err());
226 }
227}