use crate::declaration::{ConstructorVal, InductiveVal};
use crate::def_eq::DefEqChecker;
use crate::error::KernelError;
use crate::expr_util::{get_app_args, get_app_fn, has_any_fvar, mk_app};
use crate::instantiate::instantiate_type_lparams;
use crate::reduce::{Reducer, TransparencyMode};
use crate::subst::{abstract_expr, instantiate};
use crate::{BinderInfo, Environment, Expr, FVarId, Level, Literal, Name};
use std::collections::HashMap;
#[allow(dead_code)]
pub struct SimpleDag {
edges: Vec<Vec<usize>>,
}
#[allow(dead_code)]
impl SimpleDag {
pub fn new(n: usize) -> Self {
Self {
edges: vec![Vec::new(); n],
}
}
pub fn add_edge(&mut self, from: usize, to: usize) {
if from < self.edges.len() {
self.edges[from].push(to);
}
}
pub fn successors(&self, node: usize) -> &[usize] {
self.edges.get(node).map(|v| v.as_slice()).unwrap_or(&[])
}
pub fn can_reach(&self, from: usize, to: usize) -> bool {
let mut visited = vec![false; self.edges.len()];
self.dfs(from, to, &mut visited)
}
fn dfs(&self, cur: usize, target: usize, visited: &mut Vec<bool>) -> bool {
if cur == target {
return true;
}
if cur >= visited.len() || visited[cur] {
return false;
}
visited[cur] = true;
for &next in self.successors(cur) {
if self.dfs(next, target, visited) {
return true;
}
}
false
}
pub fn topological_sort(&self) -> Option<Vec<usize>> {
let n = self.edges.len();
let mut in_degree = vec![0usize; n];
for succs in &self.edges {
for &s in succs {
if s < n {
in_degree[s] += 1;
}
}
}
let mut queue: std::collections::VecDeque<usize> =
(0..n).filter(|&i| in_degree[i] == 0).collect();
let mut order = Vec::new();
while let Some(node) = queue.pop_front() {
order.push(node);
for &s in self.successors(node) {
if s < n {
in_degree[s] -= 1;
if in_degree[s] == 0 {
queue.push_back(s);
}
}
}
}
if order.len() == n {
Some(order)
} else {
None
}
}
pub fn num_nodes(&self) -> usize {
self.edges.len()
}
}
#[allow(dead_code)]
pub struct TransitiveClosure {
adj: Vec<Vec<usize>>,
n: usize,
}
#[allow(dead_code)]
impl TransitiveClosure {
pub fn new(n: usize) -> Self {
Self {
adj: vec![Vec::new(); n],
n,
}
}
pub fn add_edge(&mut self, from: usize, to: usize) {
if from < self.n {
self.adj[from].push(to);
}
}
pub fn reachable_from(&self, start: usize) -> Vec<usize> {
let mut visited = vec![false; self.n];
let mut queue = std::collections::VecDeque::new();
queue.push_back(start);
while let Some(node) = queue.pop_front() {
if node >= self.n || visited[node] {
continue;
}
visited[node] = true;
for &next in &self.adj[node] {
queue.push_back(next);
}
}
(0..self.n).filter(|&i| visited[i]).collect()
}
pub fn can_reach(&self, from: usize, to: usize) -> bool {
self.reachable_from(from).contains(&to)
}
}
#[allow(dead_code)]
pub struct ConfigNode {
key: String,
value: Option<String>,
children: Vec<ConfigNode>,
}
#[allow(dead_code)]
impl ConfigNode {
pub fn leaf(key: impl Into<String>, value: impl Into<String>) -> Self {
Self {
key: key.into(),
value: Some(value.into()),
children: Vec::new(),
}
}
pub fn section(key: impl Into<String>) -> Self {
Self {
key: key.into(),
value: None,
children: Vec::new(),
}
}
pub fn add_child(&mut self, child: ConfigNode) {
self.children.push(child);
}
pub fn key(&self) -> &str {
&self.key
}
pub fn value(&self) -> Option<&str> {
self.value.as_deref()
}
pub fn num_children(&self) -> usize {
self.children.len()
}
pub fn lookup(&self, path: &str) -> Option<&str> {
let mut parts = path.splitn(2, '.');
let head = parts.next()?;
let tail = parts.next();
if head != self.key {
return None;
}
match tail {
None => self.value.as_deref(),
Some(rest) => self.children.iter().find_map(|c| c.lookup_relative(rest)),
}
}
fn lookup_relative(&self, path: &str) -> Option<&str> {
let mut parts = path.splitn(2, '.');
let head = parts.next()?;
let tail = parts.next();
if head != self.key {
return None;
}
match tail {
None => self.value.as_deref(),
Some(rest) => self.children.iter().find_map(|c| c.lookup_relative(rest)),
}
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct TypingJudgment {
pub expr: Expr,
pub ty: Expr,
pub success: bool,
}
impl TypingJudgment {
#[allow(dead_code)]
pub fn ok(expr: Expr, ty: Expr) -> Self {
Self {
expr,
ty,
success: true,
}
}
#[allow(dead_code)]
pub fn fail(expr: Expr) -> Self {
Self {
ty: Expr::Sort(Level::zero()),
expr,
success: false,
}
}
#[allow(dead_code)]
pub fn is_ok(&self) -> bool {
self.success
}
}
#[allow(dead_code)]
pub struct SparseVec<T: Default + Clone + PartialEq> {
entries: std::collections::HashMap<usize, T>,
default_: T,
logical_len: usize,
}
#[allow(dead_code)]
impl<T: Default + Clone + PartialEq> SparseVec<T> {
pub fn new(len: usize) -> Self {
Self {
entries: std::collections::HashMap::new(),
default_: T::default(),
logical_len: len,
}
}
pub fn set(&mut self, idx: usize, val: T) {
if val == self.default_ {
self.entries.remove(&idx);
} else {
self.entries.insert(idx, val);
}
}
pub fn get(&self, idx: usize) -> &T {
self.entries.get(&idx).unwrap_or(&self.default_)
}
pub fn len(&self) -> usize {
self.logical_len
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn nnz(&self) -> usize {
self.entries.len()
}
}
#[allow(dead_code)]
pub struct StackCalc {
stack: Vec<i64>,
}
#[allow(dead_code)]
impl StackCalc {
pub fn new() -> Self {
Self { stack: Vec::new() }
}
pub fn push(&mut self, n: i64) {
self.stack.push(n);
}
pub fn add(&mut self) {
let b = self
.stack
.pop()
.expect("stack must have at least two values for add");
let a = self
.stack
.pop()
.expect("stack must have at least two values for add");
self.stack.push(a + b);
}
pub fn sub(&mut self) {
let b = self
.stack
.pop()
.expect("stack must have at least two values for sub");
let a = self
.stack
.pop()
.expect("stack must have at least two values for sub");
self.stack.push(a - b);
}
pub fn mul(&mut self) {
let b = self
.stack
.pop()
.expect("stack must have at least two values for mul");
let a = self
.stack
.pop()
.expect("stack must have at least two values for mul");
self.stack.push(a * b);
}
pub fn peek(&self) -> Option<i64> {
self.stack.last().copied()
}
pub fn depth(&self) -> usize {
self.stack.len()
}
}
#[allow(dead_code)]
pub struct WindowIterator<'a, T> {
pub(super) data: &'a [T],
pub(super) pos: usize,
pub(super) window: usize,
}
#[allow(dead_code)]
impl<'a, T> WindowIterator<'a, T> {
pub fn new(data: &'a [T], window: usize) -> Self {
Self {
data,
pos: 0,
window,
}
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct InferCacheEntry {
pub expr: Expr,
pub ty: Expr,
}
#[allow(dead_code)]
#[allow(missing_docs)]
pub struct RewriteRule {
pub name: String,
pub lhs: String,
pub rhs: String,
pub conditional: bool,
}
#[allow(dead_code)]
impl RewriteRule {
pub fn unconditional(
name: impl Into<String>,
lhs: impl Into<String>,
rhs: impl Into<String>,
) -> Self {
Self {
name: name.into(),
lhs: lhs.into(),
rhs: rhs.into(),
conditional: false,
}
}
pub fn conditional(
name: impl Into<String>,
lhs: impl Into<String>,
rhs: impl Into<String>,
) -> Self {
Self {
name: name.into(),
lhs: lhs.into(),
rhs: rhs.into(),
conditional: true,
}
}
pub fn display(&self) -> String {
format!("{}: {} → {}", self.name, self.lhs, self.rhs)
}
}
#[allow(dead_code)]
pub struct RewriteRuleSet {
rules: Vec<RewriteRule>,
}
#[allow(dead_code)]
impl RewriteRuleSet {
pub fn new() -> Self {
Self { rules: Vec::new() }
}
pub fn add(&mut self, rule: RewriteRule) {
self.rules.push(rule);
}
pub fn len(&self) -> usize {
self.rules.len()
}
pub fn is_empty(&self) -> bool {
self.rules.is_empty()
}
pub fn conditional_rules(&self) -> Vec<&RewriteRule> {
self.rules.iter().filter(|r| r.conditional).collect()
}
pub fn unconditional_rules(&self) -> Vec<&RewriteRule> {
self.rules.iter().filter(|r| !r.conditional).collect()
}
pub fn get(&self, name: &str) -> Option<&RewriteRule> {
self.rules.iter().find(|r| r.name == name)
}
}
#[allow(dead_code)]
pub struct Stopwatch {
start: std::time::Instant,
splits: Vec<f64>,
}
#[allow(dead_code)]
impl Stopwatch {
pub fn start() -> Self {
Self {
start: std::time::Instant::now(),
splits: Vec::new(),
}
}
pub fn split(&mut self) {
self.splits.push(self.elapsed_ms());
}
pub fn elapsed_ms(&self) -> f64 {
self.start.elapsed().as_secs_f64() * 1000.0
}
pub fn splits(&self) -> &[f64] {
&self.splits
}
pub fn num_splits(&self) -> usize {
self.splits.len()
}
}
#[derive(Clone, Debug, Default)]
pub struct InferStats {
pub infer_calls: usize,
pub whnf_calls: usize,
pub def_eq_calls: usize,
pub const_lookups: usize,
pub cache_hits: usize,
}
impl InferStats {
pub fn reset(&mut self) {
*self = InferStats::default();
}
pub fn total_ops(&self) -> usize {
self.infer_calls + self.whnf_calls + self.def_eq_calls
}
}
#[allow(dead_code)]
pub struct FocusStack<T> {
items: Vec<T>,
}
#[allow(dead_code)]
impl<T> FocusStack<T> {
pub fn new() -> Self {
Self { items: Vec::new() }
}
pub fn focus(&mut self, item: T) {
self.items.push(item);
}
pub fn blur(&mut self) -> Option<T> {
self.items.pop()
}
pub fn current(&self) -> Option<&T> {
self.items.last()
}
pub fn depth(&self) -> usize {
self.items.len()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
}
#[allow(dead_code)]
pub struct NonEmptyVec<T> {
head: T,
tail: Vec<T>,
}
#[allow(dead_code)]
impl<T> NonEmptyVec<T> {
pub fn singleton(val: T) -> Self {
Self {
head: val,
tail: Vec::new(),
}
}
pub fn push(&mut self, val: T) {
self.tail.push(val);
}
pub fn first(&self) -> &T {
&self.head
}
pub fn last(&self) -> &T {
self.tail.last().unwrap_or(&self.head)
}
pub fn len(&self) -> usize {
1 + self.tail.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn to_vec(&self) -> Vec<&T> {
let mut v = vec![&self.head];
v.extend(self.tail.iter());
v
}
}
#[allow(dead_code)]
pub struct TransformStat {
before: StatSummary,
after: StatSummary,
}
#[allow(dead_code)]
impl TransformStat {
pub fn new() -> Self {
Self {
before: StatSummary::new(),
after: StatSummary::new(),
}
}
pub fn record_before(&mut self, v: f64) {
self.before.record(v);
}
pub fn record_after(&mut self, v: f64) {
self.after.record(v);
}
pub fn mean_ratio(&self) -> Option<f64> {
let b = self.before.mean()?;
let a = self.after.mean()?;
if b.abs() < f64::EPSILON {
return None;
}
Some(a / b)
}
}
#[allow(dead_code)]
pub struct TokenBucket {
capacity: u64,
tokens: u64,
refill_per_ms: u64,
last_refill: std::time::Instant,
}
#[allow(dead_code)]
impl TokenBucket {
pub fn new(capacity: u64, refill_per_ms: u64) -> Self {
Self {
capacity,
tokens: capacity,
refill_per_ms,
last_refill: std::time::Instant::now(),
}
}
pub fn try_consume(&mut self, n: u64) -> bool {
self.refill();
if self.tokens >= n {
self.tokens -= n;
true
} else {
false
}
}
fn refill(&mut self) {
let now = std::time::Instant::now();
let elapsed_ms = now.duration_since(self.last_refill).as_millis() as u64;
if elapsed_ms > 0 {
let new_tokens = elapsed_ms * self.refill_per_ms;
self.tokens = (self.tokens + new_tokens).min(self.capacity);
self.last_refill = now;
}
}
pub fn available(&self) -> u64 {
self.tokens
}
pub fn capacity(&self) -> u64 {
self.capacity
}
}
#[allow(dead_code)]
pub struct FlatSubstitution {
pairs: Vec<(String, String)>,
}
#[allow(dead_code)]
impl FlatSubstitution {
pub fn new() -> Self {
Self { pairs: Vec::new() }
}
pub fn add(&mut self, from: impl Into<String>, to: impl Into<String>) {
self.pairs.push((from.into(), to.into()));
}
pub fn apply(&self, s: &str) -> String {
let mut result = s.to_string();
for (from, to) in &self.pairs {
result = result.replace(from.as_str(), to.as_str());
}
result
}
pub fn len(&self) -> usize {
self.pairs.len()
}
pub fn is_empty(&self) -> bool {
self.pairs.is_empty()
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TypeKind {
Prop,
Type0,
LargeType,
Universe,
Pi,
Lambda,
Application,
FreeVar,
Constant,
Literal,
Unknown,
}
#[allow(dead_code)]
pub struct WriteOnce<T> {
value: std::cell::Cell<Option<T>>,
}
#[allow(dead_code)]
impl<T: Copy> WriteOnce<T> {
pub fn new() -> Self {
Self {
value: std::cell::Cell::new(None),
}
}
pub fn write(&self, val: T) -> bool {
if self.value.get().is_some() {
return false;
}
self.value.set(Some(val));
true
}
pub fn read(&self) -> Option<T> {
self.value.get()
}
pub fn is_written(&self) -> bool {
self.value.get().is_some()
}
}
pub struct TypeChecker<'env> {
env: &'env Environment,
local_ctx: Vec<LocalDecl>,
reducer: Reducer,
def_eq_checker: DefEqChecker<'env>,
next_fvar: u64,
check_mode: bool,
}
impl<'env> TypeChecker<'env> {
pub fn new(env: &'env Environment) -> Self {
Self {
env,
local_ctx: Vec::new(),
reducer: Reducer::new(),
def_eq_checker: DefEqChecker::new(env),
next_fvar: 0,
check_mode: true,
}
}
pub fn new_infer_only(env: &'env Environment) -> Self {
Self {
env,
local_ctx: Vec::new(),
reducer: Reducer::new(),
def_eq_checker: DefEqChecker::new(env),
next_fvar: 0,
check_mode: false,
}
}
pub fn set_transparency(&mut self, mode: TransparencyMode) {
self.reducer.set_transparency(mode);
self.def_eq_checker.set_transparency(mode);
}
pub fn env(&self) -> &Environment {
self.env
}
pub fn fresh_fvar(&mut self, name: Name, ty: Expr) -> FVarId {
let fvar = FVarId(self.next_fvar);
self.next_fvar += 1;
self.local_ctx.push(LocalDecl {
fvar,
name,
ty,
val: None,
});
fvar
}
pub fn fresh_fvar_let(&mut self, name: Name, ty: Expr, val: Expr) -> FVarId {
let fvar = FVarId(self.next_fvar);
self.next_fvar += 1;
self.local_ctx.push(LocalDecl {
fvar,
name,
ty,
val: Some(val),
});
fvar
}
pub fn push_local(&mut self, decl: LocalDecl) {
self.local_ctx.push(decl);
}
pub fn pop_local(&mut self) -> Option<LocalDecl> {
self.local_ctx.pop()
}
#[allow(clippy::result_large_err)]
fn lookup_fvar(&self, fvar: FVarId) -> Result<&LocalDecl, KernelError> {
self.local_ctx
.iter()
.find(|decl| decl.fvar == fvar)
.ok_or_else(|| KernelError::Other(format!("free variable not found: {:?}", fvar)))
}
pub fn local_ctx(&self) -> &[LocalDecl] {
&self.local_ctx
}
pub fn whnf(&mut self, expr: &Expr) -> Expr {
self.reducer.whnf_env(expr, self.env)
}
#[allow(clippy::result_large_err)]
pub fn ensure_sort(&mut self, expr: &Expr) -> Result<Level, KernelError> {
let ty = self.infer_type(expr)?;
let ty_whnf = self.whnf(&ty);
match ty_whnf {
Expr::Sort(l) => Ok(l),
_ => Err(KernelError::NotASort(ty_whnf)),
}
}
#[allow(clippy::result_large_err)]
pub fn ensure_pi(&mut self, expr: &Expr) -> Result<Expr, KernelError> {
let ty = self.infer_type(expr)?;
let ty_whnf = self.whnf(&ty);
if ty_whnf.is_pi() {
Ok(ty_whnf)
} else {
Err(KernelError::NotAFunction(ty_whnf))
}
}
pub fn is_def_eq(&mut self, t: &Expr, s: &Expr) -> bool {
self.def_eq_checker.is_def_eq(t, s)
}
#[allow(clippy::result_large_err)]
pub fn check_type(
&mut self,
expr: &Expr,
inferred: &Expr,
expected: &Expr,
) -> Result<(), KernelError> {
if self.is_def_eq(inferred, expected) {
Ok(())
} else {
Err(KernelError::TypeMismatch {
expected: expected.clone(),
got: inferred.clone(),
context: format!("checking {}", expr),
})
}
}
#[allow(clippy::result_large_err)]
pub fn infer_type(&mut self, expr: &Expr) -> Result<Expr, KernelError> {
match expr {
Expr::Sort(l) => Ok(Expr::Sort(Level::succ(l.clone()))),
Expr::BVar(idx) => Err(KernelError::UnboundVariable(*idx)),
Expr::FVar(fvar) => {
let decl = self.lookup_fvar(*fvar)?;
Ok(decl.ty.clone())
}
Expr::Const(name, levels) => self.infer_const(name, levels),
Expr::App(f, a) => self.infer_app(f, a),
Expr::Lam(bi, name, ty, body) => {
if self.check_mode {
self.ensure_sort(ty)?;
}
let fvar = self.fresh_fvar(name.clone(), (**ty).clone());
let body_open = instantiate(body, &Expr::FVar(fvar));
let body_ty = self.infer_type(&body_open)?;
self.pop_local();
let body_ty_closed = abstract_expr(&body_ty, fvar);
Ok(Expr::Pi(
*bi,
name.clone(),
ty.clone(),
Box::new(body_ty_closed),
))
}
Expr::Pi(_, _, dom, cod) => {
let dom_sort = self.ensure_sort(dom)?;
let fvar = self.fresh_fvar(Name::str("_"), (**dom).clone());
let cod_open = instantiate(cod, &Expr::FVar(fvar));
let cod_sort = self.ensure_sort(&cod_open)?;
self.pop_local();
Ok(Expr::Sort(Level::imax(dom_sort, cod_sort)))
}
Expr::Let(_, ty, val, body) => {
if self.check_mode {
self.ensure_sort(ty)?;
let val_ty = self.infer_type(val)?;
self.check_type(val, &val_ty, ty)?;
}
let body_inst = instantiate(body, val);
self.infer_type(&body_inst)
}
Expr::Lit(Literal::Nat(_)) => Ok(Expr::Const(Name::str("Nat"), vec![])),
Expr::Lit(Literal::Str(_)) => Ok(Expr::Const(Name::str("String"), vec![])),
Expr::Proj(struct_name, idx, struct_expr) => {
self.infer_proj(struct_name, *idx, struct_expr)
}
}
}
#[allow(clippy::result_large_err)]
fn infer_const(&self, name: &Name, levels: &[Level]) -> Result<Expr, KernelError> {
if let Some(ci) = self.env.find(name) {
let params = ci.level_params();
if !params.is_empty() && !levels.is_empty() && params.len() != levels.len() {
return Err(KernelError::Other(format!(
"universe parameter count mismatch for {}: expected {}, got {}",
name,
params.len(),
levels.len()
)));
}
if params.is_empty() || levels.is_empty() {
return Ok(ci.ty().clone());
}
return Ok(instantiate_type_lparams(ci.ty(), params, levels));
}
let decl = self
.env
.get(name)
.ok_or_else(|| KernelError::UnknownConstant(name.clone()))?;
Ok(decl.ty().clone())
}
#[allow(clippy::result_large_err)]
fn infer_app(&mut self, f: &Expr, a: &Expr) -> Result<Expr, KernelError> {
let f_ty = self.infer_type(f)?;
let f_ty_whnf = self.whnf(&f_ty);
match &f_ty_whnf {
Expr::Pi(_, _, dom, cod) => {
if self.check_mode {
let a_ty = self.infer_type(a)?;
self.check_type(a, &a_ty, dom)?;
}
Ok(instantiate(cod, a))
}
_ => Err(KernelError::NotAFunction(f_ty_whnf)),
}
}
#[allow(clippy::result_large_err)]
fn infer_proj(
&mut self,
struct_name: &Name,
idx: u32,
struct_expr: &Expr,
) -> Result<Expr, KernelError> {
let ind_val = self
.env
.get_inductive_val(struct_name)
.ok_or_else(|| KernelError::Other(format!("not a structure type: {}", struct_name)))?
.clone();
if ind_val.ctors.len() != 1 {
return Err(KernelError::Other(format!(
"{} is not a structure (has {} constructors)",
struct_name,
ind_val.ctors.len()
)));
}
let ctor_name = &ind_val.ctors[0];
let ctor_val = self
.env
.get_constructor_val(ctor_name)
.ok_or_else(|| KernelError::Other(format!("constructor not found: {}", ctor_name)))?
.clone();
if idx >= ctor_val.num_fields {
return Err(KernelError::Other(format!(
"field index {} out of range for {} (has {} fields)",
idx, struct_name, ctor_val.num_fields
)));
}
let struct_ty = self.infer_type(struct_expr)?;
Ok(self.infer_proj_field_type(&ind_val, &ctor_val, idx, struct_expr, &struct_ty))
}
fn infer_proj_field_type(
&mut self,
ind_val: &InductiveVal,
ctor_val: &ConstructorVal,
idx: u32,
struct_expr: &Expr,
struct_ty: &Expr,
) -> Expr {
let ctor_ty = ctor_val.common.ty.clone();
let struct_ty_whnf = self.whnf(struct_ty);
let levels: Vec<Level> = match get_app_fn(&struct_ty_whnf) {
Expr::Const(_, lvls) => lvls.clone(),
_ => vec![],
};
let level_params = &ind_val.common.level_params;
let mut cur_ty = instantiate_type_lparams(&ctor_ty, level_params, &levels);
let struct_args: Vec<Expr> = get_app_args(&struct_ty_whnf).into_iter().cloned().collect();
for i in 0..ind_val.num_params as usize {
match cur_ty {
Expr::Pi(_, _, _, body) => {
let param = struct_args.get(i).cloned().unwrap_or(Expr::BVar(0));
cur_ty = instantiate(&body, ¶m);
}
_ => return struct_ty.clone(),
}
}
for j in 0..idx {
match cur_ty {
Expr::Pi(_, _, _, body) => {
let field_val = Expr::Proj(
ind_val.common.name.clone(),
j,
Box::new(struct_expr.clone()),
);
cur_ty = instantiate(&body, &field_val);
}
_ => return struct_ty.clone(),
}
}
match cur_ty {
Expr::Pi(_, _, dom, _) => *dom,
_ => struct_ty.clone(),
}
}
pub fn is_prop(&mut self, expr: &Expr) -> bool {
if let Ok(ty) = self.infer_type(expr) {
let ty_whnf = self.whnf(&ty);
matches!(ty_whnf, Expr::Sort(l) if l.is_zero())
} else {
false
}
}
pub fn is_proof(&mut self, expr: &Expr) -> bool {
if let Ok(ty) = self.infer_type(expr) {
self.is_prop(&ty)
} else {
false
}
}
pub fn is_type(&mut self, expr: &Expr) -> bool {
if let Ok(ty) = self.infer_type(expr) {
let ty_whnf = self.whnf(&ty);
ty_whnf.is_sort()
} else {
false
}
}
#[allow(clippy::result_large_err)]
pub fn get_level(&mut self, expr: &Expr) -> Result<Level, KernelError> {
let ty = self.infer_type(expr)?;
let ty_whnf = self.whnf(&ty);
match ty_whnf {
Expr::Sort(l) => Ok(l),
_ => Err(KernelError::NotASort(ty_whnf)),
}
}
pub fn unfold_definition(&mut self, expr: &Expr) -> Option<Expr> {
let head = get_app_fn(expr);
if let Expr::Const(name, levels) = head {
if let Some(ci) = self.env.find(name) {
if let Some(val) = ci.value() {
let hint = ci.reducibility_hint();
if hint.should_unfold() {
let val_inst = if ci.level_params().is_empty() || levels.is_empty() {
val.clone()
} else {
instantiate_type_lparams(val, ci.level_params(), levels)
};
let args: Vec<Expr> = get_app_args(expr).into_iter().cloned().collect();
return Some(mk_app(val_inst, &args));
}
}
}
}
None
}
}
impl<'env> TypeChecker<'env> {
#[allow(clippy::result_large_err)]
pub fn infer_app_chain(
&mut self,
f: &Expr,
args: &[Expr],
) -> Result<Expr, crate::error::KernelError> {
let mut ty = self.infer_type(f)?;
for arg in args {
let whnf = self.whnf(&ty);
match whnf {
Expr::Pi(_, _, dom, cod) => {
if self.check_mode {
let arg_ty = self.infer_type(arg)?;
if !self.is_def_eq(&arg_ty, &dom) {
return Err(crate::error::KernelError::TypeMismatch {
expected: *dom,
got: arg_ty,
context: "application argument".to_string(),
});
}
}
ty = instantiate(&cod, arg);
}
other => return Err(crate::error::KernelError::NotAFunction(other)),
}
}
Ok(ty)
}
pub fn telescope_type(&mut self, ty: &Expr, max_pis: usize) -> (Vec<LocalDecl>, Expr) {
let mut fvars = Vec::new();
let mut current = ty.clone();
for _ in 0..max_pis {
let whnf = self.whnf(¤t);
match whnf {
Expr::Pi(bi, name, dom, cod) => {
let fvar_id = self.fresh_fvar(name.clone(), *dom.clone());
let decl = LocalDecl {
fvar: fvar_id,
name,
ty: *dom,
val: None,
};
let body = instantiate(&cod, &Expr::FVar(fvar_id));
fvars.push(decl);
current = body;
let _ = bi;
}
_ => break,
}
}
(fvars, current)
}
pub fn close_type_over_fvars(&mut self, fvars: &[LocalDecl], ty: Expr) -> Expr {
let mut result = ty;
for decl in fvars.iter().rev() {
result = abstract_expr(&result, decl.fvar);
result = Expr::Pi(
crate::BinderInfo::Default,
decl.name.clone(),
Box::new(decl.ty.clone()),
Box::new(result),
);
}
result
}
pub fn close_term_over_fvars(&mut self, fvars: &[LocalDecl], term: Expr) -> Expr {
let mut result = term;
for decl in fvars.iter().rev() {
result = abstract_expr(&result, decl.fvar);
result = Expr::Lam(
crate::BinderInfo::Default,
decl.name.clone(),
Box::new(decl.ty.clone()),
Box::new(result),
);
}
result
}
#[allow(clippy::result_large_err)]
pub fn check(&mut self, expr: &Expr, expected: &Expr) -> Result<(), crate::error::KernelError> {
let inferred = self.infer_type(expr)?;
if self.is_def_eq(&inferred, expected) {
Ok(())
} else {
Err(crate::error::KernelError::TypeMismatch {
expected: expected.clone(),
got: inferred,
context: format!("check({:?})", expr),
})
}
}
pub fn try_check(&mut self, expr: &Expr, expected: &Expr) -> bool {
if let Ok(inferred) = self.infer_type(expr) {
self.is_def_eq(&inferred, expected)
} else {
false
}
}
pub fn count_pi_binders(&mut self, ty: &Expr) -> usize {
let mut count = 0;
let mut current = ty.clone();
loop {
let whnf = self.whnf(¤t);
if let Expr::Pi(_, _, _, cod) = whnf {
count += 1;
current = *cod;
} else {
break;
}
}
count
}
#[allow(clippy::result_large_err)]
pub fn verify_declaration(
&mut self,
name: &Name,
ty: &Expr,
val: Option<&Expr>,
) -> Result<(), crate::error::KernelError> {
self.ensure_sort(ty)?;
if let Some(v) = val {
let v_ty = self.infer_type(v)?;
if !self.is_def_eq(&v_ty, ty) {
return Err(crate::error::KernelError::TypeMismatch {
expected: ty.clone(),
got: v_ty,
context: format!("verifying declaration {}", name),
});
}
}
Ok(())
}
pub fn normalize(&mut self, expr: &Expr) -> Expr {
let whnf = self.whnf(expr);
match &whnf {
Expr::App(f, a) => {
let f_norm = self.normalize(f);
let a_norm = self.normalize(a);
Expr::App(Box::new(f_norm), Box::new(a_norm))
}
Expr::Lam(bi, name, ty, body) => {
let ty_norm = self.normalize(ty);
let body_norm = self.normalize(body);
Expr::Lam(*bi, name.clone(), Box::new(ty_norm), Box::new(body_norm))
}
Expr::Pi(bi, name, ty, body) => {
let ty_norm = self.normalize(ty);
let body_norm = self.normalize(body);
Expr::Pi(*bi, name.clone(), Box::new(ty_norm), Box::new(body_norm))
}
Expr::Let(name, ty, val, body) => {
let ty_norm = self.normalize(ty);
let val_norm = self.normalize(val);
let body_norm = self.normalize(body);
Expr::Let(
name.clone(),
Box::new(ty_norm),
Box::new(val_norm),
Box::new(body_norm),
)
}
Expr::Proj(sname, idx, inner) => {
let inner_norm = self.normalize(inner);
Expr::Proj(sname.clone(), *idx, Box::new(inner_norm))
}
_ => whnf,
}
}
pub fn constant_arity(&mut self, name: &Name) -> Option<usize> {
let ty = if let Some(ci) = self.env.find(name) {
ci.ty().clone()
} else {
self.env.get(name)?.ty().clone()
};
Some(self.count_pi_binders(&ty))
}
pub fn is_level_eq(&self, l1: &Level, l2: &Level) -> bool {
l1 == l2 || (l1.is_zero() && l2.is_zero())
}
}
#[allow(dead_code)]
pub struct SmallMap<K: Ord + Clone, V: Clone> {
entries: Vec<(K, V)>,
}
#[allow(dead_code)]
impl<K: Ord + Clone, V: Clone> SmallMap<K, V> {
pub fn new() -> Self {
Self {
entries: Vec::new(),
}
}
pub fn insert(&mut self, key: K, val: V) {
match self.entries.binary_search_by_key(&&key, |(k, _)| k) {
Ok(i) => self.entries[i].1 = val,
Err(i) => self.entries.insert(i, (key, val)),
}
}
pub fn get(&self, key: &K) -> Option<&V> {
self.entries
.binary_search_by_key(&key, |(k, _)| k)
.ok()
.map(|i| &self.entries[i].1)
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn keys(&self) -> Vec<&K> {
self.entries.iter().map(|(k, _)| k).collect()
}
pub fn values(&self) -> Vec<&V> {
self.entries.iter().map(|(_, v)| v).collect()
}
}
#[allow(dead_code)]
pub struct VersionedRecord<T: Clone> {
history: Vec<T>,
}
#[allow(dead_code)]
impl<T: Clone> VersionedRecord<T> {
pub fn new(initial: T) -> Self {
Self {
history: vec![initial],
}
}
pub fn update(&mut self, val: T) {
self.history.push(val);
}
pub fn current(&self) -> &T {
self.history
.last()
.expect("VersionedRecord history is always non-empty after construction")
}
pub fn at_version(&self, n: usize) -> Option<&T> {
self.history.get(n)
}
pub fn version(&self) -> usize {
self.history.len() - 1
}
pub fn has_history(&self) -> bool {
self.history.len() > 1
}
}
#[allow(dead_code)]
pub struct InferCache {
entries: Vec<InferCacheEntry>,
capacity: usize,
}
impl InferCache {
#[allow(dead_code)]
pub fn new(capacity: usize) -> Self {
Self {
entries: Vec::with_capacity(capacity),
capacity,
}
}
#[allow(dead_code)]
pub fn get(&self, expr: &Expr) -> Option<&Expr> {
self.entries
.iter()
.rev()
.find(|e| &e.expr == expr)
.map(|e| &e.ty)
}
#[allow(dead_code)]
pub fn insert(&mut self, expr: Expr, ty: Expr) {
if self.entries.len() >= self.capacity {
self.entries.remove(0);
}
self.entries.push(InferCacheEntry { expr, ty });
}
#[allow(dead_code)]
pub fn clear(&mut self) {
self.entries.clear();
}
#[allow(dead_code)]
pub fn len(&self) -> usize {
self.entries.len()
}
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[allow(dead_code)]
pub struct SlidingSum {
window: Vec<f64>,
capacity: usize,
pos: usize,
sum: f64,
count: usize,
}
#[allow(dead_code)]
impl SlidingSum {
pub fn new(capacity: usize) -> Self {
Self {
window: vec![0.0; capacity],
capacity,
pos: 0,
sum: 0.0,
count: 0,
}
}
pub fn push(&mut self, val: f64) {
let oldest = self.window[self.pos];
self.sum -= oldest;
self.sum += val;
self.window[self.pos] = val;
self.pos = (self.pos + 1) % self.capacity;
if self.count < self.capacity {
self.count += 1;
}
}
pub fn sum(&self) -> f64 {
self.sum
}
pub fn mean(&self) -> Option<f64> {
if self.count == 0 {
None
} else {
Some(self.sum / self.count as f64)
}
}
pub fn count(&self) -> usize {
self.count
}
}
#[derive(Clone, Debug)]
pub struct LocalDecl {
pub fvar: FVarId,
pub name: Name,
pub ty: Expr,
pub val: Option<Expr>,
}
#[allow(dead_code)]
pub struct StringPool {
free: Vec<String>,
}
#[allow(dead_code)]
impl StringPool {
pub fn new() -> Self {
Self { free: Vec::new() }
}
pub fn take(&mut self) -> String {
self.free.pop().unwrap_or_default()
}
pub fn give(&mut self, mut s: String) {
s.clear();
self.free.push(s);
}
pub fn free_count(&self) -> usize {
self.free.len()
}
}
#[allow(dead_code)]
#[allow(missing_docs)]
pub enum DecisionNode {
Leaf(String),
Branch {
key: String,
val: String,
yes_branch: Box<DecisionNode>,
no_branch: Box<DecisionNode>,
},
}
#[allow(dead_code)]
impl DecisionNode {
pub fn evaluate(&self, ctx: &std::collections::HashMap<String, String>) -> &str {
match self {
DecisionNode::Leaf(action) => action.as_str(),
DecisionNode::Branch {
key,
val,
yes_branch,
no_branch,
} => {
let actual = ctx.get(key).map(|s| s.as_str()).unwrap_or("");
if actual == val.as_str() {
yes_branch.evaluate(ctx)
} else {
no_branch.evaluate(ctx)
}
}
}
}
pub fn depth(&self) -> usize {
match self {
DecisionNode::Leaf(_) => 0,
DecisionNode::Branch {
yes_branch,
no_branch,
..
} => 1 + yes_branch.depth().max(no_branch.depth()),
}
}
}
#[allow(dead_code)]
pub struct StatSummary {
count: u64,
sum: f64,
min: f64,
max: f64,
}
#[allow(dead_code)]
impl StatSummary {
pub fn new() -> Self {
Self {
count: 0,
sum: 0.0,
min: f64::INFINITY,
max: f64::NEG_INFINITY,
}
}
pub fn record(&mut self, val: f64) {
self.count += 1;
self.sum += val;
if val < self.min {
self.min = val;
}
if val > self.max {
self.max = val;
}
}
pub fn mean(&self) -> Option<f64> {
if self.count == 0 {
None
} else {
Some(self.sum / self.count as f64)
}
}
pub fn min(&self) -> Option<f64> {
if self.count == 0 {
None
} else {
Some(self.min)
}
}
pub fn max(&self) -> Option<f64> {
if self.count == 0 {
None
} else {
Some(self.max)
}
}
pub fn count(&self) -> u64 {
self.count
}
}
#[allow(dead_code)]
pub struct LabelSet {
labels: Vec<String>,
}
#[allow(dead_code)]
impl LabelSet {
pub fn new() -> Self {
Self { labels: Vec::new() }
}
pub fn add(&mut self, label: impl Into<String>) {
let s = label.into();
if !self.labels.contains(&s) {
self.labels.push(s);
}
}
pub fn has(&self, label: &str) -> bool {
self.labels.iter().any(|l| l == label)
}
pub fn count(&self) -> usize {
self.labels.len()
}
pub fn all(&self) -> &[String] {
&self.labels
}
}
#[allow(dead_code)]
pub struct PathBuf {
components: Vec<String>,
}
#[allow(dead_code)]
impl PathBuf {
pub fn new() -> Self {
Self {
components: Vec::new(),
}
}
pub fn push(&mut self, comp: impl Into<String>) {
self.components.push(comp.into());
}
pub fn pop(&mut self) {
self.components.pop();
}
pub fn as_str(&self) -> String {
self.components.join("/")
}
pub fn depth(&self) -> usize {
self.components.len()
}
pub fn clear(&mut self) {
self.components.clear();
}
}
#[allow(dead_code)]
pub struct RawFnPtr {
ptr: usize,
arity: usize,
name: String,
}
#[allow(dead_code)]
impl RawFnPtr {
pub fn new(ptr: usize, arity: usize, name: impl Into<String>) -> Self {
Self {
ptr,
arity,
name: name.into(),
}
}
pub fn arity(&self) -> usize {
self.arity
}
pub fn name(&self) -> &str {
&self.name
}
pub fn raw(&self) -> usize {
self.ptr
}
}
#[allow(dead_code)]
pub enum Either2<A, B> {
First(A),
Second(B),
}
#[allow(dead_code)]
impl<A, B> Either2<A, B> {
pub fn is_first(&self) -> bool {
matches!(self, Either2::First(_))
}
pub fn is_second(&self) -> bool {
matches!(self, Either2::Second(_))
}
pub fn first(self) -> Option<A> {
match self {
Either2::First(a) => Some(a),
_ => None,
}
}
pub fn second(self) -> Option<B> {
match self {
Either2::Second(b) => Some(b),
_ => None,
}
}
pub fn map_first<C, F: FnOnce(A) -> C>(self, f: F) -> Either2<C, B> {
match self {
Either2::First(a) => Either2::First(f(a)),
Either2::Second(b) => Either2::Second(b),
}
}
}