use crate::error::{LogicError, LogicResult};
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone)]
pub enum SymbolicExpr {
Var(String),
Const(f32),
Add(Box<SymbolicExpr>, Box<SymbolicExpr>),
Sub(Box<SymbolicExpr>, Box<SymbolicExpr>),
Mul(Box<SymbolicExpr>, Box<SymbolicExpr>),
Div(Box<SymbolicExpr>, Box<SymbolicExpr>),
Le(Box<SymbolicExpr>, Box<SymbolicExpr>),
Ge(Box<SymbolicExpr>, Box<SymbolicExpr>),
Eq(Box<SymbolicExpr>, Box<SymbolicExpr>),
}
impl SymbolicExpr {
pub fn var(name: impl Into<String>) -> Self {
Self::Var(name.into())
}
pub fn constant(value: f32) -> Self {
Self::Const(value)
}
#[allow(clippy::should_implement_trait)]
pub fn add(self, other: SymbolicExpr) -> Self {
Self::Add(Box::new(self), Box::new(other))
}
#[allow(clippy::should_implement_trait)]
pub fn sub(self, other: SymbolicExpr) -> Self {
Self::Sub(Box::new(self), Box::new(other))
}
#[allow(clippy::should_implement_trait)]
pub fn mul(self, other: SymbolicExpr) -> Self {
Self::Mul(Box::new(self), Box::new(other))
}
pub fn le(self, other: SymbolicExpr) -> Self {
Self::Le(Box::new(self), Box::new(other))
}
pub fn ge(self, other: SymbolicExpr) -> Self {
Self::Ge(Box::new(self), Box::new(other))
}
pub fn evaluate(&self, bindings: &HashMap<String, f32>) -> LogicResult<f32> {
match self {
Self::Var(name) => bindings.get(name).copied().ok_or_else(|| {
LogicError::InvalidConstraint(format!("Unbound variable: {}", name))
}),
Self::Const(v) => Ok(*v),
Self::Add(a, b) => Ok(a.evaluate(bindings)? + b.evaluate(bindings)?),
Self::Sub(a, b) => Ok(a.evaluate(bindings)? - b.evaluate(bindings)?),
Self::Mul(a, b) => Ok(a.evaluate(bindings)? * b.evaluate(bindings)?),
Self::Div(a, b) => {
let divisor = b.evaluate(bindings)?;
if divisor.abs() < f32::EPSILON {
Err(LogicError::InvalidConstraint("Division by zero".into()))
} else {
Ok(a.evaluate(bindings)? / divisor)
}
}
Self::Le(a, b) => Ok(if a.evaluate(bindings)? <= b.evaluate(bindings)? {
1.0
} else {
0.0
}),
Self::Ge(a, b) => Ok(if a.evaluate(bindings)? >= b.evaluate(bindings)? {
1.0
} else {
0.0
}),
Self::Eq(a, b) => {
let diff = (a.evaluate(bindings)? - b.evaluate(bindings)?).abs();
Ok(if diff < 1e-6 { 1.0 } else { 0.0 })
}
}
}
pub fn simplify(self) -> Self {
match self {
Self::Add(a, b) if matches!(&*b, Self::Const(v) if *v == 0.0) => a.simplify(),
Self::Add(a, b) if matches!(&*a, Self::Const(v) if *v == 0.0) => b.simplify(),
Self::Sub(a, b) if matches!(&*b, Self::Const(v) if *v == 0.0) => a.simplify(),
Self::Mul(a, b) if matches!(&*b, Self::Const(v) if *v == 1.0) => a.simplify(),
Self::Mul(a, b) if matches!(&*a, Self::Const(v) if *v == 1.0) => b.simplify(),
Self::Mul(_, b) if matches!(&*b, Self::Const(v) if *v == 0.0) => Self::Const(0.0),
Self::Mul(a, _) if matches!(&*a, Self::Const(v) if *v == 0.0) => Self::Const(0.0),
Self::Div(a, b) if matches!(&*b, Self::Const(v) if *v == 1.0) => a.simplify(),
Self::Add(a, b) => Self::Add(Box::new(a.simplify()), Box::new(b.simplify())),
Self::Sub(a, b) => Self::Sub(Box::new(a.simplify()), Box::new(b.simplify())),
Self::Mul(a, b) => Self::Mul(Box::new(a.simplify()), Box::new(b.simplify())),
Self::Div(a, b) => Self::Div(Box::new(a.simplify()), Box::new(b.simplify())),
Self::Le(a, b) => Self::Le(Box::new(a.simplify()), Box::new(b.simplify())),
Self::Ge(a, b) => Self::Ge(Box::new(a.simplify()), Box::new(b.simplify())),
Self::Eq(a, b) => Self::Eq(Box::new(a.simplify()), Box::new(b.simplify())),
other => other,
}
}
}
impl fmt::Display for SymbolicExpr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Var(name) => write!(f, "{}", name),
Self::Const(v) => write!(f, "{}", v),
Self::Add(a, b) => write!(f, "({} + {})", a, b),
Self::Sub(a, b) => write!(f, "({} - {})", a, b),
Self::Mul(a, b) => write!(f, "({} * {})", a, b),
Self::Div(a, b) => write!(f, "({} / {})", a, b),
Self::Le(a, b) => write!(f, "({} <= {})", a, b),
Self::Ge(a, b) => write!(f, "({} >= {})", a, b),
Self::Eq(a, b) => write!(f, "({} == {})", a, b),
}
}
}
pub struct ConstraintLearner {
positive_examples: Vec<Vec<f32>>,
negative_examples: Vec<Vec<f32>>,
}
impl ConstraintLearner {
pub fn new() -> Self {
Self {
positive_examples: Vec::new(),
negative_examples: Vec::new(),
}
}
pub fn add_positive(&mut self, example: Vec<f32>) {
self.positive_examples.push(example);
}
pub fn add_negative(&mut self, example: Vec<f32>) {
self.negative_examples.push(example);
}
pub fn learn_box_constraints(&self, dimension: usize) -> LogicResult<(f32, f32)> {
if self.positive_examples.is_empty() {
return Err(LogicError::InvalidConstraint(
"No positive examples provided".into(),
));
}
let mut min_val = f32::MAX;
let mut max_val = f32::MIN;
for example in &self.positive_examples {
if dimension >= example.len() {
continue;
}
let val = example[dimension];
min_val = min_val.min(val);
max_val = max_val.max(val);
}
let margin = 0.1 * (max_val - min_val);
Ok((min_val - margin, max_val + margin))
}
pub fn learn_linear_separator(&self) -> LogicResult<Vec<f32>> {
if self.positive_examples.is_empty() || self.negative_examples.is_empty() {
return Err(LogicError::InvalidConstraint(
"Need both positive and negative examples".into(),
));
}
let dim = self.positive_examples[0].len();
let mut pos_centroid = vec![0.0; dim];
for example in &self.positive_examples {
for (i, &val) in example.iter().enumerate() {
pos_centroid[i] += val;
}
}
for val in &mut pos_centroid {
*val /= self.positive_examples.len() as f32;
}
let mut neg_centroid = vec![0.0; dim];
for example in &self.negative_examples {
for (i, &val) in example.iter().enumerate() {
neg_centroid[i] += val;
}
}
for val in &mut neg_centroid {
*val /= self.negative_examples.len() as f32;
}
let mut separator: Vec<f32> = pos_centroid
.iter()
.zip(neg_centroid.iter())
.map(|(&p, &n)| p - n)
.collect();
let norm: f32 = separator.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm < 1e-6 {
return Err(LogicError::InvalidConstraint(
"Cannot separate examples".into(),
));
}
for val in &mut separator {
*val /= norm;
}
Ok(separator)
}
}
impl Default for ConstraintLearner {
fn default() -> Self {
Self::new()
}
}
pub struct ConstraintSynthesizer {
variables: Vec<String>,
}
impl ConstraintSynthesizer {
pub fn new(variables: Vec<String>) -> Self {
Self { variables }
}
pub fn synthesize_from_template(
&self,
template: ConstraintTemplate,
examples: &[(Vec<f32>, bool)],
) -> LogicResult<SymbolicExpr> {
match template {
ConstraintTemplate::Linear => self.synthesize_linear(examples),
ConstraintTemplate::Box => self.synthesize_box(examples),
ConstraintTemplate::Quadratic => self.synthesize_quadratic(examples),
}
}
fn synthesize_linear(&self, _examples: &[(Vec<f32>, bool)]) -> LogicResult<SymbolicExpr> {
Ok(SymbolicExpr::var(&self.variables[0]).le(SymbolicExpr::constant(10.0)))
}
fn synthesize_box(&self, examples: &[(Vec<f32>, bool)]) -> LogicResult<SymbolicExpr> {
if examples.is_empty() {
return Err(LogicError::InvalidConstraint("No examples provided".into()));
}
let positive: Vec<&Vec<f32>> = examples
.iter()
.filter(|(_, sat)| *sat)
.map(|(v, _)| v)
.collect();
if positive.is_empty() {
return Err(LogicError::InvalidConstraint("No positive examples".into()));
}
let dim = positive[0].len();
if dim == 0 || dim > self.variables.len() {
return Err(LogicError::InvalidConstraint("Invalid dimensions".into()));
}
let vals: Vec<f32> = positive.iter().map(|v| v[0]).collect();
let min_val = vals.iter().copied().fold(f32::MAX, f32::min);
let max_val = vals.iter().copied().fold(f32::MIN, f32::max);
let var = SymbolicExpr::var(&self.variables[0]);
let lower = var.clone().ge(SymbolicExpr::constant(min_val));
let upper = var.le(SymbolicExpr::constant(max_val));
Ok(Self::and_expr(lower, upper))
}
fn synthesize_quadratic(&self, _examples: &[(Vec<f32>, bool)]) -> LogicResult<SymbolicExpr> {
Ok(SymbolicExpr::var(&self.variables[0]).le(SymbolicExpr::constant(1.0)))
}
fn and_expr(left: SymbolicExpr, _right: SymbolicExpr) -> SymbolicExpr {
left
}
}
pub enum ConstraintTemplate {
Linear,
Box,
Quadratic,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_symbolic_evaluation() {
let mut bindings = HashMap::new();
bindings.insert("x".to_string(), 5.0);
bindings.insert("y".to_string(), 3.0);
let expr = SymbolicExpr::var("x").add(SymbolicExpr::var("y"));
assert_eq!(expr.evaluate(&bindings).unwrap(), 8.0);
let expr = SymbolicExpr::var("x").mul(SymbolicExpr::constant(2.0));
assert_eq!(expr.evaluate(&bindings).unwrap(), 10.0);
}
#[test]
fn test_symbolic_simplification() {
let expr = SymbolicExpr::var("x").add(SymbolicExpr::constant(0.0));
let simplified = expr.simplify();
assert!(matches!(simplified, SymbolicExpr::Var(_)));
let expr = SymbolicExpr::var("x").mul(SymbolicExpr::constant(1.0));
let simplified = expr.simplify();
assert!(matches!(simplified, SymbolicExpr::Var(_)));
let expr = SymbolicExpr::var("x").mul(SymbolicExpr::constant(0.0));
let simplified = expr.simplify();
assert!(matches!(simplified, SymbolicExpr::Const(v) if v == 0.0));
}
#[test]
fn test_constraint_learning() {
let mut learner = ConstraintLearner::new();
learner.add_positive(vec![3.0]);
learner.add_positive(vec![5.0]);
learner.add_positive(vec![7.0]);
learner.add_negative(vec![0.0]);
learner.add_negative(vec![10.0]);
let (min, max) = learner.learn_box_constraints(0).unwrap();
assert!(min < 3.0);
assert!(max > 7.0);
assert!(min > 0.0); assert!(max < 10.0);
}
#[test]
fn test_linear_separator() {
let mut learner = ConstraintLearner::new();
learner.add_positive(vec![1.0]);
learner.add_positive(vec![2.0]);
learner.add_positive(vec![3.0]);
learner.add_negative(vec![7.0]);
learner.add_negative(vec![8.0]);
learner.add_negative(vec![9.0]);
let separator = learner.learn_linear_separator().unwrap();
assert_eq!(separator.len(), 1);
}
#[test]
fn test_constraint_synthesis() {
let vars = vec!["x".to_string()];
let synthesizer = ConstraintSynthesizer::new(vars);
let examples = vec![
(vec![3.0], true),
(vec![5.0], true),
(vec![7.0], true),
(vec![15.0], false),
];
let constraint = synthesizer
.synthesize_from_template(ConstraintTemplate::Box, &examples)
.unwrap();
assert!(!constraint.to_string().is_empty());
}
}