use super::Deserializer;
use crate::{buffer::BufReader, de::Deserialize, AlgebraicValue, ArrayValue, ProductValue, SumValue};
use core::{mem, slice};
pub fn eq_bsatn<'r>(lhs: &AlgebraicValue, rhs: Deserializer<'_, impl BufReader<'r>>) -> bool {
match lhs {
AlgebraicValue::Sum(lhs) => eq_bsatn_sum(lhs, rhs),
AlgebraicValue::Product(lhs) => eq_bsatn_prod(lhs, rhs),
AlgebraicValue::Array(lhs) => eq_bsatn_array(lhs, rhs),
AlgebraicValue::Bool(lhs) => eq_bsatn_de(lhs, rhs),
AlgebraicValue::I8(lhs) => eq_bsatn_de(lhs, rhs),
AlgebraicValue::U8(lhs) => eq_bsatn_de(lhs, rhs),
AlgebraicValue::I16(lhs) => eq_bsatn_de(lhs, rhs),
AlgebraicValue::U16(lhs) => eq_bsatn_de(lhs, rhs),
AlgebraicValue::I32(lhs) => eq_bsatn_de(lhs, rhs),
AlgebraicValue::U32(lhs) => eq_bsatn_de(lhs, rhs),
AlgebraicValue::I64(lhs) => eq_bsatn_de(lhs, rhs),
AlgebraicValue::U64(lhs) => eq_bsatn_de(lhs, rhs),
AlgebraicValue::I128(lhs) => eq_bsatn_de(&{ lhs.0 }, rhs),
AlgebraicValue::U128(lhs) => eq_bsatn_de(&{ lhs.0 }, rhs),
AlgebraicValue::I256(lhs) => eq_bsatn_de(lhs, rhs),
AlgebraicValue::U256(lhs) => eq_bsatn_de(lhs, rhs),
AlgebraicValue::F32(lhs) => eq_bsatn_de(lhs, rhs),
AlgebraicValue::F64(lhs) => eq_bsatn_de(lhs, rhs),
AlgebraicValue::String(lhs) => eq_bsatn_str(lhs, rhs),
AlgebraicValue::Min | AlgebraicValue::Max => panic!("not defined for Min/Max"),
}
}
fn eq_bsatn_sum<'r>(lhs: &SumValue, mut rhs: Deserializer<'_, impl BufReader<'r>>) -> bool {
eq_bsatn_de(&lhs.tag, rhs.reborrow()) && eq_bsatn(&lhs.value, rhs)
}
fn eq_bsatn_prod<'r>(lhs: &ProductValue, mut rhs: Deserializer<'_, impl BufReader<'r>>) -> bool {
lhs.elements.iter().all(|f| eq_bsatn(f, rhs.reborrow()))
}
fn eq_bsatn_array<'r>(lhs: &ArrayValue, rhs: Deserializer<'_, impl BufReader<'r>>) -> bool {
match lhs {
ArrayValue::Sum(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_sum),
ArrayValue::Product(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_prod),
ArrayValue::Bool(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_de),
ArrayValue::F32(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_de),
ArrayValue::F64(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_de),
ArrayValue::String(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_str),
ArrayValue::Array(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_array),
ArrayValue::I8(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
ArrayValue::U8(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
ArrayValue::I16(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
ArrayValue::U16(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
ArrayValue::I32(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
ArrayValue::U32(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
ArrayValue::I64(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
ArrayValue::U64(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
ArrayValue::I128(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
ArrayValue::U128(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
ArrayValue::I256(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
ArrayValue::U256(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
}
}
unsafe fn eq_bsatn_int_seq<'r, T>(lhs: &[T], mut rhs: Deserializer<'_, impl BufReader<'r>>) -> bool {
let Ok(len) = rhs.reborrow().deserialize_len() else {
return false;
};
let Ok(rhs_bytes) = rhs.get_slice(len * mem::size_of::<T>()) else {
return false;
};
let ptr = lhs.as_ptr().cast::<u8>();
let lhs_bytes = unsafe { slice::from_raw_parts(ptr, mem::size_of_val(lhs)) };
lhs_bytes == rhs_bytes
}
#[allow(clippy::borrowed_box)]
fn eq_bsatn_str<'r>(lhs: &Box<str>, rhs: Deserializer<'_, impl BufReader<'r>>) -> bool {
<&str>::deserialize(rhs).map(|rhs| &**lhs == rhs).unwrap_or(false)
}
fn eq_bsatn_seq<'r, T, I: ExactSizeIterator<Item = T>, R: BufReader<'r>>(
lhs: impl IntoIterator<IntoIter = I>,
mut rhs: Deserializer<'_, R>,
elem_eq: impl Fn(T, Deserializer<'_, R>) -> bool,
) -> bool {
let mut lhs = lhs.into_iter();
match rhs.reborrow().deserialize_len() {
Ok(len) if lhs.len() == len => lhs.all(|e| elem_eq(e, rhs.reborrow())),
_ => false,
}
}
fn eq_bsatn_de<'r, T: Eq + Deserialize<'r>>(lhs: &T, rhs: Deserializer<'_, impl BufReader<'r>>) -> bool {
T::deserialize(rhs).map(|rhs| lhs == &rhs).unwrap_or(false)
}
#[cfg(test)]
mod tests {
use super::eq_bsatn;
use crate::{
bsatn::{to_vec, Deserializer},
proptest::generate_typed_value,
};
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(2048))]
#[test]
fn encoded_val_eq_to_self((_, val) in generate_typed_value()) {
let mut bsatn = &*to_vec(&val).unwrap();
let de = Deserializer::new(&mut bsatn);
prop_assert!(eq_bsatn(&val, de));
}
}
}