1use decy_hir::{HirExpression, HirFunction, HirStatement};
7use std::collections::{HashMap, HashSet};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct LockRegion {
12 pub lock_name: String,
14 pub start_index: usize,
16 pub end_index: usize,
18}
19
20#[derive(Debug, Clone)]
22pub struct LockDataMapping {
23 lock_to_data: HashMap<String, HashSet<String>>,
25}
26
27impl LockDataMapping {
28 pub fn new() -> Self {
30 Self {
31 lock_to_data: HashMap::new(),
32 }
33 }
34
35 pub fn is_protected_by(&self, data: &str, lock: &str) -> bool {
37 self.lock_to_data
38 .get(lock)
39 .map(|vars| vars.contains(data))
40 .unwrap_or(false)
41 }
42
43 pub fn get_protected_data(&self, lock: &str) -> Vec<String> {
45 self.lock_to_data
46 .get(lock)
47 .map(|vars| vars.iter().cloned().collect())
48 .unwrap_or_default()
49 }
50
51 pub fn get_locks(&self) -> Vec<String> {
53 self.lock_to_data.keys().cloned().collect()
54 }
55
56 fn add_protected_data(&mut self, lock: String, data: String) {
58 self.lock_to_data.entry(lock).or_default().insert(data);
59 }
60}
61
62impl Default for LockDataMapping {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68pub struct LockAnalyzer;
70
71impl LockAnalyzer {
72 pub fn new() -> Self {
74 Self
75 }
76
77 pub fn find_lock_regions(&self, func: &HirFunction) -> Vec<LockRegion> {
82 let mut regions = Vec::new();
83 let body = func.body();
84
85 let mut active_locks: HashMap<String, usize> = HashMap::new();
87
88 for (idx, stmt) in body.iter().enumerate() {
89 if let Some(lock_name) = Self::extract_lock_call(stmt) {
91 active_locks.insert(lock_name, idx);
92 }
93 else if let Some(unlock_name) = Self::extract_unlock_call(stmt) {
95 if let Some(start_idx) = active_locks.remove(&unlock_name) {
96 regions.push(LockRegion {
97 lock_name: unlock_name,
98 start_index: start_idx,
99 end_index: idx,
100 });
101 }
102 }
103 }
104
105 regions
106 }
107
108 fn extract_lock_call(stmt: &HirStatement) -> Option<String> {
110 if let HirStatement::Expression(HirExpression::FunctionCall {
111 function,
112 arguments,
113 }) = stmt
114 {
115 if function == "pthread_mutex_lock" {
116 if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
118 if let HirExpression::Variable(name) = &**inner {
119 return Some(name.clone());
120 }
121 }
122 }
123 }
124 None
125 }
126
127 fn extract_unlock_call(stmt: &HirStatement) -> Option<String> {
129 if let HirStatement::Expression(HirExpression::FunctionCall {
130 function,
131 arguments,
132 }) = stmt
133 {
134 if function == "pthread_mutex_unlock" {
135 if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
137 if let HirExpression::Variable(name) = &**inner {
138 return Some(name.clone());
139 }
140 }
141 }
142 }
143 None
144 }
145
146 pub fn analyze_lock_data_mapping(&self, func: &HirFunction) -> LockDataMapping {
151 let mut mapping = LockDataMapping::new();
152 let regions = self.find_lock_regions(func);
153 let body = func.body();
154
155 for region in regions {
157 let protected_vars = self.find_accessed_variables_in_region(body, ®ion);
158 for var in protected_vars {
159 mapping.add_protected_data(region.lock_name.clone(), var);
160 }
161 }
162
163 mapping
164 }
165
166 fn find_accessed_variables_in_region(
168 &self,
169 body: &[HirStatement],
170 region: &LockRegion,
171 ) -> HashSet<String> {
172 let mut accessed = HashSet::new();
173
174 for idx in (region.start_index + 1)..region.end_index {
176 if let Some(stmt) = body.get(idx) {
177 self.collect_accessed_variables(stmt, &mut accessed);
178 }
179 }
180
181 accessed
182 }
183
184 fn collect_accessed_variables(&self, stmt: &HirStatement, accessed: &mut HashSet<String>) {
186 match stmt {
187 HirStatement::Assignment { target, value } => {
188 accessed.insert(target.clone());
189 self.collect_variables_from_expr(value, accessed);
190 }
191 HirStatement::VariableDeclaration {
192 initializer: Some(init),
193 ..
194 } => {
195 self.collect_variables_from_expr(init, accessed);
198 }
200 HirStatement::VariableDeclaration {
201 initializer: None, ..
202 } => {
203 }
205 HirStatement::Return(Some(e)) => {
206 self.collect_variables_from_expr(e, accessed);
207 }
208 HirStatement::Return(None) => {
209 }
211 HirStatement::If {
212 condition,
213 then_block,
214 else_block,
215 } => {
216 self.collect_variables_from_expr(condition, accessed);
217 for s in then_block {
218 self.collect_accessed_variables(s, accessed);
219 }
220 if let Some(else_stmts) = else_block {
221 for s in else_stmts {
222 self.collect_accessed_variables(s, accessed);
223 }
224 }
225 }
226 HirStatement::While { condition, body } => {
227 self.collect_variables_from_expr(condition, accessed);
228 for s in body {
229 self.collect_accessed_variables(s, accessed);
230 }
231 }
232 HirStatement::Expression(expr) => {
233 self.collect_variables_from_expr(expr, accessed);
234 }
235 HirStatement::DerefAssignment { target, value } => {
236 self.collect_variables_from_expr(target, accessed);
237 self.collect_variables_from_expr(value, accessed);
238 }
239 HirStatement::ArrayIndexAssignment {
240 array,
241 index,
242 value,
243 } => {
244 self.collect_variables_from_expr(array, accessed);
245 self.collect_variables_from_expr(index, accessed);
246 self.collect_variables_from_expr(value, accessed);
247 }
248 HirStatement::FieldAssignment {
249 object,
250 field: _,
251 value,
252 } => {
253 self.collect_variables_from_expr(object, accessed);
254 self.collect_variables_from_expr(value, accessed);
255 }
256 _ => {
257 }
259 }
260 }
261
262 #[allow(clippy::only_used_in_recursion)]
264 fn collect_variables_from_expr(&self, expr: &HirExpression, accessed: &mut HashSet<String>) {
265 match expr {
266 HirExpression::Variable(name) => {
267 accessed.insert(name.clone());
268 }
269 HirExpression::BinaryOp { left, right, .. } => {
270 self.collect_variables_from_expr(left, accessed);
271 self.collect_variables_from_expr(right, accessed);
272 }
273 HirExpression::UnaryOp { operand, .. } => {
274 self.collect_variables_from_expr(operand, accessed);
275 }
276 HirExpression::FunctionCall { arguments, .. } => {
277 for arg in arguments {
278 self.collect_variables_from_expr(arg, accessed);
279 }
280 }
281 HirExpression::AddressOf(inner) | HirExpression::Dereference(inner) => {
282 self.collect_variables_from_expr(inner, accessed);
283 }
284 HirExpression::ArrayIndex { array, index } => {
285 self.collect_variables_from_expr(array, accessed);
286 self.collect_variables_from_expr(index, accessed);
287 }
288 HirExpression::FieldAccess { object, .. } => {
289 self.collect_variables_from_expr(object, accessed);
290 }
291 HirExpression::Cast { expr, .. } => {
292 self.collect_variables_from_expr(expr, accessed);
293 }
294 _ => {}
296 }
297 }
298
299 pub fn check_lock_discipline(&self, func: &HirFunction) -> Vec<String> {
308 let mut violations = Vec::new();
309 let body = func.body();
310
311 let mut active_locks: HashMap<String, usize> = HashMap::new();
313
314 for (idx, stmt) in body.iter().enumerate() {
315 if let Some(lock_name) = Self::extract_lock_call(stmt) {
317 active_locks.insert(lock_name, idx);
318 }
319 else if let Some(unlock_name) = Self::extract_unlock_call(stmt) {
321 if active_locks.remove(&unlock_name).is_none() {
322 violations.push(format!(
324 "Unlock without lock: pthread_mutex_unlock(&{}) at statement {}",
325 unlock_name, idx
326 ));
327 }
328 }
329 }
330
331 for (lock_name, start_idx) in active_locks {
333 violations.push(format!(
334 "Unmatched lock: pthread_mutex_lock(&{}) at statement {} has no corresponding unlock",
335 lock_name, start_idx
336 ));
337 }
338
339 violations
340 }
341}
342
343impl Default for LockAnalyzer {
344 fn default() -> Self {
345 Self::new()
346 }
347}