1use decy_hir::HirFunction;
24use std::collections::HashMap;
25
26#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct OutputParameter {
29 pub name: String,
31 pub kind: ParameterKind,
33 pub is_fallible: bool,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum ParameterKind {
40 Output,
42 InputOutput,
44}
45
46#[derive(Debug, Clone)]
48pub struct OutputParamDetector;
49
50impl OutputParamDetector {
51 pub fn new() -> Self {
53 Self
54 }
55
56 pub fn detect(&self, func: &HirFunction) -> Vec<OutputParameter> {
77 let mut results = Vec::new();
78
79 let mut reads: HashMap<String, bool> = HashMap::new();
81 let mut writes: HashMap<String, bool> = HashMap::new();
82
83 for param in func.parameters() {
85 if Self::is_pointer_type(param.param_type()) {
86 reads.insert(param.name().to_string(), false);
87 writes.insert(param.name().to_string(), false);
88 }
89 }
90
91 for stmt in func.body() {
93 Self::analyze_statement_internal(stmt, &mut reads, &mut writes);
94 }
95
96 let is_fallible = self.is_fallible_function(func);
98
99 for param in func.parameters() {
101 let param_name = param.name();
102
103 if !Self::is_pointer_type(param.param_type()) {
104 continue;
105 }
106
107 let was_read = reads.get(param_name).copied().unwrap_or(false);
108 let was_written = writes.get(param_name).copied().unwrap_or(false);
109
110 if was_written && !was_read {
112 results.push(OutputParameter {
113 name: param_name.to_string(),
114 kind: ParameterKind::Output,
115 is_fallible,
116 });
117 }
118 }
119
120 results
121 }
122
123 fn is_pointer_type(ty: &decy_hir::HirType) -> bool {
125 matches!(ty, decy_hir::HirType::Pointer(_))
126 }
127
128 fn is_fallible_function(&self, func: &HirFunction) -> bool {
134 use decy_hir::HirType;
135
136 if matches!(func.return_type(), HirType::Void) {
138 return false;
139 }
140
141 matches!(func.return_type(), HirType::Int)
144 }
145
146 fn analyze_statement_internal(
148 stmt: &decy_hir::HirStatement,
149 reads: &mut HashMap<String, bool>,
150 writes: &mut HashMap<String, bool>,
151 ) {
152 use decy_hir::{HirExpression, HirStatement};
153
154 match stmt {
155 HirStatement::DerefAssignment { target, value } => {
157 if let HirExpression::Variable(var_name) = target {
159 if writes.contains_key(var_name) {
160 if !reads.get(var_name).copied().unwrap_or(false) {
162 writes.insert(var_name.clone(), true);
163 }
164 }
165 }
166
167 Self::analyze_expression_internal(value, reads);
169 }
170
171 HirStatement::VariableDeclaration {
173 initializer: Some(expr),
174 ..
175 } => {
176 Self::analyze_expression_internal(expr, reads);
177 }
178
179 HirStatement::Assignment { value, .. } => {
181 Self::analyze_expression_internal(value, reads);
182 }
183
184 HirStatement::Return(Some(expr)) => {
186 Self::analyze_expression_internal(expr, reads);
187 }
188
189 HirStatement::If {
191 condition,
192 then_block,
193 else_block,
194 } => {
195 Self::analyze_expression_internal(condition, reads);
196 for s in then_block {
197 Self::analyze_statement_internal(s, reads, writes);
198 }
199 if let Some(else_stmts) = else_block {
200 for s in else_stmts {
201 Self::analyze_statement_internal(s, reads, writes);
202 }
203 }
204 }
205
206 HirStatement::While { condition, body } => {
207 Self::analyze_expression_internal(condition, reads);
208 for s in body {
209 Self::analyze_statement_internal(s, reads, writes);
210 }
211 }
212
213 HirStatement::For {
214 init,
215 condition,
216 increment,
217 body,
218 } => {
219 if let Some(init_stmt) = init {
220 Self::analyze_statement_internal(init_stmt, reads, writes);
221 }
222 Self::analyze_expression_internal(condition, reads);
223 if let Some(inc_stmt) = increment {
224 Self::analyze_statement_internal(inc_stmt, reads, writes);
225 }
226 for s in body {
227 Self::analyze_statement_internal(s, reads, writes);
228 }
229 }
230
231 HirStatement::Switch {
232 condition,
233 cases,
234 default_case,
235 } => {
236 Self::analyze_expression_internal(condition, reads);
237 for case in cases {
238 for s in &case.body {
239 Self::analyze_statement_internal(s, reads, writes);
240 }
241 }
242 if let Some(default_stmts) = default_case {
243 for s in default_stmts {
244 Self::analyze_statement_internal(s, reads, writes);
245 }
246 }
247 }
248
249 HirStatement::Expression(expr) => {
250 Self::analyze_expression_internal(expr, reads);
251 }
252
253 _ => {}
254 }
255 }
256
257 fn analyze_expression_internal(
259 expr: &decy_hir::HirExpression,
260 reads: &mut HashMap<String, bool>,
261 ) {
262 use decy_hir::HirExpression;
263
264 match expr {
265 HirExpression::Dereference(inner) => {
267 if let HirExpression::Variable(var_name) = inner.as_ref() {
268 if reads.contains_key(var_name) {
269 reads.insert(var_name.clone(), true);
270 }
271 }
272 }
273
274 HirExpression::BinaryOp { left, right, .. } => {
276 Self::analyze_expression_internal(left, reads);
277 Self::analyze_expression_internal(right, reads);
278 }
279
280 HirExpression::UnaryOp { operand, .. } => {
282 Self::analyze_expression_internal(operand, reads);
283 }
284
285 HirExpression::FunctionCall { arguments, .. } => {
287 for arg in arguments {
288 Self::analyze_expression_internal(arg, reads);
289 }
290 }
291
292 HirExpression::FieldAccess { object, .. }
294 | HirExpression::PointerFieldAccess {
295 pointer: object, ..
296 } => {
297 Self::analyze_expression_internal(object, reads);
298 }
299
300 HirExpression::ArrayIndex { array, index }
302 | HirExpression::SliceIndex {
303 slice: array,
304 index,
305 ..
306 } => {
307 Self::analyze_expression_internal(array, reads);
308 Self::analyze_expression_internal(index, reads);
309 }
310
311 _ => {}
312 }
313 }
314}
315
316impl Default for OutputParamDetector {
317 fn default() -> Self {
318 Self::new()
319 }
320}