Skip to main content

legion_protocol/
message.rs

1//! IRC message parsing and serialization
2//!
3//! This module provides the core `IrcMessage` type and related functionality
4//! for parsing and serializing IRC messages according to the IRCv3 specification.
5
6use crate::error::{IronError, Result};
7use crate::constants::*;
8use std::collections::HashMap;
9use std::str::FromStr;
10
11#[cfg(feature = "chrono")]
12use std::time::SystemTime;
13
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Serialize};
16
17/// An IRC message with optional tags, prefix, command, and parameters
18#[derive(Debug, Clone, PartialEq)]
19#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
20pub struct IrcMessage {
21    /// Message tags (IRCv3)
22    pub tags: HashMap<String, Option<String>>,
23    /// Message prefix (source)
24    pub prefix: Option<String>,
25    /// IRC command
26    pub command: String,
27    /// Command parameters
28    pub params: Vec<String>,
29}
30
31impl IrcMessage {
32    /// Create a new IRC message with the given command
33    pub fn new(command: impl Into<String>) -> Self {
34        Self {
35            tags: HashMap::new(),
36            prefix: None,
37            command: command.into(),
38            params: Vec::new(),
39        }
40    }
41
42    /// Add parameters to the message
43    pub fn with_params(mut self, params: Vec<String>) -> Self {
44        self.params = params;
45        self
46    }
47
48    /// Add a prefix to the message
49    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
50        self.prefix = Some(prefix.into());
51        self
52    }
53
54    /// Add a tag to the message
55    pub fn with_tag(mut self, key: impl Into<String>, value: Option<String>) -> Self {
56        self.tags.insert(key.into(), value);
57        self
58    }
59
60    /// Add multiple tags to the message
61    pub fn with_tags(mut self, tags: HashMap<String, Option<String>>) -> Self {
62        self.tags.extend(tags);
63        self
64    }
65
66    /// Create a raw message (for debugging/testing)
67    pub fn raw(data: &str) -> Self {
68        Self {
69            tags: HashMap::new(),
70            prefix: None,
71            command: "RAW".to_string(),
72            params: vec![data.to_string()],
73        }
74    }
75
76    /// Extract server timestamp from message tags, fallback to current time
77    #[cfg(feature = "chrono")]
78    pub fn get_timestamp(&self) -> SystemTime {
79        if let Some(Some(time_str)) = self.tags.get("time") {
80            // Parse ISO 8601 timestamp from server-time capability
81            if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(time_str) {
82                return SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(dt.timestamp() as u64);
83            }
84        }
85        SystemTime::now()
86    }
87
88    /// Get the message ID from tags (if present)
89    pub fn get_msgid(&self) -> Option<&str> {
90        self.tags.get("msgid").and_then(|v| v.as_deref())
91    }
92
93    /// Get the account tag (if present)
94    pub fn get_account(&self) -> Option<&str> {
95        self.tags.get("account").and_then(|v| v.as_deref())
96    }
97
98    /// Check if this message has a specific tag
99    pub fn has_tag(&self, key: &str) -> bool {
100        self.tags.contains_key(key)
101    }
102
103    /// Get a tag value
104    pub fn get_tag(&self, key: &str) -> Option<&Option<String>> {
105        self.tags.get(key)
106    }
107
108    /// Check if this is a PRIVMSG or NOTICE
109    pub fn is_message(&self) -> bool {
110        matches!(self.command.as_str(), "PRIVMSG" | "NOTICE")
111    }
112
113    /// Check if this is a channel message (target starts with # or &)
114    pub fn is_channel_message(&self) -> bool {
115        self.is_message() && 
116        self.params.first()
117            .map(|target| target.starts_with('#') || target.starts_with('&'))
118            .unwrap_or(false)
119    }
120
121    /// Get the target of a message (first parameter)
122    pub fn target(&self) -> Option<&str> {
123        self.params.first().map(|s| s.as_str())
124    }
125
126    /// Get the message text (last parameter, typically)
127    pub fn text(&self) -> Option<&str> {
128        self.params.last().map(|s| s.as_str())
129    }
130
131    /// Validate the message for security issues
132    fn validate_security(&self) -> Result<()> {
133        // Validate command length
134        if self.command.len() > 32 {
135            return Err(IronError::SecurityViolation(
136                "Command too long".to_string()
137            ));
138        }
139
140        // Validate parameter count
141        if self.params.len() > MAX_PARAMS {
142            return Err(IronError::SecurityViolation(
143                "Too many parameters".to_string()
144            ));
145        }
146
147        // Validate each parameter
148        for param in &self.params {
149            // CAP messages can have very long capability lists, allow up to 4KB for them
150            let max_param_len = if self.command == "CAP" {
151                4096
152            } else {
153                MAX_MESSAGE_LENGTH
154            };
155            
156            if param.len() > max_param_len {
157                return Err(IronError::SecurityViolation(
158                    "Parameter too long".to_string()
159                ));
160            }
161            
162            // Check for invalid characters
163            if param.contains('\0') || param.contains('\r') || param.contains('\n') {
164                return Err(IronError::SecurityViolation(
165                    "Invalid characters in parameter".to_string()
166                ));
167            }
168            
169            // Validate ASCII characters only (for now)
170            if !param.is_ascii() {
171                return Err(IronError::SecurityViolation(
172                    "Non-ASCII characters in parameter".to_string()
173                ));
174            }
175        }
176
177        // Validate prefix
178        if let Some(prefix) = &self.prefix {
179            if prefix.len() > 255 || prefix.contains('\0') || prefix.contains(' ') {
180                return Err(IronError::SecurityViolation(
181                    "Invalid prefix".to_string()
182                ));
183            }
184        }
185
186        // Validate total tag length
187        let total_tag_length: usize = self.tags.iter()
188            .map(|(k, v)| k.len() + v.as_ref().map_or(0, |s| s.len()) + 2)
189            .sum();
190        
191        if total_tag_length > MAX_TAG_LENGTH {
192            return Err(IronError::SecurityViolation(
193                "Tags too long".to_string()
194            ));
195        }
196
197        Ok(())
198    }
199}
200
201impl FromStr for IrcMessage {
202    type Err = IronError;
203
204    fn from_str(line: &str) -> Result<Self> {
205        // Check total message length
206        if line.len() > MAX_MESSAGE_LENGTH + MAX_TAG_LENGTH {
207            return Err(IronError::SecurityViolation(
208                "Message too long".to_string()
209            ));
210        }
211
212        let line = line.trim_end_matches("\r\n");
213        let mut message = IrcMessage::new("");
214        let mut remaining = line;
215
216        // Parse tags if present
217        if remaining.starts_with('@') {
218            let space_pos = remaining.find(' ')
219                .ok_or_else(|| IronError::Parse("No space after tags".to_string()))?;
220            
221            let tag_str = &remaining[1..space_pos];
222            
223            // Check total tag length before parsing
224            if tag_str.len() > MAX_TAG_LENGTH {
225                return Err(IronError::SecurityViolation(
226                    "Tag section exceeds maximum length".to_string()
227                ));
228            }
229            
230            remaining = &remaining[space_pos + 1..];
231
232            // Parse individual tags
233            for tag in tag_str.split(';') {
234                if tag.is_empty() {
235                    continue;
236                }
237
238                let (key, value) = if let Some(eq_pos) = tag.find('=') {
239                    let key = &tag[..eq_pos];
240                    let value_str = &tag[eq_pos + 1..];
241                    let value = if value_str.is_empty() {
242                        None
243                    } else {
244                        Some(unescape_tag_value(value_str))
245                    };
246                    (key, value)
247                } else {
248                    (tag, None)
249                };
250
251                if !is_valid_tag_key(key) {
252                    return Err(IronError::SecurityViolation(
253                        format!("Invalid tag key: {}", key)
254                    ));
255                }
256
257                message.tags.insert(key.to_string(), value);
258            }
259        }
260
261        // Parse prefix if present
262        if remaining.starts_with(':') {
263            let space_pos = remaining.find(' ')
264                .ok_or_else(|| IronError::Parse("No space after prefix".to_string()))?;
265            
266            let prefix = &remaining[1..space_pos];
267            // Validate prefix doesn't contain spaces
268            if prefix.contains(' ') {
269                return Err(IronError::SecurityViolation(
270                    "Space in prefix".to_string()
271                ));
272            }
273            
274            message.prefix = Some(prefix.to_string());
275            remaining = &remaining[space_pos + 1..];
276        }
277
278        // Parse command and parameters
279        let mut parts: Vec<&str> = remaining.splitn(15, ' ').collect();
280        
281        if parts.is_empty() {
282            return Err(IronError::Parse("No command found".to_string()));
283        }
284
285        message.command = parts.remove(0).to_uppercase();
286
287        if !is_valid_command(&message.command) {
288            return Err(IronError::SecurityViolation(
289                format!("Invalid command: {}", message.command)
290            ));
291        }
292
293        // Parse parameters
294        for (i, part) in parts.iter().enumerate() {
295            if part.starts_with(':') && i > 0 {
296                // Trailing parameter - combine all remaining parts
297                let trailing = parts[i..].join(" ");
298                message.params.push(trailing[1..].to_string());
299                break;
300            } else {
301                message.params.push(part.to_string());
302            }
303        }
304
305        message.validate_security()?;
306        Ok(message)
307    }
308}
309
310impl std::fmt::Display for IrcMessage {
311    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312        // Write tags if present
313        if !self.tags.is_empty() {
314            write!(f, "@")?;
315            let mut first = true;
316            for (key, value) in &self.tags {
317                if !first {
318                    write!(f, ";")?;
319                }
320                first = false;
321                write!(f, "{}", key)?;
322                if let Some(val) = value {
323                    write!(f, "={}", escape_tag_value(val))?;
324                }
325            }
326            write!(f, " ")?;
327        }
328
329        // Write prefix if present
330        if let Some(prefix) = &self.prefix {
331            write!(f, ":{} ", prefix)?;
332        }
333
334        // Write command
335        write!(f, "{}", self.command)?;
336
337        // Write parameters
338        for (i, param) in self.params.iter().enumerate() {
339            if i == self.params.len() - 1 && (param.contains(' ') || param.starts_with(':')) {
340                write!(f, " :{}", param)?;
341            } else {
342                write!(f, " {}", param)?;
343            }
344        }
345
346        write!(f, "\r\n")
347    }
348}
349
350/// Unescape IRC tag values
351fn unescape_tag_value(value: &str) -> String {
352    value
353        .replace("\\:", ";")
354        .replace("\\s", " ")
355        .replace("\\\\", "\\")
356        .replace("\\r", "\r")
357        .replace("\\n", "\n")
358}
359
360/// Escape IRC tag values
361fn escape_tag_value(value: &str) -> String {
362    value
363        .replace("\\", "\\\\")
364        .replace(";", "\\:")
365        .replace(" ", "\\s")
366        .replace("\r", "\\r")
367        .replace("\n", "\\n")
368}
369
370/// Check if a tag key is valid
371fn is_valid_tag_key(key: &str) -> bool {
372    if key.is_empty() || key.len() > MAX_CAPABILITY_NAME_LENGTH {
373        return false;
374    }
375
376    key.chars().all(|c| {
377        c.is_ascii_alphanumeric() || 
378        c == '-' || c == '/' || c == '.' || c == '_' || c == '+'
379    })
380}
381
382/// Check if a command is valid
383fn is_valid_command(command: &str) -> bool {
384    if command.is_empty() || command.len() > 32 {
385        return false;
386    }
387
388    // Valid IRC commands are either:
389    // 1. Alphabetic commands (PRIVMSG, NOTICE, etc.)
390    // 2. Three-digit numeric replies (001, 372, etc.)
391    let is_alpha_command = command.chars().all(|c| c.is_ascii_alphabetic());
392    let is_numeric_reply = command.len() == 3 && command.chars().all(|c| c.is_ascii_digit());
393    
394    if !is_alpha_command && !is_numeric_reply {
395        return false;
396    }
397    
398    // Reject known non-IRC protocols
399    const INVALID_COMMANDS: &[&str] = &[
400        "GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH", // HTTP
401        "HELO", "EHLO", "MAIL", "RCPT", "DATA", "RSET", "VRFY", // SMTP
402        "SYST", "STAT", "RETR", "DELE", "UIDL", "APOP", // POP3
403        "AUTH", "LOGIN", "SELECT", "EXAMINE", "CREATE", "RENAME", // IMAP
404    ];
405    
406    !INVALID_COMMANDS.contains(&command)
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    #[test]
414    fn test_basic_message_parsing() {
415        let msg = "PRIVMSG #channel :Hello world".parse::<IrcMessage>().unwrap();
416        assert_eq!(msg.command, "PRIVMSG");
417        assert_eq!(msg.params, vec!["#channel", "Hello world"]);
418        assert!(msg.tags.is_empty());
419        assert!(msg.prefix.is_none());
420    }
421
422    #[test]
423    fn test_message_with_tags() {
424        let msg = "@time=2023-01-01T00:00:00.000Z PRIVMSG #channel :Hello"
425            .parse::<IrcMessage>().unwrap();
426        assert!(msg.tags.contains_key("time"));
427        assert_eq!(msg.command, "PRIVMSG");
428        assert_eq!(msg.params, vec!["#channel", "Hello"]);
429    }
430
431    #[test]
432    fn test_message_with_prefix() {
433        let msg = ":nick!user@host PRIVMSG #channel :Hello"
434            .parse::<IrcMessage>().unwrap();
435        assert_eq!(msg.prefix, Some("nick!user@host".to_string()));
436        assert_eq!(msg.command, "PRIVMSG");
437        assert_eq!(msg.params, vec!["#channel", "Hello"]);
438    }
439
440    #[test]
441    fn test_message_formatting() {
442        let msg = IrcMessage::new("PRIVMSG")
443            .with_params(vec!["#channel".to_string(), "Hello world".to_string()]);
444        let formatted = msg.to_string();
445        assert_eq!(formatted, "PRIVMSG #channel :Hello world\r\n");
446    }
447
448    #[test]
449    fn test_security_validation() {
450        let long_command = "A".repeat(100);
451        let result = format!("{} #channel :test", long_command).parse::<IrcMessage>();
452        assert!(matches!(result, Err(IronError::SecurityViolation(_))));
453    }
454
455    #[test]
456    fn test_helper_methods() {
457        let msg = "PRIVMSG #channel :Hello world".parse::<IrcMessage>().unwrap();
458        assert!(msg.is_message());
459        assert!(msg.is_channel_message());
460        assert_eq!(msg.target(), Some("#channel"));
461        assert_eq!(msg.text(), Some("Hello world"));
462    }
463}