1use crate::nfc_uid::NfcUid;
11use bitmask_enum::bitmask;
12use byteorder::ReadBytesExt;
13use md5::Digest;
14use std::io::{self, Cursor, Read};
15use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
16use uuid::Uuid;
17
18#[derive(Debug, Clone)]
20pub struct Message {
21 pub message_type: u8,
23 pub payload: Vec<u8>,
25}
26
27impl Message {
28 pub fn new(message_type: u8, payload: Vec<u8>) -> Self {
30 Self {
31 message_type,
32 payload,
33 }
34 }
35
36 pub fn into_bytes(self) -> Vec<u8> {
42 let mut bytes = Vec::with_capacity(1 + 4 + self.payload.len());
43 bytes.push(self.message_type);
44 bytes.extend_from_slice(&(self.payload.len() as u32).to_le_bytes());
45 bytes.extend(self.payload);
46
47 bytes
48 }
49}
50
51#[derive(Clone, Copy, Debug, PartialEq, Eq)]
53pub struct DataHash(Digest);
54
55impl From<Digest> for DataHash {
56 fn from(value: Digest) -> Self {
57 Self(value)
58 }
59}
60
61impl DataHash {
62 pub fn as_bytes(&self) -> &[u8] {
64 self.0.as_ref()
65 }
66
67 pub fn into_tagged_bytes(self) -> [u8; 17] {
70 let mut bytes = [0u8; 17];
71 bytes[0] = 16;
72 bytes[1..].copy_from_slice(self.0.as_slice());
73
74 bytes
75 }
76
77 fn from_cursor_opt(cursor: &mut Cursor<Vec<u8>>) -> Result<Option<Self>, io::Error> {
79 let length = cursor.read_u8()?;
80
81 if length == 0 {
82 return Ok(None);
83 }
84
85 if length != 16 {
86 return Err(io::Error::new(
87 io::ErrorKind::InvalidInput,
88 format!("invalid data hash length: {length}"),
89 ));
90 }
91
92 let mut bytes = [0u8; 16];
93 cursor.read_exact(&mut bytes)?;
94
95 Ok(Some(Self(Digest(bytes))))
96 }
97}
98
99#[bitmask(u64)]
101#[bitmask_config(vec_debug)]
102pub enum Capabilities {
103 PreloadCheck = 0x1,
105}
106
107#[derive(Debug)]
109pub enum ClientMessage {
110 ClientHandshake { min_version: u8, max_version: u8 },
112
113 Authentication {
115 client_id: String,
116 client_secret: String,
117 ip_addr: IpAddr,
118 },
119
120 Ping,
122
123 Quit,
125
126 Bloop { nfc_uid: NfcUid },
128
129 RetrieveAudio { achievement_id: Uuid },
131
132 PreloadCheck {
134 audio_manifest_hash: Option<DataHash>,
135 },
136
137 Unknown(Message),
139}
140
141impl TryFrom<Message> for ClientMessage {
142 type Error = io::Error;
143
144 fn try_from(message: Message) -> Result<Self, Self::Error> {
148 let mut cursor = Cursor::new(message.payload);
149
150 match message.message_type {
151 0x01 => {
152 let min_version = cursor.read_u8()?;
153 let max_version = cursor.read_u8()?;
154
155 Ok(Self::ClientHandshake {
156 min_version,
157 max_version,
158 })
159 }
160 0x03 => {
161 let client_id = read_string(&mut cursor)?;
162 let client_secret = read_string(&mut cursor)?;
163 let ip_addr = read_ip_addr(&mut cursor)?;
164
165 Ok(Self::Authentication {
166 client_id,
167 client_secret,
168 ip_addr,
169 })
170 }
171 0x05 => Ok(Self::Ping),
172 0x07 => Ok(Self::Quit),
173 0x08 => {
174 let length = cursor.read_u8()? as usize;
175 let mut buffer = vec![0u8; length];
176 cursor.read_exact(&mut buffer)?;
177
178 let nfc_uid = NfcUid::try_from(buffer.as_slice()).map_err(|_| {
179 io::Error::new(io::ErrorKind::InvalidData, "Invalid NFC UID length or data")
180 })?;
181
182 Ok(Self::Bloop { nfc_uid })
183 }
184 0x0a => {
185 let mut uuid = [0u8; 16];
186 cursor.read_exact(&mut uuid)?;
187
188 Ok(Self::RetrieveAudio {
189 achievement_id: Uuid::from_bytes(uuid),
190 })
191 }
192 0x0c => {
193 let hash = DataHash::from_cursor_opt(&mut cursor)?;
194
195 Ok(Self::PreloadCheck {
196 audio_manifest_hash: hash,
197 })
198 }
199 code => Ok(Self::Unknown(Message::new(code, cursor.into_inner()))),
200 }
201 }
202}
203
204fn read_string(cursor: &mut Cursor<Vec<u8>>) -> Result<String, io::Error> {
208 let length = cursor.read_u8()? as usize;
209 let mut buffer = vec![0; length];
210 cursor.read_exact(&mut buffer)?;
211
212 String::from_utf8(buffer)
213 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid UTF-8 string"))
214}
215
216fn read_ip_addr(cursor: &mut Cursor<Vec<u8>>) -> Result<IpAddr, io::Error> {
221 let kind = cursor.read_u8()?;
222
223 match kind {
224 4 => {
225 let mut bytes = [0u8; 4];
226 cursor.read_exact(&mut bytes)?;
227 Ok(IpAddr::V4(Ipv4Addr::from(bytes)))
228 }
229 6 => {
230 let mut bytes = [0u8; 16];
231 cursor.read_exact(&mut bytes)?;
232 Ok(IpAddr::V6(Ipv6Addr::from(bytes)))
233 }
234 _ => Err(io::Error::new(
235 io::ErrorKind::InvalidData,
236 format!("invalid IP address type: {kind}"),
237 )),
238 }
239}
240
241#[derive(Debug)]
244pub struct AchievementRecord {
245 pub id: Uuid,
247 pub audio_file_hash: Option<DataHash>,
249}
250
251impl AchievementRecord {
252 pub fn into_bytes(self) -> Vec<u8> {
257 let mut bytes = Vec::with_capacity(16 + 17);
258 bytes.extend_from_slice(&self.id.into_bytes());
259
260 match self.audio_file_hash {
261 Some(hash) => bytes.extend_from_slice(&hash.into_tagged_bytes()),
262 None => bytes.push(0),
263 }
264
265 bytes
266 }
267}
268
269#[derive(Debug)]
271pub enum ErrorResponse {
272 UnexpectedMessage,
273 MalformedMessage,
274 UnsupportedVersionRange,
275 InvalidCredentials,
276 UnknownNfcUid,
277 NfcUidThrottled,
278 AudioUnavailable,
279 Custom(u8),
281}
282
283impl From<ErrorResponse> for u8 {
284 fn from(error: ErrorResponse) -> u8 {
285 error.into_error_code()
286 }
287}
288
289impl ErrorResponse {
290 fn into_error_code(self) -> u8 {
292 match self {
293 Self::UnexpectedMessage => 0,
294 Self::MalformedMessage => 1,
295 Self::UnsupportedVersionRange => 2,
296 Self::InvalidCredentials => 3,
297 Self::UnknownNfcUid => 4,
298 Self::NfcUidThrottled => 5,
299 Self::AudioUnavailable => 6,
300 Self::Custom(code) => code,
301 }
302 }
303}
304
305#[derive(Debug)]
307pub enum ServerMessage {
308 Error(ErrorResponse),
310
311 ServerHandshake {
314 accepted_version: u8,
315 capabilities: Capabilities,
316 },
317
318 AuthenticationAccepted,
320
321 Pong,
323
324 BloopAccepted {
326 achievements: Vec<AchievementRecord>,
327 },
328
329 AudioData { data: Vec<u8> },
331
332 PreloadMatch,
334
335 PreloadMismatch {
337 audio_manifest_hash: DataHash,
338 achievements: Vec<AchievementRecord>,
339 },
340
341 Custom(Message),
343}
344
345impl From<ServerMessage> for Message {
346 fn from(server_message: ServerMessage) -> Message {
349 match server_message {
350 ServerMessage::Error(error) => Message::new(0x00, vec![error.into_error_code()]),
351 ServerMessage::ServerHandshake {
352 accepted_version,
353 capabilities,
354 } => {
355 let mut payload = Vec::with_capacity(9);
356 payload.push(accepted_version);
357 payload.extend_from_slice(&capabilities.bits().to_le_bytes());
358 Message::new(0x02, payload)
359 }
360 ServerMessage::AuthenticationAccepted => Message::new(0x04, vec![]),
361 ServerMessage::Pong => Message::new(0x06, vec![]),
362 ServerMessage::BloopAccepted { achievements } => {
363 let mut payload = Vec::with_capacity(1 + achievements.len() * (16 + 17));
364 payload.push(achievements.len() as u8);
365
366 for achievement in achievements {
367 payload.extend(achievement.into_bytes())
368 }
369
370 Message::new(0x09, payload)
371 }
372 ServerMessage::AudioData { data } => {
373 let mut payload = Vec::with_capacity(4 + data.len());
374 payload.extend_from_slice(&(data.len() as u32).to_le_bytes());
375 payload.extend(data);
376
377 Message::new(0x0b, payload)
378 }
379 ServerMessage::PreloadMatch => Message::new(0x0d, vec![]),
380 ServerMessage::PreloadMismatch {
381 audio_manifest_hash,
382 achievements,
383 } => {
384 let mut payload = Vec::with_capacity(17 + 1 + achievements.len() * (16 + 17));
385 payload.extend_from_slice(&audio_manifest_hash.into_tagged_bytes());
386 payload.extend_from_slice(&(achievements.len() as u32).to_le_bytes());
387
388 for achievement in achievements {
389 payload.extend(achievement.into_bytes())
390 }
391
392 Message::new(0x0e, payload)
393 }
394 ServerMessage::Custom(message) => message,
395 }
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use std::net::{IpAddr, Ipv4Addr};
403 use uuid::Uuid;
404
405 fn make_message(msg_type: u8, payload: &[u8]) -> Message {
406 Message::new(msg_type, payload.to_vec())
407 }
408
409 #[test]
410 fn client_handshake_parses_correctly_from_message() {
411 let payload = [1u8, 5];
412 let msg = make_message(0x01, &payload);
413 let client_msg = ClientMessage::try_from(msg).unwrap();
414
415 match client_msg {
416 ClientMessage::ClientHandshake {
417 min_version,
418 max_version,
419 } => {
420 assert_eq!(min_version, 1);
421 assert_eq!(max_version, 5);
422 }
423 _ => panic!("Expected ClientHandshake variant"),
424 }
425 }
426
427 #[test]
428 fn authentication_parses_correctly_from_message() {
429 let mut payload = vec![];
430 payload.push(3);
431 payload.extend(b"foo");
432 payload.push(3);
433 payload.extend(b"bar");
434 payload.push(4);
435 payload.extend(&[127, 0, 0, 1]);
436
437 let msg = make_message(0x03, &payload);
438 let client_msg = ClientMessage::try_from(msg).unwrap();
439
440 match client_msg {
441 ClientMessage::Authentication {
442 client_id,
443 client_secret,
444 ip_addr,
445 } => {
446 assert_eq!(client_id, "foo");
447 assert_eq!(client_secret, "bar");
448 assert_eq!(ip_addr, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
449 }
450 _ => panic!("Expected Authentication variant"),
451 }
452 }
453
454 #[test]
455 fn ping_message_parses_as_ping_variant() {
456 let msg = make_message(0x05, &[]);
457 let client_msg = ClientMessage::try_from(msg).unwrap();
458 assert!(matches!(client_msg, ClientMessage::Ping));
459 }
460
461 #[test]
462 fn quit_message_parses_as_quit_variant() {
463 let msg = make_message(0x07, &[]);
464 let client_msg = ClientMessage::try_from(msg).unwrap();
465 assert!(matches!(client_msg, ClientMessage::Quit));
466 }
467
468 #[test]
469 fn bloop_message_parses_single_nfc_uid_correctly() {
470 let payload = [4u8, 1, 2, 3, 4];
471 let msg = make_message(0x08, &payload);
472 let client_msg = ClientMessage::try_from(msg).unwrap();
473
474 match client_msg {
475 ClientMessage::Bloop { nfc_uid } => {
476 let expected = NfcUid::try_from(&[1, 2, 3, 4][..]).unwrap();
477 assert_eq!(nfc_uid, expected);
478 }
479 _ => panic!("Expected Bloop variant"),
480 }
481 }
482
483 #[test]
484 fn retrieve_audio_message_parses_uuid_correctly() {
485 let uuid = Uuid::new_v4();
486 let payload = uuid.as_bytes();
487 let msg = make_message(0x0a, payload);
488 let client_msg = ClientMessage::try_from(msg).unwrap();
489
490 match client_msg {
491 ClientMessage::RetrieveAudio { achievement_id } => {
492 assert_eq!(achievement_id, uuid);
493 }
494 _ => panic!("Expected RetrieveAudio variant"),
495 }
496 }
497
498 #[test]
499 fn preload_check_message_parses_with_some_hash() {
500 let mut payload = vec![16];
501 payload.extend_from_slice(&[0u8; 16]);
502 let msg = make_message(0x0c, &payload);
503 let client_msg = ClientMessage::try_from(msg).unwrap();
504
505 match client_msg {
506 ClientMessage::PreloadCheck {
507 audio_manifest_hash,
508 } => {
509 assert!(audio_manifest_hash.is_some());
510 let hash = audio_manifest_hash.unwrap();
511 assert_eq!(hash.0.as_slice(), &[0u8; 16]);
512 }
513 _ => panic!("Expected PreloadCheck variant"),
514 }
515 }
516
517 #[test]
518 fn preload_check_message_parses_with_none_hash() {
519 let payload = [0];
520 let msg = make_message(0x0c, &payload);
521 let client_msg = ClientMessage::try_from(msg).unwrap();
522
523 match client_msg {
524 ClientMessage::PreloadCheck {
525 audio_manifest_hash,
526 } => {
527 assert!(audio_manifest_hash.is_none());
528 }
529 _ => panic!("Expected PreloadCheck variant"),
530 }
531 }
532
533 #[test]
534 fn unknown_message_parses_correctly_with_payload_preserved() {
535 let payload = [1, 2, 3];
536 let msg = make_message(0xFF, &payload);
537 let client_msg = ClientMessage::try_from(msg).unwrap();
538
539 match client_msg {
540 ClientMessage::Unknown(m) => {
541 assert_eq!(m.message_type, 0xFF);
542 assert_eq!(m.payload, payload);
543 }
544 _ => panic!("Expected Unknown variant"),
545 }
546 }
547
548 #[test]
549 fn client_handshake_fails_if_payload_too_short() {
550 let msg = make_message(0x01, &[1]);
551 assert!(ClientMessage::try_from(msg).is_err());
552 }
553
554 #[test]
555 fn authentication_fails_on_invalid_utf8_client_id() {
556 let mut payload = vec![2];
557 payload.extend(&[0xff, 0xff]);
558 payload.push(3);
559 payload.extend(b"bar");
560 payload.push(4);
561 payload.extend(&[127, 0, 0, 1]);
562
563 let msg = make_message(0x03, &payload);
564 assert!(ClientMessage::try_from(msg).is_err());
565 }
566
567 #[test]
568 fn authentication_fails_on_invalid_utf8_client_secret() {
569 let mut payload = vec![3];
570 payload.extend(b"foo");
571 payload.push(2);
572 payload.extend(&[0xff, 0xff]);
573 payload.push(4);
574 payload.extend(&[127, 0, 0, 1]);
575
576 let msg = make_message(0x03, &payload);
577 assert!(ClientMessage::try_from(msg).is_err());
578 }
579
580 #[test]
581 fn authentication_fails_on_invalid_ip_kind() {
582 let mut payload = vec![3];
583 payload.extend(b"foo");
584 payload.push(3);
585 payload.extend(b"bar");
586 payload.push(0xff);
587 payload.extend(&[1, 2, 3, 4]);
588
589 let msg = make_message(0x03, &payload);
590 assert!(ClientMessage::try_from(msg).is_err());
591 }
592
593 #[test]
594 fn bloop_fails_if_nfc_uid_length_mismatch() {
595 let payload = [5u8, 1, 2, 3, 4];
596 let msg = make_message(0x08, &payload);
597 assert!(ClientMessage::try_from(msg).is_err());
598 }
599
600 #[test]
601 fn retrieve_audio_fails_if_uuid_too_short() {
602 let payload = [0u8; 15];
603 let msg = make_message(0x0a, &payload);
604 assert!(ClientMessage::try_from(msg).is_err());
605 }
606
607 #[test]
608 fn preload_check_fails_on_invalid_length() {
609 let payload = [1, 0];
610 let msg = make_message(0x0c, &payload);
611 assert!(ClientMessage::try_from(msg).is_err());
612 }
613
614 #[test]
615 fn server_message_error_serializes_correctly() {
616 let server_msg = ServerMessage::Error(ErrorResponse::InvalidCredentials);
617 let message: Message = server_msg.into();
618 assert_eq!(message.message_type, 0x00);
619 assert_eq!(message.payload, vec![3]);
620 }
621
622 #[test]
623 fn server_handshake_serializes_correctly() {
624 let features = Capabilities::none();
625 let server_msg = ServerMessage::ServerHandshake {
626 accepted_version: 7,
627 capabilities: features,
628 };
629 let message: Message = server_msg.into();
630 assert_eq!(message.message_type, 0x02);
631 assert_eq!(message.payload.len(), 9);
632 assert_eq!(message.payload[0], 7);
633 assert_eq!(&message.payload[1..], &features.bits().to_le_bytes());
634 }
635
636 #[test]
637 fn authentication_accepted_serializes_to_empty_payload() {
638 let server_msg = ServerMessage::AuthenticationAccepted;
639 let message: Message = server_msg.into();
640 assert_eq!(message.message_type, 0x04);
641 assert!(message.payload.is_empty());
642 }
643
644 #[test]
645 fn pong_serializes_to_empty_payload() {
646 let server_msg = ServerMessage::Pong;
647 let message: Message = server_msg.into();
648 assert_eq!(message.message_type, 0x06);
649 assert!(message.payload.is_empty());
650 }
651
652 #[test]
653 fn bloop_accepted_serializes_with_achievements() {
654 let uuid = Uuid::new_v4();
655 let record = AchievementRecord {
656 id: uuid,
657 audio_file_hash: None,
658 };
659
660 let server_msg = ServerMessage::BloopAccepted {
661 achievements: vec![record],
662 };
663
664 let message: Message = server_msg.into();
665 assert_eq!(message.message_type, 0x09);
666 assert_eq!(message.payload[0], 1);
667
668 assert_eq!(&message.payload[1..17], uuid.as_bytes());
669 assert_eq!(message.payload.len(), 1 + 16 + 1);
670 }
671
672 #[test]
673 fn audio_data_serializes_correctly() {
674 let data = vec![1, 2, 3, 4, 5];
675 let server_msg = ServerMessage::AudioData { data: data.clone() };
676 let message: Message = server_msg.into();
677
678 assert_eq!(message.message_type, 0x0b);
679 assert_eq!(&message.payload[4..], &data[..]);
680
681 let length = u32::from_le_bytes(message.payload[0..4].try_into().unwrap());
682 assert_eq!(length as usize, data.len());
683 }
684
685 #[test]
686 fn preload_match_serializes_to_empty_payload() {
687 let server_msg = ServerMessage::PreloadMatch;
688 let message: Message = server_msg.into();
689
690 assert_eq!(message.message_type, 0x0d);
691 assert!(message.payload.is_empty());
692 }
693
694 #[test]
695 fn preload_mismatch_serializes_with_hash_and_achievements() {
696 let hash = DataHash(Digest([1u8; 16]));
697 let uuid = Uuid::new_v4();
698 let record = AchievementRecord {
699 id: uuid,
700 audio_file_hash: None,
701 };
702
703 let server_msg = ServerMessage::PreloadMismatch {
704 audio_manifest_hash: hash,
705 achievements: vec![record],
706 };
707
708 let message: Message = server_msg.into();
709 assert_eq!(message.message_type, 0x0e);
710
711 assert_eq!(message.payload[0], 16);
713 assert_eq!(&message.payload[1..17], &hash.0.as_slice()[..]);
714
715 let count = u32::from_le_bytes(message.payload[17..21].try_into().unwrap());
717 assert_eq!(count, 1);
718
719 assert_eq!(&message.payload[21..37], uuid.as_bytes());
721 }
722
723 #[test]
724 fn custom_server_message_passes_through_as_is() {
725 let original = Message::new(0xAB, vec![9, 8, 7]);
726 let server_msg = ServerMessage::Custom(original.clone());
727 let message: Message = server_msg.into();
728
729 assert_eq!(message.message_type, 0xAB);
730 assert_eq!(message.payload, vec![9, 8, 7]);
731 }
732}