mpl_token_auth_rules/state/v2/
rule_v2.rs

1use bytemuck::{Pod, Zeroable};
2use solana_program::{account_info::AccountInfo, entrypoint::ProgramResult, pubkey::Pubkey};
3use std::collections::HashMap;
4
5use crate::{
6    error::RuleSetError,
7    payload::Payload,
8    state::{constraint::*, Constraint, ConstraintType, RuleResult, U64_BYTES},
9    types::Assertable,
10};
11
12use super::try_from_bytes;
13
14/// Size (in bytes) of the header section.
15pub const HEADER_SECTION: usize = U64_BYTES;
16
17/// Macro to automate the code required to deserialize a constraint from a byte array.
18macro_rules! constraint_from_bytes {
19    ( $constraint_type:ident, $slice:expr, $( $available:ident ),+ $(,)? ) => {
20        match $constraint_type {
21            $(
22                $crate::state::ConstraintType::$available => {
23                    Box::new($available::from_bytes($slice)?) as Box<dyn Constraint>
24                }
25            )+
26            _ => return Err(RuleSetError::InvalidConstraintType),
27        }
28    };
29}
30
31/// Struct representing a 'RuleV2'.
32///
33/// A rule is a combination of a header and a constraint.
34pub struct RuleV2<'a> {
35    /// Header of the rule.
36    pub header: &'a Header,
37    /// Constraint represented by the rule.
38    pub constraint: Box<dyn Constraint<'a> + 'a>,
39}
40
41impl<'a> RuleV2<'a> {
42    /// Deserialize a constraint from a byte array.
43    pub fn from_bytes(bytes: &'a [u8]) -> Result<Self, RuleSetError> {
44        let (header, data) = bytes.split_at(HEADER_SECTION);
45        let header = try_from_bytes::<Header>(0, HEADER_SECTION, header)?;
46
47        let constraint_type = header.constraint_type();
48        let length = header.length();
49
50        let constraint = constraint_from_bytes!(
51            constraint_type,
52            &data[..length],
53            AdditionalSigner,
54            All,
55            Amount,
56            Any,
57            Frequency,
58            IsWallet,
59            Namespace,
60            Not,
61            Pass,
62            PDAMatch,
63            ProgramOwnedList,
64            ProgramOwnedTree,
65            ProgramOwned,
66            PubkeyListMatch,
67            PubkeyMatch,
68            PubkeyTreeMatch
69        );
70
71        Ok(Self { header, constraint })
72    }
73
74    /// Length (in bytes) of the serialized rule.
75    pub fn length(&self) -> usize {
76        HEADER_SECTION + self.header.length()
77    }
78}
79
80impl<'a> Assertable<'a> for RuleV2<'a> {
81    fn validate(
82        &self,
83        accounts: &HashMap<Pubkey, &AccountInfo>,
84        payload: &Payload,
85        update_rule_state: bool,
86        rule_set_state_pda: &Option<&AccountInfo>,
87        rule_authority: &Option<&AccountInfo>,
88    ) -> ProgramResult {
89        let result = self.constraint.validate(
90            accounts,
91            payload,
92            update_rule_state,
93            rule_set_state_pda,
94            rule_authority,
95        );
96
97        match result {
98            RuleResult::Success(_) => Ok(()),
99            RuleResult::Failure(err) => Err(err),
100            RuleResult::Error(err) => Err(err),
101        }
102    }
103}
104
105impl<'a> Constraint<'a> for RuleV2<'a> {
106    fn constraint_type(&self) -> ConstraintType {
107        self.constraint.constraint_type()
108    }
109
110    fn validate(
111        &self,
112        accounts: &std::collections::HashMap<
113            solana_program::pubkey::Pubkey,
114            &solana_program::account_info::AccountInfo,
115        >,
116        payload: &crate::payload::Payload,
117        update_rule_state: bool,
118        rule_set_state_pda: &Option<&solana_program::account_info::AccountInfo>,
119        rule_authority: &Option<&solana_program::account_info::AccountInfo>,
120    ) -> RuleResult {
121        self.constraint.validate(
122            accounts,
123            payload,
124            update_rule_state,
125            rule_set_state_pda,
126            rule_authority,
127        )
128    }
129}
130
131/// Header for the rule.
132#[repr(C)]
133#[derive(Clone, Copy, Pod, Zeroable)]
134pub struct Header {
135    /// Header data.
136    pub data: [u32; 2],
137}
138
139impl Header {
140    /// Returns the type of the constraint.
141    pub fn constraint_type(&self) -> ConstraintType {
142        ConstraintType::try_from(self.data[0]).unwrap()
143    }
144
145    /// Returns the length of the data section.
146    pub fn length(&self) -> usize {
147        self.data[1] as usize
148    }
149
150    /// Serialize the header.
151    pub fn serialize(constraint_type: ConstraintType, length: u32, data: &mut Vec<u8>) {
152        // constraint type
153        data.extend(u32::to_le_bytes(constraint_type as u32));
154        // length
155        data.extend(u32::to_le_bytes(length));
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::RuleV2;
162    use crate::state::v2::{Amount, Any, Operator, ProgramOwnedList, Str32};
163    use solana_program::pubkey::Pubkey;
164
165    #[test]
166    fn test_create_amount() {
167        let amount = Amount::serialize(String::from("Destination"), Operator::Eq, 1).unwrap();
168
169        // loads the data using bytemuck
170
171        let rule = RuleV2::from_bytes(&amount).unwrap();
172
173        assert_eq!(rule.header.length(), 48);
174    }
175
176    #[test]
177    fn test_create_program_owned_list() {
178        let programs = &[Pubkey::default(), Pubkey::default()];
179
180        let program_owned =
181            ProgramOwnedList::serialize(String::from("Destination"), programs).unwrap();
182
183        // loads the data using bytemuck
184
185        let rule = RuleV2::from_bytes(&program_owned).unwrap();
186
187        assert_eq!(rule.header.length(), 96);
188    }
189
190    #[test]
191    fn test_create_large_program_owned_list() {
192        const SIZE: usize = 1000;
193
194        let mut programs = Vec::new();
195
196        for _ in 0..SIZE {
197            programs.push(Pubkey::default());
198        }
199
200        let program_owned =
201            ProgramOwnedList::serialize(String::from("Destination"), programs.as_mut_slice())
202                .unwrap();
203
204        // loads the data using bytemuck
205
206        let rule = RuleV2::from_bytes(&program_owned).unwrap();
207
208        assert_eq!(rule.header.length(), Str32::SIZE + (SIZE * 32));
209    }
210
211    #[test]
212    fn test_create_any() {
213        let programs_list1 = &[Pubkey::default()];
214        let program_owned1 =
215            ProgramOwnedList::serialize(String::from("Destination"), programs_list1).unwrap();
216
217        let programs_list2 = &[Pubkey::default(), Pubkey::default(), Pubkey::default()];
218        let program_owned2 =
219            ProgramOwnedList::serialize(String::from("Destination"), programs_list2).unwrap();
220
221        let any = Any::serialize(&[&program_owned1, &program_owned2]).unwrap();
222
223        // loads the data using bytemuck
224        let rule = RuleV2::from_bytes(&any).unwrap();
225
226        assert_eq!(
227            rule.header.length(),
228            8 + program_owned1.len() + program_owned2.len()
229        );
230    }
231}