use super::{
activation::State,
constraint::Constraint,
errors::{NoSuchConstraint, NoSuchVariable},
filtered_callback::FilteredCallback,
generation_id::GenerationId,
method::Method,
undo::{NoMoreRedo, NoMoreUndo, UndoLimit},
variable::Variable,
variables::Variables,
};
use crate::{
event::{Event, EventWithLocation},
model::activation::Activation,
planner::{
hierarchical_planner, priority_adjuster::adjust_priorities, ComponentSpec, ConstraintSpec,
MethodSpec, OwnedEnforcedConstraint, PlanError, Vertex,
},
solver::SolveError,
thread::{DummyPool, ThreadPool},
variable_ranking::{SortRanker, VariableRanker},
};
use itertools::Itertools;
use std::{
collections::{HashMap, HashSet},
fmt::{self, Debug, Write},
ops::{Index, IndexMut},
sync::{Arc, Mutex},
};
#[derive(derivative::Derivative)]
#[derivative(Clone(bound = ""), Debug, Default(bound = ""))]
pub struct Component<T> {
name: String,
name_to_index: HashMap<String, usize>,
callbacks: Arc<Mutex<Vec<FilteredCallback<T, SolveError>>>>,
variables: Variables<Activation<T>>,
constraints: Vec<Constraint<T>>,
ranker: SortRanker,
updated_since_last_solve: HashSet<usize>,
n_ready: usize,
current_generation: usize,
total_generation: usize,
}
impl<T> Component<T> {
pub fn new_empty(name: String) -> Self {
Self {
name,
..Default::default()
}
}
pub fn subscribe<'s>(
&mut self,
variable: &'s str,
callback: impl Fn(Event<'_, T, SolveError>) + Send + 'static,
) -> Result<(), NoSuchVariable<'s>>
where
T: 'static,
{
if let Some(&index) = self.name_to_index.get(variable) {
let activation = &self.variables[index];
let inner = activation.inner().lock().unwrap();
match inner.state() {
State::Pending => callback(Event::Pending),
State::Ready(value) => callback(Event::Ready(value)),
State::Error(errors) => callback(Event::Error(errors)),
}
self.callbacks.lock().unwrap()[index].subscribe(callback);
Ok(())
} else {
Err(NoSuchVariable(variable))
}
}
pub fn unsubscribe<'s>(&mut self, variable: &'s str) -> Result<(), NoSuchVariable<'s>> {
if let Some(&index) = self.name_to_index.get(variable) {
self.callbacks.lock().unwrap()[index].unsubscribe();
Ok(())
} else {
Err(NoSuchVariable(variable))
}
}
pub fn set_variable<'s>(
&mut self,
variable: &'s str,
value: impl Into<T>,
) -> Result<(), NoSuchVariable<'s>> {
let idx = self.variable_index(variable)?;
self.updated_since_last_solve.insert(idx);
self.ranker.touch(idx);
let value = value.into();
self.callbacks.lock().unwrap()[idx].call(EventWithLocation::new(
idx,
GenerationId::new(self.current_generation, self.total_generation),
Event::Ready(&value),
));
self.variables.set(idx, Activation::from(value));
Ok(())
}
pub fn variable<'a>(
&self,
variable: &'a str,
) -> Result<&Variable<Activation<T>>, NoSuchVariable<'a>> {
let idx = self.variable_index(variable)?;
self.variables.get(idx).ok_or(NoSuchVariable(variable))
}
pub fn value<'a>(&self, variable: &'a str) -> Result<Activation<T>, NoSuchVariable<'a>> {
let idx = self.variable_index(variable)?;
self.variables
.get(idx)
.map(Variable::get)
.cloned()
.ok_or(NoSuchVariable(variable))
}
fn variable_index<'s>(&self, variable: &'s str) -> Result<usize, NoSuchVariable<'s>> {
match self.name_to_index.get(variable) {
Some(&index) => Ok(index),
None => Err(NoSuchVariable(variable)),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn set_name<S: Into<String>>(&mut self, name: S) {
self.name = name.into();
}
pub fn variable_names(&self) -> Vec<&str> {
self.name_to_index.keys().map(String::as_str).collect()
}
pub fn variables(&self) -> &[Variable<Activation<T>>] {
self.variables.variables()
}
pub fn values(&self) -> Vec<&Activation<T>> {
self.variables.values()
}
pub fn constraint<'a>(&self, name: &'a str) -> Result<&Constraint<T>, NoSuchConstraint<'a>> {
self.constraints
.iter()
.find(|c| c.name() == name)
.ok_or(NoSuchConstraint(name))
}
pub fn constraint_mut<'a>(
&mut self,
name: &'a str,
) -> Result<&mut Constraint<T>, NoSuchConstraint<'a>> {
self.constraints
.iter_mut()
.find(|c| c.name() == name)
.ok_or(NoSuchConstraint(name))
}
pub fn new_with_map(
name: String,
name_to_idx: HashMap<String, usize>,
values: Vec<T>,
constraints: Vec<Constraint<T>>,
) -> Self {
let mut idx_to_name: HashMap<usize, String> = HashMap::new();
for (k, v) in &name_to_idx {
idx_to_name.insert(*v, k.clone());
}
let mut component = Component::new(name, values, constraints);
component.name_to_index = name_to_idx;
component
}
pub fn update(&mut self) -> Result<(), PlanError>
where
T: Send + Sync + 'static + Debug,
{
self.par_update(&mut DummyPool)
}
pub fn par_update(&mut self, pool: &mut impl ThreadPool) -> Result<(), PlanError>
where
T: Send + Sync + 'static + Debug,
{
let plan = hierarchical_planner(self)?;
self.solve(pool, plan);
Ok(())
}
fn solve(&mut self, pool: &mut impl ThreadPool, plan: Vec<OwnedEnforcedConstraint<Method<T>>>)
where
T: Send + Sync + 'static + Debug,
{
self.ranker = adjust_priorities(&plan, &self.ranker);
self.updated_since_last_solve.clear();
let component_name = self.name().to_owned();
let variable_information_clone = self.callbacks.clone();
self.current_generation += 1;
self.total_generation += 1;
let generation = GenerationId::new(self.current_generation, self.total_generation);
for fcb in &mut *self.callbacks.lock().unwrap() {
fcb.set_target(generation);
}
crate::solver::par_solve(
&plan,
&mut self.variables,
component_name,
generation,
pool,
move |ge| {
let mut lock = variable_information_clone.lock().unwrap();
let fcb = &mut lock[ge.variable()];
fcb.call(ge);
},
);
self.variables.commit();
}
pub fn pin<'s>(&mut self, variable: &'s str) -> Result<(), NoSuchVariable<'s>>
where
T: 'static,
{
let idx = self.variable_index(variable)?;
self.constraints.push(Constraint::new(vec![Method::new(
"pin".to_owned() + &idx.to_string(),
vec![idx],
vec![idx],
Arc::new(Ok),
)]));
Ok(())
}
pub fn unpin<'s>(&mut self, variable: &'s str) -> Result<(), NoSuchVariable<'s>>
where
T: 'static,
{
let idx = self.variable_index(variable)?;
self.constraints.drain_filter(|c| {
c.methods()
.get(0)
.map(|m| m.name() == "pin".to_owned() + &idx.to_string())
.unwrap_or(false)
});
Ok(())
}
pub fn is_modified(&self) -> bool {
!self.updated_since_last_solve.is_empty()
}
pub fn ranking(&self) -> Vec<usize> {
self.ranker.ranking()
}
pub fn to_dot_detailed(&self) -> Result<String, fmt::Error> {
let mut index_to_name = HashMap::new();
for (k, v) in &self.name_to_index {
index_to_name.insert(v, k);
}
let mut buffer = String::new();
writeln!(buffer, "strict digraph {} {{", self.name())?;
writeln!(buffer, " rankdir=LR;")?;
for vi in 0..self.n_variables() {
if let Some(name) = index_to_name.get(&vi) {
writeln!(buffer, " {} [shape=box];", name)?;
}
}
for c in self.constraints() {
writeln!(buffer, " subgraph {} {{", c.name())?;
writeln!(buffer, " color=gray;")?;
writeln!(buffer, " style=filled;")?;
writeln!(buffer, " style=rounded;")?;
writeln!(buffer, " label={};", c.name())?;
for m in c.methods() {
writeln!(
buffer,
" {}_{} [label={}];",
c.name(),
m.name(),
m.name()
)?;
}
write!(buffer, " {{ rank = same; ")?;
for m in c.methods() {
write!(buffer, "{}_{}; ", c.name(), m.name())?;
}
writeln!(buffer, "}}")?;
writeln!(buffer, " }}")?;
}
for c in self.constraints() {
for m in c.methods() {
for i in m.inputs() {
let var_name = index_to_name[i];
writeln!(
buffer,
" {} -> {}_{} [style=dotted];",
&var_name,
c.name(),
m.name()
)?;
}
for o in m.outputs() {
let var_name = index_to_name[o];
writeln!(buffer, " {}_{} -> {};", c.name(), m.name(), var_name)?;
}
}
}
writeln!(buffer, "}}")?;
Ok(buffer)
}
pub fn to_dot_simple(&self) -> Result<String, fmt::Error> {
let mut index_to_name = HashMap::new();
for (k, v) in &self.name_to_index {
index_to_name.insert(v, k);
}
let mut buffer = String::new();
writeln!(buffer, "strict graph {} {{", self.name())?;
for vi in 0..self.n_variables() {
if let Some(name) = index_to_name.get(&vi) {
writeln!(buffer, " {} [shape=box];", name)?;
}
}
for c in self.constraints() {
for v in c.variables() {
let var_name = index_to_name.get(v).unwrap();
writeln!(buffer, " {} -- {};", c.name(), var_name)?;
}
}
writeln!(buffer, "}}")?;
Ok(buffer)
}
fn notify(&self, callbacks: &[FilteredCallback<T, SolveError>]) {
for (vi, v) in callbacks.iter().enumerate() {
let va = &self.variables[vi];
let inner = va.inner().lock().unwrap();
let event = match inner.state() {
State::Ready(value) => Event::Ready(value.as_ref()),
State::Error(errors) => Event::Error(errors),
State::Pending => Event::Pending,
};
v.call(EventWithLocation::new(
vi,
GenerationId::new(self.current_generation, self.total_generation),
event,
));
}
}
pub fn undo(&mut self) -> Result<(), NoMoreUndo> {
let mut callbacks = self.callbacks.lock().unwrap();
self.variables.undo()?;
self.current_generation -= 1;
self.total_generation += 1;
for fcb in callbacks.iter_mut() {
fcb.set_target(GenerationId::new(
self.current_generation,
self.total_generation,
));
}
self.notify(&callbacks);
Ok(())
}
pub fn redo(&mut self) -> Result<(), NoMoreRedo> {
let mut callbacks = self.callbacks.lock().unwrap();
self.variables.redo()?;
self.current_generation += 1;
self.total_generation += 1;
for fcb in callbacks.iter_mut() {
fcb.set_target(GenerationId::new(
self.current_generation,
self.total_generation,
));
}
self.notify(&callbacks);
Ok(())
}
pub fn new_with_undo_limit(
name: String,
values: Vec<T>,
constraints: Vec<Constraint<T>>,
limit: usize,
) -> Self {
let n_variables = values.len();
let values =
Variables::new_with_limit(values.into_iter().map(|v| v.into()).collect(), limit);
Self {
name,
variables: values,
callbacks: Arc::new(Mutex::new(vec![FilteredCallback::new(); n_variables])),
constraints,
ranker: VariableRanker::of_size(n_variables),
updated_since_last_solve: (0..n_variables).collect(),
n_ready: n_variables,
..Default::default()
}
}
pub fn set_undo_limit(&mut self, limit: UndoLimit) {
self.variables.set_limit(limit);
}
pub fn enable_constraint<'a>(&mut self, name: &'a str) -> Result<(), NoSuchConstraint<'a>> {
self.constraint_mut(name).map(|c| c.set_active(true))
}
pub fn disable_constraint<'a>(&mut self, name: &'a str) -> Result<(), NoSuchConstraint<'a>> {
self.constraint_mut(name).map(|c| c.set_active(false))
}
}
impl<T> ComponentSpec for Component<T> {
type Value = Activation<T>;
type Constraint = Constraint<T>;
fn new(
name: String,
values: Vec<impl Into<Self::Value>>,
constraints: Vec<Self::Constraint>,
) -> Self {
let n_variables = values.len();
let values = Variables::new(values.into_iter().map_into().collect());
Self {
name,
variables: values,
callbacks: Arc::new(Mutex::new(vec![FilteredCallback::new(); n_variables])),
constraints,
ranker: VariableRanker::of_size(n_variables),
updated_since_last_solve: (0..n_variables).collect(),
n_ready: n_variables,
..Default::default()
}
}
fn n_variables(&self) -> usize {
self.variables.n_variables()
}
fn constraints(&self) -> &[Self::Constraint] {
&self.constraints
}
fn constraints_mut(&mut self) -> &mut Vec<Self::Constraint> {
&mut self.constraints
}
fn add_constraint(&mut self, constraint: Self::Constraint) {
self.constraints.push(constraint)
}
fn pop_constraint(&mut self) -> Option<Self::Constraint> {
self.constraints.pop()
}
fn remove_constraint(&mut self, idx: usize) -> Self::Constraint {
self.constraints.remove(idx)
}
fn ranking(&self) -> Vec<usize> {
self.ranking()
}
}
impl<T> Index<&str> for Component<T> {
type Output = Constraint<T>;
fn index(&self, index: &str) -> &Self::Output {
for constraint in &self.constraints {
if constraint.name() == index {
return constraint;
}
}
panic!("No constraint named {}", index)
}
}
impl<T> IndexMut<&str> for Component<T> {
fn index_mut(&mut self, index: &str) -> &mut Self::Output {
for constraint in &mut self.constraints {
if constraint.name() == index {
return constraint;
}
}
panic!("No constraint named {}", index)
}
}
impl<T: PartialEq> PartialEq for Component<T> {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
&& self.name_to_index == other.name_to_index
&& self.variables == other.variables
&& self.constraints == other.constraints
&& self.ranker == other.ranker
&& self.updated_since_last_solve == other.updated_since_last_solve
}
}
#[cfg(test)]
mod tests {
use super::Component;
use crate::{
component, examples::components::numbers::sum, model::activation::Activation, ret,
thread::DummyPool,
};
#[test]
fn solve_sum() {
let mut component: Component<i32> = sum();
assert_eq!(
&component.values(),
&[
&Activation::from(0),
&Activation::from(0),
&Activation::from(0)
]
);
component.set_variable("a", 3).unwrap();
component.par_update(&mut DummyPool).unwrap();
assert_eq!(
&component.values(),
&[
&Activation::from(3),
&Activation::from(0),
&Activation::from(3)
]
);
component.set_variable("c", 2).unwrap();
component.update().unwrap();
assert_eq!(
&component.values(),
&[
&Activation::from(3),
&Activation::from(-1),
&Activation::from(2)
]
);
}
#[test]
fn pin_unpin() {
let mut component: Component<i32> = sum();
let val1 = 3;
component.pin("c").unwrap();
component.set_variable("a", val1).unwrap();
component.update().unwrap();
assert_eq!(
&component.values(),
&[
&Activation::from(val1),
&Activation::from(-val1),
&Activation::from(0)
]
);
let val2 = 5;
component.unpin("c").unwrap();
component.set_variable("a", val2).unwrap();
component.update().unwrap();
assert_eq!(
&component.values(),
&[
&Activation::from(val2),
&Activation::from(-val1),
&Activation::from(val2 - val1)
]
);
}
#[test]
fn undo_redo_works() {
let mut component: Component<i32> = sum();
component.set_variable("a", 3).unwrap();
component.update().unwrap();
assert_eq!(
&component.values(),
&[
&Activation::from(3),
&Activation::from(0),
&Activation::from(3)
]
);
assert_eq!(component.undo(), Ok(()));
assert_eq!(
&component.values(),
&[
&Activation::from(0),
&Activation::from(0),
&Activation::from(0)
]
);
assert_eq!(component.redo(), Ok(()));
assert_eq!(
&component.values(),
&[
&Activation::from(3),
&Activation::from(0),
&Activation::from(3)
]
);
}
#[test]
fn enable_disable_constraint() {
let mut component = component! {
component A {
let a: i32, b: i32, c: i32, d: i32;
constraint Ab {
right(a: &i32) -> [b] = ret![*a];
left(b: &i32) -> [a] = ret![*b];
}
constraint Bc {
right(b: &i32) -> [c] = ret![*b];
left(c: &i32) -> [b] = ret![*c];
}
constraint Cd {
right(c: &i32) -> [d] = ret![*c];
left(d: &i32) -> [c] = ret![*d];
}
}
};
component.set_variable("a", 1).unwrap();
component.update().unwrap();
assert_eq!(
component.values(),
vec![
&Activation::from(1),
&Activation::from(1),
&Activation::from(1),
&Activation::from(1),
]
);
component.disable_constraint("Bc").unwrap();
component.set_variable("b", 2).unwrap();
component.set_variable("d", 3).unwrap();
component.update().unwrap();
assert_eq!(
component.values(),
vec![
&Activation::from(2),
&Activation::from(2),
&Activation::from(3),
&Activation::from(3),
]
);
component.enable_constraint("Bc").unwrap();
component.update().unwrap();
assert_eq!(
component.values(),
vec![
&Activation::from(3),
&Activation::from(3),
&Activation::from(3),
&Activation::from(3),
]
);
}
}