Skip to main content

decy_analyzer/
lock_analysis.rs

1//! Lock-to-data binding analysis for pthread synchronization (DECY-077).
2//!
3//! Analyzes C code with pthread_mutex locks to determine which locks
4//! protect which data variables, enabling safe `Mutex<T>` generation.
5
6use decy_hir::{HirExpression, HirFunction, HirStatement};
7use std::collections::{HashMap, HashSet};
8
9/// Represents a locked code region.
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct LockRegion {
12    /// Name of the lock variable
13    pub lock_name: String,
14    /// Starting statement index (lock call)
15    pub start_index: usize,
16    /// Ending statement index (unlock call)
17    pub end_index: usize,
18}
19
20/// Mapping from locks to protected data variables.
21#[derive(Debug, Clone)]
22pub struct LockDataMapping {
23    /// Maps lock name → set of protected variable names
24    lock_to_data: HashMap<String, HashSet<String>>,
25}
26
27impl LockDataMapping {
28    /// Create a new empty mapping.
29    pub fn new() -> Self {
30        Self { lock_to_data: HashMap::new() }
31    }
32
33    /// Check if a variable is protected by a specific lock.
34    pub fn is_protected_by(&self, data: &str, lock: &str) -> bool {
35        self.lock_to_data.get(lock).map(|vars| vars.contains(data)).unwrap_or(false)
36    }
37
38    /// Get all data variables protected by a lock.
39    pub fn get_protected_data(&self, lock: &str) -> Vec<String> {
40        self.lock_to_data.get(lock).map(|vars| vars.iter().cloned().collect()).unwrap_or_default()
41    }
42
43    /// Get all locks tracked in this mapping.
44    pub fn get_locks(&self) -> Vec<String> {
45        self.lock_to_data.keys().cloned().collect()
46    }
47
48    /// Add a data variable to a lock's protection set.
49    fn add_protected_data(&mut self, lock: String, data: String) {
50        self.lock_to_data.entry(lock).or_default().insert(data);
51    }
52}
53
54impl Default for LockDataMapping {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60/// Analyzes pthread lock usage and protected data.
61pub struct LockAnalyzer;
62
63impl LockAnalyzer {
64    /// Create a new lock analyzer.
65    pub fn new() -> Self {
66        Self
67    }
68
69    /// Find all locked regions in a function.
70    ///
71    /// Identifies pthread_mutex_lock/unlock pairs and returns
72    /// the code regions they protect.
73    pub fn find_lock_regions(&self, func: &HirFunction) -> Vec<LockRegion> {
74        let mut regions = Vec::new();
75        let body = func.body();
76
77        // Track active locks (lock name -> start index)
78        let mut active_locks: HashMap<String, usize> = HashMap::new();
79
80        for (idx, stmt) in body.iter().enumerate() {
81            // Check for pthread_mutex_lock calls
82            if let Some(lock_name) = Self::extract_lock_call(stmt) {
83                active_locks.insert(lock_name, idx);
84            }
85            // Check for pthread_mutex_unlock calls
86            else if let Some(unlock_name) = Self::extract_unlock_call(stmt) {
87                if let Some(start_idx) = active_locks.remove(&unlock_name) {
88                    regions.push(LockRegion {
89                        lock_name: unlock_name,
90                        start_index: start_idx,
91                        end_index: idx,
92                    });
93                }
94            }
95        }
96
97        regions
98    }
99
100    /// Extract lock name from pthread_mutex_lock call.
101    fn extract_lock_call(stmt: &HirStatement) -> Option<String> {
102        if let HirStatement::Expression(HirExpression::FunctionCall { function, arguments }) = stmt
103        {
104            if function == "pthread_mutex_lock" {
105                // Extract lock name from &lock argument
106                if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
107                    if let HirExpression::Variable(name) = &**inner {
108                        return Some(name.clone());
109                    }
110                }
111            }
112        }
113        None
114    }
115
116    /// Extract lock name from pthread_mutex_unlock call.
117    fn extract_unlock_call(stmt: &HirStatement) -> Option<String> {
118        if let HirStatement::Expression(HirExpression::FunctionCall { function, arguments }) = stmt
119        {
120            if function == "pthread_mutex_unlock" {
121                // Extract lock name from &lock argument
122                if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
123                    if let HirExpression::Variable(name) = &**inner {
124                        return Some(name.clone());
125                    }
126                }
127            }
128        }
129        None
130    }
131
132    /// Analyze lock-to-data mapping for a function.
133    ///
134    /// Determines which locks protect which data variables based
135    /// on variable accesses within locked regions.
136    pub fn analyze_lock_data_mapping(&self, func: &HirFunction) -> LockDataMapping {
137        let mut mapping = LockDataMapping::new();
138        let regions = self.find_lock_regions(func);
139        let body = func.body();
140
141        // For each lock region, find all accessed variables
142        for region in regions {
143            let protected_vars = self.find_accessed_variables_in_region(body, &region);
144            for var in protected_vars {
145                mapping.add_protected_data(region.lock_name.clone(), var);
146            }
147        }
148
149        mapping
150    }
151
152    /// Find all variables accessed in a locked region.
153    fn find_accessed_variables_in_region(
154        &self,
155        body: &[HirStatement],
156        region: &LockRegion,
157    ) -> HashSet<String> {
158        let mut accessed = HashSet::new();
159
160        // Scan statements in the region (excluding lock/unlock calls)
161        for idx in (region.start_index + 1)..region.end_index {
162            if let Some(stmt) = body.get(idx) {
163                self.collect_accessed_variables(stmt, &mut accessed);
164            }
165        }
166
167        accessed
168    }
169
170    /// Recursively collect all variable names accessed in a statement.
171    fn collect_accessed_variables(&self, stmt: &HirStatement, accessed: &mut HashSet<String>) {
172        match stmt {
173            HirStatement::Assignment { target, value } => {
174                accessed.insert(target.clone());
175                self.collect_variables_from_expr(value, accessed);
176            }
177            HirStatement::VariableDeclaration { initializer: Some(init), .. } => {
178                // Local variable declarations don't count as protected data
179                // But if the initializer reads from other variables, those count
180                self.collect_variables_from_expr(init, accessed);
181                // Don't add the variable name itself - it's local to this scope
182            }
183            HirStatement::VariableDeclaration { initializer: None, .. } => {
184                // No initializer, nothing to track
185            }
186            HirStatement::Return(Some(e)) => {
187                self.collect_variables_from_expr(e, accessed);
188            }
189            HirStatement::Return(None) => {
190                // No return value, nothing to track
191            }
192            HirStatement::If { condition, then_block, else_block } => {
193                self.collect_variables_from_expr(condition, accessed);
194                for s in then_block {
195                    self.collect_accessed_variables(s, accessed);
196                }
197                if let Some(else_stmts) = else_block {
198                    for s in else_stmts {
199                        self.collect_accessed_variables(s, accessed);
200                    }
201                }
202            }
203            HirStatement::While { condition, body } => {
204                self.collect_variables_from_expr(condition, accessed);
205                for s in body {
206                    self.collect_accessed_variables(s, accessed);
207                }
208            }
209            HirStatement::Expression(expr) => {
210                self.collect_variables_from_expr(expr, accessed);
211            }
212            HirStatement::DerefAssignment { target, value } => {
213                self.collect_variables_from_expr(target, accessed);
214                self.collect_variables_from_expr(value, accessed);
215            }
216            HirStatement::ArrayIndexAssignment { array, index, value } => {
217                self.collect_variables_from_expr(array, accessed);
218                self.collect_variables_from_expr(index, accessed);
219                self.collect_variables_from_expr(value, accessed);
220            }
221            HirStatement::FieldAssignment { object, field: _, value } => {
222                self.collect_variables_from_expr(object, accessed);
223                self.collect_variables_from_expr(value, accessed);
224            }
225            _ => {
226                // Break, Continue, etc. don't access variables
227            }
228        }
229    }
230
231    /// Collect variable names from an expression.
232    #[allow(clippy::only_used_in_recursion)]
233    fn collect_variables_from_expr(&self, expr: &HirExpression, accessed: &mut HashSet<String>) {
234        match expr {
235            HirExpression::Variable(name) => {
236                accessed.insert(name.clone());
237            }
238            HirExpression::BinaryOp { left, right, .. } => {
239                self.collect_variables_from_expr(left, accessed);
240                self.collect_variables_from_expr(right, accessed);
241            }
242            HirExpression::UnaryOp { operand, .. } => {
243                self.collect_variables_from_expr(operand, accessed);
244            }
245            HirExpression::FunctionCall { arguments, .. } => {
246                for arg in arguments {
247                    self.collect_variables_from_expr(arg, accessed);
248                }
249            }
250            HirExpression::AddressOf(inner) | HirExpression::Dereference(inner) => {
251                self.collect_variables_from_expr(inner, accessed);
252            }
253            HirExpression::ArrayIndex { array, index } => {
254                self.collect_variables_from_expr(array, accessed);
255                self.collect_variables_from_expr(index, accessed);
256            }
257            HirExpression::FieldAccess { object, .. } => {
258                self.collect_variables_from_expr(object, accessed);
259            }
260            HirExpression::Cast { expr, .. } => {
261                self.collect_variables_from_expr(expr, accessed);
262            }
263            // Literals and other expressions don't reference variables
264            _ => {}
265        }
266    }
267
268    /// Check for lock discipline violations.
269    ///
270    /// Detects:
271    /// - Locks without unlocks
272    /// - Unlocks without locks
273    /// - Mismatched lock/unlock pairs
274    ///
275    /// Returns a list of violation descriptions.
276    pub fn check_lock_discipline(&self, func: &HirFunction) -> Vec<String> {
277        let mut violations = Vec::new();
278        let body = func.body();
279
280        // Track active locks
281        let mut active_locks: HashMap<String, usize> = HashMap::new();
282
283        for (idx, stmt) in body.iter().enumerate() {
284            // Check for lock calls
285            if let Some(lock_name) = Self::extract_lock_call(stmt) {
286                active_locks.insert(lock_name, idx);
287            }
288            // Check for unlock calls
289            else if let Some(unlock_name) = Self::extract_unlock_call(stmt) {
290                if active_locks.remove(&unlock_name).is_none() {
291                    // Unlock without corresponding lock
292                    violations.push(format!(
293                        "Unlock without lock: pthread_mutex_unlock(&{}) at statement {}",
294                        unlock_name, idx
295                    ));
296                }
297            }
298        }
299
300        // Check for unmatched locks (locks without unlocks)
301        for (lock_name, start_idx) in active_locks {
302            violations.push(format!(
303                "Unmatched lock: pthread_mutex_lock(&{}) at statement {} has no corresponding unlock",
304                lock_name, start_idx
305            ));
306        }
307
308        violations
309    }
310}
311
312impl Default for LockAnalyzer {
313    fn default() -> Self {
314        Self::new()
315    }
316}