use crate::Category;
use serde::{Deserialize, Serialize};
use std::convert::{From, TryFrom};
use std::hash::Hash;
use thiserror::Error;
#[derive(Debug, Clone, Serialize, Deserialize, PartialOrd)]
#[serde(rename_all = "snake_case")]
pub enum Datum {
Binary(bool),
Continuous(f64),
Categorical(Category),
Count(u32),
Missing,
}
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum DatumConversionError {
#[error("tried to convert Binary into a type other than bool")]
InvalidTypeRequestedFromBinary,
#[error("tried to convert Continuous into a type other than f64")]
InvalidTypeRequestedFromContinuous,
#[error("tried to convert Categorical into non-categorical type")]
InvalidTypeRequestedFromCategorical,
#[error("tried to convert Count into a type other than u32")]
InvalidTypeRequestedFromCount,
#[error("cannot convert Missing into a value of any type")]
CannotConvertMissing,
}
fn hash_float<H: std::hash::Hasher>(float: f64, state: &mut H) {
let x: f64 = if float.is_nan() { std::f64::NAN } else { float };
x.to_bits().hash(state);
}
impl Hash for Datum {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
Self::Binary(x) => x.hash(state),
Self::Continuous(x) => hash_float(*x, state),
Self::Categorical(x) => x.hash(state),
Self::Count(x) => x.hash(state),
Self::Missing => hash_float(std::f64::NAN, state),
}
}
}
macro_rules! datum_peq {
($x: ident, $y: ident, $variant: ident) => {{
if let Datum::$variant(y) = $y {
$x == y
} else {
false
}
}};
}
impl PartialEq for Datum {
fn eq(&self, other: &Self) -> bool {
match self {
Self::Continuous(x) => {
if let Self::Continuous(y) = other {
if x.is_nan() && y.is_nan() {
true
} else {
x == y
}
} else {
false
}
}
Self::Binary(x) => datum_peq!(x, other, Binary),
Self::Categorical(x) => datum_peq!(x, other, Categorical),
Self::Count(x) => datum_peq!(x, other, Count),
Self::Missing => matches!(other, Self::Missing),
}
}
}
macro_rules! impl_try_from_datum {
($out: ty, $pat_in: path, $err: expr) => {
impl TryFrom<Datum> for $out {
type Error = DatumConversionError;
fn try_from(datum: Datum) -> Result<$out, Self::Error> {
match datum {
$pat_in(x) => Ok(x),
Datum::Missing => {
Err(DatumConversionError::CannotConvertMissing)
}
_ => Err($err),
}
}
}
};
}
impl TryFrom<Datum> for u8 {
type Error = DatumConversionError;
fn try_from(datum: Datum) -> Result<u8, Self::Error> {
match datum {
Datum::Categorical(Category::U8(x)) => Ok(x),
Datum::Categorical(Category::Bool(x)) => Ok(x as u8),
Datum::Missing => Err(DatumConversionError::CannotConvertMissing),
_ => Err(DatumConversionError::InvalidTypeRequestedFromCategorical),
}
}
}
impl_try_from_datum!(
bool,
Datum::Binary,
DatumConversionError::InvalidTypeRequestedFromBinary
);
impl_try_from_datum!(
f64,
Datum::Continuous,
DatumConversionError::InvalidTypeRequestedFromContinuous
);
impl_try_from_datum!(
u32,
Datum::Count,
DatumConversionError::InvalidTypeRequestedFromCount
);
impl Datum {
pub fn to_f64_opt(&self) -> Option<f64> {
match self {
Datum::Binary(x) => Some(if *x { 1.0 } else { 0.0 }),
Datum::Continuous(x) => Some(*x),
Datum::Categorical(Category::Bool(x)) => {
Some(if *x { 1.0 } else { 0.0 })
}
Datum::Categorical(Category::U8(x)) => Some(f64::from(*x)),
Datum::Categorical(Category::String(_)) => None,
Datum::Count(x) => Some(f64::from(*x)),
Datum::Missing => None,
}
}
pub fn to_u8_opt(&self) -> Option<u8> {
match self {
Datum::Binary(..) => None,
Datum::Continuous(..) => None,
Datum::Categorical(Category::U8(x)) => Some(*x),
Datum::Categorical(Category::Bool(x)) => Some(*x as u8),
Datum::Categorical(Category::String(_)) => None,
Datum::Count(..) => None,
Datum::Missing => None,
}
}
pub fn is_binary(&self) -> bool {
matches!(self, Datum::Binary(_))
}
pub fn is_continuous(&self) -> bool {
matches!(self, Datum::Continuous(_))
}
pub fn is_categorical(&self) -> bool {
matches!(self, Datum::Categorical(_))
}
pub fn is_count(&self) -> bool {
matches!(self, Datum::Count(_))
}
pub fn is_missing(&self) -> bool {
matches!(self, Datum::Missing)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::convert::TryInto;
#[test]
fn continuous_datum_try_into_f64() {
let datum = Datum::Continuous(1.1);
let _res: f64 = datum.try_into().unwrap();
}
#[test]
#[should_panic]
fn continuous_datum_try_into_u8_panics() {
let datum = Datum::Continuous(1.1);
let _res: u8 = datum.try_into().unwrap();
}
#[test]
#[should_panic]
fn missing_datum_try_into_u8_panics() {
let datum = Datum::Missing;
let _res: u8 = datum.try_into().unwrap();
}
#[test]
#[should_panic]
fn missing_datum_try_into_f64_panics() {
let datum = Datum::Missing;
let _res: f64 = datum.try_into().unwrap();
}
#[test]
fn categorical_datum_try_into_u8() {
let datum = Datum::Categorical(Category::U8(7));
let _res: u8 = datum.try_into().unwrap();
}
#[test]
#[should_panic]
fn categorical_datum_try_into_f64_panics() {
let datum = Datum::Categorical(Category::U8(7));
let _res: f64 = datum.try_into().unwrap();
}
#[test]
fn count_data_into_f64() {
let datum = Datum::Count(12);
let x = datum.to_f64_opt();
assert_eq!(x, Some(12.0));
}
#[test]
fn count_data_try_into_u32() {
let datum = Datum::Count(12);
let _x: u32 = datum.try_into().unwrap();
}
#[test]
#[should_panic]
fn count_data_try_into_u8_fails() {
let datum = Datum::Count(12);
let _x: u8 = datum.try_into().unwrap();
}
#[test]
fn serde_continuous() {
let data = r#"
{
"continuous": 1.2
}"#;
let x: Datum = serde_json::from_str(data).unwrap();
assert_eq!(x, Datum::Continuous(1.2));
}
#[test]
fn serde_categorical_u8() {
let data = r#"
{
"categorical": 2
}"#;
let x: Datum = serde_json::from_str(data).unwrap();
assert_eq!(x, Datum::Categorical(Category::U8(2)));
}
#[test]
fn serde_categorical_bool() {
let data = r#"
{
"categorical": true
}"#;
let x: Datum = serde_json::from_str(data).unwrap();
assert_eq!(x, Datum::Categorical(Category::Bool(true)));
}
#[test]
fn serde_categorical_string() {
let data = r#"
{
"categorical": "zoidberg"
}"#;
let x: Datum = serde_json::from_str(data).unwrap();
assert_eq!(x, Datum::Categorical("zoidberg".into()));
}
#[test]
fn serde_count() {
let data = r#"
{
"count": 277
}"#;
let x: Datum = serde_json::from_str(data).unwrap();
assert_eq!(x, Datum::Count(277));
}
#[test]
fn serde_missing() {
let data = r#""missing""#;
let x: Datum = serde_json::from_str(data).unwrap();
assert_eq!(x, Datum::Missing);
}
}