aver-lang 0.15.0

VM and transpiler for Aver, a statically-typed language designed for AI-assisted development
Documentation
use super::expr::aver_name_to_lean;
use crate::ast::{BinOp, Expr, FnDef, Literal, MatchArm, Pattern, Spanned};

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) struct AffinePairExpr {
    pub prev2_coeff: i64,
    pub prev1_coeff: i64,
    pub constant: i64,
}

impl AffinePairExpr {
    fn prev2() -> Self {
        Self {
            prev2_coeff: 1,
            prev1_coeff: 0,
            constant: 0,
        }
    }

    fn prev1() -> Self {
        Self {
            prev2_coeff: 0,
            prev1_coeff: 1,
            constant: 0,
        }
    }

    fn constant(value: i64) -> Self {
        Self {
            prev2_coeff: 0,
            prev1_coeff: 0,
            constant: value,
        }
    }

    fn add(self, other: Self) -> Self {
        Self {
            prev2_coeff: self.prev2_coeff + other.prev2_coeff,
            prev1_coeff: self.prev1_coeff + other.prev1_coeff,
            constant: self.constant + other.constant,
        }
    }

    fn sub(self, other: Self) -> Self {
        Self {
            prev2_coeff: self.prev2_coeff - other.prev2_coeff,
            prev1_coeff: self.prev1_coeff - other.prev1_coeff,
            constant: self.constant - other.constant,
        }
    }

    fn scale(self, factor: i64) -> Self {
        Self {
            prev2_coeff: self.prev2_coeff * factor,
            prev1_coeff: self.prev1_coeff * factor,
            constant: self.constant * factor,
        }
    }
}

#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SecondOrderIntLinearRecurrenceShape {
    pub param_name: String,
    pub negative_branch: Spanned<Expr>,
    pub base0: Spanned<Expr>,
    pub base1: Spanned<Expr>,
    pub recurrence: AffinePairExpr,
}

#[derive(Clone, Debug, PartialEq)]
pub(crate) struct TailrecIntLinearPairWorkerShape {
    pub count_param_name: String,
    pub prev_param_name: String,
    pub curr_param_name: String,
    pub recurrence: AffinePairExpr,
}

#[derive(Clone, Debug, PartialEq)]
pub(crate) struct TailrecIntLinearPairWrapperShape {
    pub param_name: String,
    pub negative_branch: Spanned<Expr>,
    pub helper_fn_name: String,
    pub seed_prev: Spanned<Expr>,
    pub seed_curr: Spanned<Expr>,
}

pub(crate) fn recurrence_nat_helper_name(fn_name: &str) -> String {
    format!("{}__nat", aver_name_to_lean(fn_name))
}

pub(crate) fn fuel_helper_name(fn_name: &str) -> String {
    format!("{}__fuel", aver_name_to_lean(fn_name))
}

pub(crate) fn render_affine_pair_expr(expr: AffinePairExpr, prev2: &str, prev1: &str) -> String {
    let mut terms = Vec::new();
    push_affine_term(&mut terms, expr.prev2_coeff, prev2);
    push_affine_term(&mut terms, expr.prev1_coeff, prev1);
    if expr.constant != 0 || terms.is_empty() {
        terms.push(expr.constant.to_string());
    }
    terms.join(" + ")
}

fn push_affine_term(terms: &mut Vec<String>, coeff: i64, value: &str) {
    match coeff {
        0 => {}
        1 => terms.push(value.to_string()),
        -1 => terms.push(format!("-({value})")),
        _ => terms.push(format!("({coeff}) * ({value})")),
    }
}

pub(crate) fn detect_second_order_int_linear_recurrence(
    fd: &FnDef,
) -> Option<SecondOrderIntLinearRecurrenceShape> {
    let [(param_name, param_type)] = fd.params.as_slice() else {
        return None;
    };
    if param_type != "Int" || fd.return_type != "Int" {
        return None;
    }

    let body = fd.body.tail_expr()?;
    let (negative_branch, nonnegative_branch) = split_negative_guard(&body.node, param_name)?;
    let Expr::Match {
        subject: inner_subject,
        arms,
    } = nonnegative_branch
    else {
        return None;
    };
    if !matches!(&inner_subject.node, Expr::Ident(name) if name == param_name) {
        return None;
    }

    let base0 = match_arm_body_for_int(arms, 0)?;
    let base1 = match_arm_body_for_int(arms, 1)?;
    let recursive_branch = match_arm_body_for_wildcard(arms)?;
    let recurrence = parse_recurrence_affine(&recursive_branch.node, &fd.name, param_name)?;

    Some(SecondOrderIntLinearRecurrenceShape {
        param_name: param_name.clone(),
        negative_branch: negative_branch.clone(),
        base0: base0.clone(),
        base1: base1.clone(),
        recurrence,
    })
}

pub(crate) fn detect_tailrec_int_linear_pair_worker(
    fd: &FnDef,
) -> Option<TailrecIntLinearPairWorkerShape> {
    let [
        (count_param_name, count_ty),
        (prev_param_name, prev_ty),
        (curr_param_name, curr_ty),
    ] = fd.params.as_slice()
    else {
        return None;
    };
    if count_ty != "Int" || prev_ty != "Int" || curr_ty != "Int" || fd.return_type != "Int" {
        return None;
    }

    let body = fd.body.tail_expr()?;
    let Expr::Match { subject, arms, .. } = &body.node else {
        return None;
    };
    if !matches!(&subject.node, Expr::Ident(name) if name == count_param_name) {
        return None;
    }

    let zero_branch = match_arm_body_for_int(arms, 0)?;
    if !matches!(&zero_branch.node, Expr::Ident(name) if name == prev_param_name) {
        return None;
    }

    let recursive_branch = match_arm_body_for_wildcard(arms)?;
    let args: &[Spanned<Expr>] = match &recursive_branch.node {
        Expr::FnCall(callee, args) if matches!(&callee.node, Expr::Ident(name) if name == &fd.name) => {
            args.as_slice()
        }
        Expr::TailCall(call) if call.target == fd.name => call.args.as_slice(),
        _ => return None,
    };
    if args.len() != 3 {
        return None;
    }
    if !matches_int_sub_positive(&args[0].node, count_param_name, 1)
        || !matches!(&args[1].node, Expr::Ident(name) if name == curr_param_name.as_str())
    {
        return None;
    }

    let recurrence = parse_affine_expr(&args[2].node, prev_param_name, curr_param_name)?;
    Some(TailrecIntLinearPairWorkerShape {
        count_param_name: count_param_name.clone(),
        prev_param_name: prev_param_name.clone(),
        curr_param_name: curr_param_name.clone(),
        recurrence,
    })
}

pub(crate) fn detect_tailrec_int_linear_pair_wrapper(
    fd: &FnDef,
) -> Option<TailrecIntLinearPairWrapperShape> {
    let [(param_name, param_type)] = fd.params.as_slice() else {
        return None;
    };
    if param_type != "Int" || fd.return_type != "Int" {
        return None;
    }

    let body = fd.body.tail_expr()?;
    let (negative_branch, nonnegative_branch) = split_negative_guard(&body.node, param_name)?;
    let Expr::FnCall(callee, args) = nonnegative_branch else {
        return None;
    };
    let Expr::Ident(helper_fn_name) = &callee.node else {
        return None;
    };
    let [count_arg, seed_prev, seed_curr] = args.as_slice() else {
        return None;
    };
    if !matches!(&count_arg.node, Expr::Ident(name) if name == param_name) {
        return None;
    }

    Some(TailrecIntLinearPairWrapperShape {
        param_name: param_name.clone(),
        negative_branch: negative_branch.clone(),
        helper_fn_name: helper_fn_name.clone(),
        seed_prev: seed_prev.clone(),
        seed_curr: seed_curr.clone(),
    })
}

fn split_negative_guard<'a>(
    expr: &'a Expr,
    param_name: &str,
) -> Option<(&'a Spanned<Expr>, &'a Expr)> {
    let Expr::Match { subject, arms, .. } = expr else {
        return None;
    };
    let Expr::BinOp(BinOp::Lt, left, right) = &subject.node else {
        return None;
    };
    if !matches!(&left.node, Expr::Ident(name) if name == param_name)
        || !matches!(&right.node, Expr::Literal(Literal::Int(0)))
    {
        return None;
    }

    let mut negative_branch = None;
    let mut nonnegative_branch = None;
    for arm in arms {
        match &arm.pattern {
            Pattern::Literal(Literal::Bool(true)) => negative_branch = Some(arm.body.as_ref()),
            Pattern::Literal(Literal::Bool(false)) => nonnegative_branch = Some(arm.body.as_ref()),
            _ => return None,
        }
    }

    Some((negative_branch?, &nonnegative_branch?.node))
}

fn match_arm_body_for_int(arms: &[MatchArm], expected: i64) -> Option<&Spanned<Expr>> {
    arms.iter().find_map(|arm| match &arm.pattern {
        Pattern::Literal(Literal::Int(value)) if *value == expected => Some(arm.body.as_ref()),
        _ => None,
    })
}

fn match_arm_body_for_wildcard(arms: &[MatchArm]) -> Option<&Spanned<Expr>> {
    arms.iter().find_map(|arm| match arm.pattern {
        Pattern::Wildcard => Some(arm.body.as_ref()),
        _ => None,
    })
}

fn parse_recurrence_affine(expr: &Expr, fn_name: &str, param_name: &str) -> Option<AffinePairExpr> {
    match expr {
        Expr::Literal(Literal::Int(value)) => Some(AffinePairExpr::constant(*value)),
        Expr::FnCall(callee, args) => {
            if !matches!(&callee.node, Expr::Ident(name) if name == fn_name) || args.len() != 1 {
                return None;
            }
            let offset = int_sub_positive_offset(&args[0].node, param_name)?;
            match offset {
                1 => Some(AffinePairExpr::prev1()),
                2 => Some(AffinePairExpr::prev2()),
                _ => None,
            }
        }
        Expr::BinOp(BinOp::Add, left, right) => {
            Some(
                parse_recurrence_affine(&left.node, fn_name, param_name)?
                    .add(parse_recurrence_affine(&right.node, fn_name, param_name)?),
            )
        }
        Expr::BinOp(BinOp::Sub, left, right) => {
            Some(
                parse_recurrence_affine(&left.node, fn_name, param_name)?
                    .sub(parse_recurrence_affine(&right.node, fn_name, param_name)?),
            )
        }
        Expr::BinOp(BinOp::Mul, left, right) => {
            if let Some(scale) = int_literal(&left.node) {
                return Some(
                    parse_recurrence_affine(&right.node, fn_name, param_name)?.scale(scale),
                );
            }
            if let Some(scale) = int_literal(&right.node) {
                return Some(
                    parse_recurrence_affine(&left.node, fn_name, param_name)?.scale(scale),
                );
            }
            None
        }
        _ => None,
    }
}

fn parse_affine_expr(expr: &Expr, prev2_name: &str, prev1_name: &str) -> Option<AffinePairExpr> {
    match expr {
        Expr::Literal(Literal::Int(value)) => Some(AffinePairExpr::constant(*value)),
        Expr::Ident(name) if name == prev2_name => Some(AffinePairExpr::prev2()),
        Expr::Ident(name) if name == prev1_name => Some(AffinePairExpr::prev1()),
        Expr::BinOp(BinOp::Add, left, right) => Some(
            parse_affine_expr(&left.node, prev2_name, prev1_name)?.add(parse_affine_expr(
                &right.node,
                prev2_name,
                prev1_name,
            )?),
        ),
        Expr::BinOp(BinOp::Sub, left, right) => Some(
            parse_affine_expr(&left.node, prev2_name, prev1_name)?.sub(parse_affine_expr(
                &right.node,
                prev2_name,
                prev1_name,
            )?),
        ),
        Expr::BinOp(BinOp::Mul, left, right) => {
            if let Some(scale) = int_literal(&left.node) {
                return Some(parse_affine_expr(&right.node, prev2_name, prev1_name)?.scale(scale));
            }
            if let Some(scale) = int_literal(&right.node) {
                return Some(parse_affine_expr(&left.node, prev2_name, prev1_name)?.scale(scale));
            }
            None
        }
        _ => None,
    }
}

fn int_literal(expr: &Expr) -> Option<i64> {
    match expr {
        Expr::Literal(Literal::Int(value)) => Some(*value),
        _ => None,
    }
}

fn matches_int_sub_positive(expr: &Expr, param_name: &str, expected: i64) -> bool {
    int_sub_positive_offset(expr, param_name) == Some(expected)
}

fn int_sub_positive_offset(expr: &Expr, param_name: &str) -> Option<i64> {
    let Expr::BinOp(BinOp::Sub, left, right) = expr else {
        return None;
    };
    if !matches!(&left.node, Expr::Ident(name) if name == param_name) {
        return None;
    }
    let value = int_literal(&right.node)?;
    (value > 0).then_some(value)
}