use crate::error::{LogicError, LogicResult};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LogicalOperator {
And,
Or,
Not,
Implies,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ComposedConstraint {
Single(Constraint),
And(Box<ComposedConstraint>, Box<ComposedConstraint>),
Or(Box<ComposedConstraint>, Box<ComposedConstraint>),
Not(Box<ComposedConstraint>),
Implies(Box<ComposedConstraint>, Box<ComposedConstraint>),
}
impl ComposedConstraint {
pub fn single(constraint: Constraint) -> Self {
Self::Single(constraint)
}
pub fn and(self, other: ComposedConstraint) -> Self {
Self::And(Box::new(self), Box::new(other))
}
pub fn or(self, other: ComposedConstraint) -> Self {
Self::Or(Box::new(self), Box::new(other))
}
pub fn negate(self) -> Self {
Self::Not(Box::new(self))
}
pub fn implies(self, other: ComposedConstraint) -> Self {
Self::Implies(Box::new(self), Box::new(other))
}
pub fn check(&self, value: f32) -> bool {
match self {
Self::Single(c) => c.check(value),
Self::And(a, b) => a.check(value) && b.check(value),
Self::Or(a, b) => a.check(value) || b.check(value),
Self::Not(c) => !c.check(value),
Self::Implies(a, b) => !a.check(value) || b.check(value),
}
}
pub fn check_all(&self, values: &[f32]) -> bool {
match self {
Self::Single(c) => {
if let Some(dim) = c.dimension() {
values.get(dim).is_some_and(|&v| c.check(v))
} else {
values.iter().all(|&v| c.check(v))
}
}
Self::And(a, b) => a.check_all(values) && b.check_all(values),
Self::Or(a, b) => a.check_all(values) || b.check_all(values),
Self::Not(c) => !c.check_all(values),
Self::Implies(a, b) => !a.check_all(values) || b.check_all(values),
}
}
pub fn violation(&self, value: f32) -> f32 {
match self {
Self::Single(c) => c.violation(value),
Self::And(a, b) => a.violation(value) + b.violation(value),
Self::Or(a, b) => a.violation(value).min(b.violation(value)),
Self::Not(c) => {
if c.check(value) {
1.0
} else {
0.0
}
}
Self::Implies(a, b) => {
if a.check(value) && !b.check(value) {
b.violation(value)
} else {
0.0
}
}
}
}
pub fn project(&self, value: f32) -> f32 {
match self {
Self::Single(c) => c.project(value),
Self::And(a, b) => {
let v1 = a.project(value);
b.project(v1)
}
Self::Or(a, b) => {
let proj_a = a.project(value);
let proj_b = b.project(value);
let dist_a = (value - proj_a).abs();
let dist_b = (value - proj_b).abs();
if dist_a <= dist_b {
proj_a
} else {
proj_b
}
}
Self::Not(_) => {
value
}
Self::Implies(a, b) => {
if a.check(value) {
b.project(value)
} else {
value
}
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BoundType {
LessThan(f32),
LessEq(f32),
GreaterThan(f32),
GreaterEq(f32),
Equal(f32, f32), InRange(f32, f32),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Constraint {
name: String,
dimension: Option<usize>,
bound: BoundType,
weight: f32,
}
impl Constraint {
pub fn check(&self, value: f32) -> bool {
match &self.bound {
BoundType::LessThan(b) => value < *b,
BoundType::LessEq(b) => value <= *b,
BoundType::GreaterThan(b) => value > *b,
BoundType::GreaterEq(b) => value >= *b,
BoundType::Equal(target, tol) => (value - target).abs() <= *tol,
BoundType::InRange(lo, hi) => value >= *lo && value <= *hi,
}
}
pub fn violation(&self, value: f32) -> f32 {
match &self.bound {
BoundType::LessThan(b) | BoundType::LessEq(b) => (value - b).max(0.0),
BoundType::GreaterThan(b) | BoundType::GreaterEq(b) => (b - value).max(0.0),
BoundType::Equal(target, _) => (value - target).abs(),
BoundType::InRange(lo, hi) => {
if value < *lo {
lo - value
} else if value > *hi {
value - hi
} else {
0.0
}
}
}
}
pub fn project(&self, value: f32) -> f32 {
match &self.bound {
BoundType::LessThan(b) => value.min(*b - f32::EPSILON),
BoundType::LessEq(b) => value.min(*b),
BoundType::GreaterThan(b) => value.max(*b + f32::EPSILON),
BoundType::GreaterEq(b) => value.max(*b),
BoundType::Equal(target, _) => *target,
BoundType::InRange(lo, hi) => value.clamp(*lo, *hi),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn dimension(&self) -> Option<usize> {
self.dimension
}
pub fn weight(&self) -> f32 {
self.weight
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RateType {
MaxRate(f32),
RateRange { min_rate: f32, max_rate: f32 },
MonotonicIncreasing,
MonotonicDecreasing,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemporalConstraint {
name: String,
dimension: Option<usize>,
rate_type: RateType,
dt: f32,
weight: f32,
}
impl TemporalConstraint {
pub fn name(&self) -> &str {
&self.name
}
pub fn dimension(&self) -> Option<usize> {
self.dimension
}
pub fn dt(&self) -> f32 {
self.dt
}
pub fn weight(&self) -> f32 {
self.weight
}
pub fn check(&self, prev_value: f32, current_value: f32) -> bool {
let rate = (current_value - prev_value) / self.dt;
match &self.rate_type {
RateType::MaxRate(max) => rate.abs() <= *max,
RateType::RateRange { min_rate, max_rate } => rate >= *min_rate && rate <= *max_rate,
RateType::MonotonicIncreasing => rate >= 0.0,
RateType::MonotonicDecreasing => rate <= 0.0,
}
}
pub fn violation(&self, prev_value: f32, current_value: f32) -> f32 {
let rate = (current_value - prev_value) / self.dt;
match &self.rate_type {
RateType::MaxRate(max) => (rate.abs() - max).max(0.0),
RateType::RateRange { min_rate, max_rate } => {
if rate < *min_rate {
min_rate - rate
} else if rate > *max_rate {
rate - max_rate
} else {
0.0
}
}
RateType::MonotonicIncreasing => (-rate).max(0.0),
RateType::MonotonicDecreasing => rate.max(0.0),
}
}
pub fn project(&self, prev_value: f32, current_value: f32) -> f32 {
let rate = (current_value - prev_value) / self.dt;
match &self.rate_type {
RateType::MaxRate(max) => {
if rate.abs() <= *max {
current_value
} else {
prev_value + rate.signum() * max * self.dt
}
}
RateType::RateRange { min_rate, max_rate } => {
let clamped_rate = rate.clamp(*min_rate, *max_rate);
prev_value + clamped_rate * self.dt
}
RateType::MonotonicIncreasing => {
if rate >= 0.0 {
current_value
} else {
prev_value }
}
RateType::MonotonicDecreasing => {
if rate <= 0.0 {
current_value
} else {
prev_value }
}
}
}
}
#[derive(Default)]
pub struct TemporalConstraintBuilder {
name: Option<String>,
dimension: Option<usize>,
rate_type: Option<RateType>,
dt: Option<f32>,
weight: f32,
}
impl TemporalConstraintBuilder {
pub fn new() -> Self {
Self {
weight: 1.0,
..Default::default()
}
}
pub fn name(mut self, name: &str) -> Self {
self.name = Some(name.to_string());
self
}
pub fn dimension(mut self, dim: usize) -> Self {
self.dimension = Some(dim);
self
}
pub fn max_rate(mut self, max_rate: f32) -> Self {
self.rate_type = Some(RateType::MaxRate(max_rate));
self
}
pub fn rate_range(mut self, min_rate: f32, max_rate: f32) -> Self {
self.rate_type = Some(RateType::RateRange { min_rate, max_rate });
self
}
pub fn monotonic_increasing(mut self) -> Self {
self.rate_type = Some(RateType::MonotonicIncreasing);
self
}
pub fn monotonic_decreasing(mut self) -> Self {
self.rate_type = Some(RateType::MonotonicDecreasing);
self
}
pub fn dt(mut self, dt: f32) -> Self {
self.dt = Some(dt);
self
}
pub fn weight(mut self, w: f32) -> Self {
self.weight = w;
self
}
pub fn build(self) -> LogicResult<TemporalConstraint> {
let name = self
.name
.ok_or_else(|| LogicError::InvalidConstraint("name is required".into()))?;
let rate_type = self
.rate_type
.ok_or_else(|| LogicError::InvalidConstraint("rate_type is required".into()))?;
let dt = self
.dt
.ok_or_else(|| LogicError::InvalidConstraint("dt (time step) is required".into()))?;
if dt <= 0.0 {
return Err(LogicError::InvalidConstraint("dt must be positive".into()));
}
Ok(TemporalConstraint {
name,
dimension: self.dimension,
rate_type,
dt,
weight: self.weight,
})
}
}
#[derive(Debug, Clone)]
pub struct TemporalChecker {
constraints: Vec<TemporalConstraint>,
prev_values: Vec<f32>,
initialized: bool,
}
impl TemporalChecker {
pub fn new(constraints: Vec<TemporalConstraint>) -> Self {
Self {
constraints,
prev_values: Vec::new(),
initialized: false,
}
}
pub fn reset(&mut self) {
self.prev_values.clear();
self.initialized = false;
}
pub fn check(&mut self, values: &[f32]) -> Vec<(String, bool)> {
if !self.initialized {
self.prev_values = values.to_vec();
self.initialized = true;
return self
.constraints
.iter()
.map(|c| (c.name.clone(), true))
.collect();
}
let results: Vec<(String, bool)> = self
.constraints
.iter()
.map(|c| {
let result = if let Some(dim) = c.dimension() {
if dim < values.len() && dim < self.prev_values.len() {
c.check(self.prev_values[dim], values[dim])
} else {
true }
} else {
values
.iter()
.zip(self.prev_values.iter())
.all(|(&curr, &prev)| c.check(prev, curr))
};
(c.name.clone(), result)
})
.collect();
self.prev_values = values.to_vec();
results
}
pub fn total_violation(&mut self, values: &[f32]) -> f32 {
if !self.initialized {
self.prev_values = values.to_vec();
self.initialized = true;
return 0.0;
}
let violation: f32 = self
.constraints
.iter()
.map(|c| {
let v = if let Some(dim) = c.dimension() {
if dim < values.len() && dim < self.prev_values.len() {
c.violation(self.prev_values[dim], values[dim])
} else {
0.0
}
} else {
values
.iter()
.zip(self.prev_values.iter())
.map(|(&curr, &prev)| c.violation(prev, curr))
.sum()
};
v * c.weight()
})
.sum();
self.prev_values = values.to_vec();
violation
}
pub fn project(&mut self, values: &[f32]) -> Vec<f32> {
if !self.initialized {
self.prev_values = values.to_vec();
self.initialized = true;
return values.to_vec();
}
let mut projected = values.to_vec();
for c in &self.constraints {
if let Some(dim) = c.dimension() {
if dim < projected.len() && dim < self.prev_values.len() {
projected[dim] = c.project(self.prev_values[dim], projected[dim]);
}
} else {
for i in 0..projected.len().min(self.prev_values.len()) {
projected[i] = c.project(self.prev_values[i], projected[i]);
}
}
}
self.prev_values = projected.clone();
projected
}
pub fn all_satisfied(&mut self, values: &[f32]) -> bool {
self.check(values).iter().all(|(_, sat)| *sat)
}
}
pub struct ConstraintBuilder {
name: Option<String>,
dimension: Option<usize>,
bound: Option<BoundType>,
weight: f32,
}
impl Default for ConstraintBuilder {
fn default() -> Self {
Self::new()
}
}
impl ConstraintBuilder {
pub fn new() -> Self {
Self {
name: None,
dimension: None,
bound: None,
weight: 1.0,
}
}
pub fn name(mut self, name: &str) -> Self {
self.name = Some(name.to_string());
self
}
pub fn dimension(mut self, dim: usize) -> Self {
self.dimension = Some(dim);
self
}
pub fn less_than(mut self, value: f32) -> Self {
self.bound = Some(BoundType::LessThan(value));
self
}
pub fn less_eq(mut self, value: f32) -> Self {
self.bound = Some(BoundType::LessEq(value));
self
}
pub fn greater_than(mut self, value: f32) -> Self {
self.bound = Some(BoundType::GreaterThan(value));
self
}
pub fn greater_eq(mut self, value: f32) -> Self {
self.bound = Some(BoundType::GreaterEq(value));
self
}
pub fn equal(mut self, value: f32, tolerance: f32) -> Self {
self.bound = Some(BoundType::Equal(value, tolerance));
self
}
pub fn in_range(mut self, lo: f32, hi: f32) -> Self {
self.bound = Some(BoundType::InRange(lo, hi));
self
}
pub fn weight(mut self, w: f32) -> Self {
self.weight = w;
self
}
pub fn build(self) -> LogicResult<Constraint> {
let name = self
.name
.ok_or_else(|| LogicError::InvalidConstraint("name is required".into()))?;
let bound = self
.bound
.ok_or_else(|| LogicError::InvalidConstraint("bound is required".into()))?;
Ok(Constraint {
name,
dimension: self.dimension,
bound,
weight: self.weight,
})
}
}