decy_codegen/
concurrency_transform.rs1use decy_hir::{HirExpression, HirFunction, HirStatement};
10
11pub fn is_pthread_lock(stmt: &HirStatement) -> Option<String> {
19 if let HirStatement::Expression(HirExpression::FunctionCall {
20 function,
21 arguments,
22 }) = stmt
23 {
24 if function == "pthread_mutex_lock" && !arguments.is_empty() {
25 if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
27 return extract_variable_name(inner);
28 }
29 }
30 }
31 None
32}
33
34pub fn is_pthread_unlock(stmt: &HirStatement) -> Option<String> {
42 if let HirStatement::Expression(HirExpression::FunctionCall {
43 function,
44 arguments,
45 }) = stmt
46 {
47 if function == "pthread_mutex_unlock" && !arguments.is_empty() {
48 if let Some(HirExpression::AddressOf(inner)) = arguments.first() {
50 return extract_variable_name(inner);
51 }
52 }
53 }
54 None
55}
56
57fn extract_variable_name(expr: &HirExpression) -> Option<String> {
65 match expr {
66 HirExpression::Variable(name) => Some(name.clone()),
67 HirExpression::PointerFieldAccess { field, .. } => Some(field.clone()),
68 HirExpression::FieldAccess { field, .. } => Some(field.clone()),
69 _ => None,
70 }
71}
72
73pub fn identify_lock_regions(func: &HirFunction) -> Vec<(String, usize, usize)> {
80 let mut regions = Vec::new();
81 let body = func.body();
82 let mut active_locks: std::collections::HashMap<String, usize> =
83 std::collections::HashMap::new();
84
85 for (idx, stmt) in body.iter().enumerate() {
86 if let Some(lock_name) = is_pthread_lock(stmt) {
87 active_locks.insert(lock_name, idx);
88 } else if let Some(unlock_name) = is_pthread_unlock(stmt) {
89 if let Some(start_idx) = active_locks.remove(&unlock_name) {
90 regions.push((unlock_name, start_idx, idx));
91 }
92 }
93 }
94
95 regions
96}
97
98pub fn has_pthread_mutex_calls(func: &HirFunction) -> bool {
100 func.body()
101 .iter()
102 .any(|stmt| is_pthread_lock(stmt).is_some() || is_pthread_unlock(stmt).is_some())
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use decy_hir::HirType;
109
110 fn lock_call(lock_name: &str) -> HirStatement {
111 HirStatement::Expression(HirExpression::FunctionCall {
112 function: "pthread_mutex_lock".to_string(),
113 arguments: vec![HirExpression::AddressOf(Box::new(HirExpression::Variable(
114 lock_name.to_string(),
115 )))],
116 })
117 }
118
119 fn unlock_call(lock_name: &str) -> HirStatement {
120 HirStatement::Expression(HirExpression::FunctionCall {
121 function: "pthread_mutex_unlock".to_string(),
122 arguments: vec![HirExpression::AddressOf(Box::new(HirExpression::Variable(
123 lock_name.to_string(),
124 )))],
125 })
126 }
127
128 #[test]
129 fn test_detect_pthread_lock_call() {
130 let stmt = lock_call("my_mutex");
131 assert_eq!(is_pthread_lock(&stmt), Some("my_mutex".to_string()));
132 }
133
134 #[test]
135 fn test_detect_pthread_unlock_call() {
136 let stmt = unlock_call("my_mutex");
137 assert_eq!(is_pthread_unlock(&stmt), Some("my_mutex".to_string()));
138 }
139
140 #[test]
141 fn test_non_pthread_call_not_detected() {
142 let stmt = HirStatement::Expression(HirExpression::FunctionCall {
143 function: "some_other_function".to_string(),
144 arguments: vec![],
145 });
146 assert_eq!(is_pthread_lock(&stmt), None);
147 assert_eq!(is_pthread_unlock(&stmt), None);
148 }
149
150 #[test]
151 fn test_identify_single_lock_region() {
152 let func = HirFunction::new_with_body(
153 "test".to_string(),
154 HirType::Void,
155 vec![],
156 vec![
157 lock_call("lock"),
158 HirStatement::Assignment {
159 target: "data".to_string(),
160 value: HirExpression::IntLiteral(42),
161 },
162 unlock_call("lock"),
163 ],
164 );
165
166 let regions = identify_lock_regions(&func);
167 assert_eq!(regions.len(), 1);
168 assert_eq!(regions[0].0, "lock");
169 assert_eq!(regions[0].1, 0); assert_eq!(regions[0].2, 2); }
172
173 #[test]
174 fn test_identify_multiple_lock_regions() {
175 let func = HirFunction::new_with_body(
176 "test".to_string(),
177 HirType::Void,
178 vec![],
179 vec![
180 lock_call("lock1"),
181 unlock_call("lock1"),
182 lock_call("lock2"),
183 unlock_call("lock2"),
184 ],
185 );
186
187 let regions = identify_lock_regions(&func);
188 assert_eq!(regions.len(), 2);
189 assert_eq!(regions[0].0, "lock1");
190 assert_eq!(regions[1].0, "lock2");
191 }
192
193 #[test]
194 fn test_has_pthread_mutex_calls() {
195 let func_with_mutex = HirFunction::new_with_body(
196 "test".to_string(),
197 HirType::Void,
198 vec![],
199 vec![lock_call("lock"), unlock_call("lock")],
200 );
201
202 let func_without_mutex = HirFunction::new_with_body(
203 "test".to_string(),
204 HirType::Void,
205 vec![],
206 vec![HirStatement::Return(Some(HirExpression::IntLiteral(0)))],
207 );
208
209 assert!(has_pthread_mutex_calls(&func_with_mutex));
210 assert!(!has_pthread_mutex_calls(&func_without_mutex));
211 }
212
213 #[test]
214 fn test_extract_variable_name_from_variable() {
215 let expr = HirExpression::Variable("my_var".to_string());
216 assert_eq!(extract_variable_name(&expr), Some("my_var".to_string()));
217 }
218
219 #[test]
220 fn test_extract_variable_name_from_field_access() {
221 let expr = HirExpression::FieldAccess {
222 object: Box::new(HirExpression::Variable("obj".to_string())),
223 field: "field_name".to_string(),
224 };
225 assert_eq!(extract_variable_name(&expr), Some("field_name".to_string()));
226 }
227
228 #[test]
229 fn test_extract_variable_name_from_literal_returns_none() {
230 let expr = HirExpression::IntLiteral(42);
231 assert_eq!(extract_variable_name(&expr), None);
232 }
233}