use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HysteresisConstraint {
name: String,
low_threshold: f32,
high_threshold: f32,
state: bool,
weight: f32,
}
impl HysteresisConstraint {
pub fn new(name: impl Into<String>, low_threshold: f32, high_threshold: f32) -> Self {
assert!(
low_threshold < high_threshold,
"Low threshold must be less than high threshold"
);
Self {
name: name.into(),
low_threshold,
high_threshold,
state: false,
weight: 1.0,
}
}
pub fn with_initial_state(mut self, state: bool) -> Self {
self.state = state;
self
}
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = weight;
self
}
pub fn update_and_check(&mut self, value: f32) -> bool {
let new_state = if self.state {
value >= self.low_threshold
} else {
value > self.high_threshold
};
self.state = new_state;
true }
pub fn state(&self) -> bool {
self.state
}
pub fn would_transition(&self, value: f32) -> bool {
if self.state {
value < self.low_threshold
} else {
value > self.high_threshold
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn weight(&self) -> f32 {
self.weight
}
pub fn reset(&mut self, initial_state: bool) {
self.state = initial_state;
}
}
#[derive(Debug, Clone)]
pub struct HysteresisChecker {
constraints: Vec<HysteresisConstraint>,
}
impl HysteresisChecker {
pub fn new() -> Self {
Self {
constraints: Vec::new(),
}
}
pub fn add(&mut self, constraint: HysteresisConstraint) {
self.constraints.push(constraint);
}
pub fn update(&mut self, value: f32) -> Vec<(String, bool)> {
self.constraints
.iter_mut()
.map(|c| {
let state = c.update_and_check(value);
(c.name().to_string(), state)
})
.collect()
}
pub fn states(&self) -> Vec<(String, bool)> {
self.constraints
.iter()
.map(|c| (c.name().to_string(), c.state()))
.collect()
}
pub fn reset(&mut self) {
for c in &mut self.constraints {
c.reset(false);
}
}
}
impl Default for HysteresisChecker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::super::{
AffineEquality, ComposedConstraint, ConstraintBuilder, ConstraintSet, LinearConstraint,
LinearConstraintSet, PenaltyFunction, QuadraticConstraint, QuadraticConstraintSet,
SlidingWindowChecker, SlidingWindowConstraint, SlidingWindowFn, SoftHardConstraint,
TemporalChecker, TemporalConstraintBuilder,
};
#[test]
fn test_less_than_constraint() {
let c = ConstraintBuilder::new()
.name("max_vel")
.less_than(10.0)
.build()
.unwrap();
assert!(c.check(5.0));
assert!(!c.check(15.0));
assert_eq!(c.violation(5.0), 0.0);
assert_eq!(c.violation(15.0), 5.0);
assert_eq!(c.project(15.0), 10.0 - f32::EPSILON);
}
#[test]
fn test_range_constraint() {
let c = ConstraintBuilder::new()
.name("temp_range")
.in_range(-10.0, 100.0)
.build()
.unwrap();
assert!(c.check(50.0));
assert!(!c.check(-20.0));
assert!(!c.check(150.0));
assert_eq!(c.project(-20.0), -10.0);
assert_eq!(c.project(150.0), 100.0);
}
#[test]
fn test_composed_and() {
let c1 = ConstraintBuilder::new()
.name("lower_bound")
.greater_eq(0.0)
.build()
.unwrap();
let c2 = ConstraintBuilder::new()
.name("upper_bound")
.less_eq(10.0)
.build()
.unwrap();
let composed = ComposedConstraint::single(c1).and(ComposedConstraint::single(c2));
assert!(composed.check(5.0)); assert!(!composed.check(-1.0)); assert!(!composed.check(11.0)); }
#[test]
fn test_composed_or() {
let c1 = ConstraintBuilder::new()
.name("negative")
.less_than(0.0)
.build()
.unwrap();
let c2 = ConstraintBuilder::new()
.name("large")
.greater_than(100.0)
.build()
.unwrap();
let composed = ComposedConstraint::single(c1).or(ComposedConstraint::single(c2));
assert!(composed.check(-5.0)); assert!(composed.check(150.0)); assert!(!composed.check(50.0)); }
#[test]
fn test_composed_not() {
let c = ConstraintBuilder::new()
.name("positive")
.greater_eq(0.0)
.build()
.unwrap();
let composed = ComposedConstraint::single(c).negate();
assert!(composed.check(-1.0)); assert!(!composed.check(1.0)); }
#[test]
fn test_composed_implies() {
let premise = ConstraintBuilder::new()
.name("large")
.greater_than(50.0)
.build()
.unwrap();
let conclusion = ConstraintBuilder::new()
.name("bounded")
.less_than(100.0)
.build()
.unwrap();
let composed =
ComposedConstraint::single(premise).implies(ComposedConstraint::single(conclusion));
assert!(composed.check(30.0)); assert!(composed.check(75.0)); assert!(!composed.check(150.0)); }
#[test]
fn test_composed_projection_and() {
let c1 = ConstraintBuilder::new()
.name("lower")
.greater_eq(0.0)
.build()
.unwrap();
let c2 = ConstraintBuilder::new()
.name("upper")
.less_eq(10.0)
.build()
.unwrap();
let composed = ComposedConstraint::single(c1).and(ComposedConstraint::single(c2));
assert_eq!(composed.project(-5.0), 0.0); assert_eq!(composed.project(15.0), 10.0); assert_eq!(composed.project(5.0), 5.0); }
#[test]
fn test_composed_projection_or() {
let c1 = ConstraintBuilder::new()
.name("small")
.less_eq(0.0)
.build()
.unwrap();
let c2 = ConstraintBuilder::new()
.name("large")
.greater_eq(10.0)
.build()
.unwrap();
let composed = ComposedConstraint::single(c1).or(ComposedConstraint::single(c2));
assert_eq!(composed.project(6.0), 10.0);
assert_eq!(composed.project(3.0), 0.0);
}
#[test]
fn test_check_all_dimensions() {
let c = ConstraintBuilder::new()
.name("bounded")
.in_range(-1.0, 1.0)
.build()
.unwrap();
let composed = ComposedConstraint::single(c);
assert!(composed.check_all(&[0.0, 0.5, -0.5]));
assert!(!composed.check_all(&[0.0, 2.0, 0.0]));
}
#[test]
fn test_temporal_max_rate() {
let c = TemporalConstraintBuilder::new()
.name("max_velocity")
.max_rate(10.0) .dt(0.1) .build()
.unwrap();
assert!(!c.check(0.0, 5.0));
assert!(c.check(0.0, 0.5));
assert!(c.check(0.5, 0.0));
}
#[test]
fn test_temporal_rate_range() {
let c = TemporalConstraintBuilder::new()
.name("rate_bounded")
.rate_range(-5.0, 10.0)
.dt(1.0)
.build()
.unwrap();
assert!(c.check(0.0, 3.0));
assert!(c.check(3.0, 0.0));
assert!(!c.check(0.0, 15.0));
assert!(!c.check(10.0, 0.0));
}
#[test]
fn test_temporal_monotonic() {
let inc = TemporalConstraintBuilder::new()
.name("monotonic_inc")
.monotonic_increasing()
.dt(1.0)
.build()
.unwrap();
let dec = TemporalConstraintBuilder::new()
.name("monotonic_dec")
.monotonic_decreasing()
.dt(1.0)
.build()
.unwrap();
assert!(inc.check(0.0, 1.0)); assert!(!inc.check(1.0, 0.0)); assert!(inc.check(0.0, 0.0));
assert!(!dec.check(0.0, 1.0)); assert!(dec.check(1.0, 0.0)); assert!(dec.check(0.0, 0.0)); }
#[test]
fn test_temporal_projection() {
let c = TemporalConstraintBuilder::new()
.name("max_rate")
.max_rate(10.0)
.dt(0.1)
.build()
.unwrap();
let projected = c.project(0.0, 5.0);
assert!((projected - 1.0).abs() < 1e-5);
let projected = c.project(0.0, 0.5);
assert!((projected - 0.5).abs() < 1e-5);
let projected = c.project(5.0, 0.0);
assert!((projected - 4.0).abs() < 1e-5);
}
#[test]
fn test_temporal_checker() {
let c = TemporalConstraintBuilder::new()
.name("velocity_limit")
.max_rate(10.0)
.dt(1.0)
.build()
.unwrap();
let mut checker = TemporalChecker::new(vec![c]);
let result = checker.check(&[0.0]);
assert!(result[0].1);
let result = checker.check(&[5.0]);
assert!(result[0].1);
let result = checker.check(&[55.0]);
assert!(!result[0].1);
}
#[test]
fn test_temporal_checker_project() {
let c = TemporalConstraintBuilder::new()
.name("rate_limit")
.max_rate(10.0)
.dt(1.0)
.build()
.unwrap();
let mut checker = TemporalChecker::new(vec![c]);
let _ = checker.project(&[0.0]);
let projected = checker.project(&[50.0]);
assert!((projected[0] - 10.0).abs() < 1e-5);
let projected = checker.project(&[5.0]);
assert!((projected[0] - 5.0).abs() < 1e-5);
}
#[test]
fn test_temporal_violation() {
let c = TemporalConstraintBuilder::new()
.name("rate_limit")
.max_rate(10.0)
.dt(1.0)
.build()
.unwrap();
assert!((c.violation(0.0, 15.0) - 5.0).abs() < 1e-5);
assert_eq!(c.violation(0.0, 5.0), 0.0);
}
#[test]
fn test_temporal_builder_validation() {
let result = TemporalConstraintBuilder::new()
.max_rate(10.0)
.dt(1.0)
.build();
assert!(result.is_err());
let result = TemporalConstraintBuilder::new()
.name("test")
.dt(1.0)
.build();
assert!(result.is_err());
let result = TemporalConstraintBuilder::new()
.name("test")
.max_rate(10.0)
.build();
assert!(result.is_err());
let result = TemporalConstraintBuilder::new()
.name("test")
.max_rate(10.0)
.dt(0.0)
.build();
assert!(result.is_err());
let result = TemporalConstraintBuilder::new()
.name("test")
.max_rate(10.0)
.dt(-1.0)
.build();
assert!(result.is_err());
}
#[test]
fn test_temporal_multi_dimension() {
let c = TemporalConstraintBuilder::new()
.name("dim1_rate")
.dimension(1)
.max_rate(10.0)
.dt(1.0)
.build()
.unwrap();
let mut checker = TemporalChecker::new(vec![c]);
let _ = checker.check(&[0.0, 0.0, 0.0]);
let result = checker.check(&[100.0, 5.0, 100.0]);
assert!(result[0].1);
let result = checker.check(&[0.0, 55.0, 0.0]);
assert!(!result[0].1);
}
#[test]
fn test_linear_equality_constraint() {
let c = LinearConstraint::equality(
vec![2.0, 3.0],
7.0,
0.01, );
assert!(c.check(&[2.0, 1.0])); assert!(c.check(&[0.5, 2.0])); assert!(!c.check(&[1.0, 1.0])); }
#[test]
fn test_linear_inequality_constraint() {
let c = LinearConstraint::less_eq(vec![1.0, 1.0], 10.0);
assert!(c.check(&[3.0, 5.0])); assert!(c.check(&[5.0, 5.0])); assert!(!c.check(&[6.0, 5.0])); }
#[test]
fn test_linear_constraint_projection() {
let c = LinearConstraint::less_eq(vec![1.0, 1.0], 5.0);
let proj = c.project(&[2.0, 2.0]);
assert!((proj[0] - 2.0).abs() < 0.01);
assert!((proj[1] - 2.0).abs() < 0.01);
let proj = c.project(&[4.0, 4.0]);
let sum: f32 = proj.iter().sum();
assert!((sum - 5.0).abs() < 0.1);
}
#[test]
fn test_linear_constraint_set() {
let constraints = vec![
LinearConstraint::less_eq(vec![1.0, 0.0], 5.0), LinearConstraint::less_eq(vec![0.0, 1.0], 5.0), LinearConstraint::greater_eq(vec![1.0, 0.0], 0.0), LinearConstraint::greater_eq(vec![0.0, 1.0], 0.0), ];
let set = LinearConstraintSet::new(constraints);
assert!(set.check_all(&[2.0, 3.0])); assert!(!set.check_all(&[6.0, 3.0])); assert!(!set.check_all(&[-1.0, 3.0])); }
#[test]
fn test_affine_equality() {
let a = vec![vec![1.0, 1.0, 1.0], vec![1.0, -1.0, 0.0]];
let b = vec![6.0, 0.0];
let c = AffineEquality::new(a, b, 0.01);
assert!(c.check(&[2.0, 2.0, 2.0])); assert!(!c.check(&[1.0, 2.0, 3.0])); }
#[test]
fn test_quadratic_ball_constraint() {
let c = QuadraticConstraint::ball(vec![0.0, 0.0], 2.0);
assert!(c.check(&[0.0, 0.0])); assert!(c.check(&[1.0, 1.0])); assert!(c.check(&[2.0, 0.0])); assert!(!c.check(&[2.0, 2.0])); }
#[test]
fn test_quadratic_ball_with_center() {
let c = QuadraticConstraint::ball(vec![1.0, 1.0], 1.0);
assert!(c.check(&[1.0, 1.0])); assert!(c.check(&[1.5, 1.0])); assert!(!c.check(&[3.0, 3.0])); }
#[test]
fn test_quadratic_constraint_violation() {
let c = QuadraticConstraint::ball(vec![0.0, 0.0], 1.0);
assert_eq!(c.violation(&[0.5, 0.0]), 0.0); assert!((c.violation(&[1.0, 1.0]) - 1.0).abs() < 0.01); }
#[test]
fn test_quadratic_constraint_gradient() {
let q = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let c = QuadraticConstraint::less_eq(q, vec![0.0, 0.0], 1.0);
let grad = c.gradient(&[1.0, 2.0]);
assert!((grad[0] - 2.0).abs() < 0.01);
assert!((grad[1] - 4.0).abs() < 0.01);
}
#[test]
fn test_quadratic_constraint_projection() {
let c = QuadraticConstraint::ball(vec![0.0, 0.0], 1.0);
let proj = c.project(&[0.5, 0.0], 100, 0.1);
assert!((proj[0] - 0.5).abs() < 0.1);
let proj = c.project(&[2.0, 0.0], 100, 0.1);
let norm: f32 = proj.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.2); }
#[test]
fn test_quadratic_ellipsoid() {
let a = vec![vec![4.0, 0.0], vec![0.0, 1.0]];
let c = QuadraticConstraint::ellipsoid(a, vec![0.0, 0.0]);
assert!(c.check(&[0.0, 0.0])); assert!(c.check(&[0.25, 0.5])); assert!(!c.check(&[0.5, 0.5])); }
#[test]
fn test_quadratic_constraint_set() {
let c1 = QuadraticConstraint::ball(vec![0.0, 0.0], 2.0);
let c2 = QuadraticConstraint::ball(vec![1.0, 0.0], 2.0);
let set = QuadraticConstraintSet::new(vec![c1, c2]);
assert!(set.check_all(&[0.5, 0.0])); assert!(!set.check_all(&[-2.0, 0.0])); assert!(!set.check_all(&[3.0, 0.0])); }
#[test]
fn test_sliding_window_mean() {
let mut c = SlidingWindowConstraint::new(
"mean_check",
3,
SlidingWindowFn::MeanInRange { lo: 0.0, hi: 10.0 },
);
let _ = c.push_and_check(1.0);
let _ = c.push_and_check(2.0);
let (sat, _) = c.push_and_check(3.0); assert!(sat);
let (sat, _) = c.push_and_check(30.0); assert!(!sat);
}
#[test]
fn test_sliding_window_variance() {
let mut c = SlidingWindowConstraint::new(
"var_check",
3,
SlidingWindowFn::MaxVariance { max_var: 1.0 },
);
let _ = c.push_and_check(1.0);
let _ = c.push_and_check(1.0);
let (sat, _) = c.push_and_check(1.0); assert!(sat);
let (sat, _) = c.push_and_check(10.0); assert!(!sat);
}
#[test]
fn test_sliding_window_max_range() {
let mut c = SlidingWindowConstraint::new(
"range_check",
4,
SlidingWindowFn::MaxRange { max_range: 5.0 },
);
for val in [1.0, 2.0, 3.0, 4.0] {
let _ = c.push_and_check(val);
}
assert!(c.is_ready());
let (sat, _) = c.check_window(); assert!(sat);
let _ = c.push_and_check(10.0); let (sat, _) = c.check_window();
assert!(!sat);
}
#[test]
fn test_sliding_window_bounded_variation() {
let mut c = SlidingWindowConstraint::new(
"bv_check",
4,
SlidingWindowFn::BoundedVariation { max_variation: 3.0 },
);
for val in [0.0, 1.0, 2.0, 3.0] {
let _ = c.push_and_check(val);
}
let (sat, _) = c.check_window();
assert!(sat);
let _ = c.push_and_check(10.0); let (sat, _) = c.check_window();
assert!(!sat);
}
#[test]
fn test_sliding_window_trend() {
let mut c = SlidingWindowConstraint::new(
"trend_check",
4,
SlidingWindowFn::TrendInRange {
min_slope: -0.5,
max_slope: 2.0,
},
);
for val in [0.0, 1.0, 2.0, 3.0] {
let _ = c.push_and_check(val);
}
let (sat, _) = c.check_window(); assert!(sat);
let _ = c.push_and_check(20.0); let (sat, _) = c.check_window();
assert!(!sat); }
#[test]
fn test_sliding_window_checker() {
let mut checker = SlidingWindowChecker::new();
checker.add(SlidingWindowConstraint::new(
"mean",
3,
SlidingWindowFn::MeanInRange { lo: 0.0, hi: 10.0 },
));
checker.add(SlidingWindowConstraint::new(
"range",
3,
SlidingWindowFn::MaxRange { max_range: 5.0 },
));
for val in [1.0, 2.0, 3.0] {
let _ = checker.push_and_check(val);
}
assert!(checker.all_satisfied());
let _ = checker.push_and_check(20.0);
assert!(!checker.all_satisfied());
}
#[test]
fn test_sliding_window_all_in_range() {
let mut c = SlidingWindowConstraint::new(
"all_check",
3,
SlidingWindowFn::AllInRange { lo: 0.0, hi: 5.0 },
);
for val in [1.0, 2.0, 3.0] {
let _ = c.push_and_check(val);
}
let (sat, _) = c.check_window();
assert!(sat);
let _ = c.push_and_check(10.0); let (sat, viol) = c.check_window();
assert!(!sat);
assert!((viol - 5.0).abs() < 0.01); }
#[test]
fn test_sliding_window_any_in_range() {
let mut c = SlidingWindowConstraint::new(
"any_check",
3,
SlidingWindowFn::AnyInRange { lo: 0.0, hi: 5.0 },
);
for val in [10.0, 20.0, 30.0] {
let _ = c.push_and_check(val);
}
let (sat, _) = c.check_window(); assert!(!sat);
let _ = c.push_and_check(3.0); let (sat, _) = c.check_window();
assert!(sat);
}
#[test]
fn test_penalty_function_l1() {
let penalty = PenaltyFunction::L1;
assert_eq!(penalty.compute(0.0, 1.0), 0.0); assert_eq!(penalty.compute(2.0, 1.0), 2.0); assert_eq!(penalty.compute(3.0, 2.0), 6.0); }
#[test]
fn test_penalty_function_l2() {
let penalty = PenaltyFunction::L2;
assert_eq!(penalty.compute(0.0, 1.0), 0.0);
assert_eq!(penalty.compute(2.0, 1.0), 4.0); assert_eq!(penalty.compute(3.0, 2.0), 18.0); }
#[test]
fn test_penalty_function_huber() {
let penalty = PenaltyFunction::Huber { delta: 1.0 };
assert!((penalty.compute(0.5, 1.0) - 0.125).abs() < 0.01);
assert!((penalty.compute(2.0, 1.0) - 1.5).abs() < 0.01); }
#[test]
fn test_soft_hard_constraint_modes() {
let c_hard = SoftHardConstraint::hard(LinearConstraint::less_eq(vec![1.0], 5.0));
assert!(c_hard.is_hard());
assert!(!c_hard.is_soft());
let c_soft = SoftHardConstraint::soft(LinearConstraint::less_eq(vec![1.0], 5.0));
assert!(c_soft.is_soft());
assert!(!c_soft.is_hard());
}
#[test]
fn test_soft_constraint_loss() {
let c = SoftHardConstraint::soft(LinearConstraint::less_eq(vec![1.0], 5.0))
.with_penalty(PenaltyFunction::L2)
.with_weight(2.0);
assert_eq!(c.loss(&[3.0]), 0.0);
assert!((c.loss(&[7.0]) - 8.0).abs() < 0.01);
}
#[test]
fn test_hard_constraint_loss() {
let c = SoftHardConstraint::hard(LinearConstraint::less_eq(vec![1.0], 5.0));
assert_eq!(c.loss(&[3.0]), 0.0);
assert_eq!(c.loss(&[7.0]), f32::MAX);
}
#[test]
fn test_constraint_set_mixed() {
let mut set: ConstraintSet<LinearConstraint> = ConstraintSet::new();
set.add_hard(LinearConstraint::greater_eq(vec![1.0], 0.0));
set.add_soft(
LinearConstraint::less_eq(vec![1.0], 10.0),
PenaltyFunction::L2,
1.0,
);
assert!(set.all_satisfied(&[5.0]));
assert!(set.all_hard_satisfied(&[5.0]));
assert_eq!(set.soft_loss(&[5.0]), 0.0);
assert!(!set.all_satisfied(&[15.0]));
assert!(set.all_hard_satisfied(&[15.0]));
assert!((set.soft_loss(&[15.0]) - 25.0).abs() < 0.01);
assert!(!set.all_satisfied(&[-1.0]));
assert!(!set.all_hard_satisfied(&[-1.0]));
assert_eq!(set.total_loss(&[-1.0]), f32::MAX);
}
#[test]
fn test_constraint_priority() {
let mut set: ConstraintSet<LinearConstraint> = ConstraintSet::new();
set.add(
SoftHardConstraint::soft(LinearConstraint::less_eq(vec![1.0], 5.0)).with_priority(1),
);
set.add(
SoftHardConstraint::soft(LinearConstraint::less_eq(vec![1.0], 10.0)).with_priority(3),
);
set.add(
SoftHardConstraint::soft(LinearConstraint::less_eq(vec![1.0], 8.0)).with_priority(2),
);
let sorted = set.by_priority();
assert_eq!(sorted[0].priority(), 3);
assert_eq!(sorted[1].priority(), 2);
assert_eq!(sorted[2].priority(), 1);
}
#[test]
fn test_penalty_gradient() {
let penalty = PenaltyFunction::L2;
assert_eq!(penalty.gradient(0.0, 1.0), 0.0);
assert!((penalty.gradient(2.0, 1.0) - 4.0).abs() < 0.01);
let penalty = PenaltyFunction::L1;
assert_eq!(penalty.gradient(0.0, 1.0), 0.0);
assert_eq!(penalty.gradient(5.0, 2.0), 2.0);
}
#[test]
fn test_soft_constraint_with_quadratic() {
let ball = QuadraticConstraint::ball(vec![0.0, 0.0], 1.0);
let c = SoftHardConstraint::soft(ball)
.with_penalty(PenaltyFunction::L2)
.with_weight(1.0);
assert!(c.check(&[0.5, 0.0]));
assert_eq!(c.loss(&[0.5, 0.0]), 0.0);
assert!(!c.check(&[2.0, 0.0]));
assert!((c.loss(&[2.0, 0.0]) - 9.0).abs() < 0.1); }
}