use crate::unary::*;
use std::convert::{TryFrom, TryInto};
use thiserror::Error;
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Choice<const N: usize> {
choice: u8,
}
impl<M: Unary, const N: usize> Default for Choice<N>
where
Number<N>: ToUnary<AsUnary = S<M>>,
{
fn default() -> Self {
0.try_into()
.expect("0 is in bounds for all non-zero-bounded `Choice`s")
}
}
impl<const N: usize> PartialEq<u8> for Choice<N> {
fn eq(&self, other: &u8) -> bool {
self.choice == *other
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Error)]
pub struct OutOfBoundsChoiceError {
choice: u8,
bound: usize,
}
impl std::fmt::Display for OutOfBoundsChoiceError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"choice {} is invalid for exclusive upper bound {}",
self.choice, self.bound
)
}
}
impl<const N: usize> TryFrom<u8> for Choice<N> {
type Error = OutOfBoundsChoiceError;
fn try_from(choice: u8) -> Result<Self, Self::Error> {
if (choice as usize) < N {
Ok(Choice { choice })
} else {
Err(OutOfBoundsChoiceError { choice, bound: N })
}
}
}
impl<const N: usize> From<Choice<N>> for u8 {
fn from(Choice { choice, .. }: Choice<N>) -> u8 {
choice
}
}
#[cfg(feature = "serde")]
mod serialization {
use super::*;
use serde::{
de::{self, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
impl<const N: usize> Serialize for Choice<N> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_u8(self.choice)
}
}
#[derive(Debug, Clone, Copy)]
struct ChoiceVisitor<const N: usize>;
impl<const N: usize> Default for ChoiceVisitor<N> {
fn default() -> Self {
ChoiceVisitor
}
}
impl<'de, const N: usize> Visitor<'de> for ChoiceVisitor<N> {
type Value = Choice<N>;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"a non-negative integer strictly less than {}",
N.min(u8::MAX as usize)
)?;
if N == 0 {
write!(
f,
" (since that strict upper bound is 0, this is impossible)"
)?;
}
Ok(())
}
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: de::Error,
{
let choice: u8 = v
.try_into()
.map_err(|_| de::Error::invalid_value(de::Unexpected::Unsigned(v), &self))?;
choice.try_into().map_err(|_| {
de::Error::invalid_value(de::Unexpected::Unsigned(choice as u64), &self)
})
}
}
impl<'de, const N: usize> Deserialize<'de> for Choice<N> {
fn deserialize<D>(deserializer: D) -> Result<Choice<N>, D::Error>
where
D: Deserializer<'de>,
{
let visitor: ChoiceVisitor<N> = ChoiceVisitor::default();
deserializer.deserialize_u8(visitor)
}
}
}