1use crate::{ChainType, Error, Result, TransactionRequest};
29use chrono::{DateTime, Datelike, Timelike, Utc};
30use parking_lot::RwLock;
31use serde::{Deserialize, Serialize};
32use std::collections::{HashMap, HashSet};
33use std::sync::Arc;
34
35#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
37#[serde(tag = "type", rename_all = "snake_case")]
38pub enum PolicyDecision {
39 Approve,
41 Reject { reason: String },
43 RequireAdditionalApproval { reason: String },
45}
46
47impl PolicyDecision {
48 pub fn is_approved(&self) -> bool {
50 matches!(self, PolicyDecision::Approve)
51 }
52
53 pub fn requires_additional_approval(&self) -> bool {
55 matches!(self, PolicyDecision::RequireAdditionalApproval { .. })
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct SpendingLimits {
62 pub per_transaction: Option<u128>,
64 pub daily: Option<u128>,
66 pub weekly: Option<u128>,
68 pub currency: String,
70}
71
72impl Default for SpendingLimits {
73 fn default() -> Self {
74 Self {
75 per_transaction: None,
76 daily: None,
77 weekly: None,
78 currency: "ETH".to_string(),
79 }
80 }
81}
82
83impl SpendingLimits {
84 pub fn with_per_tx(amount: u128, currency: impl Into<String>) -> Self {
86 Self {
87 per_transaction: Some(amount),
88 daily: None,
89 weekly: None,
90 currency: currency.into(),
91 }
92 }
93
94 pub fn daily(mut self, amount: u128) -> Self {
96 self.daily = Some(amount);
97 self
98 }
99
100 pub fn weekly(mut self, amount: u128) -> Self {
102 self.weekly = Some(amount);
103 self
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct TimeBounds {
110 pub start_hour: u8,
112 pub end_hour: u8,
114 pub allowed_days: Vec<u8>,
116}
117
118impl Default for TimeBounds {
119 fn default() -> Self {
120 Self {
121 start_hour: 0,
122 end_hour: 24,
123 allowed_days: vec![0, 1, 2, 3, 4, 5, 6], }
125 }
126}
127
128impl TimeBounds {
129 pub fn business_hours() -> Self {
131 Self {
132 start_hour: 9,
133 end_hour: 17,
134 allowed_days: vec![1, 2, 3, 4, 5], }
136 }
137
138 pub fn is_allowed(&self, timestamp: DateTime<Utc>) -> bool {
140 let hour = timestamp.hour() as u8;
141 let day = timestamp.weekday().num_days_from_sunday() as u8;
142
143 let hour_ok = if self.start_hour <= self.end_hour {
144 hour >= self.start_hour && hour < self.end_hour
145 } else {
146 hour >= self.start_hour || hour < self.end_hour
148 };
149
150 hour_ok && self.allowed_days.contains(&day)
151 }
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct ContractRestriction {
157 pub allowed_contracts: HashSet<String>,
159 pub allowed_selectors: HashSet<String>,
161 pub blocked_selectors: HashSet<String>,
163}
164
165impl Default for ContractRestriction {
166 fn default() -> Self {
167 Self {
168 allowed_contracts: HashSet::new(),
169 allowed_selectors: HashSet::new(),
170 blocked_selectors: HashSet::new(),
171 }
172 }
173}
174
175impl ContractRestriction {
176 pub fn allow_contract(mut self, address: impl Into<String>) -> Self {
178 self.allowed_contracts.insert(address.into().to_lowercase());
179 self
180 }
181
182 pub fn allow_selector(mut self, selector: impl Into<String>) -> Self {
184 self.allowed_selectors
185 .insert(selector.into().to_lowercase());
186 self
187 }
188
189 pub fn block_selector(mut self, selector: impl Into<String>) -> Self {
191 self.blocked_selectors
192 .insert(selector.into().to_lowercase());
193 self
194 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct PolicyConfig {
200 pub spending_limits: HashMap<ChainType, SpendingLimits>,
202 pub whitelist: Option<HashSet<String>>,
204 pub blacklist: HashSet<String>,
206 pub time_bounds: Option<TimeBounds>,
208 pub contract_restrictions: Option<ContractRestriction>,
210 pub additional_approval_threshold: Option<u128>,
212 pub max_pending_requests: usize,
214 pub enabled: bool,
216}
217
218impl Default for PolicyConfig {
219 fn default() -> Self {
220 Self {
221 spending_limits: HashMap::new(),
222 whitelist: None,
223 blacklist: HashSet::new(),
224 time_bounds: None,
225 contract_restrictions: None,
226 additional_approval_threshold: None,
227 max_pending_requests: 10,
228 enabled: true,
229 }
230 }
231}
232
233impl PolicyConfig {
234 pub fn new() -> Self {
236 Self::default()
237 }
238
239 pub fn disabled() -> Self {
241 Self {
242 enabled: false,
243 ..Default::default()
244 }
245 }
246
247 pub fn with_spending_limits(mut self, chain: ChainType, limits: SpendingLimits) -> Self {
249 self.spending_limits.insert(chain, limits);
250 self
251 }
252
253 pub fn with_per_tx_limit(mut self, amount: u128, currency: impl Into<String>) -> Self {
255 let limits = self
256 .spending_limits
257 .entry(ChainType::Evm)
258 .or_insert_with(SpendingLimits::default);
259 limits.per_transaction = Some(amount);
260 limits.currency = currency.into();
261 self
262 }
263
264 pub fn with_daily_limit(mut self, amount: u128) -> Self {
266 let limits = self
267 .spending_limits
268 .entry(ChainType::Evm)
269 .or_insert_with(SpendingLimits::default);
270 limits.daily = Some(amount);
271 self
272 }
273
274 pub fn with_weekly_limit(mut self, amount: u128) -> Self {
276 let limits = self
277 .spending_limits
278 .entry(ChainType::Evm)
279 .or_insert_with(SpendingLimits::default);
280 limits.weekly = Some(amount);
281 self
282 }
283
284 pub fn with_whitelist(mut self, addresses: Vec<String>) -> Self {
286 self.whitelist = Some(addresses.into_iter().map(|a| a.to_lowercase()).collect());
287 self
288 }
289
290 pub fn with_blacklist(mut self, addresses: Vec<String>) -> Self {
292 self.blacklist = addresses.into_iter().map(|a| a.to_lowercase()).collect();
293 self
294 }
295
296 pub fn with_time_bounds(mut self, bounds: TimeBounds) -> Self {
298 self.time_bounds = Some(bounds);
299 self
300 }
301
302 pub fn with_contract_restrictions(mut self, restrictions: ContractRestriction) -> Self {
304 self.contract_restrictions = Some(restrictions);
305 self
306 }
307
308 pub fn with_additional_approval_threshold(mut self, amount: u128) -> Self {
310 self.additional_approval_threshold = Some(amount);
311 self
312 }
313}
314
315#[derive(Debug, Default)]
317struct SpendingTracker {
318 daily: HashMap<String, u128>,
320 weekly: HashMap<String, u128>,
322}
323
324impl SpendingTracker {
325 fn new() -> Self {
326 Self::default()
327 }
328
329 fn get_daily_spent(&self, date: &str) -> u128 {
330 *self.daily.get(date).unwrap_or(&0)
331 }
332
333 fn get_weekly_spent(&self, week: &str) -> u128 {
334 *self.weekly.get(week).unwrap_or(&0)
335 }
336
337 fn record_spending(&mut self, date: &str, week: &str, amount: u128) {
338 *self.daily.entry(date.to_string()).or_insert(0) += amount;
339 *self.weekly.entry(week.to_string()).or_insert(0) += amount;
340 }
341
342 fn cleanup_old_entries(&mut self, current_date: &str, current_week: &str) {
343 self.daily.retain(|k, _| k == current_date);
344 self.weekly.retain(|k, _| k == current_week);
345 }
346}
347
348#[derive(Debug)]
350pub struct PolicyEngine {
351 config: PolicyConfig,
353 spending: Arc<RwLock<HashMap<ChainType, SpendingTracker>>>,
355}
356
357impl PolicyEngine {
358 pub fn new(config: PolicyConfig) -> Self {
360 Self {
361 config,
362 spending: Arc::new(RwLock::new(HashMap::new())),
363 }
364 }
365
366 pub fn config(&self) -> &PolicyConfig {
368 &self.config
369 }
370
371 pub fn update_config(&mut self, config: PolicyConfig) {
373 self.config = config;
374 }
375
376 pub fn evaluate(&self, tx: &TransactionRequest) -> Result<PolicyDecision> {
378 if !self.config.enabled {
380 return Ok(PolicyDecision::Approve);
381 }
382
383 if self.config.blacklist.contains(&tx.to.to_lowercase()) {
385 return Ok(PolicyDecision::Reject {
386 reason: format!("Address {} is blacklisted", tx.to),
387 });
388 }
389
390 if let Some(ref whitelist) = self.config.whitelist {
392 if !whitelist.contains(&tx.to.to_lowercase()) {
393 return Ok(PolicyDecision::Reject {
394 reason: format!("Address {} is not whitelisted", tx.to),
395 });
396 }
397 }
398
399 if let Some(ref bounds) = self.config.time_bounds {
401 let now = Utc::now();
402 if !bounds.is_allowed(now) {
403 return Ok(PolicyDecision::Reject {
404 reason: format!(
405 "Transaction outside allowed time window ({}:00-{}:00 UTC)",
406 bounds.start_hour, bounds.end_hour
407 ),
408 });
409 }
410 }
411
412 if tx.is_contract_call() {
414 if let Some(ref restrictions) = self.config.contract_restrictions {
415 if !restrictions.allowed_contracts.is_empty()
417 && !restrictions
418 .allowed_contracts
419 .contains(&tx.to.to_lowercase())
420 {
421 return Ok(PolicyDecision::Reject {
422 reason: format!("Contract {} is not in allowed list", tx.to),
423 });
424 }
425
426 if let Some(selector) = tx.function_selector() {
428 let selector_hex = hex::encode(selector);
429
430 if restrictions.blocked_selectors.contains(&selector_hex) {
432 return Ok(PolicyDecision::Reject {
433 reason: format!("Function selector 0x{} is blocked", selector_hex),
434 });
435 }
436
437 if !restrictions.allowed_selectors.is_empty()
439 && !restrictions.allowed_selectors.contains(&selector_hex)
440 {
441 return Ok(PolicyDecision::Reject {
442 reason: format!(
443 "Function selector 0x{} is not in allowed list",
444 selector_hex
445 ),
446 });
447 }
448 }
449 }
450 }
451
452 let value = self.parse_value(&tx.value)?;
454
455 if let Some(limits) = self.config.spending_limits.get(&tx.chain) {
457 if let Some(per_tx) = limits.per_transaction {
459 if value > per_tx {
460 return Ok(PolicyDecision::Reject {
461 reason: format!(
462 "Transaction value {} exceeds per-transaction limit {}",
463 tx.value, per_tx
464 ),
465 });
466 }
467 }
468
469 let now = Utc::now();
471 let date_key = now.format("%Y-%m-%d").to_string();
472 let week_key = now.format("%Y-W%W").to_string();
473
474 let spending = self.spending.read();
475 if let Some(tracker) = spending.get(&tx.chain) {
476 if let Some(daily_limit) = limits.daily {
478 let spent = tracker.get_daily_spent(&date_key);
479 if spent + value > daily_limit {
480 return Ok(PolicyDecision::Reject {
481 reason: format!(
482 "Transaction would exceed daily limit of {} {} (already spent: {})",
483 daily_limit, limits.currency, spent
484 ),
485 });
486 }
487 }
488
489 if let Some(weekly_limit) = limits.weekly {
491 let spent = tracker.get_weekly_spent(&week_key);
492 if spent + value > weekly_limit {
493 return Ok(PolicyDecision::Reject {
494 reason: format!(
495 "Transaction would exceed weekly limit of {} {} (already spent: {})",
496 weekly_limit, limits.currency, spent
497 ),
498 });
499 }
500 }
501 }
502 }
503
504 if let Some(threshold) = self.config.additional_approval_threshold {
506 if value > threshold {
507 return Ok(PolicyDecision::RequireAdditionalApproval {
508 reason: format!(
509 "Transaction value {} exceeds additional approval threshold {}",
510 tx.value, threshold
511 ),
512 });
513 }
514 }
515
516 Ok(PolicyDecision::Approve)
517 }
518
519 pub fn record_transaction(&self, tx: &TransactionRequest) -> Result<()> {
521 let value = self.parse_value(&tx.value)?;
522 let now = Utc::now();
523 let date_key = now.format("%Y-%m-%d").to_string();
524 let week_key = now.format("%Y-W%W").to_string();
525
526 let mut spending = self.spending.write();
527 let tracker = spending
528 .entry(tx.chain)
529 .or_insert_with(SpendingTracker::new);
530
531 tracker.cleanup_old_entries(&date_key, &week_key);
533
534 tracker.record_spending(&date_key, &week_key, value);
536
537 Ok(())
538 }
539
540 fn parse_value(&self, value: &str) -> Result<u128> {
542 if value.contains('.') {
544 let parts: Vec<&str> = value.split('.').collect();
545 if parts.len() != 2 {
546 return Err(Error::PolicyViolation(format!(
547 "Invalid value format: {}",
548 value
549 )));
550 }
551
552 let whole: u128 = parts[0]
553 .parse()
554 .map_err(|_| Error::PolicyViolation(format!("Invalid value: {}", value)))?;
555
556 let mut decimal_str = parts[1].to_string();
557 while decimal_str.len() < 18 {
559 decimal_str.push('0');
560 }
561 decimal_str.truncate(18);
562
563 let decimal: u128 = decimal_str
564 .parse()
565 .map_err(|_| Error::PolicyViolation(format!("Invalid value: {}", value)))?;
566
567 Ok(whole * 10u128.pow(18) + decimal)
568 } else {
569 value
570 .parse()
571 .map_err(|_| Error::PolicyViolation(format!("Invalid value: {}", value)))
572 }
573 }
574
575 pub fn daily_spending(&self, chain: ChainType) -> u128 {
577 let date_key = Utc::now().format("%Y-%m-%d").to_string();
578 let spending = self.spending.read();
579 spending
580 .get(&chain)
581 .map(|t| t.get_daily_spent(&date_key))
582 .unwrap_or(0)
583 }
584
585 pub fn weekly_spending(&self, chain: ChainType) -> u128 {
587 let week_key = Utc::now().format("%Y-W%W").to_string();
588 let spending = self.spending.read();
589 spending
590 .get(&chain)
591 .map(|t| t.get_weekly_spent(&week_key))
592 .unwrap_or(0)
593 }
594
595 pub fn reset_spending(&self) {
597 let mut spending = self.spending.write();
598 spending.clear();
599 }
600}
601
602#[derive(Default)]
604pub struct PolicyBuilder {
605 config: PolicyConfig,
606}
607
608impl PolicyBuilder {
609 pub fn new() -> Self {
611 Self::default()
612 }
613
614 pub fn spending_limits(mut self, chain: ChainType, limits: SpendingLimits) -> Self {
616 self.config.spending_limits.insert(chain, limits);
617 self
618 }
619
620 pub fn whitelist(mut self, addresses: impl IntoIterator<Item = impl Into<String>>) -> Self {
622 let set: HashSet<String> = addresses
623 .into_iter()
624 .map(|a| a.into().to_lowercase())
625 .collect();
626 self.config.whitelist = Some(set);
627 self
628 }
629
630 pub fn blacklist(mut self, addresses: impl IntoIterator<Item = impl Into<String>>) -> Self {
632 self.config.blacklist = addresses
633 .into_iter()
634 .map(|a| a.into().to_lowercase())
635 .collect();
636 self
637 }
638
639 pub fn time_bounds(mut self, bounds: TimeBounds) -> Self {
641 self.config.time_bounds = Some(bounds);
642 self
643 }
644
645 pub fn contract_restrictions(mut self, restrictions: ContractRestriction) -> Self {
647 self.config.contract_restrictions = Some(restrictions);
648 self
649 }
650
651 pub fn additional_approval_threshold(mut self, amount: u128) -> Self {
653 self.config.additional_approval_threshold = Some(amount);
654 self
655 }
656
657 pub fn build(self) -> PolicyConfig {
659 self.config
660 }
661}
662
663#[cfg(test)]
664mod tests {
665 use super::*;
666
667 #[test]
668 fn test_policy_approve_basic() {
669 let engine = PolicyEngine::new(PolicyConfig::default());
670 let tx = TransactionRequest::new(ChainType::Evm, "0x1234", "1000000000000000000");
671
672 let decision = engine.evaluate(&tx).unwrap();
673 assert!(decision.is_approved());
674 }
675
676 #[test]
677 fn test_policy_disabled() {
678 let engine = PolicyEngine::new(PolicyConfig::disabled());
679 let tx = TransactionRequest::new(ChainType::Evm, "0x1234", "999999999999999999999999");
680
681 let decision = engine.evaluate(&tx).unwrap();
682 assert!(decision.is_approved());
683 }
684
685 #[test]
686 fn test_blacklist_rejection() {
687 let config = PolicyConfig::default().with_blacklist(vec!["0xBAD".to_string()]);
688 let engine = PolicyEngine::new(config);
689 let tx = TransactionRequest::new(ChainType::Evm, "0xbad", "1000");
690
691 let decision = engine.evaluate(&tx).unwrap();
692 assert!(!decision.is_approved());
693 if let PolicyDecision::Reject { reason } = decision {
694 assert!(reason.contains("blacklisted"));
695 }
696 }
697
698 #[test]
699 fn test_whitelist_rejection() {
700 let config = PolicyConfig::default().with_whitelist(vec!["0xGOOD".to_string()]);
701 let engine = PolicyEngine::new(config);
702 let tx = TransactionRequest::new(ChainType::Evm, "0xOTHER", "1000");
703
704 let decision = engine.evaluate(&tx).unwrap();
705 assert!(!decision.is_approved());
706 if let PolicyDecision::Reject { reason } = decision {
707 assert!(reason.contains("not whitelisted"));
708 }
709 }
710
711 #[test]
712 fn test_whitelist_approval() {
713 let config = PolicyConfig::default().with_whitelist(vec!["0xGOOD".to_string()]);
714 let engine = PolicyEngine::new(config);
715 let tx = TransactionRequest::new(ChainType::Evm, "0xgood", "1000");
716
717 let decision = engine.evaluate(&tx).unwrap();
718 assert!(decision.is_approved());
719 }
720
721 #[test]
722 fn test_per_tx_limit() {
723 let limits = SpendingLimits::with_per_tx(1_000_000_000_000_000_000u128, "ETH"); let config = PolicyConfig::default().with_spending_limits(ChainType::Evm, limits);
725 let engine = PolicyEngine::new(config);
726
727 let tx = TransactionRequest::new(ChainType::Evm, "0x1234", "500000000000000000");
729 assert!(engine.evaluate(&tx).unwrap().is_approved());
730
731 let tx_over = TransactionRequest::new(ChainType::Evm, "0x1234", "2000000000000000000");
733 assert!(!engine.evaluate(&tx_over).unwrap().is_approved());
734 }
735
736 #[test]
737 fn test_daily_limit() {
738 let limits = SpendingLimits::default().daily(2_000_000_000_000_000_000u128); let config = PolicyConfig::default().with_spending_limits(ChainType::Evm, limits);
740 let engine = PolicyEngine::new(config);
741
742 let tx1 = TransactionRequest::new(ChainType::Evm, "0x1234", "1000000000000000000");
744 assert!(engine.evaluate(&tx1).unwrap().is_approved());
745 engine.record_transaction(&tx1).unwrap();
746
747 let tx2 = TransactionRequest::new(ChainType::Evm, "0x1234", "500000000000000000");
749 assert!(engine.evaluate(&tx2).unwrap().is_approved());
750 engine.record_transaction(&tx2).unwrap();
751
752 let tx3 = TransactionRequest::new(ChainType::Evm, "0x1234", "1000000000000000000");
754 assert!(!engine.evaluate(&tx3).unwrap().is_approved());
755 }
756
757 #[test]
758 fn test_additional_approval_threshold() {
759 let config = PolicyConfig::default()
760 .with_additional_approval_threshold(5_000_000_000_000_000_000u128); let engine = PolicyEngine::new(config);
762
763 let tx = TransactionRequest::new(ChainType::Evm, "0x1234", "1000000000000000000");
765 assert!(engine.evaluate(&tx).unwrap().is_approved());
766
767 let tx_over = TransactionRequest::new(ChainType::Evm, "0x1234", "10000000000000000000");
769 let decision = engine.evaluate(&tx_over).unwrap();
770 assert!(decision.requires_additional_approval());
771 }
772
773 #[test]
774 fn test_time_bounds() {
775 let bounds = TimeBounds::business_hours();
776 assert!(bounds.start_hour == 9);
777 assert!(bounds.end_hour == 17);
778 assert_eq!(bounds.allowed_days, vec![1, 2, 3, 4, 5]);
779 }
780
781 #[test]
782 fn test_contract_restrictions() {
783 let restrictions = ContractRestriction::default()
784 .allow_contract("0xUniswap")
785 .block_selector("a9059cbb"); let config = PolicyConfig::default().with_contract_restrictions(restrictions);
788 let engine = PolicyEngine::new(config);
789
790 let mut tx = TransactionRequest::new(ChainType::Evm, "0xuniswap", "0");
792 tx.data = Some(vec![0x12, 0x34, 0x56, 0x78]); assert!(engine.evaluate(&tx).unwrap().is_approved());
794
795 let mut tx_blocked = TransactionRequest::new(ChainType::Evm, "0xuniswap", "0");
797 tx_blocked.data = Some(vec![0xa9, 0x05, 0x9c, 0xbb, 0x00]); assert!(!engine.evaluate(&tx_blocked).unwrap().is_approved());
799 }
800
801 #[test]
802 fn test_policy_builder() {
803 let policy = PolicyBuilder::new()
804 .spending_limits(
805 ChainType::Evm,
806 SpendingLimits::with_per_tx(1_000_000_000_000_000_000, "ETH"),
807 )
808 .whitelist(["0x1234", "0x5678"])
809 .blacklist(["0xBAD"])
810 .time_bounds(TimeBounds::business_hours())
811 .additional_approval_threshold(10_000_000_000_000_000_000)
812 .build();
813
814 assert!(policy.whitelist.is_some());
815 assert!(policy.blacklist.contains("0xbad"));
816 assert!(policy.time_bounds.is_some());
817 }
818
819 #[test]
820 fn test_parse_decimal_value() {
821 let engine = PolicyEngine::new(PolicyConfig::default());
822
823 let value = engine.parse_value("1.5").unwrap();
825 assert_eq!(value, 1_500_000_000_000_000_000u128);
826
827 let value = engine.parse_value("0.001").unwrap();
829 assert_eq!(value, 1_000_000_000_000_000u128);
830
831 let value = engine.parse_value("1000000000000000000").unwrap();
833 assert_eq!(value, 1_000_000_000_000_000_000u128);
834 }
835}