use std::collections::hash_map::Entry;
use crate as casuarius;
use ordered_float::OrderedFloat;
mod operators;
mod solver_impl;
pub use operators::Constrainable;
use rustc_hash::FxHashMap;
pub use strength::{MEDIUM, REQUIRED, STRONG, WEAK};
#[macro_use]
pub mod derive_syntax;
#[cfg(doctest)]
pub mod doctest {
#[doc = include_str!("../README.md")]
pub struct ReadmeDoctests;
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct Variable(pub &'static str);
derive_syntax_for!(Variable);
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct Term<T> {
pub variable: T,
pub coefficient: OrderedFloat<f64>,
}
impl<T> Term<T> {
pub fn new(variable: T, coefficient: f64) -> Term<T> {
Term {
variable: variable,
coefficient: coefficient.into(),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct Expression<T> {
pub terms: Vec<Term<T>>,
pub constant: OrderedFloat<f64>,
}
impl<T: Clone> Expression<T> {
pub fn from_constant(v: f64) -> Expression<T> {
Expression {
terms: Vec::new(),
constant: v.into(),
}
}
pub fn from_term(term: Term<T>) -> Expression<T> {
Expression {
terms: vec![term],
constant: 0.0.into(),
}
}
pub fn new(terms: Vec<Term<T>>, constant: f64) -> Expression<T> {
Expression {
terms: terms,
constant: constant.into(),
}
}
pub fn negate(&mut self) {
self.constant = (-(self.constant.into_inner())).into();
for t in &mut self.terms {
let t2 = t.clone();
*t = -t2;
}
}
}
impl<Var: Clone> Constrainable<Var> for Expression<Var> {
fn equal_to<X>(self, x: X) -> Constraint<Var>
where
X: Into<Expression<Var>> + Clone,
{
let lhs = PartialConstraint(self.into(), WeightedRelation::EQ(strength::REQUIRED));
let rhs: Expression<Var> = x.into();
let (op, s) = lhs.1.into();
Constraint::new(lhs.0 - rhs, op, s)
}
fn greater_than_or_equal_to<X>(self, x: X) -> Constraint<Var>
where
X: Into<Expression<Var>> + Clone,
{
let lhs = PartialConstraint(self.into(), WeightedRelation::GE(strength::REQUIRED));
let rhs: Expression<Var> = x.into();
let (op, s) = lhs.1.into();
Constraint::new(lhs.0 - rhs, op, s)
}
fn less_than_or_equal_to<X>(self, x: X) -> Constraint<Var>
where
X: Into<Expression<Var>> + Clone,
{
let lhs = PartialConstraint(self.into(), WeightedRelation::LE(strength::REQUIRED));
let rhs: Expression<Var> = x.into();
let (op, s) = lhs.1.into();
Constraint::new(lhs.0 - rhs, op, s)
}
}
impl<T: Clone> From<f64> for Expression<T> {
fn from(v: f64) -> Expression<T> {
Expression::from_constant(v)
}
}
impl<T: Clone> From<i32> for Expression<T> {
fn from(v: i32) -> Expression<T> {
Expression::from_constant(v as f64)
}
}
impl<T: Clone> From<u32> for Expression<T> {
fn from(v: u32) -> Expression<T> {
Expression::from_constant(v as f64)
}
}
impl<T: Clone> From<Term<T>> for Expression<T> {
fn from(t: Term<T>) -> Expression<T> {
Expression::from_term(t)
}
}
pub mod strength {
pub fn create(a: f64, b: f64, c: f64, w: f64) -> f64 {
(a * w).max(0.0).min(1000.0) * 1_000_000.0
+ (b * w).max(0.0).min(1000.0) * 1000.0
+ (c * w).max(0.0).min(1000.0)
}
pub const REQUIRED: f64 = 1_001_001_000.0;
pub const STRONG: f64 = 1_000_000.0;
pub const MEDIUM: f64 = 1_000.0;
pub const WEAK: f64 = 1.0;
pub fn clip(s: f64) -> f64 {
s.min(REQUIRED).max(0.0)
}
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)]
pub enum RelationalOperator {
LessOrEqual,
Equal,
GreaterOrEqual,
}
impl std::fmt::Display for RelationalOperator {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
match *self {
RelationalOperator::LessOrEqual => write!(fmt, "<=")?,
RelationalOperator::Equal => write!(fmt, "==")?,
RelationalOperator::GreaterOrEqual => write!(fmt, ">=")?,
};
Ok(())
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct Constraint<T> {
expression: Expression<T>,
strength: OrderedFloat<f64>,
op: RelationalOperator,
}
impl<T> Constraint<T> {
pub fn new(e: Expression<T>, op: RelationalOperator, strength: f64) -> Constraint<T> {
Constraint {
expression: e,
op: op,
strength: strength.into(),
}
}
pub fn expr(&self) -> &Expression<T> {
&self.expression
}
pub fn op(&self) -> RelationalOperator {
self.op
}
pub fn strength(&self) -> f64 {
self.strength.into_inner()
}
pub fn with_strength(self, s: f64) -> Self {
let mut c = self;
c.strength = s.into();
c
}
}
#[derive(Debug)]
pub enum WeightedRelation {
EQ(f64),
LE(f64),
GE(f64),
}
impl From<WeightedRelation> for (RelationalOperator, f64) {
fn from(r: WeightedRelation) -> (RelationalOperator, f64) {
use WeightedRelation::*;
match r {
EQ(s) => (RelationalOperator::Equal, s),
LE(s) => (RelationalOperator::LessOrEqual, s),
GE(s) => (RelationalOperator::GreaterOrEqual, s),
}
}
}
#[derive(Debug)]
pub struct PartialConstraint<T>(pub Expression<T>, pub WeightedRelation);
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
enum SymbolType {
Invalid,
External,
Slack,
Error,
Dummy,
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
struct Symbol(usize, SymbolType);
impl Symbol {
fn choose_subject(row: &Row, tag: &Tag) -> Symbol {
for s in row.cells.keys() {
if s.type_() == SymbolType::External {
return *s;
}
}
if tag.marker.type_() == SymbolType::Slack || tag.marker.type_() == SymbolType::Error {
if row.coefficient_for(tag.marker) < 0.0 {
return tag.marker;
}
}
if tag.other.type_() == SymbolType::Slack || tag.other.type_() == SymbolType::Error {
if row.coefficient_for(tag.other) < 0.0 {
return tag.other;
}
}
Symbol::invalid()
}
fn invalid() -> Symbol {
Symbol(0, SymbolType::Invalid)
}
fn type_(&self) -> SymbolType {
self.1
}
}
#[derive(Copy, Clone, Debug)]
struct Tag {
marker: Symbol,
other: Symbol,
}
#[derive(Clone, Debug)]
struct Row {
cells: FxHashMap<Symbol, OrderedFloat<f64>>,
constant: OrderedFloat<f64>,
}
fn near_zero(value: f64) -> bool {
const EPS: f64 = 1E-8;
if value < 0.0 {
-value < EPS
} else {
value < EPS
}
}
impl Row {
pub fn new(constant: f64) -> Row {
Row {
cells: FxHashMap::default(),
constant: constant.into(),
}
}
fn add(&mut self, v: f64) -> f64 {
*(self.constant.as_mut()) += v;
self.constant.into_inner()
}
fn insert_symbol(&mut self, s: Symbol, coefficient: f64) {
match self.cells.entry(s) {
Entry::Vacant(entry) => {
if !near_zero(coefficient) {
entry.insert(coefficient.into());
}
}
Entry::Occupied(mut entry) => {
let ofloat = entry.get_mut();
let float = ofloat.as_mut();
*float += coefficient;
if near_zero(*float) {
entry.remove();
}
}
}
}
fn insert_row(&mut self, other: &Row, coefficient: f64) -> bool {
let constant_diff = other.constant.as_ref() * coefficient;
*self.constant.as_mut() += constant_diff;
for (s, v) in &other.cells {
self.insert_symbol(*s, v.into_inner() * coefficient);
}
constant_diff != 0.0
}
fn remove(&mut self, s: Symbol) {
self.cells.remove(&s);
}
fn reverse_sign(&mut self) {
*self.constant.as_mut() *= -1.0;
for (_, v) in &mut self.cells {
*v.as_mut() *= -1.0;
}
}
fn solve_for_symbol(&mut self, s: Symbol) {
let coeff = -1.0
/ match self.cells.entry(s) {
Entry::Occupied(entry) => entry.remove().into_inner(),
Entry::Vacant(_) => unreachable!(),
};
*self.constant.as_mut() *= coeff;
for (_, v) in &mut self.cells {
*v.as_mut() *= coeff;
}
}
fn solve_for_symbols(&mut self, lhs: Symbol, rhs: Symbol) {
self.insert_symbol(lhs, -1.0);
self.solve_for_symbol(rhs);
}
fn coefficient_for(&self, s: Symbol) -> f64 {
self.cells
.get(&s)
.cloned()
.map(|o| o.into_inner())
.unwrap_or(0.0)
}
fn substitute(&mut self, s: Symbol, row: &Row) -> bool {
if let Some(coeff) = self.cells.remove(&s) {
self.insert_row(row, coeff.into())
} else {
false
}
}
fn all_dummies(&self) -> bool {
for symbol in self.cells.keys() {
if symbol.type_() != SymbolType::Dummy {
return false;
}
}
true
}
fn any_pivotable_symbol(&self) -> Symbol {
for symbol in self.cells.keys() {
if symbol.type_() == SymbolType::Slack || symbol.type_() == SymbolType::Error {
return *symbol;
}
}
Symbol::invalid()
}
fn get_entering_symbol(&self) -> Symbol {
for (symbol, value) in &self.cells {
if symbol.type_() != SymbolType::Dummy && *value.as_ref() < 0.0 {
return symbol.clone();
}
}
Symbol::invalid()
}
}
#[derive(Debug, Copy, Clone)]
pub struct EditConstraintError(&'static str);
#[derive(Debug, Copy, Clone)]
pub enum AddConstraintError {
DuplicateConstraint,
UnsatisfiableConstraint,
InternalSolverError(&'static str),
}
#[derive(Debug, Copy, Clone)]
pub enum RemoveConstraintError {
UnknownConstraint,
InternalSolverError(&'static str),
}
#[derive(Debug, Copy, Clone)]
pub enum AddEditVariableError {
DuplicateEditVariable,
BadRequiredStrength,
}
#[derive(Debug, Copy, Clone)]
pub enum RemoveEditVariableError {
UnknownEditVariable,
InternalSolverError(&'static str),
}
#[derive(Debug, Copy, Clone)]
pub enum SuggestValueError {
UnknownEditVariable,
InternalSolverError(&'static str),
}
#[derive(Debug, Copy, Clone)]
pub struct InternalSolverError(&'static str);
pub use solver_impl::Solver;
#[cfg(test)]
mod tests {
use super::*;
use crate as casuarius;
use std::{
collections::HashMap,
sync::atomic::{AtomicUsize, Ordering},
};
static NEXT_K: AtomicUsize = AtomicUsize::new(0);
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
pub struct Var(usize);
derive_syntax_for!(Var);
impl Var {
pub fn new() -> Var {
Var(NEXT_K.fetch_add(1, Ordering::Relaxed))
}
}
#[test]
fn example() {
let mut names = HashMap::new();
fn print_changes(names: &HashMap<Var, &'static str>, changes: &[(Var, f64)]) {
println!("Changes:");
for &(ref var, ref val) in changes {
println!("{}: {}", names[var], val);
}
}
let window_width = Var::new();
names.insert(window_width, "window_width");
struct Element {
left: Var,
right: Var,
}
let box1 = Element {
left: Var::new(),
right: Var::new(),
};
names.insert(box1.left, "box1.left");
names.insert(box1.right, "box1.right");
let box2 = Element {
left: Var::new(),
right: Var::new(),
};
names.insert(box2.left, "box2.left");
names.insert(box2.right, "box2.right");
let mut solver = Solver::default();
solver
.add_constraint(window_width.is_ge(0.0))
.expect("Could not add window width >= 0");
solver
.add_constraint(window_width.is_le(1000.0))
.expect("Could not add window width <= 1000.0");
solver
.add_constraint(box1.left.is(0.0))
.expect("Could not add left align constraint");
solver
.add_constraint(box2.right.is(window_width))
.expect("Could not add right align constraint");
solver
.add_constraint(box2.left.is_ge(box1.right))
.expect("Could not add no overlap constraint");
solver
.add_constraint(box1.right.is(box1.left + 50.0).with_strength(WEAK))
.expect("Could not add box1 width constraint");
solver
.add_constraint(box2.right.is(box2.left + 100.0).with_strength(WEAK))
.expect("Could not add box2 width constraint");
solver
.add_constraint(box1.left.is_le(box1.right))
.expect("Could not add box1 positive width constraint");
solver
.add_constraint(box2.left.is_le(box2.right))
.expect("Could not add box2 positive width constraint");
solver
.add_edit_variable(window_width, STRONG)
.expect("Could not add window width edit var");
solver
.suggest_value(window_width, 1000.0)
.expect("Could not suggest window width = 1000");
print_changes(&names, solver.fetch_changes());
solver
.suggest_value(window_width, 75.0)
.expect("Could not suggest window width = 75");
print_changes(&names, solver.fetch_changes());
solver
.add_constraint(
((box1.right - box1.left) / 50.0f64)
.is((box2.right - box2.left) / 100.0)
.with_strength(MEDIUM),
)
.unwrap();
print_changes(&names, solver.fetch_changes());
}
#[test]
fn test_quadrilateral() {
struct Point {
x: Var,
y: Var,
}
impl Point {
fn new() -> Point {
Point {
x: Var::new(),
y: Var::new(),
}
}
}
let points = [Point::new(), Point::new(), Point::new(), Point::new()];
let point_starts = [(10.0, 10.0), (10.0, 200.0), (200.0, 200.0), (200.0, 10.0)];
let midpoints = [Point::new(), Point::new(), Point::new(), Point::new()];
let mut solver = Solver::default();
let mut weight = 1.0;
let multiplier = 2.0;
solver.begin_edit();
for i in 0..4 {
solver
.add_constraints(vec![
(points[i].x)
.is(point_starts[i].0)
.with_strength(WEAK * weight),
(points[i].y)
.is(point_starts[i].1)
.with_strength(WEAK * weight),
])
.expect("Could not add initial quad points");
weight *= multiplier;
}
for (start, end) in vec![(0, 1), (1, 2), (2, 3), (3, 0)] {
solver
.add_constraints(vec![
(midpoints[start].x).is((points[start].x + points[end].x) / 2.0),
(midpoints[start].y).is((points[start].y + points[end].y) / 2.0),
])
.expect("Could not add quad midpoints");
}
solver
.add_constraints(vec![
(points[0].x + 20.0f64).is_le(points[2].x),
(points[0].x + 20.0f64).is_le(points[3].x),
(points[1].x + 20.0f64).is_le(points[2].x),
(points[1].x + 20.0f64).is_le(points[3].x),
(points[0].y + 20.0f64).is_le(points[1].y),
(points[0].y + 20.0f64).is_le(points[2].y),
(points[3].y + 20.0f64).is_le(points[1].y),
(points[3].y + 20.0f64).is_le(points[2].y),
])
.expect("Could not add quad midpoint constraints");
for point in &points {
solver
.add_constraints(vec![
point.x.is_ge(0.0),
point.y.is_ge(0.0),
point.x.is_le(500.0),
point.y.is_le(500.0),
])
.expect("Could not add required bounds on quad");
}
solver
.commit_edit()
.expect("Could not commit constraint edit");
assert_eq!(
[
(
solver.get_value(midpoints[0].x),
solver.get_value(midpoints[0].y)
),
(
solver.get_value(midpoints[1].x),
solver.get_value(midpoints[1].y)
),
(
solver.get_value(midpoints[2].x),
solver.get_value(midpoints[2].y)
),
(
solver.get_value(midpoints[3].x),
solver.get_value(midpoints[3].y)
)
],
[(10.0, 105.0), (105.0, 200.0), (200.0, 105.0), (105.0, 10.0)]
);
solver
.add_edit_variable(points[2].x, STRONG)
.expect("Could not add x edit variable for 2nd point");
solver
.add_edit_variable(points[2].y, STRONG)
.expect("Could not add y edit variable for 2nd point");
solver
.suggest_value(points[2].x, 300.0)
.expect("Could not suggest value for x edit variable for 2nd point");
solver
.suggest_value(points[2].y, 400.0)
.expect("Could not suggest value for y edit variable for 2nd point");
assert_eq!(
[
(solver.get_value(points[0].x), solver.get_value(points[0].y)),
(solver.get_value(points[1].x), solver.get_value(points[1].y)),
(solver.get_value(points[2].x), solver.get_value(points[2].y)),
(solver.get_value(points[3].x), solver.get_value(points[3].y))
],
[(10.0, 10.0), (10.0, 200.0), (300.0, 400.0), (200.0, 10.0)]
);
assert_eq!(
[
(
solver.get_value(midpoints[0].x),
solver.get_value(midpoints[0].y)
),
(
solver.get_value(midpoints[1].x),
solver.get_value(midpoints[1].y)
),
(
solver.get_value(midpoints[2].x),
solver.get_value(midpoints[2].y)
),
(
solver.get_value(midpoints[3].x),
solver.get_value(midpoints[3].y)
)
],
[(10.0, 105.0), (155.0, 300.0), (250.0, 205.0), (105.0, 10.0)]
);
}
#[test]
fn can_add_and_remove_constraints() {
let mut solver = Solver::default();
let var = Var(0);
let constraint: Constraint<Var> = var.is(100.0);
solver.add_constraint(constraint.clone()).unwrap();
assert_eq!(solver.get_value(var), 100.0);
solver.remove_constraint(&constraint).unwrap();
solver.add_constraint(var.is(0.0)).unwrap();
assert_eq!(solver.get_value(var), 0.0);
}
#[test]
fn lib_doctest_part_one() {
struct Element {
left: Variable,
right: Variable,
}
let box1 = Element {
left: Variable("box1.left"),
right: Variable("box1.right"),
};
let window_width = Variable("window_width");
let box2 = Element {
left: Variable("box2.left"),
right: Variable("box2.right"),
};
let mut solver = Solver::<Variable>::default();
solver
.add_constraints(vec![
window_width.is_ge(0.0), box1.left.is(0.0), box2.right.is(window_width), box2.left.is_ge(box1.right), (box1.right - box1.left).is(50.0),
(box2.right - box2.left).is(100.0),
box1.left.is_le(box1.right).with_strength(WEAK),
box2.left.is_le(box2.right).with_strength(WEAK),
])
.unwrap();
solver.add_edit_variable(window_width, STRONG).unwrap();
solver.suggest_value(window_width, 300.0).unwrap();
let mut print_changes = || {
println!("Changes:");
solver
.fetch_changes()
.iter()
.for_each(|(var, val)| println!("{}: {}", var.0, val));
};
print_changes();
let ww = solver.get_value(window_width);
let b1l = solver.get_value(box1.left);
let b1r = solver.get_value(box1.right);
let b2l = solver.get_value(box2.left);
let b2r = solver.get_value(box2.right);
println!("window_width: {}", ww);
println!("box1.left {}", b1l);
println!("box1.right {}", b1r);
println!("box2.left {}", b2l);
println!("box2.right {}", b2r);
assert!(ww >= 0.0, "window_width >= 0.0");
assert_eq!(0.0, b1l, "box1.left ({}) == 0", b1l);
assert_eq!(ww, b2r, "box2.right ({}) != ww ({})", b2r, ww);
assert!(b2l >= b1r, "box2.left >= box1.right");
assert!(b1l <= b1r, "box1.left <= box1.right");
assert!(b2l <= b2r, "box2.left <= box2.right");
assert_eq!(50.0, b1r - b1l, "box1 width");
assert_eq!(100.0, b2r - b2l, "box2 width");
}
#[test]
fn neg_zero_sanity() {
let nzero: f64 = -0.0;
let zero: f64 = 0.0;
assert!(nzero == zero);
assert!(!(nzero < zero));
assert!(nzero >= zero);
}
}