use decy_hir::HirFunction;
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OutputParameter {
pub name: String,
pub kind: ParameterKind,
pub is_fallible: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParameterKind {
Output,
InputOutput,
}
#[derive(Debug, Clone)]
pub struct OutputParamDetector;
impl OutputParamDetector {
pub fn new() -> Self {
Self
}
pub fn detect(&self, func: &HirFunction) -> Vec<OutputParameter> {
let mut results = Vec::new();
let mut reads: HashMap<String, bool> = HashMap::new();
let mut writes: HashMap<String, bool> = HashMap::new();
for param in func.parameters() {
if Self::is_pointer_type(param.param_type()) {
reads.insert(param.name().to_string(), false);
writes.insert(param.name().to_string(), false);
}
}
for stmt in func.body() {
Self::analyze_statement_internal(stmt, &mut reads, &mut writes);
}
let is_fallible = self.is_fallible_function(func);
for param in func.parameters() {
let param_name = param.name();
if !Self::is_pointer_type(param.param_type()) {
continue;
}
let was_read = reads.get(param_name).copied().unwrap_or(false);
let was_written = writes.get(param_name).copied().unwrap_or(false);
if was_written && !was_read {
results.push(OutputParameter {
name: param_name.to_string(),
kind: ParameterKind::Output,
is_fallible,
});
}
}
results
}
fn is_pointer_type(ty: &decy_hir::HirType) -> bool {
matches!(ty, decy_hir::HirType::Pointer(_))
}
fn is_fallible_function(&self, func: &HirFunction) -> bool {
use decy_hir::HirType;
if matches!(func.return_type(), HirType::Void) {
return false;
}
matches!(func.return_type(), HirType::Int)
}
fn analyze_statement_internal(
stmt: &decy_hir::HirStatement,
reads: &mut HashMap<String, bool>,
writes: &mut HashMap<String, bool>,
) {
use decy_hir::{HirExpression, HirStatement};
match stmt {
HirStatement::DerefAssignment { target, value } => {
if let HirExpression::Variable(var_name) = target {
if writes.contains_key(var_name) {
if !reads.get(var_name).copied().unwrap_or(false) {
writes.insert(var_name.clone(), true);
}
}
}
Self::analyze_expression_internal(value, reads);
}
HirStatement::VariableDeclaration { initializer: Some(expr), .. } => {
Self::analyze_expression_internal(expr, reads);
}
HirStatement::Assignment { value, .. } => {
Self::analyze_expression_internal(value, reads);
}
HirStatement::Return(Some(expr)) => {
Self::analyze_expression_internal(expr, reads);
}
HirStatement::If { condition, then_block, else_block } => {
Self::analyze_expression_internal(condition, reads);
for s in then_block {
Self::analyze_statement_internal(s, reads, writes);
}
if let Some(else_stmts) = else_block {
for s in else_stmts {
Self::analyze_statement_internal(s, reads, writes);
}
}
}
HirStatement::While { condition, body } => {
Self::analyze_expression_internal(condition, reads);
for s in body {
Self::analyze_statement_internal(s, reads, writes);
}
}
HirStatement::For { init, condition, increment, body } => {
for init_stmt in init {
Self::analyze_statement_internal(init_stmt, reads, writes);
}
if let Some(cond) = condition {
Self::analyze_expression_internal(cond, reads);
}
for inc_stmt in increment {
Self::analyze_statement_internal(inc_stmt, reads, writes);
}
for s in body {
Self::analyze_statement_internal(s, reads, writes);
}
}
HirStatement::Switch { condition, cases, default_case } => {
Self::analyze_expression_internal(condition, reads);
for case in cases {
for s in &case.body {
Self::analyze_statement_internal(s, reads, writes);
}
}
if let Some(default_stmts) = default_case {
for s in default_stmts {
Self::analyze_statement_internal(s, reads, writes);
}
}
}
HirStatement::Expression(expr) => {
Self::analyze_expression_internal(expr, reads);
}
_ => {}
}
}
fn analyze_expression_internal(
expr: &decy_hir::HirExpression,
reads: &mut HashMap<String, bool>,
) {
use decy_hir::HirExpression;
match expr {
HirExpression::Dereference(inner) => {
if let HirExpression::Variable(var_name) = inner.as_ref() {
if reads.contains_key(var_name) {
reads.insert(var_name.clone(), true);
}
}
}
HirExpression::BinaryOp { left, right, .. } => {
Self::analyze_expression_internal(left, reads);
Self::analyze_expression_internal(right, reads);
}
HirExpression::UnaryOp { operand, .. } => {
Self::analyze_expression_internal(operand, reads);
}
HirExpression::FunctionCall { arguments, .. } => {
for arg in arguments {
Self::analyze_expression_internal(arg, reads);
}
}
HirExpression::FieldAccess { object, .. }
| HirExpression::PointerFieldAccess { pointer: object, .. } => {
Self::analyze_expression_internal(object, reads);
}
HirExpression::ArrayIndex { array, index }
| HirExpression::SliceIndex { slice: array, index, .. } => {
Self::analyze_expression_internal(array, reads);
Self::analyze_expression_internal(index, reads);
}
_ => {}
}
}
}
impl Default for OutputParamDetector {
fn default() -> Self {
Self::new()
}
}