ant_quic/
chat.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8//! Chat protocol implementation for QUIC streams
9//!
10//! This module provides a structured chat protocol for P2P communication
11//! over QUIC streams, including message types, serialization, and handling.
12
13use crate::nat_traversal_api::PeerId;
14use serde::{Deserialize, Serialize};
15use std::time::SystemTime;
16use thiserror::Error;
17
18/// Chat protocol version
19pub const CHAT_PROTOCOL_VERSION: u16 = 1;
20
21/// Maximum message size (1MB)
22pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
23
24/// Chat protocol errors
25#[derive(Error, Debug)]
26pub enum ChatError {
27    /// Message serialization failed
28    #[error("Serialization error: {0}")]
29    Serialization(String),
30
31    /// Message deserialization failed
32    #[error("Deserialization error: {0}")]
33    Deserialization(String),
34
35    /// Message exceeded the maximum allowed size
36    #[error("Message too large: {0} bytes (max: {1})")]
37    MessageTooLarge(usize, usize),
38
39    /// Unsupported or invalid protocol version
40    #[error("Invalid protocol version: {0}")]
41    InvalidProtocolVersion(u16),
42
43    /// Message failed schema validation
44    #[error("Invalid message format")]
45    InvalidFormat,
46}
47
48/// Chat message types
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
50#[serde(tag = "type", rename_all = "snake_case")]
51pub enum ChatMessage {
52    /// User joined the chat
53    Join {
54        /// Display name of the user
55        nickname: String,
56        /// Sender's peer identifier
57        peer_id: [u8; 32],
58        #[serde(with = "timestamp_serde")]
59        /// Time the event occurred
60        timestamp: SystemTime,
61    },
62
63    /// User left the chat
64    Leave {
65        /// Display name of the user
66        nickname: String,
67        /// Sender's peer identifier
68        peer_id: [u8; 32],
69        #[serde(with = "timestamp_serde")]
70        /// Time the event occurred
71        timestamp: SystemTime,
72    },
73
74    /// Text message from user
75    Text {
76        /// Display name of the user
77        nickname: String,
78        /// Sender's peer identifier
79        peer_id: [u8; 32],
80        /// UTF-8 message body
81        text: String,
82        #[serde(with = "timestamp_serde")]
83        /// Time the message was sent
84        timestamp: SystemTime,
85    },
86
87    /// Status update from user
88    Status {
89        /// Display name of the user
90        nickname: String,
91        /// Sender's peer identifier
92        peer_id: [u8; 32],
93        /// Arbitrary status string
94        status: String,
95        #[serde(with = "timestamp_serde")]
96        /// Time the status was set
97        timestamp: SystemTime,
98    },
99
100    /// Direct message to specific peer
101    Direct {
102        /// Sender nickname
103        from_nickname: String,
104        /// Sender peer ID
105        from_peer_id: [u8; 32],
106        /// Recipient peer ID
107        to_peer_id: [u8; 32],
108        /// Encrypted or plain text body
109        text: String,
110        #[serde(with = "timestamp_serde")]
111        /// Time the message was sent
112        timestamp: SystemTime,
113    },
114
115    /// Typing indicator
116    Typing {
117        /// Display name of the user
118        nickname: String,
119        /// Sender's peer identifier
120        peer_id: [u8; 32],
121        /// Whether the user is currently typing
122        is_typing: bool,
123    },
124
125    /// Request peer list
126    /// Request current peer list from the node
127    PeerListRequest {
128        /// Requestor's peer identifier
129        peer_id: [u8; 32],
130    },
131
132    /// Response with peer list
133    /// Response containing current peers
134    PeerListResponse {
135        /// List of known peers and metadata
136        peers: Vec<PeerInfo>,
137    },
138}
139
140/// Information about a connected peer
141#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
142pub struct PeerInfo {
143    /// Unique peer identifier
144    pub peer_id: [u8; 32],
145    /// Display name
146    pub nickname: String,
147    /// User status string
148    pub status: String,
149    #[serde(with = "timestamp_serde")]
150    /// When this peer joined
151    pub joined_at: SystemTime,
152}
153
154/// Timestamp serialization module
155mod timestamp_serde {
156    use serde::{Deserialize, Deserializer, Serialize, Serializer};
157    use std::time::{Duration, SystemTime, UNIX_EPOCH};
158
159    pub(super) fn serialize<S>(time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
160    where
161        S: Serializer,
162    {
163        let duration = time
164            .duration_since(UNIX_EPOCH)
165            .map_err(serde::ser::Error::custom)?;
166        // Serialize as a tuple of (seconds, nanoseconds) to preserve full precision
167        let secs = duration.as_secs();
168        let nanos = duration.subsec_nanos();
169        (secs, nanos).serialize(serializer)
170    }
171
172    pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
173    where
174        D: Deserializer<'de>,
175    {
176        let (secs, nanos): (u64, u32) = Deserialize::deserialize(deserializer)?;
177        Ok(UNIX_EPOCH + Duration::new(secs, nanos))
178    }
179}
180
181/// Wire format for chat messages
182#[derive(Debug, Serialize, Deserialize)]
183struct ChatWireFormat {
184    /// Protocol version
185    version: u16,
186    /// Message payload
187    message: ChatMessage,
188}
189
190impl ChatMessage {
191    /// Create a new join message
192    pub fn join(nickname: String, peer_id: PeerId) -> Self {
193        Self::Join {
194            nickname,
195            peer_id: peer_id.0,
196            timestamp: SystemTime::now(),
197        }
198    }
199
200    /// Create a new leave message
201    pub fn leave(nickname: String, peer_id: PeerId) -> Self {
202        Self::Leave {
203            nickname,
204            peer_id: peer_id.0,
205            timestamp: SystemTime::now(),
206        }
207    }
208
209    /// Create a new text message
210    pub fn text(nickname: String, peer_id: PeerId, text: String) -> Self {
211        Self::Text {
212            nickname,
213            peer_id: peer_id.0,
214            text,
215            timestamp: SystemTime::now(),
216        }
217    }
218
219    /// Create a new status message
220    pub fn status(nickname: String, peer_id: PeerId, status: String) -> Self {
221        Self::Status {
222            nickname,
223            peer_id: peer_id.0,
224            status,
225            timestamp: SystemTime::now(),
226        }
227    }
228
229    /// Create a new direct message
230    pub fn direct(
231        from_nickname: String,
232        from_peer_id: PeerId,
233        to_peer_id: PeerId,
234        text: String,
235    ) -> Self {
236        Self::Direct {
237            from_nickname,
238            from_peer_id: from_peer_id.0,
239            to_peer_id: to_peer_id.0,
240            text,
241            timestamp: SystemTime::now(),
242        }
243    }
244
245    /// Create a typing indicator
246    pub fn typing(nickname: String, peer_id: PeerId, is_typing: bool) -> Self {
247        Self::Typing {
248            nickname,
249            peer_id: peer_id.0,
250            is_typing,
251        }
252    }
253
254    /// Serialize message to bytes
255    pub fn serialize(&self) -> Result<Vec<u8>, ChatError> {
256        let wire_format = ChatWireFormat {
257            version: CHAT_PROTOCOL_VERSION,
258            message: self.clone(),
259        };
260
261        let data = serde_json::to_vec(&wire_format)
262            .map_err(|e| ChatError::Serialization(e.to_string()))?;
263
264        if data.len() > MAX_MESSAGE_SIZE {
265            return Err(ChatError::MessageTooLarge(data.len(), MAX_MESSAGE_SIZE));
266        }
267
268        Ok(data)
269    }
270
271    /// Deserialize message from bytes
272    pub fn deserialize(data: &[u8]) -> Result<Self, ChatError> {
273        if data.len() > MAX_MESSAGE_SIZE {
274            return Err(ChatError::MessageTooLarge(data.len(), MAX_MESSAGE_SIZE));
275        }
276
277        let wire_format: ChatWireFormat =
278            serde_json::from_slice(data).map_err(|e| ChatError::Deserialization(e.to_string()))?;
279
280        if wire_format.version != CHAT_PROTOCOL_VERSION {
281            return Err(ChatError::InvalidProtocolVersion(wire_format.version));
282        }
283
284        Ok(wire_format.message)
285    }
286
287    /// Get the peer ID from the message
288    pub fn peer_id(&self) -> Option<PeerId> {
289        match self {
290            Self::Join { peer_id, .. }
291            | Self::Leave { peer_id, .. }
292            | Self::Text { peer_id, .. }
293            | Self::Status { peer_id, .. }
294            | Self::Typing { peer_id, .. }
295            | Self::PeerListRequest { peer_id, .. } => Some(PeerId(*peer_id)),
296            Self::Direct { from_peer_id, .. } => Some(PeerId(*from_peer_id)),
297            Self::PeerListResponse { .. } => None,
298        }
299    }
300
301    /// Get the nickname from the message
302    pub fn nickname(&self) -> Option<&str> {
303        match self {
304            Self::Join { nickname, .. }
305            | Self::Leave { nickname, .. }
306            | Self::Text { nickname, .. }
307            | Self::Status { nickname, .. }
308            | Self::Typing { nickname, .. } => Some(nickname),
309            Self::Direct { from_nickname, .. } => Some(from_nickname),
310            Self::PeerListRequest { .. } | Self::PeerListResponse { .. } => None,
311        }
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_message_serialization() {
321        let peer_id = PeerId([1u8; 32]);
322        let message = ChatMessage::text(
323            "test-user".to_string(),
324            peer_id,
325            "Hello, world!".to_string(),
326        );
327
328        // Serialize
329        let data = message.serialize().unwrap();
330        assert!(data.len() < MAX_MESSAGE_SIZE);
331
332        // Deserialize
333        let deserialized = ChatMessage::deserialize(&data).unwrap();
334        assert_eq!(message, deserialized);
335    }
336
337    #[test]
338    fn test_all_message_types() {
339        let peer_id = PeerId([2u8; 32]);
340        let messages = vec![
341            ChatMessage::join("alice".to_string(), peer_id),
342            ChatMessage::leave("alice".to_string(), peer_id),
343            ChatMessage::text("alice".to_string(), peer_id, "Hello".to_string()),
344            ChatMessage::status("alice".to_string(), peer_id, "Away".to_string()),
345            ChatMessage::direct(
346                "alice".to_string(),
347                peer_id,
348                PeerId([3u8; 32]),
349                "Private message".to_string(),
350            ),
351            ChatMessage::typing("alice".to_string(), peer_id, true),
352            ChatMessage::PeerListRequest { peer_id: peer_id.0 },
353            ChatMessage::PeerListResponse {
354                peers: vec![PeerInfo {
355                    peer_id: peer_id.0,
356                    nickname: "alice".to_string(),
357                    status: "Online".to_string(),
358                    joined_at: SystemTime::now(),
359                }],
360            },
361        ];
362
363        for msg in messages {
364            let data = msg.serialize().unwrap();
365            let deserialized = ChatMessage::deserialize(&data).unwrap();
366            match (&msg, &deserialized) {
367                (
368                    ChatMessage::Join {
369                        nickname: n1,
370                        peer_id: p1,
371                        ..
372                    },
373                    ChatMessage::Join {
374                        nickname: n2,
375                        peer_id: p2,
376                        ..
377                    },
378                ) => {
379                    assert_eq!(n1, n2);
380                    assert_eq!(p1, p2);
381                }
382                _ => assert_eq!(msg, deserialized),
383            }
384        }
385    }
386
387    #[test]
388    fn test_message_too_large() {
389        let peer_id = PeerId([4u8; 32]);
390        let large_text = "a".repeat(MAX_MESSAGE_SIZE);
391        let message = ChatMessage::text("user".to_string(), peer_id, large_text);
392
393        match message.serialize() {
394            Err(ChatError::MessageTooLarge(_, _)) => {}
395            _ => panic!("Expected MessageTooLarge error"),
396        }
397    }
398
399    #[test]
400    fn test_invalid_version() {
401        let peer_id = PeerId([5u8; 32]);
402        let message = ChatMessage::text("user".to_string(), peer_id, "test".to_string());
403
404        // Create wire format with wrong version
405        let wire_format = ChatWireFormat {
406            version: 999,
407            message,
408        };
409
410        let data = serde_json::to_vec(&wire_format).unwrap();
411
412        match ChatMessage::deserialize(&data) {
413            Err(ChatError::InvalidProtocolVersion(999)) => {}
414            _ => panic!("Expected InvalidProtocolVersion error"),
415        }
416    }
417}