use std::fmt::Display;
use allocative::Allocative;
use serde::Serialize;
use crate::model::{label::label_model_error::LabelModelError, network::VertexId};
#[cfg(target_pointer_width = "32")]
pub const OS_ALIGNED_STATE_LEN: usize = 4;
#[cfg(target_pointer_width = "64")]
pub const OS_ALIGNED_STATE_LEN: usize = 8;
#[derive(PartialEq, Eq, Hash, Debug, Clone, Serialize, Allocative)]
pub enum Label {
Vertex(VertexId),
VertexWithIntState {
vertex_id: VertexId,
state: usize,
},
VertexWithIntStateVec {
vertex_id: VertexId,
state: Vec<usize>,
},
VertexWithU8StateVec {
vertex_id: VertexId,
state: Vec<u8>,
state_len: u8,
},
}
impl Label {
pub fn new_u8_state(vertex_id: VertexId, state: &[u8]) -> Result<Self, LabelModelError> {
let mut label_state = state.to_vec();
let state_len: u8 = state
.len()
.try_into()
.map_err(|_| LabelModelError::BadLabelVecSize(state.len(), u8::MAX as usize))?;
let total_data_len = label_state.len() + 1;
let remainder = total_data_len % OS_ALIGNED_STATE_LEN;
if remainder != 0 {
let padding_needed = OS_ALIGNED_STATE_LEN - remainder;
label_state.extend(vec![0u8; padding_needed]);
}
Ok(Label::VertexWithU8StateVec {
vertex_id,
state_len,
state: label_state,
})
}
pub fn get_u8_state(&self) -> Option<&[u8]> {
match self {
Label::VertexWithU8StateVec {
state, state_len, ..
} => {
let len: usize = (*state_len).into();
Some(&state[0..len])
}
_ => None,
}
}
pub fn vertex_id(&self) -> &VertexId {
match self {
Label::Vertex(vertex_id) => vertex_id,
Label::VertexWithIntState { vertex_id, .. } => vertex_id,
Label::VertexWithIntStateVec { vertex_id, .. } => vertex_id,
Label::VertexWithU8StateVec { vertex_id, .. } => vertex_id,
}
}
}
impl Display for Label {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Label::Vertex(vertex_id) => write!(f, "Vertex({vertex_id})"),
Label::VertexWithIntState { vertex_id, state } => {
write!(f, "VertexWithIntState({vertex_id}, {state})")
}
Label::VertexWithIntStateVec { vertex_id, state } => {
write!(f, "VertexWithIntStateVec({vertex_id}, {state:?})")
}
Label::VertexWithU8StateVec {
vertex_id,
state_len,
state,
} => {
write!(
f,
"VertexWithU8StateVec({vertex_id}, {state_len}, {state:?})"
)
}
}
}
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use super::*;
#[test]
fn test_e2e_display_trip() {
let modes = ["walk", "bike", "drive", "tnc", "transit"];
let trip_sequence = [0, 2, 4, 0, 4, 3];
let vertex_id = VertexId(1234);
let label = Label::new_u8_state(vertex_id, &trip_sequence).unwrap();
println!("label storage: {}", label);
let out = label.get_u8_state().unwrap();
let trip_modes = out
.iter()
.map(|idx| modes[*idx as usize].to_string())
.join(",");
println!("[{}]", trip_modes);
assert_eq!(
trip_modes,
"walk,drive,transit,walk,transit,tnc".to_string()
);
}
#[test]
fn test_new_u8_state_valid() {
let vertex_id = VertexId(42);
let state = vec![0u8; OS_ALIGNED_STATE_LEN];
let label = Label::new_u8_state(vertex_id, &state).expect("test failed");
assert_eq!(label.vertex_id(), &vertex_id);
assert_eq!(label.get_u8_state(), Some(state.as_slice()));
}
#[test]
fn test_new_u8_state_aligned() {
let vertex_id = VertexId(42);
let state = vec![1, 2, 3];
let label = Label::new_u8_state(vertex_id, &state).expect("test failed");
match label {
Label::VertexWithU8StateVec {
state: inner_state, ..
} => {
assert_eq!(inner_state.len(), super::OS_ALIGNED_STATE_LEN - 1);
}
_ => panic!("wrong label variant!"),
};
}
#[test]
fn test_u8_state_none_for_other_variants() {
let vertex_id = VertexId(42);
let vertex_label = Label::Vertex(vertex_id);
assert_eq!(vertex_label.get_u8_state(), None);
let int_state_label = Label::VertexWithIntState {
vertex_id,
state: 123,
};
assert_eq!(int_state_label.get_u8_state(), None);
let valid_state = vec![1, 2, 3];
let u8_vec_label = Label::new_u8_state(vertex_id, &valid_state).expect("test failed");
let result = u8_vec_label.get_u8_state().unwrap();
let expected = [1, 2, 3];
assert_eq!(result, &expected);
}
#[test]
fn test_u8_state_label_display() {
let vertex_id = VertexId(42);
let state = vec![1u8; OS_ALIGNED_STATE_LEN];
let label = Label::new_u8_state(vertex_id, &state).expect("test failed");
let display_string = format!("{}", label);
assert!(display_string.contains("VertexWithU8StateVec"));
assert!(display_string.contains("42"));
}
}