eqlog 0.9.0

Datalog with equality
Documentation
use crate::flat_eqlog::*;
use crate::fmt_util::*;
use crate::rust_gen::flat_eqlog::display_flat_rule;
use crate::rust_gen::*;
use convert_case::{Case, Casing};
use indoc::writedoc;
use itertools::Itertools;
use std::fmt::{Display, Formatter, Result};

use Case::{Snake, UpperCamel};

fn display_imports<'a>() -> impl 'a + Display {
    FmtFn(move |f: &mut Formatter| -> Result {
        writedoc! {f, "
            #[allow(unused)]
            use eqlog_runtime::*;
            #[allow(unused)]
            use std::cell::LazyCell;
        "}
    })
}

pub fn display_module_env_struct_name<'a>(ram_module: &'a RamModule) -> impl 'a + Display {
    FmtFn(move |f| {
        let name_camel = &ram_module.name.to_case(UpperCamel);
        write!(f, "{name_camel}Env")
    })
}

pub fn module_env_in_rels(ram_module: &RamModule) -> BTreeSet<(FlatInRel, IndexSpec)> {
    ram_module
        .routines
        .iter()
        .flat_map(|routine| {
            routine
                .stmts
                .iter()
                .filter_map(|stmt| -> Option<(FlatInRel, IndexSpec)> {
                    let define_set_stmt = match stmt {
                        RamStmt::DefineSet(define_set_stmt) => define_set_stmt,
                        RamStmt::Iter(_) | RamStmt::Insert(_) | RamStmt::GuardInhabited(_) => {
                            return None;
                        }
                    };

                    let GetIndexExpr { rel, index_spec } = match &define_set_stmt.expr {
                        InSetExpr::GetIndex(get_index_expr) => get_index_expr,
                        InSetExpr::Restrict(_) => {
                            return None;
                        }
                    };

                    Some((rel.clone(), index_spec.clone()))
                })
        })
        .collect()
}

pub fn module_env_out_rels(ram_module: &RamModule) -> BTreeSet<FlatOutRel> {
    ram_module
        .routines
        .iter()
        .flat_map(|routine| {
            routine
                .stmts
                .iter()
                .filter_map(|stmt| -> Option<FlatOutRel> {
                    let InsertStmt { rel, args: _ } = match stmt {
                        RamStmt::DefineSet(_) | RamStmt::Iter(_) | RamStmt::GuardInhabited(_) => {
                            return None;
                        }
                        RamStmt::Insert(insert_stmt) => insert_stmt,
                    };

                    Some(rel.clone())
                })
        })
        .collect()
}

pub fn display_module_env_struct<'a>(
    ram_module: &'a RamModule,
    ctx: &'a RustGenCtx<'a>,
) -> impl 'a + Display {
    FmtFn(move |f: &mut Formatter| -> Result {
        let in_rels = module_env_in_rels(ram_module)
            .into_iter()
            .map(|(rel, index_spec)| {
                FmtFn(move |f| {
                    let name = display_index_field_name(&rel, &index_spec, ctx);
                    let typ = display_index_type(&rel, ctx);

                    write!(f, "{name}: &'a {typ},")
                })
            })
            .format("\n");

        let out_rels = module_env_out_rels(ram_module)
            .into_iter()
            .map(|rel| {
                FmtFn(move |f| {
                    let name = display_out_set_field_name(&rel, ctx);
                    let typ = display_out_set_type(&rel, ctx);

                    write!(f, "{name}: &'a mut {typ},")
                })
            })
            .format("\n");

        let name = display_module_env_struct_name(ram_module);

        writedoc! {f, "
            #[allow(unused)]
            pub struct {name}<'a> {{
            phantom: std::marker::PhantomData<&'a ()>,

            {in_rels}

            {out_rels}
            }}
        "}
    })
}

fn display_set_var<'a>(set_var: &'a SetVar, ctx: &'a RustGenCtx<'a>) -> impl 'a + Display {
    FmtFn(move |f| {
        let SetVarName {
            stmt_index,
            rel,
            index,
            restricted,
        } = set_var.name.clone();
        let field_name = display_index_field_name(&rel, &index, ctx);
        write!(f, "set{stmt_index}_{field_name}_r{restricted}")
    })
}

fn display_in_set_expr<'a>(expr: &'a InSetExpr, ctx: &'a RustGenCtx<'a>) -> impl 'a + Display {
    FmtFn(move |f| match expr {
        InSetExpr::GetIndex(GetIndexExpr { rel, index_spec }) => {
            let index_field = display_index_field_name(rel, index_spec, ctx);
            write!(f, "env.{index_field}")
        }
        InSetExpr::Restrict(RestrictExpr {
            set,
            first_column_var,
        }) => {
            let result_arity = set.arity - 1;
            let set = display_set_var(set, ctx);
            write!(f, "{set}.get({first_column_var}).unwrap_or_else(|| PrefixTree{result_arity}::empty())")
        }
    })
}

fn display_stmt_pre<'a>(ram_stmt: &'a RamStmt, ctx: &'a RustGenCtx<'a>) -> impl 'a + Display {
    FmtFn(move |f| {
        match ram_stmt {
            RamStmt::DefineSet(DefineSetStmt { defined_var, expr }) => {
                let expr = display_in_set_expr(expr, ctx);
                let strictness = defined_var.strictness;
                let defined_var = display_set_var(defined_var, ctx);
                match strictness {
                    Strictness::Lazy => {
                        writedoc! {f, "
                            let {defined_var} =
                            LazyCell::new(|| {{
                            {expr}
                            }});
                        "}
                    }
                    Strictness::Strict => {
                        writedoc! {f, "
                            let {defined_var} =
                            {expr}
                            ;
                        "}
                    }
                }
            }
            RamStmt::Iter(IterStmt {
                sets,
                loop_var_el,
                loop_var_set,
            }) => {
                let SetVar {
                    name: _,
                    arity: _,
                    strictness,
                } = loop_var_set;
                match strictness {
                    Strictness::Lazy => {
                        panic!("Loop set variables for iter statements must be strict")
                    }
                    Strictness::Strict => {}
                }
                assert!(sets.len() >= 1, "Expected at least one set in IterStmt");
                let set_head = display_set_var(&sets[0], ctx);
                let set_tail_chain_iters = sets[1..]
                    .iter()
                    .map(|set| {
                        FmtFn(move |f| {
                            let set = display_set_var(set, ctx);
                            write!(f, ".chain({set}.iter_restrictions())")
                        })
                    })
                    .format("\n");
                let loop_var_set = display_set_var(loop_var_set, ctx);
                writedoc! {f, "
                    #[allow(unused_variables)]
                    for
                    ({loop_var_el}, {loop_var_set})
                    in
                    {set_head}.iter_restrictions()
                    {set_tail_chain_iters}
                    {{
                "}
            }
            RamStmt::Insert(InsertStmt { rel, args }) => {
                let rel_field = display_out_set_field_name(rel, ctx);
                let args = args.iter().format(", ");
                // TODO: Check that this row doesn't exist already in indices.
                writedoc! {f, "
                    env.{rel_field}.push([{args}]);
                "}
            }
            RamStmt::GuardInhabited(GuardInhabitedStmt { sets }) => {
                let checks = sets
                    .iter()
                    .map(|set| {
                        FmtFn(move |f| {
                            let set = display_set_var(set, ctx);
                            write!(f, "|| !{set}.is_empty()")
                        })
                    })
                    .format("");
                writedoc! {f, "
                    if false {checks} {{
                "}
            }
        }
    })
}

fn display_stmt_post<'a>(ram_stmt: &'a RamStmt) -> impl 'a + Display {
    FmtFn(move |f| match ram_stmt {
        RamStmt::DefineSet(_) | RamStmt::Insert(_) => Ok(()),
        RamStmt::Iter(_) | RamStmt::GuardInhabited(_) => {
            writedoc! {f, "
                    }}
                "}
        }
    })
}

fn display_routine<'a>(
    RamRoutine {
        name,
        flat_rule,
        stmts,
    }: &'a RamRoutine,
    ram_module: &'a RamModule,
    ctx: &'a RustGenCtx<'a>,
) -> impl 'a + Display {
    FmtFn(move |f| {
        let name = name;
        let env_type = display_module_env_struct_name(ram_module);

        let stmts_pre = stmts
            .iter()
            .map(|stmt| display_stmt_pre(stmt, ctx))
            .format("\n");
        let stmts_post = stmts
            .iter()
            .rev()
            .map(|stmt| display_stmt_post(stmt))
            .format("\n");

        let flat_rule = display_flat_rule(flat_rule, ctx).to_string();
        let flat_rule_comment = flat_rule
            .lines()
            .map(|line| FmtFn(move |f| write!(f, "// {line}")))
            .format("\n");

        writedoc! {f, "
            {flat_rule_comment}
            fn {name}(env: &mut {env_type}) {{
            {stmts_pre}
            {stmts_post}
            }}
        "}
    })
}

pub fn display_module_main_fn_name<'a>(ram_module: &'a RamModule) -> impl 'a + Display {
    ram_module.name.to_case(Snake)
}

pub fn display_module_main_fn_decl<'a>(
    ram_module: &'a RamModule,
    symbol_prefix: &'a str,
) -> impl 'a + Display {
    FmtFn(move |f| {
        let fn_name = display_module_main_fn_name(ram_module);
        let env_name = display_module_env_struct_name(ram_module);
        writedoc! {f, r#"
            #[link_name = "{symbol_prefix}_{fn_name}"]
            safe fn {fn_name}(env: {env_name});
        "#}
    })
}

fn display_module_main_fn<'a>(
    ram_module: &'a RamModule,
    symbol_prefix: &'a str,
) -> impl 'a + Display {
    FmtFn(move |f| {
        let fn_name = display_module_main_fn_name(ram_module);
        let env_name = display_module_env_struct_name(ram_module);
        let calls = ram_module
            .routines
            .iter()
            .map(|routine| {
                FmtFn(move |f| {
                    let name = &routine.name;
                    write!(f, "{name}(&mut env);")
                })
            })
            .format("\n");
        writedoc! {f, r#"
            #[unsafe(no_mangle)]
            pub fn {symbol_prefix}_{fn_name}(mut env: {env_name}) {{
            {calls}
            }}
        "#}
    })
}

pub fn display_ram_module<'a>(
    ram_module: &'a RamModule,
    _index_selection: &'a IndexSelection,
    ctx: &'a RustGenCtx<'a>,
    symbol_prefix: &'a str,
) -> impl 'a + Display {
    FmtFn(move |f: &mut Formatter| -> Result {
        let imports = display_imports();
        let env_struct = display_module_env_struct(ram_module, ctx);
        let main_fn = display_module_main_fn(ram_module, symbol_prefix);
        let routines = ram_module
            .routines
            .iter()
            .map(|routine| display_routine(routine, ram_module, ctx))
            .format("\n");

        writedoc! {f, r#"
            {imports}
            {env_struct}
            {routines}
            {main_fn}
        "#}
    })
}