use solana_program::{
msg,
program_error::ProgramError,
program_memory::sol_memcmp,
pubkey::{Pubkey, PUBKEY_BYTES},
};
use super::{try_cast_slice, try_from_bytes, Constraint, ConstraintType, RuleV2, Str32, U64_BYTES};
use crate::{
error::RuleSetError,
types::{Assertable, LibVersion, RuleSet},
};
const EMPTY: usize = 0;
pub struct RuleSetV2<'a> {
header: &'a [u32; 2],
pub owner: &'a Pubkey,
pub rule_set_name: &'a Str32,
pub operations: &'a [Str32],
pub rules: Vec<RuleV2<'a>>,
}
impl<'a> RuleSetV2<'a> {
pub fn size(&self) -> u32 {
self.header[1]
}
pub fn from_bytes(bytes: &'a [u8]) -> Result<Self, RuleSetError> {
let header = try_from_bytes::<[u32; 2]>(0, U64_BYTES, bytes)?;
let mut cursor = U64_BYTES;
let owner = try_from_bytes::<Pubkey>(cursor, PUBKEY_BYTES, bytes)?;
cursor += PUBKEY_BYTES;
let rule_set_name = try_from_bytes::<Str32>(cursor, Str32::SIZE, bytes)?;
cursor += Str32::SIZE;
let size = header[1] as usize;
let slice_end = cursor
+ Str32::SIZE
.checked_mul(size)
.ok_or(RuleSetError::NumericalOverflow)?;
if size > 0 && (slice_end + 1) > bytes.len() {
msg!("Invalid slice end: {} > {}", slice_end, bytes.len());
return Err(RuleSetError::RuleSetReadFailed);
}
let operations = try_cast_slice(&bytes[cursor..slice_end])?;
cursor = slice_end;
let mut rules = Vec::with_capacity(size);
for _ in 0..size {
let rule = RuleV2::from_bytes(&bytes[cursor..]).unwrap();
cursor += rule.length();
rules.push(rule);
}
Ok(Self {
header,
owner,
rule_set_name,
operations,
rules,
})
}
pub fn serialize(
owner: Pubkey,
name: &str,
operations: &[String],
rules: &[&[u8]],
) -> Result<Vec<u8>, RuleSetError> {
let length = U64_BYTES
+ PUBKEY_BYTES
+ Str32::SIZE
+ (operations.len() * Str32::SIZE)
+ rules
.iter()
.map(|v| v.len())
.reduce(|accum, item| accum + item)
.unwrap_or(EMPTY);
let mut data = Vec::with_capacity(length);
data.extend([LibVersion::V2 as u8, 0, 0, 0]);
data.extend(u32::to_le_bytes(operations.len() as u32));
data.extend(owner.as_ref());
let mut field_bytes = [0u8; Str32::SIZE];
field_bytes[..name.len()].copy_from_slice(name.as_bytes());
data.extend(field_bytes);
if (1..operations.len()).any(|i| operations[i..].contains(&operations[i - 1])) {
return Err(RuleSetError::DuplicatedOperationName);
}
operations.iter().for_each(|x| {
let mut field_bytes = [0u8; Str32::SIZE];
field_bytes[..x.len()].copy_from_slice(x.as_bytes());
data.extend(field_bytes);
});
rules.iter().for_each(|x| data.extend(x.iter()));
Ok(data)
}
pub fn get(&self, operation: String) -> Option<&RuleV2<'a>> {
let mut bytes = [0u8; Str32::SIZE];
bytes[..operation.len()].copy_from_slice(operation.as_bytes());
for (i, operation) in self.operations.iter().enumerate() {
if sol_memcmp(&operation.value, &bytes, bytes.len()) == 0 {
return Some(&self.rules[i]);
}
}
None
}
}
impl<'a> RuleSet<'a> for RuleSetV2<'a> {
fn name(&self) -> String {
self.rule_set_name.to_string()
}
fn owner(&self) -> &Pubkey {
self.owner
}
fn lib_version(&self) -> u8 {
(self.header[0] & 0x000000ff) as u8
}
fn get_rule(&self, operation: String) -> Result<&dyn Assertable<'a>, ProgramError> {
let rule = self.get(operation.to_string());
match rule {
Some(rule) => {
match rule.constraint_type() {
ConstraintType::Namespace => {
let split = operation.split(':').collect::<Vec<&str>>();
if split.len() > 1 {
self.get_rule(split[0].to_owned())
} else {
Err(RuleSetError::OperationNotFound.into())
}
}
_ => Ok(rule),
}
}
None => Err(RuleSetError::OperationNotFound.into()),
}
}
}
#[cfg(test)]
mod tests {
use crate::{
error::RuleSetError,
state::v2::{Amount, Operator, ProgramOwnedList, RuleSetV2},
types::{LibVersion, RuleSet},
};
use solana_program::pubkey::Pubkey;
#[test]
fn test_create_amount() {
let amount = Amount::serialize(String::from("Destination"), Operator::Eq, 1).unwrap();
let programs = &[Pubkey::default(), Pubkey::default()];
let program_owned =
ProgramOwnedList::serialize(String::from("Destination"), programs).unwrap();
let serialized = RuleSetV2::serialize(
Pubkey::default(),
"Royalties",
&["deletage_transfer".to_string(), "transfer".to_string()],
&[&amount, &program_owned],
)
.unwrap();
let rule_set = RuleSetV2::from_bytes(&serialized).unwrap();
assert_eq!(rule_set.operations.len(), 2);
assert_eq!(rule_set.rules.len(), 2);
assert_eq!(rule_set.lib_version(), LibVersion::V2 as u8);
}
#[test]
fn test_duplicated_operation_name() {
let amount = Amount::serialize(String::from("Destination"), Operator::Eq, 1).unwrap();
let programs = &[Pubkey::default(), Pubkey::default()];
let program_owned =
ProgramOwnedList::serialize(String::from("Destination"), programs).unwrap();
let error = RuleSetV2::serialize(
Pubkey::default(),
"Royalties",
&["transfer".to_string(), "transfer".to_string()],
&[&amount, &program_owned],
)
.unwrap_err();
assert_eq!(error, RuleSetError::DuplicatedOperationName);
}
}