Skip to main content

memscope_rs/analysis/safety/
analyzer.rs

1use crate::analysis::safety::engine::RiskAssessmentEngine;
2use crate::analysis::safety::types::*;
3use crate::analysis::unsafe_ffi_tracker::{RiskLevel, SafetyViolation, StackFrame};
4use crate::capture::types::{AllocationInfo, TrackingResult};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use std::time::{SystemTime, UNIX_EPOCH};
9
10#[derive(Debug, Clone)]
11pub struct SafetyAnalysisConfig {
12    pub detailed_risk_assessment: bool,
13    pub enable_passport_tracking: bool,
14    pub min_risk_level: RiskLevel,
15    pub max_reports: usize,
16    pub enable_dynamic_violations: bool,
17}
18
19impl Default for SafetyAnalysisConfig {
20    fn default() -> Self {
21        Self {
22            detailed_risk_assessment: true,
23            enable_passport_tracking: true,
24            min_risk_level: RiskLevel::Low,
25            max_reports: 1000,
26            enable_dynamic_violations: true,
27        }
28    }
29}
30
31#[derive(Debug, Clone, Default, Serialize, Deserialize)]
32pub struct SafetyAnalysisStats {
33    pub total_reports: usize,
34    pub reports_by_risk_level: HashMap<String, usize>,
35    pub total_passports: usize,
36    pub passports_by_status: HashMap<String, usize>,
37    pub dynamic_violations: usize,
38    pub analysis_start_time: u64,
39}
40
41pub struct SafetyAnalyzer {
42    unsafe_reports: Arc<Mutex<HashMap<String, UnsafeReport>>>,
43    memory_passports: Arc<Mutex<HashMap<usize, MemoryPassport>>>,
44    risk_engine: RiskAssessmentEngine,
45    config: SafetyAnalysisConfig,
46    stats: Arc<Mutex<SafetyAnalysisStats>>,
47}
48
49impl SafetyAnalyzer {
50    pub fn new(config: SafetyAnalysisConfig) -> Self {
51        tracing::info!("🔒 Initializing Safety Analyzer");
52        tracing::info!(
53            "   • Detailed risk assessment: {}",
54            config.detailed_risk_assessment
55        );
56        tracing::info!(
57            "   • Passport tracking: {}",
58            config.enable_passport_tracking
59        );
60        tracing::info!("   • Min risk level: {:?}", config.min_risk_level);
61
62        Self {
63            unsafe_reports: Arc::new(Mutex::new(HashMap::new())),
64            memory_passports: Arc::new(Mutex::new(HashMap::new())),
65            risk_engine: RiskAssessmentEngine::new(),
66            config,
67            stats: Arc::new(Mutex::new(SafetyAnalysisStats {
68                analysis_start_time: SystemTime::now()
69                    .duration_since(UNIX_EPOCH)
70                    .unwrap_or_default()
71                    .as_secs(),
72                ..Default::default()
73            })),
74        }
75    }
76
77    pub fn generate_unsafe_report(
78        &self,
79        source: UnsafeSource,
80        allocations: &[AllocationInfo],
81        violations: &[SafetyViolation],
82    ) -> TrackingResult<String> {
83        let report_id = self.generate_report_id(&source);
84
85        tracing::info!("🔍 Generating unsafe report: {}", report_id);
86
87        let memory_context = self.create_memory_context(allocations);
88        let call_stack = self.capture_call_stack()?;
89
90        let risk_assessment = if self.config.detailed_risk_assessment {
91            self.risk_engine
92                .assess_risk(&source, &memory_context, &call_stack)
93        } else {
94            self.create_basic_risk_assessment(&source)
95        };
96
97        if !self.should_generate_report(&risk_assessment.risk_level) {
98            return Ok(report_id);
99        }
100
101        let dynamic_violations = self.convert_safety_violations(violations);
102
103        let related_passports = if self.config.enable_passport_tracking {
104            self.find_related_passports(&source, allocations)
105        } else {
106            Vec::new()
107        };
108
109        let report = UnsafeReport {
110            report_id: report_id.clone(),
111            source,
112            risk_assessment: risk_assessment.clone(),
113            dynamic_violations,
114            related_passports,
115            memory_context,
116            generated_at: SystemTime::now()
117                .duration_since(UNIX_EPOCH)
118                .unwrap_or_default()
119                .as_secs(),
120        };
121
122        if let Ok(mut reports) = self.unsafe_reports.lock() {
123            if reports.len() >= self.config.max_reports {
124                if let Some(oldest_id) = reports.keys().next().cloned() {
125                    reports.remove(&oldest_id);
126                }
127            }
128            reports.insert(report_id.clone(), report);
129        }
130
131        self.update_stats(&report_id, &risk_assessment.risk_level);
132
133        tracing::info!(
134            "✅ Generated unsafe report: {} (risk: {:?})",
135            report_id,
136            risk_assessment.risk_level
137        );
138
139        Ok(report_id)
140    }
141
142    pub fn create_memory_passport(
143        &self,
144        allocation_ptr: usize,
145        size_bytes: usize,
146        initial_event: PassportEventType,
147    ) -> TrackingResult<String> {
148        if !self.config.enable_passport_tracking {
149            return Ok(String::new());
150        }
151
152        let passport_id = format!(
153            "passport_{:x}_{}",
154            allocation_ptr,
155            SystemTime::now()
156                .duration_since(UNIX_EPOCH)
157                .unwrap_or_default()
158                .as_nanos()
159        );
160
161        let call_stack = self.capture_call_stack()?;
162        let current_time = SystemTime::now()
163            .duration_since(UNIX_EPOCH)
164            .unwrap_or_default()
165            .as_secs();
166
167        let initial_passport_event = PassportEvent {
168            event_type: initial_event,
169            timestamp: current_time,
170            context: "SafetyAnalyzer".to_string(),
171            call_stack,
172            metadata: HashMap::new(),
173        };
174
175        let memory_context = MemoryContext {
176            total_allocated: size_bytes,
177            active_allocations: 1,
178            memory_pressure: MemoryPressureLevel::Low,
179            allocation_patterns: Vec::new(),
180        };
181
182        let source = UnsafeSource::RawPointer {
183            operation: "passport_creation".to_string(),
184            location: format!("0x{allocation_ptr:x}"),
185        };
186
187        let risk_assessment = self.risk_engine.assess_risk(&source, &memory_context, &[]);
188
189        let passport = MemoryPassport {
190            passport_id: passport_id.clone(),
191            allocation_ptr,
192            size_bytes,
193            status_at_shutdown: PassportStatus::Unknown,
194            lifecycle_events: vec![initial_passport_event],
195            risk_assessment,
196            created_at: current_time,
197            updated_at: current_time,
198        };
199
200        if let Ok(mut passports) = self.memory_passports.lock() {
201            passports.insert(allocation_ptr, passport);
202        }
203
204        if let Ok(mut stats) = self.stats.lock() {
205            stats.total_passports += 1;
206        }
207
208        tracing::info!(
209            "📋 Created memory passport: {} for 0x{:x}",
210            passport_id,
211            allocation_ptr
212        );
213
214        Ok(passport_id)
215    }
216
217    pub fn record_passport_event(
218        &self,
219        allocation_ptr: usize,
220        event_type: PassportEventType,
221        context: String,
222    ) -> TrackingResult<()> {
223        if !self.config.enable_passport_tracking {
224            return Ok(());
225        }
226
227        let call_stack = self.capture_call_stack()?;
228        let current_time = SystemTime::now()
229            .duration_since(UNIX_EPOCH)
230            .unwrap_or_default()
231            .as_secs();
232
233        let event = PassportEvent {
234            event_type,
235            timestamp: current_time,
236            context,
237            call_stack,
238            metadata: HashMap::new(),
239        };
240
241        if let Ok(mut passports) = self.memory_passports.lock() {
242            if let Some(passport) = passports.get_mut(&allocation_ptr) {
243                passport.lifecycle_events.push(event);
244                passport.updated_at = current_time;
245
246                tracing::info!("📝 Recorded passport event for 0x{:x}", allocation_ptr);
247            }
248        }
249
250        Ok(())
251    }
252
253    pub fn finalize_passports_at_shutdown(&self) -> Vec<String> {
254        let mut leaked_passports = Vec::new();
255
256        if let Ok(mut passports) = self.memory_passports.lock() {
257            for (ptr, passport) in passports.iter_mut() {
258                let final_status = self.determine_final_passport_status(&passport.lifecycle_events);
259                passport.status_at_shutdown = final_status.clone();
260
261                if matches!(final_status, PassportStatus::InForeignCustody) {
262                    leaked_passports.push(passport.passport_id.clone());
263                    tracing::warn!(
264                        "🚨 Memory leak detected: passport {} (0x{:x}) in foreign custody",
265                        passport.passport_id,
266                        ptr
267                    );
268                }
269            }
270
271            if let Ok(mut stats) = self.stats.lock() {
272                for passport in passports.values() {
273                    let status_key = format!("{:?}", passport.status_at_shutdown);
274                    *stats.passports_by_status.entry(status_key).or_insert(0) += 1;
275                }
276            }
277        }
278
279        tracing::info!(
280            "🏁 Finalized {} passports, {} leaks detected",
281            self.get_passport_count(),
282            leaked_passports.len()
283        );
284
285        leaked_passports
286    }
287
288    pub fn get_unsafe_reports(&self) -> HashMap<String, UnsafeReport> {
289        self.unsafe_reports
290            .lock()
291            .unwrap_or_else(|e| {
292                tracing::warn!(
293                    "Mutex poisoned in get_unsafe_reports, recovering data: {}",
294                    e
295                );
296                e.into_inner()
297            })
298            .clone()
299    }
300
301    pub fn get_memory_passports(&self) -> HashMap<usize, MemoryPassport> {
302        self.memory_passports
303            .lock()
304            .unwrap_or_else(|e| {
305                tracing::warn!(
306                    "Mutex poisoned in get_memory_passports, recovering data: {}",
307                    e
308                );
309                e.into_inner()
310            })
311            .clone()
312    }
313
314    pub fn get_stats(&self) -> SafetyAnalysisStats {
315        self.stats
316            .lock()
317            .unwrap_or_else(|e| {
318                tracing::warn!("Mutex poisoned in get_stats, recovering data: {}", e);
319                e.into_inner()
320            })
321            .clone()
322    }
323
324    fn generate_report_id(&self, source: &UnsafeSource) -> String {
325        let timestamp = SystemTime::now()
326            .duration_since(UNIX_EPOCH)
327            .unwrap_or_default()
328            .as_nanos();
329
330        let source_type = match source {
331            UnsafeSource::UnsafeBlock { .. } => "UB",
332            UnsafeSource::FfiFunction { .. } => "FFI",
333            UnsafeSource::RawPointer { .. } => "PTR",
334            UnsafeSource::Transmute { .. } => "TX",
335        };
336
337        format!("UNSAFE-{}-{}", source_type, timestamp % 1000000)
338    }
339
340    fn create_memory_context(&self, allocations: &[AllocationInfo]) -> MemoryContext {
341        let total_allocated = allocations.iter().map(|a| a.size).sum();
342        let active_allocations = allocations
343            .iter()
344            .filter(|a| a.timestamp_dealloc.is_none())
345            .count();
346
347        let memory_pressure = if total_allocated > 1024 * 1024 * 1024 {
348            MemoryPressureLevel::Critical
349        } else if total_allocated > 512 * 1024 * 1024 {
350            MemoryPressureLevel::High
351        } else if total_allocated > 256 * 1024 * 1024 {
352            MemoryPressureLevel::Medium
353        } else {
354            MemoryPressureLevel::Low
355        };
356
357        MemoryContext {
358            total_allocated,
359            active_allocations,
360            memory_pressure,
361            allocation_patterns: Vec::new(),
362        }
363    }
364
365    fn capture_call_stack(&self) -> TrackingResult<Vec<StackFrame>> {
366        Ok(vec![StackFrame {
367            function_name: "safety_analyzer".to_string(),
368            file_name: Some("src/analysis/safety_analyzer.rs".to_string()),
369            line_number: Some(1),
370            is_unsafe: false,
371        }])
372    }
373
374    fn create_basic_risk_assessment(&self, source: &UnsafeSource) -> RiskAssessment {
375        let (risk_level, risk_score) = match source {
376            UnsafeSource::UnsafeBlock { .. } => (RiskLevel::Medium, 50.0),
377            UnsafeSource::FfiFunction { .. } => (RiskLevel::Medium, 45.0),
378            UnsafeSource::RawPointer { .. } => (RiskLevel::High, 70.0),
379            UnsafeSource::Transmute { .. } => (RiskLevel::High, 65.0),
380        };
381
382        RiskAssessment {
383            risk_level,
384            risk_score,
385            risk_factors: Vec::new(),
386            confidence_score: 0.5,
387            mitigation_suggestions: vec!["Review unsafe operation for safety".to_string()],
388            assessment_timestamp: SystemTime::now()
389                .duration_since(UNIX_EPOCH)
390                .unwrap_or_default()
391                .as_secs(),
392        }
393    }
394
395    fn should_generate_report(&self, risk_level: &RiskLevel) -> bool {
396        match (&self.config.min_risk_level, risk_level) {
397            (RiskLevel::Low, _) => true,
398            (RiskLevel::Medium, RiskLevel::Low) => false,
399            (RiskLevel::Medium, _) => true,
400            (RiskLevel::High, RiskLevel::Low | RiskLevel::Medium) => false,
401            (RiskLevel::High, _) => true,
402            (RiskLevel::Critical, RiskLevel::Critical) => true,
403            (RiskLevel::Critical, _) => false,
404        }
405    }
406
407    fn convert_safety_violations(&self, violations: &[SafetyViolation]) -> Vec<DynamicViolation> {
408        violations
409            .iter()
410            .map(|v| match v {
411                SafetyViolation::DoubleFree { timestamp, .. } => DynamicViolation {
412                    violation_type: ViolationType::DoubleFree,
413                    memory_address: 0,
414                    memory_size: 0,
415                    detected_at: (*timestamp as u64),
416                    call_stack: Vec::new(),
417                    severity: RiskLevel::Critical,
418                    context: "Double free detected".to_string(),
419                },
420                SafetyViolation::InvalidFree {
421                    attempted_pointer,
422                    timestamp,
423                    ..
424                } => DynamicViolation {
425                    violation_type: ViolationType::InvalidAccess,
426                    memory_address: *attempted_pointer,
427                    memory_size: 0,
428                    detected_at: (*timestamp as u64),
429                    call_stack: Vec::new(),
430                    severity: RiskLevel::High,
431                    context: "Invalid free attempted".to_string(),
432                },
433                SafetyViolation::PotentialLeak {
434                    leak_detection_timestamp,
435                    ..
436                } => DynamicViolation {
437                    violation_type: ViolationType::InvalidAccess,
438                    memory_address: 0,
439                    memory_size: 0,
440                    detected_at: (*leak_detection_timestamp as u64),
441                    call_stack: Vec::new(),
442                    severity: RiskLevel::Medium,
443                    context: "Potential memory leak".to_string(),
444                },
445                SafetyViolation::CrossBoundaryRisk { .. } => DynamicViolation {
446                    violation_type: ViolationType::FfiBoundaryViolation,
447                    memory_address: 0,
448                    memory_size: 0,
449                    detected_at: SystemTime::now()
450                        .duration_since(UNIX_EPOCH)
451                        .unwrap_or_default()
452                        .as_secs(),
453                    call_stack: Vec::new(),
454                    severity: RiskLevel::Medium,
455                    context: "Cross-boundary risk detected".to_string(),
456                },
457            })
458            .collect()
459    }
460
461    fn find_related_passports(
462        &self,
463        _source: &UnsafeSource,
464        _allocations: &[AllocationInfo],
465    ) -> Vec<String> {
466        Vec::new()
467    }
468
469    fn update_stats(&self, _report_id: &str, risk_level: &RiskLevel) {
470        if let Ok(mut stats) = self.stats.lock() {
471            stats.total_reports += 1;
472            let risk_key = format!("{risk_level:?}");
473            *stats.reports_by_risk_level.entry(risk_key).or_insert(0) += 1;
474        }
475    }
476
477    fn determine_final_passport_status(&self, events: &[PassportEvent]) -> PassportStatus {
478        let mut has_handover = false;
479        let mut has_reclaim = false;
480        let mut has_foreign_free = false;
481
482        for event in events {
483            match event.event_type {
484                PassportEventType::HandoverToFfi => has_handover = true,
485                PassportEventType::ReclaimedByRust => has_reclaim = true,
486                PassportEventType::FreedByForeign => has_foreign_free = true,
487                _ => {}
488            }
489        }
490
491        if has_handover && !has_reclaim && !has_foreign_free {
492            PassportStatus::InForeignCustody
493        } else if has_foreign_free {
494            PassportStatus::FreedByForeign
495        } else if has_reclaim {
496            PassportStatus::ReclaimedByRust
497        } else if has_handover {
498            PassportStatus::HandoverToFfi
499        } else {
500            PassportStatus::FreedByRust
501        }
502    }
503
504    fn get_passport_count(&self) -> usize {
505        self.memory_passports.lock().map(|p| p.len()).unwrap_or(0)
506    }
507}
508
509impl Default for SafetyAnalyzer {
510    fn default() -> Self {
511        Self::new(SafetyAnalysisConfig::default())
512    }
513}