1use crate::command::AuthAction;
4use crate::protocol::{AUTH_ACCEPTED, AUTH_REQUIRED};
5use tokio::io::AsyncWriteExt;
6
7#[derive(Clone)]
9struct Credentials {
10 username: String,
11 password: String,
12}
13
14impl Credentials {
15 fn new(username: String, password: String) -> Self {
16 Self { username, password }
17 }
18
19 fn validate(&self, username: &str, password: &str) -> bool {
20 self.username == username && self.password == password
21 }
22}
23
24#[derive(Default)]
26pub struct AuthHandler {
27 credentials: Option<Credentials>,
28}
29
30impl std::fmt::Debug for AuthHandler {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("AuthHandler")
33 .field("enabled", &self.credentials.is_some())
34 .finish_non_exhaustive()
35 }
36}
37
38impl AuthHandler {
39 pub fn new(username: Option<String>, password: Option<String>) -> Self {
41 let credentials = match (username, password) {
42 (Some(u), Some(p)) => Some(Credentials::new(u, p)),
43 _ => None,
44 };
45 Self { credentials }
46 }
47
48 #[inline]
50 pub const fn is_enabled(&self) -> bool {
51 self.credentials.is_some()
52 }
53
54 pub fn validate(&self, username: &str, password: &str) -> bool {
56 match &self.credentials {
57 None => true, Some(creds) => creds.validate(username, password),
59 }
60 }
61
62 pub async fn handle_auth_command<W>(
65 &self,
66 auth_action: AuthAction,
67 writer: &mut W,
68 stored_username: Option<&str>,
69 ) -> std::io::Result<(usize, bool)>
70 where
71 W: AsyncWriteExt + Unpin,
72 {
73 match auth_action {
74 AuthAction::RequestPassword(_username) => {
75 writer.write_all(AUTH_REQUIRED).await?;
77 Ok((AUTH_REQUIRED.len(), false))
78 }
79 AuthAction::ValidateAndRespond { password } => {
80 let auth_success = if let Some(username) = stored_username {
82 self.validate(username, &password)
83 } else {
84 false
86 };
87
88 let response = if auth_success {
89 AUTH_ACCEPTED
90 } else {
91 b"481 Authentication failed\r\n" as &[u8]
92 };
93 writer.write_all(response).await?;
94 Ok((response.len(), auth_success))
95 }
96 }
97 }
98
99 #[inline]
101 pub const fn user_response(&self) -> &'static [u8] {
102 AUTH_REQUIRED
103 }
104
105 #[inline]
107 pub const fn pass_response(&self) -> &'static [u8] {
108 AUTH_ACCEPTED
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 fn test_handler() -> AuthHandler {
117 AuthHandler::new(None, None)
118 }
119
120 mod credentials {
121 use super::*;
122
123 #[test]
124 fn test_new() {
125 let creds = Credentials::new("user".to_string(), "pass".to_string());
126 assert_eq!(creds.username, "user");
127 assert_eq!(creds.password, "pass");
128 }
129
130 #[test]
131 fn test_validate_correct() {
132 let creds = Credentials::new("alice".to_string(), "secret123".to_string());
133 assert!(creds.validate("alice", "secret123"));
134 }
135
136 #[test]
137 fn test_validate_wrong_username() {
138 let creds = Credentials::new("alice".to_string(), "secret123".to_string());
139 assert!(!creds.validate("bob", "secret123"));
140 }
141
142 #[test]
143 fn test_validate_wrong_password() {
144 let creds = Credentials::new("alice".to_string(), "secret123".to_string());
145 assert!(!creds.validate("alice", "wrong"));
146 }
147
148 #[test]
149 fn test_validate_both_wrong() {
150 let creds = Credentials::new("alice".to_string(), "secret123".to_string());
151 assert!(!creds.validate("bob", "wrong"));
152 }
153
154 #[test]
155 fn test_validate_empty_strings() {
156 let creds = Credentials::new("".to_string(), "".to_string());
157 assert!(creds.validate("", ""));
158 assert!(!creds.validate("user", ""));
159 assert!(!creds.validate("", "pass"));
160 }
161
162 #[test]
163 fn test_validate_case_sensitive() {
164 let creds = Credentials::new("Alice".to_string(), "Secret".to_string());
165 assert!(creds.validate("Alice", "Secret"));
166 assert!(!creds.validate("alice", "Secret"));
167 assert!(!creds.validate("Alice", "secret"));
168 assert!(!creds.validate("alice", "secret"));
169 }
170
171 #[test]
172 fn test_validate_with_spaces() {
173 let creds = Credentials::new("user name".to_string(), "pass word".to_string());
174 assert!(creds.validate("user name", "pass word"));
175 assert!(!creds.validate("username", "password"));
176 }
177
178 #[test]
179 fn test_validate_unicode() {
180 let creds = Credentials::new("用户".to_string(), "密码".to_string());
181 assert!(creds.validate("用户", "密码"));
182 assert!(!creds.validate("user", "pass"));
183 }
184
185 #[test]
186 fn test_validate_special_chars() {
187 let creds =
188 Credentials::new("user@example.com".to_string(), "p@ss!w0rd#123".to_string());
189 assert!(creds.validate("user@example.com", "p@ss!w0rd#123"));
190 assert!(!creds.validate("user", "p@ss!w0rd#123"));
191 }
192
193 #[test]
194 fn test_clone() {
195 let creds1 = Credentials::new("user".to_string(), "pass".to_string());
196 let creds2 = creds1.clone();
197 assert_eq!(creds1.username, creds2.username);
198 assert_eq!(creds1.password, creds2.password);
199 assert!(creds2.validate("user", "pass"));
200 }
201
202 #[test]
203 fn test_very_long_credentials() {
204 let long_user = "u".repeat(1000);
205 let long_pass = "p".repeat(1000);
206 let creds = Credentials::new(long_user.clone(), long_pass.clone());
207 assert!(creds.validate(&long_user, &long_pass));
208 assert!(!creds.validate(&long_user, "wrong"));
209 }
210 }
211
212 mod auth_handler {
213 use super::*;
214
215 #[test]
216 fn test_default() {
217 let handler = AuthHandler::default();
218 assert!(!handler.is_enabled());
219 }
220
221 #[test]
222 fn test_new_with_both_credentials() {
223 let handler = AuthHandler::new(Some("user".to_string()), Some("pass".to_string()));
224 assert!(handler.is_enabled());
225 }
226
227 #[test]
228 fn test_new_with_only_username() {
229 let handler = AuthHandler::new(Some("user".to_string()), None);
230 assert!(!handler.is_enabled());
231 }
232
233 #[test]
234 fn test_new_with_only_password() {
235 let handler = AuthHandler::new(None, Some("pass".to_string()));
236 assert!(!handler.is_enabled());
237 }
238
239 #[test]
240 fn test_new_with_neither() {
241 let handler = AuthHandler::new(None, None);
242 assert!(!handler.is_enabled());
243 }
244
245 #[test]
246 fn test_validate_when_disabled() {
247 let handler = AuthHandler::default();
248 assert!(handler.validate("any", "thing"));
249 assert!(handler.validate("", ""));
250 assert!(handler.validate("foo", "bar"));
251 }
252
253 #[test]
254 fn test_validate_when_enabled() {
255 let handler = AuthHandler::new(Some("alice".to_string()), Some("secret".to_string()));
256 assert!(handler.validate("alice", "secret"));
257 assert!(!handler.validate("alice", "wrong"));
258 assert!(!handler.validate("bob", "secret"));
259 assert!(!handler.validate("bob", "wrong"));
260 }
261
262 #[test]
263 fn test_is_enabled_consistent() {
264 let disabled = AuthHandler::new(None, None);
265 assert!(!disabled.is_enabled());
266 assert!(!disabled.is_enabled()); let enabled = AuthHandler::new(Some("u".to_string()), Some("p".to_string()));
269 assert!(enabled.is_enabled());
270 assert!(enabled.is_enabled()); }
272 }
273
274 #[test]
275 fn test_user_response() {
276 let handler = test_handler();
277 let response = handler.user_response();
278 let response_str = String::from_utf8_lossy(response);
279
280 assert!(response_str.starts_with("381"));
282 assert!(response_str.contains("Password required") || response_str.contains("password"));
283 assert!(response_str.ends_with("\r\n"));
284 }
285
286 #[test]
287 fn test_pass_response() {
288 let handler = test_handler();
289 let response = handler.pass_response();
290 let response_str = String::from_utf8_lossy(response);
291
292 assert!(response_str.starts_with("281"));
294 assert!(response_str.contains("accepted") || response_str.contains("Authentication"));
295 assert!(response_str.ends_with("\r\n"));
296 }
297
298 #[test]
299 fn test_responses_are_static() {
300 let handler = test_handler();
302 let response1 = handler.user_response();
303 let response2 = handler.user_response();
304 assert_eq!(response1.as_ptr(), response2.as_ptr());
305
306 let response3 = handler.pass_response();
307 let response4 = handler.pass_response();
308 assert_eq!(response3.as_ptr(), response4.as_ptr());
309 }
310
311 #[test]
312 fn test_responses_are_different() {
313 let handler = test_handler();
315 let user_resp = handler.user_response();
316 let pass_resp = handler.pass_response();
317 assert_ne!(user_resp, pass_resp);
318 }
319
320 #[test]
321 fn test_responses_are_valid_utf8() {
322 let handler = test_handler();
324 let user_resp = handler.user_response();
325 assert!(std::str::from_utf8(user_resp).is_ok());
326
327 let pass_resp = handler.pass_response();
328 assert!(std::str::from_utf8(pass_resp).is_ok());
329 }
330
331 #[test]
332 fn test_auth_disabled_by_default() {
333 let handler = AuthHandler::default();
334 assert!(!handler.is_enabled());
335 assert!(handler.validate("any", "thing")); }
337
338 #[test]
339 fn test_auth_new_none_none() {
340 let handler = AuthHandler::new(None, None);
341 assert!(!handler.is_enabled());
342 assert!(handler.validate("any", "thing"));
343 }
344
345 #[test]
346 fn test_auth_enabled_with_credentials() {
347 let handler = AuthHandler::new(Some("mjc".to_string()), Some("nntp1337".to_string()));
348 assert!(handler.is_enabled());
349 assert!(handler.validate("mjc", "nntp1337"));
350 assert!(!handler.validate("mjc", "wrong"));
351 assert!(!handler.validate("wrong", "nntp1337"));
352 }
353}