nntp_proxy/auth/
handler.rs

1//! Client authentication handling
2
3use crate::command::AuthAction;
4use crate::protocol::{AUTH_ACCEPTED, AUTH_REQUIRED};
5use tokio::io::AsyncWriteExt;
6
7/// Client credentials for authentication
8#[derive(Clone)]
9struct Credentials {
10    username: String,
11    password: String,
12}
13
14impl Credentials {
15    fn new(username: String, password: String) -> Self {
16        Self { username, password }
17    }
18
19    fn validate(&self, username: &str, password: &str) -> bool {
20        self.username == username && self.password == password
21    }
22}
23
24/// Handles client-facing authentication interception
25#[derive(Default)]
26pub struct AuthHandler {
27    credentials: Option<Credentials>,
28}
29
30impl std::fmt::Debug for AuthHandler {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("AuthHandler")
33            .field("enabled", &self.credentials.is_some())
34            .finish_non_exhaustive()
35    }
36}
37
38impl AuthHandler {
39    /// Create a new auth handler with optional credentials
40    pub fn new(username: Option<String>, password: Option<String>) -> Self {
41        let credentials = match (username, password) {
42            (Some(u), Some(p)) => Some(Credentials::new(u, p)),
43            _ => None,
44        };
45        Self { credentials }
46    }
47
48    /// Check if authentication is enabled
49    #[inline]
50    pub const fn is_enabled(&self) -> bool {
51        self.credentials.is_some()
52    }
53
54    /// Validate credentials
55    pub fn validate(&self, username: &str, password: &str) -> bool {
56        match &self.credentials {
57            None => true, // Auth disabled, always accept
58            Some(creds) => creds.validate(username, password),
59        }
60    }
61
62    /// Handle an auth command - writes response to client and returns (bytes_written, auth_success)
63    /// This is the ONE place where auth interception happens
64    pub async fn handle_auth_command<W>(
65        &self,
66        auth_action: AuthAction,
67        writer: &mut W,
68        stored_username: Option<&str>,
69    ) -> std::io::Result<(usize, bool)>
70    where
71        W: AsyncWriteExt + Unpin,
72    {
73        match auth_action {
74            AuthAction::RequestPassword(_username) => {
75                // Always respond with password required
76                writer.write_all(AUTH_REQUIRED).await?;
77                Ok((AUTH_REQUIRED.len(), false))
78            }
79            AuthAction::ValidateAndRespond { password } => {
80                // Validate credentials
81                let auth_success = if let Some(username) = stored_username {
82                    self.validate(username, &password)
83                } else {
84                    // No username was stored (client sent AUTHINFO PASS without USER)
85                    false
86                };
87
88                let response = if auth_success {
89                    AUTH_ACCEPTED
90                } else {
91                    b"481 Authentication failed\r\n" as &[u8]
92                };
93                writer.write_all(response).await?;
94                Ok((response.len(), auth_success))
95            }
96        }
97    }
98
99    /// Get the AUTHINFO USER response
100    #[inline]
101    pub const fn user_response(&self) -> &'static [u8] {
102        AUTH_REQUIRED
103    }
104
105    /// Get the AUTHINFO PASS response
106    #[inline]
107    pub const fn pass_response(&self) -> &'static [u8] {
108        AUTH_ACCEPTED
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    fn test_handler() -> AuthHandler {
117        AuthHandler::new(None, None)
118    }
119
120    mod credentials {
121        use super::*;
122
123        #[test]
124        fn test_new() {
125            let creds = Credentials::new("user".to_string(), "pass".to_string());
126            assert_eq!(creds.username, "user");
127            assert_eq!(creds.password, "pass");
128        }
129
130        #[test]
131        fn test_validate_correct() {
132            let creds = Credentials::new("alice".to_string(), "secret123".to_string());
133            assert!(creds.validate("alice", "secret123"));
134        }
135
136        #[test]
137        fn test_validate_wrong_username() {
138            let creds = Credentials::new("alice".to_string(), "secret123".to_string());
139            assert!(!creds.validate("bob", "secret123"));
140        }
141
142        #[test]
143        fn test_validate_wrong_password() {
144            let creds = Credentials::new("alice".to_string(), "secret123".to_string());
145            assert!(!creds.validate("alice", "wrong"));
146        }
147
148        #[test]
149        fn test_validate_both_wrong() {
150            let creds = Credentials::new("alice".to_string(), "secret123".to_string());
151            assert!(!creds.validate("bob", "wrong"));
152        }
153
154        #[test]
155        fn test_validate_empty_strings() {
156            let creds = Credentials::new("".to_string(), "".to_string());
157            assert!(creds.validate("", ""));
158            assert!(!creds.validate("user", ""));
159            assert!(!creds.validate("", "pass"));
160        }
161
162        #[test]
163        fn test_validate_case_sensitive() {
164            let creds = Credentials::new("Alice".to_string(), "Secret".to_string());
165            assert!(creds.validate("Alice", "Secret"));
166            assert!(!creds.validate("alice", "Secret"));
167            assert!(!creds.validate("Alice", "secret"));
168            assert!(!creds.validate("alice", "secret"));
169        }
170
171        #[test]
172        fn test_validate_with_spaces() {
173            let creds = Credentials::new("user name".to_string(), "pass word".to_string());
174            assert!(creds.validate("user name", "pass word"));
175            assert!(!creds.validate("username", "password"));
176        }
177
178        #[test]
179        fn test_validate_unicode() {
180            let creds = Credentials::new("用户".to_string(), "密码".to_string());
181            assert!(creds.validate("用户", "密码"));
182            assert!(!creds.validate("user", "pass"));
183        }
184
185        #[test]
186        fn test_validate_special_chars() {
187            let creds =
188                Credentials::new("user@example.com".to_string(), "p@ss!w0rd#123".to_string());
189            assert!(creds.validate("user@example.com", "p@ss!w0rd#123"));
190            assert!(!creds.validate("user", "p@ss!w0rd#123"));
191        }
192
193        #[test]
194        fn test_clone() {
195            let creds1 = Credentials::new("user".to_string(), "pass".to_string());
196            let creds2 = creds1.clone();
197            assert_eq!(creds1.username, creds2.username);
198            assert_eq!(creds1.password, creds2.password);
199            assert!(creds2.validate("user", "pass"));
200        }
201
202        #[test]
203        fn test_very_long_credentials() {
204            let long_user = "u".repeat(1000);
205            let long_pass = "p".repeat(1000);
206            let creds = Credentials::new(long_user.clone(), long_pass.clone());
207            assert!(creds.validate(&long_user, &long_pass));
208            assert!(!creds.validate(&long_user, "wrong"));
209        }
210    }
211
212    mod auth_handler {
213        use super::*;
214
215        #[test]
216        fn test_default() {
217            let handler = AuthHandler::default();
218            assert!(!handler.is_enabled());
219        }
220
221        #[test]
222        fn test_new_with_both_credentials() {
223            let handler = AuthHandler::new(Some("user".to_string()), Some("pass".to_string()));
224            assert!(handler.is_enabled());
225        }
226
227        #[test]
228        fn test_new_with_only_username() {
229            let handler = AuthHandler::new(Some("user".to_string()), None);
230            assert!(!handler.is_enabled());
231        }
232
233        #[test]
234        fn test_new_with_only_password() {
235            let handler = AuthHandler::new(None, Some("pass".to_string()));
236            assert!(!handler.is_enabled());
237        }
238
239        #[test]
240        fn test_new_with_neither() {
241            let handler = AuthHandler::new(None, None);
242            assert!(!handler.is_enabled());
243        }
244
245        #[test]
246        fn test_validate_when_disabled() {
247            let handler = AuthHandler::default();
248            assert!(handler.validate("any", "thing"));
249            assert!(handler.validate("", ""));
250            assert!(handler.validate("foo", "bar"));
251        }
252
253        #[test]
254        fn test_validate_when_enabled() {
255            let handler = AuthHandler::new(Some("alice".to_string()), Some("secret".to_string()));
256            assert!(handler.validate("alice", "secret"));
257            assert!(!handler.validate("alice", "wrong"));
258            assert!(!handler.validate("bob", "secret"));
259            assert!(!handler.validate("bob", "wrong"));
260        }
261
262        #[test]
263        fn test_is_enabled_consistent() {
264            let disabled = AuthHandler::new(None, None);
265            assert!(!disabled.is_enabled());
266            assert!(!disabled.is_enabled()); // Call twice to ensure consistency
267
268            let enabled = AuthHandler::new(Some("u".to_string()), Some("p".to_string()));
269            assert!(enabled.is_enabled());
270            assert!(enabled.is_enabled()); // Call twice to ensure consistency
271        }
272    }
273
274    #[test]
275    fn test_user_response() {
276        let handler = test_handler();
277        let response = handler.user_response();
278        let response_str = String::from_utf8_lossy(response);
279
280        // Should be 381 Password required
281        assert!(response_str.starts_with("381"));
282        assert!(response_str.contains("Password required") || response_str.contains("password"));
283        assert!(response_str.ends_with("\r\n"));
284    }
285
286    #[test]
287    fn test_pass_response() {
288        let handler = test_handler();
289        let response = handler.pass_response();
290        let response_str = String::from_utf8_lossy(response);
291
292        // Should be 281 Authentication accepted
293        assert!(response_str.starts_with("281"));
294        assert!(response_str.contains("accepted") || response_str.contains("Authentication"));
295        assert!(response_str.ends_with("\r\n"));
296    }
297
298    #[test]
299    fn test_responses_are_static() {
300        // Verify responses are the same each time (static)
301        let handler = test_handler();
302        let response1 = handler.user_response();
303        let response2 = handler.user_response();
304        assert_eq!(response1.as_ptr(), response2.as_ptr());
305
306        let response3 = handler.pass_response();
307        let response4 = handler.pass_response();
308        assert_eq!(response3.as_ptr(), response4.as_ptr());
309    }
310
311    #[test]
312    fn test_responses_are_different() {
313        // User and pass responses should be different
314        let handler = test_handler();
315        let user_resp = handler.user_response();
316        let pass_resp = handler.pass_response();
317        assert_ne!(user_resp, pass_resp);
318    }
319
320    #[test]
321    fn test_responses_are_valid_utf8() {
322        // Ensure responses are valid UTF-8
323        let handler = test_handler();
324        let user_resp = handler.user_response();
325        assert!(std::str::from_utf8(user_resp).is_ok());
326
327        let pass_resp = handler.pass_response();
328        assert!(std::str::from_utf8(pass_resp).is_ok());
329    }
330
331    #[test]
332    fn test_auth_disabled_by_default() {
333        let handler = AuthHandler::default();
334        assert!(!handler.is_enabled());
335        assert!(handler.validate("any", "thing")); // Should accept anything
336    }
337
338    #[test]
339    fn test_auth_new_none_none() {
340        let handler = AuthHandler::new(None, None);
341        assert!(!handler.is_enabled());
342        assert!(handler.validate("any", "thing"));
343    }
344
345    #[test]
346    fn test_auth_enabled_with_credentials() {
347        let handler = AuthHandler::new(Some("mjc".to_string()), Some("nntp1337".to_string()));
348        assert!(handler.is_enabled());
349        assert!(handler.validate("mjc", "nntp1337"));
350        assert!(!handler.validate("mjc", "wrong"));
351        assert!(!handler.validate("wrong", "nntp1337"));
352    }
353}