Skip to main content

memscope_rs/analysis/safety/
engine.rs

1use crate::analysis::safety::types::*;
2use crate::analysis::unsafe_ffi_tracker::StackFrame;
3use std::collections::HashMap;
4use std::collections::HashSet;
5use std::time::{SystemTime, UNIX_EPOCH};
6
7pub struct RiskAssessmentEngine {
8    _risk_weights: HashMap<RiskFactorType, f64>,
9    _historical_data: HashMap<String, Vec<f64>>,
10}
11
12impl RiskAssessmentEngine {
13    pub fn new() -> Self {
14        let mut risk_weights = HashMap::new();
15        risk_weights.insert(RiskFactorType::RawPointerDereference, 8.5);
16        risk_weights.insert(RiskFactorType::UnsafeDataRace, 9.0);
17        risk_weights.insert(RiskFactorType::InvalidTransmute, 7.5);
18        risk_weights.insert(RiskFactorType::FfiCall, 6.0);
19        risk_weights.insert(RiskFactorType::ManualMemoryManagement, 7.0);
20        risk_weights.insert(RiskFactorType::CrossBoundaryTransfer, 6.5);
21        risk_weights.insert(RiskFactorType::UseAfterFree, 9.5);
22        risk_weights.insert(RiskFactorType::BufferOverflow, 9.0);
23        risk_weights.insert(RiskFactorType::LifetimeViolation, 8.0);
24
25        Self {
26            _risk_weights: risk_weights,
27            _historical_data: HashMap::new(),
28        }
29    }
30
31    pub fn assess_risk(
32        &self,
33        source: &UnsafeSource,
34        context: &MemoryContext,
35        call_stack: &[StackFrame],
36    ) -> RiskAssessment {
37        let mut risk_factors = Vec::new();
38        let mut total_risk_score = 0.0;
39        let mut total_confidence = 0.0;
40
41        match source {
42            UnsafeSource::UnsafeBlock { location, .. } => {
43                risk_factors.extend(self.analyze_unsafe_block(location, call_stack));
44            }
45            UnsafeSource::FfiFunction {
46                library, function, ..
47            } => {
48                risk_factors.extend(self.analyze_ffi_function(library, function, call_stack));
49            }
50            UnsafeSource::RawPointer { operation, .. } => {
51                risk_factors.extend(self.analyze_raw_pointer(operation, call_stack));
52            }
53            UnsafeSource::Transmute {
54                from_type, to_type, ..
55            } => {
56                risk_factors.extend(self.analyze_transmute(from_type, to_type, call_stack));
57            }
58        }
59
60        for factor in &risk_factors {
61            total_risk_score += factor.severity * factor.confidence;
62            total_confidence += factor.confidence;
63        }
64
65        let risk_count = risk_factors.len() as f64;
66        let average_confidence = if risk_count > 0.0 {
67            total_confidence / risk_count
68        } else {
69            0.0
70        };
71
72        let pressure_multiplier = match context.memory_pressure {
73            MemoryPressureLevel::Critical => 1.5,
74            MemoryPressureLevel::High => 1.2,
75            MemoryPressureLevel::Medium => 1.0,
76            MemoryPressureLevel::Low => 0.8,
77        };
78
79        total_risk_score *= pressure_multiplier;
80
81        let risk_level = if total_risk_score >= 80.0 {
82            crate::analysis::unsafe_ffi_tracker::RiskLevel::Critical
83        } else if total_risk_score >= 60.0 {
84            crate::analysis::unsafe_ffi_tracker::RiskLevel::High
85        } else if total_risk_score >= 40.0 {
86            crate::analysis::unsafe_ffi_tracker::RiskLevel::Medium
87        } else {
88            crate::analysis::unsafe_ffi_tracker::RiskLevel::Low
89        };
90
91        let mitigation_suggestions =
92            self.generate_mitigation_suggestions(&risk_factors, &risk_level);
93
94        RiskAssessment {
95            risk_level,
96            risk_score: total_risk_score.min(100.0),
97            risk_factors,
98            confidence_score: average_confidence,
99            mitigation_suggestions,
100            assessment_timestamp: SystemTime::now()
101                .duration_since(UNIX_EPOCH)
102                .unwrap_or_default()
103                .as_secs(),
104        }
105    }
106
107    fn analyze_unsafe_block(&self, location: &str, call_stack: &[StackFrame]) -> Vec<RiskFactor> {
108        let mut factors = Vec::new();
109
110        if location.contains("*") || location.contains("ptr::") {
111            factors.push(RiskFactor {
112                factor_type: RiskFactorType::RawPointerDereference,
113                severity: 7.5,
114                confidence: 0.8,
115                description: "Raw pointer dereference in unsafe block".to_string(),
116                source_location: Some(location.to_string()),
117                call_stack: call_stack.to_vec(),
118                mitigation: "Add bounds checking and null pointer validation".to_string(),
119            });
120        }
121
122        if location.contains("alloc") || location.contains("dealloc") || location.contains("free") {
123            factors.push(RiskFactor {
124                factor_type: RiskFactorType::ManualMemoryManagement,
125                severity: 6.5,
126                confidence: 0.9,
127                description: "Manual memory management in unsafe block".to_string(),
128                source_location: Some(location.to_string()),
129                call_stack: call_stack.to_vec(),
130                mitigation: "Use RAII patterns and smart pointers where possible".to_string(),
131            });
132        }
133
134        factors
135    }
136
137    fn analyze_ffi_function(
138        &self,
139        library: &str,
140        function: &str,
141        call_stack: &[StackFrame],
142    ) -> Vec<RiskFactor> {
143        let mut factors = Vec::new();
144
145        factors.push(RiskFactor {
146            factor_type: RiskFactorType::FfiCall,
147            severity: 5.5,
148            confidence: 0.7,
149            description: format!("FFI call to {library}::{function}"),
150            source_location: Some(format!("{library}::{function}")),
151            call_stack: call_stack.to_vec(),
152            mitigation: "Validate all parameters and return values".to_string(),
153        });
154
155        let risky_functions = ["malloc", "free", "strcpy", "strcat", "sprintf", "gets"];
156        if risky_functions.iter().any(|&f| function.contains(f)) {
157            factors.push(RiskFactor {
158                factor_type: RiskFactorType::BufferOverflow,
159                severity: 8.0,
160                confidence: 0.9,
161                description: format!("Call to potentially unsafe function: {function}"),
162                source_location: Some(format!("{library}::{function}")),
163                call_stack: call_stack.to_vec(),
164                mitigation: "Use safer alternatives or add explicit bounds checking".to_string(),
165            });
166        }
167
168        factors
169    }
170
171    fn analyze_raw_pointer(&self, operation: &str, call_stack: &[StackFrame]) -> Vec<RiskFactor> {
172        let mut factors = Vec::new();
173
174        factors.push(RiskFactor {
175            factor_type: RiskFactorType::RawPointerDereference,
176            severity: 8.0,
177            confidence: 0.85,
178            description: format!("Raw pointer operation: {operation}"),
179            source_location: Some(operation.to_string()),
180            call_stack: call_stack.to_vec(),
181            mitigation: "Add null checks and bounds validation".to_string(),
182        });
183
184        factors
185    }
186
187    fn analyze_transmute(
188        &self,
189        from_type: &str,
190        to_type: &str,
191        call_stack: &[StackFrame],
192    ) -> Vec<RiskFactor> {
193        let mut factors = Vec::new();
194
195        let severity = if from_type.contains("*") || to_type.contains("*") {
196            9.0
197        } else {
198            7.0
199        };
200
201        factors.push(RiskFactor {
202            factor_type: RiskFactorType::InvalidTransmute,
203            severity,
204            confidence: 0.8,
205            description: format!("Transmute from {from_type} to {to_type}"),
206            source_location: Some(format!("{from_type} -> {to_type}")),
207            call_stack: call_stack.to_vec(),
208            mitigation: "Verify size and alignment compatibility".to_string(),
209        });
210
211        factors
212    }
213
214    fn generate_mitigation_suggestions(
215        &self,
216        risk_factors: &[RiskFactor],
217        risk_level: &crate::analysis::unsafe_ffi_tracker::RiskLevel,
218    ) -> Vec<String> {
219        let mut suggestions = Vec::new();
220
221        match risk_level {
222            crate::analysis::unsafe_ffi_tracker::RiskLevel::Critical => {
223                suggestions.push(
224                    "URGENT: Critical safety issues detected - immediate review required"
225                        .to_string(),
226                );
227                suggestions.push(
228                    "Consider refactoring to eliminate unsafe code where possible".to_string(),
229                );
230            }
231            crate::analysis::unsafe_ffi_tracker::RiskLevel::High => {
232                suggestions.push(
233                    "High-risk operations detected - thorough testing recommended".to_string(),
234                );
235                suggestions.push("Add comprehensive error handling and validation".to_string());
236            }
237            crate::analysis::unsafe_ffi_tracker::RiskLevel::Medium => {
238                suggestions
239                    .push("Moderate risks detected - review and add safety checks".to_string());
240            }
241            crate::analysis::unsafe_ffi_tracker::RiskLevel::Low => {
242                suggestions.push("Low-level risks detected - monitor for issues".to_string());
243            }
244        }
245
246        let mut factor_types: HashSet<RiskFactorType> = HashSet::new();
247        for factor in risk_factors {
248            factor_types.insert(factor.factor_type.clone());
249        }
250
251        for factor_type in factor_types {
252            match factor_type {
253                RiskFactorType::RawPointerDereference => {
254                    suggestions.push("Add null pointer checks before dereferencing".to_string());
255                    suggestions.push("Validate pointer bounds and alignment".to_string());
256                }
257                RiskFactorType::UnsafeDataRace => {
258                    suggestions.push("Use proper synchronization primitives".to_string());
259                    suggestions.push("Consider using atomic operations".to_string());
260                }
261                RiskFactorType::FfiCall => {
262                    suggestions.push("Validate all FFI parameters and return values".to_string());
263                    suggestions.push("Handle FFI errors gracefully".to_string());
264                }
265                RiskFactorType::ManualMemoryManagement => {
266                    suggestions.push("Use RAII patterns to ensure cleanup".to_string());
267                    suggestions.push("Consider using smart pointers".to_string());
268                }
269                _ => {}
270            }
271        }
272
273        suggestions
274    }
275}
276
277impl Default for RiskAssessmentEngine {
278    fn default() -> Self {
279        Self::new()
280    }
281}