use std::collections::VecDeque;
use std::fmt::{self, Display};
use std::marker::PhantomData;
use std::mem;
use derive_deftly::{Deftly, define_derive_deftly, derive_deftly_adhoc};
use paste::paste;
use serde::de::{self, DeserializeSeed, Deserializer, Error as _, IgnoredAny, MapAccess, Visitor};
use serde::{Deserialize, Serialize, Serializer};
use serde_value::Value;
use thiserror::Error;
define_derive_deftly! {
export Flattenable for struct, expect items:
impl tor_config::Flattenable for $ttype {
fn has_field(s: &str) -> bool {
let fnames = tor_config::flattenable_extract_fields::<'_, Self>();
IntoIterator::into_iter(fnames).any(|f| *f == s)
}
}
#[test]
fn $<flattenable_test_ ${snake_case $tname}>() {
let _: bool = <$ttype as tor_config::Flattenable>::has_field("");
}
}
#[derive(Deftly, Debug, Clone, Copy, Hash, Ord, PartialOrd, Eq, PartialEq, Default)]
#[derive_deftly_adhoc]
#[allow(clippy::exhaustive_structs)]
pub struct Flatten<T, U>(pub T, pub U);
pub trait Flattenable {
fn has_field(f: &str) -> bool;
}
macro_rules! call_any { { $what:ident $( $args:tt )* } => { paste!{
fn [<deserialize_ $what>]<V>(self $( $args )*, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_any(visitor)
}
} } }
macro_rules! call_any_for_rest { {} => {
call_any!(map);
call_any!(bool);
call_any!(byte_buf);
call_any!(bytes);
call_any!(char);
call_any!(f32);
call_any!(f64);
call_any!(i128);
call_any!(i16);
call_any!(i32);
call_any!(i64);
call_any!(i8);
call_any!(identifier);
call_any!(ignored_any);
call_any!(option);
call_any!(seq);
call_any!(str);
call_any!(string);
call_any!(u128);
call_any!(u16);
call_any!(u32);
call_any!(u64);
call_any!(u8);
call_any!(unit);
call_any!(enum, _: &'static str, _: FieldList);
call_any!(newtype_struct, _: &'static str );
call_any!(tuple, _: usize );
call_any!(tuple_struct, _: &'static str, _: usize );
call_any!(unit_struct, _: &'static str );
} }
derive_deftly_adhoc! {
Flatten expect items:
impl<T, U> Serialize for Flatten<T, U>
where $( $ftype: Serialize, )
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer
{
#[derive(Serialize)]
struct Flatten<'r, T, U> {
$(
#[serde(flatten)]
$fpatname: &'r $ftype,
)
}
Flatten {
$(
$fpatname: &self.$fname,
)
}
.serialize(serializer)
}
}
impl<T, U> Flattenable for Flatten<T, U>
where $( $ftype: Flattenable, )
{
fn has_field(f: &str) -> bool {
$(
$ftype::has_field(f)
||
)
false
}
}
}
#[derive(Default)]
struct Portion(VecDeque<(String, Value)>);
struct FlattenVisitor<T, U>(PhantomData<(T, U)>);
struct Key(String);
type FlattenError = serde_value::DeserializerError;
derive_deftly_adhoc! {
Flatten expect items:
${define FLATTENABLE $( $ftype: Deserialize<'de> + Flattenable, )}
impl<'de, T, U> Deserialize<'de> for Flatten<T, U>
where $FLATTENABLE
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de>
{
deserializer.deserialize_map(FlattenVisitor(PhantomData))
}
}
impl<'de, T, U> Visitor<'de> for FlattenVisitor<T,U>
where $FLATTENABLE
{
type Value = Flatten<T, U>;
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "map (for struct)")
}
fn visit_map<A>(self, mut mapa: A) -> Result<Self::Value, A::Error>
where A: MapAccess<'de>
{
${define P $<p_ $fname>}
${for fields { let mut $P = Portion::default(); }}
#[allow(clippy::suspicious_else_formatting)] while let Some(k) = mapa.next_key::<String>()? {
$(
if $ftype::has_field(&k) {
let v: Value = mapa.next_value()?;
$P.0.push_back((k, v));
continue;
}
else
)
{
let _: IgnoredAny = mapa.next_value()?;
}
}
Flatten::assemble( ${for fields { $P, }} )
.map_err(A::Error::custom)
}
}
}
derive_deftly_adhoc! {
Flatten expect items:
impl<'de, T, U> Flatten<T, U>
where $( $ftype: Deserialize<'de>, )
{
fn assemble(
$(
$fpatname: Portion,
)
) -> Result<Self, FlattenError> {
Ok(Flatten(
$(
$ftype::deserialize($fpatname)?,
)
))
}
}
}
impl<'de> Deserializer<'de> for Portion {
type Error = FlattenError;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_map(self)
}
call_any!(struct, _: &'static str, _: FieldList);
call_any_for_rest!();
}
impl<'de> MapAccess<'de> for Portion {
type Error = FlattenError;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where
K: DeserializeSeed<'de>,
{
let Some(entry) = self.0.get_mut(0) else {
return Ok(None);
};
let k = mem::take(&mut entry.0);
let k: K::Value = seed.deserialize(Key(k))?;
Ok(Some(k))
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
where
V: DeserializeSeed<'de>,
{
let v = self
.0
.pop_front()
.expect("next_value called inappropriately")
.1;
let r = seed.deserialize(v)?;
Ok(r)
}
}
impl<'de> Deserializer<'de> for Key {
type Error = FlattenError;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_string(self.0)
}
call_any!(struct, _: &'static str, _: FieldList);
call_any_for_rest!();
}
type FieldList = &'static [&'static str];
struct FieldExtractor;
#[derive(Error, Debug)]
#[error("Flattenable macro test gave error, so test passed successfully")]
struct FieldExtractorSuccess(FieldList);
pub fn flattenable_extract_fields<'de, T: Deserialize<'de>>() -> FieldList {
let notional_input = FieldExtractor;
let FieldExtractorSuccess(fields) = T::deserialize(notional_input)
.map(|_| ())
.expect_err("unexpected success deserializing from FieldExtractor!");
fields
}
impl de::Error for FieldExtractorSuccess {
fn custom<E>(e: E) -> Self
where
E: Display,
{
panic!("Flattenable macro test failed - some *other* serde error: {e}");
}
}
impl<'de> Deserializer<'de> for FieldExtractor {
type Error = FieldExtractorSuccess;
fn deserialize_struct<V>(
self,
_name: &'static str,
fields: FieldList,
_visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
Err(FieldExtractorSuccess(fields))
}
fn deserialize_any<V>(self, _: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
panic!("test failed: Flattennable misimplemented by macros!");
}
call_any_for_rest!();
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_duration_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
use crate as tor_config;
use std::collections::HashMap;
#[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
#[derive_deftly(Flattenable)]
struct A {
a: i32,
m: HashMap<String, String>,
}
#[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
#[derive_deftly(Flattenable)]
struct B {
b: i32,
v: Vec<String>,
}
#[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
#[derive_deftly(Flattenable)]
struct C {
c: HashMap<String, String>,
}
const TEST_INPUT: &str = r#"
a = 42
m.one = "unum"
m.two = "bis"
b = 99
v = ["hi", "ho"]
spurious = 66
c.zed = "final"
"#;
fn test_input() -> toml::Value {
toml::from_str(TEST_INPUT).unwrap()
}
fn simply<'de, T: Deserialize<'de>>() -> T {
test_input().try_into().unwrap()
}
fn with_ignored<'de, T: Deserialize<'de>>() -> (T, Vec<String>) {
let mut ignored = vec![];
let f = serde_ignored::deserialize(
test_input(), |path| ignored.push(path.to_string()),
)
.unwrap();
(f, ignored)
}
#[test]
fn plain() {
let f: Flatten<A, B> = test_input().try_into().unwrap();
assert_eq!(f, Flatten(simply(), simply()));
}
#[test]
fn ignored() {
let (f, ignored) = with_ignored::<Flatten<A, B>>();
assert_eq!(f, simply());
assert_eq!(ignored, ["c", "spurious"]);
}
#[test]
fn nested() {
let (f, ignored) = with_ignored::<Flatten<A, Flatten<B, C>>>();
assert_eq!(f, simply());
assert_eq!(ignored, ["spurious"]);
}
#[test]
fn ser() {
let f: Flatten<A, Flatten<B, C>> = simply();
assert_eq!(
serde_json::to_value(f).unwrap(),
serde_json::json!({
"a": 42,
"m": {
"one": "unum",
"two": "bis"
},
"b": 99,
"v": [
"hi",
"ho"
],
"c": {
"zed": "final"
}
}),
);
}
fn flattenable_extract_fields_a() -> FieldList {
flattenable_extract_fields::<'_, A>()
}
#[test]
fn flattenable_extract_fields_a_test() {
use std::hint::black_box;
let f: fn() -> _ = black_box(flattenable_extract_fields_a);
eprintln!("{:?}", f());
}
}