1use std::{collections::HashSet, sync::Arc, time::SystemTime};
2
3use super::{
4 limiter::{LimiterContext, LimiterResult},
5 rules::{InstructionRule, UsageRule},
6 usage_store::{InMemoryUsageStore, RedisUsageStore},
7 UsageStore,
8};
9use crate::{
10 cache::CacheUtil,
11 config::Config,
12 error::KoraError,
13 sanitize_error,
14 state::get_signer_pool,
15 token::token::TokenType,
16 transaction::{
17 ParsedSPLInstructionData, ParsedSPLInstructionType, VersionedTransactionResolved,
18 },
19};
20use deadpool_redis::Runtime;
21use redis::AsyncCommands;
22use solana_client::nonblocking::rpc_client::RpcClient;
23use solana_sdk::pubkey::Pubkey;
24use tokio::sync::OnceCell;
25
26#[cfg(not(test))]
27use crate::state::get_config;
28
29#[cfg(test)]
30use crate::tests::config_mock::mock_state::get_config;
31
32static USAGE_LIMITER: OnceCell<Option<UsageTracker>> = OnceCell::const_new();
34
35pub struct UsageTracker {
36 enabled: bool,
37 store: Arc<dyn UsageStore>,
38 rules: Vec<UsageRule>,
39 instruction_rule_indices: Vec<usize>,
40 kora_signers: HashSet<Pubkey>,
41 fallback_if_unavailable: bool,
42}
43
44impl UsageTracker {
45 pub fn new(
46 enabled: bool,
47 store: Arc<dyn UsageStore>,
48 rules: Vec<UsageRule>,
49 kora_signers: HashSet<Pubkey>,
50 fallback_if_unavailable: bool,
51 ) -> Self {
52 let instruction_rule_indices: Vec<usize> =
54 rules
55 .iter()
56 .enumerate()
57 .filter_map(|(idx, rule)| {
58 if matches!(rule, UsageRule::Instruction(_)) {
59 Some(idx)
60 } else {
61 None
62 }
63 })
64 .collect();
65
66 Self {
67 enabled,
68 store,
69 rules,
70 instruction_rule_indices,
71 kora_signers,
72 fallback_if_unavailable,
73 }
74 }
75
76 fn get_usage_limiter() -> Result<Option<&'static UsageTracker>, KoraError> {
77 match USAGE_LIMITER.get() {
78 Some(limiter) => Ok(limiter.as_ref()),
79 None => {
80 Err(KoraError::InternalServerError("Usage limiter not initialized".to_string()))
81 }
82 }
83 }
84
85 fn is_enabled(&self) -> bool {
86 self.enabled && !self.rules.is_empty()
87 }
88
89 fn has_instruction_rules(&self) -> bool {
90 !self.instruction_rule_indices.is_empty()
91 }
92
93 async fn extract_user_from_payment_instruction(
94 &self,
95 transaction: &mut VersionedTransactionResolved,
96 config: &Config,
97 fee_payer: &Pubkey,
98 rpc_client: &RpcClient,
99 ) -> Result<Option<Pubkey>, KoraError> {
100 let payment_destination = config.kora.get_payment_address(fee_payer)?;
101 let parsed_spl_instructions = transaction.get_or_parse_spl_instructions()?;
102
103 for instruction in parsed_spl_instructions
104 .get(&ParsedSPLInstructionType::SplTokenTransfer)
105 .unwrap_or(&vec![])
106 {
107 if let ParsedSPLInstructionData::SplTokenTransfer {
108 destination_address, owner, ..
109 } = instruction
110 {
111 let destination_account =
114 match CacheUtil::get_account(config, rpc_client, destination_address, true)
115 .await
116 {
117 Ok(account) => account,
118 Err(KoraError::AccountNotFound(_)) => continue,
119 Err(e) => return Err(e),
120 };
121
122 let token_program =
123 TokenType::get_token_program_from_owner(&destination_account.owner)?;
124 let token_account =
125 token_program.unpack_token_account(&destination_account.data)?;
126
127 if token_account.owner() == payment_destination {
129 return Ok(Some(*owner));
130 }
131 }
132 }
133
134 Ok(None)
135 }
136
137 fn extract_kora_signer(&self, transaction: &VersionedTransactionResolved) -> Option<Pubkey> {
139 let account_keys = transaction.message.static_account_keys();
140 let num_signers = transaction.message.header().num_required_signatures as usize;
141
142 account_keys
143 .iter()
144 .take(num_signers)
145 .find(|signer| self.kora_signers.contains(signer))
146 .copied()
147 }
148
149 fn current_timestamp() -> u64 {
150 SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).map(|d| d.as_secs()).unwrap_or(0)
151 }
152
153 async fn check_and_record(
156 &self,
157 ctx: &mut LimiterContext<'_>,
158 ) -> Result<LimiterResult, KoraError> {
159 if !self.is_enabled() {
160 return Ok(LimiterResult::Allowed);
161 }
162
163 let instruction_rules: Vec<&InstructionRule> = self
165 .instruction_rule_indices
166 .iter()
167 .filter_map(|&idx| self.rules[idx].as_instruction())
168 .collect();
169
170 let instruction_counts = if !instruction_rules.is_empty() {
172 InstructionRule::count_all_rules(&instruction_rules, ctx)
173 } else {
174 Vec::new()
175 };
176
177 let ix_idx_set: HashSet<usize> = self.instruction_rule_indices.iter().copied().collect();
179
180 let mut pending_increments: Vec<(String, u64, Option<u64>)> = Vec::new();
183 let mut instruction_count_idx = 0;
184
185 for (idx, rule) in self.rules.iter().enumerate() {
186 let increment_count = if ix_idx_set.contains(&idx) {
187 let count = instruction_counts[instruction_count_idx];
189 instruction_count_idx += 1;
190 count
191 } else {
192 1
194 };
195
196 if increment_count == 0 {
197 continue;
198 }
199
200 let key = rule.storage_key(&ctx.user_id, ctx.timestamp);
201
202 let current = self.store.get(&key).await?;
203 let new_count = current as u64 + increment_count;
204
205 if new_count > rule.max() {
206 return Ok(LimiterResult::Denied {
207 reason: format!(
208 "User {} exceeded {} limit: {}/{}",
209 ctx.user_id,
210 rule.description(),
211 new_count,
212 rule.max()
213 ),
214 });
215 }
216
217 pending_increments.push((key, increment_count, rule.window_seconds()));
219
220 log::debug!(
221 "[rule] User {} {}: {}/{} ({})",
222 ctx.user_id,
223 rule.description(),
224 new_count,
225 rule.max(),
226 rule.window_seconds().map_or("lifetime".to_string(), |w| format!("{}s window", w))
227 );
228 }
229
230 for (key, increment_count, window_seconds) in pending_increments {
231 if let Some(window) = window_seconds.filter(|&w| w > 0) {
232 let expires_at = (ctx.timestamp / window + 1) * window;
235 self.store.increment_with_expiry(&key, expires_at).await?;
237 for _ in 1..increment_count {
239 self.store.increment(&key).await?;
240 }
241 } else {
242 for _ in 0..increment_count {
243 self.store.increment(&key).await?;
244 }
245 }
246 }
247
248 Ok(LimiterResult::Allowed)
249 }
250
251 pub async fn init_usage_limiter() -> Result<(), KoraError> {
252 let config = get_config()?;
253 let usage_config = &config.kora.usage_limit;
254
255 let set_limiter = |limiter| {
256 USAGE_LIMITER.set(limiter).map_err(|_| {
257 KoraError::InternalServerError("Usage limiter already initialized".to_string())
258 })
259 };
260
261 if !usage_config.enabled {
262 log::info!("Usage limiting disabled");
263 return set_limiter(None);
264 }
265
266 let rules = usage_config.build_rules()?;
267 if rules.is_empty() {
268 log::info!("Usage limiting enabled but no rules configured - disabled");
269 return set_limiter(None);
270 }
271
272 let kora_signers = get_signer_pool()?
273 .get_signers_info()
274 .iter()
275 .filter_map(|info| info.public_key.parse().ok())
276 .collect();
277
278 let (store, backend): (Arc<dyn UsageStore>, &str) =
279 if let Some(cache_url) = &usage_config.cache_url {
280 let cfg = deadpool_redis::Config::from_url(cache_url);
281 let pool = cfg.create_pool(Some(Runtime::Tokio1)).map_err(|e| {
282 KoraError::InternalServerError(format!(
283 "Failed to create Redis pool: {}",
284 sanitize_error!(e)
285 ))
286 })?;
287
288 let mut conn = pool.get().await.map_err(|e| {
289 KoraError::InternalServerError(format!(
290 "Failed to connect to Redis: {}",
291 sanitize_error!(e)
292 ))
293 })?;
294
295 let _: Option<String> = conn.get("__usage_limiter_test__").await.map_err(|e| {
296 KoraError::InternalServerError(format!(
297 "Redis connection test failed: {}",
298 sanitize_error!(e)
299 ))
300 })?;
301
302 (Arc::new(RedisUsageStore::new(pool)), "Redis")
303 } else {
304 log::warn!(
305 "Usage limiting configured with in-memory store. \
306 Limits will NOT be shared across instances and will reset on restart. \
307 Configure 'cache_url' in [kora.usage_limit] for production deployments."
308 );
309 (Arc::new(InMemoryUsageStore::new()), "in-memory")
310 };
311
312 log::info!("Usage limiting initialized with {} rules ({backend})", rules.len());
313
314 set_limiter(Some(UsageTracker::new(
315 usage_config.enabled,
316 store,
317 rules,
318 kora_signers,
319 usage_config.fallback_if_unavailable,
320 )))
321 }
322
323 pub async fn check_transaction_usage_limit(
324 config: &Config,
325 transaction: &mut VersionedTransactionResolved,
326 user_id: Option<&str>,
327 fee_payer: &Pubkey,
328 rpc_client: &RpcClient,
329 ) -> Result<(), KoraError> {
330 if config.kora.usage_limit.enabled
332 && matches!(&config.validation.price.model, crate::fee::price::PriceModel::Free)
333 && user_id.is_none()
334 {
335 return Err(KoraError::ValidationError(
336 "user_id is required when usage tracking is enabled and pricing is free"
337 .to_string(),
338 ));
339 }
340
341 let Some(tracker) = Self::get_usage_limiter()? else {
342 if config.kora.usage_limit.enabled && !config.kora.usage_limit.fallback_if_unavailable {
343 return Err(KoraError::InternalServerError(
344 "Usage limiter unavailable and fallback disabled".to_string(),
345 ));
346 }
347 return Ok(());
348 };
349
350 if tracker.has_instruction_rules() {
351 transaction.get_or_parse_system_instructions()?;
352 transaction.get_or_parse_spl_instructions()?;
353 }
354
355 let resolved_user_id = if let Some(user_id_str) = user_id {
359 user_id_str.to_string()
360 } else {
361 tracker
363 .extract_user_from_payment_instruction(transaction, config, fee_payer, rpc_client)
364 .await?
365 .ok_or_else(|| {
366 KoraError::ValidationError(
367 "Could not resolve user_id: no payment instruction found".to_string(),
368 )
369 })?
370 .to_string()
371 };
372
373 let kora_signer = tracker.extract_kora_signer(transaction);
374
375 let mut ctx = LimiterContext {
376 transaction,
377 user_id: resolved_user_id,
378 kora_signer,
379 timestamp: Self::current_timestamp(),
380 };
381
382 match tracker.check_and_record(&mut ctx).await {
383 Ok(LimiterResult::Allowed) => Ok(()),
384 Ok(LimiterResult::Denied { reason }) => Err(KoraError::UsageLimitExceeded(reason)),
385 Err(e)
386 if tracker.fallback_if_unavailable
387 && matches!(e, KoraError::InternalServerError(_)) =>
388 {
389 log::warn!("Usage limiter error (fallback enabled): {e}");
390 Ok(())
391 }
392 Err(e) => Err(e),
393 }
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use crate::{
401 tests::{
402 config_mock::ConfigMockBuilder, rpc_mock::RpcMockBuilder,
403 transaction_mock::create_mock_resolved_transaction,
404 },
405 usage_limit::{InMemoryUsageStore, UsageLimitConfig, UsageLimitRuleConfig},
406 };
407 use std::sync::Arc;
408
409 fn create_test_tracker(max_transactions: u64) -> UsageTracker {
410 let store = Arc::new(InMemoryUsageStore::new());
411 let config = UsageLimitConfig {
412 enabled: true,
413 cache_url: None,
414 fallback_if_unavailable: false,
415 rules: vec![UsageLimitRuleConfig::Transaction {
416 max: max_transactions,
417 window_seconds: None,
418 }],
419 };
420 let rules = config.build_rules().unwrap();
421 UsageTracker::new(true, store, rules, HashSet::new(), false)
422 }
423
424 #[tokio::test]
425 async fn test_usage_limit_enforcement() {
426 let tracker = create_test_tracker(2);
427 let user_id = "test-user-enforcement".to_string();
428
429 let mut tx1 = create_mock_resolved_transaction();
430 let mut ctx1 = LimiterContext {
431 transaction: &mut tx1,
432 user_id: user_id.clone(),
433 kora_signer: None,
434 timestamp: 1000000,
435 };
436
437 let mut tx2 = create_mock_resolved_transaction();
438 let mut ctx2 = LimiterContext {
439 transaction: &mut tx2,
440 user_id: user_id.clone(),
441 kora_signer: None,
442 timestamp: 1000000,
443 };
444
445 let mut tx3 = create_mock_resolved_transaction();
446 let mut ctx3 = LimiterContext {
447 transaction: &mut tx3,
448 user_id: user_id.clone(),
449 kora_signer: None,
450 timestamp: 1000000,
451 };
452
453 assert!(matches!(
455 tracker.check_and_record(&mut ctx1).await.unwrap(),
456 LimiterResult::Allowed
457 ));
458
459 assert!(matches!(
461 tracker.check_and_record(&mut ctx2).await.unwrap(),
462 LimiterResult::Allowed
463 ));
464
465 assert!(matches!(
467 tracker.check_and_record(&mut ctx3).await.unwrap(),
468 LimiterResult::Denied { .. }
469 ));
470 }
471
472 #[tokio::test]
473 async fn test_independent_user_limits() {
474 let tracker = create_test_tracker(2);
475
476 let user_id1 = "test-user-1".to_string();
477 let user_id2 = "test-user-2".to_string();
478
479 let mut tx1a = create_mock_resolved_transaction();
481 let mut ctx1a = LimiterContext {
482 transaction: &mut tx1a,
483 user_id: user_id1.clone(),
484 kora_signer: None,
485 timestamp: 1000000,
486 };
487 assert!(matches!(
488 tracker.check_and_record(&mut ctx1a).await.unwrap(),
489 LimiterResult::Allowed
490 ));
491 let mut tx1b = create_mock_resolved_transaction();
492 let mut ctx1b = LimiterContext {
493 transaction: &mut tx1b,
494 user_id: user_id1.clone(),
495 kora_signer: None,
496 timestamp: 1000000,
497 };
498 assert!(matches!(
499 tracker.check_and_record(&mut ctx1b).await.unwrap(),
500 LimiterResult::Allowed
501 ));
502 let mut tx1c = create_mock_resolved_transaction();
503 let mut ctx1c = LimiterContext {
504 transaction: &mut tx1c,
505 user_id: user_id1.clone(),
506 kora_signer: None,
507 timestamp: 1000000,
508 };
509 assert!(matches!(
510 tracker.check_and_record(&mut ctx1c).await.unwrap(),
511 LimiterResult::Denied { .. }
512 ));
513
514 let mut tx2a = create_mock_resolved_transaction();
516 let mut ctx2a = LimiterContext {
517 transaction: &mut tx2a,
518 user_id: user_id2.clone(),
519 kora_signer: None,
520 timestamp: 1000000,
521 };
522 assert!(matches!(
523 tracker.check_and_record(&mut ctx2a).await.unwrap(),
524 LimiterResult::Allowed
525 ));
526 let mut tx2b = create_mock_resolved_transaction();
527 let mut ctx2b = LimiterContext {
528 transaction: &mut tx2b,
529 user_id: user_id2.clone(),
530 kora_signer: None,
531 timestamp: 1000000,
532 };
533 assert!(matches!(
534 tracker.check_and_record(&mut ctx2b).await.unwrap(),
535 LimiterResult::Allowed
536 ));
537 let mut tx2c = create_mock_resolved_transaction();
538 let mut ctx2c = LimiterContext {
539 transaction: &mut tx2c,
540 user_id: user_id2.clone(),
541 kora_signer: None,
542 timestamp: 1000000,
543 };
544 assert!(matches!(
545 tracker.check_and_record(&mut ctx2c).await.unwrap(),
546 LimiterResult::Denied { .. }
547 ));
548 }
549
550 #[tokio::test]
551 async fn test_unlimited_usage() {
552 let store = Arc::new(InMemoryUsageStore::new());
553 let config = UsageLimitConfig {
554 enabled: true,
555 cache_url: None,
556 fallback_if_unavailable: false,
557 rules: vec![], };
559 let rules = config.build_rules().unwrap();
560 let tracker = UsageTracker::new(true, store, rules, HashSet::new(), false);
561
562 let user_id = "test-user-unlimited".to_string();
563
564 for _ in 0..10 {
566 let mut tx = create_mock_resolved_transaction();
567 let mut ctx = LimiterContext {
568 transaction: &mut tx,
569 user_id: user_id.clone(),
570 kora_signer: None,
571 timestamp: 1000000,
572 };
573 assert!(matches!(
574 tracker.check_and_record(&mut ctx).await.unwrap(),
575 LimiterResult::Allowed
576 ));
577 }
578 }
579
580 #[tokio::test]
581 async fn test_multiple_rules() {
582 let store: Arc<dyn UsageStore> = Arc::new(InMemoryUsageStore::new());
583
584 let config = UsageLimitConfig {
585 enabled: true,
586 cache_url: None,
587 fallback_if_unavailable: false,
588 rules: vec![
589 UsageLimitRuleConfig::Transaction { max: 10, window_seconds: None },
591 UsageLimitRuleConfig::Transaction { max: 2, window_seconds: Some(100) },
593 ],
594 };
595
596 let rules = config.build_rules().unwrap();
597 let tracker = UsageTracker::new(true, store, rules, HashSet::new(), false);
598
599 let user_id = "test-user-multiple-rules".to_string();
600 let now = UsageTracker::current_timestamp();
602
603 let mut tx1 = create_mock_resolved_transaction();
605 let mut ctx1 = LimiterContext {
606 transaction: &mut tx1,
607 user_id: user_id.clone(),
608 kora_signer: None,
609 timestamp: now,
610 };
611 assert!(matches!(
612 tracker.check_and_record(&mut ctx1).await.unwrap(),
613 LimiterResult::Allowed
614 ));
615 let mut tx2 = create_mock_resolved_transaction();
616 let mut ctx2 = LimiterContext {
617 transaction: &mut tx2,
618 user_id: user_id.clone(),
619 kora_signer: None,
620 timestamp: now,
621 };
622 assert!(matches!(
623 tracker.check_and_record(&mut ctx2).await.unwrap(),
624 LimiterResult::Allowed
625 ));
626
627 let mut tx3 = create_mock_resolved_transaction();
629 let mut ctx3 = LimiterContext {
630 transaction: &mut tx3,
631 user_id: user_id.clone(),
632 kora_signer: None,
633 timestamp: now,
634 };
635 assert!(matches!(
636 tracker.check_and_record(&mut ctx3).await.unwrap(),
637 LimiterResult::Denied { .. }
638 ));
639 }
640
641 #[tokio::test]
642 async fn test_usage_limiter_disabled_fallback() {
643 let _m = ConfigMockBuilder::new().with_usage_limit_enabled(false).build_and_setup();
645
646 let _ = UsageTracker::init_usage_limiter().await;
648
649 let config = get_config().unwrap();
650 let mut tx = create_mock_resolved_transaction();
651 let rpc_client = Arc::new(RpcMockBuilder::new().build());
652 let fee_payer = Pubkey::new_unique();
653 let result = UsageTracker::check_transaction_usage_limit(
654 &config,
655 &mut tx,
656 None,
657 &fee_payer,
658 &rpc_client,
659 )
660 .await;
661 match &result {
662 Ok(_) => {}
663 Err(e) => println!("Test failed with error: {e}"),
664 }
665 assert!(result.is_ok());
666 }
667
668 #[tokio::test]
669 async fn test_usage_limiter_fallback_allowed() {
670 let _m = ConfigMockBuilder::new()
671 .with_usage_limit_enabled(true)
672 .with_usage_limit_cache_url(None)
673 .with_usage_limit_fallback(true)
674 .build_and_setup();
675
676 let _ = UsageTracker::init_usage_limiter().await;
678
679 let config = get_config().unwrap();
680 let mut tx = create_mock_resolved_transaction();
681 let rpc_client = Arc::new(RpcMockBuilder::new().build());
682 let fee_payer = Pubkey::new_unique();
683 let result = UsageTracker::check_transaction_usage_limit(
684 &config,
685 &mut tx,
686 None,
687 &fee_payer,
688 &rpc_client,
689 )
690 .await;
691 assert!(result.is_ok());
692 }
693
694 #[tokio::test]
695 async fn test_usage_limiter_fallback_denied() {
696 let _m = ConfigMockBuilder::new()
697 .with_usage_limit_enabled(true)
698 .with_usage_limit_cache_url(None)
699 .with_usage_limit_fallback(false)
700 .build_and_setup();
701
702 let _ = UsageTracker::init_usage_limiter().await;
704
705 let config = get_config().unwrap();
706 let mut tx = create_mock_resolved_transaction();
707 let rpc_client = Arc::new(RpcMockBuilder::new().build());
708 let fee_payer = Pubkey::new_unique();
709 let result = UsageTracker::check_transaction_usage_limit(
710 &config,
711 &mut tx,
712 None,
713 &fee_payer,
714 &rpc_client,
715 )
716 .await;
717 assert!(result.is_err());
718 assert!(result
719 .unwrap_err()
720 .to_string()
721 .contains("Usage limiter unavailable and fallback disabled"));
722 }
723}