1use decy_hir::HirFunction;
24use std::collections::HashMap;
25
26#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct OutputParameter {
29 pub name: String,
31 pub kind: ParameterKind,
33 pub is_fallible: bool,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum ParameterKind {
40 Output,
42 InputOutput,
44}
45
46#[derive(Debug, Clone)]
48pub struct OutputParamDetector;
49
50impl OutputParamDetector {
51 pub fn new() -> Self {
53 Self
54 }
55
56 pub fn detect(&self, func: &HirFunction) -> Vec<OutputParameter> {
77 let mut results = Vec::new();
78
79 let mut reads: HashMap<String, bool> = HashMap::new();
81 let mut writes: HashMap<String, bool> = HashMap::new();
82
83 for param in func.parameters() {
85 if Self::is_pointer_type(param.param_type()) {
86 reads.insert(param.name().to_string(), false);
87 writes.insert(param.name().to_string(), false);
88 }
89 }
90
91 for stmt in func.body() {
93 Self::analyze_statement_internal(stmt, &mut reads, &mut writes);
94 }
95
96 let is_fallible = self.is_fallible_function(func);
98
99 for param in func.parameters() {
101 let param_name = param.name();
102
103 if !Self::is_pointer_type(param.param_type()) {
104 continue;
105 }
106
107 let was_read = reads.get(param_name).copied().unwrap_or(false);
108 let was_written = writes.get(param_name).copied().unwrap_or(false);
109
110 if was_written && !was_read {
112 results.push(OutputParameter {
113 name: param_name.to_string(),
114 kind: ParameterKind::Output,
115 is_fallible,
116 });
117 }
118 }
119
120 results
121 }
122
123 fn is_pointer_type(ty: &decy_hir::HirType) -> bool {
125 matches!(ty, decy_hir::HirType::Pointer(_))
126 }
127
128 fn is_fallible_function(&self, func: &HirFunction) -> bool {
134 use decy_hir::HirType;
135
136 if matches!(func.return_type(), HirType::Void) {
138 return false;
139 }
140
141 matches!(func.return_type(), HirType::Int)
144 }
145
146 fn analyze_statement_internal(
148 stmt: &decy_hir::HirStatement,
149 reads: &mut HashMap<String, bool>,
150 writes: &mut HashMap<String, bool>,
151 ) {
152 use decy_hir::{HirExpression, HirStatement};
153
154 match stmt {
155 HirStatement::DerefAssignment { target, value } => {
157 if let HirExpression::Variable(var_name) = target {
159 if writes.contains_key(var_name) {
160 if !reads.get(var_name).copied().unwrap_or(false) {
162 writes.insert(var_name.clone(), true);
163 }
164 }
165 }
166
167 Self::analyze_expression_internal(value, reads);
169 }
170
171 HirStatement::VariableDeclaration { initializer: Some(expr), .. } => {
173 Self::analyze_expression_internal(expr, reads);
174 }
175
176 HirStatement::Assignment { value, .. } => {
178 Self::analyze_expression_internal(value, reads);
179 }
180
181 HirStatement::Return(Some(expr)) => {
183 Self::analyze_expression_internal(expr, reads);
184 }
185
186 HirStatement::If { condition, then_block, else_block } => {
188 Self::analyze_expression_internal(condition, reads);
189 for s in then_block {
190 Self::analyze_statement_internal(s, reads, writes);
191 }
192 if let Some(else_stmts) = else_block {
193 for s in else_stmts {
194 Self::analyze_statement_internal(s, reads, writes);
195 }
196 }
197 }
198
199 HirStatement::While { condition, body } => {
200 Self::analyze_expression_internal(condition, reads);
201 for s in body {
202 Self::analyze_statement_internal(s, reads, writes);
203 }
204 }
205
206 HirStatement::For { init, condition, increment, body } => {
207 for init_stmt in init {
209 Self::analyze_statement_internal(init_stmt, reads, writes);
210 }
211 if let Some(cond) = condition {
212 Self::analyze_expression_internal(cond, reads);
213 }
214 for inc_stmt in increment {
216 Self::analyze_statement_internal(inc_stmt, reads, writes);
217 }
218 for s in body {
219 Self::analyze_statement_internal(s, reads, writes);
220 }
221 }
222
223 HirStatement::Switch { condition, cases, default_case } => {
224 Self::analyze_expression_internal(condition, reads);
225 for case in cases {
226 for s in &case.body {
227 Self::analyze_statement_internal(s, reads, writes);
228 }
229 }
230 if let Some(default_stmts) = default_case {
231 for s in default_stmts {
232 Self::analyze_statement_internal(s, reads, writes);
233 }
234 }
235 }
236
237 HirStatement::Expression(expr) => {
238 Self::analyze_expression_internal(expr, reads);
239 }
240
241 _ => {}
242 }
243 }
244
245 fn analyze_expression_internal(
247 expr: &decy_hir::HirExpression,
248 reads: &mut HashMap<String, bool>,
249 ) {
250 use decy_hir::HirExpression;
251
252 match expr {
253 HirExpression::Dereference(inner) => {
255 if let HirExpression::Variable(var_name) = inner.as_ref() {
256 if reads.contains_key(var_name) {
257 reads.insert(var_name.clone(), true);
258 }
259 }
260 }
261
262 HirExpression::BinaryOp { left, right, .. } => {
264 Self::analyze_expression_internal(left, reads);
265 Self::analyze_expression_internal(right, reads);
266 }
267
268 HirExpression::UnaryOp { operand, .. } => {
270 Self::analyze_expression_internal(operand, reads);
271 }
272
273 HirExpression::FunctionCall { arguments, .. } => {
275 for arg in arguments {
276 Self::analyze_expression_internal(arg, reads);
277 }
278 }
279
280 HirExpression::FieldAccess { object, .. }
282 | HirExpression::PointerFieldAccess { pointer: object, .. } => {
283 Self::analyze_expression_internal(object, reads);
284 }
285
286 HirExpression::ArrayIndex { array, index }
288 | HirExpression::SliceIndex { slice: array, index, .. } => {
289 Self::analyze_expression_internal(array, reads);
290 Self::analyze_expression_internal(index, reads);
291 }
292
293 _ => {}
294 }
295 }
296}
297
298impl Default for OutputParamDetector {
299 fn default() -> Self {
300 Self::new()
301 }
302}