use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::fmt;
use crate::{IrError, ParametricType, TLExpr, Term};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Refinement {
pub var_name: String,
pub predicate: TLExpr,
}
impl Refinement {
pub fn new(var_name: impl Into<String>, predicate: TLExpr) -> Self {
Refinement {
var_name: var_name.into(),
predicate,
}
}
pub fn free_vars(&self) -> HashSet<String> {
let mut vars = self.predicate.free_vars();
vars.remove(&self.var_name);
vars
}
pub fn substitute(&self, subst: &HashMap<String, Term>) -> Refinement {
let mut filtered_subst = subst.clone();
filtered_subst.remove(&self.var_name);
Refinement {
var_name: self.var_name.clone(),
predicate: self.predicate.clone(), }
}
pub fn simplify(&self) -> Refinement {
use crate::optimize_expr;
Refinement {
var_name: self.var_name.clone(),
predicate: optimize_expr(&self.predicate),
}
}
pub fn implies(&self, other: &Refinement) -> bool {
if self.var_name != other.var_name {
return false;
}
self.predicate == other.predicate
}
}
impl fmt::Display for Refinement {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{{{}: | {}}}", self.var_name, self.predicate)
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct RefinementType {
pub var_name: String,
pub base_type: ParametricType,
pub refinement: TLExpr,
}
impl RefinementType {
pub fn new(
var_name: impl Into<String>,
base_type: impl Into<String>,
refinement: TLExpr,
) -> Self {
RefinementType {
var_name: var_name.into(),
base_type: ParametricType::concrete(base_type),
refinement,
}
}
pub fn from_parametric(
var_name: impl Into<String>,
base_type: ParametricType,
refinement: TLExpr,
) -> Self {
RefinementType {
var_name: var_name.into(),
base_type,
refinement,
}
}
pub fn positive_int(var_name: impl Into<String>) -> Self {
let var_name = var_name.into();
RefinementType::new(
var_name.clone(),
"Int",
TLExpr::gt(TLExpr::pred(&var_name, vec![]), TLExpr::constant(0.0)),
)
}
pub fn nat(var_name: impl Into<String>) -> Self {
let var_name = var_name.into();
RefinementType::new(
var_name.clone(),
"Int",
TLExpr::gte(TLExpr::pred(&var_name, vec![]), TLExpr::constant(0.0)),
)
}
pub fn probability(var_name: impl Into<String>) -> Self {
let var_name = var_name.into();
RefinementType::new(
var_name.clone(),
"Float",
TLExpr::and(
TLExpr::gte(TLExpr::pred(&var_name, vec![]), TLExpr::constant(0.0)),
TLExpr::lte(TLExpr::pred(&var_name, vec![]), TLExpr::constant(1.0)),
),
)
}
pub fn non_empty_vec(var_name: impl Into<String>, element_type: impl Into<String>) -> Self {
let var_name = var_name.into();
use crate::TypeConstructor;
let elem_type = ParametricType::concrete(element_type);
let vec_type = ParametricType::apply(TypeConstructor::List, vec![elem_type]);
RefinementType::from_parametric(
var_name.clone(),
vec_type,
TLExpr::gt(TLExpr::pred("length", vec![]), TLExpr::constant(0.0)),
)
}
pub fn free_vars(&self) -> HashSet<String> {
let mut vars = self.refinement.free_vars();
vars.remove(&self.var_name);
vars
}
pub fn is_subtype_of(&self, other: &RefinementType) -> bool {
if self.base_type != other.base_type {
return false;
}
if self.var_name != other.var_name {
return false;
}
self.refinement == other.refinement
}
pub fn weaken(&self) -> RefinementType {
RefinementType {
var_name: self.var_name.clone(),
base_type: self.base_type.clone(),
refinement: TLExpr::constant(1.0), }
}
pub fn strengthen(&self, additional: TLExpr) -> RefinementType {
RefinementType {
var_name: self.var_name.clone(),
base_type: self.base_type.clone(),
refinement: TLExpr::and(self.refinement.clone(), additional),
}
}
}
impl fmt::Display for RefinementType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{{{}: {} | {}}}",
self.var_name, self.base_type, self.refinement
)
}
}
#[derive(Clone, Debug, Default)]
pub struct RefinementContext {
bindings: HashMap<String, RefinementType>,
assumptions: Vec<TLExpr>,
}
impl RefinementContext {
pub fn new() -> Self {
Self::default()
}
pub fn bind(&mut self, name: impl Into<String>, typ: RefinementType) {
let name = name.into();
let assumption = typ.refinement.clone();
self.assumptions.push(assumption);
self.bindings.insert(name, typ);
}
pub fn get_type(&self, name: &str) -> Option<&RefinementType> {
self.bindings.get(name)
}
pub fn assume(&mut self, fact: TLExpr) {
self.assumptions.push(fact);
}
pub fn check_refinement(&self, refinement: &TLExpr) -> bool {
self.assumptions.contains(refinement)
}
pub fn verify(&self, _value: &Term, _typ: &RefinementType) -> Result<(), IrError> {
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct LiquidTypeInference {
context: RefinementContext,
unknowns: HashMap<String, Vec<TLExpr>>,
}
impl LiquidTypeInference {
pub fn new() -> Self {
LiquidTypeInference {
context: RefinementContext::new(),
unknowns: HashMap::new(),
}
}
pub fn add_unknown(&mut self, name: impl Into<String>, candidates: Vec<TLExpr>) {
self.unknowns.insert(name.into(), candidates);
}
pub fn infer(&mut self) -> HashMap<String, TLExpr> {
let mut inferred = HashMap::new();
for (name, candidates) in &self.unknowns {
if let Some(refinement) = candidates.first() {
inferred.insert(name.clone(), refinement.clone());
}
}
inferred
}
pub fn context(&self) -> &RefinementContext {
&self.context
}
}
impl Default for LiquidTypeInference {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_refinement_creation() {
let predicate = TLExpr::gt(TLExpr::pred("x", vec![]), TLExpr::constant(0.0));
let refinement = Refinement::new("x", predicate.clone());
assert_eq!(refinement.var_name, "x");
assert_eq!(refinement.predicate, predicate);
}
#[test]
fn test_refinement_type_positive_int() {
let pos_int = RefinementType::positive_int("x");
assert_eq!(pos_int.var_name, "x");
assert_eq!(pos_int.base_type, ParametricType::concrete("Int"));
assert!(pos_int.free_vars().is_empty());
}
#[test]
fn test_refinement_type_nat() {
let nat = RefinementType::nat("n");
assert_eq!(nat.to_string(), "{n: Int | (n() ≥ 0)}");
}
#[test]
fn test_refinement_type_probability() {
let prob = RefinementType::probability("p");
let s = prob.to_string();
assert!(s.contains("Float"));
assert!(s.contains("≥") || s.contains(">="));
assert!(s.contains("≤") || s.contains("<="));
}
#[test]
fn test_refinement_context() {
let mut ctx = RefinementContext::new();
let pos_int = RefinementType::positive_int("x");
ctx.bind("x", pos_int.clone());
assert!(ctx.get_type("x").is_some());
assert_eq!(ctx.get_type("x").expect("unwrap"), &pos_int);
}
#[test]
fn test_refinement_type_weaken() {
let pos_int = RefinementType::positive_int("x");
let weakened = pos_int.weaken();
assert_eq!(weakened.base_type, pos_int.base_type);
assert_eq!(weakened.refinement, TLExpr::constant(1.0));
}
#[test]
fn test_refinement_type_strengthen() {
let pos_int = RefinementType::positive_int("x");
let additional = TLExpr::lt(TLExpr::pred("x", vec![]), TLExpr::constant(100.0));
let strengthened = pos_int.strengthen(additional.clone());
if let TLExpr::And(left, right) = &strengthened.refinement {
assert!(**left == pos_int.refinement || **right == pos_int.refinement);
} else {
panic!("Expected AND expression");
}
}
#[test]
fn test_liquid_type_inference() {
let mut inference = LiquidTypeInference::new();
let candidates = vec![
TLExpr::gt(TLExpr::pred("x", vec![]), TLExpr::constant(0.0)),
TLExpr::gte(TLExpr::pred("x", vec![]), TLExpr::constant(0.0)),
];
inference.add_unknown("x_refinement", candidates);
let inferred = inference.infer();
assert!(inferred.contains_key("x_refinement"));
}
#[test]
fn test_refinement_free_vars() {
let predicate = TLExpr::and(
TLExpr::gt(TLExpr::pred("x", vec![]), TLExpr::constant(0.0)),
TLExpr::lt(TLExpr::pred("x", vec![]), TLExpr::pred("y", vec![])),
);
let refinement = Refinement::new("x", predicate);
let free_vars = refinement.free_vars();
assert!(!free_vars.contains("x"));
assert!(free_vars.contains("y") || free_vars.is_empty()); }
#[test]
fn test_non_empty_vec() {
let non_empty = RefinementType::non_empty_vec("v", "Int");
assert!(non_empty.to_string().contains("List"));
assert!(non_empty.to_string().contains("length"));
}
#[test]
fn test_refinement_context_assumptions() {
let mut ctx = RefinementContext::new();
let fact = TLExpr::gt(TLExpr::pred("x", vec![]), TLExpr::constant(0.0));
ctx.assume(fact.clone());
assert!(ctx.check_refinement(&fact));
}
#[test]
fn test_refinement_type_subtyping() {
let pos_int = RefinementType::positive_int("x");
let nat = RefinementType::nat("x");
assert!(!pos_int.is_subtype_of(&nat)); }
}