use serde::{Deserialize, Deserializer, Serialize, Serializer};
use smallvec::SmallVec;
use crate::{CoreError, CoreResult, DbString, Value};
#[derive(Clone, Debug, PartialEq)]
pub struct LabelDiff {
pub added: SmallVec<[DbString; 2]>,
pub removed: SmallVec<[DbString; 2]>,
}
impl LabelDiff {
pub fn new(
added: impl IntoIterator<Item = DbString>,
removed: impl IntoIterator<Item = DbString>,
) -> CoreResult<Self> {
let added = sorted_deduped(added);
let removed = sorted_deduped(removed);
ensure_disjoint("label", &added, &removed)?;
Ok(Self { added, removed })
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.added.is_empty() && self.removed.is_empty()
}
}
#[derive(Deserialize, Serialize)]
struct LabelDiffWire {
added: SmallVec<[DbString; 2]>,
removed: SmallVec<[DbString; 2]>,
}
impl Serialize for LabelDiff {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut added = self.added.clone();
let mut removed = self.removed.clone();
added.sort_by(|lhs, rhs| lhs.as_str().cmp(rhs.as_str()));
removed.sort_by(|lhs, rhs| lhs.as_str().cmp(rhs.as_str()));
LabelDiffWire { added, removed }.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for LabelDiff {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let wire = LabelDiffWire::deserialize(deserializer)?;
validate_sorted_unique(&wire.added, "LabelDiff.added")?;
validate_sorted_unique(&wire.removed, "LabelDiff.removed")?;
validate_disjoint(&wire.added, &wire.removed, "label")?;
Ok(Self {
added: wire.added,
removed: wire.removed,
})
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct PropertyDiff {
pub set: SmallVec<[(DbString, Value); 4]>,
pub removed: SmallVec<[DbString; 2]>,
}
impl PropertyDiff {
pub fn new(
set: impl IntoIterator<Item = (DbString, Value)>,
removed: impl IntoIterator<Item = DbString>,
) -> CoreResult<Self> {
let mut set: SmallVec<[(DbString, Value); 4]> = set.into_iter().collect();
if set.len() > 1 {
set.sort_by(|(lhs, _), (rhs, _)| lhs.cmp(rhs));
let mut deduped = SmallVec::new();
for (key, value) in set {
if let Some((last_key, last_value)) = deduped.last_mut()
&& last_key == &key
{
*last_value = value;
continue;
}
deduped.push((key, value));
}
set = deduped;
}
let removed = sorted_deduped(removed);
for (key, _) in set.iter() {
if removed.binary_search(key).is_ok() {
return Err(CoreError::OverlappingDiff {
kind: "property",
key: key.clone(),
});
}
}
Ok(Self { set, removed })
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.set.is_empty() && self.removed.is_empty()
}
}
#[derive(Deserialize, Serialize)]
struct PropertyDiffWire {
set: SmallVec<[(DbString, Value); 4]>,
removed: SmallVec<[DbString; 2]>,
}
impl Serialize for PropertyDiff {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut set = self.set.clone();
let mut removed = self.removed.clone();
set.sort_by(|(lhs, _), (rhs, _)| lhs.as_str().cmp(rhs.as_str()));
removed.sort_by(|lhs, rhs| lhs.as_str().cmp(rhs.as_str()));
PropertyDiffWire { set, removed }.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for PropertyDiff {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let wire = PropertyDiffWire::deserialize(deserializer)?;
for window in wire.set.windows(2) {
if window[0].0 >= window[1].0 {
return Err(serde::de::Error::custom(
"PropertyDiff.set entries must be sorted by DbString order with no duplicate keys",
));
}
}
validate_sorted_unique(&wire.removed, "PropertyDiff.removed")?;
for (key, _) in wire.set.iter() {
if wire.removed.binary_search(key).is_ok() {
return Err(serde::de::Error::custom(format!(
"PropertyDiff: key {key} appears in both set and removed",
)));
}
}
Ok(Self {
set: wire.set,
removed: wire.removed,
})
}
}
fn sorted_deduped(values: impl IntoIterator<Item = DbString>) -> SmallVec<[DbString; 2]> {
let mut values: SmallVec<[DbString; 2]> = values.into_iter().collect();
values.sort();
values.dedup();
values
}
fn ensure_disjoint(
kind: &'static str,
added: &SmallVec<[DbString; 2]>,
removed: &SmallVec<[DbString; 2]>,
) -> CoreResult<()> {
for label in added.iter() {
if removed.binary_search(label).is_ok() {
return Err(CoreError::OverlappingDiff {
kind,
key: label.clone(),
});
}
}
Ok(())
}
fn validate_sorted_unique<E: serde::de::Error>(
values: &SmallVec<[DbString; 2]>,
label: &'static str,
) -> Result<(), E> {
for window in values.windows(2) {
if window[0] >= window[1] {
return Err(E::custom(format!(
"{label} must be sorted by DbString order with no duplicates"
)));
}
}
Ok(())
}
fn validate_disjoint<E: serde::de::Error>(
added: &SmallVec<[DbString; 2]>,
removed: &SmallVec<[DbString; 2]>,
kind: &'static str,
) -> Result<(), E> {
for label in added.iter() {
if removed.binary_search(label).is_ok() {
return Err(E::custom(format!(
"overlapping {kind} diff: {label} appears in both add/set and remove",
)));
}
}
Ok(())
}