1use crate::nat_traversal_api::PeerId;
7use serde::{Deserialize, Serialize};
8use std::time::SystemTime;
9use thiserror::Error;
10
11pub const CHAT_PROTOCOL_VERSION: u16 = 1;
13
14pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
16
17#[derive(Error, Debug)]
19pub enum ChatError {
20 #[error("Serialization error: {0}")]
21 Serialization(String),
22
23 #[error("Deserialization error: {0}")]
24 Deserialization(String),
25
26 #[error("Message too large: {0} bytes (max: {1})")]
27 MessageTooLarge(usize, usize),
28
29 #[error("Invalid protocol version: {0}")]
30 InvalidProtocolVersion(u16),
31
32 #[error("Invalid message format")]
33 InvalidFormat,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
38#[serde(tag = "type", rename_all = "snake_case")]
39pub enum ChatMessage {
40 Join {
42 nickname: String,
43 peer_id: [u8; 32],
44 #[serde(with = "timestamp_serde")]
45 timestamp: SystemTime,
46 },
47
48 Leave {
50 nickname: String,
51 peer_id: [u8; 32],
52 #[serde(with = "timestamp_serde")]
53 timestamp: SystemTime,
54 },
55
56 Text {
58 nickname: String,
59 peer_id: [u8; 32],
60 text: String,
61 #[serde(with = "timestamp_serde")]
62 timestamp: SystemTime,
63 },
64
65 Status {
67 nickname: String,
68 peer_id: [u8; 32],
69 status: String,
70 #[serde(with = "timestamp_serde")]
71 timestamp: SystemTime,
72 },
73
74 Direct {
76 from_nickname: String,
77 from_peer_id: [u8; 32],
78 to_peer_id: [u8; 32],
79 text: String,
80 #[serde(with = "timestamp_serde")]
81 timestamp: SystemTime,
82 },
83
84 Typing {
86 nickname: String,
87 peer_id: [u8; 32],
88 is_typing: bool,
89 },
90
91 PeerListRequest {
93 peer_id: [u8; 32],
94 },
95
96 PeerListResponse {
98 peers: Vec<PeerInfo>,
99 },
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
104pub struct PeerInfo {
105 pub peer_id: [u8; 32],
106 pub nickname: String,
107 pub status: String,
108 #[serde(with = "timestamp_serde")]
109 pub joined_at: SystemTime,
110}
111
112mod timestamp_serde {
114 use serde::{Deserialize, Deserializer, Serialize, Serializer};
115 use std::time::{Duration, SystemTime, UNIX_EPOCH};
116
117 pub fn serialize<S>(time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
118 where
119 S: Serializer,
120 {
121 let duration = time.duration_since(UNIX_EPOCH)
122 .map_err(serde::ser::Error::custom)?;
123 let secs = duration.as_secs();
125 let nanos = duration.subsec_nanos();
126 (secs, nanos).serialize(serializer)
127 }
128
129 pub fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
130 where
131 D: Deserializer<'de>,
132 {
133 let (secs, nanos): (u64, u32) = Deserialize::deserialize(deserializer)?;
134 Ok(UNIX_EPOCH + Duration::new(secs, nanos))
135 }
136}
137
138#[derive(Debug, Serialize, Deserialize)]
140struct ChatWireFormat {
141 version: u16,
143 message: ChatMessage,
145}
146
147impl ChatMessage {
148 pub fn join(nickname: String, peer_id: PeerId) -> Self {
150 ChatMessage::Join {
151 nickname,
152 peer_id: peer_id.0,
153 timestamp: SystemTime::now(),
154 }
155 }
156
157 pub fn leave(nickname: String, peer_id: PeerId) -> Self {
159 ChatMessage::Leave {
160 nickname,
161 peer_id: peer_id.0,
162 timestamp: SystemTime::now(),
163 }
164 }
165
166 pub fn text(nickname: String, peer_id: PeerId, text: String) -> Self {
168 ChatMessage::Text {
169 nickname,
170 peer_id: peer_id.0,
171 text,
172 timestamp: SystemTime::now(),
173 }
174 }
175
176 pub fn status(nickname: String, peer_id: PeerId, status: String) -> Self {
178 ChatMessage::Status {
179 nickname,
180 peer_id: peer_id.0,
181 status,
182 timestamp: SystemTime::now(),
183 }
184 }
185
186 pub fn direct(from_nickname: String, from_peer_id: PeerId, to_peer_id: PeerId, text: String) -> Self {
188 ChatMessage::Direct {
189 from_nickname,
190 from_peer_id: from_peer_id.0,
191 to_peer_id: to_peer_id.0,
192 text,
193 timestamp: SystemTime::now(),
194 }
195 }
196
197 pub fn typing(nickname: String, peer_id: PeerId, is_typing: bool) -> Self {
199 ChatMessage::Typing {
200 nickname,
201 peer_id: peer_id.0,
202 is_typing,
203 }
204 }
205
206 pub fn serialize(&self) -> Result<Vec<u8>, ChatError> {
208 let wire_format = ChatWireFormat {
209 version: CHAT_PROTOCOL_VERSION,
210 message: self.clone(),
211 };
212
213 let data = serde_json::to_vec(&wire_format)
214 .map_err(|e| ChatError::Serialization(e.to_string()))?;
215
216 if data.len() > MAX_MESSAGE_SIZE {
217 return Err(ChatError::MessageTooLarge(data.len(), MAX_MESSAGE_SIZE));
218 }
219
220 Ok(data)
221 }
222
223 pub fn deserialize(data: &[u8]) -> Result<Self, ChatError> {
225 if data.len() > MAX_MESSAGE_SIZE {
226 return Err(ChatError::MessageTooLarge(data.len(), MAX_MESSAGE_SIZE));
227 }
228
229 let wire_format: ChatWireFormat = serde_json::from_slice(data)
230 .map_err(|e| ChatError::Deserialization(e.to_string()))?;
231
232 if wire_format.version != CHAT_PROTOCOL_VERSION {
233 return Err(ChatError::InvalidProtocolVersion(wire_format.version));
234 }
235
236 Ok(wire_format.message)
237 }
238
239 pub fn peer_id(&self) -> Option<PeerId> {
241 match self {
242 ChatMessage::Join { peer_id, .. } |
243 ChatMessage::Leave { peer_id, .. } |
244 ChatMessage::Text { peer_id, .. } |
245 ChatMessage::Status { peer_id, .. } |
246 ChatMessage::Typing { peer_id, .. } |
247 ChatMessage::PeerListRequest { peer_id, .. } => Some(PeerId(*peer_id)),
248 ChatMessage::Direct { from_peer_id, .. } => Some(PeerId(*from_peer_id)),
249 ChatMessage::PeerListResponse { .. } => None,
250 }
251 }
252
253 pub fn nickname(&self) -> Option<&str> {
255 match self {
256 ChatMessage::Join { nickname, .. } |
257 ChatMessage::Leave { nickname, .. } |
258 ChatMessage::Text { nickname, .. } |
259 ChatMessage::Status { nickname, .. } |
260 ChatMessage::Typing { nickname, .. } => Some(nickname),
261 ChatMessage::Direct { from_nickname, .. } => Some(from_nickname),
262 ChatMessage::PeerListRequest { .. } |
263 ChatMessage::PeerListResponse { .. } => None,
264 }
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn test_message_serialization() {
274 let peer_id = PeerId([1u8; 32]);
275 let message = ChatMessage::text("test-user".to_string(), peer_id, "Hello, world!".to_string());
276
277 let data = message.serialize().unwrap();
279 assert!(data.len() < MAX_MESSAGE_SIZE);
280
281 let deserialized = ChatMessage::deserialize(&data).unwrap();
283 assert_eq!(message, deserialized);
284 }
285
286 #[test]
287 fn test_all_message_types() {
288 let peer_id = PeerId([2u8; 32]);
289 let messages = vec![
290 ChatMessage::join("alice".to_string(), peer_id),
291 ChatMessage::leave("alice".to_string(), peer_id),
292 ChatMessage::text("alice".to_string(), peer_id, "Hello".to_string()),
293 ChatMessage::status("alice".to_string(), peer_id, "Away".to_string()),
294 ChatMessage::direct("alice".to_string(), peer_id, PeerId([3u8; 32]), "Private message".to_string()),
295 ChatMessage::typing("alice".to_string(), peer_id, true),
296 ChatMessage::PeerListRequest { peer_id: peer_id.0 },
297 ChatMessage::PeerListResponse {
298 peers: vec![
299 PeerInfo {
300 peer_id: peer_id.0,
301 nickname: "alice".to_string(),
302 status: "Online".to_string(),
303 joined_at: SystemTime::now(),
304 }
305 ],
306 },
307 ];
308
309 for msg in messages {
310 let data = msg.serialize().unwrap();
311 let deserialized = ChatMessage::deserialize(&data).unwrap();
312 match (&msg, &deserialized) {
313 (ChatMessage::Join { nickname: n1, peer_id: p1, .. },
314 ChatMessage::Join { nickname: n2, peer_id: p2, .. }) => {
315 assert_eq!(n1, n2);
316 assert_eq!(p1, p2);
317 }
318 _ => assert_eq!(msg, deserialized),
319 }
320 }
321 }
322
323 #[test]
324 fn test_message_too_large() {
325 let peer_id = PeerId([4u8; 32]);
326 let large_text = "a".repeat(MAX_MESSAGE_SIZE);
327 let message = ChatMessage::text("user".to_string(), peer_id, large_text);
328
329 match message.serialize() {
330 Err(ChatError::MessageTooLarge(_, _)) => {},
331 _ => panic!("Expected MessageTooLarge error"),
332 }
333 }
334
335 #[test]
336 fn test_invalid_version() {
337 let peer_id = PeerId([5u8; 32]);
338 let message = ChatMessage::text("user".to_string(), peer_id, "test".to_string());
339
340 let wire_format = ChatWireFormat {
342 version: 999,
343 message,
344 };
345
346 let data = serde_json::to_vec(&wire_format).unwrap();
347
348 match ChatMessage::deserialize(&data) {
349 Err(ChatError::InvalidProtocolVersion(999)) => {},
350 _ => panic!("Expected InvalidProtocolVersion error"),
351 }
352 }
353}