decy_analyzer/
lock_analysis.rs

1//! Lock-to-data binding analysis for pthread synchronization (DECY-077).
2//!
3//! Analyzes C code with pthread_mutex locks to determine which locks
4//! protect which data variables, enabling safe `Mutex<T>` generation.
5
6use decy_hir::{HirExpression, HirFunction, HirStatement};
7use std::collections::{HashMap, HashSet};
8
9/// Represents a locked code region.
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct LockRegion {
12    /// Name of the lock variable
13    pub lock_name: String,
14    /// Starting statement index (lock call)
15    pub start_index: usize,
16    /// Ending statement index (unlock call)
17    pub end_index: usize,
18}
19
20/// Mapping from locks to protected data variables.
21#[derive(Debug, Clone)]
22pub struct LockDataMapping {
23    /// Maps lock name → set of protected variable names
24    lock_to_data: HashMap<String, HashSet<String>>,
25}
26
27impl LockDataMapping {
28    /// Create a new empty mapping.
29    pub fn new() -> Self {
30        Self {
31            lock_to_data: HashMap::new(),
32        }
33    }
34
35    /// Check if a variable is protected by a specific lock.
36    pub fn is_protected_by(&self, data: &str, lock: &str) -> bool {
37        self.lock_to_data
38            .get(lock)
39            .map(|vars| vars.contains(data))
40            .unwrap_or(false)
41    }
42
43    /// Get all data variables protected by a lock.
44    pub fn get_protected_data(&self, lock: &str) -> Vec<String> {
45        self.lock_to_data
46            .get(lock)
47            .map(|vars| vars.iter().cloned().collect())
48            .unwrap_or_default()
49    }
50
51    /// Get all locks tracked in this mapping.
52    pub fn get_locks(&self) -> Vec<String> {
53        self.lock_to_data.keys().cloned().collect()
54    }
55
56    /// Add a data variable to a lock's protection set.
57    fn add_protected_data(&mut self, lock: String, data: String) {
58        self.lock_to_data.entry(lock).or_default().insert(data);
59    }
60}
61
62impl Default for LockDataMapping {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68/// Analyzes pthread lock usage and protected data.
69pub struct LockAnalyzer;
70
71impl LockAnalyzer {
72    /// Create a new lock analyzer.
73    pub fn new() -> Self {
74        Self
75    }
76
77    /// Find all locked regions in a function.
78    ///
79    /// Identifies pthread_mutex_lock/unlock pairs and returns
80    /// the code regions they protect.
81    pub fn find_lock_regions(&self, func: &HirFunction) -> Vec<LockRegion> {
82        let mut regions = Vec::new();
83        let body = func.body();
84
85        // Track active locks (lock name -> start index)
86        let mut active_locks: HashMap<String, usize> = HashMap::new();
87
88        for (idx, stmt) in body.iter().enumerate() {
89            // Check for pthread_mutex_lock calls
90            if let Some(lock_name) = Self::extract_lock_call(stmt) {
91                active_locks.insert(lock_name, idx);
92            }
93            // Check for pthread_mutex_unlock calls
94            else if let Some(unlock_name) = Self::extract_unlock_call(stmt) {
95                if let Some(start_idx) = active_locks.remove(&unlock_name) {
96                    regions.push(LockRegion {
97                        lock_name: unlock_name,
98                        start_index: start_idx,
99                        end_index: idx,
100                    });
101                }
102            }
103        }
104
105        regions
106    }
107
108    /// Extract lock name from pthread_mutex_lock call.
109    fn extract_lock_call(stmt: &HirStatement) -> Option<String> {
110        if let HirStatement::Expression(HirExpression::FunctionCall {
111            function,
112            arguments,
113        }) = stmt
114        {
115            if function == "pthread_mutex_lock" {
116                // Extract lock name from &lock argument
117                if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
118                    if let HirExpression::Variable(name) = &**inner {
119                        return Some(name.clone());
120                    }
121                }
122            }
123        }
124        None
125    }
126
127    /// Extract lock name from pthread_mutex_unlock call.
128    fn extract_unlock_call(stmt: &HirStatement) -> Option<String> {
129        if let HirStatement::Expression(HirExpression::FunctionCall {
130            function,
131            arguments,
132        }) = stmt
133        {
134            if function == "pthread_mutex_unlock" {
135                // Extract lock name from &lock argument
136                if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
137                    if let HirExpression::Variable(name) = &**inner {
138                        return Some(name.clone());
139                    }
140                }
141            }
142        }
143        None
144    }
145
146    /// Analyze lock-to-data mapping for a function.
147    ///
148    /// Determines which locks protect which data variables based
149    /// on variable accesses within locked regions.
150    pub fn analyze_lock_data_mapping(&self, func: &HirFunction) -> LockDataMapping {
151        let mut mapping = LockDataMapping::new();
152        let regions = self.find_lock_regions(func);
153        let body = func.body();
154
155        // For each lock region, find all accessed variables
156        for region in regions {
157            let protected_vars = self.find_accessed_variables_in_region(body, &region);
158            for var in protected_vars {
159                mapping.add_protected_data(region.lock_name.clone(), var);
160            }
161        }
162
163        mapping
164    }
165
166    /// Find all variables accessed in a locked region.
167    fn find_accessed_variables_in_region(
168        &self,
169        body: &[HirStatement],
170        region: &LockRegion,
171    ) -> HashSet<String> {
172        let mut accessed = HashSet::new();
173
174        // Scan statements in the region (excluding lock/unlock calls)
175        for idx in (region.start_index + 1)..region.end_index {
176            if let Some(stmt) = body.get(idx) {
177                self.collect_accessed_variables(stmt, &mut accessed);
178            }
179        }
180
181        accessed
182    }
183
184    /// Recursively collect all variable names accessed in a statement.
185    fn collect_accessed_variables(&self, stmt: &HirStatement, accessed: &mut HashSet<String>) {
186        match stmt {
187            HirStatement::Assignment { target, value } => {
188                accessed.insert(target.clone());
189                self.collect_variables_from_expr(value, accessed);
190            }
191            HirStatement::VariableDeclaration {
192                initializer: Some(init),
193                ..
194            } => {
195                // Local variable declarations don't count as protected data
196                // But if the initializer reads from other variables, those count
197                self.collect_variables_from_expr(init, accessed);
198                // Don't add the variable name itself - it's local to this scope
199            }
200            HirStatement::VariableDeclaration {
201                initializer: None, ..
202            } => {
203                // No initializer, nothing to track
204            }
205            HirStatement::Return(Some(e)) => {
206                self.collect_variables_from_expr(e, accessed);
207            }
208            HirStatement::Return(None) => {
209                // No return value, nothing to track
210            }
211            HirStatement::If {
212                condition,
213                then_block,
214                else_block,
215            } => {
216                self.collect_variables_from_expr(condition, accessed);
217                for s in then_block {
218                    self.collect_accessed_variables(s, accessed);
219                }
220                if let Some(else_stmts) = else_block {
221                    for s in else_stmts {
222                        self.collect_accessed_variables(s, accessed);
223                    }
224                }
225            }
226            HirStatement::While { condition, body } => {
227                self.collect_variables_from_expr(condition, accessed);
228                for s in body {
229                    self.collect_accessed_variables(s, accessed);
230                }
231            }
232            HirStatement::Expression(expr) => {
233                self.collect_variables_from_expr(expr, accessed);
234            }
235            HirStatement::DerefAssignment { target, value } => {
236                self.collect_variables_from_expr(target, accessed);
237                self.collect_variables_from_expr(value, accessed);
238            }
239            HirStatement::ArrayIndexAssignment {
240                array,
241                index,
242                value,
243            } => {
244                self.collect_variables_from_expr(array, accessed);
245                self.collect_variables_from_expr(index, accessed);
246                self.collect_variables_from_expr(value, accessed);
247            }
248            HirStatement::FieldAssignment {
249                object,
250                field: _,
251                value,
252            } => {
253                self.collect_variables_from_expr(object, accessed);
254                self.collect_variables_from_expr(value, accessed);
255            }
256            _ => {
257                // Break, Continue, etc. don't access variables
258            }
259        }
260    }
261
262    /// Collect variable names from an expression.
263    #[allow(clippy::only_used_in_recursion)]
264    fn collect_variables_from_expr(&self, expr: &HirExpression, accessed: &mut HashSet<String>) {
265        match expr {
266            HirExpression::Variable(name) => {
267                accessed.insert(name.clone());
268            }
269            HirExpression::BinaryOp { left, right, .. } => {
270                self.collect_variables_from_expr(left, accessed);
271                self.collect_variables_from_expr(right, accessed);
272            }
273            HirExpression::UnaryOp { operand, .. } => {
274                self.collect_variables_from_expr(operand, accessed);
275            }
276            HirExpression::FunctionCall { arguments, .. } => {
277                for arg in arguments {
278                    self.collect_variables_from_expr(arg, accessed);
279                }
280            }
281            HirExpression::AddressOf(inner) | HirExpression::Dereference(inner) => {
282                self.collect_variables_from_expr(inner, accessed);
283            }
284            HirExpression::ArrayIndex { array, index } => {
285                self.collect_variables_from_expr(array, accessed);
286                self.collect_variables_from_expr(index, accessed);
287            }
288            HirExpression::FieldAccess { object, .. } => {
289                self.collect_variables_from_expr(object, accessed);
290            }
291            HirExpression::Cast { expr, .. } => {
292                self.collect_variables_from_expr(expr, accessed);
293            }
294            // Literals and other expressions don't reference variables
295            _ => {}
296        }
297    }
298
299    /// Check for lock discipline violations.
300    ///
301    /// Detects:
302    /// - Locks without unlocks
303    /// - Unlocks without locks
304    /// - Mismatched lock/unlock pairs
305    ///
306    /// Returns a list of violation descriptions.
307    pub fn check_lock_discipline(&self, func: &HirFunction) -> Vec<String> {
308        let mut violations = Vec::new();
309        let body = func.body();
310
311        // Track active locks
312        let mut active_locks: HashMap<String, usize> = HashMap::new();
313
314        for (idx, stmt) in body.iter().enumerate() {
315            // Check for lock calls
316            if let Some(lock_name) = Self::extract_lock_call(stmt) {
317                active_locks.insert(lock_name, idx);
318            }
319            // Check for unlock calls
320            else if let Some(unlock_name) = Self::extract_unlock_call(stmt) {
321                if active_locks.remove(&unlock_name).is_none() {
322                    // Unlock without corresponding lock
323                    violations.push(format!(
324                        "Unlock without lock: pthread_mutex_unlock(&{}) at statement {}",
325                        unlock_name, idx
326                    ));
327                }
328            }
329        }
330
331        // Check for unmatched locks (locks without unlocks)
332        for (lock_name, start_idx) in active_locks {
333            violations.push(format!(
334                "Unmatched lock: pthread_mutex_lock(&{}) at statement {} has no corresponding unlock",
335                lock_name, start_idx
336            ));
337        }
338
339        violations
340    }
341}
342
343impl Default for LockAnalyzer {
344    fn default() -> Self {
345        Self::new()
346    }
347}