ant_quic/
chat.rs

1//! Chat protocol implementation for QUIC streams
2//!
3//! This module provides a structured chat protocol for P2P communication
4//! over QUIC streams, including message types, serialization, and handling.
5
6use crate::nat_traversal_api::PeerId;
7use serde::{Deserialize, Serialize};
8use std::time::SystemTime;
9use thiserror::Error;
10
11/// Chat protocol version
12pub const CHAT_PROTOCOL_VERSION: u16 = 1;
13
14/// Maximum message size (1MB)
15pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
16
17/// Chat protocol errors
18#[derive(Error, Debug)]
19pub enum ChatError {
20    #[error("Serialization error: {0}")]
21    Serialization(String),
22
23    #[error("Deserialization error: {0}")]
24    Deserialization(String),
25
26    #[error("Message too large: {0} bytes (max: {1})")]
27    MessageTooLarge(usize, usize),
28
29    #[error("Invalid protocol version: {0}")]
30    InvalidProtocolVersion(u16),
31
32    #[error("Invalid message format")]
33    InvalidFormat,
34}
35
36/// Chat message types
37#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
38#[serde(tag = "type", rename_all = "snake_case")]
39pub enum ChatMessage {
40    /// User joined the chat
41    Join {
42        nickname: String,
43        peer_id: [u8; 32],
44        #[serde(with = "timestamp_serde")]
45        timestamp: SystemTime,
46    },
47
48    /// User left the chat
49    Leave {
50        nickname: String,
51        peer_id: [u8; 32],
52        #[serde(with = "timestamp_serde")]
53        timestamp: SystemTime,
54    },
55
56    /// Text message from user
57    Text {
58        nickname: String,
59        peer_id: [u8; 32],
60        text: String,
61        #[serde(with = "timestamp_serde")]
62        timestamp: SystemTime,
63    },
64
65    /// Status update from user
66    Status {
67        nickname: String,
68        peer_id: [u8; 32],
69        status: String,
70        #[serde(with = "timestamp_serde")]
71        timestamp: SystemTime,
72    },
73
74    /// Direct message to specific peer
75    Direct {
76        from_nickname: String,
77        from_peer_id: [u8; 32],
78        to_peer_id: [u8; 32],
79        text: String,
80        #[serde(with = "timestamp_serde")]
81        timestamp: SystemTime,
82    },
83
84    /// Typing indicator
85    Typing {
86        nickname: String,
87        peer_id: [u8; 32],
88        is_typing: bool,
89    },
90
91    /// Request peer list
92    PeerListRequest { peer_id: [u8; 32] },
93
94    /// Response with peer list
95    PeerListResponse { peers: Vec<PeerInfo> },
96}
97
98/// Information about a connected peer
99#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
100pub struct PeerInfo {
101    pub peer_id: [u8; 32],
102    pub nickname: String,
103    pub status: String,
104    #[serde(with = "timestamp_serde")]
105    pub joined_at: SystemTime,
106}
107
108/// Timestamp serialization module
109mod timestamp_serde {
110    use serde::{Deserialize, Deserializer, Serialize, Serializer};
111    use std::time::{Duration, SystemTime, UNIX_EPOCH};
112
113    pub(super) fn serialize<S>(time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
114    where
115        S: Serializer,
116    {
117        let duration = time
118            .duration_since(UNIX_EPOCH)
119            .map_err(serde::ser::Error::custom)?;
120        // Serialize as a tuple of (seconds, nanoseconds) to preserve full precision
121        let secs = duration.as_secs();
122        let nanos = duration.subsec_nanos();
123        (secs, nanos).serialize(serializer)
124    }
125
126    pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
127    where
128        D: Deserializer<'de>,
129    {
130        let (secs, nanos): (u64, u32) = Deserialize::deserialize(deserializer)?;
131        Ok(UNIX_EPOCH + Duration::new(secs, nanos))
132    }
133}
134
135/// Wire format for chat messages
136#[derive(Debug, Serialize, Deserialize)]
137struct ChatWireFormat {
138    /// Protocol version
139    version: u16,
140    /// Message payload
141    message: ChatMessage,
142}
143
144impl ChatMessage {
145    /// Create a new join message
146    pub fn join(nickname: String, peer_id: PeerId) -> Self {
147        Self::Join {
148            nickname,
149            peer_id: peer_id.0,
150            timestamp: SystemTime::now(),
151        }
152    }
153
154    /// Create a new leave message
155    pub fn leave(nickname: String, peer_id: PeerId) -> Self {
156        Self::Leave {
157            nickname,
158            peer_id: peer_id.0,
159            timestamp: SystemTime::now(),
160        }
161    }
162
163    /// Create a new text message
164    pub fn text(nickname: String, peer_id: PeerId, text: String) -> Self {
165        Self::Text {
166            nickname,
167            peer_id: peer_id.0,
168            text,
169            timestamp: SystemTime::now(),
170        }
171    }
172
173    /// Create a new status message
174    pub fn status(nickname: String, peer_id: PeerId, status: String) -> Self {
175        Self::Status {
176            nickname,
177            peer_id: peer_id.0,
178            status,
179            timestamp: SystemTime::now(),
180        }
181    }
182
183    /// Create a new direct message
184    pub fn direct(
185        from_nickname: String,
186        from_peer_id: PeerId,
187        to_peer_id: PeerId,
188        text: String,
189    ) -> Self {
190        Self::Direct {
191            from_nickname,
192            from_peer_id: from_peer_id.0,
193            to_peer_id: to_peer_id.0,
194            text,
195            timestamp: SystemTime::now(),
196        }
197    }
198
199    /// Create a typing indicator
200    pub fn typing(nickname: String, peer_id: PeerId, is_typing: bool) -> Self {
201        Self::Typing {
202            nickname,
203            peer_id: peer_id.0,
204            is_typing,
205        }
206    }
207
208    /// Serialize message to bytes
209    pub fn serialize(&self) -> Result<Vec<u8>, ChatError> {
210        let wire_format = ChatWireFormat {
211            version: CHAT_PROTOCOL_VERSION,
212            message: self.clone(),
213        };
214
215        let data = serde_json::to_vec(&wire_format)
216            .map_err(|e| ChatError::Serialization(e.to_string()))?;
217
218        if data.len() > MAX_MESSAGE_SIZE {
219            return Err(ChatError::MessageTooLarge(data.len(), MAX_MESSAGE_SIZE));
220        }
221
222        Ok(data)
223    }
224
225    /// Deserialize message from bytes
226    pub fn deserialize(data: &[u8]) -> Result<Self, ChatError> {
227        if data.len() > MAX_MESSAGE_SIZE {
228            return Err(ChatError::MessageTooLarge(data.len(), MAX_MESSAGE_SIZE));
229        }
230
231        let wire_format: ChatWireFormat =
232            serde_json::from_slice(data).map_err(|e| ChatError::Deserialization(e.to_string()))?;
233
234        if wire_format.version != CHAT_PROTOCOL_VERSION {
235            return Err(ChatError::InvalidProtocolVersion(wire_format.version));
236        }
237
238        Ok(wire_format.message)
239    }
240
241    /// Get the peer ID from the message
242    pub fn peer_id(&self) -> Option<PeerId> {
243        match self {
244            Self::Join { peer_id, .. }
245            | Self::Leave { peer_id, .. }
246            | Self::Text { peer_id, .. }
247            | Self::Status { peer_id, .. }
248            | Self::Typing { peer_id, .. }
249            | Self::PeerListRequest { peer_id, .. } => Some(PeerId(*peer_id)),
250            Self::Direct { from_peer_id, .. } => Some(PeerId(*from_peer_id)),
251            Self::PeerListResponse { .. } => None,
252        }
253    }
254
255    /// Get the nickname from the message
256    pub fn nickname(&self) -> Option<&str> {
257        match self {
258            Self::Join { nickname, .. }
259            | Self::Leave { nickname, .. }
260            | Self::Text { nickname, .. }
261            | Self::Status { nickname, .. }
262            | Self::Typing { nickname, .. } => Some(nickname),
263            Self::Direct { from_nickname, .. } => Some(from_nickname),
264            Self::PeerListRequest { .. } | Self::PeerListResponse { .. } => None,
265        }
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_message_serialization() {
275        let peer_id = PeerId([1u8; 32]);
276        let message = ChatMessage::text(
277            "test-user".to_string(),
278            peer_id,
279            "Hello, world!".to_string(),
280        );
281
282        // Serialize
283        let data = message.serialize().unwrap();
284        assert!(data.len() < MAX_MESSAGE_SIZE);
285
286        // Deserialize
287        let deserialized = ChatMessage::deserialize(&data).unwrap();
288        assert_eq!(message, deserialized);
289    }
290
291    #[test]
292    fn test_all_message_types() {
293        let peer_id = PeerId([2u8; 32]);
294        let messages = vec![
295            ChatMessage::join("alice".to_string(), peer_id),
296            ChatMessage::leave("alice".to_string(), peer_id),
297            ChatMessage::text("alice".to_string(), peer_id, "Hello".to_string()),
298            ChatMessage::status("alice".to_string(), peer_id, "Away".to_string()),
299            ChatMessage::direct(
300                "alice".to_string(),
301                peer_id,
302                PeerId([3u8; 32]),
303                "Private message".to_string(),
304            ),
305            ChatMessage::typing("alice".to_string(), peer_id, true),
306            ChatMessage::PeerListRequest { peer_id: peer_id.0 },
307            ChatMessage::PeerListResponse {
308                peers: vec![PeerInfo {
309                    peer_id: peer_id.0,
310                    nickname: "alice".to_string(),
311                    status: "Online".to_string(),
312                    joined_at: SystemTime::now(),
313                }],
314            },
315        ];
316
317        for msg in messages {
318            let data = msg.serialize().unwrap();
319            let deserialized = ChatMessage::deserialize(&data).unwrap();
320            match (&msg, &deserialized) {
321                (
322                    ChatMessage::Join {
323                        nickname: n1,
324                        peer_id: p1,
325                        ..
326                    },
327                    ChatMessage::Join {
328                        nickname: n2,
329                        peer_id: p2,
330                        ..
331                    },
332                ) => {
333                    assert_eq!(n1, n2);
334                    assert_eq!(p1, p2);
335                }
336                _ => assert_eq!(msg, deserialized),
337            }
338        }
339    }
340
341    #[test]
342    fn test_message_too_large() {
343        let peer_id = PeerId([4u8; 32]);
344        let large_text = "a".repeat(MAX_MESSAGE_SIZE);
345        let message = ChatMessage::text("user".to_string(), peer_id, large_text);
346
347        match message.serialize() {
348            Err(ChatError::MessageTooLarge(_, _)) => {}
349            _ => panic!("Expected MessageTooLarge error"),
350        }
351    }
352
353    #[test]
354    fn test_invalid_version() {
355        let peer_id = PeerId([5u8; 32]);
356        let message = ChatMessage::text("user".to_string(), peer_id, "test".to_string());
357
358        // Create wire format with wrong version
359        let wire_format = ChatWireFormat {
360            version: 999,
361            message,
362        };
363
364        let data = serde_json::to_vec(&wire_format).unwrap();
365
366        match ChatMessage::deserialize(&data) {
367            Err(ChatError::InvalidProtocolVersion(999)) => {}
368            _ => panic!("Expected InvalidProtocolVersion error"),
369        }
370    }
371}