1use crate::nat_traversal_api::PeerId;
7use serde::{Deserialize, Serialize};
8use std::time::SystemTime;
9use thiserror::Error;
10
11pub const CHAT_PROTOCOL_VERSION: u16 = 1;
13
14pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
16
17#[derive(Error, Debug)]
19pub enum ChatError {
20 #[error("Serialization error: {0}")]
21 Serialization(String),
22
23 #[error("Deserialization error: {0}")]
24 Deserialization(String),
25
26 #[error("Message too large: {0} bytes (max: {1})")]
27 MessageTooLarge(usize, usize),
28
29 #[error("Invalid protocol version: {0}")]
30 InvalidProtocolVersion(u16),
31
32 #[error("Invalid message format")]
33 InvalidFormat,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
38#[serde(tag = "type", rename_all = "snake_case")]
39pub enum ChatMessage {
40 Join {
42 nickname: String,
43 peer_id: [u8; 32],
44 #[serde(with = "timestamp_serde")]
45 timestamp: SystemTime,
46 },
47
48 Leave {
50 nickname: String,
51 peer_id: [u8; 32],
52 #[serde(with = "timestamp_serde")]
53 timestamp: SystemTime,
54 },
55
56 Text {
58 nickname: String,
59 peer_id: [u8; 32],
60 text: String,
61 #[serde(with = "timestamp_serde")]
62 timestamp: SystemTime,
63 },
64
65 Status {
67 nickname: String,
68 peer_id: [u8; 32],
69 status: String,
70 #[serde(with = "timestamp_serde")]
71 timestamp: SystemTime,
72 },
73
74 Direct {
76 from_nickname: String,
77 from_peer_id: [u8; 32],
78 to_peer_id: [u8; 32],
79 text: String,
80 #[serde(with = "timestamp_serde")]
81 timestamp: SystemTime,
82 },
83
84 Typing {
86 nickname: String,
87 peer_id: [u8; 32],
88 is_typing: bool,
89 },
90
91 PeerListRequest { peer_id: [u8; 32] },
93
94 PeerListResponse { peers: Vec<PeerInfo> },
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
100pub struct PeerInfo {
101 pub peer_id: [u8; 32],
102 pub nickname: String,
103 pub status: String,
104 #[serde(with = "timestamp_serde")]
105 pub joined_at: SystemTime,
106}
107
108mod timestamp_serde {
110 use serde::{Deserialize, Deserializer, Serialize, Serializer};
111 use std::time::{Duration, SystemTime, UNIX_EPOCH};
112
113 pub(super) fn serialize<S>(time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
114 where
115 S: Serializer,
116 {
117 let duration = time
118 .duration_since(UNIX_EPOCH)
119 .map_err(serde::ser::Error::custom)?;
120 let secs = duration.as_secs();
122 let nanos = duration.subsec_nanos();
123 (secs, nanos).serialize(serializer)
124 }
125
126 pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
127 where
128 D: Deserializer<'de>,
129 {
130 let (secs, nanos): (u64, u32) = Deserialize::deserialize(deserializer)?;
131 Ok(UNIX_EPOCH + Duration::new(secs, nanos))
132 }
133}
134
135#[derive(Debug, Serialize, Deserialize)]
137struct ChatWireFormat {
138 version: u16,
140 message: ChatMessage,
142}
143
144impl ChatMessage {
145 pub fn join(nickname: String, peer_id: PeerId) -> Self {
147 Self::Join {
148 nickname,
149 peer_id: peer_id.0,
150 timestamp: SystemTime::now(),
151 }
152 }
153
154 pub fn leave(nickname: String, peer_id: PeerId) -> Self {
156 Self::Leave {
157 nickname,
158 peer_id: peer_id.0,
159 timestamp: SystemTime::now(),
160 }
161 }
162
163 pub fn text(nickname: String, peer_id: PeerId, text: String) -> Self {
165 Self::Text {
166 nickname,
167 peer_id: peer_id.0,
168 text,
169 timestamp: SystemTime::now(),
170 }
171 }
172
173 pub fn status(nickname: String, peer_id: PeerId, status: String) -> Self {
175 Self::Status {
176 nickname,
177 peer_id: peer_id.0,
178 status,
179 timestamp: SystemTime::now(),
180 }
181 }
182
183 pub fn direct(
185 from_nickname: String,
186 from_peer_id: PeerId,
187 to_peer_id: PeerId,
188 text: String,
189 ) -> Self {
190 Self::Direct {
191 from_nickname,
192 from_peer_id: from_peer_id.0,
193 to_peer_id: to_peer_id.0,
194 text,
195 timestamp: SystemTime::now(),
196 }
197 }
198
199 pub fn typing(nickname: String, peer_id: PeerId, is_typing: bool) -> Self {
201 Self::Typing {
202 nickname,
203 peer_id: peer_id.0,
204 is_typing,
205 }
206 }
207
208 pub fn serialize(&self) -> Result<Vec<u8>, ChatError> {
210 let wire_format = ChatWireFormat {
211 version: CHAT_PROTOCOL_VERSION,
212 message: self.clone(),
213 };
214
215 let data = serde_json::to_vec(&wire_format)
216 .map_err(|e| ChatError::Serialization(e.to_string()))?;
217
218 if data.len() > MAX_MESSAGE_SIZE {
219 return Err(ChatError::MessageTooLarge(data.len(), MAX_MESSAGE_SIZE));
220 }
221
222 Ok(data)
223 }
224
225 pub fn deserialize(data: &[u8]) -> Result<Self, ChatError> {
227 if data.len() > MAX_MESSAGE_SIZE {
228 return Err(ChatError::MessageTooLarge(data.len(), MAX_MESSAGE_SIZE));
229 }
230
231 let wire_format: ChatWireFormat =
232 serde_json::from_slice(data).map_err(|e| ChatError::Deserialization(e.to_string()))?;
233
234 if wire_format.version != CHAT_PROTOCOL_VERSION {
235 return Err(ChatError::InvalidProtocolVersion(wire_format.version));
236 }
237
238 Ok(wire_format.message)
239 }
240
241 pub fn peer_id(&self) -> Option<PeerId> {
243 match self {
244 Self::Join { peer_id, .. }
245 | Self::Leave { peer_id, .. }
246 | Self::Text { peer_id, .. }
247 | Self::Status { peer_id, .. }
248 | Self::Typing { peer_id, .. }
249 | Self::PeerListRequest { peer_id, .. } => Some(PeerId(*peer_id)),
250 Self::Direct { from_peer_id, .. } => Some(PeerId(*from_peer_id)),
251 Self::PeerListResponse { .. } => None,
252 }
253 }
254
255 pub fn nickname(&self) -> Option<&str> {
257 match self {
258 Self::Join { nickname, .. }
259 | Self::Leave { nickname, .. }
260 | Self::Text { nickname, .. }
261 | Self::Status { nickname, .. }
262 | Self::Typing { nickname, .. } => Some(nickname),
263 Self::Direct { from_nickname, .. } => Some(from_nickname),
264 Self::PeerListRequest { .. } | Self::PeerListResponse { .. } => None,
265 }
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_message_serialization() {
275 let peer_id = PeerId([1u8; 32]);
276 let message = ChatMessage::text(
277 "test-user".to_string(),
278 peer_id,
279 "Hello, world!".to_string(),
280 );
281
282 let data = message.serialize().unwrap();
284 assert!(data.len() < MAX_MESSAGE_SIZE);
285
286 let deserialized = ChatMessage::deserialize(&data).unwrap();
288 assert_eq!(message, deserialized);
289 }
290
291 #[test]
292 fn test_all_message_types() {
293 let peer_id = PeerId([2u8; 32]);
294 let messages = vec![
295 ChatMessage::join("alice".to_string(), peer_id),
296 ChatMessage::leave("alice".to_string(), peer_id),
297 ChatMessage::text("alice".to_string(), peer_id, "Hello".to_string()),
298 ChatMessage::status("alice".to_string(), peer_id, "Away".to_string()),
299 ChatMessage::direct(
300 "alice".to_string(),
301 peer_id,
302 PeerId([3u8; 32]),
303 "Private message".to_string(),
304 ),
305 ChatMessage::typing("alice".to_string(), peer_id, true),
306 ChatMessage::PeerListRequest { peer_id: peer_id.0 },
307 ChatMessage::PeerListResponse {
308 peers: vec![PeerInfo {
309 peer_id: peer_id.0,
310 nickname: "alice".to_string(),
311 status: "Online".to_string(),
312 joined_at: SystemTime::now(),
313 }],
314 },
315 ];
316
317 for msg in messages {
318 let data = msg.serialize().unwrap();
319 let deserialized = ChatMessage::deserialize(&data).unwrap();
320 match (&msg, &deserialized) {
321 (
322 ChatMessage::Join {
323 nickname: n1,
324 peer_id: p1,
325 ..
326 },
327 ChatMessage::Join {
328 nickname: n2,
329 peer_id: p2,
330 ..
331 },
332 ) => {
333 assert_eq!(n1, n2);
334 assert_eq!(p1, p2);
335 }
336 _ => assert_eq!(msg, deserialized),
337 }
338 }
339 }
340
341 #[test]
342 fn test_message_too_large() {
343 let peer_id = PeerId([4u8; 32]);
344 let large_text = "a".repeat(MAX_MESSAGE_SIZE);
345 let message = ChatMessage::text("user".to_string(), peer_id, large_text);
346
347 match message.serialize() {
348 Err(ChatError::MessageTooLarge(_, _)) => {}
349 _ => panic!("Expected MessageTooLarge error"),
350 }
351 }
352
353 #[test]
354 fn test_invalid_version() {
355 let peer_id = PeerId([5u8; 32]);
356 let message = ChatMessage::text("user".to_string(), peer_id, "test".to_string());
357
358 let wire_format = ChatWireFormat {
360 version: 999,
361 message,
362 };
363
364 let data = serde_json::to_vec(&wire_format).unwrap();
365
366 match ChatMessage::deserialize(&data) {
367 Err(ChatError::InvalidProtocolVersion(999)) => {}
368 _ => panic!("Expected InvalidProtocolVersion error"),
369 }
370 }
371}