use decy_hir::{HirExpression, HirFunction, HirStatement, HirType};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum VoidPtrPattern {
Generic,
Swap,
Compare,
Copy,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TypeConstraint {
Readable,
Mutable,
Copy,
Clone,
PartialOrd,
PartialEq,
}
#[derive(Debug, Clone)]
pub struct VoidPtrInfo {
pub param_name: String,
pub pattern: VoidPtrPattern,
pub inferred_types: Vec<HirType>,
pub constraints: Vec<TypeConstraint>,
}
pub struct VoidPtrAnalyzer;
impl VoidPtrAnalyzer {
pub fn new() -> Self {
Self
}
pub fn analyze(&self, func: &HirFunction) -> Vec<VoidPtrInfo> {
let mut results = Vec::new();
let void_ptr_params: Vec<_> =
func.parameters().iter().filter(|p| self.is_void_ptr(p.param_type())).collect();
if void_ptr_params.is_empty() {
return results;
}
let pattern = self.detect_pattern(func, &void_ptr_params);
for param in void_ptr_params {
let mut info = VoidPtrInfo {
param_name: param.name().to_string(),
pattern: pattern.clone(),
inferred_types: Vec::new(),
constraints: Vec::new(),
};
self.analyze_body(func.body(), param.name(), &mut info);
results.push(info);
}
results
}
fn is_void_ptr(&self, ty: &HirType) -> bool {
matches!(ty, HirType::Pointer(inner) if matches!(inner.as_ref(), HirType::Void))
}
fn detect_pattern(
&self,
func: &HirFunction,
void_params: &[&decy_hir::HirParameter],
) -> VoidPtrPattern {
let param_count = void_params.len();
let has_size_param = func
.parameters()
.iter()
.any(|p| p.name().contains("size") || p.name() == "n" || p.name() == "len");
let returns_int = matches!(func.return_type(), HirType::Int);
if param_count == 2 && has_size_param && func.name() == "swap" {
return VoidPtrPattern::Swap;
}
if param_count == 2
&& returns_int
&& (func.name().contains("cmp") || func.name() == "compare")
{
return VoidPtrPattern::Compare;
}
if param_count == 2 && has_size_param {
let names: Vec<&str> = void_params.iter().map(|p| p.name()).collect();
if names.contains(&"dest") || names.contains(&"src") || func.name().contains("copy") {
return VoidPtrPattern::Copy;
}
}
VoidPtrPattern::Generic
}
fn analyze_body(&self, stmts: &[HirStatement], param_name: &str, info: &mut VoidPtrInfo) {
for stmt in stmts {
self.analyze_statement(stmt, param_name, info);
}
}
fn analyze_statement(&self, stmt: &HirStatement, param_name: &str, info: &mut VoidPtrInfo) {
match stmt {
HirStatement::VariableDeclaration { initializer: Some(init), .. } => {
self.analyze_expression(init, param_name, info);
}
HirStatement::DerefAssignment { target, value } => {
if self.expr_uses_param(target, param_name)
&& !info.constraints.contains(&TypeConstraint::Mutable)
{
info.constraints.push(TypeConstraint::Mutable);
}
if matches!(value, HirExpression::Dereference(_))
&& !info.constraints.contains(&TypeConstraint::Clone)
{
info.constraints.push(TypeConstraint::Clone);
}
self.analyze_expression(target, param_name, info);
self.analyze_expression(value, param_name, info);
}
HirStatement::If { condition, then_block, else_block, .. } => {
self.analyze_expression(condition, param_name, info);
self.analyze_body(then_block, param_name, info);
if let Some(else_stmts) = else_block {
self.analyze_body(else_stmts, param_name, info);
}
}
HirStatement::While { condition, body, .. } => {
self.analyze_expression(condition, param_name, info);
self.analyze_body(body, param_name, info);
}
HirStatement::For { body, .. } => {
self.analyze_body(body, param_name, info);
}
HirStatement::Expression(expr) => {
self.analyze_expression(expr, param_name, info);
}
HirStatement::Return(Some(expr)) => {
self.analyze_expression(expr, param_name, info);
}
_ => {}
}
}
fn analyze_expression(&self, expr: &HirExpression, param_name: &str, info: &mut VoidPtrInfo) {
match expr {
HirExpression::Cast { expr: inner, target_type } => {
if self.expr_uses_param(inner, param_name) {
if let HirType::Pointer(inner_type) = target_type {
if !info.inferred_types.contains(inner_type) {
info.inferred_types.push((**inner_type).clone());
}
}
}
}
HirExpression::BinaryOp { op, left, right } => {
let uses_param = self.expr_uses_param(left, param_name)
|| self.expr_uses_param(right, param_name);
if uses_param {
use decy_hir::BinaryOperator;
match op {
BinaryOperator::LessThan
| BinaryOperator::GreaterThan
| BinaryOperator::LessEqual
| BinaryOperator::GreaterEqual => {
if !info.constraints.contains(&TypeConstraint::PartialOrd) {
info.constraints.push(TypeConstraint::PartialOrd);
}
}
BinaryOperator::Equal | BinaryOperator::NotEqual => {
if !info.constraints.contains(&TypeConstraint::PartialEq) {
info.constraints.push(TypeConstraint::PartialEq);
}
}
_ => {}
}
}
self.analyze_expression(left, param_name, info);
self.analyze_expression(right, param_name, info);
}
HirExpression::Dereference(inner) => {
self.analyze_expression(inner, param_name, info);
}
HirExpression::FunctionCall { arguments, .. } => {
for arg in arguments {
self.analyze_expression(arg, param_name, info);
}
}
_ => {}
}
}
fn expr_uses_param(&self, expr: &HirExpression, param_name: &str) -> bool {
match expr {
HirExpression::Variable(name) => name == param_name,
HirExpression::Cast { expr: inner, .. } => self.expr_uses_param(inner, param_name),
HirExpression::Dereference(inner) => self.expr_uses_param(inner, param_name),
_ => false,
}
}
}
impl Default for VoidPtrAnalyzer {
fn default() -> Self {
Self::new()
}
}