Skip to main content

legion_protocol/
validation.rs

1//! Protocol validation and security checks
2
3use crate::error::{IronError, Result};
4use crate::constants::*;
5use crate::utils::{is_valid_nick, is_valid_channel};
6
7/// Validate an IRC nickname
8pub fn validate_nickname(nick: &str) -> Result<()> {
9    if !is_valid_nick(nick) {
10        return Err(IronError::InvalidInput(
11            format!("Invalid nickname: {}", nick)
12        ));
13    }
14    Ok(())
15}
16
17/// Validate an IRC channel name
18pub fn validate_channel_name(channel: &str) -> Result<()> {
19    if !is_valid_channel(channel) {
20        return Err(IronError::InvalidInput(
21            format!("Invalid channel name: {}", channel)
22        ));
23    }
24    Ok(())
25}
26
27/// Validate message content for security issues
28pub fn validate_message_content(content: &str) -> Result<()> {
29    // Check for null bytes
30    if content.contains('\0') {
31        return Err(IronError::SecurityViolation(
32            "Message contains null bytes".to_string()
33        ));
34    }
35
36    // Check for control characters that could cause issues
37    if content.chars().any(|c| c.is_control() && c != '\t') {
38        return Err(IronError::SecurityViolation(
39            "Message contains dangerous control characters".to_string()
40        ));
41    }
42
43    // Check message length
44    if content.len() > MAX_MESSAGE_LENGTH {
45        return Err(IronError::InvalidInput(
46            format!("Message too long: {} > {}", content.len(), MAX_MESSAGE_LENGTH)
47        ));
48    }
49
50    Ok(())
51}
52
53/// Validate hostname/server name
54pub fn validate_hostname(hostname: &str) -> Result<()> {
55    if hostname.is_empty() {
56        return Err(IronError::InvalidInput(
57            "Hostname cannot be empty".to_string()
58        ));
59    }
60
61    if hostname.len() > 255 {
62        return Err(IronError::InvalidInput(
63            "Hostname too long".to_string()
64        ));
65    }
66
67    // Basic hostname validation
68    if !hostname.chars().all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-') {
69        return Err(IronError::InvalidInput(
70            "Invalid characters in hostname".to_string()
71        ));
72    }
73
74    // Cannot start or end with hyphen
75    if hostname.starts_with('-') || hostname.ends_with('-') {
76        return Err(IronError::InvalidInput(
77            "Hostname cannot start or end with hyphen".to_string()
78        ));
79    }
80
81    Ok(())
82}
83
84/// Validate user information for registration
85pub fn validate_user_info(username: &str, realname: &str) -> Result<()> {
86    // Validate username
87    if username.is_empty() || username.len() > 32 {
88        return Err(IronError::InvalidInput(
89            "Invalid username length".to_string()
90        ));
91    }
92
93    if !username.chars().all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.') {
94        return Err(IronError::InvalidInput(
95            "Invalid characters in username".to_string()
96        ));
97    }
98
99    // Validate realname
100    if realname.len() > 255 {
101        return Err(IronError::InvalidInput(
102            "Real name too long".to_string()
103        ));
104    }
105
106    // Check for dangerous characters in realname
107    if realname.contains('\0') || realname.contains('\r') || realname.contains('\n') {
108        return Err(IronError::SecurityViolation(
109            "Real name contains invalid characters".to_string()
110        ));
111    }
112
113    Ok(())
114}
115
116/// Validate CTCP message
117pub fn validate_ctcp_message(message: &str) -> Result<()> {
118    // CTCP messages should be wrapped in \x01
119    if !message.starts_with('\x01') || !message.ends_with('\x01') {
120        return Err(IronError::InvalidInput(
121            "Invalid CTCP message format".to_string()
122        ));
123    }
124
125    // Extract the inner content
126    let content = &message[1..message.len()-1];
127    
128    // Basic length check
129    if content.len() > MAX_MESSAGE_LENGTH - 20 { // Leave room for formatting
130        return Err(IronError::InvalidInput(
131            "CTCP message too long".to_string()
132        ));
133    }
134
135    // Check for nested CTCP delimiters
136    if content.contains('\x01') {
137        return Err(IronError::SecurityViolation(
138            "Nested CTCP delimiters not allowed".to_string()
139        ));
140    }
141
142    Ok(())
143}
144
145/// Check for flood/spam patterns
146pub fn check_flood_protection(messages: &[&str], _time_window: std::time::Duration) -> Result<()> {
147    if messages.len() > 10 {
148        return Err(IronError::RateLimit(
149            "Too many messages in time window".to_string()
150        ));
151    }
152
153    // Check for repeated messages (simple spam detection)
154    if messages.len() >= 3 {
155        let last_three: Vec<&str> = messages.iter().rev().take(3).cloned().collect();
156        if last_three.iter().all(|&msg| msg == last_three[0]) {
157            return Err(IronError::RateLimit(
158                "Repeated message spam detected".to_string()
159            ));
160        }
161    }
162
163    Ok(())
164}
165
166/// Validate IRC mode string
167pub fn validate_mode_string(mode_string: &str) -> Result<()> {
168    if mode_string.is_empty() {
169        return Ok(());
170    }
171
172    let mut chars = mode_string.chars();
173    
174    // First character should be + or -
175    match chars.next() {
176        Some('+') | Some('-') => {},
177        _ => return Err(IronError::InvalidInput(
178            "Mode string must start with + or -".to_string()
179        )),
180    }
181
182    // Remaining characters should be valid mode letters
183    for c in chars {
184        if !c.is_ascii_alphabetic() {
185            return Err(IronError::InvalidInput(
186                format!("Invalid mode character: {}", c)
187            ));
188        }
189    }
190
191    Ok(())
192}
193
194/// Sanitize user input to prevent injection attacks
195pub fn sanitize_user_input(input: &str) -> String {
196    input
197        .replace('\0', "")        // Remove null bytes
198        .replace('\r', "")        // Remove carriage returns
199        .replace('\n', " ")       // Replace newlines with spaces
200        .replace('\t', " ")       // Replace tabs with spaces
201        .chars()
202        .filter(|c| !c.is_control() || *c == ' ') // Remove other control characters
203        .take(MAX_MESSAGE_LENGTH) // Truncate to max length
204        .collect()
205}
206
207/// Check if a string contains potentially dangerous content
208pub fn contains_dangerous_content(content: &str) -> bool {
209    // Check for common IRC injection patterns
210    let dangerous_patterns = [
211        "\r\n",           // IRC line breaks
212        "\x01",           // CTCP delimiter
213        "PRIVMSG",        // Command injection attempts
214        "NOTICE",
215        "JOIN",
216        "PART",
217        "QUIT",
218        "KICK",
219        "MODE",
220    ];
221
222    for pattern in &dangerous_patterns {
223        if content.to_uppercase().contains(&pattern.to_uppercase()) {
224            return true;
225        }
226    }
227
228    false
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn test_nickname_validation() {
237        assert!(validate_nickname("Alice").is_ok());
238        assert!(validate_nickname("Bot123").is_ok());
239        assert!(validate_nickname("[Server]").is_ok());
240        
241        assert!(validate_nickname("").is_err());
242        assert!(validate_nickname("123user").is_err()); // Can't start with number
243        assert!(validate_nickname("user name").is_err()); // No spaces
244    }
245
246    #[test]
247    fn test_channel_validation() {
248        assert!(validate_channel_name("#general").is_ok());
249        assert!(validate_channel_name("&local").is_ok());
250        
251        assert!(validate_channel_name("general").is_err()); // Must start with # or &
252        assert!(validate_channel_name("#test channel").is_err()); // No spaces
253    }
254
255    #[test]
256    fn test_message_content_validation() {
257        assert!(validate_message_content("Hello world").is_ok());
258        
259        assert!(validate_message_content("Bad\0message").is_err()); // Null byte
260        assert!(validate_message_content(&"x".repeat(600)).is_err()); // Too long
261    }
262
263    #[test]
264    fn test_hostname_validation() {
265        assert!(validate_hostname("irc.example.com").is_ok());
266        assert!(validate_hostname("server1.chat").is_ok());
267        
268        assert!(validate_hostname("").is_err()); // Empty
269        assert!(validate_hostname("-invalid.com").is_err()); // Starts with hyphen
270        assert!(validate_hostname("bad_host").is_err()); // Invalid character
271    }
272
273    #[test]
274    fn test_user_info_validation() {
275        assert!(validate_user_info("alice", "Alice Smith").is_ok());
276        
277        assert!(validate_user_info("", "Real Name").is_err()); // Empty username
278        assert!(validate_user_info("user@host", "Name").is_err()); // Invalid char
279        assert!(validate_user_info("user", "Bad\0name").is_err()); // Null in realname
280    }
281
282    #[test]
283    fn test_ctcp_validation() {
284        assert!(validate_ctcp_message("\x01VERSION\x01").is_ok());
285        assert!(validate_ctcp_message("\x01ACTION waves\x01").is_ok());
286        
287        assert!(validate_ctcp_message("VERSION").is_err()); // Missing delimiters
288        assert!(validate_ctcp_message("\x01BAD\x01MESSAGE\x01").is_err()); // Nested delimiters
289    }
290
291    #[test]
292    fn test_mode_string_validation() {
293        assert!(validate_mode_string("+nt").is_ok());
294        assert!(validate_mode_string("-i").is_ok());
295        assert!(validate_mode_string("").is_ok()); // Empty is OK
296        
297        assert!(validate_mode_string("nt").is_err()); // Missing +/-
298        assert!(validate_mode_string("+n2t").is_err()); // Invalid character
299    }
300
301    #[test]
302    fn test_sanitize_user_input() {
303        assert_eq!(sanitize_user_input("Hello\0world\r\n"), "Helloworld ");
304        assert_eq!(sanitize_user_input("Normal text"), "Normal text");
305        
306        let long_input = "x".repeat(1000);
307        let sanitized = sanitize_user_input(&long_input);
308        assert!(sanitized.len() <= MAX_MESSAGE_LENGTH);
309    }
310
311    #[test]
312    fn test_dangerous_content_detection() {
313        assert!(contains_dangerous_content("PRIVMSG #test :hello"));
314        assert!(contains_dangerous_content("Some\r\nmessage"));
315        assert!(contains_dangerous_content("\x01ACTION test\x01"));
316        
317        assert!(!contains_dangerous_content("Normal message"));
318        assert!(!contains_dangerous_content("Hello world"));
319    }
320}