use allocative::Allocative;
use serde::Serialize;
use crate::model::label::LabelModelError;
#[cfg(target_pointer_width = "32")]
pub const OS_ALIGNED_STATE_LEN: usize = 4;
#[cfg(target_pointer_width = "64")]
pub const OS_ALIGNED_STATE_LEN: usize = 8;
#[derive(Debug, Clone, Serialize, Allocative)]
pub struct U8StateVec {
state: Vec<u8>,
state_len: u8,
}
impl PartialEq for U8StateVec {
fn eq(&self, other: &Self) -> bool {
let len = self.state_len as usize;
let other_len = other.state_len as usize;
len == other_len && self.state[..len] == other.state[..other_len]
}
}
impl Eq for U8StateVec {}
impl std::hash::Hash for U8StateVec {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
let len = self.state_len as usize;
self.state_len.hash(state);
self.state[..len].hash(state);
}
}
impl U8StateVec {
pub fn new(state: &[u8]) -> Result<Self, LabelModelError> {
let state_len: u8 = state
.len()
.try_into()
.map_err(|_| LabelModelError::BadLabelVecSize(state.len(), u8::MAX as usize))?;
let total_data_len = state.len() + 1;
let remainder = total_data_len % OS_ALIGNED_STATE_LEN;
let padding_needed = if remainder != 0 {
OS_ALIGNED_STATE_LEN - remainder
} else {
0
};
let mut v = Vec::with_capacity(state.len() + padding_needed);
v.extend_from_slice(state);
v.resize(state.len() + padding_needed, 0);
Ok(U8StateVec {
state: v,
state_len,
})
}
pub fn get(&self, index: usize) -> Option<u8> {
if index < self.state_len as usize {
self.state.get(index).cloned()
} else {
None
}
}
pub fn len(&self) -> usize {
self.state_len as usize
}
pub fn is_empty(&self) -> bool {
self.state_len == 0
}
pub fn as_slice(&self) -> &[u8] {
let len = self.state_len as usize;
&self.state[..len]
}
pub fn storage_len(&self) -> usize {
self.state.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_vec() {
let u8_state_vec = U8StateVec::new(&[]).expect("empty vec should be valid");
assert!(u8_state_vec.is_empty());
assert_eq!(u8_state_vec.len(), 0);
assert_eq!(u8_state_vec.get(0), None);
assert!(u8_state_vec.storage_len() >= 1);
}
#[test]
fn test_max_size_limit() {
let max_valid = vec![0u8; 255];
assert!(U8StateVec::new(&max_valid).is_ok());
let too_big = vec![0u8; 256];
let result = U8StateVec::new(&too_big);
assert!(matches!(
result,
Err(LabelModelError::BadLabelVecSize(256, 255))
));
}
#[test]
fn test_padding_exact_fit() {
let data_len = OS_ALIGNED_STATE_LEN - 1;
let data = vec![0u8; data_len];
let vec = U8StateVec::new(&data).unwrap();
assert_eq!(vec.storage_len(), data_len);
}
#[test]
fn test_new_u8_state_valid() {
let state = vec![0u8; OS_ALIGNED_STATE_LEN];
let u8_state_vec = U8StateVec::new(&state).expect("test failed");
assert_eq!(u8_state_vec.as_slice(), state.as_slice());
}
#[test]
fn test_new_u8_state_aligned() {
let state = vec![1, 2, 3];
let u8_state_vec = U8StateVec::new(&state).expect("test failed");
assert_eq!(u8_state_vec.storage_len(), OS_ALIGNED_STATE_LEN - 1);
assert_eq!(u8_state_vec.len(), 3);
assert_eq!(u8_state_vec.as_slice(), &[1, 2, 3]);
}
#[test]
fn test_get() {
let state = vec![10, 20, 30];
let u8_state_vec = U8StateVec::new(&state).expect("test failed");
assert_eq!(u8_state_vec.get(0), Some(10));
assert_eq!(u8_state_vec.get(1), Some(20));
assert_eq!(u8_state_vec.get(2), Some(30));
assert_eq!(u8_state_vec.get(3), None);
}
#[test]
fn test_eq_and_hash() {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let state1 = vec![1, 2, 3];
let state2 = vec![1, 2, 3];
let state3 = vec![1, 2, 4];
let vec1 = U8StateVec::new(&state1).unwrap();
let vec2 = U8StateVec::new(&state2).unwrap();
let vec3 = U8StateVec::new(&state3).unwrap();
assert_eq!(vec1, vec2);
assert_ne!(vec1, vec3);
let mut h1 = DefaultHasher::new();
vec1.hash(&mut h1);
let mut h2 = DefaultHasher::new();
vec2.hash(&mut h2);
assert_eq!(h1.finish(), h2.finish());
}
}