decy_codegen/
concurrency_transform.rs

1//! Concurrency transformation module for pthread → Rust std::sync conversions.
2//!
3//! Transforms C pthread synchronization primitives to safe Rust equivalents:
4//! - pthread_mutex_t + data → `Mutex<T>`
5//! - pthread_mutex_lock/unlock → `.lock().unwrap()` with RAII
6//!
7//! Part of DECY-078: Transform pthread_mutex to `Mutex<T>`
8
9use decy_hir::{HirExpression, HirFunction, HirStatement};
10
11/// Detects if a function call is a pthread mutex lock operation.
12///
13/// Recognizes patterns:
14/// - pthread_mutex_lock(&mutex)
15/// - pthread_mutex_lock(&ptr->mutex)
16///
17/// Returns the name of the mutex variable if detected, None otherwise.
18pub fn is_pthread_lock(stmt: &HirStatement) -> Option<String> {
19    if let HirStatement::Expression(HirExpression::FunctionCall {
20        function,
21        arguments,
22    }) = stmt
23    {
24        if function == "pthread_mutex_lock" && !arguments.is_empty() {
25            // Extract mutex name from &mutex or &ptr->field
26            if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
27                return extract_variable_name(inner);
28            }
29        }
30    }
31    None
32}
33
34/// Detects if a function call is a pthread mutex unlock operation.
35///
36/// Recognizes patterns:
37/// - pthread_mutex_unlock(&mutex)
38/// - pthread_mutex_unlock(&ptr->mutex)
39///
40/// Returns the name of the mutex variable if detected, None otherwise.
41pub fn is_pthread_unlock(stmt: &HirStatement) -> Option<String> {
42    if let HirStatement::Expression(HirExpression::FunctionCall {
43        function,
44        arguments,
45    }) = stmt
46    {
47        if function == "pthread_mutex_unlock" && !arguments.is_empty() {
48            // Extract mutex name from &mutex or &ptr->field
49            if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
50                return extract_variable_name(inner);
51            }
52        }
53    }
54    None
55}
56
57/// Extracts the variable name from an expression.
58///
59/// Handles:
60/// - HirExpression::Variable(name) → Some(name)
61/// - HirExpression::PointerFieldAccess → Some(field_name)
62/// - HirExpression::FieldAccess → Some(field_name)
63/// - Other expressions → None
64fn extract_variable_name(expr: &HirExpression) -> Option<String> {
65    match expr {
66        HirExpression::Variable(name) => Some(name.clone()),
67        HirExpression::PointerFieldAccess { field, .. } => Some(field.clone()),
68        HirExpression::FieldAccess { field, .. } => Some(field.clone()),
69        _ => None,
70    }
71}
72
73/// Identifies lock regions in a function body.
74///
75/// A lock region is a sequence of statements between pthread_mutex_lock
76/// and pthread_mutex_unlock for the same mutex.
77///
78/// Returns a vector of (lock_name, start_index, end_index) tuples.
79pub fn identify_lock_regions(func: &HirFunction) -> Vec<(String, usize, usize)> {
80    let mut regions = Vec::new();
81    let body = func.body();
82    let mut active_locks: std::collections::HashMap<String, usize> =
83        std::collections::HashMap::new();
84
85    for (idx, stmt) in body.iter().enumerate() {
86        if let Some(lock_name) = is_pthread_lock(stmt) {
87            active_locks.insert(lock_name, idx);
88        } else if let Some(unlock_name) = is_pthread_unlock(stmt) {
89            if let Some(start_idx) = active_locks.remove(&unlock_name) {
90                regions.push((unlock_name, start_idx, idx));
91            }
92        }
93    }
94
95    regions
96}
97
98/// Checks if a function contains any pthread mutex operations.
99pub fn has_pthread_mutex_calls(func: &HirFunction) -> bool {
100    func.body()
101        .iter()
102        .any(|stmt| is_pthread_lock(stmt).is_some() || is_pthread_unlock(stmt).is_some())
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use decy_hir::HirType;
109
110    fn lock_call(lock_name: &str) -> HirStatement {
111        HirStatement::Expression(HirExpression::FunctionCall {
112            function: "pthread_mutex_lock".to_string(),
113            arguments: vec![HirExpression::AddressOf(Box::new(HirExpression::Variable(
114                lock_name.to_string(),
115            )))],
116        })
117    }
118
119    fn unlock_call(lock_name: &str) -> HirStatement {
120        HirStatement::Expression(HirExpression::FunctionCall {
121            function: "pthread_mutex_unlock".to_string(),
122            arguments: vec![HirExpression::AddressOf(Box::new(HirExpression::Variable(
123                lock_name.to_string(),
124            )))],
125        })
126    }
127
128    #[test]
129    fn test_detect_pthread_lock_call() {
130        let stmt = lock_call("my_mutex");
131        assert_eq!(is_pthread_lock(&stmt), Some("my_mutex".to_string()));
132    }
133
134    #[test]
135    fn test_detect_pthread_unlock_call() {
136        let stmt = unlock_call("my_mutex");
137        assert_eq!(is_pthread_unlock(&stmt), Some("my_mutex".to_string()));
138    }
139
140    #[test]
141    fn test_non_pthread_call_not_detected() {
142        let stmt = HirStatement::Expression(HirExpression::FunctionCall {
143            function: "some_other_function".to_string(),
144            arguments: vec![],
145        });
146        assert_eq!(is_pthread_lock(&stmt), None);
147        assert_eq!(is_pthread_unlock(&stmt), None);
148    }
149
150    #[test]
151    fn test_identify_single_lock_region() {
152        let func = HirFunction::new_with_body(
153            "test".to_string(),
154            HirType::Void,
155            vec![],
156            vec![
157                lock_call("lock"),
158                HirStatement::Assignment {
159                    target: "data".to_string(),
160                    value: HirExpression::IntLiteral(42),
161                },
162                unlock_call("lock"),
163            ],
164        );
165
166        let regions = identify_lock_regions(&func);
167        assert_eq!(regions.len(), 1);
168        assert_eq!(regions[0].0, "lock");
169        assert_eq!(regions[0].1, 0); // start index
170        assert_eq!(regions[0].2, 2); // end index
171    }
172
173    #[test]
174    fn test_identify_multiple_lock_regions() {
175        let func = HirFunction::new_with_body(
176            "test".to_string(),
177            HirType::Void,
178            vec![],
179            vec![
180                lock_call("lock1"),
181                unlock_call("lock1"),
182                lock_call("lock2"),
183                unlock_call("lock2"),
184            ],
185        );
186
187        let regions = identify_lock_regions(&func);
188        assert_eq!(regions.len(), 2);
189        assert_eq!(regions[0].0, "lock1");
190        assert_eq!(regions[1].0, "lock2");
191    }
192
193    #[test]
194    fn test_has_pthread_mutex_calls() {
195        let func_with_mutex = HirFunction::new_with_body(
196            "test".to_string(),
197            HirType::Void,
198            vec![],
199            vec![lock_call("lock"), unlock_call("lock")],
200        );
201
202        let func_without_mutex = HirFunction::new_with_body(
203            "test".to_string(),
204            HirType::Void,
205            vec![],
206            vec![HirStatement::Return(Some(HirExpression::IntLiteral(0)))],
207        );
208
209        assert!(has_pthread_mutex_calls(&func_with_mutex));
210        assert!(!has_pthread_mutex_calls(&func_without_mutex));
211    }
212
213    #[test]
214    fn test_extract_variable_name_from_variable() {
215        let expr = HirExpression::Variable("my_var".to_string());
216        assert_eq!(extract_variable_name(&expr), Some("my_var".to_string()));
217    }
218
219    #[test]
220    fn test_extract_variable_name_from_field_access() {
221        let expr = HirExpression::FieldAccess {
222            object: Box::new(HirExpression::Variable("obj".to_string())),
223            field: "field_name".to_string(),
224        };
225        assert_eq!(extract_variable_name(&expr), Some("field_name".to_string()));
226    }
227
228    #[test]
229    fn test_extract_variable_name_from_literal_returns_none() {
230        let expr = HirExpression::IntLiteral(42);
231        assert_eq!(extract_variable_name(&expr), None);
232    }
233}