use std::cmp::Ordering;
use std::convert::{TryFrom, TryInto};
use cid::Cid;
use libipld_core::ipld::Ipld;
use once_cell::unsync::OnceCell;
use serde::de::{self, DeserializeOwned};
use serde::{ser, Deserialize, Deserializer, Serialize, Serializer};
use super::node::Node;
use super::{Error, Hash, HashAlgorithm, KeyValuePair, MAX_ARRAY_WIDTH};
#[derive(Debug)]
pub(crate) enum Pointer<K, V, H> {
Values(Vec<KeyValuePair<K, V>>),
Link {
cid: Cid,
cache: OnceCell<Box<Node<K, V, H>>>,
},
Dirty(Box<Node<K, V, H>>),
}
impl<K: PartialEq, V: PartialEq, H> PartialEq for Pointer<K, V, H> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(&Pointer::Values(ref a), &Pointer::Values(ref b)) => a == b,
(&Pointer::Link { cid: ref a, .. }, &Pointer::Link { cid: ref b, .. }) => a == b,
(&Pointer::Dirty(ref a), &Pointer::Dirty(ref b)) => a == b,
_ => false,
}
}
}
impl<K, V, H> Serialize for Pointer<K, V, H>
where
K: Serialize,
V: Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Pointer::Values(vals) => vals.serialize(serializer),
Pointer::Link { cid, .. } => cid.serialize(serializer),
Pointer::Dirty(_) => Err(ser::Error::custom("Cannot serialize cached values")),
}
}
}
impl<K, V, H> TryFrom<Ipld> for Pointer<K, V, H>
where
K: DeserializeOwned,
V: DeserializeOwned,
{
type Error = String;
fn try_from(ipld: Ipld) -> Result<Self, Self::Error> {
match ipld {
ipld_list @ Ipld::List(_) => {
let values: Vec<KeyValuePair<K, V>> =
Deserialize::deserialize(ipld_list).map_err(|error| error.to_string())?;
Ok(Self::Values(values))
}
Ipld::Link(cid) => Ok(Self::Link {
cid,
cache: Default::default(),
}),
other => Err(format!(
"Expected `Ipld::List` or `Ipld::Link`, got {:#?}",
other
)),
}
}
}
impl<'de, K, V, H> Deserialize<'de> for Pointer<K, V, H>
where
K: DeserializeOwned,
V: DeserializeOwned,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Ipld::deserialize(deserializer).and_then(|ipld| ipld.try_into().map_err(de::Error::custom))
}
}
impl<K, V, H> Default for Pointer<K, V, H> {
fn default() -> Self {
Pointer::Values(Vec::new())
}
}
impl<K, V, H> Pointer<K, V, H>
where
K: Serialize + DeserializeOwned + Hash + PartialOrd,
V: Serialize + DeserializeOwned,
H: HashAlgorithm,
{
pub(crate) fn from_key_value(key: K, value: V) -> Self {
Pointer::Values(vec![KeyValuePair::new(key, value)])
}
pub(crate) fn clean(&mut self) -> Result<(), Error> {
match self {
Pointer::Dirty(n) => match n.pointers.len() {
0 => Err(Error::ZeroPointers),
1 => {
if let Pointer::Values(vals) = &mut n.pointers[0] {
let values = std::mem::take(vals);
*self = Pointer::Values(values)
}
Ok(())
}
2..=MAX_ARRAY_WIDTH => {
let mut children_len = 0;
for c in n.pointers.iter() {
if let Pointer::Values(vals) = c {
children_len += vals.len();
} else {
return Ok(());
}
}
if children_len > MAX_ARRAY_WIDTH {
return Ok(());
}
let mut child_vals: Vec<KeyValuePair<K, V>> = n
.pointers
.iter_mut()
.filter_map(|p| {
if let Pointer::Values(kvs) = p {
Some(std::mem::take(kvs))
} else {
None
}
})
.flatten()
.collect();
child_vals.sort_unstable_by(|a, b| {
a.key().partial_cmp(b.key()).unwrap_or(Ordering::Equal)
});
*self = Pointer::Values(child_vals);
Ok(())
}
_ => Ok(()),
},
_ => unreachable!("clean is only called on dirty pointer"),
}
}
}