use super::ValueMut;
use crate::{
init_sized_vec,
node::{CollapsedNode, Link},
nodes_for_height, Error, Node, Root, DEFAULT_BIT_WIDTH, MAX_HEIGHT, MAX_INDEX,
};
use cid::{Cid, Code::Blake2b256};
use encoding::{de::DeserializeOwned, ser::Serialize};
use ipld_blockstore::BlockStore;
use itertools::sorted;
use std::error::Error as StdError;
#[derive(Debug)]
pub struct Amt<'db, V, BS> {
root: Root<V>,
block_store: &'db BS,
}
impl<'a, V: PartialEq, BS: BlockStore> PartialEq for Amt<'a, V, BS> {
fn eq(&self, other: &Self) -> bool {
self.root == other.root
}
}
impl<'db, V, BS> Amt<'db, V, BS>
where
V: DeserializeOwned + Serialize,
BS: BlockStore,
{
pub fn new(block_store: &'db BS) -> Self {
Self::new_with_bit_width(block_store, DEFAULT_BIT_WIDTH)
}
pub fn new_with_bit_width(block_store: &'db BS, bit_width: usize) -> Self {
Self {
root: Root::new(bit_width),
block_store,
}
}
fn bit_width(&self) -> usize {
self.root.bit_width
}
pub fn load(cid: &Cid, block_store: &'db BS) -> Result<Self, Error> {
let root: Root<V> = block_store
.get(cid)?
.ok_or_else(|| Error::CidNotFound(cid.to_string()))?;
if root.height > MAX_HEIGHT {
return Err(Error::MaxHeight(root.height, MAX_HEIGHT));
}
Ok(Self { root, block_store })
}
pub fn height(&self) -> usize {
self.root.height
}
pub fn count(&self) -> usize {
self.root.count
}
pub fn new_from_iter(
block_store: &'db BS,
vals: impl IntoIterator<Item = V>,
) -> Result<Cid, Error> {
let mut t = Self::new(block_store);
t.batch_set(vals)?;
t.flush()
}
pub fn get(&self, i: usize) -> Result<Option<&V>, Error> {
if i > MAX_INDEX {
return Err(Error::OutOfRange(i));
}
if i >= nodes_for_height(self.bit_width(), self.height() + 1) {
return Ok(None);
}
self.root
.node
.get(self.block_store, self.height(), self.bit_width(), i)
}
pub fn set(&mut self, i: usize, val: V) -> Result<(), Error> {
if i > MAX_INDEX {
return Err(Error::OutOfRange(i));
}
while i >= nodes_for_height(self.bit_width(), self.height() + 1) {
if !self.root.node.is_empty() {
let mut new_links: Vec<Option<Link<V>>> = init_sized_vec(self.root.bit_width);
let node = std::mem::replace(&mut self.root.node, Node::empty());
new_links[0] = Some(Link::Dirty(Box::new(node)));
self.root.node = Node::Link { links: new_links };
} else {
self.root.node = Node::Link {
links: init_sized_vec(self.bit_width()),
};
}
self.root.height += 1;
}
if self
.root
.node
.set(self.block_store, self.height(), self.bit_width(), i, val)?
.is_none()
{
self.root.count += 1;
}
Ok(())
}
pub fn batch_set(&mut self, vals: impl IntoIterator<Item = V>) -> Result<(), Error> {
for (i, val) in vals.into_iter().enumerate() {
self.set(i, val)?;
}
Ok(())
}
pub fn delete(&mut self, i: usize) -> Result<Option<V>, Error> {
if i > MAX_INDEX {
return Err(Error::OutOfRange(i));
}
if i >= nodes_for_height(self.bit_width(), self.height() + 1) {
return Ok(None);
}
let deleted =
self.root
.node
.delete(self.block_store, self.height(), self.bit_width(), i)?;
if deleted.is_none() {
return Ok(None);
}
self.root.count -= 1;
if self.root.node.is_empty() {
self.root.node = Node::Leaf {
vals: init_sized_vec(self.root.bit_width),
};
self.root.height = 0;
} else {
while self.root.node.can_collapse() && self.height() > 0 {
let sub_node: Node<V> = match &mut self.root.node {
Node::Link { links, .. } => match &mut links[0] {
Some(Link::Dirty(node)) => {
*std::mem::replace(node, Box::new(Node::empty()))
}
Some(Link::Cid { cid, cache }) => {
let cache_node = std::mem::take(cache);
if let Some(sn) = cache_node.into_inner() {
*sn
} else {
self.block_store
.get::<CollapsedNode<V>>(cid)?
.ok_or_else(|| Error::CidNotFound(cid.to_string()))?
.expand(self.root.bit_width)?
}
}
_ => unreachable!("First index checked to be Some in `can_collapse`"),
},
Node::Leaf { .. } => unreachable!("Non zero height cannot be a leaf node"),
};
self.root.node = sub_node;
self.root.height -= 1;
}
}
Ok(deleted)
}
pub fn batch_delete(
&mut self,
iter: impl IntoIterator<Item = usize>,
strict: bool,
) -> Result<bool, Error> {
let mut modified = false;
for i in sorted(iter) {
let found = self.delete(i)?.is_none();
if strict && found {
return Err(Error::Other(format!(
"no such index {} in Amt for batch delete",
i
)));
}
modified |= found;
}
Ok(modified)
}
pub fn flush(&mut self) -> Result<Cid, Error> {
self.root.node.flush(self.block_store)?;
Ok(self.block_store.put(&self.root, Blake2b256)?)
}
#[inline]
pub fn for_each<F>(&self, mut f: F) -> Result<(), Box<dyn StdError>>
where
F: FnMut(usize, &V) -> Result<(), Box<dyn StdError>>,
{
self.for_each_while(|i, x| {
f(i, x)?;
Ok(true)
})
}
pub fn for_each_while<F>(&self, mut f: F) -> Result<(), Box<dyn StdError>>
where
F: FnMut(usize, &V) -> Result<bool, Box<dyn StdError>>,
{
self.root
.node
.for_each_while(self.block_store, self.height(), self.bit_width(), 0, &mut f)
.map(|_| ())
}
pub fn for_each_mut<F>(&mut self, mut f: F) -> Result<(), Box<dyn StdError>>
where
V: Clone,
F: FnMut(usize, &mut ValueMut<'_, V>) -> Result<(), Box<dyn StdError>>,
{
self.for_each_while_mut(|i, x| {
f(i, x)?;
Ok(true)
})
}
pub fn for_each_while_mut<F>(&mut self, mut f: F) -> Result<(), Box<dyn StdError>>
where
V: Clone,
F: FnMut(usize, &mut ValueMut<'_, V>) -> Result<bool, Box<dyn StdError>>,
{
#[cfg(not(feature = "go-interop"))]
{
self.root
.node
.for_each_while_mut(self.block_store, self.height(), self.bit_width(), 0, &mut f)
.map(|_| ())
}
#[cfg(feature = "go-interop")]
{
let mut mutated = ahash::AHashMap::new();
self.root.node.for_each_while_mut(
self.block_store,
self.height(),
self.bit_width(),
0,
&mut |idx, value| {
let keep_going = f(idx, value)?;
if value.value_changed() {
value.mark_unchanged();
mutated.insert(idx, value.clone());
}
Ok(keep_going)
},
)?;
for (i, v) in mutated.into_iter() {
self.set(i, v)?;
}
Ok(())
}
}
}