Skip to main content

decy_codegen/
concurrency_transform.rs

1//! Concurrency transformation module for pthread → Rust std::sync conversions.
2//!
3//! Transforms C pthread synchronization primitives to safe Rust equivalents:
4//! - pthread_mutex_t + data → `Mutex<T>`
5//! - pthread_mutex_lock/unlock → `.lock().unwrap()` with RAII
6//!
7//! Part of DECY-078: Transform pthread_mutex to `Mutex<T>`
8
9use decy_hir::{HirExpression, HirFunction, HirStatement};
10
11/// Detects if a function call is a pthread mutex lock operation.
12///
13/// Recognizes patterns:
14/// - pthread_mutex_lock(&mutex)
15/// - pthread_mutex_lock(&ptr->mutex)
16///
17/// Returns the name of the mutex variable if detected, None otherwise.
18pub fn is_pthread_lock(stmt: &HirStatement) -> Option<String> {
19    if let HirStatement::Expression(HirExpression::FunctionCall { function, arguments }) = stmt {
20        if function == "pthread_mutex_lock" && !arguments.is_empty() {
21            // Extract mutex name from &mutex or &ptr->field
22            if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
23                return extract_variable_name(inner);
24            }
25        }
26    }
27    None
28}
29
30/// Detects if a function call is a pthread mutex unlock operation.
31///
32/// Recognizes patterns:
33/// - pthread_mutex_unlock(&mutex)
34/// - pthread_mutex_unlock(&ptr->mutex)
35///
36/// Returns the name of the mutex variable if detected, None otherwise.
37pub fn is_pthread_unlock(stmt: &HirStatement) -> Option<String> {
38    if let HirStatement::Expression(HirExpression::FunctionCall { function, arguments }) = stmt {
39        if function == "pthread_mutex_unlock" && !arguments.is_empty() {
40            // Extract mutex name from &mutex or &ptr->field
41            if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
42                return extract_variable_name(inner);
43            }
44        }
45    }
46    None
47}
48
49/// Extracts the variable name from an expression.
50///
51/// Handles:
52/// - HirExpression::Variable(name) → Some(name)
53/// - HirExpression::PointerFieldAccess → Some(field_name)
54/// - HirExpression::FieldAccess → Some(field_name)
55/// - Other expressions → None
56fn extract_variable_name(expr: &HirExpression) -> Option<String> {
57    match expr {
58        HirExpression::Variable(name) => Some(name.clone()),
59        HirExpression::PointerFieldAccess { field, .. } => Some(field.clone()),
60        HirExpression::FieldAccess { field, .. } => Some(field.clone()),
61        _ => None,
62    }
63}
64
65/// Identifies lock regions in a function body.
66///
67/// A lock region is a sequence of statements between pthread_mutex_lock
68/// and pthread_mutex_unlock for the same mutex.
69///
70/// Returns a vector of (lock_name, start_index, end_index) tuples.
71pub fn identify_lock_regions(func: &HirFunction) -> Vec<(String, usize, usize)> {
72    let mut regions = Vec::new();
73    let body = func.body();
74    let mut active_locks: std::collections::HashMap<String, usize> =
75        std::collections::HashMap::new();
76
77    for (idx, stmt) in body.iter().enumerate() {
78        if let Some(lock_name) = is_pthread_lock(stmt) {
79            active_locks.insert(lock_name, idx);
80        } else if let Some(unlock_name) = is_pthread_unlock(stmt) {
81            if let Some(start_idx) = active_locks.remove(&unlock_name) {
82                regions.push((unlock_name, start_idx, idx));
83            }
84        }
85    }
86
87    regions
88}
89
90/// Checks if a function contains any pthread mutex operations.
91pub fn has_pthread_mutex_calls(func: &HirFunction) -> bool {
92    func.body()
93        .iter()
94        .any(|stmt| is_pthread_lock(stmt).is_some() || is_pthread_unlock(stmt).is_some())
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use decy_hir::HirType;
101
102    fn lock_call(lock_name: &str) -> HirStatement {
103        HirStatement::Expression(HirExpression::FunctionCall {
104            function: "pthread_mutex_lock".to_string(),
105            arguments: vec![HirExpression::AddressOf(Box::new(HirExpression::Variable(
106                lock_name.to_string(),
107            )))],
108        })
109    }
110
111    fn unlock_call(lock_name: &str) -> HirStatement {
112        HirStatement::Expression(HirExpression::FunctionCall {
113            function: "pthread_mutex_unlock".to_string(),
114            arguments: vec![HirExpression::AddressOf(Box::new(HirExpression::Variable(
115                lock_name.to_string(),
116            )))],
117        })
118    }
119
120    #[test]
121    fn test_detect_pthread_lock_call() {
122        let stmt = lock_call("my_mutex");
123        assert_eq!(is_pthread_lock(&stmt), Some("my_mutex".to_string()));
124    }
125
126    #[test]
127    fn test_detect_pthread_unlock_call() {
128        let stmt = unlock_call("my_mutex");
129        assert_eq!(is_pthread_unlock(&stmt), Some("my_mutex".to_string()));
130    }
131
132    #[test]
133    fn test_non_pthread_call_not_detected() {
134        let stmt = HirStatement::Expression(HirExpression::FunctionCall {
135            function: "some_other_function".to_string(),
136            arguments: vec![],
137        });
138        assert_eq!(is_pthread_lock(&stmt), None);
139        assert_eq!(is_pthread_unlock(&stmt), None);
140    }
141
142    #[test]
143    fn test_identify_single_lock_region() {
144        let func = HirFunction::new_with_body(
145            "test".to_string(),
146            HirType::Void,
147            vec![],
148            vec![
149                lock_call("lock"),
150                HirStatement::Assignment {
151                    target: "data".to_string(),
152                    value: HirExpression::IntLiteral(42),
153                },
154                unlock_call("lock"),
155            ],
156        );
157
158        let regions = identify_lock_regions(&func);
159        assert_eq!(regions.len(), 1);
160        assert_eq!(regions[0].0, "lock");
161        assert_eq!(regions[0].1, 0); // start index
162        assert_eq!(regions[0].2, 2); // end index
163    }
164
165    #[test]
166    fn test_identify_multiple_lock_regions() {
167        let func = HirFunction::new_with_body(
168            "test".to_string(),
169            HirType::Void,
170            vec![],
171            vec![
172                lock_call("lock1"),
173                unlock_call("lock1"),
174                lock_call("lock2"),
175                unlock_call("lock2"),
176            ],
177        );
178
179        let regions = identify_lock_regions(&func);
180        assert_eq!(regions.len(), 2);
181        assert_eq!(regions[0].0, "lock1");
182        assert_eq!(regions[1].0, "lock2");
183    }
184
185    #[test]
186    fn test_has_pthread_mutex_calls() {
187        let func_with_mutex = HirFunction::new_with_body(
188            "test".to_string(),
189            HirType::Void,
190            vec![],
191            vec![lock_call("lock"), unlock_call("lock")],
192        );
193
194        let func_without_mutex = HirFunction::new_with_body(
195            "test".to_string(),
196            HirType::Void,
197            vec![],
198            vec![HirStatement::Return(Some(HirExpression::IntLiteral(0)))],
199        );
200
201        assert!(has_pthread_mutex_calls(&func_with_mutex));
202        assert!(!has_pthread_mutex_calls(&func_without_mutex));
203    }
204
205    #[test]
206    fn test_extract_variable_name_from_variable() {
207        let expr = HirExpression::Variable("my_var".to_string());
208        assert_eq!(extract_variable_name(&expr), Some("my_var".to_string()));
209    }
210
211    #[test]
212    fn test_extract_variable_name_from_field_access() {
213        let expr = HirExpression::FieldAccess {
214            object: Box::new(HirExpression::Variable("obj".to_string())),
215            field: "field_name".to_string(),
216        };
217        assert_eq!(extract_variable_name(&expr), Some("field_name".to_string()));
218    }
219
220    #[test]
221    fn test_extract_variable_name_from_literal_returns_none() {
222        let expr = HirExpression::IntLiteral(42);
223        assert_eq!(extract_variable_name(&expr), None);
224    }
225
226    #[test]
227    fn test_lock_via_pointer_field_access() {
228        let stmt = HirStatement::Expression(HirExpression::FunctionCall {
229            function: "pthread_mutex_lock".to_string(),
230            arguments: vec![HirExpression::AddressOf(Box::new(
231                HirExpression::PointerFieldAccess {
232                    pointer: Box::new(HirExpression::Variable("ctx".to_string())),
233                    field: "lock".to_string(),
234                },
235            ))],
236        });
237        assert_eq!(is_pthread_lock(&stmt), Some("lock".to_string()));
238    }
239
240    #[test]
241    fn test_unlock_via_pointer_field_access() {
242        let stmt = HirStatement::Expression(HirExpression::FunctionCall {
243            function: "pthread_mutex_unlock".to_string(),
244            arguments: vec![HirExpression::AddressOf(Box::new(
245                HirExpression::PointerFieldAccess {
246                    pointer: Box::new(HirExpression::Variable("ctx".to_string())),
247                    field: "lock".to_string(),
248                },
249            ))],
250        });
251        assert_eq!(is_pthread_unlock(&stmt), Some("lock".to_string()));
252    }
253
254    #[test]
255    fn test_unmatched_lock_no_region() {
256        let func = HirFunction::new_with_body(
257            "test".to_string(),
258            HirType::Void,
259            vec![],
260            vec![lock_call("orphan")],
261        );
262        let regions = identify_lock_regions(&func);
263        assert!(regions.is_empty());
264    }
265}