Skip to main content

kora_lib/usage_limit/rules/
instruction.rs

1use std::collections::HashMap;
2
3use solana_sdk::pubkey::Pubkey;
4use solana_system_interface::program::ID as SYSTEM_PROGRAM_ID;
5use spl_associated_token_account_interface::program::ID as ATA_PROGRAM_ID;
6
7use crate::transaction::{ParsedSystemInstructionData, ParsedSystemInstructionType};
8
9use super::super::limiter::LimiterContext;
10
11const IX_KEY_PREFIX: &str = "kora:ix";
12
13// System Program instruction names
14const SYSTEM_CREATE_ACCOUNT: &str = "createaccount";
15const SYSTEM_CREATE_ACCOUNT_WITH_SEED: &str = "createaccountwithseed";
16
17// ATA Program instruction names
18const ATA_CREATE: &str = "create";
19const ATA_CREATE_IDEMPOTENT: &str = "createidempotent";
20
21/// Rule that limits specific instruction types per wallet
22///
23/// Counts matching instructions in each transaction and enforces limits.
24/// Supports both lifetime limits (never resets) and time-windowed limits (resets periodically).
25///
26/// Currently supported instruction types:
27/// - System: CreateAccount / CreateAccountWithSeed
28/// - ATA: CreateIdempotent / Create
29#[derive(Debug)]
30pub struct InstructionRule {
31    program: Pubkey,
32    instruction: String,
33    max: u64,
34    window_seconds: Option<u64>,
35}
36
37impl InstructionRule {
38    pub fn new(
39        program: Pubkey,
40        instruction: String,
41        max: u64,
42        window_seconds: Option<u64>,
43    ) -> Self {
44        let lowered = instruction.to_lowercase();
45        Self { program, instruction: lowered, max, window_seconds }
46    }
47
48    /// Create a lifetime instruction limit (never resets)
49    pub fn lifetime(program: Pubkey, instruction: String, max: u64) -> Self {
50        Self::new(program, instruction, max, None)
51    }
52
53    /// Create a time-windowed instruction limit
54    pub fn windowed(program: Pubkey, instruction: String, max: u64, window_seconds: u64) -> Self {
55        Self::new(program, instruction, max, Some(window_seconds))
56    }
57
58    /// Count matching instructions for one or more rules in a single pass
59    /// Only counts instructions where Kora is the payer (subsidized operations)
60    pub fn count_all_rules(rules: &[&InstructionRule], ctx: &mut LimiterContext<'_>) -> Vec<u64> {
61        if rules.is_empty() {
62            return vec![];
63        }
64
65        // Group rules by program ID
66        let mut system_rules: Vec<(usize, &InstructionRule)> = vec![];
67        let mut ata_rules: Vec<(usize, &InstructionRule)> = vec![];
68        let mut other_rules: Vec<(usize, &InstructionRule)> = vec![];
69
70        for (idx, rule) in rules.iter().enumerate() {
71            if rule.program == SYSTEM_PROGRAM_ID {
72                system_rules.push((idx, rule));
73            } else if rule.program == ATA_PROGRAM_ID {
74                ata_rules.push((idx, rule));
75            } else {
76                other_rules.push((idx, rule));
77            }
78        }
79
80        let mut counts = vec![0u64; rules.len()];
81
82        // Count System instructions
83        if !system_rules.is_empty() {
84            match ctx.transaction.get_or_parse_system_instructions() {
85                Ok(parsed) => {
86                    let kora_signer = ctx.kora_signer;
87                    Self::count_batch_system_instructions(
88                        &system_rules,
89                        parsed,
90                        kora_signer,
91                        &mut counts,
92                    );
93                }
94                Err(_) => {
95                    Self::count_batch_manual(&system_rules, ctx, &mut counts);
96                }
97            }
98        }
99
100        // Count ATA instructions (manual parsing)
101        if !ata_rules.is_empty() {
102            Self::count_batch_manual(&ata_rules, ctx, &mut counts);
103        }
104
105        // Count other program instructions (manual parsing)
106        if !other_rules.is_empty() {
107            Self::count_batch_manual(&other_rules, ctx, &mut counts);
108        }
109
110        counts
111    }
112
113    /// Batch count system instructions for multiple rules
114    /// Only counts instructions where Kora is the payer (subsidized operations)
115    fn count_batch_system_instructions(
116        rules: &[(usize, &InstructionRule)],
117        parsed: &HashMap<ParsedSystemInstructionType, Vec<ParsedSystemInstructionData>>,
118        kora_signer: Option<Pubkey>,
119        counts: &mut [u64],
120    ) {
121        for (idx, rule) in rules {
122            let matching_type = match rule.instruction.as_str() {
123                SYSTEM_CREATE_ACCOUNT | SYSTEM_CREATE_ACCOUNT_WITH_SEED => {
124                    Some(ParsedSystemInstructionType::SystemCreateAccount)
125                }
126                _ => None,
127            };
128
129            if let Some(ix_type) = matching_type {
130                if let Some(instructions) = parsed.get(&ix_type) {
131                    let count = instructions
132                        .iter()
133                        .filter(|ix_data| {
134                            match ix_data {
135                                ParsedSystemInstructionData::SystemCreateAccount {
136                                    payer, ..
137                                } => {
138                                    // Count instructions where Kora IS the payer
139                                    // This tracks subsidized account creations
140                                    kora_signer == Some(*payer)
141                                }
142                                _ => false,
143                            }
144                        })
145                        .count() as u64;
146                    counts[*idx] = count;
147                } else {
148                    counts[*idx] = 0;
149                }
150            }
151        }
152    }
153
154    /// Batch count using manual parsing
155    /// Only counts instructions where Kora is the payer (subsidized operations)
156    fn count_batch_manual(
157        rules: &[(usize, &InstructionRule)],
158        ctx: &LimiterContext<'_>,
159        counts: &mut [u64],
160    ) {
161        let kora_signer = ctx.kora_signer;
162
163        for instruction in ctx.transaction.all_instructions.iter() {
164            for (idx, rule) in rules {
165                if instruction.program_id != rule.program {
166                    continue;
167                }
168
169                if let Some(instr_type) =
170                    InstructionIdentifier::identify(&instruction.program_id, &instruction.data)
171                {
172                    if instr_type == rule.instruction {
173                        // For ATA instructions, check if Kora is the payer (first account)
174                        if rule.program == ATA_PROGRAM_ID {
175                            match (instruction.accounts.first(), kora_signer) {
176                                (Some(payer), Some(kora)) if payer.pubkey == kora => {
177                                    counts[*idx] += 1;
178                                }
179                                _ => {}
180                            }
181                        } else {
182                            // For other programs, count all matching instructions
183                            counts[*idx] += 1;
184                        }
185                    }
186                }
187            }
188        }
189    }
190
191    pub fn storage_key(&self, user_id: &str, timestamp: u64) -> String {
192        let base = format!("{IX_KEY_PREFIX}:{user_id}:{}:{}", self.program, self.instruction);
193        match self.window_seconds {
194            Some(window) if window > 0 => format!("{base}:{}", timestamp / window),
195            _ => base,
196        }
197    }
198
199    /// How many units to increment for this transaction
200    pub fn count_increment(&self, ctx: &mut LimiterContext<'_>) -> u64 {
201        Self::count_all_rules(&[self], ctx).into_iter().next().unwrap_or(0)
202    }
203
204    /// Maximum allowed count within the window (or lifetime)
205    pub fn max(&self) -> u64 {
206        self.max
207    }
208
209    /// Time window in seconds
210    pub fn window_seconds(&self) -> Option<u64> {
211        self.window_seconds
212    }
213
214    pub fn description(&self) -> String {
215        let window = self.window_seconds.map_or("lifetime".to_string(), |w| format!("per {w}s"));
216        format!("{} on {} ({window})", self.instruction, self.program)
217    }
218}
219
220pub struct InstructionIdentifier;
221
222impl InstructionIdentifier {
223    pub fn identify(program_id: &Pubkey, data: &[u8]) -> Option<String> {
224        match *program_id {
225            _ if *program_id == SYSTEM_PROGRAM_ID => Self::system(data),
226            _ if *program_id == ATA_PROGRAM_ID => Self::ata(data),
227            _ => None,
228        }
229    }
230
231    fn system(data: &[u8]) -> Option<String> {
232        let discriminator = u32::from_le_bytes(data.get(..4)?.try_into().ok()?);
233        match discriminator {
234            0 => Some(SYSTEM_CREATE_ACCOUNT.to_string()),
235            3 => Some(SYSTEM_CREATE_ACCOUNT_WITH_SEED.to_string()),
236            _ => None,
237        }
238    }
239
240    fn ata(data: &[u8]) -> Option<String> {
241        match data.first().copied() {
242            None | Some(0) => Some(ATA_CREATE.to_string()),
243            Some(1) => Some(ATA_CREATE_IDEMPOTENT.to_string()),
244            _ => None,
245        }
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use crate::tests::transaction_mock::create_mock_resolved_transaction;
253
254    #[test]
255    fn test_instruction_rule_lifetime_key() {
256        let rule =
257            InstructionRule::lifetime(SYSTEM_PROGRAM_ID, SYSTEM_CREATE_ACCOUNT.to_string(), 10);
258        let user_id = "test-user-123";
259
260        let key = rule.storage_key(user_id, 1000000);
261        assert_eq!(key, format!("kora:ix:{}:{}:createaccount", user_id, SYSTEM_PROGRAM_ID));
262    }
263
264    #[test]
265    fn test_instruction_rule_windowed_key() {
266        let rule = InstructionRule::windowed(
267            SYSTEM_PROGRAM_ID,
268            SYSTEM_CREATE_ACCOUNT.to_string(),
269            10,
270            3600,
271        );
272        let user_id = "test-user-456";
273
274        let key1 = rule.storage_key(user_id, 3600);
275        let key2 = rule.storage_key(user_id, 7199);
276        let key3 = rule.storage_key(user_id, 7200);
277
278        assert!(key1.ends_with(":1"));
279        assert!(key2.ends_with(":1"));
280        assert!(key3.ends_with(":2"));
281    }
282
283    #[test]
284    fn test_instruction_rule_count_no_match() {
285        let rule =
286            InstructionRule::lifetime(SYSTEM_PROGRAM_ID, SYSTEM_CREATE_ACCOUNT.to_string(), 10);
287        let tx = create_mock_resolved_transaction();
288        let user_id = "test-user-789".to_string();
289        let mut tx = tx;
290        let mut ctx =
291            LimiterContext { transaction: &mut tx, user_id, kora_signer: None, timestamp: 1000000 };
292
293        assert_eq!(rule.count_increment(&mut ctx), 0);
294    }
295
296    #[test]
297    fn test_instruction_rule_description() {
298        let lifetime =
299            InstructionRule::lifetime(SYSTEM_PROGRAM_ID, SYSTEM_CREATE_ACCOUNT.to_string(), 10);
300        assert!(lifetime.description().contains(SYSTEM_CREATE_ACCOUNT));
301        assert!(lifetime.description().contains("lifetime"));
302
303        let windowed =
304            InstructionRule::windowed(ATA_PROGRAM_ID, ATA_CREATE_IDEMPOTENT.to_string(), 5, 86400);
305        assert!(windowed.description().contains(ATA_CREATE_IDEMPOTENT));
306        assert!(windowed.description().contains("per 86400s"));
307    }
308
309    #[test]
310    fn test_instruction_case_insensitive() {
311        let rule = InstructionRule::new(SYSTEM_PROGRAM_ID, "CreateAccount".to_string(), 10, None);
312        assert_eq!(rule.instruction, SYSTEM_CREATE_ACCOUNT);
313    }
314
315    #[test]
316    fn test_identify_system_instructions() {
317        assert_eq!(
318            InstructionIdentifier::system(&[0, 0, 0, 0]),
319            Some(SYSTEM_CREATE_ACCOUNT.to_string())
320        );
321        assert_eq!(
322            InstructionIdentifier::system(&[3, 0, 0, 0]),
323            Some(SYSTEM_CREATE_ACCOUNT_WITH_SEED.to_string())
324        );
325    }
326
327    #[test]
328    fn test_identify_ata_instructions() {
329        assert_eq!(InstructionIdentifier::ata(&[]), Some(ATA_CREATE.to_string()));
330        assert_eq!(InstructionIdentifier::ata(&[0]), Some(ATA_CREATE.to_string()));
331        assert_eq!(InstructionIdentifier::ata(&[1]), Some(ATA_CREATE_IDEMPOTENT.to_string()));
332    }
333
334    #[test]
335    fn test_batch_counting_empty_rules() {
336        let tx = create_mock_resolved_transaction();
337        let user_id = "test-user-batch".to_string();
338        let mut tx_mut = tx;
339        let mut ctx = LimiterContext {
340            transaction: &mut tx_mut,
341            user_id,
342            kora_signer: None,
343            timestamp: 1000000,
344        };
345
346        let rules: Vec<&InstructionRule> = vec![];
347        let counts = InstructionRule::count_all_rules(&rules, &mut ctx);
348        assert_eq!(counts.len(), 0);
349    }
350
351    #[test]
352    fn test_batch_counting_matches_individual() {
353        let tx1 = create_mock_resolved_transaction();
354        let tx2 = create_mock_resolved_transaction();
355        let tx_batch = create_mock_resolved_transaction();
356        let user_id = "test-user-individual".to_string();
357
358        let rule1 =
359            InstructionRule::lifetime(SYSTEM_PROGRAM_ID, SYSTEM_CREATE_ACCOUNT.to_string(), 10);
360        let rule2 = InstructionRule::lifetime(ATA_PROGRAM_ID, ATA_CREATE_IDEMPOTENT.to_string(), 5);
361
362        // Count individually
363        let mut tx1_mut = tx1;
364        let mut ctx1 = LimiterContext {
365            transaction: &mut tx1_mut,
366            user_id: user_id.clone(),
367            kora_signer: None,
368            timestamp: 1000000,
369        };
370        let mut tx2_mut = tx2;
371        let mut ctx2 = LimiterContext {
372            transaction: &mut tx2_mut,
373            user_id: user_id.clone(),
374            kora_signer: None,
375            timestamp: 1000000,
376        };
377        let count1 = rule1.count_increment(&mut ctx1);
378        let count2 = rule2.count_increment(&mut ctx2);
379
380        // Count using batch method
381        let mut tx_batch_mut = tx_batch;
382        let mut ctx_batch = LimiterContext {
383            transaction: &mut tx_batch_mut,
384            user_id,
385            kora_signer: None,
386            timestamp: 1000000,
387        };
388        let rules = vec![&rule1, &rule2];
389        let batch_counts = InstructionRule::count_all_rules(&rules, &mut ctx_batch);
390
391        assert_eq!(batch_counts.len(), 2);
392        assert_eq!(batch_counts[0], count1);
393        assert_eq!(batch_counts[1], count2);
394    }
395}