padlock_source/
concurrency.rs1use padlock_core::ir::{AccessPattern, StructLayout};
7
8use crate::SourceLanguage;
9
10pub fn annotate_concurrency(layout: &mut StructLayout, language: &SourceLanguage) {
16 for field in &mut layout.fields {
17 let ty_name = match &field.ty {
18 padlock_core::ir::TypeInfo::Primitive { name, .. }
19 | padlock_core::ir::TypeInfo::Opaque { name, .. } => name.clone(),
20 _ => continue,
21 };
22
23 if is_concurrent_type(&ty_name, language) {
24 let is_atomic = is_atomic_type(&ty_name, language);
25 if matches!(field.access, AccessPattern::Unknown) {
26 field.access = AccessPattern::Concurrent {
30 guard: Some(field.name.clone()),
31 is_atomic,
32 };
33 }
34 } else if is_read_mostly_type(&ty_name, language)
35 && matches!(field.access, AccessPattern::Unknown)
36 {
37 field.access = AccessPattern::ReadMostly;
38 }
39 }
40}
41
42pub fn has_concurrent_fields(layout: &StructLayout) -> bool {
44 layout
45 .fields
46 .iter()
47 .any(|f| matches!(f.access, AccessPattern::Concurrent { .. }))
48}
49
50fn is_concurrent_type(name: &str, lang: &SourceLanguage) -> bool {
51 match lang {
52 SourceLanguage::Rust => {
53 name.starts_with("Mutex")
54 || name.starts_with("RwLock")
55 || name.starts_with("Arc")
56 || name.contains("Atomic")
57 || name.starts_with("Condvar")
58 || name.starts_with("Once")
59 }
60 SourceLanguage::C | SourceLanguage::Cpp => {
61 name.contains("mutex")
62 || name.contains("atomic")
63 || name.contains("spinlock")
64 || name.contains("critical_section")
65 || name.contains("pthread_mutex")
66 }
67 SourceLanguage::Go => {
68 name == "sync.Mutex"
69 || name == "sync.RWMutex"
70 || name == "Mutex"
71 || name == "RWMutex"
72 || name.contains("atomic")
73 }
74 }
75}
76
77fn is_atomic_type(name: &str, lang: &SourceLanguage) -> bool {
78 match lang {
79 SourceLanguage::Rust => name.contains("Atomic"),
80 SourceLanguage::C | SourceLanguage::Cpp => name.contains("atomic"),
81 SourceLanguage::Go => name.contains("atomic"),
82 }
83}
84
85fn is_read_mostly_type(name: &str, lang: &SourceLanguage) -> bool {
86 match lang {
87 SourceLanguage::Rust => name.starts_with("RwLock"),
88 SourceLanguage::C | SourceLanguage::Cpp => {
89 name.contains("rwlock") || name.contains("shared_mutex")
90 }
91 SourceLanguage::Go => name == "sync.RWMutex" || name == "RWMutex",
92 }
93}
94
95#[cfg(test)]
98mod tests {
99 use super::*;
100 use padlock_core::arch::X86_64_SYSV;
101 use padlock_core::ir::{AccessPattern, Field, StructLayout, TypeInfo};
102
103 fn field_with_type(name: &str, ty_name: &str) -> Field {
104 Field {
105 name: name.into(),
106 ty: TypeInfo::Primitive {
107 name: ty_name.into(),
108 size: 8,
109 align: 8,
110 },
111 offset: 0,
112 size: 8,
113 align: 8,
114 source_file: None,
115 source_line: None,
116 access: AccessPattern::Unknown,
117 }
118 }
119
120 fn layout_with_fields(fields: Vec<Field>) -> StructLayout {
121 StructLayout {
122 name: "T".into(),
123 total_size: 64,
124 align: 8,
125 fields,
126 source_file: None,
127 source_line: None,
128 arch: &X86_64_SYSV,
129 is_packed: false,
130 is_union: false,
131 }
132 }
133
134 #[test]
135 fn rust_mutex_field_is_annotated() {
136 let mut layout = layout_with_fields(vec![field_with_type("counter", "Mutex")]);
137 annotate_concurrency(&mut layout, &SourceLanguage::Rust);
138 assert!(matches!(
139 layout.fields[0].access,
140 AccessPattern::Concurrent { .. }
141 ));
142 }
143
144 #[test]
145 fn rust_atomic_is_atomic() {
146 let mut layout = layout_with_fields(vec![field_with_type("count", "AtomicU64")]);
147 annotate_concurrency(&mut layout, &SourceLanguage::Rust);
148 if let AccessPattern::Concurrent { is_atomic, .. } = &layout.fields[0].access {
149 assert!(is_atomic);
150 } else {
151 panic!("expected Concurrent");
152 }
153 }
154
155 #[test]
156 fn cpp_mutex_annotated() {
157 let mut layout = layout_with_fields(vec![field_with_type("mu", "std::mutex")]);
158 annotate_concurrency(&mut layout, &SourceLanguage::Cpp);
159 assert!(has_concurrent_fields(&layout));
160 }
161
162 #[test]
163 fn unknown_field_stays_unknown() {
164 let mut layout = layout_with_fields(vec![field_with_type("x", "int")]);
165 annotate_concurrency(&mut layout, &SourceLanguage::C);
166 assert!(matches!(layout.fields[0].access, AccessPattern::Unknown));
167 }
168
169 #[test]
170 fn has_concurrent_fields_false_when_none() {
171 let layout = layout_with_fields(vec![field_with_type("x", "int")]);
172 assert!(!has_concurrent_fields(&layout));
173 }
174}