1use crate::analysis::unsafe_ffi_tracker::{RiskLevel, StackFrame};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
6pub enum RiskFactorType {
7 RawPointerDereference,
8 UnsafeDataRace,
9 InvalidTransmute,
10 FfiCall,
11 ManualMemoryManagement,
12 CrossBoundaryTransfer,
13 UseAfterFree,
14 BufferOverflow,
15 LifetimeViolation,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RiskFactor {
20 pub factor_type: RiskFactorType,
21 pub severity: f64,
22 pub confidence: f64,
23 pub description: String,
24 pub source_location: Option<String>,
25 pub call_stack: Vec<StackFrame>,
26 pub mitigation: String,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct RiskAssessment {
31 pub risk_level: RiskLevel,
32 pub risk_score: f64,
33 pub risk_factors: Vec<RiskFactor>,
34 pub confidence_score: f64,
35 pub mitigation_suggestions: Vec<String>,
36 pub assessment_timestamp: u64,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct UnsafeReport {
41 pub report_id: String,
42 pub source: UnsafeSource,
43 pub risk_assessment: RiskAssessment,
44 pub dynamic_violations: Vec<DynamicViolation>,
45 pub related_passports: Vec<String>,
46 pub memory_context: MemoryContext,
47 pub generated_at: u64,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub enum UnsafeSource {
52 UnsafeBlock {
53 location: String,
54 function: String,
55 file_path: Option<String>,
56 line_number: Option<u32>,
57 },
58 FfiFunction {
59 library: String,
60 function: String,
61 call_site: String,
62 },
63 RawPointer {
64 operation: String,
65 location: String,
66 },
67 Transmute {
68 from_type: String,
69 to_type: String,
70 location: String,
71 },
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct DynamicViolation {
76 pub violation_type: ViolationType,
77 pub memory_address: usize,
78 pub memory_size: usize,
79 pub detected_at: u64,
80 pub call_stack: Vec<StackFrame>,
81 pub severity: RiskLevel,
82 pub context: String,
83}
84
85#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
86pub enum ViolationType {
87 DoubleFree,
88 UseAfterFree,
89 BufferOverflow,
90 InvalidAccess,
91 DataRace,
92 FfiBoundaryViolation,
93 MemoryLeak,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct MemoryContext {
98 pub total_allocated: usize,
99 pub active_allocations: usize,
100 pub memory_pressure: MemoryPressureLevel,
101 pub allocation_patterns: Vec<AllocationPattern>,
102}
103
104#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105pub enum MemoryPressureLevel {
106 Low,
107 Medium,
108 High,
109 Critical,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct AllocationPattern {
114 pub pattern_type: String,
115 pub frequency: u32,
116 pub average_size: usize,
117 pub risk_level: RiskLevel,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct MemoryPassport {
122 pub passport_id: String,
123 pub allocation_ptr: usize,
124 pub size_bytes: usize,
125 pub status_at_shutdown: PassportStatus,
126 pub lifecycle_events: Vec<PassportEvent>,
127 pub risk_assessment: RiskAssessment,
128 pub created_at: u64,
129 pub updated_at: u64,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub enum PassportStatus {
134 FreedByRust,
135 HandoverToFfi,
136 FreedByForeign,
137 ReclaimedByRust,
138 InForeignCustody,
139 Unknown,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct PassportEvent {
144 pub event_type: PassportEventType,
145 pub timestamp: u64,
146 pub context: String,
147 pub call_stack: Vec<StackFrame>,
148 pub metadata: HashMap<String, String>,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub enum PassportEventType {
153 AllocatedInRust,
154 HandoverToFfi,
155 FreedByForeign,
156 ReclaimedByRust,
157 BoundaryAccess,
158 OwnershipTransfer,
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 #[test]
168 fn test_risk_factor_type_variants() {
169 let variants = vec![
170 RiskFactorType::RawPointerDereference,
171 RiskFactorType::UnsafeDataRace,
172 RiskFactorType::InvalidTransmute,
173 RiskFactorType::FfiCall,
174 RiskFactorType::ManualMemoryManagement,
175 RiskFactorType::CrossBoundaryTransfer,
176 RiskFactorType::UseAfterFree,
177 RiskFactorType::BufferOverflow,
178 RiskFactorType::LifetimeViolation,
179 ];
180
181 for variant in variants {
182 let debug_str = format!("{variant:?}");
183 assert!(
184 !debug_str.is_empty(),
185 "Variant should have debug representation"
186 );
187 }
188 }
189
190 #[test]
193 fn test_risk_factor_creation() {
194 let factor = RiskFactor {
195 factor_type: RiskFactorType::BufferOverflow,
196 severity: 0.9,
197 confidence: 0.85,
198 description: "Test buffer overflow".to_string(),
199 source_location: Some("test.rs:10".to_string()),
200 call_stack: vec![],
201 mitigation: "Use bounds checking".to_string(),
202 };
203
204 assert_eq!(factor.severity, 0.9, "Severity should match");
205 assert_eq!(factor.confidence, 0.85, "Confidence should match");
206 assert!(
207 factor.source_location.is_some(),
208 "Source location should be present"
209 );
210 }
211
212 #[test]
215 fn test_risk_assessment_creation() {
216 let assessment = RiskAssessment {
217 risk_level: RiskLevel::High,
218 risk_score: 75.0,
219 risk_factors: vec![],
220 confidence_score: 0.9,
221 mitigation_suggestions: vec!["Review code".to_string()],
222 assessment_timestamp: 1000,
223 };
224
225 assert_eq!(
226 assessment.risk_level,
227 RiskLevel::High,
228 "Risk level should be High"
229 );
230 assert_eq!(assessment.risk_score, 75.0, "Risk score should match");
231 assert_eq!(
232 assessment.mitigation_suggestions.len(),
233 1,
234 "Should have one suggestion"
235 );
236 }
237
238 #[test]
241 fn test_unsafe_source_variants() {
242 let block = UnsafeSource::UnsafeBlock {
243 location: "test.rs:10".to_string(),
244 function: "test_fn".to_string(),
245 file_path: Some("test.rs".to_string()),
246 line_number: Some(10),
247 };
248
249 let ffi = UnsafeSource::FfiFunction {
250 library: "libc".to_string(),
251 function: "malloc".to_string(),
252 call_site: "test.rs:20".to_string(),
253 };
254
255 let raw = UnsafeSource::RawPointer {
256 operation: "deref".to_string(),
257 location: "0x1000".to_string(),
258 };
259
260 let transmute = UnsafeSource::Transmute {
261 from_type: "u8".to_string(),
262 to_type: "i8".to_string(),
263 location: "test.rs:30".to_string(),
264 };
265
266 assert!(matches!(block, UnsafeSource::UnsafeBlock { .. }));
267 assert!(matches!(ffi, UnsafeSource::FfiFunction { .. }));
268 assert!(matches!(raw, UnsafeSource::RawPointer { .. }));
269 assert!(matches!(transmute, UnsafeSource::Transmute { .. }));
270 }
271
272 #[test]
275 fn test_dynamic_violation_creation() {
276 let violation = DynamicViolation {
277 violation_type: ViolationType::UseAfterFree,
278 memory_address: 0x1000,
279 memory_size: 1024,
280 detected_at: 1000,
281 call_stack: vec![],
282 severity: RiskLevel::Critical,
283 context: "Use after free detected".to_string(),
284 };
285
286 assert_eq!(
287 violation.memory_address, 0x1000,
288 "Memory address should match"
289 );
290 assert_eq!(violation.memory_size, 1024, "Memory size should match");
291 assert_eq!(
292 violation.severity,
293 RiskLevel::Critical,
294 "Severity should be Critical"
295 );
296 }
297
298 #[test]
301 fn test_violation_type_equality() {
302 assert_eq!(ViolationType::DoubleFree, ViolationType::DoubleFree);
303 assert_eq!(ViolationType::UseAfterFree, ViolationType::UseAfterFree);
304 assert_eq!(ViolationType::BufferOverflow, ViolationType::BufferOverflow);
305 assert_eq!(ViolationType::InvalidAccess, ViolationType::InvalidAccess);
306 assert_eq!(ViolationType::DataRace, ViolationType::DataRace);
307 assert_eq!(
308 ViolationType::FfiBoundaryViolation,
309 ViolationType::FfiBoundaryViolation
310 );
311
312 assert_ne!(ViolationType::DoubleFree, ViolationType::UseAfterFree);
313 }
314
315 #[test]
318 fn test_memory_context_creation() {
319 let context = MemoryContext {
320 total_allocated: 1024 * 1024,
321 active_allocations: 10,
322 memory_pressure: MemoryPressureLevel::Medium,
323 allocation_patterns: vec![],
324 };
325
326 assert_eq!(
327 context.total_allocated,
328 1024 * 1024,
329 "Total allocated should match"
330 );
331 assert_eq!(
332 context.active_allocations, 10,
333 "Active allocations should match"
334 );
335 }
336
337 #[test]
340 fn test_memory_pressure_level() {
341 let levels = [
342 MemoryPressureLevel::Low,
343 MemoryPressureLevel::Medium,
344 MemoryPressureLevel::High,
345 MemoryPressureLevel::Critical,
346 ];
347
348 for (i, level) in levels.iter().enumerate() {
349 for (j, other) in levels.iter().enumerate() {
350 if i == j {
351 assert_eq!(level, other, "Same levels should be equal");
352 } else {
353 assert_ne!(level, other, "Different levels should not be equal");
354 }
355 }
356 }
357 }
358
359 #[test]
362 fn test_allocation_pattern_creation() {
363 let pattern = AllocationPattern {
364 pattern_type: "repeated".to_string(),
365 frequency: 100,
366 average_size: 256,
367 risk_level: RiskLevel::Medium,
368 };
369
370 assert_eq!(pattern.frequency, 100, "Frequency should match");
371 assert_eq!(pattern.average_size, 256, "Average size should match");
372 }
373
374 #[test]
377 fn test_memory_passport_creation() {
378 let passport = MemoryPassport {
379 passport_id: "passport_123".to_string(),
380 allocation_ptr: 0x1000,
381 size_bytes: 1024,
382 status_at_shutdown: PassportStatus::Unknown,
383 lifecycle_events: vec![],
384 risk_assessment: RiskAssessment {
385 risk_level: RiskLevel::Low,
386 risk_score: 10.0,
387 risk_factors: vec![],
388 confidence_score: 0.5,
389 mitigation_suggestions: vec![],
390 assessment_timestamp: 0,
391 },
392 created_at: 1000,
393 updated_at: 1000,
394 };
395
396 assert_eq!(
397 passport.passport_id, "passport_123",
398 "Passport ID should match"
399 );
400 assert_eq!(
401 passport.allocation_ptr, 0x1000,
402 "Allocation pointer should match"
403 );
404 assert_eq!(passport.size_bytes, 1024, "Size should match");
405 }
406
407 #[test]
410 fn test_passport_status_variants() {
411 let statuses = vec![
412 PassportStatus::FreedByRust,
413 PassportStatus::HandoverToFfi,
414 PassportStatus::FreedByForeign,
415 PassportStatus::ReclaimedByRust,
416 PassportStatus::InForeignCustody,
417 PassportStatus::Unknown,
418 ];
419
420 for status in &statuses {
421 let debug_str = format!("{status:?}");
422 assert!(
423 !debug_str.is_empty(),
424 "Status should have debug representation"
425 );
426 }
427 }
428
429 #[test]
432 fn test_passport_event_creation() {
433 let event = PassportEvent {
434 event_type: PassportEventType::HandoverToFfi,
435 timestamp: 1000,
436 context: "ffi_transfer".to_string(),
437 call_stack: vec![],
438 metadata: HashMap::new(),
439 };
440
441 assert_eq!(event.timestamp, 1000, "Timestamp should match");
442 assert_eq!(event.context, "ffi_transfer", "Context should match");
443 }
444
445 #[test]
448 fn test_passport_event_type_variants() {
449 let event_types = vec![
450 PassportEventType::AllocatedInRust,
451 PassportEventType::HandoverToFfi,
452 PassportEventType::FreedByForeign,
453 PassportEventType::ReclaimedByRust,
454 PassportEventType::BoundaryAccess,
455 PassportEventType::OwnershipTransfer,
456 ];
457
458 for event_type in &event_types {
459 let debug_str = format!("{event_type:?}");
460 assert!(
461 !debug_str.is_empty(),
462 "Event type should have debug representation"
463 );
464 }
465 }
466
467 #[test]
470 fn test_unsafe_report_creation() {
471 let report = UnsafeReport {
472 report_id: "UNSAFE-UB-123".to_string(),
473 source: UnsafeSource::UnsafeBlock {
474 location: "test.rs".to_string(),
475 function: "test".to_string(),
476 file_path: None,
477 line_number: None,
478 },
479 risk_assessment: RiskAssessment {
480 risk_level: RiskLevel::Medium,
481 risk_score: 50.0,
482 risk_factors: vec![],
483 confidence_score: 0.8,
484 mitigation_suggestions: vec![],
485 assessment_timestamp: 0,
486 },
487 dynamic_violations: vec![],
488 related_passports: vec![],
489 memory_context: MemoryContext {
490 total_allocated: 0,
491 active_allocations: 0,
492 memory_pressure: MemoryPressureLevel::Low,
493 allocation_patterns: vec![],
494 },
495 generated_at: 1000,
496 };
497
498 assert_eq!(report.report_id, "UNSAFE-UB-123", "Report ID should match");
499 assert_eq!(
500 report.generated_at, 1000,
501 "Generated timestamp should match"
502 );
503 }
504
505 #[test]
508 fn test_risk_factor_edge_values() {
509 let zero_factor = RiskFactor {
510 factor_type: RiskFactorType::UseAfterFree,
511 severity: 0.0,
512 confidence: 0.0,
513 description: String::new(),
514 source_location: None,
515 call_stack: vec![],
516 mitigation: String::new(),
517 };
518
519 let max_factor = RiskFactor {
520 factor_type: RiskFactorType::BufferOverflow,
521 severity: 1.0,
522 confidence: 1.0,
523 description: "x".repeat(1000),
524 source_location: Some("x".repeat(1000)),
525 call_stack: vec![],
526 mitigation: "x".repeat(1000),
527 };
528
529 assert_eq!(zero_factor.severity, 0.0, "Zero severity should be valid");
530 assert_eq!(max_factor.severity, 1.0, "Max severity should be valid");
531 assert_eq!(
532 max_factor.description.len(),
533 1000,
534 "Long description should be preserved"
535 );
536 }
537
538 #[test]
541 fn test_serialization() {
542 let assessment = RiskAssessment {
543 risk_level: RiskLevel::High,
544 risk_score: 75.0,
545 risk_factors: vec![],
546 confidence_score: 0.9,
547 mitigation_suggestions: vec!["test".to_string()],
548 assessment_timestamp: 1000,
549 };
550
551 let json = serde_json::to_string(&assessment);
552 assert!(json.is_ok(), "Should serialize to JSON");
553
554 let deserialized: Result<RiskAssessment, _> = serde_json::from_str(&json.unwrap());
555 assert!(deserialized.is_ok(), "Should deserialize from JSON");
556 }
557}