decy_codegen/
concurrency_transform.rs1use decy_hir::{HirExpression, HirFunction, HirStatement};
10
11pub fn is_pthread_lock(stmt: &HirStatement) -> Option<String> {
19 if let HirStatement::Expression(HirExpression::FunctionCall { function, arguments }) = stmt {
20 if function == "pthread_mutex_lock" && !arguments.is_empty() {
21 if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
23 return extract_variable_name(inner);
24 }
25 }
26 }
27 None
28}
29
30pub fn is_pthread_unlock(stmt: &HirStatement) -> Option<String> {
38 if let HirStatement::Expression(HirExpression::FunctionCall { function, arguments }) = stmt {
39 if function == "pthread_mutex_unlock" && !arguments.is_empty() {
40 if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
42 return extract_variable_name(inner);
43 }
44 }
45 }
46 None
47}
48
49fn extract_variable_name(expr: &HirExpression) -> Option<String> {
57 match expr {
58 HirExpression::Variable(name) => Some(name.clone()),
59 HirExpression::PointerFieldAccess { field, .. } => Some(field.clone()),
60 HirExpression::FieldAccess { field, .. } => Some(field.clone()),
61 _ => None,
62 }
63}
64
65pub fn identify_lock_regions(func: &HirFunction) -> Vec<(String, usize, usize)> {
72 let mut regions = Vec::new();
73 let body = func.body();
74 let mut active_locks: std::collections::HashMap<String, usize> =
75 std::collections::HashMap::new();
76
77 for (idx, stmt) in body.iter().enumerate() {
78 if let Some(lock_name) = is_pthread_lock(stmt) {
79 active_locks.insert(lock_name, idx);
80 } else if let Some(unlock_name) = is_pthread_unlock(stmt) {
81 if let Some(start_idx) = active_locks.remove(&unlock_name) {
82 regions.push((unlock_name, start_idx, idx));
83 }
84 }
85 }
86
87 regions
88}
89
90pub fn has_pthread_mutex_calls(func: &HirFunction) -> bool {
92 func.body()
93 .iter()
94 .any(|stmt| is_pthread_lock(stmt).is_some() || is_pthread_unlock(stmt).is_some())
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use decy_hir::HirType;
101
102 fn lock_call(lock_name: &str) -> HirStatement {
103 HirStatement::Expression(HirExpression::FunctionCall {
104 function: "pthread_mutex_lock".to_string(),
105 arguments: vec![HirExpression::AddressOf(Box::new(HirExpression::Variable(
106 lock_name.to_string(),
107 )))],
108 })
109 }
110
111 fn unlock_call(lock_name: &str) -> HirStatement {
112 HirStatement::Expression(HirExpression::FunctionCall {
113 function: "pthread_mutex_unlock".to_string(),
114 arguments: vec![HirExpression::AddressOf(Box::new(HirExpression::Variable(
115 lock_name.to_string(),
116 )))],
117 })
118 }
119
120 #[test]
121 fn test_detect_pthread_lock_call() {
122 let stmt = lock_call("my_mutex");
123 assert_eq!(is_pthread_lock(&stmt), Some("my_mutex".to_string()));
124 }
125
126 #[test]
127 fn test_detect_pthread_unlock_call() {
128 let stmt = unlock_call("my_mutex");
129 assert_eq!(is_pthread_unlock(&stmt), Some("my_mutex".to_string()));
130 }
131
132 #[test]
133 fn test_non_pthread_call_not_detected() {
134 let stmt = HirStatement::Expression(HirExpression::FunctionCall {
135 function: "some_other_function".to_string(),
136 arguments: vec![],
137 });
138 assert_eq!(is_pthread_lock(&stmt), None);
139 assert_eq!(is_pthread_unlock(&stmt), None);
140 }
141
142 #[test]
143 fn test_identify_single_lock_region() {
144 let func = HirFunction::new_with_body(
145 "test".to_string(),
146 HirType::Void,
147 vec![],
148 vec![
149 lock_call("lock"),
150 HirStatement::Assignment {
151 target: "data".to_string(),
152 value: HirExpression::IntLiteral(42),
153 },
154 unlock_call("lock"),
155 ],
156 );
157
158 let regions = identify_lock_regions(&func);
159 assert_eq!(regions.len(), 1);
160 assert_eq!(regions[0].0, "lock");
161 assert_eq!(regions[0].1, 0); assert_eq!(regions[0].2, 2); }
164
165 #[test]
166 fn test_identify_multiple_lock_regions() {
167 let func = HirFunction::new_with_body(
168 "test".to_string(),
169 HirType::Void,
170 vec![],
171 vec![
172 lock_call("lock1"),
173 unlock_call("lock1"),
174 lock_call("lock2"),
175 unlock_call("lock2"),
176 ],
177 );
178
179 let regions = identify_lock_regions(&func);
180 assert_eq!(regions.len(), 2);
181 assert_eq!(regions[0].0, "lock1");
182 assert_eq!(regions[1].0, "lock2");
183 }
184
185 #[test]
186 fn test_has_pthread_mutex_calls() {
187 let func_with_mutex = HirFunction::new_with_body(
188 "test".to_string(),
189 HirType::Void,
190 vec![],
191 vec![lock_call("lock"), unlock_call("lock")],
192 );
193
194 let func_without_mutex = HirFunction::new_with_body(
195 "test".to_string(),
196 HirType::Void,
197 vec![],
198 vec![HirStatement::Return(Some(HirExpression::IntLiteral(0)))],
199 );
200
201 assert!(has_pthread_mutex_calls(&func_with_mutex));
202 assert!(!has_pthread_mutex_calls(&func_without_mutex));
203 }
204
205 #[test]
206 fn test_extract_variable_name_from_variable() {
207 let expr = HirExpression::Variable("my_var".to_string());
208 assert_eq!(extract_variable_name(&expr), Some("my_var".to_string()));
209 }
210
211 #[test]
212 fn test_extract_variable_name_from_field_access() {
213 let expr = HirExpression::FieldAccess {
214 object: Box::new(HirExpression::Variable("obj".to_string())),
215 field: "field_name".to_string(),
216 };
217 assert_eq!(extract_variable_name(&expr), Some("field_name".to_string()));
218 }
219
220 #[test]
221 fn test_extract_variable_name_from_literal_returns_none() {
222 let expr = HirExpression::IntLiteral(42);
223 assert_eq!(extract_variable_name(&expr), None);
224 }
225
226 #[test]
227 fn test_lock_via_pointer_field_access() {
228 let stmt = HirStatement::Expression(HirExpression::FunctionCall {
229 function: "pthread_mutex_lock".to_string(),
230 arguments: vec![HirExpression::AddressOf(Box::new(
231 HirExpression::PointerFieldAccess {
232 pointer: Box::new(HirExpression::Variable("ctx".to_string())),
233 field: "lock".to_string(),
234 },
235 ))],
236 });
237 assert_eq!(is_pthread_lock(&stmt), Some("lock".to_string()));
238 }
239
240 #[test]
241 fn test_unlock_via_pointer_field_access() {
242 let stmt = HirStatement::Expression(HirExpression::FunctionCall {
243 function: "pthread_mutex_unlock".to_string(),
244 arguments: vec![HirExpression::AddressOf(Box::new(
245 HirExpression::PointerFieldAccess {
246 pointer: Box::new(HirExpression::Variable("ctx".to_string())),
247 field: "lock".to_string(),
248 },
249 ))],
250 });
251 assert_eq!(is_pthread_unlock(&stmt), Some("lock".to_string()));
252 }
253
254 #[test]
255 fn test_unmatched_lock_no_region() {
256 let func = HirFunction::new_with_body(
257 "test".to_string(),
258 HirType::Void,
259 vec![],
260 vec![lock_call("orphan")],
261 );
262 let regions = identify_lock_regions(&func);
263 assert!(regions.is_empty());
264 }
265}