use serde::{Deserialize, Serialize};
use std::{
error::Error,
fmt::{Display, Formatter, Result as FmtResult},
num::NonZero,
};
pub struct ShardIdParseError {
kind: ShardIdParseErrorType,
}
impl ShardIdParseError {
#[must_use = "retrieving the type has no effect if left unused"]
pub const fn kind(&self) -> &ShardIdParseErrorType {
&self.kind
}
#[allow(clippy::unused_self)]
#[must_use = "consuming the error and retrieving the source has no effect if left unused"]
pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
None
}
#[must_use = "consuming the error into its parts has no effect if left unused"]
pub fn into_parts(self) -> (ShardIdParseErrorType, Option<Box<dyn Error + Send + Sync>>) {
(self.kind, None)
}
}
impl Display for ShardIdParseError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self.kind {
ShardIdParseErrorType::NumberGreaterOrEqualTotal { number, total } => {
f.write_str("ShardId's number (")?;
Display::fmt(&number, f)?;
f.write_str(") was greater or equal to its total (")?;
Display::fmt(&total, f)?;
f.write_str(")")
}
}
}
}
#[derive(Debug)]
pub enum ShardIdParseErrorType {
NumberGreaterOrEqualTotal {
number: u32,
total: u32,
},
}
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
#[serde(try_from = "[u32; 2]", into = "[u32; 2]")]
pub struct ShardId {
number: u32,
total: NonZero<u32>,
}
impl ShardId {
pub const ONE: ShardId = ShardId::new(0, 1);
pub const fn new(number: u32, total: u32) -> Self {
assert!(number < total, "number must be less than total");
if let Some(total) = NonZero::new(total) {
Self { number, total }
} else {
panic!("unreachable: total is at least 1")
}
}
#[allow(clippy::missing_panics_doc)]
pub const fn new_checked(number: u32, total: u32) -> Option<Self> {
if number >= total {
return None;
}
if let Some(total) = NonZero::new(total) {
Some(Self { number, total })
} else {
panic!("unreachable: total is at least 1")
}
}
pub const fn number(self) -> u32 {
self.number
}
pub const fn total(self) -> u32 {
self.total.get()
}
}
impl Display for ShardId {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.debug_list()
.entries(Into::<[u32; 2]>::into(*self))
.finish()
}
}
impl TryFrom<[u32; 2]> for ShardId {
type Error = ShardIdParseError;
fn try_from([number, total]: [u32; 2]) -> Result<Self, Self::Error> {
Self::new_checked(number, total).ok_or(ShardIdParseError {
kind: ShardIdParseErrorType::NumberGreaterOrEqualTotal { number, total },
})
}
}
impl From<ShardId> for [u32; 2] {
fn from(id: ShardId) -> Self {
[id.number(), id.total()]
}
}
#[cfg(test)]
mod tests {
use super::ShardId;
use serde::{Serialize, de::DeserializeOwned};
use serde_test::Token;
use static_assertions::{assert_impl_all, const_assert_eq};
use std::{fmt::Debug, hash::Hash};
const_assert_eq!(ShardId::ONE.number(), 0);
const_assert_eq!(ShardId::ONE.total(), 1);
assert_impl_all!(
ShardId: Clone,
Copy,
Debug,
DeserializeOwned,
Eq,
Hash,
PartialEq,
Send,
Serialize,
Sync
);
#[test]
const fn checked_invalid() {
assert!(ShardId::new_checked(0, 1).is_some());
assert!(ShardId::new_checked(1, 1).is_none());
assert!(ShardId::new_checked(2, 1).is_none());
assert!(ShardId::new_checked(0, 0).is_none());
}
#[test]
const fn getters() {
let id = ShardId::new(2, 4);
assert!(id.number() == 2);
assert!(id.total() == 4);
}
#[test]
fn serde() {
let value = ShardId::new(0, 1);
serde_test::assert_tokens(
&value,
&[
Token::Tuple { len: 2 },
Token::U32(0),
Token::U32(1),
Token::TupleEnd,
],
)
}
#[should_panic(expected = "number must be less than total")]
#[test]
const fn number_equal() {
ShardId::new(1, 1);
}
#[should_panic(expected = "number must be less than total")]
#[test]
const fn number_greater() {
ShardId::new(2, 1);
}
}