use std::convert::{TryFrom, TryInto};
use anyhow::anyhow;
use cid::multihash::Code;
use cid::Cid;
use fvm_ipld_blockstore::Blockstore;
use fvm_ipld_encoding::{strict_bytes, BytesSer, CborStore};
use once_cell::unsync::OnceCell;
use serde::de::{self, DeserializeOwned};
use serde::{ser, Deserialize, Serialize};
use super::ValueMut;
use crate::{bmap_bytes, init_sized_vec, nodes_for_height, Error};
#[derive(Debug)]
pub(super) enum Link<V> {
Cid {
cid: Cid,
cache: OnceCell<Box<Node<V>>>,
},
Dirty(Box<Node<V>>),
}
impl<'de, V> Deserialize<'de> for Link<V>
where
V: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: de::Deserializer<'de>,
{
let cid: Cid = Deserialize::deserialize(deserializer)?;
Ok(Link::Cid {
cid,
cache: Default::default(),
})
}
}
impl<V> PartialEq for Link<V>
where
V: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(&Link::Cid { cid: ref a, .. }, &Link::Cid { cid: ref b, .. }) => a == b,
(&Link::Dirty(ref a), &Link::Dirty(ref b)) => a == b,
_ => false,
}
}
}
impl<V> Eq for Link<V> where V: Eq {}
impl<V> From<Cid> for Link<V> {
fn from(cid: Cid) -> Link<V> {
Link::Cid {
cid,
cache: Default::default(),
}
}
}
#[derive(PartialEq, Eq, Debug)]
#[allow(clippy::large_enum_variant)]
pub(super) enum Node<V> {
Link { links: Vec<Option<Link<V>>> },
Leaf { vals: Vec<Option<V>> },
}
impl<V> Serialize for Node<V>
where
V: Serialize,
{
fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
where
S: ser::Serializer,
{
match &self {
Node::Leaf { vals } => {
let mut values = Vec::<&V>::with_capacity(vals.len());
let mut bmap = vec![0u8; ((vals.len().saturating_sub(1)) / 8) + 1];
for (i, v) in vals.iter().enumerate() {
if let Some(val) = v {
values.push(val);
bmap[i / 8] |= 1 << (i % 8);
}
}
(BytesSer(&bmap), Vec::<&Cid>::new(), values).serialize(s)
}
Node::Link { links } => {
let mut collapsed = Vec::<&Cid>::with_capacity(links.len());
let mut bmap = vec![0u8; ((links.len().saturating_sub(1)) / 8) + 1];
for (i, v) in links.iter().enumerate() {
if let Some(val) = v {
if let Link::Cid { cid, .. } = val {
collapsed.push(cid);
bmap[i / 8] |= 1 << (i % 8);
} else {
return Err(ser::Error::custom(Error::Cached));
}
}
}
(BytesSer(&bmap), collapsed, Vec::<&V>::new()).serialize(s)
}
}
}
}
#[derive(Serialize, Deserialize)]
pub(crate) struct CollapsedNode<V>(#[serde(with = "strict_bytes")] Vec<u8>, Vec<Cid>, Vec<V>);
impl<V> CollapsedNode<V> {
pub(crate) fn expand(self, bit_width: u32) -> Result<Node<V>, Error> {
let CollapsedNode(bmap, links, values) = self;
if !links.is_empty() && !values.is_empty() {
return Err(Error::LinksAndValues);
}
if bmap_bytes(bit_width) != bmap.len() {
return Err(anyhow!(
"expected bitfield of length {}, found bitfield with length {}",
bmap_bytes(bit_width),
bmap.len()
)
.into());
}
if !links.is_empty() {
let mut links_iter = links.into_iter();
let mut links = init_sized_vec::<Link<V>>(bit_width);
for (i, v) in links.iter_mut().enumerate() {
if bmap[i / 8] & (1 << (i % 8)) != 0 {
*v = Some(Link::from(links_iter.next().ok_or_else(|| {
anyhow!("Bitmap contained more set bits than links provided",)
})?))
}
}
if links_iter.next().is_some() {
return Err(anyhow!("Bitmap contained less set bits than links provided",).into());
}
Ok(Node::Link { links })
} else {
let mut val_iter = values.into_iter();
let mut vals = init_sized_vec::<V>(bit_width);
for (i, v) in vals.iter_mut().enumerate() {
if bmap[i / 8] & (1 << (i % 8)) != 0 {
*v = Some(val_iter.next().ok_or_else(|| {
anyhow!("Bitmap contained more set bits than values provided")
})?)
}
}
if val_iter.next().is_some() {
return Err(anyhow!("Bitmap contained less set bits than values provided").into());
}
Ok(Node::Leaf { vals })
}
}
}
impl<V> Node<V>
where
V: Serialize + DeserializeOwned,
{
pub(super) fn empty() -> Self {
Node::Leaf {
vals: Default::default(),
}
}
pub(super) fn flush<DB: Blockstore>(&mut self, bs: &DB) -> Result<(), Error> {
if let Node::Link { links } = self {
for link in links.iter_mut().flatten() {
if let Link::Dirty(n) = link {
n.flush(bs)?;
let cid = bs.put_cbor(n, Code::Blake2b256)?;
let existing = std::mem::replace(n, Box::new(Node::empty()));
let cache = OnceCell::from(existing);
*link = Link::Cid { cid, cache };
}
}
}
Ok(())
}
pub(super) fn can_collapse(&self) -> bool {
match self {
Node::Link { links } => {
links.get(0).and_then(|l| l.as_ref()).is_some()
&& links
.get(1..)
.map(|l| l.iter().all(|l| l.is_none()))
.unwrap_or(true)
}
Node::Leaf { .. } => false,
}
}
pub(super) fn is_empty(&self) -> bool {
match self {
Node::Link { links } => links.iter().all(|l| l.is_none()),
Node::Leaf { vals } => vals.iter().all(|l| l.is_none()),
}
}
pub(super) fn get<DB: Blockstore>(
&self,
bs: &DB,
height: u32,
bit_width: u32,
i: u64,
) -> Result<Option<&V>, Error> {
match self {
Node::Leaf { vals, .. } => Ok(vals.get(i as usize).and_then(|v| v.as_ref())),
Node::Link { links, .. } => {
let sub_i: usize = (i / nodes_for_height(bit_width, height))
.try_into()
.unwrap();
match links.get(sub_i).and_then(|v| v.as_ref()) {
Some(Link::Cid { cid, cache }) => {
let cached_node = cache.get_or_try_init(|| {
bs.get_cbor::<CollapsedNode<V>>(cid)?
.ok_or_else(|| Error::CidNotFound(cid.to_string()))?
.expand(bit_width)
.map(Box::new)
})?;
cached_node.get(
bs,
height - 1,
bit_width,
i % nodes_for_height(bit_width, height),
)
}
Some(Link::Dirty(n)) => n.get(
bs,
height - 1,
bit_width,
i % nodes_for_height(bit_width, height),
),
None => Ok(None),
}
}
}
}
pub(super) fn set<DB: Blockstore>(
&mut self,
bs: &DB,
height: u32,
bit_width: u32,
i: u64,
val: V,
) -> Result<Option<V>, Error> {
if height == 0 {
return Ok(self.set_leaf(i, val));
}
let nfh = nodes_for_height(bit_width, height);
let idx: usize = (i / nfh).try_into().expect("index overflow");
if let Node::Link { links } = self {
links[idx] = match &mut links[idx] {
Some(Link::Cid { cid, cache }) => {
let cache_node = std::mem::take(cache);
let sub_node = if let Some(sn) = cache_node.into_inner() {
sn
} else {
bs.get_cbor::<CollapsedNode<V>>(cid)?
.ok_or_else(|| Error::CidNotFound(cid.to_string()))?
.expand(bit_width)
.map(Box::new)?
};
Some(Link::Dirty(sub_node))
}
None => {
let node = match height {
1 => Node::Leaf {
vals: init_sized_vec(bit_width),
},
_ => Node::Link {
links: init_sized_vec(bit_width),
},
};
Some(Link::Dirty(Box::new(node)))
}
Some(Link::Dirty(node)) => {
return node.set(bs, height - 1, bit_width, i % nfh, val)
}
};
if let Some(Link::Dirty(n)) = &mut links[idx] {
n.set(bs, height - 1, bit_width, i % nfh, val)
} else {
unreachable!("Value is set as cached")
}
} else {
unreachable!("should not be handled");
}
}
fn set_leaf(&mut self, i: u64, val: V) -> Option<V> {
match self {
Node::Leaf { vals } => {
let prev = std::mem::replace(
vals.get_mut(usize::try_from(i).unwrap()).unwrap(),
Some(val),
);
prev
}
Node::Link { .. } => panic!("set_leaf should never be called on a shard of links"),
}
}
pub(super) fn delete<DB: Blockstore>(
&mut self,
bs: &DB,
height: u32,
bit_width: u32,
i: u64,
) -> Result<Option<V>, Error> {
match self {
Self::Leaf { vals } => Ok(vals
.get_mut(usize::try_from(i).unwrap())
.and_then(std::mem::take)),
Self::Link { links } => {
let sub_i: usize = (i / nodes_for_height(bit_width, height))
.try_into()
.unwrap();
let (deleted, replace) = match &mut links[sub_i] {
Some(Link::Dirty(n)) => {
let deleted = n.delete(
bs,
height - 1,
bit_width,
i % nodes_for_height(bit_width, height),
)?;
if deleted.is_none() {
return Ok(None);
}
if !n.is_empty() {
return Ok(deleted);
}
(deleted, None)
}
Some(Link::Cid { cid, cache }) => {
cache.get_or_try_init(|| {
bs.get_cbor::<CollapsedNode<V>>(cid)?
.ok_or_else(|| Error::CidNotFound(cid.to_string()))?
.expand(bit_width)
.map(Box::new)
})?;
let sub_node = cache.get_mut().expect("filled line above");
let deleted = sub_node.delete(
bs,
height - 1,
bit_width,
i % nodes_for_height(bit_width, height),
)?;
if deleted.is_none() {
return Ok(None);
};
let sub_node = std::mem::replace(sub_node, Box::new(Node::empty()));
if sub_node.is_empty() {
(deleted, None)
} else {
(deleted, Some(Link::Dirty(sub_node)))
}
}
None => return Ok(None),
};
links[sub_i] = replace;
Ok(deleted)
}
}
}
pub(super) fn for_each_while<S, F>(
&self,
bs: &S,
height: u32,
bit_width: u32,
offset: u64,
f: &mut F,
) -> Result<bool, Error>
where
F: FnMut(u64, &V) -> anyhow::Result<bool>,
S: Blockstore,
{
match self {
Node::Leaf { vals } => {
for (i, v) in (0..).zip(vals.iter()) {
if let Some(v) = v {
let keep_going = f(offset + i, v)?;
if !keep_going {
return Ok(false);
}
}
}
}
Node::Link { links } => {
for (i, l) in (0..).zip(links.iter()) {
if let Some(l) = l {
let offs = offset + (i * nodes_for_height(bit_width, height));
let keep_going = match l {
Link::Dirty(sub) => {
sub.for_each_while(bs, height - 1, bit_width, offs, f)?
}
Link::Cid { cid, cache } => {
let cached_node = cache.get_or_try_init(|| {
bs.get_cbor::<CollapsedNode<V>>(cid)?
.ok_or_else(|| Error::CidNotFound(cid.to_string()))?
.expand(bit_width)
.map(Box::new)
})?;
cached_node.for_each_while(bs, height - 1, bit_width, offs, f)?
}
};
if !keep_going {
return Ok(false);
}
}
}
}
}
Ok(true)
}
pub(super) fn for_each_while_mut<S, F>(
&mut self,
bs: &S,
height: u32,
bit_width: u32,
offset: u64,
f: &mut F,
) -> Result<(bool, bool), Error>
where
F: FnMut(u64, &mut ValueMut<'_, V>) -> anyhow::Result<bool>,
S: Blockstore,
{
let mut did_mutate = false;
match self {
Node::Leaf { vals } => {
for (i, v) in (0..).zip(vals.iter_mut()) {
if let Some(v) = v {
let mut value_mut = ValueMut::new(v);
let keep_going = f(offset + i, &mut value_mut)?;
did_mutate |= value_mut.value_changed();
if !keep_going {
return Ok((false, did_mutate));
}
}
}
}
Node::Link { links } => {
for (i, l) in (0..).zip(links.iter_mut()) {
if let Some(link) = l {
let offs = offset + (i * nodes_for_height(bit_width, height));
let (keep_going, did_mutate_node) = match link {
Link::Dirty(sub) => {
sub.for_each_while_mut(bs, height - 1, bit_width, offs, f)?
}
Link::Cid { cid, cache } => {
cache.get_or_try_init(|| {
bs.get_cbor::<CollapsedNode<V>>(cid)?
.ok_or_else(|| Error::CidNotFound(cid.to_string()))?
.expand(bit_width)
.map(Box::new)
})?;
let node = cache.get_mut().expect("cache filled on line above");
let (keep_going, did_mutate_node) =
node.for_each_while_mut(bs, height - 1, bit_width, offs, f)?;
if did_mutate_node {
*link = Link::Dirty(std::mem::replace(
node,
Box::new(Node::empty()),
));
}
(keep_going, did_mutate_node)
}
};
did_mutate |= did_mutate_node;
if !keep_going {
return Ok((false, did_mutate));
}
}
}
}
}
Ok((true, did_mutate))
}
}
#[cfg(test)]
mod tests {
use fvm_ipld_encoding::{from_slice, to_vec};
use super::*;
#[test]
fn serialize_node_symmetric() {
let node = Node::Leaf { vals: vec![None] };
let nbz = to_vec(&node).unwrap();
assert_eq!(
from_slice::<CollapsedNode<u8>>(&nbz)
.unwrap()
.expand(0)
.unwrap(),
node
);
}
}