use std::fmt;
use std::hash::{Hash, Hasher};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GenerationOverflowError {
pub index: u32,
pub generation: u64,
}
impl fmt::Display for GenerationOverflowError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"generation overflow at index {}: generation {} would exceed MAX_GENERATION",
self.index, self.generation
)
}
}
impl std::error::Error for GenerationOverflowError {}
#[derive(Clone, Copy, Serialize, Deserialize, Ord, PartialOrd)]
pub struct NodeId {
index: u32,
generation: u64,
}
impl NodeId {
pub const INVALID: NodeId = NodeId {
index: u32::MAX,
generation: 0,
};
pub const MAX_GENERATION: u64 = u64::MAX / 2;
#[inline]
#[must_use]
pub const fn new(index: u32, generation: u64) -> Self {
Self { index, generation }
}
#[inline]
#[must_use]
pub const fn index(self) -> u32 {
self.index
}
#[inline]
#[must_use]
pub const fn generation(self) -> u64 {
self.generation
}
#[inline]
#[must_use]
pub const fn is_invalid(self) -> bool {
self.index == u32::MAX
}
#[inline]
#[must_use]
pub const fn is_valid(self) -> bool {
self.index != u32::MAX
}
pub fn try_increment_generation(self) -> Result<u64, GenerationOverflowError> {
let next = self
.generation
.checked_add(1)
.ok_or(GenerationOverflowError {
index: self.index,
generation: self.generation,
})?;
if next > Self::MAX_GENERATION {
return Err(GenerationOverflowError {
index: self.index,
generation: self.generation,
});
}
Ok(next)
}
#[inline]
#[must_use]
pub const fn is_near_overflow(self) -> bool {
self.generation > Self::MAX_GENERATION.saturating_sub(1000)
}
}
impl PartialEq for NodeId {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.index == other.index && self.generation == other.generation
}
}
impl Eq for NodeId {}
impl Hash for NodeId {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.index.hash(state);
self.generation.hash(state);
}
}
impl fmt::Debug for NodeId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_invalid() {
write!(f, "NodeId(INVALID)")
} else {
write!(f, "NodeId({}:{})", self.index, self.generation)
}
}
}
impl fmt::Display for NodeId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_invalid() {
write!(f, "INVALID")
} else {
write!(f, "{}:{}", self.index, self.generation)
}
}
}
impl Default for NodeId {
#[inline]
fn default() -> Self {
Self::INVALID
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_node_id_creation() {
let id = NodeId::new(42, 1);
assert_eq!(id.index(), 42);
assert_eq!(id.generation(), 1);
assert!(!id.is_invalid());
assert!(id.is_valid());
}
#[test]
fn test_node_id_invalid_sentinel() {
assert!(NodeId::INVALID.is_invalid());
assert!(!NodeId::INVALID.is_valid());
assert_eq!(NodeId::INVALID.index(), u32::MAX);
assert_eq!(NodeId::INVALID.generation(), 0);
}
#[test]
fn test_node_id_default() {
let default_id: NodeId = NodeId::default();
assert_eq!(default_id, NodeId::INVALID);
}
#[test]
fn test_node_id_equality() {
let id1 = NodeId::new(5, 10);
let id2 = NodeId::new(5, 10);
let id3 = NodeId::new(5, 11);
let id4 = NodeId::new(6, 10);
assert_eq!(id1, id2);
assert_ne!(id1, id3); assert_ne!(id1, id4); }
#[test]
fn test_node_id_hash() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(NodeId::new(1, 1));
set.insert(NodeId::new(1, 2));
set.insert(NodeId::new(2, 1));
assert!(set.contains(&NodeId::new(1, 1)));
assert!(set.contains(&NodeId::new(1, 2)));
assert!(set.contains(&NodeId::new(2, 1)));
assert!(!set.contains(&NodeId::new(3, 1)));
assert_eq!(set.len(), 3);
}
#[test]
#[allow(clippy::clone_on_copy)] fn test_node_id_copy_clone() {
let id = NodeId::new(10, 20);
let copied = id;
let cloned = id.clone();
assert_eq!(id, copied);
assert_eq!(id, cloned);
}
#[test]
fn test_generation_increment_success() {
let id = NodeId::new(5, 1);
let next_gen = id.try_increment_generation().unwrap();
assert_eq!(next_gen, 2);
}
#[test]
fn test_generation_increment_at_max() {
let id = NodeId::new(5, NodeId::MAX_GENERATION);
let result = id.try_increment_generation();
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.index, 5);
assert_eq!(err.generation, NodeId::MAX_GENERATION);
}
#[test]
fn test_generation_overflow_at_u64_max() {
let id = NodeId::new(5, u64::MAX);
let result = id.try_increment_generation();
assert!(result.is_err());
}
#[test]
fn test_is_near_overflow() {
let safe_id = NodeId::new(5, 1000);
assert!(!safe_id.is_near_overflow());
let near_limit = NodeId::new(5, NodeId::MAX_GENERATION - 500);
assert!(near_limit.is_near_overflow());
let at_limit = NodeId::new(5, NodeId::MAX_GENERATION);
assert!(at_limit.is_near_overflow());
}
#[test]
fn test_debug_display_format() {
let id = NodeId::new(42, 7);
assert_eq!(format!("{id:?}"), "NodeId(42:7)");
assert_eq!(format!("{id}"), "42:7");
assert_eq!(format!("{:?}", NodeId::INVALID), "NodeId(INVALID)");
assert_eq!(format!("{}", NodeId::INVALID), "INVALID");
}
#[test]
fn test_serde_roundtrip() {
let original = NodeId::new(123, 456);
let json = serde_json::to_string(&original).unwrap();
let deserialized: NodeId = serde_json::from_str(&json).unwrap();
assert_eq!(original, deserialized);
let bytes = postcard::to_allocvec(&original).unwrap();
let deserialized: NodeId = postcard::from_bytes(&bytes).unwrap();
assert_eq!(original, deserialized);
}
#[test]
fn test_size_of_node_id() {
assert!(std::mem::size_of::<NodeId>() <= 16);
assert!(std::mem::size_of::<NodeId>() >= 12);
}
}