padlock_source/
concurrency.rs1use padlock_core::ir::{AccessPattern, StructLayout};
7
8use crate::SourceLanguage;
9
10pub fn annotate_concurrency(layout: &mut StructLayout, language: &SourceLanguage) {
16 for field in &mut layout.fields {
17 let ty_name = match &field.ty {
18 padlock_core::ir::TypeInfo::Primitive { name, .. }
19 | padlock_core::ir::TypeInfo::Opaque { name, .. } => name.clone(),
20 _ => continue,
21 };
22
23 if is_concurrent_type(&ty_name, language) {
24 let is_atomic = is_atomic_type(&ty_name, language);
25 if matches!(field.access, AccessPattern::Unknown) {
26 field.access = AccessPattern::Concurrent {
30 guard: Some(field.name.clone()),
31 is_atomic,
32 };
33 }
34 } else if is_read_mostly_type(&ty_name, language)
35 && matches!(field.access, AccessPattern::Unknown)
36 {
37 field.access = AccessPattern::ReadMostly;
38 }
39 }
40}
41
42pub fn has_concurrent_fields(layout: &StructLayout) -> bool {
44 layout
45 .fields
46 .iter()
47 .any(|f| matches!(f.access, AccessPattern::Concurrent { .. }))
48}
49
50fn is_concurrent_type(name: &str, lang: &SourceLanguage) -> bool {
51 match lang {
52 SourceLanguage::Rust => {
53 name.starts_with("Mutex")
54 || name.starts_with("RwLock")
55 || name.starts_with("Arc")
56 || name.contains("Atomic")
57 || name.starts_with("Condvar")
58 || name.starts_with("Once")
59 }
60 SourceLanguage::C | SourceLanguage::Cpp => {
61 name.contains("mutex")
62 || name.contains("atomic")
63 || name.contains("spinlock")
64 || name.contains("critical_section")
65 || name.contains("pthread_mutex")
66 }
67 SourceLanguage::Go => {
68 name == "sync.Mutex"
69 || name == "sync.RWMutex"
70 || name == "Mutex"
71 || name == "RWMutex"
72 || name.contains("atomic")
73 }
74 SourceLanguage::Zig => {
75 name.contains("Mutex")
76 || name.contains("RwLock")
77 || name.contains("atomic.Value")
78 || name.contains("Atomic")
79 }
80 }
81}
82
83fn is_atomic_type(name: &str, lang: &SourceLanguage) -> bool {
84 match lang {
85 SourceLanguage::Rust => name.contains("Atomic"),
86 SourceLanguage::C | SourceLanguage::Cpp => name.contains("atomic"),
87 SourceLanguage::Go => name.contains("atomic"),
88 SourceLanguage::Zig => name.contains("atomic.Value") || name.contains("Atomic"),
89 }
90}
91
92fn is_read_mostly_type(name: &str, lang: &SourceLanguage) -> bool {
93 match lang {
94 SourceLanguage::Rust => name.starts_with("RwLock"),
95 SourceLanguage::C | SourceLanguage::Cpp => {
96 name.contains("rwlock") || name.contains("shared_mutex")
97 }
98 SourceLanguage::Go => name == "sync.RWMutex" || name == "RWMutex",
99 SourceLanguage::Zig => name.contains("RwLock"),
100 }
101}
102
103#[cfg(test)]
106mod tests {
107 use super::*;
108 use padlock_core::arch::X86_64_SYSV;
109 use padlock_core::ir::{AccessPattern, Field, StructLayout, TypeInfo};
110
111 fn field_with_type(name: &str, ty_name: &str) -> Field {
112 Field {
113 name: name.into(),
114 ty: TypeInfo::Primitive {
115 name: ty_name.into(),
116 size: 8,
117 align: 8,
118 },
119 offset: 0,
120 size: 8,
121 align: 8,
122 source_file: None,
123 source_line: None,
124 access: AccessPattern::Unknown,
125 }
126 }
127
128 fn layout_with_fields(fields: Vec<Field>) -> StructLayout {
129 StructLayout {
130 name: "T".into(),
131 total_size: 64,
132 align: 8,
133 fields,
134 source_file: None,
135 source_line: None,
136 arch: &X86_64_SYSV,
137 is_packed: false,
138 is_union: false,
139 is_repr_rust: false,
140 }
141 }
142
143 #[test]
144 fn rust_mutex_field_is_annotated() {
145 let mut layout = layout_with_fields(vec![field_with_type("counter", "Mutex")]);
146 annotate_concurrency(&mut layout, &SourceLanguage::Rust);
147 assert!(matches!(
148 layout.fields[0].access,
149 AccessPattern::Concurrent { .. }
150 ));
151 }
152
153 #[test]
154 fn rust_atomic_is_atomic() {
155 let mut layout = layout_with_fields(vec![field_with_type("count", "AtomicU64")]);
156 annotate_concurrency(&mut layout, &SourceLanguage::Rust);
157 if let AccessPattern::Concurrent { is_atomic, .. } = &layout.fields[0].access {
158 assert!(is_atomic);
159 } else {
160 panic!("expected Concurrent");
161 }
162 }
163
164 #[test]
165 fn cpp_mutex_annotated() {
166 let mut layout = layout_with_fields(vec![field_with_type("mu", "std::mutex")]);
167 annotate_concurrency(&mut layout, &SourceLanguage::Cpp);
168 assert!(has_concurrent_fields(&layout));
169 }
170
171 #[test]
172 fn unknown_field_stays_unknown() {
173 let mut layout = layout_with_fields(vec![field_with_type("x", "int")]);
174 annotate_concurrency(&mut layout, &SourceLanguage::C);
175 assert!(matches!(layout.fields[0].access, AccessPattern::Unknown));
176 }
177
178 #[test]
179 fn has_concurrent_fields_false_when_none() {
180 let layout = layout_with_fields(vec![field_with_type("x", "int")]);
181 assert!(!has_concurrent_fields(&layout));
182 }
183}