1use serde::{Deserialize, Serialize};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use crate::constants::*;
7use crate::types::{SecurityViolation, SecurityViolationType, SecuritySeverity};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct SecurityConfig {
12 pub sandbox_enabled: bool,
14 pub reentrancy_protection: bool,
16 pub overflow_detection: bool,
18 pub access_control_verification: bool,
20 pub max_call_depth: u32,
22 pub max_external_calls: u32,
24 pub gas_limit_enforcement: bool,
26 pub max_gas_limit: u64,
28 pub memory_limit_enforcement: bool,
30 pub max_memory_bytes: u64,
32}
33
34impl Default for SecurityConfig {
35 fn default() -> Self {
36 Self {
37 sandbox_enabled: true,
38 reentrancy_protection: true,
39 overflow_detection: true,
40 access_control_verification: true,
41 max_call_depth: DEFAULT_MAX_CALL_DEPTH,
42 max_external_calls: DEFAULT_MAX_EXTERNAL_CALLS,
43 gas_limit_enforcement: true,
44 max_gas_limit: DEFAULT_MAX_GAS_LIMIT,
45 memory_limit_enforcement: true,
46 max_memory_bytes: DEFAULT_MAX_MEMORY_BYTES,
47 }
48 }
49}
50
51impl SecurityConfig {
52 pub fn new(
54 sandbox_enabled: bool,
55 reentrancy_protection: bool,
56 overflow_detection: bool,
57 access_control_verification: bool,
58 ) -> Self {
59 Self {
60 sandbox_enabled,
61 reentrancy_protection,
62 overflow_detection,
63 access_control_verification,
64 max_call_depth: DEFAULT_MAX_CALL_DEPTH,
65 max_external_calls: DEFAULT_MAX_EXTERNAL_CALLS,
66 gas_limit_enforcement: true,
67 max_gas_limit: DEFAULT_MAX_GAS_LIMIT,
68 memory_limit_enforcement: true,
69 max_memory_bytes: DEFAULT_MAX_MEMORY_BYTES,
70 }
71 }
72
73 pub fn permissive() -> Self {
75 Self {
76 sandbox_enabled: false,
77 reentrancy_protection: false,
78 overflow_detection: false,
79 access_control_verification: false,
80 max_call_depth: DEFAULT_MAX_CALL_DEPTH,
81 max_external_calls: DEFAULT_MAX_EXTERNAL_CALLS,
82 gas_limit_enforcement: false,
83 max_gas_limit: u64::MAX,
84 memory_limit_enforcement: false,
85 max_memory_bytes: u64::MAX,
86 }
87 }
88
89 pub fn strict() -> Self {
91 Self {
92 sandbox_enabled: true,
93 reentrancy_protection: true,
94 overflow_detection: true,
95 access_control_verification: true,
96 max_call_depth: 100, max_external_calls: 10, gas_limit_enforcement: true,
99 max_gas_limit: 1_000_000, memory_limit_enforcement: true,
101 max_memory_bytes: 10 * 1024 * 1024, }
103 }
104}
105
106pub struct SecurityValidator {
108 config: SecurityConfig,
109}
110
111impl SecurityValidator {
112 pub fn new(config: SecurityConfig) -> Self {
114 Self { config }
115 }
116
117 pub fn validate_call_depth(&self, call_depth: u32) -> Result<(), SecurityViolation> {
119 if call_depth > self.config.max_call_depth {
120 Err(self.create_violation(
121 SecurityViolationType::CallDepthExceeded,
122 format!("Call depth {} exceeds maximum {}", call_depth, self.config.max_call_depth),
123 SecuritySeverity::High,
124 ))
125 } else {
126 Ok(())
127 }
128 }
129
130 pub fn validate_external_calls(&self, call_count: u32) -> Result<(), SecurityViolation> {
132 if call_count > self.config.max_external_calls {
133 Err(self.create_violation(
134 SecurityViolationType::ExternalCallLimitExceeded,
135 format!("External call count {} exceeds maximum {}", call_count, self.config.max_external_calls),
136 SecuritySeverity::High,
137 ))
138 } else {
139 Ok(())
140 }
141 }
142
143 pub fn validate_gas_usage(&self, gas_used: u64) -> Result<(), SecurityViolation> {
145 if self.config.gas_limit_enforcement && gas_used > self.config.max_gas_limit {
146 Err(self.create_violation(
147 SecurityViolationType::GasLimitExceeded,
148 format!("Gas usage {} exceeds maximum {}", gas_used, self.config.max_gas_limit),
149 SecuritySeverity::High,
150 ))
151 } else {
152 Ok(())
153 }
154 }
155
156 pub fn validate_memory_usage(&self, memory_used: u64) -> Result<(), SecurityViolation> {
158 if self.config.memory_limit_enforcement && memory_used > self.config.max_memory_bytes {
159 Err(self.create_violation(
160 SecurityViolationType::MemoryLimitExceeded,
161 format!("Memory usage {} exceeds maximum {}", memory_used, self.config.max_memory_bytes),
162 SecuritySeverity::High,
163 ))
164 } else {
165 Ok(())
166 }
167 }
168
169 pub fn check_reentrancy(&self, function_name: &str, caller: &str, call_stack: &[String]) -> Result<bool, SecurityViolation> {
171 if !self.config.reentrancy_protection {
172 return Ok(false);
173 }
174
175 let recursive_calls = call_stack.iter().filter(|&f| f == function_name).count();
177 if recursive_calls > 1 {
178 Err(self.create_violation(
179 SecurityViolationType::ReentrancyAttack,
180 format!("Potential reentrancy attack in function {} called by {}", function_name, caller),
181 SecuritySeverity::Critical,
182 ))
183 } else {
184 Ok(false)
185 }
186 }
187
188 pub fn detect_overflow(&self, operation: &str, operands: &[i64]) -> Result<bool, SecurityViolation> {
190 if !self.config.overflow_detection {
191 return Ok(false);
192 }
193
194 match operation {
196 "add" | "+" => {
197 if operands.len() >= 2 {
198 let result = operands[0].checked_add(operands[1]);
199 if result.is_none() {
200 return Err(self.create_violation(
201 SecurityViolationType::IntegerOverflow,
202 format!("Integer overflow detected in addition: {} + {}", operands[0], operands[1]),
203 SecuritySeverity::High,
204 ));
205 }
206 }
207 }
208 "multiply" | "*" => {
209 if operands.len() >= 2 {
210 let result = operands[0].checked_mul(operands[1]);
211 if result.is_none() {
212 return Err(self.create_violation(
213 SecurityViolationType::IntegerOverflow,
214 format!("Integer overflow detected in multiplication: {} * {}", operands[0], operands[1]),
215 SecuritySeverity::High,
216 ));
217 }
218 }
219 }
220 _ => {}
221 }
222
223 Ok(false)
224 }
225
226 pub fn verify_access_control(&self, function_name: &str, caller: &str, required_role: Option<&str>) -> Result<bool, SecurityViolation> {
228 if !self.config.access_control_verification {
229 return Ok(true);
230 }
231
232 if let Some(role) = required_role {
234 if role == "admin" && !caller.ends_with("admin") {
235 return Err(self.create_violation(
236 SecurityViolationType::AccessControlViolation,
237 format!("Access denied: {} does not have {} role for function {}", caller, role, function_name),
238 SecuritySeverity::Medium,
239 ));
240 }
241 }
242
243 Ok(true)
244 }
245
246 fn create_violation(&self, violation_type: SecurityViolationType, description: String, severity: SecuritySeverity) -> SecurityViolation {
248 SecurityViolation {
249 violation_type,
250 description,
251 severity,
252 timestamp: SystemTime::now()
253 .duration_since(UNIX_EPOCH)
254 .unwrap()
255 .as_secs(),
256 context: std::collections::HashMap::new(),
257 }
258 }
259}
260
261pub struct SecurityContext {
263 validator: SecurityValidator,
264 violations: Vec<SecurityViolation>,
265}
266
267impl SecurityContext {
268 pub fn new(config: SecurityConfig) -> Self {
270 Self {
271 validator: SecurityValidator::new(config),
272 violations: Vec::new(),
273 }
274 }
275
276 pub fn validator(&self) -> &SecurityValidator {
278 &self.validator
279 }
280
281 pub fn add_violation(&mut self, violation: SecurityViolation) {
283 self.violations.push(violation);
284 }
285
286 pub fn violations(&self) -> &[SecurityViolation] {
288 &self.violations
289 }
290
291 pub fn clear_violations(&mut self) {
293 self.violations.clear();
294 }
295
296 pub fn has_critical_violations(&self) -> bool {
298 self.violations.iter().any(|v| matches!(v.severity, SecuritySeverity::Critical))
299 }
300
301 pub fn violation_count_by_severity(&self) -> (usize, usize, usize, usize) {
303 let mut critical = 0;
304 let mut high = 0;
305 let mut medium = 0;
306 let mut low = 0;
307
308 for violation in &self.violations {
309 match violation.severity {
310 SecuritySeverity::Critical => critical += 1,
311 SecuritySeverity::High => high += 1,
312 SecuritySeverity::Medium => medium += 1,
313 SecuritySeverity::Low => low += 1,
314 }
315 }
316
317 (critical, high, medium, low)
318 }
319}