padlock_source/
concurrency.rs1use padlock_core::ir::{AccessPattern, StructLayout};
7
8use crate::SourceLanguage;
9
10pub fn annotate_concurrency(layout: &mut StructLayout, language: &SourceLanguage) {
16 for field in &mut layout.fields {
17 let ty_name = match &field.ty {
18 padlock_core::ir::TypeInfo::Primitive { name, .. }
19 | padlock_core::ir::TypeInfo::Opaque { name, .. } => name.clone(),
20 _ => continue,
21 };
22
23 if is_concurrent_type(&ty_name, language) {
24 let is_atomic = is_atomic_type(&ty_name, language);
25 if matches!(field.access, AccessPattern::Unknown) {
26 field.access = AccessPattern::Concurrent {
30 guard: Some(field.name.clone()),
31 is_atomic,
32 is_annotated: false,
33 };
34 }
35 } else if is_read_mostly_type(&ty_name, language)
36 && matches!(field.access, AccessPattern::Unknown)
37 {
38 field.access = AccessPattern::ReadMostly;
39 }
40 }
41}
42
43pub fn has_concurrent_fields(layout: &StructLayout) -> bool {
45 layout
46 .fields
47 .iter()
48 .any(|f| matches!(f.access, AccessPattern::Concurrent { .. }))
49}
50
51fn is_concurrent_type(name: &str, lang: &SourceLanguage) -> bool {
52 match lang {
53 SourceLanguage::Rust => {
54 name.starts_with("Mutex")
55 || name.starts_with("RwLock")
56 || name.starts_with("Arc")
57 || name.contains("Atomic")
58 || name.starts_with("Condvar")
59 || name.starts_with("Once")
60 }
61 SourceLanguage::C | SourceLanguage::Cpp => {
62 name.contains("mutex")
63 || name.contains("atomic")
64 || name.contains("spinlock")
65 || name.contains("critical_section")
66 || name.contains("pthread_mutex")
67 }
68 SourceLanguage::Go => {
69 name == "sync.Mutex"
70 || name == "sync.RWMutex"
71 || name == "Mutex"
72 || name == "RWMutex"
73 || name.contains("atomic")
74 }
75 SourceLanguage::Zig => {
76 name.contains("Mutex")
77 || name.contains("RwLock")
78 || name.contains("atomic.Value")
79 || name.contains("Atomic")
80 }
81 }
82}
83
84fn is_atomic_type(name: &str, lang: &SourceLanguage) -> bool {
85 match lang {
86 SourceLanguage::Rust => name.contains("Atomic"),
87 SourceLanguage::C | SourceLanguage::Cpp => name.contains("atomic"),
88 SourceLanguage::Go => name.contains("atomic"),
89 SourceLanguage::Zig => name.contains("atomic.Value") || name.contains("Atomic"),
90 }
91}
92
93fn is_read_mostly_type(name: &str, lang: &SourceLanguage) -> bool {
94 match lang {
95 SourceLanguage::Rust => name.starts_with("RwLock"),
96 SourceLanguage::C | SourceLanguage::Cpp => {
97 name.contains("rwlock") || name.contains("shared_mutex")
98 }
99 SourceLanguage::Go => name == "sync.RWMutex" || name == "RWMutex",
100 SourceLanguage::Zig => name.contains("RwLock"),
101 }
102}
103
104#[cfg(test)]
107mod tests {
108 use super::*;
109 use padlock_core::arch::X86_64_SYSV;
110 use padlock_core::ir::{AccessPattern, Field, StructLayout, TypeInfo};
111
112 fn field_with_type(name: &str, ty_name: &str) -> Field {
113 Field {
114 name: name.into(),
115 ty: TypeInfo::Primitive {
116 name: ty_name.into(),
117 size: 8,
118 align: 8,
119 },
120 offset: 0,
121 size: 8,
122 align: 8,
123 source_file: None,
124 source_line: None,
125 access: AccessPattern::Unknown,
126 }
127 }
128
129 fn layout_with_fields(fields: Vec<Field>) -> StructLayout {
130 StructLayout {
131 name: "T".into(),
132 total_size: 64,
133 align: 8,
134 fields,
135 source_file: None,
136 source_line: None,
137 arch: &X86_64_SYSV,
138 is_packed: false,
139 is_union: false,
140 is_repr_rust: false,
141 suppressed_findings: Vec::new(),
142 }
143 }
144
145 #[test]
146 fn rust_mutex_field_is_annotated() {
147 let mut layout = layout_with_fields(vec![field_with_type("counter", "Mutex")]);
148 annotate_concurrency(&mut layout, &SourceLanguage::Rust);
149 assert!(matches!(
150 layout.fields[0].access,
151 AccessPattern::Concurrent { .. }
152 ));
153 }
154
155 #[test]
156 fn rust_atomic_is_atomic() {
157 let mut layout = layout_with_fields(vec![field_with_type("count", "AtomicU64")]);
158 annotate_concurrency(&mut layout, &SourceLanguage::Rust);
159 if let AccessPattern::Concurrent { is_atomic, .. } = &layout.fields[0].access {
160 assert!(is_atomic);
161 } else {
162 panic!("expected Concurrent");
163 }
164 }
165
166 #[test]
167 fn cpp_mutex_annotated() {
168 let mut layout = layout_with_fields(vec![field_with_type("mu", "std::mutex")]);
169 annotate_concurrency(&mut layout, &SourceLanguage::Cpp);
170 assert!(has_concurrent_fields(&layout));
171 }
172
173 #[test]
174 fn unknown_field_stays_unknown() {
175 let mut layout = layout_with_fields(vec![field_with_type("x", "int")]);
176 annotate_concurrency(&mut layout, &SourceLanguage::C);
177 assert!(matches!(layout.fields[0].access, AccessPattern::Unknown));
178 }
179
180 #[test]
181 fn has_concurrent_fields_false_when_none() {
182 let layout = layout_with_fields(vec![field_with_type("x", "int")]);
183 assert!(!has_concurrent_fields(&layout));
184 }
185}