1use serde::{Deserialize, Deserializer, Serialize, Serializer};
17use serde_repr::{Deserialize_repr, Serialize_repr};
18
19use crate::{
20 utils::{
21 byte_and_inner_idx, check_status_against_bits, compress_and_encode, decode_and_decompress,
22 },
23 Error, Result, UriBuf,
24};
25
26#[derive(Debug, Clone, Copy, PartialEq, Serialize_repr, Deserialize_repr)]
29#[repr(u8)]
30pub enum StatusBits {
31 One = 1,
33 Two = 2,
35 Four = 4,
37 Eight = 8,
39}
40
41impl StatusBits {
42 pub fn from_u8(bits: u8) -> Option<Self> {
45 match bits {
46 b if b == Self::One as u8 => Some(Self::One),
47 b if b == Self::Two as u8 => Some(Self::Two),
48 b if b == Self::Four as u8 => Some(Self::Four),
49 b if b == Self::Eight as u8 => Some(Self::Eight),
50 _ => None,
51 }
52 }
53}
54
55impl std::fmt::Display for StatusBits {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 write!(f, "{}", *self as u8)
58 }
59}
60
61#[derive(Debug)]
77pub struct StatusListInternal {
78 status_list: StatusList,
80
81 size: usize,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct StatusList {
101 bits: StatusBits,
104
105 #[serde(serialize_with = "serialize_lst", deserialize_with = "deserialize_lst")]
116 lst: Vec<u8>,
117
118 #[serde(skip_serializing_if = "Option::is_none")]
127 aggregation_uri: Option<UriBuf>,
128}
129
130fn serialize_lst<S: Serializer>(lst: &[u8], s: S) -> std::result::Result<S::Ok, S::Error> {
131 let lst_encoded =
132 compress_and_encode(lst).map_err(|e| serde::ser::Error::custom(format!("{:?}", e)))?;
133
134 s.serialize_str(&lst_encoded)
135}
136
137fn deserialize_lst<'de, D>(d: D) -> std::result::Result<Vec<u8>, D::Error>
138where
139 D: Deserializer<'de>,
140{
141 let lst_encoded = String::deserialize(d)?;
142
143 let lst = decode_and_decompress(lst_encoded)
144 .map_err(|e| serde::de::Error::custom(format!("{:?}", e)))?;
145
146 Ok(lst)
147}
148
149impl StatusListInternal {
150 pub fn new(bits: StatusBits, aggregation_uri: Option<UriBuf>) -> Self {
152 Self {
153 status_list: StatusList {
154 bits,
155 lst: Vec::new(),
156 aggregation_uri,
157 },
158 size: 0,
159 }
160 }
161
162 pub fn size(&self) -> usize {
164 self.size
165 }
166
167 pub fn status_list(&self) -> &StatusList {
169 &self.status_list
170 }
171
172 pub fn new_from_parts(
192 bits: StatusBits,
193 lst: Vec<u8>,
194 aggregation_uri: Option<UriBuf>,
195 size: usize,
196 ) -> Result<Self> {
197 if lst.is_empty() ^ (size == 0) {
200 return Err(bherror::Error::root(Error::InconsistentSize)
201 .ctx("`lst` not empty but the `size` is 0 or vice-versa"));
202 }
203
204 if size > 0 {
207 let (byte_idx, inner_idx) = byte_and_inner_idx(bits, size - 1);
209
210 if byte_idx + 1 != lst.len() {
212 return Err(bherror::Error::root(Error::InconsistentSize)
213 .ctx("`size` does not point to the last `byte`"));
214 }
215
216 let last_byte = lst.last().unwrap();
218
219 if *last_byte as u16 >> ((inner_idx + 1) * bits as u8) != 0 {
226 return Err(bherror::Error::root(Error::InconsistentSize)
227 .ctx("last `byte` is not empty after `size` elements"));
228 }
229 }
230
231 Ok(Self {
232 status_list: StatusList {
233 bits,
234 lst,
235 aggregation_uri,
236 },
237 size,
238 })
239 }
240
241 pub fn push(&mut self, status: u8) -> Result<usize> {
249 let list = &mut self.status_list;
250
251 check_status_against_bits(list.bits, status)?;
253
254 let (_, inner_idx) = byte_and_inner_idx(list.bits, self.size);
255
256 if inner_idx == 0 {
265 list.lst.push(0);
266 }
267
268 let last_byte = list.lst.last_mut().unwrap();
270
271 *last_byte |= status << (inner_idx * list.bits as u8);
273
274 self.size += 1;
275
276 Ok(self.size - 1)
277 }
278
279 pub fn update(&mut self, index: usize, status: u8) -> Result<()> {
287 if index >= self.size {
288 return Err(bherror::Error::root(Error::IndexOutOfBounds(
289 self.size, index,
290 )));
291 }
292
293 let list = &mut self.status_list;
294
295 check_status_against_bits(list.bits, status)?;
297
298 let (byte_idx, inner_idx) = byte_and_inner_idx(list.bits, index);
299
300 let byte = list
301 .lst
302 .get_mut(byte_idx)
303 .ok_or_else(|| bherror::Error::root(Error::IndexOutOfBounds(self.size, index)))?;
305
306 let bits_per_status = list.bits as u8;
307
308 let shift = inner_idx * bits_per_status;
310
311 let mask = ((1u16 << bits_per_status) - 1) as u8;
314
315 *byte = *byte & !(mask << shift) | (status << shift);
321
322 Ok(())
323 }
324}
325
326impl StatusList {
327 pub fn bits(&self) -> StatusBits {
329 self.bits
330 }
331
332 pub fn lst(&self) -> &[u8] {
334 &self.lst
335 }
336
337 pub fn aggregation_uri(&self) -> Option<&UriBuf> {
339 self.aggregation_uri.as_ref()
340 }
341
342 pub fn get(&self, index: usize) -> Option<u8> {
347 let (byte_idx, inner_idx) = byte_and_inner_idx(self.bits, index);
348
349 let mut byte = *self.lst.get(byte_idx)?;
350
351 let bits_per_status = self.bits as u8;
352
353 byte <<= 8 - (inner_idx + 1) * bits_per_status;
358 byte >>= 8 - bits_per_status;
359
360 Some(byte)
361 }
362}
363
364impl From<StatusListInternal> for StatusList {
365 fn from(list: StatusListInternal) -> Self {
366 list.status_list
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use std::ops::RangeInclusive;
373
374 use rand::{
375 distributions::{Distribution as _, Uniform},
376 thread_rng, Rng,
377 };
378 use serde_json::{Map, Value};
379
380 use super::*;
381 use crate::utils::statuses_per_byte;
382
383 fn json_value_to_object<T: Serialize>(value: T) -> Map<String, Value> {
384 if let Value::Object(obj) = serde_json::to_value(value).unwrap() {
385 obj
386 } else {
387 panic!("JSON value is not a JSON object")
388 }
389 }
390
391 fn get_lst(status_list: &StatusList) -> String {
392 if let Value::String(lst) = json_value_to_object(status_list).remove("lst").unwrap() {
393 lst
394 } else {
395 panic!("StatusList's `lst` field is not a `String`")
396 }
397 }
398
399 fn create_json_object(
400 bits: StatusBits,
401 lst: &str,
402 aggregation_uri: &Option<UriBuf>,
403 ) -> Map<String, Value> {
404 let mut expected = Map::new();
405
406 expected.insert("bits".to_owned(), (bits as u8).into());
407 expected.insert("lst".to_owned(), lst.into());
408
409 if let Some(aggregation_uri) = aggregation_uri {
410 expected.insert(
411 "aggregation_uri".to_owned(),
412 aggregation_uri.to_string().into(),
413 );
414 }
415
416 expected
417 }
418
419 fn random_statuses(
420 rng: &mut impl Rng,
421 size: usize,
422 bits: StatusBits,
423 ) -> impl Iterator<Item = u8> + '_ {
424 let max_status = ((1u16 << bits as u8) - 1) as u8;
425
426 let dist = Uniform::new_inclusive(0, max_status);
427
428 dist.sample_iter(rng).take(size)
429 }
430
431 fn test_status_list_push(
432 bits: StatusBits,
433 statuses: &[u8],
434 expected_lst: &str,
435 aggregation_uri: Option<UriBuf>,
436 ) {
437 let expected_json = create_json_object(bits, expected_lst, &aggregation_uri);
438
439 let mut status_list = StatusListInternal::new(bits, aggregation_uri);
440
441 for (i, &status) in statuses.iter().enumerate() {
442 let new_i = status_list.push(status).unwrap();
443
444 assert_eq!(i, new_i, "invalid returned index from `push`");
446 }
447
448 for (i, &status) in statuses.iter().enumerate() {
450 let get_res = status_list.status_list.get(i).unwrap();
451 assert_eq!(status, get_res, "`get` did not return the correct value");
452 }
453
454 assert_eq!(
456 expected_lst,
457 get_lst(&status_list.status_list),
458 "invalid `lst` value"
459 );
460
461 let actual_json = json_value_to_object(&status_list.status_list);
463 assert_eq!(
464 expected_json, actual_json,
465 "invalid `StatusList` serialization"
466 );
467 }
468
469 #[test]
473 fn test_status_list_one_bit_push() {
474 let bits = StatusBits::One;
475 let statuses = [1u8, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1];
476 let expected_lst = "eNrbuRgAAhcBXQ";
477
478 test_status_list_push(bits, &statuses, expected_lst, None);
479 }
480
481 #[test]
485 fn test_status_list_two_bits_push() {
486 let bits = StatusBits::Two;
487 let statuses = [1u8, 2, 0, 3, 0, 1, 0, 1, 1, 2, 3, 3];
488 let expected_lst = "eNo76fITAAPfAgc";
489
490 test_status_list_push(bits, &statuses, expected_lst, None);
491 }
492
493 fn test_status_list_push_large_status(bits: StatusBits) {
494 let mut status_list = StatusListInternal::new(bits, None);
495 let status = 1 << bits as u8;
496
497 let idx = status_list.push(status - 1).unwrap();
499 assert_eq!(0, idx);
500
501 let get_res = status_list.status_list.get(0).unwrap();
503 assert_eq!(status - 1, get_res);
504
505 let err = status_list.push(status).unwrap_err();
507 assert!(matches!(err.error, Error::StatusTooLarge(b, s) if b == bits && s == status));
508 }
509
510 #[test]
511 fn test_status_list_push_large_status_fails() {
512 test_status_list_push_large_status(StatusBits::One);
513 test_status_list_push_large_status(StatusBits::Two);
514 test_status_list_push_large_status(StatusBits::Four);
515 }
516
517 #[test]
518 fn test_status_list_eight_bits_push_max_value() {
519 let bits = StatusBits::Eight;
520 let status = u8::MAX;
521
522 let mut status_list = StatusListInternal::new(bits, None);
523
524 let idx = status_list.push(status).unwrap();
525 assert_eq!(0, idx);
526
527 let get_res = status_list.status_list.get(0).unwrap();
529 assert_eq!(status, get_res);
530 }
531
532 #[test]
536 fn test_status_list_deserialize() {
537 let bits = StatusBits::One;
538 let lst = "eNrbuRgAAhcBXQ";
539 let expected_lst = [0xb9u8, 0xa3];
541
542 let json_value = Value::Object(create_json_object(bits, lst, &None));
543
544 let status_list: StatusList = serde_json::from_value(json_value).unwrap();
545
546 assert_eq!(bits, status_list.bits);
547 assert_eq!(expected_lst, *status_list.lst);
548 }
549
550 #[test]
551 fn test_status_list_new_success() {
552 let bits = StatusBits::Four;
553 let aggregation_uri = None;
554
555 let status_list = StatusListInternal::new(bits, aggregation_uri.clone());
556
557 assert_eq!(0, status_list.size);
558 assert!(status_list.status_list.lst.is_empty());
559 assert_eq!(aggregation_uri, status_list.status_list.aggregation_uri);
560 assert_eq!(bits, status_list.status_list.bits);
561
562 let get_res = status_list.status_list.get(0);
564 assert!(get_res.is_none());
565 }
566
567 fn test_status_list_new_from_parts(
568 bits: StatusBits,
569 lst: Vec<u8>,
570 size: usize,
571 aggregation_uri: Option<UriBuf>,
572 ) {
573 let status_list =
574 StatusListInternal::new_from_parts(bits, lst.clone(), aggregation_uri.clone(), size)
575 .unwrap();
576
577 assert_eq!(size, status_list.size);
578 assert_eq!(lst, status_list.status_list.lst);
579 assert_eq!(aggregation_uri, status_list.status_list.aggregation_uri);
580 assert_eq!(bits, status_list.status_list.bits);
581 }
582
583 #[test]
584 fn test_status_list_new_from_parts_full_byte_success() {
585 let bits = StatusBits::Two;
586 let lst = vec![0xa3u8, 0x8f];
588 let size = 8;
589
590 test_status_list_new_from_parts(bits, lst, size, None);
591 }
592
593 #[test]
594 fn test_status_list_new_from_parts_partial_byte_success() {
595 let bits = StatusBits::Four;
596 let lst = vec![0xa3u8, 0x09];
598 let size = 3;
599
600 test_status_list_new_from_parts(bits, lst, size, None);
601 }
602
603 #[test]
604 fn test_status_list_new_from_parts_full_byte_fill_zeros_success() {
605 let bits = StatusBits::Four;
606 let lst = vec![0xa3u8, 0x09];
608 let size = 4;
609
610 test_status_list_new_from_parts(bits, lst, size, None);
611 }
612
613 #[test]
614 fn test_status_list_new_from_parts_empty_success() {
615 let bits = StatusBits::Eight;
616 let lst = Vec::new();
617 let size = 0;
618
619 test_status_list_new_from_parts(bits, lst, size, None);
620 }
621
622 #[test]
623 fn test_status_list_new_from_parts_empty_lst_not_size_fail() {
624 let bits = StatusBits::One;
625 let lst = Vec::new();
626 let size = 1;
627
628 let err = StatusListInternal::new_from_parts(bits, lst, None, size).unwrap_err();
629
630 assert!(matches!(err.error, Error::InconsistentSize));
631 }
632
633 #[test]
634 fn test_status_list_new_from_parts_zero_size_full_lst_fail() {
635 let bits = StatusBits::Eight;
636 let lst = vec![0x97u8, 0x03, 0xa1];
638 let size = 0;
639
640 let err = StatusListInternal::new_from_parts(bits, lst, None, size).unwrap_err();
641
642 assert!(matches!(err.error, Error::InconsistentSize));
643 }
644
645 #[test]
646 fn test_status_list_new_from_parts_size_not_last_byte_fail() {
647 let bits = StatusBits::Four;
648 let lst = vec![0x97u8, 0x03, 0xa1];
650 let size = 4;
651
652 let err = StatusListInternal::new_from_parts(bits, lst, None, size).unwrap_err();
653
654 assert!(matches!(err.error, Error::InconsistentSize));
655 }
656
657 #[test]
658 fn test_status_list_new_from_parts_size_mid_last_byte_fail() {
659 let bits = StatusBits::One;
660 let lst = vec![0x97u8, 0x03, 0xa1];
662 let size = 23;
663
664 let err = StatusListInternal::new_from_parts(bits, lst, None, size).unwrap_err();
665
666 assert!(matches!(err.error, Error::InconsistentSize));
667 }
668
669 #[test]
670 fn test_status_list_get_index_too_large() {
671 let bits = StatusBits::One;
672 let lst = vec![0x97u8, 0x03, 0xa1];
674 let size = lst.len() * statuses_per_byte(bits) as usize;
675
676 let list = StatusListInternal::new_from_parts(bits, lst, None, size)
677 .unwrap()
678 .status_list;
679
680 let get_res = list.get(size - 1).unwrap();
682 assert_eq!(1, get_res);
683
684 let get_res = list.get(size);
686 assert!(get_res.is_none());
687 }
688
689 #[test]
690 fn test_status_list_update_empty_list_fails() {
691 let mut status_list = StatusListInternal::new(StatusBits::Four, None);
692
693 let err = status_list.update(0, 3).unwrap_err();
694 assert!(matches!(err.error, Error::IndexOutOfBounds(0, 0)));
695 }
696
697 #[test]
698 fn test_status_list_update_index_too_large_fails() {
699 let mut status_list = StatusListInternal::new(StatusBits::Two, None);
700
701 status_list.push(3).unwrap();
702
703 let err = status_list.update(1, 1).unwrap_err();
704 assert!(matches!(err.error, Error::IndexOutOfBounds(1, 1)));
705
706 let err = status_list.update(5, 2).unwrap_err();
707 assert!(matches!(err.error, Error::IndexOutOfBounds(1, 5)));
708 }
709
710 fn test_status_list_update_status_too_large(bits: StatusBits) {
711 let mut status_list = StatusListInternal::new(bits, None);
712 status_list.push(0).unwrap();
713
714 let status = 1 << bits as u8;
715
716 status_list.update(0, status - 1).unwrap();
718
719 let get_res = status_list.status_list.get(0).unwrap();
721 assert_eq!(status - 1, get_res, "bits={}", bits);
722
723 let err = status_list.update(0, status).unwrap_err();
725 assert!(
726 matches!(err.error, Error::StatusTooLarge(b, s) if b == bits && s == status),
727 "bits={}",
728 bits
729 );
730 }
731
732 #[test]
733 fn test_status_list_update_status_too_large_fails() {
734 test_status_list_update_status_too_large(StatusBits::One);
735 test_status_list_update_status_too_large(StatusBits::Two);
736 test_status_list_update_status_too_large(StatusBits::Four);
737 }
738
739 #[test]
740 fn test_status_list_eight_bits_update_max_value_success() {
741 let status = u8::MAX;
742
743 let mut status_list = StatusListInternal::new(StatusBits::Eight, None);
744 status_list.push(0).unwrap();
745
746 status_list.update(0, status).unwrap();
747
748 let get_res = status_list.status_list.get(0).unwrap();
749 assert_eq!(status, get_res);
750 }
751
752 fn test_status_list_update_success(rng: &mut impl Rng, bits: StatusBits, statuses: &[u8]) {
753 let mut status_list = StatusListInternal::new(bits, None);
754
755 for status in random_statuses(rng, statuses.len(), bits) {
757 status_list.push(status).unwrap();
758 }
759
760 for (i, &status) in statuses.iter().enumerate() {
761 status_list.update(i, status).unwrap();
762
763 let get_res = status_list.status_list.get(i).unwrap();
764 assert_eq!(
765 status, get_res,
766 "index={}, statuses={:?}, status_list={:?}",
767 i, statuses, status_list
768 );
769 }
770
771 for (i, &status) in statuses.iter().enumerate() {
772 let get_res = status_list.status_list.get(i).unwrap();
773 assert_eq!(
774 status, get_res,
775 "index={}, statuses={:?}, status_list={:?}",
776 i, statuses, status_list
777 );
778 }
779 }
780
781 fn test_status_list_update_random(
782 rng: &mut impl Rng,
783 bits: StatusBits,
784 size_range: RangeInclusive<usize>,
785 ) {
786 let size = rng.gen_range(size_range);
788
789 let statuses: Vec<u8> = random_statuses(rng, size, bits).collect();
790
791 test_status_list_update_success(rng, bits, &statuses);
792 }
793
794 #[test]
795 fn test_status_list_update_random_success() {
796 let mut rng = thread_rng();
797
798 test_status_list_update_random(&mut rng, StatusBits::One, 50..=100);
799 test_status_list_update_random(&mut rng, StatusBits::Two, 50..=100);
800 test_status_list_update_random(&mut rng, StatusBits::Four, 50..=100);
801 test_status_list_update_random(&mut rng, StatusBits::Eight, 50..=100);
802 }
803}