use crate::error::AccountCompressionError;
use crate::events::ChangeLogEvent;
use anchor_lang::prelude::*;
use bytemuck::{cast_slice, cast_slice_mut};
use spl_concurrent_merkle_tree::node::{empty_node_cached, Node, EMPTY};
use std::mem::size_of;
#[inline(always)]
pub fn check_canopy_bytes(canopy_bytes: &[u8]) -> Result<()> {
if canopy_bytes.len() % size_of::<Node>() != 0 {
msg!(
"Canopy byte length {} is not a multiple of {}",
canopy_bytes.len(),
size_of::<Node>()
);
err!(AccountCompressionError::CanopyLengthMismatch)
} else {
Ok(())
}
}
#[inline(always)]
fn get_cached_path_length(canopy: &[Node], max_depth: u32) -> Result<u32> {
let closest_power_of_2 = (canopy.len() + 2) as u32;
if closest_power_of_2 & (closest_power_of_2 - 1) == 0 {
if closest_power_of_2 > (1 << (max_depth + 1)) {
msg!(
"Canopy size is too large. Size: {}. Max size: {}",
closest_power_of_2 - 2,
(1 << (max_depth + 1)) - 2
);
return err!(AccountCompressionError::CanopyLengthMismatch);
}
} else {
msg!(
"Canopy length {} is not 2 less than a power of 2",
canopy.len()
);
return err!(AccountCompressionError::CanopyLengthMismatch);
}
Ok(closest_power_of_2.trailing_zeros() - 1)
}
pub fn update_canopy(
canopy_bytes: &mut [u8],
max_depth: u32,
change_log: Option<&ChangeLogEvent>,
) -> Result<()> {
check_canopy_bytes(canopy_bytes)?;
let canopy = cast_slice_mut::<u8, Node>(canopy_bytes);
let path_len = get_cached_path_length(canopy, max_depth)?;
if let Some(cl_event) = change_log {
match &*cl_event {
ChangeLogEvent::V1(cl) => {
for path_node in cl.path.iter().rev().skip(1).take(path_len as usize) {
canopy[(path_node.index - 2) as usize] = path_node.node;
}
}
}
}
Ok(())
}
pub fn fill_in_proof_from_canopy(
canopy_bytes: &[u8],
max_depth: u32,
index: u32,
proof: &mut Vec<Node>,
) -> Result<()> {
let mut empty_node_cache = Box::new([EMPTY; 30]);
check_canopy_bytes(canopy_bytes)?;
let canopy = cast_slice::<u8, Node>(canopy_bytes);
let path_len = get_cached_path_length(canopy, max_depth)?;
let mut node_idx = ((1 << max_depth) + index) >> (max_depth - path_len);
let mut inferred_nodes = vec![];
while node_idx > 1 {
let shifted_index = node_idx as usize - 2;
let cached_idx = if shifted_index % 2 == 0 {
shifted_index + 1
} else {
shifted_index - 1
};
if canopy[cached_idx] == EMPTY {
let level = max_depth - (31 - node_idx.leading_zeros());
let empty_node = empty_node_cached::<30>(level, &mut empty_node_cache);
inferred_nodes.push(empty_node);
} else {
inferred_nodes.push(canopy[cached_idx]);
}
node_idx >>= 1;
}
let overlap = (proof.len() + inferred_nodes.len()).saturating_sub(max_depth as usize);
proof.extend(inferred_nodes.iter().skip(overlap));
Ok(())
}