use super::smt::{SmtBackend, SmtResult, SmtSolver};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RefinementType {
pub base: BaseType,
pub predicate: Option<Predicate>,
pub params: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum BaseType {
Int,
Bool,
String,
Float,
Array(Box<BaseType>),
Tuple(Vec<BaseType>),
Function(Vec<BaseType>, Box<BaseType>),
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Predicate {
pub var: String,
pub expr: String,
}
impl RefinementType {
pub fn bounded_int(min: i64, max: i64) -> Self {
Self {
base: BaseType::Int,
predicate: Some(Predicate {
var: "x".to_string(),
expr: format!("(and (>= x {min}) (<= x {max}))"),
}),
params: Vec::new(),
}
}
pub fn positive_int() -> Self {
Self {
base: BaseType::Int,
predicate: Some(Predicate {
var: "x".to_string(),
expr: "(> x 0)".to_string(),
}),
params: Vec::new(),
}
}
pub fn non_empty_array(elem_type: BaseType) -> Self {
Self {
base: BaseType::Array(Box::new(elem_type)),
predicate: Some(Predicate {
var: "a".to_string(),
expr: "(> (len a) 0)".to_string(),
}),
params: Vec::new(),
}
}
pub fn sorted_array() -> Self {
Self {
base: BaseType::Array(Box::new(BaseType::Int)),
predicate: Some(Predicate {
var: "a".to_string(),
expr: "(sorted a)".to_string(),
}),
params: Vec::new(),
}
}
}
impl fmt::Display for RefinementType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(pred) = &self.predicate {
write!(f, "{} where {}", self.base, pred.expr)
} else {
write!(f, "{}", self.base)
}
}
}
impl fmt::Display for BaseType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Int => write!(f, "Int"),
Self::Bool => write!(f, "Bool"),
Self::String => write!(f, "String"),
Self::Float => write!(f, "Float"),
Self::Array(t) => write!(f, "[{t}]"),
Self::Tuple(ts) => {
write!(f, "(")?;
for (i, t) in ts.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{t}")?;
}
write!(f, ")")
}
Self::Function(params, ret) => {
write!(f, "(")?;
for (i, p) in params.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{p}")?;
}
write!(f, ") -> {ret}")
}
Self::Custom(name) => write!(f, "{name}"),
}
}
}
#[derive(Debug, Clone)]
pub struct TypeRefinement {
pub name: String,
pub input: RefinementType,
pub output: RefinementType,
pub args: Vec<(String, RefinementType)>,
pub preconditions: Vec<String>,
pub postconditions: Vec<String>,
pub invariants: Vec<String>,
}
impl TypeRefinement {
pub fn new(name: &str, input: RefinementType, output: RefinementType) -> Self {
Self {
name: name.to_string(),
input,
output,
args: Vec::new(),
preconditions: Vec::new(),
postconditions: Vec::new(),
invariants: Vec::new(),
}
}
pub fn add_arg(&mut self, name: &str, ty: RefinementType) {
self.args.push((name.to_string(), ty));
}
pub fn add_precondition(&mut self, pred: &str) {
self.preconditions.push(pred.to_string());
}
pub fn add_postcondition(&mut self, pred: &str) {
self.postconditions.push(pred.to_string());
}
pub fn add_invariant(&mut self, inv: &str) {
self.invariants.push(inv.to_string());
}
}
pub struct RefinementChecker {
backend: SmtBackend,
env: HashMap<String, RefinementType>,
signatures: HashMap<String, TypeRefinement>,
}
impl RefinementChecker {
pub fn new() -> Self {
Self {
backend: SmtBackend::Z3,
env: HashMap::new(),
signatures: HashMap::new(),
}
}
pub fn set_backend(&mut self, backend: SmtBackend) {
self.backend = backend;
}
pub fn declare_var(&mut self, name: &str, ty: RefinementType) {
self.env.insert(name.to_string(), ty);
}
pub fn declare_function(&mut self, name: &str, refinement: TypeRefinement) {
self.signatures.insert(name.to_string(), refinement);
}
pub fn is_subtype(
&self,
sub_type: &RefinementType,
super_type: &RefinementType,
) -> Result<bool> {
if sub_type.base != super_type.base {
return Ok(false);
}
match (&sub_type.predicate, &super_type.predicate) {
(Some(sub_pred), Some(super_pred)) => {
self.check_implication(&sub_pred.expr, &super_pred.expr)
}
(Some(_), None) => Ok(true),
(None, Some(_)) => Ok(false),
(None, None) => Ok(true),
}
}
fn check_implication(&self, antecedent: &str, consequent: &str) -> Result<bool> {
let mut solver = SmtSolver::new(self.backend);
solver.assert(antecedent);
solver.assert(&format!("(not {consequent})"));
match solver.check_sat()? {
SmtResult::Unsat => Ok(true),
_ => Ok(false),
}
}
pub fn verify_function(&self, name: &str, body: &str) -> Result<VerificationResult> {
let refinement = self
.signatures
.get(name)
.ok_or_else(|| anyhow::anyhow!("Unknown function: {name}"))?;
let mut solver = SmtSolver::new(self.backend);
for pre in &refinement.preconditions {
solver.assert(pre);
}
solver.assert(body);
for post in &refinement.postconditions {
solver.assert(&format!("(not {post})"));
}
match solver.check_sat()? {
SmtResult::Unsat => Ok(VerificationResult::Valid),
SmtResult::Sat => Ok(VerificationResult::Invalid(
"Postcondition violation".to_string(),
)),
_ => Ok(VerificationResult::Unknown),
}
}
pub fn check_invariant(&self, invariant: &str, body: &str) -> Result<bool> {
let mut solver = SmtSolver::new(self.backend);
solver.assert(invariant);
solver.assert(body);
solver.assert(&format!("(not {invariant})"));
match solver.check_sat()? {
SmtResult::Unsat => Ok(true),
_ => Ok(false),
}
}
}
impl Default for RefinementChecker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum VerificationResult {
Valid,
Invalid(String),
Unknown,
}
impl VerificationResult {
pub fn is_valid(&self) -> bool {
matches!(self, Self::Valid)
}
pub fn error(&self) -> Option<&str> {
match self {
Self::Invalid(msg) => Some(msg),
_ => None,
}
}
}
pub struct LiquidTypeInference {
checker: RefinementChecker,
constraints: Vec<String>,
}
impl LiquidTypeInference {
pub fn new() -> Self {
Self {
checker: RefinementChecker::new(),
constraints: Vec::new(),
}
}
pub fn infer(&mut self, expr: &str) -> Result<RefinementType> {
match expr {
s if s.parse::<i64>().is_ok() => {
let n = s
.parse::<i64>()
.expect("Failed to parse integer after validation");
Ok(RefinementType {
base: BaseType::Int,
predicate: Some(Predicate {
var: "x".to_string(),
expr: format!("(= x {n})"),
}),
params: Vec::new(),
})
}
"true" | "false" => Ok(RefinementType {
base: BaseType::Bool,
predicate: None,
params: Vec::new(),
}),
_ => Ok(RefinementType {
base: BaseType::Custom("Unknown".to_string()),
predicate: None,
params: Vec::new(),
}),
}
}
pub fn add_constraint(&mut self, constraint: &str) {
self.constraints.push(constraint.to_string());
}
pub fn solve(&self) -> Result<bool> {
let mut solver = SmtSolver::new(self.checker.backend);
for constraint in &self.constraints {
solver.assert(constraint);
}
match solver.check_sat()? {
SmtResult::Sat => Ok(true),
_ => Ok(false),
}
}
}
impl Default for LiquidTypeInference {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_refinement_type_display() {
let ty = RefinementType::positive_int();
assert_eq!(ty.to_string(), "Int where (> x 0)");
let bounded = RefinementType::bounded_int(0, 100);
assert_eq!(bounded.to_string(), "Int where (and (>= x 0) (<= x 100))");
}
#[test]
fn test_base_type_display() {
assert_eq!(BaseType::Int.to_string(), "Int");
assert_eq!(
BaseType::Array(Box::new(BaseType::Int)).to_string(),
"[Int]"
);
let func = BaseType::Function(
vec![BaseType::Int, BaseType::Bool],
Box::new(BaseType::String),
);
assert_eq!(func.to_string(), "(Int, Bool) -> String");
}
#[test]
fn test_bounded_int_creation() {
let bounded = RefinementType::bounded_int(-10, 10);
assert_eq!(bounded.base, BaseType::Int);
assert!(bounded.predicate.is_some());
let pred = bounded.predicate.expect("operation should succeed in test");
assert_eq!(pred.var, "x");
assert!(pred.expr.contains("-10"));
assert!(pred.expr.contains("10"));
}
#[test]
fn test_positive_int() {
let pos = RefinementType::positive_int();
assert_eq!(pos.base, BaseType::Int);
assert!(pos.predicate.is_some());
let pred = pos.predicate.expect("operation should succeed in test");
assert_eq!(pred.var, "x");
assert_eq!(pred.expr, "(> x 0)");
}
#[test]
fn test_non_empty_array() {
let non_empty = RefinementType::non_empty_array(BaseType::String);
match &non_empty.base {
BaseType::Array(elem) => assert_eq!(**elem, BaseType::String),
_ => panic!("Expected array type"),
}
assert!(non_empty.predicate.is_some());
let pred = non_empty
.predicate
.expect("operation should succeed in test");
assert_eq!(pred.var, "a");
assert!(pred.expr.contains("len"));
}
#[test]
fn test_sorted_array() {
let sorted = RefinementType::sorted_array();
match &sorted.base {
BaseType::Array(elem) => assert_eq!(**elem, BaseType::Int),
_ => panic!("Expected array type"),
}
assert!(sorted.predicate.is_some());
let pred = sorted.predicate.expect("operation should succeed in test");
assert!(pred.expr.contains("sorted"));
}
#[test]
fn test_type_refinement_creation() {
let mut refinement = TypeRefinement::new(
"add",
RefinementType::positive_int(),
RefinementType::positive_int(),
);
assert_eq!(refinement.name, "add");
refinement.add_arg("x", RefinementType::bounded_int(0, 10));
assert_eq!(refinement.args.len(), 1);
assert_eq!(refinement.args[0].0, "x");
refinement.add_precondition("x >= 0");
assert_eq!(refinement.preconditions.len(), 1);
refinement.add_postcondition("result > x");
assert_eq!(refinement.postconditions.len(), 1);
refinement.add_invariant("x >= 0");
assert_eq!(refinement.invariants.len(), 1);
}
#[test]
fn test_refinement_checker_new() {
let checker = RefinementChecker::new();
assert!(checker.env.is_empty());
assert!(checker.signatures.is_empty());
}
#[test]
fn test_refinement_checker_declare_var() {
let mut checker = RefinementChecker::new();
let ty = RefinementType::positive_int();
checker.declare_var("count", ty);
assert!(checker.env.contains_key("count"));
let stored = &checker.env["count"];
assert_eq!(stored.base, BaseType::Int);
}
#[test]
fn test_refinement_checker_declare_function() {
let mut checker = RefinementChecker::new();
let refinement = TypeRefinement::new(
"increment",
RefinementType::positive_int(),
RefinementType::positive_int(),
);
checker.declare_function("increment", refinement);
assert!(checker.signatures.contains_key("increment"));
let stored = &checker.signatures["increment"];
assert_eq!(stored.name, "increment");
}
#[test]
fn test_is_subtype_same_base() {
let checker = RefinementChecker::new();
let ty1 = RefinementType {
base: BaseType::Int,
predicate: None,
params: vec![],
};
let ty2 = ty1.clone();
assert!(checker
.is_subtype(&ty1, &ty2)
.expect("operation should succeed in test"));
}
#[test]
fn test_is_subtype_different_base() {
let checker = RefinementChecker::new();
let int_ty = RefinementType {
base: BaseType::Int,
predicate: None,
params: vec![],
};
let bool_ty = RefinementType {
base: BaseType::Bool,
predicate: None,
params: vec![],
};
assert!(!checker
.is_subtype(&int_ty, &bool_ty)
.expect("operation should succeed in test"));
}
#[test]
fn test_is_subtype_with_predicates() {
let checker = RefinementChecker::new();
let pos_int = RefinementType::positive_int();
let plain_int = RefinementType {
base: BaseType::Int,
predicate: None,
params: vec![],
};
assert!(checker
.is_subtype(&pos_int, &plain_int)
.expect("operation should succeed in test"));
let _ = checker.is_subtype(&plain_int, &pos_int);
}
#[test]
fn test_base_type_equality() {
assert_eq!(BaseType::Int, BaseType::Int);
assert_ne!(BaseType::Int, BaseType::Bool);
assert_eq!(BaseType::String, BaseType::String);
assert_ne!(BaseType::Float, BaseType::String);
}
#[test]
fn test_array_type() {
let array_int = BaseType::Array(Box::new(BaseType::Int));
let array_bool = BaseType::Array(Box::new(BaseType::Bool));
assert_ne!(array_int, array_bool);
match array_int {
BaseType::Array(elem) => assert_eq!(*elem, BaseType::Int),
_ => panic!("Expected array type"),
}
}
#[test]
fn test_tuple_type() {
let tuple = BaseType::Tuple(vec![BaseType::Int, BaseType::Bool, BaseType::String]);
match tuple {
BaseType::Tuple(elems) => {
assert_eq!(elems.len(), 3);
assert_eq!(elems[0], BaseType::Int);
assert_eq!(elems[1], BaseType::Bool);
assert_eq!(elems[2], BaseType::String);
}
_ => panic!("Expected tuple type"),
}
}
#[test]
fn test_function_type() {
let func = BaseType::Function(vec![BaseType::Int, BaseType::Int], Box::new(BaseType::Bool));
match func {
BaseType::Function(args, ret) => {
assert_eq!(args.len(), 2);
assert_eq!(args[0], BaseType::Int);
assert_eq!(*ret, BaseType::Bool);
}
_ => panic!("Expected function type"),
}
}
#[test]
fn test_custom_type() {
let custom = BaseType::Custom("MyType".to_string());
match custom {
BaseType::Custom(name) => assert_eq!(name, "MyType"),
_ => panic!("Expected custom type"),
}
}
#[test]
fn test_predicate_creation() {
let pred = Predicate {
var: "n".to_string(),
expr: "(>= n 0)".to_string(),
};
assert_eq!(pred.var, "n");
assert_eq!(pred.expr, "(>= n 0)");
}
#[test]
fn test_refinement_type_with_params() {
let ty = RefinementType {
base: BaseType::Custom("Map".to_string()),
predicate: None,
params: vec!["K".to_string(), "V".to_string()],
};
assert_eq!(ty.params.len(), 2);
assert_eq!(ty.params[0], "K");
assert_eq!(ty.params[1], "V");
}
#[test]
fn test_nested_array() {
let nested = BaseType::Array(Box::new(BaseType::Array(Box::new(BaseType::Int))));
let formatted = nested.to_string();
assert_eq!(formatted, "[[Int]]");
}
#[test]
fn test_complex_function_type() {
let func = BaseType::Function(
vec![
BaseType::Array(Box::new(BaseType::Int)),
BaseType::Tuple(vec![BaseType::Bool, BaseType::String]),
],
Box::new(BaseType::Float),
);
let formatted = func.to_string();
assert!(formatted.contains("Int"));
assert!(formatted.contains("Bool"));
assert!(formatted.contains("String"));
assert!(formatted.contains("Float"));
}
#[test]
fn test_refinement_checker_set_backend() {
let mut checker = RefinementChecker::new();
checker.set_backend(SmtBackend::CVC5);
checker.set_backend(SmtBackend::Z3);
}
}
#[cfg(test)]
mod property_tests_refinement {
use proptest::proptest;
proptest! {
#[test]
fn test_bounded_int_never_panics(input: String) {
let _input = if input.len() > 100 { &input[..100] } else { &input[..] };
let _ = std::panic::catch_unwind(|| {
});
}
}
}