1use super::classifier::NntpCommand;
21
22#[derive(Debug, Clone, Copy, PartialEq)]
24#[non_exhaustive]
25pub enum CommandAction<'a> {
26 InterceptAuth(AuthAction<'a>),
28 Reject(&'static str),
30 ForwardStateless,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq)]
36#[non_exhaustive]
37pub enum AuthAction<'a> {
38 RequestPassword(&'a str),
40 ValidateAndRespond { password: &'a str },
42}
43
44pub struct CommandHandler;
46
47impl CommandHandler {
48 pub fn classify(command: &str) -> CommandAction<'_> {
50 match NntpCommand::parse(command) {
51 NntpCommand::AuthUser => {
52 let username = command
54 .trim()
55 .strip_prefix("AUTHINFO USER")
56 .or_else(|| command.trim().strip_prefix("authinfo user"))
57 .unwrap_or("")
58 .trim();
59 CommandAction::InterceptAuth(AuthAction::RequestPassword(username))
60 }
61 NntpCommand::AuthPass => {
62 let password = command
64 .trim()
65 .strip_prefix("AUTHINFO PASS")
66 .or_else(|| command.trim().strip_prefix("authinfo pass"))
67 .unwrap_or("")
68 .trim();
69 CommandAction::InterceptAuth(AuthAction::ValidateAndRespond { password })
70 }
71 NntpCommand::Stateful => {
72 CommandAction::Reject("502 Command not implemented in stateless proxy mode\r\n")
75 }
76 NntpCommand::NonRoutable => {
77 CommandAction::Reject("502 Command not implemented in per-command routing mode\r\n")
80 }
81 NntpCommand::ArticleByMessageId => CommandAction::ForwardStateless,
82 NntpCommand::Stateless => CommandAction::ForwardStateless,
83 }
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn test_auth_user_command() {
93 let action = CommandHandler::classify("AUTHINFO USER test");
94 assert!(matches!(
95 action,
96 CommandAction::InterceptAuth(AuthAction::RequestPassword(username)) if username == "test"
97 ));
98 }
99
100 #[test]
101 fn test_auth_pass_command() {
102 let action = CommandHandler::classify("AUTHINFO PASS secret");
103 assert!(matches!(
104 action,
105 CommandAction::InterceptAuth(AuthAction::ValidateAndRespond { password }) if password == "secret"
106 ));
107 }
108
109 #[test]
110 fn test_stateful_command_rejected() {
111 let action = CommandHandler::classify("GROUP alt.test");
112 assert!(
113 matches!(action, CommandAction::Reject(msg) if msg.contains("stateless")),
114 "Expected Reject with 'stateless' in message"
115 );
116 }
117
118 #[test]
119 fn test_article_by_message_id() {
120 let action = CommandHandler::classify("ARTICLE <test@example.com>");
121 assert_eq!(action, CommandAction::ForwardStateless);
122 }
123
124 #[test]
125 fn test_stateless_command() {
126 let action = CommandHandler::classify("LIST");
127 assert_eq!(action, CommandAction::ForwardStateless);
128
129 let action = CommandHandler::classify("HELP");
130 assert_eq!(action, CommandAction::ForwardStateless);
131 }
132
133 #[test]
134 fn test_all_stateful_commands_rejected() {
135 let stateful_commands = vec![
137 "GROUP alt.test",
138 "NEXT",
139 "LAST",
140 "LISTGROUP alt.test",
141 "ARTICLE 123",
142 "HEAD 456",
143 "BODY 789",
144 "STAT",
145 "XOVER 1-100",
146 ];
147
148 for cmd in stateful_commands {
149 match CommandHandler::classify(cmd) {
150 CommandAction::Reject(msg) => {
151 assert!(msg.contains("stateless") || msg.contains("not supported"));
152 }
153 other => panic!("Expected Reject for '{}', got {:?}", cmd, other),
154 }
155 }
156 }
157
158 #[test]
159 fn test_all_article_by_msgid_forwarded() {
160 let msgid_commands = vec![
162 "ARTICLE <test@example.com>",
163 "BODY <msg@server.org>",
164 "HEAD <id@host.net>",
165 "STAT <unique@domain.com>",
166 ];
167
168 for cmd in msgid_commands {
169 assert_eq!(
170 CommandHandler::classify(cmd),
171 CommandAction::ForwardStateless,
172 "Command '{}' should be forwarded as stateless",
173 cmd
174 );
175 }
176 }
177
178 #[test]
179 fn test_various_stateless_commands() {
180 let stateless_commands = vec![
181 "HELP",
182 "LIST",
183 "LIST ACTIVE",
184 "LIST NEWSGROUPS",
185 "DATE",
186 "CAPABILITIES",
187 "QUIT",
188 ];
189
190 for cmd in stateless_commands {
191 assert_eq!(
192 CommandHandler::classify(cmd),
193 CommandAction::ForwardStateless,
194 "Command '{}' should be stateless",
195 cmd
196 );
197 }
198 }
199
200 #[test]
201 fn test_case_insensitive_handling() {
202 assert_eq!(
204 CommandHandler::classify("list"),
205 CommandAction::ForwardStateless
206 );
207 assert_eq!(
208 CommandHandler::classify("LiSt"),
209 CommandAction::ForwardStateless
210 );
211 assert_eq!(
212 CommandHandler::classify("QUIT"),
213 CommandAction::ForwardStateless
214 );
215 assert_eq!(
216 CommandHandler::classify("quit"),
217 CommandAction::ForwardStateless
218 );
219 }
220
221 #[test]
222 fn test_empty_command() {
223 let action = CommandHandler::classify("");
225 assert_eq!(action, CommandAction::ForwardStateless);
226 }
227
228 #[test]
229 fn test_whitespace_handling() {
230 let action = CommandHandler::classify(" LIST ");
232 assert_eq!(action, CommandAction::ForwardStateless);
233
234 let action = CommandHandler::classify(" AUTHINFO USER test ");
236 assert!(matches!(
237 action,
238 CommandAction::InterceptAuth(AuthAction::RequestPassword(username)) if username == "test"
239 ));
240 }
241
242 #[test]
243 fn test_malformed_auth_commands() {
244 let action = CommandHandler::classify("AUTHINFO");
246 assert_eq!(action, CommandAction::ForwardStateless);
247
248 let action = CommandHandler::classify("AUTHINFO INVALID");
250 assert_eq!(action, CommandAction::ForwardStateless);
251 }
252
253 #[test]
254 fn test_auth_commands_without_arguments() {
255 let action = CommandHandler::classify("AUTHINFO USER");
257 assert!(matches!(
258 action,
259 CommandAction::InterceptAuth(AuthAction::RequestPassword(username)) if username.is_empty()
260 ));
261
262 let action = CommandHandler::classify("AUTHINFO PASS");
264 assert!(matches!(
265 action,
266 CommandAction::InterceptAuth(AuthAction::ValidateAndRespond { password }) if password.is_empty()
267 ));
268 }
269
270 #[test]
271 fn test_article_commands_with_newlines() {
272 let action = CommandHandler::classify("ARTICLE <msg@test.com>\r\n");
274 assert_eq!(action, CommandAction::ForwardStateless);
275
276 let action = CommandHandler::classify("LIST\n");
278 assert_eq!(action, CommandAction::ForwardStateless);
279 }
280
281 #[test]
282 fn test_very_long_commands() {
283 let long_cmd = format!("LIST {}", "A".repeat(10000));
285 let action = CommandHandler::classify(&long_cmd);
286 assert_eq!(action, CommandAction::ForwardStateless);
287
288 let long_group = format!("GROUP {}", "alt.".repeat(1000));
290 match CommandHandler::classify(&long_group) {
291 CommandAction::Reject(_) => {} other => panic!("Expected Reject for long GROUP, got {:?}", other),
293 }
294 }
295
296 #[test]
297 fn test_command_action_equality() {
298 assert_eq!(
300 CommandAction::ForwardStateless,
301 CommandAction::ForwardStateless
302 );
303 assert_eq!(
304 CommandAction::InterceptAuth(AuthAction::RequestPassword("test")),
305 CommandAction::InterceptAuth(AuthAction::RequestPassword("test"))
306 );
307
308 assert_ne!(
310 CommandAction::InterceptAuth(AuthAction::RequestPassword("user1")),
311 CommandAction::InterceptAuth(AuthAction::ValidateAndRespond { password: "pass1" })
312 );
313 }
314
315 #[test]
316 fn test_reject_messages() {
317 assert!(
319 matches!(
320 CommandHandler::classify("GROUP alt.test"),
321 CommandAction::Reject(msg) if !msg.is_empty() && msg.len() > 10
322 ),
323 "Expected Reject with meaningful message"
324 );
325 }
326
327 #[test]
328 fn test_unknown_commands_forwarded() {
329 let unknown_commands = ["INVALIDCOMMAND", "XYZABC", "RANDOM DATA", "12345"];
332
333 assert!(
334 unknown_commands
335 .iter()
336 .all(|cmd| { CommandHandler::classify(cmd) == CommandAction::ForwardStateless }),
337 "All unknown commands should be forwarded as stateless"
338 );
339 }
340
341 #[test]
342 fn test_non_routable_commands_rejected() {
343 assert!(
345 matches!(
346 CommandHandler::classify("POST"),
347 CommandAction::Reject(msg) if msg.contains("routing")
348 ),
349 "Expected Reject for POST"
350 );
351
352 assert!(
354 matches!(
355 CommandHandler::classify("IHAVE <test@example.com>"),
356 CommandAction::Reject(msg) if msg.contains("routing")
357 ),
358 "Expected Reject for IHAVE"
359 );
360
361 assert!(
363 matches!(
364 CommandHandler::classify("NEWGROUPS 20240101 000000 GMT"),
365 CommandAction::Reject(msg) if msg.contains("routing")
366 ),
367 "Expected Reject for NEWGROUPS"
368 );
369
370 assert!(
372 matches!(
373 CommandHandler::classify("NEWNEWS * 20240101 000000 GMT"),
374 CommandAction::Reject(msg) if msg.contains("routing")
375 ),
376 "Expected Reject for NEWNEWS"
377 );
378 }
379
380 #[test]
381 fn test_reject_message_content() {
382 let CommandAction::Reject(stateful_reject) = CommandHandler::classify("GROUP alt.test")
384 else {
385 panic!("Expected Reject")
386 };
387
388 let CommandAction::Reject(routing_reject) = CommandHandler::classify("POST") else {
389 panic!("Expected Reject")
390 };
391
392 assert!(stateful_reject.contains("stateless"));
394 assert!(routing_reject.contains("routing"));
395 assert_ne!(stateful_reject, routing_reject);
396 }
397
398 #[test]
399 fn test_reject_response_format() {
400 let CommandAction::Reject(response) = CommandHandler::classify("GROUP alt.test") else {
404 panic!("Expected Reject")
405 };
406
407 assert!(response.len() >= 3, "Response too short");
409 assert!(
410 response[0..3].chars().all(|c| c.is_ascii_digit()),
411 "First 3 chars must be digits, got: {}",
412 &response[0..3]
413 );
414
415 assert_eq!(&response[3..4], " ", "Must have space after status code");
417
418 assert!(response.ends_with("\r\n"), "Response must end with CRLF");
420
421 assert!(
424 response.starts_with("502 "),
425 "Expected 502 status code, got: {}",
426 response
427 );
428 }
429
430 #[test]
431 fn test_all_reject_responses_are_valid_nntp() {
432 let reject_commands = vec![
434 "GROUP alt.test",
435 "NEXT",
436 "LAST",
437 "POST",
438 "IHAVE <test@example.com>",
439 "NEWGROUPS 20240101 000000 GMT",
440 ];
441
442 for cmd in reject_commands {
443 let CommandAction::Reject(response) = CommandHandler::classify(cmd) else {
444 panic!("Expected Reject for command: {}", cmd);
445 };
446
447 assert!(
449 response.len() >= 5,
450 "Response too short for {}: {}",
451 cmd,
452 response
453 );
454 assert!(
455 response.starts_with(|c: char| c.is_ascii_digit()),
456 "Must start with digit for {}: {}",
457 cmd,
458 response
459 );
460 assert!(
461 response.ends_with("\r\n"),
462 "Must end with CRLF for {}: {}",
463 cmd,
464 response
465 );
466 assert!(
467 response.contains(' '),
468 "Must have space separator for {}: {}",
469 cmd,
470 response
471 );
472 }
473 }
474
475 #[test]
476 fn test_502_status_code_usage() {
477 let CommandAction::Reject(response) = CommandHandler::classify("GROUP alt.test") else {
484 panic!("Expected Reject");
485 };
486 assert!(
487 response.starts_with("502 "),
488 "Stateful commands should return 502, got: {}",
489 response
490 );
491
492 let CommandAction::Reject(response) = CommandHandler::classify("POST") else {
494 panic!("Expected Reject");
495 };
496 assert!(
497 response.starts_with("502 "),
498 "Non-routable commands should return 502, got: {}",
499 response
500 );
501 }
502
503 #[test]
504 fn test_response_messages_are_descriptive() {
505 let CommandAction::Reject(stateful) = CommandHandler::classify("GROUP alt.test") else {
507 panic!("Expected Reject");
508 };
509 assert!(
510 stateful.to_lowercase().contains("stateless")
511 || stateful.to_lowercase().contains("mode"),
512 "Should explain stateless mode restriction: {}",
513 stateful
514 );
515
516 let CommandAction::Reject(routing) = CommandHandler::classify("POST") else {
517 panic!("Expected Reject");
518 };
519 assert!(
520 routing.to_lowercase().contains("routing") || routing.to_lowercase().contains("mode"),
521 "Should explain routing mode restriction: {}",
522 routing
523 );
524 }
525}