1use crate::nat_traversal_api::PeerId;
14use serde::{Deserialize, Serialize};
15use std::time::SystemTime;
16use thiserror::Error;
17
18pub const CHAT_PROTOCOL_VERSION: u16 = 1;
20
21pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
23
24#[derive(Error, Debug)]
26pub enum ChatError {
27 #[error("Serialization error: {0}")]
29 Serialization(String),
30
31 #[error("Deserialization error: {0}")]
33 Deserialization(String),
34
35 #[error("Message too large: {0} bytes (max: {1})")]
37 MessageTooLarge(usize, usize),
38
39 #[error("Invalid protocol version: {0}")]
41 InvalidProtocolVersion(u16),
42
43 #[error("Invalid message format")]
45 InvalidFormat,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
50#[serde(tag = "type", rename_all = "snake_case")]
51pub enum ChatMessage {
52 Join {
54 nickname: String,
56 peer_id: [u8; 32],
58 #[serde(with = "timestamp_serde")]
59 timestamp: SystemTime,
61 },
62
63 Leave {
65 nickname: String,
67 peer_id: [u8; 32],
69 #[serde(with = "timestamp_serde")]
70 timestamp: SystemTime,
72 },
73
74 Text {
76 nickname: String,
78 peer_id: [u8; 32],
80 text: String,
82 #[serde(with = "timestamp_serde")]
83 timestamp: SystemTime,
85 },
86
87 Status {
89 nickname: String,
91 peer_id: [u8; 32],
93 status: String,
95 #[serde(with = "timestamp_serde")]
96 timestamp: SystemTime,
98 },
99
100 Direct {
102 from_nickname: String,
104 from_peer_id: [u8; 32],
106 to_peer_id: [u8; 32],
108 text: String,
110 #[serde(with = "timestamp_serde")]
111 timestamp: SystemTime,
113 },
114
115 Typing {
117 nickname: String,
119 peer_id: [u8; 32],
121 is_typing: bool,
123 },
124
125 PeerListRequest {
128 peer_id: [u8; 32],
130 },
131
132 PeerListResponse {
135 peers: Vec<PeerInfo>,
137 },
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
142pub struct PeerInfo {
143 pub peer_id: [u8; 32],
145 pub nickname: String,
147 pub status: String,
149 #[serde(with = "timestamp_serde")]
150 pub joined_at: SystemTime,
152}
153
154mod timestamp_serde {
156 use serde::{Deserialize, Deserializer, Serialize, Serializer};
157 use std::time::{Duration, SystemTime, UNIX_EPOCH};
158
159 pub(super) fn serialize<S>(time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
160 where
161 S: Serializer,
162 {
163 let duration = time
164 .duration_since(UNIX_EPOCH)
165 .map_err(serde::ser::Error::custom)?;
166 let secs = duration.as_secs();
168 let nanos = duration.subsec_nanos();
169 (secs, nanos).serialize(serializer)
170 }
171
172 pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
173 where
174 D: Deserializer<'de>,
175 {
176 let (secs, nanos): (u64, u32) = Deserialize::deserialize(deserializer)?;
177 Ok(UNIX_EPOCH + Duration::new(secs, nanos))
178 }
179}
180
181#[derive(Debug, Serialize, Deserialize)]
183struct ChatWireFormat {
184 version: u16,
186 message: ChatMessage,
188}
189
190impl ChatMessage {
191 pub fn join(nickname: String, peer_id: PeerId) -> Self {
193 Self::Join {
194 nickname,
195 peer_id: peer_id.0,
196 timestamp: SystemTime::now(),
197 }
198 }
199
200 pub fn leave(nickname: String, peer_id: PeerId) -> Self {
202 Self::Leave {
203 nickname,
204 peer_id: peer_id.0,
205 timestamp: SystemTime::now(),
206 }
207 }
208
209 pub fn text(nickname: String, peer_id: PeerId, text: String) -> Self {
211 Self::Text {
212 nickname,
213 peer_id: peer_id.0,
214 text,
215 timestamp: SystemTime::now(),
216 }
217 }
218
219 pub fn status(nickname: String, peer_id: PeerId, status: String) -> Self {
221 Self::Status {
222 nickname,
223 peer_id: peer_id.0,
224 status,
225 timestamp: SystemTime::now(),
226 }
227 }
228
229 pub fn direct(
231 from_nickname: String,
232 from_peer_id: PeerId,
233 to_peer_id: PeerId,
234 text: String,
235 ) -> Self {
236 Self::Direct {
237 from_nickname,
238 from_peer_id: from_peer_id.0,
239 to_peer_id: to_peer_id.0,
240 text,
241 timestamp: SystemTime::now(),
242 }
243 }
244
245 pub fn typing(nickname: String, peer_id: PeerId, is_typing: bool) -> Self {
247 Self::Typing {
248 nickname,
249 peer_id: peer_id.0,
250 is_typing,
251 }
252 }
253
254 pub fn serialize(&self) -> Result<Vec<u8>, ChatError> {
256 let wire_format = ChatWireFormat {
257 version: CHAT_PROTOCOL_VERSION,
258 message: self.clone(),
259 };
260
261 let data = serde_json::to_vec(&wire_format)
262 .map_err(|e| ChatError::Serialization(e.to_string()))?;
263
264 if data.len() > MAX_MESSAGE_SIZE {
265 return Err(ChatError::MessageTooLarge(data.len(), MAX_MESSAGE_SIZE));
266 }
267
268 Ok(data)
269 }
270
271 pub fn deserialize(data: &[u8]) -> Result<Self, ChatError> {
273 if data.len() > MAX_MESSAGE_SIZE {
274 return Err(ChatError::MessageTooLarge(data.len(), MAX_MESSAGE_SIZE));
275 }
276
277 let wire_format: ChatWireFormat =
278 serde_json::from_slice(data).map_err(|e| ChatError::Deserialization(e.to_string()))?;
279
280 if wire_format.version != CHAT_PROTOCOL_VERSION {
281 return Err(ChatError::InvalidProtocolVersion(wire_format.version));
282 }
283
284 Ok(wire_format.message)
285 }
286
287 pub fn peer_id(&self) -> Option<PeerId> {
289 match self {
290 Self::Join { peer_id, .. }
291 | Self::Leave { peer_id, .. }
292 | Self::Text { peer_id, .. }
293 | Self::Status { peer_id, .. }
294 | Self::Typing { peer_id, .. }
295 | Self::PeerListRequest { peer_id, .. } => Some(PeerId(*peer_id)),
296 Self::Direct { from_peer_id, .. } => Some(PeerId(*from_peer_id)),
297 Self::PeerListResponse { .. } => None,
298 }
299 }
300
301 pub fn nickname(&self) -> Option<&str> {
303 match self {
304 Self::Join { nickname, .. }
305 | Self::Leave { nickname, .. }
306 | Self::Text { nickname, .. }
307 | Self::Status { nickname, .. }
308 | Self::Typing { nickname, .. } => Some(nickname),
309 Self::Direct { from_nickname, .. } => Some(from_nickname),
310 Self::PeerListRequest { .. } | Self::PeerListResponse { .. } => None,
311 }
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn test_message_serialization() {
321 let peer_id = PeerId([1u8; 32]);
322 let message = ChatMessage::text(
323 "test-user".to_string(),
324 peer_id,
325 "Hello, world!".to_string(),
326 );
327
328 let data = message.serialize().unwrap();
330 assert!(data.len() < MAX_MESSAGE_SIZE);
331
332 let deserialized = ChatMessage::deserialize(&data).unwrap();
334 assert_eq!(message, deserialized);
335 }
336
337 #[test]
338 fn test_all_message_types() {
339 let peer_id = PeerId([2u8; 32]);
340 let messages = vec![
341 ChatMessage::join("alice".to_string(), peer_id),
342 ChatMessage::leave("alice".to_string(), peer_id),
343 ChatMessage::text("alice".to_string(), peer_id, "Hello".to_string()),
344 ChatMessage::status("alice".to_string(), peer_id, "Away".to_string()),
345 ChatMessage::direct(
346 "alice".to_string(),
347 peer_id,
348 PeerId([3u8; 32]),
349 "Private message".to_string(),
350 ),
351 ChatMessage::typing("alice".to_string(), peer_id, true),
352 ChatMessage::PeerListRequest { peer_id: peer_id.0 },
353 ChatMessage::PeerListResponse {
354 peers: vec![PeerInfo {
355 peer_id: peer_id.0,
356 nickname: "alice".to_string(),
357 status: "Online".to_string(),
358 joined_at: SystemTime::now(),
359 }],
360 },
361 ];
362
363 for msg in messages {
364 let data = msg.serialize().unwrap();
365 let deserialized = ChatMessage::deserialize(&data).unwrap();
366 match (&msg, &deserialized) {
367 (
368 ChatMessage::Join {
369 nickname: n1,
370 peer_id: p1,
371 ..
372 },
373 ChatMessage::Join {
374 nickname: n2,
375 peer_id: p2,
376 ..
377 },
378 ) => {
379 assert_eq!(n1, n2);
380 assert_eq!(p1, p2);
381 }
382 _ => assert_eq!(msg, deserialized),
383 }
384 }
385 }
386
387 #[test]
388 fn test_message_too_large() {
389 let peer_id = PeerId([4u8; 32]);
390 let large_text = "a".repeat(MAX_MESSAGE_SIZE);
391 let message = ChatMessage::text("user".to_string(), peer_id, large_text);
392
393 match message.serialize() {
394 Err(ChatError::MessageTooLarge(_, _)) => {}
395 _ => panic!("Expected MessageTooLarge error"),
396 }
397 }
398
399 #[test]
400 fn test_invalid_version() {
401 let peer_id = PeerId([5u8; 32]);
402 let message = ChatMessage::text("user".to_string(), peer_id, "test".to_string());
403
404 let wire_format = ChatWireFormat {
406 version: 999,
407 message,
408 };
409
410 let data = serde_json::to_vec(&wire_format).unwrap();
411
412 match ChatMessage::deserialize(&data) {
413 Err(ChatError::InvalidProtocolVersion(999)) => {}
414 _ => panic!("Expected InvalidProtocolVersion error"),
415 }
416 }
417}