1use decy_hir::{HirExpression, HirFunction, HirStatement};
7use std::collections::{HashMap, HashSet};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct LockRegion {
12 pub lock_name: String,
14 pub start_index: usize,
16 pub end_index: usize,
18}
19
20#[derive(Debug, Clone)]
22pub struct LockDataMapping {
23 lock_to_data: HashMap<String, HashSet<String>>,
25}
26
27impl LockDataMapping {
28 pub fn new() -> Self {
30 Self { lock_to_data: HashMap::new() }
31 }
32
33 pub fn is_protected_by(&self, data: &str, lock: &str) -> bool {
35 self.lock_to_data.get(lock).map(|vars| vars.contains(data)).unwrap_or(false)
36 }
37
38 pub fn get_protected_data(&self, lock: &str) -> Vec<String> {
40 self.lock_to_data.get(lock).map(|vars| vars.iter().cloned().collect()).unwrap_or_default()
41 }
42
43 pub fn get_locks(&self) -> Vec<String> {
45 self.lock_to_data.keys().cloned().collect()
46 }
47
48 fn add_protected_data(&mut self, lock: String, data: String) {
50 self.lock_to_data.entry(lock).or_default().insert(data);
51 }
52}
53
54impl Default for LockDataMapping {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60pub struct LockAnalyzer;
62
63impl LockAnalyzer {
64 pub fn new() -> Self {
66 Self
67 }
68
69 pub fn find_lock_regions(&self, func: &HirFunction) -> Vec<LockRegion> {
74 let mut regions = Vec::new();
75 let body = func.body();
76
77 let mut active_locks: HashMap<String, usize> = HashMap::new();
79
80 for (idx, stmt) in body.iter().enumerate() {
81 if let Some(lock_name) = Self::extract_lock_call(stmt) {
83 active_locks.insert(lock_name, idx);
84 }
85 else if let Some(unlock_name) = Self::extract_unlock_call(stmt) {
87 if let Some(start_idx) = active_locks.remove(&unlock_name) {
88 regions.push(LockRegion {
89 lock_name: unlock_name,
90 start_index: start_idx,
91 end_index: idx,
92 });
93 }
94 }
95 }
96
97 regions
98 }
99
100 fn extract_lock_call(stmt: &HirStatement) -> Option<String> {
102 if let HirStatement::Expression(HirExpression::FunctionCall { function, arguments }) = stmt
103 {
104 if function == "pthread_mutex_lock" {
105 if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
107 if let HirExpression::Variable(name) = &**inner {
108 return Some(name.clone());
109 }
110 }
111 }
112 }
113 None
114 }
115
116 fn extract_unlock_call(stmt: &HirStatement) -> Option<String> {
118 if let HirStatement::Expression(HirExpression::FunctionCall { function, arguments }) = stmt
119 {
120 if function == "pthread_mutex_unlock" {
121 if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
123 if let HirExpression::Variable(name) = &**inner {
124 return Some(name.clone());
125 }
126 }
127 }
128 }
129 None
130 }
131
132 pub fn analyze_lock_data_mapping(&self, func: &HirFunction) -> LockDataMapping {
137 let mut mapping = LockDataMapping::new();
138 let regions = self.find_lock_regions(func);
139 let body = func.body();
140
141 for region in regions {
143 let protected_vars = self.find_accessed_variables_in_region(body, ®ion);
144 for var in protected_vars {
145 mapping.add_protected_data(region.lock_name.clone(), var);
146 }
147 }
148
149 mapping
150 }
151
152 fn find_accessed_variables_in_region(
154 &self,
155 body: &[HirStatement],
156 region: &LockRegion,
157 ) -> HashSet<String> {
158 let mut accessed = HashSet::new();
159
160 for idx in (region.start_index + 1)..region.end_index {
162 if let Some(stmt) = body.get(idx) {
163 self.collect_accessed_variables(stmt, &mut accessed);
164 }
165 }
166
167 accessed
168 }
169
170 fn collect_accessed_variables(&self, stmt: &HirStatement, accessed: &mut HashSet<String>) {
172 match stmt {
173 HirStatement::Assignment { target, value } => {
174 accessed.insert(target.clone());
175 self.collect_variables_from_expr(value, accessed);
176 }
177 HirStatement::VariableDeclaration { initializer: Some(init), .. } => {
178 self.collect_variables_from_expr(init, accessed);
181 }
183 HirStatement::VariableDeclaration { initializer: None, .. } => {
184 }
186 HirStatement::Return(Some(e)) => {
187 self.collect_variables_from_expr(e, accessed);
188 }
189 HirStatement::Return(None) => {
190 }
192 HirStatement::If { condition, then_block, else_block } => {
193 self.collect_variables_from_expr(condition, accessed);
194 for s in then_block {
195 self.collect_accessed_variables(s, accessed);
196 }
197 if let Some(else_stmts) = else_block {
198 for s in else_stmts {
199 self.collect_accessed_variables(s, accessed);
200 }
201 }
202 }
203 HirStatement::While { condition, body } => {
204 self.collect_variables_from_expr(condition, accessed);
205 for s in body {
206 self.collect_accessed_variables(s, accessed);
207 }
208 }
209 HirStatement::Expression(expr) => {
210 self.collect_variables_from_expr(expr, accessed);
211 }
212 HirStatement::DerefAssignment { target, value } => {
213 self.collect_variables_from_expr(target, accessed);
214 self.collect_variables_from_expr(value, accessed);
215 }
216 HirStatement::ArrayIndexAssignment { array, index, value } => {
217 self.collect_variables_from_expr(array, accessed);
218 self.collect_variables_from_expr(index, accessed);
219 self.collect_variables_from_expr(value, accessed);
220 }
221 HirStatement::FieldAssignment { object, field: _, value } => {
222 self.collect_variables_from_expr(object, accessed);
223 self.collect_variables_from_expr(value, accessed);
224 }
225 _ => {
226 }
228 }
229 }
230
231 #[allow(clippy::only_used_in_recursion)]
233 fn collect_variables_from_expr(&self, expr: &HirExpression, accessed: &mut HashSet<String>) {
234 match expr {
235 HirExpression::Variable(name) => {
236 accessed.insert(name.clone());
237 }
238 HirExpression::BinaryOp { left, right, .. } => {
239 self.collect_variables_from_expr(left, accessed);
240 self.collect_variables_from_expr(right, accessed);
241 }
242 HirExpression::UnaryOp { operand, .. } => {
243 self.collect_variables_from_expr(operand, accessed);
244 }
245 HirExpression::FunctionCall { arguments, .. } => {
246 for arg in arguments {
247 self.collect_variables_from_expr(arg, accessed);
248 }
249 }
250 HirExpression::AddressOf(inner) | HirExpression::Dereference(inner) => {
251 self.collect_variables_from_expr(inner, accessed);
252 }
253 HirExpression::ArrayIndex { array, index } => {
254 self.collect_variables_from_expr(array, accessed);
255 self.collect_variables_from_expr(index, accessed);
256 }
257 HirExpression::FieldAccess { object, .. } => {
258 self.collect_variables_from_expr(object, accessed);
259 }
260 HirExpression::Cast { expr, .. } => {
261 self.collect_variables_from_expr(expr, accessed);
262 }
263 _ => {}
265 }
266 }
267
268 pub fn check_lock_discipline(&self, func: &HirFunction) -> Vec<String> {
277 let mut violations = Vec::new();
278 let body = func.body();
279
280 let mut active_locks: HashMap<String, usize> = HashMap::new();
282
283 for (idx, stmt) in body.iter().enumerate() {
284 if let Some(lock_name) = Self::extract_lock_call(stmt) {
286 active_locks.insert(lock_name, idx);
287 }
288 else if let Some(unlock_name) = Self::extract_unlock_call(stmt) {
290 if active_locks.remove(&unlock_name).is_none() {
291 violations.push(format!(
293 "Unlock without lock: pthread_mutex_unlock(&{}) at statement {}",
294 unlock_name, idx
295 ));
296 }
297 }
298 }
299
300 for (lock_name, start_idx) in active_locks {
302 violations.push(format!(
303 "Unmatched lock: pthread_mutex_lock(&{}) at statement {} has no corresponding unlock",
304 lock_name, start_idx
305 ));
306 }
307
308 violations
309 }
310}
311
312impl Default for LockAnalyzer {
313 fn default() -> Self {
314 Self::new()
315 }
316}