use dtype_variant::DType;
use paste::paste;
use serde::{Deserialize, Serialize};
use serde_bare::{Int, Uint};
use std::{collections::BTreeMap, convert::identity, fmt::Write};
pub type VariantList = Vec<Variant>;
pub type VariantMap = BTreeMap<String, Variant>;
#[derive(
DType,
Default,
Serialize,
Deserialize,
PartialEq,
Clone,
derive_more::Debug,
derive_more::Display,
strum_macros::EnumIs,
)]
pub enum Variant {
#[display("()")]
#[default]
Empty,
Boolean(bool),
#[display("{0}", _0.0)]
#[debug("Signed({0})", _0.0)]
Signed(Int),
#[display("{0}", _0.0)]
#[debug("Unsigned({0})", _0.0)]
Unsigned(Uint),
String(String),
#[display("0x{}", hex::encode(_0))]
Bytes(Vec<u8>),
#[display("{}", VariantListWrap(_0))]
List(VariantList),
#[display("{}", VariantMapWrap(_0))]
Map(VariantMap),
}
struct VariantListWrap<'a>(&'a VariantList);
fn fmt_container(f: &mut std::fmt::Formatter<'_>, it: &Variant) -> std::fmt::Result {
match it {
Variant::String(s) => f.write_fmt(format_args!(r#""{s}""#)),
_ => f.write_fmt(format_args!("{it}")),
}
}
impl std::fmt::Display for VariantListWrap<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_char('[')?;
let mut first = true;
for it in self.0 {
if !first {
f.write_str(", ")?;
}
fmt_container(f, it)?;
first = false;
}
f.write_char(']')
}
}
struct VariantMapWrap<'a>(&'a VariantMap);
impl std::fmt::Display for VariantMapWrap<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_char('{')?;
let mut first = true;
for (k, v) in self.0 {
if !first {
f.write_str(", ")?;
}
f.write_fmt(format_args!(r#""{k}": "#))?;
fmt_container(f, v)?;
first = false;
}
f.write_char('}')
}
}
impl Variant {
#[must_use]
pub fn signed<T: Into<i64>>(i: T) -> Self {
Variant::Signed(Int(i.into()))
}
#[must_use]
pub fn unsigned<T: Into<u64>>(u: T) -> Self {
Variant::Unsigned(Uint(u.into()))
}
pub fn signed_coerce<T>(u: T) -> Self
where
T: num_traits::cast::AsPrimitive<i64>,
{
Variant::Signed(Int(u.as_()))
}
pub fn unsigned_coerce<T>(u: T) -> Self
where
T: num_traits::cast::AsPrimitive<u64>,
{
Variant::Unsigned(Uint(u.as_()))
}
}
macro_rules! from_types {
($var:ident, $cls:ident, $($t:ty),+) => {$(
impl From<$t> for Variant {
fn from(value: $t) -> Self {
Variant::$var($cls(value.into()))
}
}
)+}
}
from_types!(Unsigned, Uint, u64, u32, u16, u8);
from_types!(Signed, Int, i64, i32, i16, i8);
impl From<()> for Variant {
fn from((): ()) -> Self {
Variant::Empty
}
}
impl From<&str> for Variant {
fn from(value: &str) -> Self {
Variant::String(value.to_string())
}
}
impl<const N: usize> From<&[u8; N]> for Variant {
fn from(value: &[u8; N]) -> Self {
Variant::Bytes(value.to_vec())
}
}
impl From<&[u8]> for Variant {
fn from(value: &[u8]) -> Self {
Variant::Bytes(value.to_vec())
}
}
impl<const N: usize> From<&[Variant; N]> for Variant {
fn from(value: &[Variant; N]) -> Self {
Variant::List(value.to_vec())
}
}
macro_rules! as_variant_fn {
($fname:ident, $vartype:ident, $inner:ty, $map_to:expr, $map_as:expr, $map_mut:expr) => {
paste! {
impl Variant {
#[must_use]
pub fn [<into_ $fname>](self) -> Option<$inner> {
self.downcast::<[<$vartype Variant>]>().map($map_to)
}
pub fn [<try_into_ $fname>](self) -> Result<$inner, VariantConversionFailed> {
if let Variant::$vartype(d) = self {
return Ok($map_to(d));
}
Err(VariantConversionFailed(self))
}
#[must_use]
pub fn [<as_ $fname _ref>](&self) -> Option<&$inner> {
self.downcast_ref::<[<$vartype Variant>]>().map($map_as)
}
#[must_use]
pub fn [<as_ $fname _mut>](&mut self) -> Option<&mut $inner> {
self.downcast_mut::<[<$vartype Variant>]>().map($map_mut)
}
}
}
};
}
macro_rules! as_variant_id {
($fname:ident, $vartype:ident, $inner:ty) => {
as_variant_fn!($fname, $vartype, $inner, identity, identity, identity);
};
}
as_variant_id!(bool, Boolean, bool);
as_variant_id!(bytes, Bytes, Vec<u8>);
as_variant_id!(string, String, String);
as_variant_id!(list, List, VariantList);
as_variant_id!(map, Map, VariantMap);
as_variant_fn!(unsigned, Unsigned, u64, |u: Uint| u.0, |u| &u.0, |u| &mut u
.0);
as_variant_fn!(signed, Signed, i64, |i: Int| i.0, |i| &i.0, |i| &mut i.0);
impl Variant {
#[must_use]
pub fn as_str(&self) -> Option<&str> {
self.downcast_ref::<StringVariant>()
.map(std::convert::AsRef::as_ref)
}
#[must_use]
pub fn as_str_mut(&mut self) -> Option<&mut str> {
self.downcast_mut::<StringVariant>()
.map(std::convert::AsMut::as_mut)
}
#[must_use]
pub fn as_slice_bytes(&self) -> Option<&[u8]> {
self.downcast_ref::<BytesVariant>()
.map(std::convert::AsRef::as_ref)
}
#[must_use]
pub fn as_slice_bytes_mut(&mut self) -> Option<&mut [u8]> {
self.downcast_mut::<BytesVariant>()
.map(std::convert::AsMut::as_mut)
}
#[must_use]
pub fn as_slice_variant(&self) -> Option<&[Variant]> {
self.downcast_ref::<ListVariant>()
.map(std::convert::AsRef::as_ref)
}
#[must_use]
pub fn as_slice_variant_mut(&mut self) -> Option<&mut [Variant]> {
self.downcast_mut::<ListVariant>()
.map(std::convert::AsMut::as_mut)
}
}
#[derive(thiserror::Error, Debug, derive_more::Display)]
#[display("VariantConversionFailed({_0})")]
pub struct VariantConversionFailed(pub Variant);
impl Variant {
#[must_use]
pub fn coerce_bool(&self) -> bool {
match self {
Variant::Boolean(b) => *b,
Variant::Unsigned(Uint(u)) => *u != 0,
Variant::Signed(Int(i)) => *i != 0,
_ => false,
}
}
#[must_use]
pub fn coerce_signed(&self) -> i64 {
use num_traits::AsPrimitive as _;
match self {
Variant::Signed(Int(i)) => *i,
Variant::Unsigned(Uint(u)) => (*u).as_(),
Variant::Boolean(b) => i64::from(*b),
_ => 0,
}
}
#[must_use]
pub fn coerce_unsigned(&self) -> u64 {
use num_traits::AsPrimitive as _;
match self {
Variant::Signed(Int(i)) => (*i).as_(),
Variant::Unsigned(Uint(u)) => *u,
Variant::Boolean(b) => u64::from(*b),
_ => 0,
}
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod test {
use assertables::assert_matches;
use pretty_assertions::assert_eq;
use serde_bare::{Int, Uint};
use crate::protocol::{VariantConversionFailed, common::ProtocolMessage};
use super::{Variant, VariantMap};
#[test]
fn creation_and_stringify() {
macro_rules! test_var {
($val:expr) => {
let v = Variant::from($val);
eprintln!("{v}");
};
($val:expr, $expect:expr) => {
let v = Variant::from($val);
eprintln!("{v}");
assert_eq!(v.to_string(), $expect);
};
}
test_var!((), "()");
test_var!(true, "true");
test_var!(false, "false");
test_var!("hello".to_string(), "hello");
test_var!("hello", "hello");
test_var!(vec![0, 1, 2, 3, 4], "0x0001020304");
test_var!(&[7, 6, 5, 4], "0x07060504");
let list = &[
Variant::from(true),
Variant::unsigned(0u8),
Variant::from("whee"),
];
test_var!(list, r#"[true, 0, "whee"]"#);
test_var!(list.to_vec(), r#"[true, 0, "whee"]"#);
let mut map = VariantMap::new();
let _ = map.insert("foo".into(), "bar".into());
let _ = map.insert("baz".into(), Variant::signed(42));
test_var!(map, r#"{"baz": 42, "foo": "bar"}"#);
}
#[test]
fn construct_upcast_ints() {
let v = Variant::signed(42i16);
assert_matches!(v, Variant::Signed(Int(42)));
let v = Variant::unsigned_coerce(42);
assert_matches!(v, Variant::Unsigned(Uint(42)));
}
#[test]
fn downcasting() {
let mut v = Variant::from(false);
let r = v.as_bool_ref();
assert_matches!(r, Some(&false));
let r = v.as_bool_mut().unwrap();
*r = true;
let r = v.into_bool().unwrap();
assert!(r);
let mut v = Variant::unsigned(42u8);
let r = v.as_unsigned_mut().unwrap();
*r = 1234;
let r = v.as_unsigned_ref().unwrap();
assert_eq!(*r, 1234);
let r = v.into_unsigned();
assert_matches!(r, Some(1234));
let v = Variant::signed(-4);
let r = v.into_signed();
assert_matches!(r, Some(-4));
let v = Variant::unsigned(false);
let r = v.into_unsigned();
assert_matches!(r, Some(0));
}
#[test]
fn downcast_list() {
let mut v = Variant::from(vec![
Variant::from(true),
Variant::signed(-4),
Variant::from("hi"),
]);
assert!(v.is_list());
let r = v.as_list_mut().unwrap();
assert!(r.len() == 3);
assert!(r[0].as_bool_ref() == Some(&true));
}
#[test]
fn downcast_map() {
let mut map = VariantMap::new();
let _ = map.insert("foo".into(), "bar".into());
let _ = map.insert("baz".into(), Variant::signed(42));
let mut v = Variant::from(map);
let r = v.as_map_mut().unwrap();
assert!(r.len() == 2);
assert_eq!(
r.get_mut("foo").and_then(|v| v.as_string_mut()),
Some(&mut "bar".to_string())
);
assert_eq!(r.get("baz").and_then(Variant::as_signed_ref), Some(&42));
}
#[test]
fn conversion_fail() {
let var = Variant::from(1234);
let res = var.try_into_bool();
assert_matches!(
res,
Err(VariantConversionFailed(Variant::Signed(Int(1234))))
);
}
#[test]
fn ref_inner_str() {
let mut var = Variant::from("hello");
let r = var.as_str();
assert_eq!(r, Some("hello"));
let r = var.as_str_mut();
if let Some(rr) = r {
rr.make_ascii_uppercase();
}
assert_eq!(var.as_str(), Some("HELLO"));
}
#[test]
fn ref_inner_bytes() {
let mut var = Variant::from(&[1, 2, 3, 4, 5]);
let r = var.as_slice_bytes().unwrap();
assert_eq!(r.len(), 5);
assert_eq!(r[0], 1);
let r = var.as_slice_bytes_mut().unwrap();
r[0] = 42;
assert_eq!(var.into_bytes(), Some(vec![42u8, 2, 3, 4, 5]));
}
#[test]
fn ref_inner_variant_list() {
let mut var = Variant::from(vec![
Variant::from(true),
Variant::signed(-4),
Variant::from("hi"),
]);
let r = var.as_slice_variant().unwrap();
assert!(r.len() == 3);
assert!(r[0].as_bool_ref() == Some(&true));
let r = var.as_slice_variant_mut().unwrap();
r[1] = Variant::from(false);
assert_eq!(
var.into_list(),
Some(vec![
Variant::from(true),
Variant::from(false),
Variant::from("hi"),
])
);
}
impl ProtocolMessage for Variant {}
fn test_encode(v: &Variant, expected: &[u8]) {
let encoded = v.to_vec().unwrap();
assert_eq!(encoded, expected, "failing case: {:?}", v);
let decoded = Variant::from_slice(&encoded).unwrap();
assert_eq!(*v, decoded, "failing case: {:?}", v);
}
#[test]
fn ser_de_empty() {
test_encode(&Variant::Empty, &[0u8]);
}
#[test]
fn ser_de_bool() {
test_encode(&Variant::from(true), &[1u8, 1]);
test_encode(&Variant::from(false), &[1u8, 0]);
}
#[test]
fn ser_de_int() {
test_encode(&Variant::signed(42), &[2u8, 84]);
test_encode(&Variant::signed(-2), &[2u8, 3]);
test_encode(&Variant::signed(0), &[2u8, 0]);
test_encode(&Variant::signed(1234), &[2u8, 164, 19]);
test_encode(
&Variant::signed(-9_223_372_036_854_775_807i64),
&[2u8, 253, 255, 255, 255, 255, 255, 255, 255, 255, 1],
);
test_encode(&Variant::unsigned_coerce(42), &[3u8, 42]);
test_encode(
&Variant::unsigned(18_446_744_073_709_551_615u64),
&[3u8, 255, 255, 255, 255, 255, 255, 255, 255, 255, 1],
);
}
#[test]
fn ser_de_str() {
test_encode(&Variant::from("hello"), &[4u8, 5, 104, 101, 108, 108, 111]);
}
#[test]
fn ser_de_bytes() {
test_encode(&Variant::from(&[1, 2, 3, 4, 5]), &[5u8, 5, 1, 2, 3, 4, 5]);
}
#[test]
fn ser_de_list() {
let list = vec![
Variant::from(true),
Variant::unsigned(0u8),
Variant::from("whee"),
];
test_encode(
&Variant::from(list),
&[6u8, 3, 1, 1, 3, 0, 4, 4, 119, 104, 101, 101],
);
}
#[test]
fn ser_de_map() {
let map = {
let mut m = VariantMap::new();
let _ = m.insert("foo".into(), "bar".into());
let _ = m.insert("baz".into(), Variant::signed(42));
m
};
test_encode(
&Variant::from(map),
&[
7u8, 2, 3, 98, 97, 122, 2, 84, 3, 102, 111, 111, 4, 3, 98, 97, 114, ],
);
}
#[test]
fn coerce() {
let mut var = Variant::from(true);
assert!(var.coerce_bool());
assert_eq!(var.coerce_signed(), 1);
assert_eq!(var.coerce_unsigned(), 1);
var = Variant::from(false);
assert!(!var.coerce_bool());
assert_eq!(var.coerce_signed(), 0);
assert_eq!(var.coerce_unsigned(), 0);
var = Variant::signed(17);
assert!(var.coerce_bool());
assert_eq!(var.coerce_signed(), 17);
assert_eq!(var.coerce_unsigned(), 17);
var = Variant::signed(-1);
assert!(var.coerce_bool());
assert_eq!(var.coerce_signed(), -1);
assert_eq!(var.coerce_unsigned(), 18_446_744_073_709_551_615);
var = Variant::unsigned(78u8);
assert!(var.coerce_bool());
assert_eq!(var.coerce_signed(), 78);
assert_eq!(var.coerce_unsigned(), 78);
var = Variant::unsigned(2u64.pow(63) + 1);
assert!(var.coerce_bool());
assert_eq!(var.coerce_signed(), -9_223_372_036_854_775_807);
assert_eq!(var.coerce_unsigned(), 2u64.pow(63) + 1);
var = Variant::from("hello");
assert!(!var.coerce_bool());
assert_eq!(var.coerce_signed(), 0);
assert_eq!(var.coerce_unsigned(), 0);
}
}