use std::collections::HashMap;
use std::sync::Arc;
#[derive(Clone)]
pub enum RefinementPredicate {
Equal(f64),
NotEqual(f64),
GreaterThan(f64),
GreaterThanOrEqual(f64),
LessThan(f64),
LessThanOrEqual(f64),
Range { min: f64, max: f64 },
RangeExclusive { min: f64, max: f64 },
Modulo { divisor: i64, remainder: i64 },
InSet(Vec<f64>),
NotInSet(Vec<f64>),
And(Vec<RefinementPredicate>),
Or(Vec<RefinementPredicate>),
Not(Box<RefinementPredicate>),
Custom {
name: String,
description: String,
checker: Arc<dyn Fn(f64) -> bool + Send + Sync>,
},
Dependent {
variable: String,
relation: DependentRelation,
},
StringLength {
min: Option<usize>,
max: Option<usize>,
},
Pattern(String),
}
#[derive(Debug, Clone, PartialEq)]
pub enum DependentRelation {
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
Equal,
NotEqual,
Divides,
DivisibleBy,
}
impl std::fmt::Debug for RefinementPredicate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RefinementPredicate::Equal(v) => f.debug_tuple("Equal").field(v).finish(),
RefinementPredicate::NotEqual(v) => f.debug_tuple("NotEqual").field(v).finish(),
RefinementPredicate::GreaterThan(v) => f.debug_tuple("GreaterThan").field(v).finish(),
RefinementPredicate::GreaterThanOrEqual(v) => {
f.debug_tuple("GreaterThanOrEqual").field(v).finish()
}
RefinementPredicate::LessThan(v) => f.debug_tuple("LessThan").field(v).finish(),
RefinementPredicate::LessThanOrEqual(v) => {
f.debug_tuple("LessThanOrEqual").field(v).finish()
}
RefinementPredicate::Range { min, max } => f
.debug_struct("Range")
.field("min", min)
.field("max", max)
.finish(),
RefinementPredicate::RangeExclusive { min, max } => f
.debug_struct("RangeExclusive")
.field("min", min)
.field("max", max)
.finish(),
RefinementPredicate::Modulo { divisor, remainder } => f
.debug_struct("Modulo")
.field("divisor", divisor)
.field("remainder", remainder)
.finish(),
RefinementPredicate::InSet(set) => f.debug_tuple("InSet").field(set).finish(),
RefinementPredicate::NotInSet(set) => f.debug_tuple("NotInSet").field(set).finish(),
RefinementPredicate::And(preds) => f.debug_tuple("And").field(preds).finish(),
RefinementPredicate::Or(preds) => f.debug_tuple("Or").field(preds).finish(),
RefinementPredicate::Not(pred) => f.debug_tuple("Not").field(pred).finish(),
RefinementPredicate::Custom {
name, description, ..
} => f
.debug_struct("Custom")
.field("name", name)
.field("description", description)
.finish(),
RefinementPredicate::Dependent { variable, relation } => f
.debug_struct("Dependent")
.field("variable", variable)
.field("relation", relation)
.finish(),
RefinementPredicate::StringLength { min, max } => f
.debug_struct("StringLength")
.field("min", min)
.field("max", max)
.finish(),
RefinementPredicate::Pattern(pattern) => {
f.debug_tuple("Pattern").field(pattern).finish()
}
}
}
}
impl RefinementPredicate {
pub fn greater_than(value: f64) -> Self {
RefinementPredicate::GreaterThan(value)
}
pub fn greater_than_or_equal(value: f64) -> Self {
RefinementPredicate::GreaterThanOrEqual(value)
}
pub fn less_than(value: f64) -> Self {
RefinementPredicate::LessThan(value)
}
pub fn less_than_or_equal(value: f64) -> Self {
RefinementPredicate::LessThanOrEqual(value)
}
pub fn range(min: f64, max: f64) -> Self {
RefinementPredicate::Range { min, max }
}
pub fn modulo(divisor: i64, remainder: i64) -> Self {
RefinementPredicate::Modulo { divisor, remainder }
}
pub fn in_set(values: Vec<f64>) -> Self {
RefinementPredicate::InSet(values)
}
pub fn and(predicates: Vec<RefinementPredicate>) -> Self {
RefinementPredicate::And(predicates)
}
pub fn or(predicates: Vec<RefinementPredicate>) -> Self {
RefinementPredicate::Or(predicates)
}
#[allow(clippy::should_implement_trait)]
pub fn not(predicate: RefinementPredicate) -> Self {
RefinementPredicate::Not(Box::new(predicate))
}
pub fn custom<F>(name: impl Into<String>, description: impl Into<String>, checker: F) -> Self
where
F: Fn(f64) -> bool + Send + Sync + 'static,
{
RefinementPredicate::Custom {
name: name.into(),
description: description.into(),
checker: Arc::new(checker),
}
}
pub fn dependent(variable: impl Into<String>, relation: DependentRelation) -> Self {
RefinementPredicate::Dependent {
variable: variable.into(),
relation,
}
}
pub fn check(&self, value: f64) -> bool {
match self {
RefinementPredicate::Equal(v) => (value - v).abs() < f64::EPSILON,
RefinementPredicate::NotEqual(v) => (value - v).abs() >= f64::EPSILON,
RefinementPredicate::GreaterThan(v) => value > *v,
RefinementPredicate::GreaterThanOrEqual(v) => value >= *v,
RefinementPredicate::LessThan(v) => value < *v,
RefinementPredicate::LessThanOrEqual(v) => value <= *v,
RefinementPredicate::Range { min, max } => value >= *min && value <= *max,
RefinementPredicate::RangeExclusive { min, max } => value >= *min && value < *max,
RefinementPredicate::Modulo { divisor, remainder } => {
(value as i64) % divisor == *remainder
}
RefinementPredicate::InSet(set) => set.iter().any(|v| (value - v).abs() < f64::EPSILON),
RefinementPredicate::NotInSet(set) => {
!set.iter().any(|v| (value - v).abs() < f64::EPSILON)
}
RefinementPredicate::And(preds) => preds.iter().all(|p| p.check(value)),
RefinementPredicate::Or(preds) => preds.iter().any(|p| p.check(value)),
RefinementPredicate::Not(pred) => !pred.check(value),
RefinementPredicate::Custom { checker, .. } => checker(value),
RefinementPredicate::Dependent { .. } => true, RefinementPredicate::StringLength { .. } => true, RefinementPredicate::Pattern(_) => true, }
}
pub fn check_with_context(&self, value: f64, context: &RefinementContext) -> bool {
match self {
RefinementPredicate::Dependent { variable, relation } => {
if let Some(&other) = context.get_value(variable) {
match relation {
DependentRelation::LessThan => value < other,
DependentRelation::LessThanOrEqual => value <= other,
DependentRelation::GreaterThan => value > other,
DependentRelation::GreaterThanOrEqual => value >= other,
DependentRelation::Equal => (value - other).abs() < f64::EPSILON,
DependentRelation::NotEqual => (value - other).abs() >= f64::EPSILON,
DependentRelation::Divides => {
other != 0.0 && (other as i64) % (value as i64) == 0
}
DependentRelation::DivisibleBy => {
value != 0.0 && (value as i64) % (other as i64) == 0
}
}
} else {
false }
}
RefinementPredicate::And(preds) => {
preds.iter().all(|p| p.check_with_context(value, context))
}
RefinementPredicate::Or(preds) => {
preds.iter().any(|p| p.check_with_context(value, context))
}
RefinementPredicate::Not(pred) => !pred.check_with_context(value, context),
_ => self.check(value),
}
}
pub fn free_variables(&self) -> Vec<String> {
match self {
RefinementPredicate::Dependent { variable, .. } => vec![variable.clone()],
RefinementPredicate::And(preds) | RefinementPredicate::Or(preds) => {
let mut vars = Vec::new();
for pred in preds {
vars.extend(pred.free_variables());
}
vars.sort();
vars.dedup();
vars
}
RefinementPredicate::Not(pred) => pred.free_variables(),
_ => vec![],
}
}
pub fn simplify(&self) -> RefinementPredicate {
match self {
RefinementPredicate::And(preds) => {
let simplified: Vec<_> = preds.iter().map(|p| p.simplify()).collect();
if simplified.len() == 1 {
simplified
.into_iter()
.next()
.expect("validated length == 1")
} else {
let mut min_val = f64::NEG_INFINITY;
let mut max_val = f64::INFINITY;
let mut others = Vec::new();
for pred in simplified {
match pred {
RefinementPredicate::GreaterThan(v) => {
min_val = min_val.max(v);
}
RefinementPredicate::GreaterThanOrEqual(v) => {
min_val = min_val.max(v);
}
RefinementPredicate::LessThan(v) => {
max_val = max_val.min(v);
}
RefinementPredicate::LessThanOrEqual(v) => {
max_val = max_val.min(v);
}
RefinementPredicate::Range { min, max } => {
min_val = min_val.max(min);
max_val = max_val.min(max);
}
other => others.push(other),
}
}
if min_val > f64::NEG_INFINITY || max_val < f64::INFINITY {
if min_val > f64::NEG_INFINITY && max_val < f64::INFINITY {
others.insert(
0,
RefinementPredicate::Range {
min: min_val,
max: max_val,
},
);
} else if min_val > f64::NEG_INFINITY {
others.insert(0, RefinementPredicate::GreaterThanOrEqual(min_val));
} else {
others.insert(0, RefinementPredicate::LessThanOrEqual(max_val));
}
}
if others.len() == 1 {
others.into_iter().next().expect("validated length == 1")
} else {
RefinementPredicate::And(others)
}
}
}
RefinementPredicate::Or(preds) => {
let simplified: Vec<_> = preds.iter().map(|p| p.simplify()).collect();
if simplified.len() == 1 {
simplified
.into_iter()
.next()
.expect("validated length == 1")
} else {
RefinementPredicate::Or(simplified)
}
}
RefinementPredicate::Not(pred) => {
let inner = pred.simplify();
match inner {
RefinementPredicate::Not(p) => *p, other => RefinementPredicate::Not(Box::new(other)),
}
}
other => other.clone(),
}
}
pub fn to_string_repr(&self) -> String {
match self {
RefinementPredicate::Equal(v) => format!("x == {}", v),
RefinementPredicate::NotEqual(v) => format!("x != {}", v),
RefinementPredicate::GreaterThan(v) => format!("x > {}", v),
RefinementPredicate::GreaterThanOrEqual(v) => format!("x >= {}", v),
RefinementPredicate::LessThan(v) => format!("x < {}", v),
RefinementPredicate::LessThanOrEqual(v) => format!("x <= {}", v),
RefinementPredicate::Range { min, max } => format!("{} <= x <= {}", min, max),
RefinementPredicate::RangeExclusive { min, max } => format!("{} <= x < {}", min, max),
RefinementPredicate::Modulo { divisor, remainder } => {
format!("x % {} == {}", divisor, remainder)
}
RefinementPredicate::InSet(set) => format!("x in {:?}", set),
RefinementPredicate::NotInSet(set) => format!("x not in {:?}", set),
RefinementPredicate::And(preds) => {
let parts: Vec<_> = preds.iter().map(|p| p.to_string_repr()).collect();
format!("({})", parts.join(" && "))
}
RefinementPredicate::Or(preds) => {
let parts: Vec<_> = preds.iter().map(|p| p.to_string_repr()).collect();
format!("({})", parts.join(" || "))
}
RefinementPredicate::Not(pred) => format!("!({})", pred.to_string_repr()),
RefinementPredicate::Custom { name, .. } => format!("{}(x)", name),
RefinementPredicate::Dependent { variable, relation } => {
let rel_str = match relation {
DependentRelation::LessThan => "<",
DependentRelation::LessThanOrEqual => "<=",
DependentRelation::GreaterThan => ">",
DependentRelation::GreaterThanOrEqual => ">=",
DependentRelation::Equal => "==",
DependentRelation::NotEqual => "!=",
DependentRelation::Divides => "divides",
DependentRelation::DivisibleBy => "divisible_by",
};
format!("x {} {}", rel_str, variable)
}
RefinementPredicate::StringLength { min, max } => match (min, max) {
(Some(min), Some(max)) => format!("{} <= len(x) <= {}", min, max),
(Some(min), None) => format!("len(x) >= {}", min),
(None, Some(max)) => format!("len(x) <= {}", max),
(None, None) => "true".to_string(),
},
RefinementPredicate::Pattern(pattern) => format!("x matches \"{}\"", pattern),
}
}
}
#[derive(Debug, Clone)]
pub struct RefinementType {
pub base_type: String,
pub name: Option<String>,
pub predicates: Vec<RefinementPredicate>,
pub description: Option<String>,
}
impl RefinementType {
pub fn new(base_type: impl Into<String>) -> Self {
RefinementType {
base_type: base_type.into(),
name: None,
predicates: Vec::new(),
description: None,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_predicate(mut self, predicate: RefinementPredicate) -> Self {
self.predicates.push(predicate);
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn check(&self, value: f64) -> bool {
self.predicates.iter().all(|p| p.check(value))
}
pub fn check_with_context(&self, value: f64, context: &RefinementContext) -> bool {
self.predicates
.iter()
.all(|p| p.check_with_context(value, context))
}
pub fn type_name(&self) -> &str {
self.name.as_deref().unwrap_or(&self.base_type)
}
pub fn is_subtype_of(&self, other: &RefinementType) -> bool {
if self.base_type != other.base_type {
return false;
}
if other.predicates.is_empty() {
return true;
}
if self.predicates.is_empty() && !other.predicates.is_empty() {
return false;
}
for other_pred in &other.predicates {
if !self.implies_predicate(other_pred) {
return false;
}
}
true
}
fn implies_predicate(&self, target: &RefinementPredicate) -> bool {
let target_repr = format!("{:?}", target);
if self
.predicates
.iter()
.any(|p| format!("{:?}", p) == target_repr)
{
return true;
}
for pred in &self.predicates {
if Self::semantic_implies(pred, target) {
return true;
}
}
Self::conjunction_implies(&self.predicates, target)
}
fn semantic_implies(source: &RefinementPredicate, target: &RefinementPredicate) -> bool {
use RefinementPredicate::*;
match (source, target) {
(
Range {
min: min1,
max: max1,
},
Range {
min: min2,
max: max2,
},
) => {
min1 >= min2 && max1 <= max2
}
(
RangeExclusive {
min: min1,
max: max1,
},
RangeExclusive {
min: min2,
max: max2,
},
) => min1 >= min2 && max1 <= max2,
(GreaterThan(v1), GreaterThan(v2)) => v1 >= v2,
(GreaterThanOrEqual(v1), GreaterThanOrEqual(v2)) => v1 >= v2,
(GreaterThan(v1), GreaterThanOrEqual(v2)) => v1 >= v2, (LessThan(v1), LessThan(v2)) => v1 <= v2,
(LessThanOrEqual(v1), LessThanOrEqual(v2)) => v1 <= v2,
(LessThan(v1), LessThanOrEqual(v2)) => v1 <= v2, (Equal(v1), GreaterThan(v2)) => v1 > v2,
(Equal(v1), GreaterThanOrEqual(v2)) => v1 >= v2,
(Equal(v1), LessThan(v2)) => v1 < v2,
(Equal(v1), LessThanOrEqual(v2)) => v1 <= v2,
(Equal(v1), Range { min, max }) => v1 >= min && v1 <= max,
(
Modulo {
divisor: d1,
remainder: r1,
},
Modulo {
divisor: d2,
remainder: r2,
},
) => r1 == r2 && d1 % d2 == 0,
(
Dependent {
variable: v1,
relation: rel1,
},
Dependent {
variable: v2,
relation: rel2,
},
) => {
if v1 != v2 {
return false;
}
use DependentRelation::*;
matches!(
(rel1, rel2),
(Equal, Equal)
| (GreaterThan, GreaterThan)
| (GreaterThan, GreaterThanOrEqual)
| (LessThan, LessThan)
| (LessThan, LessThanOrEqual)
| (GreaterThanOrEqual, GreaterThanOrEqual)
| (LessThanOrEqual, LessThanOrEqual)
)
}
_ => false,
}
}
fn conjunction_implies(
predicates: &[RefinementPredicate],
target: &RefinementPredicate,
) -> bool {
use RefinementPredicate::*;
let mut lower_bounds = Vec::new();
let mut upper_bounds = Vec::new();
for pred in predicates {
match pred {
GreaterThan(v) | GreaterThanOrEqual(v) => {
lower_bounds.push(*v);
}
LessThan(v) | LessThanOrEqual(v) => {
upper_bounds.push(*v);
}
Range { min, max } => {
lower_bounds.push(*min);
upper_bounds.push(*max);
}
Equal(v) => {
lower_bounds.push(*v);
upper_bounds.push(*v);
}
_ => {}
}
}
match target {
GreaterThan(v) | GreaterThanOrEqual(v) => lower_bounds.iter().any(|lb| lb >= v),
LessThan(v) | LessThanOrEqual(v) => upper_bounds.iter().any(|ub| ub <= v),
Range { min, max } => {
lower_bounds.iter().any(|lb| lb >= min) && upper_bounds.iter().any(|ub| ub <= max)
}
_ => false,
}
}
pub fn free_variables(&self) -> Vec<String> {
let mut vars = Vec::new();
for pred in &self.predicates {
vars.extend(pred.free_variables());
}
vars.sort();
vars.dedup();
vars
}
pub fn to_string_repr(&self) -> String {
if self.predicates.is_empty() {
return self.base_type.clone();
}
let pred_strs: Vec<_> = self.predicates.iter().map(|p| p.to_string_repr()).collect();
format!("{}{{{}}}", self.base_type, pred_strs.join(" && "))
}
}
#[derive(Debug, Clone, Default)]
pub struct RefinementContext {
values: HashMap<String, f64>,
types: HashMap<String, RefinementType>,
}
impl RefinementContext {
pub fn new() -> Self {
RefinementContext {
values: HashMap::new(),
types: HashMap::new(),
}
}
pub fn set_value(&mut self, var: impl Into<String>, value: f64) {
self.values.insert(var.into(), value);
}
pub fn get_value(&self, var: &str) -> Option<&f64> {
self.values.get(var)
}
pub fn set_type(&mut self, var: impl Into<String>, ty: RefinementType) {
self.types.insert(var.into(), ty);
}
pub fn get_type(&self, var: &str) -> Option<&RefinementType> {
self.types.get(var)
}
pub fn has_variable(&self, var: &str) -> bool {
self.values.contains_key(var) || self.types.contains_key(var)
}
pub fn variables(&self) -> Vec<&str> {
let mut vars: Vec<_> = self.values.keys().map(|s| s.as_str()).collect();
for key in self.types.keys() {
if !self.values.contains_key(key) {
vars.push(key.as_str());
}
}
vars
}
}
#[derive(Debug, Clone, Default)]
pub struct RefinementRegistry {
types: HashMap<String, RefinementType>,
}
impl RefinementRegistry {
pub fn new() -> Self {
RefinementRegistry {
types: HashMap::new(),
}
}
pub fn with_builtins() -> Self {
let mut registry = RefinementRegistry::new();
registry.register(
RefinementType::new("Int")
.with_name("PositiveInt")
.with_predicate(RefinementPredicate::GreaterThan(0.0))
.with_description("Strictly positive integer"),
);
registry.register(
RefinementType::new("Int")
.with_name("NonNegativeInt")
.with_predicate(RefinementPredicate::GreaterThanOrEqual(0.0))
.with_description("Non-negative integer (zero or positive)"),
);
registry.register(
RefinementType::new("Float")
.with_name("Probability")
.with_predicate(RefinementPredicate::Range { min: 0.0, max: 1.0 })
.with_description("Probability value between 0 and 1"),
);
registry.register(
RefinementType::new("Float")
.with_name("Percentage")
.with_predicate(RefinementPredicate::Range {
min: 0.0,
max: 100.0,
})
.with_description("Percentage value between 0 and 100"),
);
registry.register(
RefinementType::new("Float")
.with_name("Normalized")
.with_predicate(RefinementPredicate::Range {
min: -1.0,
max: 1.0,
})
.with_description("Normalized value between -1 and 1"),
);
registry.register(
RefinementType::new("Int")
.with_name("Natural")
.with_predicate(RefinementPredicate::And(vec![
RefinementPredicate::GreaterThanOrEqual(0.0),
RefinementPredicate::Modulo {
divisor: 1,
remainder: 0,
},
]))
.with_description("Natural number (non-negative integer)"),
);
registry.register(
RefinementType::new("Int")
.with_name("Even")
.with_predicate(RefinementPredicate::Modulo {
divisor: 2,
remainder: 0,
})
.with_description("Even integer"),
);
registry.register(
RefinementType::new("Int")
.with_name("Odd")
.with_predicate(RefinementPredicate::Modulo {
divisor: 2,
remainder: 1,
})
.with_description("Odd integer"),
);
registry
}
pub fn register(&mut self, refinement: RefinementType) {
let name = refinement.type_name().to_string();
self.types.insert(name, refinement);
}
pub fn get(&self, name: &str) -> Option<&RefinementType> {
self.types.get(name)
}
pub fn contains(&self, name: &str) -> bool {
self.types.contains_key(name)
}
pub fn type_names(&self) -> Vec<&str> {
self.types.keys().map(|s| s.as_str()).collect()
}
pub fn len(&self) -> usize {
self.types.len()
}
pub fn is_empty(&self) -> bool {
self.types.is_empty()
}
pub fn check(&self, type_name: &str, value: f64) -> Option<bool> {
self.types.get(type_name).map(|t| t.check(value))
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &RefinementType)> {
self.types.iter().map(|(k, v)| (k.as_str(), v))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_predicates() {
let pred = RefinementPredicate::GreaterThan(0.0);
assert!(pred.check(5.0));
assert!(!pred.check(-1.0));
assert!(!pred.check(0.0));
}
#[test]
fn test_range_predicate() {
let pred = RefinementPredicate::Range { min: 0.0, max: 1.0 };
assert!(pred.check(0.5));
assert!(pred.check(0.0));
assert!(pred.check(1.0));
assert!(!pred.check(-0.1));
assert!(!pred.check(1.1));
}
#[test]
fn test_modulo_predicate() {
let even = RefinementPredicate::Modulo {
divisor: 2,
remainder: 0,
};
assert!(even.check(4.0));
assert!(even.check(0.0));
assert!(!even.check(3.0));
}
#[test]
fn test_compound_predicates() {
let pred = RefinementPredicate::And(vec![
RefinementPredicate::GreaterThan(0.0),
RefinementPredicate::Modulo {
divisor: 2,
remainder: 0,
},
]);
assert!(pred.check(4.0));
assert!(!pred.check(-2.0)); assert!(!pred.check(3.0)); }
#[test]
fn test_in_set_predicate() {
let pred = RefinementPredicate::InSet(vec![1.0, 2.0, 3.0]);
assert!(pred.check(1.0));
assert!(pred.check(2.0));
assert!(!pred.check(4.0));
}
#[test]
fn test_custom_predicate() {
let pred = RefinementPredicate::custom("is_prime", "Checks if number is prime", |n| {
if n < 2.0 {
return false;
}
let n = n as i64;
for i in 2..=((n as f64).sqrt() as i64) {
if n % i == 0 {
return false;
}
}
true
});
assert!(pred.check(2.0));
assert!(pred.check(7.0));
assert!(!pred.check(4.0));
assert!(!pred.check(1.0));
}
#[test]
fn test_refinement_type() {
let pos_int = RefinementType::new("Int")
.with_name("PositiveInt")
.with_predicate(RefinementPredicate::GreaterThan(0.0));
assert_eq!(pos_int.type_name(), "PositiveInt");
assert!(pos_int.check(5.0));
assert!(!pos_int.check(-1.0));
}
#[test]
fn test_dependent_predicate() {
let pred = RefinementPredicate::Dependent {
variable: "n".to_string(),
relation: DependentRelation::LessThan,
};
let mut context = RefinementContext::new();
context.set_value("n", 10.0);
assert!(pred.check_with_context(5.0, &context));
assert!(!pred.check_with_context(15.0, &context));
}
#[test]
fn test_registry_builtins() {
let registry = RefinementRegistry::with_builtins();
assert!(registry.check("PositiveInt", 5.0).expect("unwrap"));
assert!(!registry.check("PositiveInt", -1.0).expect("unwrap"));
assert!(registry.check("Probability", 0.5).expect("unwrap"));
assert!(!registry.check("Probability", 1.5).expect("unwrap"));
assert!(registry.check("Even", 4.0).expect("unwrap"));
assert!(!registry.check("Even", 3.0).expect("unwrap"));
}
#[test]
fn test_predicate_simplification() {
let pred = RefinementPredicate::And(vec![
RefinementPredicate::GreaterThan(0.0),
RefinementPredicate::LessThan(10.0),
RefinementPredicate::GreaterThanOrEqual(1.0),
]);
let simplified = pred.simplify();
assert!(simplified.check(5.0));
assert!(!simplified.check(0.0));
assert!(simplified.check(1.0)); }
#[test]
fn test_predicate_string_repr() {
let pred = RefinementPredicate::Range { min: 0.0, max: 1.0 };
assert_eq!(pred.to_string_repr(), "0 <= x <= 1");
let pred = RefinementPredicate::And(vec![
RefinementPredicate::GreaterThan(0.0),
RefinementPredicate::LessThan(10.0),
]);
assert_eq!(pred.to_string_repr(), "(x > 0 && x < 10)");
}
#[test]
fn test_free_variables() {
let pred = RefinementPredicate::And(vec![
RefinementPredicate::GreaterThan(0.0),
RefinementPredicate::Dependent {
variable: "n".to_string(),
relation: DependentRelation::LessThan,
},
RefinementPredicate::Dependent {
variable: "m".to_string(),
relation: DependentRelation::GreaterThan,
},
]);
let vars = pred.free_variables();
assert_eq!(vars.len(), 2);
assert!(vars.contains(&"m".to_string()));
assert!(vars.contains(&"n".to_string()));
}
#[test]
fn test_refinement_type_repr() {
let ty = RefinementType::new("Int")
.with_name("BoundedInt")
.with_predicate(RefinementPredicate::Range {
min: 0.0,
max: 100.0,
});
assert_eq!(ty.to_string_repr(), "Int{0 <= x <= 100}");
}
#[test]
fn test_context_operations() {
let mut ctx = RefinementContext::new();
ctx.set_value("x", 5.0);
ctx.set_value("y", 10.0);
assert_eq!(ctx.get_value("x"), Some(&5.0));
assert!(ctx.has_variable("x"));
assert!(!ctx.has_variable("z"));
let vars = ctx.variables();
assert_eq!(vars.len(), 2);
}
#[test]
fn test_negation_predicate() {
let pred = RefinementPredicate::Not(Box::new(RefinementPredicate::Equal(0.0)));
assert!(pred.check(5.0));
assert!(!pred.check(0.0));
}
#[test]
fn test_or_predicate() {
let pred = RefinementPredicate::Or(vec![
RefinementPredicate::LessThan(0.0),
RefinementPredicate::GreaterThan(10.0),
]);
assert!(pred.check(-5.0));
assert!(pred.check(15.0));
assert!(!pred.check(5.0));
}
#[test]
fn test_double_negation_simplification() {
let pred = RefinementPredicate::Not(Box::new(RefinementPredicate::Not(Box::new(
RefinementPredicate::GreaterThan(0.0),
))));
let simplified = pred.simplify();
assert!(simplified.check(5.0));
assert!(!simplified.check(-1.0));
}
#[test]
fn test_registry_custom_type() {
let mut registry = RefinementRegistry::new();
registry.register(
RefinementType::new("Float")
.with_name("SmallPositive")
.with_predicate(RefinementPredicate::Range {
min: 0.0,
max: 1e-6,
}),
);
assert!(registry.contains("SmallPositive"));
assert!(registry.check("SmallPositive", 1e-7).expect("unwrap"));
assert!(!registry.check("SmallPositive", 1.0).expect("unwrap"));
}
#[test]
fn test_subtyping_basic() {
let int_type = RefinementType::new("Int");
let float_type = RefinementType::new("Float");
assert!(!int_type.is_subtype_of(&float_type)); assert!(int_type.is_subtype_of(&int_type)); }
#[test]
fn test_subtyping_range_implication() {
let stricter = RefinementType::new("Int").with_predicate(RefinementPredicate::Range {
min: 5.0,
max: 10.0,
});
let looser = RefinementType::new("Int").with_predicate(RefinementPredicate::Range {
min: 0.0,
max: 15.0,
});
assert!(stricter.is_subtype_of(&looser));
assert!(!looser.is_subtype_of(&stricter)); }
#[test]
fn test_subtyping_greater_than_implication() {
let stricter =
RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(10.0));
let looser =
RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(5.0));
assert!(stricter.is_subtype_of(&looser));
assert!(!looser.is_subtype_of(&stricter));
}
#[test]
fn test_subtyping_less_than_implication() {
let stricter =
RefinementType::new("Int").with_predicate(RefinementPredicate::LessThan(5.0));
let looser = RefinementType::new("Int").with_predicate(RefinementPredicate::LessThan(10.0));
assert!(stricter.is_subtype_of(&looser));
assert!(!looser.is_subtype_of(&stricter));
}
#[test]
fn test_subtyping_modulo_implication() {
let divisible_by_4 =
RefinementType::new("Int").with_predicate(RefinementPredicate::Modulo {
divisor: 4,
remainder: 0,
});
let divisible_by_2 =
RefinementType::new("Int").with_predicate(RefinementPredicate::Modulo {
divisor: 2,
remainder: 0,
});
assert!(divisible_by_4.is_subtype_of(&divisible_by_2));
assert!(!divisible_by_2.is_subtype_of(&divisible_by_4));
}
#[test]
fn test_subtyping_conjunction() {
let bounded = RefinementType::new("Int")
.with_predicate(RefinementPredicate::GreaterThan(5.0))
.with_predicate(RefinementPredicate::LessThan(10.0));
let positive =
RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(0.0));
assert!(bounded.is_subtype_of(&positive));
}
#[test]
fn test_subtyping_equality_implies_bounds() {
let exact = RefinementType::new("Int").with_predicate(RefinementPredicate::Equal(7.0));
let gt_5 = RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(5.0));
let lt_10 = RefinementType::new("Int").with_predicate(RefinementPredicate::LessThan(10.0));
assert!(exact.is_subtype_of(>_5));
assert!(exact.is_subtype_of(<_10));
}
#[test]
fn test_subtyping_no_implication() {
let even = RefinementType::new("Int").with_predicate(RefinementPredicate::Modulo {
divisor: 2,
remainder: 0,
});
let gt_5 = RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(5.0));
assert!(!even.is_subtype_of(>_5));
assert!(!gt_5.is_subtype_of(&even));
}
}