use super::functions::*;
use oxilean_kernel::Name;
use std::collections::{HashMap, HashSet};
#[derive(Clone, Debug, Default)]
pub struct LcnfModuleMetadata {
pub decl_count: usize,
pub lambdas_lifted: usize,
pub proofs_erased: usize,
pub types_erased: usize,
pub let_bindings: usize,
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum LcnfArg {
Var(LcnfVarId),
Lit(LcnfLit),
Erased,
Type(LcnfType),
}
#[derive(Clone, Debug, Default)]
pub struct LcnfModule {
pub fun_decls: Vec<LcnfFunDecl>,
pub extern_decls: Vec<LcnfExternDecl>,
pub name: String,
pub metadata: LcnfModuleMetadata,
}
#[derive(Clone, Debug)]
pub struct PrettyConfig {
pub indent: usize,
pub max_width: usize,
pub show_types: bool,
pub show_erased: bool,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct LcnfVarId(pub u64);
pub struct FreeVarCollector {
pub(super) bound: HashSet<LcnfVarId>,
pub(super) free: HashSet<LcnfVarId>,
}
impl FreeVarCollector {
pub(super) fn new() -> Self {
FreeVarCollector {
bound: HashSet::new(),
free: HashSet::new(),
}
}
pub(super) fn collect_from_arg(&mut self, arg: &LcnfArg) {
if let LcnfArg::Var(id) = arg {
if !self.bound.contains(id) {
self.free.insert(*id);
}
}
}
pub(super) fn collect_from_let_value(&mut self, val: &LcnfLetValue) {
match val {
LcnfLetValue::App(func, args) => {
self.collect_from_arg(func);
for arg in args {
self.collect_from_arg(arg);
}
}
LcnfLetValue::Proj(_, _, var) => {
if !self.bound.contains(var) {
self.free.insert(*var);
}
}
LcnfLetValue::Ctor(_, _, args) => {
for arg in args {
self.collect_from_arg(arg);
}
}
LcnfLetValue::FVar(id) => {
if !self.bound.contains(id) {
self.free.insert(*id);
}
}
LcnfLetValue::Lit(_)
| LcnfLetValue::Erased
| LcnfLetValue::Reset(_)
| LcnfLetValue::Reuse(_, _, _, _) => {}
}
}
pub(super) fn collect_expr(&mut self, expr: &LcnfExpr) {
match expr {
LcnfExpr::Let {
id, value, body, ..
} => {
self.collect_from_let_value(value);
self.bound.insert(*id);
self.collect_expr(body);
}
LcnfExpr::Case {
scrutinee,
alts,
default,
..
} => {
if !self.bound.contains(scrutinee) {
self.free.insert(*scrutinee);
}
for alt in alts {
let saved = self.bound.clone();
for param in &alt.params {
self.bound.insert(param.id);
}
self.collect_expr(&alt.body);
self.bound = saved;
}
if let Some(def) = default {
self.collect_expr(def);
}
}
LcnfExpr::Return(arg) => self.collect_from_arg(arg),
LcnfExpr::Unreachable => {}
LcnfExpr::TailCall(func, args) => {
self.collect_from_arg(func);
for arg in args {
self.collect_from_arg(arg);
}
}
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum LcnfType {
Erased,
Var(String),
Fun(Vec<LcnfType>, Box<LcnfType>),
Ctor(String, Vec<LcnfType>),
Object,
Nat,
LcnfString,
Unit,
Irrelevant,
}
pub struct UsageCounter {
pub(super) counts: HashMap<LcnfVarId, usize>,
}
impl UsageCounter {
pub(super) fn new() -> Self {
UsageCounter {
counts: HashMap::new(),
}
}
pub(super) fn count_arg(&mut self, arg: &LcnfArg) {
if let LcnfArg::Var(id) = arg {
*self.counts.entry(*id).or_insert(0) += 1;
}
}
pub(super) fn count_let_value(&mut self, val: &LcnfLetValue) {
match val {
LcnfLetValue::App(func, args) => {
self.count_arg(func);
for arg in args {
self.count_arg(arg);
}
}
LcnfLetValue::Proj(_, _, var) => {
*self.counts.entry(*var).or_insert(0) += 1;
}
LcnfLetValue::Ctor(_, _, args) => {
for arg in args {
self.count_arg(arg);
}
}
LcnfLetValue::FVar(id) => {
*self.counts.entry(*id).or_insert(0) += 1;
}
LcnfLetValue::Lit(_)
| LcnfLetValue::Erased
| LcnfLetValue::Reset(_)
| LcnfLetValue::Reuse(_, _, _, _) => {}
}
}
pub(super) fn count_expr(&mut self, expr: &LcnfExpr) {
match expr {
LcnfExpr::Let { value, body, .. } => {
self.count_let_value(value);
self.count_expr(body);
}
LcnfExpr::Case {
scrutinee,
alts,
default,
..
} => {
*self.counts.entry(*scrutinee).or_insert(0) += 1;
for alt in alts {
self.count_expr(&alt.body);
}
if let Some(def) = default {
self.count_expr(def);
}
}
LcnfExpr::Return(arg) => self.count_arg(arg),
LcnfExpr::Unreachable => {}
LcnfExpr::TailCall(func, args) => {
self.count_arg(func);
for arg in args {
self.count_arg(arg);
}
}
}
}
}
#[derive(Clone, PartialEq, Debug)]
pub enum LcnfLetValue {
App(LcnfArg, Vec<LcnfArg>),
Proj(String, u32, LcnfVarId),
Ctor(String, u32, Vec<LcnfArg>),
Lit(LcnfLit),
Erased,
FVar(LcnfVarId),
Reset(LcnfVarId),
Reuse(LcnfVarId, String, u32, Vec<LcnfArg>),
}
#[derive(Clone, Debug, Default)]
pub struct Substitution(pub HashMap<LcnfVarId, LcnfArg>);
impl Substitution {
pub fn new() -> Self {
Substitution(HashMap::new())
}
pub fn insert(&mut self, var: LcnfVarId, arg: LcnfArg) {
self.0.insert(var, arg);
}
pub fn get(&self, var: &LcnfVarId) -> Option<&LcnfArg> {
self.0.get(var)
}
pub fn contains(&self, var: &LcnfVarId) -> bool {
self.0.contains_key(var)
}
pub fn compose(&self, other: &Substitution) -> Substitution {
let mut result = HashMap::new();
for (var, arg) in &self.0 {
result.insert(*var, substitute_arg(arg, other));
}
for (var, arg) in &other.0 {
result.entry(*var).or_insert_with(|| arg.clone());
}
Substitution(result)
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct LcnfParam {
pub id: LcnfVarId,
pub name: String,
pub ty: LcnfType,
pub erased: bool,
pub borrowed: bool,
}
#[derive(Clone, Debug, PartialEq)]
pub struct DefinitionSite {
pub var: LcnfVarId,
pub name: String,
pub ty: LcnfType,
pub depth: usize,
}
#[derive(Clone, PartialEq, Debug)]
pub struct LcnfAlt {
pub ctor_name: String,
pub ctor_tag: u32,
pub params: Vec<LcnfParam>,
pub body: LcnfExpr,
}
#[derive(Clone, PartialEq, Debug)]
pub struct LcnfExternDecl {
pub name: String,
pub params: Vec<LcnfParam>,
pub ret_type: LcnfType,
}
#[derive(Clone, PartialEq, Debug)]
pub struct LcnfFunDecl {
pub name: String,
pub original_name: Option<Name>,
pub params: Vec<LcnfParam>,
pub ret_type: LcnfType,
pub body: LcnfExpr,
pub is_recursive: bool,
pub is_lifted: bool,
pub inline_cost: usize,
}
#[derive(Clone, Debug, PartialEq)]
pub enum ValidationError {
UnboundVariable(LcnfVarId),
DuplicateBinding(LcnfVarId),
EmptyCase,
InvalidTag(String, u32),
NonAtomicArgument,
}
pub struct LcnfBuilder {
pub(super) next_id: u64,
pub(super) bindings: Vec<(LcnfVarId, String, LcnfType, LcnfLetValue)>,
}
impl LcnfBuilder {
pub fn new() -> Self {
LcnfBuilder {
next_id: 0,
bindings: Vec::new(),
}
}
pub fn with_start_id(start: u64) -> Self {
LcnfBuilder {
next_id: start,
bindings: Vec::new(),
}
}
pub fn fresh_var(&mut self, _name: &str, _ty: LcnfType) -> LcnfVarId {
let id = LcnfVarId(self.next_id);
self.next_id += 1;
id
}
pub fn let_bind(&mut self, name: &str, ty: LcnfType, val: LcnfLetValue) -> LcnfVarId {
let id = LcnfVarId(self.next_id);
self.next_id += 1;
self.bindings.push((id, name.to_string(), ty, val));
id
}
pub fn let_app(
&mut self,
name: &str,
ty: LcnfType,
func: LcnfArg,
args: Vec<LcnfArg>,
) -> LcnfVarId {
self.let_bind(name, ty, LcnfLetValue::App(func, args))
}
pub fn let_ctor(
&mut self,
name: &str,
ty: LcnfType,
ctor: &str,
tag: u32,
args: Vec<LcnfArg>,
) -> LcnfVarId {
self.let_bind(name, ty, LcnfLetValue::Ctor(ctor.to_string(), tag, args))
}
pub fn let_proj(
&mut self,
name: &str,
ty: LcnfType,
type_name: &str,
idx: u32,
var: LcnfVarId,
) -> LcnfVarId {
self.let_bind(
name,
ty,
LcnfLetValue::Proj(type_name.to_string(), idx, var),
)
}
pub fn build_return(self, arg: LcnfArg) -> LcnfExpr {
self.wrap_bindings(LcnfExpr::Return(arg))
}
pub fn build_case(
self,
scrutinee: LcnfVarId,
scrutinee_ty: LcnfType,
alts: Vec<LcnfAlt>,
default: Option<LcnfExpr>,
) -> LcnfExpr {
self.wrap_bindings(LcnfExpr::Case {
scrutinee,
scrutinee_ty,
alts,
default: default.map(Box::new),
})
}
pub fn build_tail_call(self, func: LcnfArg, args: Vec<LcnfArg>) -> LcnfExpr {
self.wrap_bindings(LcnfExpr::TailCall(func, args))
}
pub(super) fn wrap_bindings(self, terminal: LcnfExpr) -> LcnfExpr {
let mut result = terminal;
for (id, name, ty, value) in self.bindings.into_iter().rev() {
result = LcnfExpr::Let {
id,
name,
ty,
value,
body: Box::new(result),
};
}
result
}
pub fn peek_next_id(&self) -> u64 {
self.next_id
}
pub fn binding_count(&self) -> usize {
self.bindings.len()
}
}
#[derive(Clone, Debug)]
pub struct CostModel {
pub let_cost: u64,
pub app_cost: u64,
pub case_cost: u64,
pub return_cost: u64,
pub branch_penalty: u64,
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum LcnfLit {
Nat(u64),
Str(String),
}
#[derive(Clone, PartialEq, Debug)]
pub enum LcnfExpr {
Let {
id: LcnfVarId,
name: String,
ty: LcnfType,
value: LcnfLetValue,
body: Box<LcnfExpr>,
},
Case {
scrutinee: LcnfVarId,
scrutinee_ty: LcnfType,
alts: Vec<LcnfAlt>,
default: Option<Box<LcnfExpr>>,
},
Return(LcnfArg),
Unreachable,
TailCall(LcnfArg, Vec<LcnfArg>),
}