Skip to main content

memscope_rs/analysis/safety/
types.rs

1use crate::analysis::unsafe_ffi_tracker::{RiskLevel, StackFrame};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
6pub enum RiskFactorType {
7    RawPointerDereference,
8    UnsafeDataRace,
9    InvalidTransmute,
10    FfiCall,
11    ManualMemoryManagement,
12    CrossBoundaryTransfer,
13    UseAfterFree,
14    BufferOverflow,
15    LifetimeViolation,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RiskFactor {
20    pub factor_type: RiskFactorType,
21    pub severity: f64,
22    pub confidence: f64,
23    pub description: String,
24    pub source_location: Option<String>,
25    pub call_stack: Vec<StackFrame>,
26    pub mitigation: String,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct RiskAssessment {
31    pub risk_level: RiskLevel,
32    pub risk_score: f64,
33    pub risk_factors: Vec<RiskFactor>,
34    pub confidence_score: f64,
35    pub mitigation_suggestions: Vec<String>,
36    pub assessment_timestamp: u64,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct UnsafeReport {
41    pub report_id: String,
42    pub source: UnsafeSource,
43    pub risk_assessment: RiskAssessment,
44    pub dynamic_violations: Vec<DynamicViolation>,
45    pub related_passports: Vec<String>,
46    pub memory_context: MemoryContext,
47    pub generated_at: u64,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub enum UnsafeSource {
52    UnsafeBlock {
53        location: String,
54        function: String,
55        file_path: Option<String>,
56        line_number: Option<u32>,
57    },
58    FfiFunction {
59        library: String,
60        function: String,
61        call_site: String,
62    },
63    RawPointer {
64        operation: String,
65        location: String,
66    },
67    Transmute {
68        from_type: String,
69        to_type: String,
70        location: String,
71    },
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct DynamicViolation {
76    pub violation_type: ViolationType,
77    pub memory_address: usize,
78    pub memory_size: usize,
79    pub detected_at: u64,
80    pub call_stack: Vec<StackFrame>,
81    pub severity: RiskLevel,
82    pub context: String,
83}
84
85#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
86pub enum ViolationType {
87    DoubleFree,
88    UseAfterFree,
89    BufferOverflow,
90    InvalidAccess,
91    DataRace,
92    FfiBoundaryViolation,
93    MemoryLeak,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct MemoryContext {
98    pub total_allocated: usize,
99    pub active_allocations: usize,
100    pub memory_pressure: MemoryPressureLevel,
101    pub allocation_patterns: Vec<AllocationPattern>,
102}
103
104#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105pub enum MemoryPressureLevel {
106    Low,
107    Medium,
108    High,
109    Critical,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct AllocationPattern {
114    pub pattern_type: String,
115    pub frequency: u32,
116    pub average_size: usize,
117    pub risk_level: RiskLevel,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct MemoryPassport {
122    pub passport_id: String,
123    pub allocation_ptr: usize,
124    pub size_bytes: usize,
125    pub status_at_shutdown: PassportStatus,
126    pub lifecycle_events: Vec<PassportEvent>,
127    pub risk_assessment: RiskAssessment,
128    pub created_at: u64,
129    pub updated_at: u64,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub enum PassportStatus {
134    FreedByRust,
135    HandoverToFfi,
136    FreedByForeign,
137    ReclaimedByRust,
138    InForeignCustody,
139    Unknown,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct PassportEvent {
144    pub event_type: PassportEventType,
145    pub timestamp: u64,
146    pub context: String,
147    pub call_stack: Vec<StackFrame>,
148    pub metadata: HashMap<String, String>,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub enum PassportEventType {
153    AllocatedInRust,
154    HandoverToFfi,
155    FreedByForeign,
156    ReclaimedByRust,
157    BoundaryAccess,
158    OwnershipTransfer,
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    /// Objective: Verify RiskFactorType enum variants
166    /// Invariants: All variants should be constructible and comparable
167    #[test]
168    fn test_risk_factor_type_variants() {
169        let variants = vec![
170            RiskFactorType::RawPointerDereference,
171            RiskFactorType::UnsafeDataRace,
172            RiskFactorType::InvalidTransmute,
173            RiskFactorType::FfiCall,
174            RiskFactorType::ManualMemoryManagement,
175            RiskFactorType::CrossBoundaryTransfer,
176            RiskFactorType::UseAfterFree,
177            RiskFactorType::BufferOverflow,
178            RiskFactorType::LifetimeViolation,
179        ];
180
181        for variant in variants {
182            let debug_str = format!("{variant:?}");
183            assert!(
184                !debug_str.is_empty(),
185                "Variant should have debug representation"
186            );
187        }
188    }
189
190    /// Objective: Verify RiskFactor creation and fields
191    /// Invariants: All fields should be accessible
192    #[test]
193    fn test_risk_factor_creation() {
194        let factor = RiskFactor {
195            factor_type: RiskFactorType::BufferOverflow,
196            severity: 0.9,
197            confidence: 0.85,
198            description: "Test buffer overflow".to_string(),
199            source_location: Some("test.rs:10".to_string()),
200            call_stack: vec![],
201            mitigation: "Use bounds checking".to_string(),
202        };
203
204        assert_eq!(factor.severity, 0.9, "Severity should match");
205        assert_eq!(factor.confidence, 0.85, "Confidence should match");
206        assert!(
207            factor.source_location.is_some(),
208            "Source location should be present"
209        );
210    }
211
212    /// Objective: Verify RiskAssessment creation
213    /// Invariants: All fields should be properly initialized
214    #[test]
215    fn test_risk_assessment_creation() {
216        let assessment = RiskAssessment {
217            risk_level: RiskLevel::High,
218            risk_score: 75.0,
219            risk_factors: vec![],
220            confidence_score: 0.9,
221            mitigation_suggestions: vec!["Review code".to_string()],
222            assessment_timestamp: 1000,
223        };
224
225        assert_eq!(
226            assessment.risk_level,
227            RiskLevel::High,
228            "Risk level should be High"
229        );
230        assert_eq!(assessment.risk_score, 75.0, "Risk score should match");
231        assert_eq!(
232            assessment.mitigation_suggestions.len(),
233            1,
234            "Should have one suggestion"
235        );
236    }
237
238    /// Objective: Verify UnsafeSource variants
239    /// Invariants: All source types should be constructible
240    #[test]
241    fn test_unsafe_source_variants() {
242        let block = UnsafeSource::UnsafeBlock {
243            location: "test.rs:10".to_string(),
244            function: "test_fn".to_string(),
245            file_path: Some("test.rs".to_string()),
246            line_number: Some(10),
247        };
248
249        let ffi = UnsafeSource::FfiFunction {
250            library: "libc".to_string(),
251            function: "malloc".to_string(),
252            call_site: "test.rs:20".to_string(),
253        };
254
255        let raw = UnsafeSource::RawPointer {
256            operation: "deref".to_string(),
257            location: "0x1000".to_string(),
258        };
259
260        let transmute = UnsafeSource::Transmute {
261            from_type: "u8".to_string(),
262            to_type: "i8".to_string(),
263            location: "test.rs:30".to_string(),
264        };
265
266        assert!(matches!(block, UnsafeSource::UnsafeBlock { .. }));
267        assert!(matches!(ffi, UnsafeSource::FfiFunction { .. }));
268        assert!(matches!(raw, UnsafeSource::RawPointer { .. }));
269        assert!(matches!(transmute, UnsafeSource::Transmute { .. }));
270    }
271
272    /// Objective: Verify DynamicViolation creation
273    /// Invariants: All fields should be accessible
274    #[test]
275    fn test_dynamic_violation_creation() {
276        let violation = DynamicViolation {
277            violation_type: ViolationType::UseAfterFree,
278            memory_address: 0x1000,
279            memory_size: 1024,
280            detected_at: 1000,
281            call_stack: vec![],
282            severity: RiskLevel::Critical,
283            context: "Use after free detected".to_string(),
284        };
285
286        assert_eq!(
287            violation.memory_address, 0x1000,
288            "Memory address should match"
289        );
290        assert_eq!(violation.memory_size, 1024, "Memory size should match");
291        assert_eq!(
292            violation.severity,
293            RiskLevel::Critical,
294            "Severity should be Critical"
295        );
296    }
297
298    /// Objective: Verify ViolationType variants
299    /// Invariants: All variants should be comparable
300    #[test]
301    fn test_violation_type_equality() {
302        assert_eq!(ViolationType::DoubleFree, ViolationType::DoubleFree);
303        assert_eq!(ViolationType::UseAfterFree, ViolationType::UseAfterFree);
304        assert_eq!(ViolationType::BufferOverflow, ViolationType::BufferOverflow);
305        assert_eq!(ViolationType::InvalidAccess, ViolationType::InvalidAccess);
306        assert_eq!(ViolationType::DataRace, ViolationType::DataRace);
307        assert_eq!(
308            ViolationType::FfiBoundaryViolation,
309            ViolationType::FfiBoundaryViolation
310        );
311
312        assert_ne!(ViolationType::DoubleFree, ViolationType::UseAfterFree);
313    }
314
315    /// Objective: Verify MemoryContext creation
316    /// Invariants: All fields should be properly initialized
317    #[test]
318    fn test_memory_context_creation() {
319        let context = MemoryContext {
320            total_allocated: 1024 * 1024,
321            active_allocations: 10,
322            memory_pressure: MemoryPressureLevel::Medium,
323            allocation_patterns: vec![],
324        };
325
326        assert_eq!(
327            context.total_allocated,
328            1024 * 1024,
329            "Total allocated should match"
330        );
331        assert_eq!(
332            context.active_allocations, 10,
333            "Active allocations should match"
334        );
335    }
336
337    /// Objective: Verify MemoryPressureLevel variants
338    /// Invariants: All levels should be distinct
339    #[test]
340    fn test_memory_pressure_level() {
341        let levels = [
342            MemoryPressureLevel::Low,
343            MemoryPressureLevel::Medium,
344            MemoryPressureLevel::High,
345            MemoryPressureLevel::Critical,
346        ];
347
348        for (i, level) in levels.iter().enumerate() {
349            for (j, other) in levels.iter().enumerate() {
350                if i == j {
351                    assert_eq!(level, other, "Same levels should be equal");
352                } else {
353                    assert_ne!(level, other, "Different levels should not be equal");
354                }
355            }
356        }
357    }
358
359    /// Objective: Verify AllocationPattern creation
360    /// Invariants: All fields should be accessible
361    #[test]
362    fn test_allocation_pattern_creation() {
363        let pattern = AllocationPattern {
364            pattern_type: "repeated".to_string(),
365            frequency: 100,
366            average_size: 256,
367            risk_level: RiskLevel::Medium,
368        };
369
370        assert_eq!(pattern.frequency, 100, "Frequency should match");
371        assert_eq!(pattern.average_size, 256, "Average size should match");
372    }
373
374    /// Objective: Verify MemoryPassport creation
375    /// Invariants: All fields should be properly initialized
376    #[test]
377    fn test_memory_passport_creation() {
378        let passport = MemoryPassport {
379            passport_id: "passport_123".to_string(),
380            allocation_ptr: 0x1000,
381            size_bytes: 1024,
382            status_at_shutdown: PassportStatus::Unknown,
383            lifecycle_events: vec![],
384            risk_assessment: RiskAssessment {
385                risk_level: RiskLevel::Low,
386                risk_score: 10.0,
387                risk_factors: vec![],
388                confidence_score: 0.5,
389                mitigation_suggestions: vec![],
390                assessment_timestamp: 0,
391            },
392            created_at: 1000,
393            updated_at: 1000,
394        };
395
396        assert_eq!(
397            passport.passport_id, "passport_123",
398            "Passport ID should match"
399        );
400        assert_eq!(
401            passport.allocation_ptr, 0x1000,
402            "Allocation pointer should match"
403        );
404        assert_eq!(passport.size_bytes, 1024, "Size should match");
405    }
406
407    /// Objective: Verify PassportStatus variants
408    /// Invariants: All statuses should be distinct
409    #[test]
410    fn test_passport_status_variants() {
411        let statuses = vec![
412            PassportStatus::FreedByRust,
413            PassportStatus::HandoverToFfi,
414            PassportStatus::FreedByForeign,
415            PassportStatus::ReclaimedByRust,
416            PassportStatus::InForeignCustody,
417            PassportStatus::Unknown,
418        ];
419
420        for status in &statuses {
421            let debug_str = format!("{status:?}");
422            assert!(
423                !debug_str.is_empty(),
424                "Status should have debug representation"
425            );
426        }
427    }
428
429    /// Objective: Verify PassportEvent creation
430    /// Invariants: All fields should be accessible
431    #[test]
432    fn test_passport_event_creation() {
433        let event = PassportEvent {
434            event_type: PassportEventType::HandoverToFfi,
435            timestamp: 1000,
436            context: "ffi_transfer".to_string(),
437            call_stack: vec![],
438            metadata: HashMap::new(),
439        };
440
441        assert_eq!(event.timestamp, 1000, "Timestamp should match");
442        assert_eq!(event.context, "ffi_transfer", "Context should match");
443    }
444
445    /// Objective: Verify PassportEventType variants
446    /// Invariants: All event types should be distinct
447    #[test]
448    fn test_passport_event_type_variants() {
449        let event_types = vec![
450            PassportEventType::AllocatedInRust,
451            PassportEventType::HandoverToFfi,
452            PassportEventType::FreedByForeign,
453            PassportEventType::ReclaimedByRust,
454            PassportEventType::BoundaryAccess,
455            PassportEventType::OwnershipTransfer,
456        ];
457
458        for event_type in &event_types {
459            let debug_str = format!("{event_type:?}");
460            assert!(
461                !debug_str.is_empty(),
462                "Event type should have debug representation"
463            );
464        }
465    }
466
467    /// Objective: Verify UnsafeReport creation
468    /// Invariants: All fields should be properly initialized
469    #[test]
470    fn test_unsafe_report_creation() {
471        let report = UnsafeReport {
472            report_id: "UNSAFE-UB-123".to_string(),
473            source: UnsafeSource::UnsafeBlock {
474                location: "test.rs".to_string(),
475                function: "test".to_string(),
476                file_path: None,
477                line_number: None,
478            },
479            risk_assessment: RiskAssessment {
480                risk_level: RiskLevel::Medium,
481                risk_score: 50.0,
482                risk_factors: vec![],
483                confidence_score: 0.8,
484                mitigation_suggestions: vec![],
485                assessment_timestamp: 0,
486            },
487            dynamic_violations: vec![],
488            related_passports: vec![],
489            memory_context: MemoryContext {
490                total_allocated: 0,
491                active_allocations: 0,
492                memory_pressure: MemoryPressureLevel::Low,
493                allocation_patterns: vec![],
494            },
495            generated_at: 1000,
496        };
497
498        assert_eq!(report.report_id, "UNSAFE-UB-123", "Report ID should match");
499        assert_eq!(
500            report.generated_at, 1000,
501            "Generated timestamp should match"
502        );
503    }
504
505    /// Objective: Verify RiskFactor with edge case values
506    /// Invariants: Should handle zero and max values
507    #[test]
508    fn test_risk_factor_edge_values() {
509        let zero_factor = RiskFactor {
510            factor_type: RiskFactorType::UseAfterFree,
511            severity: 0.0,
512            confidence: 0.0,
513            description: String::new(),
514            source_location: None,
515            call_stack: vec![],
516            mitigation: String::new(),
517        };
518
519        let max_factor = RiskFactor {
520            factor_type: RiskFactorType::BufferOverflow,
521            severity: 1.0,
522            confidence: 1.0,
523            description: "x".repeat(1000),
524            source_location: Some("x".repeat(1000)),
525            call_stack: vec![],
526            mitigation: "x".repeat(1000),
527        };
528
529        assert_eq!(zero_factor.severity, 0.0, "Zero severity should be valid");
530        assert_eq!(max_factor.severity, 1.0, "Max severity should be valid");
531        assert_eq!(
532            max_factor.description.len(),
533            1000,
534            "Long description should be preserved"
535        );
536    }
537
538    /// Objective: Verify serialization of types
539    /// Invariants: Types should serialize and deserialize correctly
540    #[test]
541    fn test_serialization() {
542        let assessment = RiskAssessment {
543            risk_level: RiskLevel::High,
544            risk_score: 75.0,
545            risk_factors: vec![],
546            confidence_score: 0.9,
547            mitigation_suggestions: vec!["test".to_string()],
548            assessment_timestamp: 1000,
549        };
550
551        let json = serde_json::to_string(&assessment);
552        assert!(json.is_ok(), "Should serialize to JSON");
553
554        let deserialized: Result<RiskAssessment, _> = serde_json::from_str(&json.unwrap());
555        assert!(deserialized.is_ok(), "Should deserialize from JSON");
556    }
557}