use super::generation_id::GenerationId;
use crate::{
event::{Event, EventWithLocation},
model::activation::{Activation, ActivationInner},
planner::{MethodFailure, MethodFunction, MethodResult, MethodSpec, Vertex},
solver::{Reason, SolveError},
thread::ThreadPool,
};
use std::{
fmt::Debug,
sync::{Arc, Mutex},
};
pub struct Method<T> {
is_stay: bool,
name: String,
inputs: Vec<usize>,
outputs: Vec<usize>,
apply: MethodFunction<T>,
}
impl<T> Clone for Method<T> {
fn clone(&self) -> Self {
Self {
is_stay: self.is_stay,
name: self.name.clone(),
inputs: self.inputs.clone(),
outputs: self.outputs.clone(),
apply: self.apply.clone(),
}
}
}
impl<T> PartialEq for Method<T> {
fn eq(&self, other: &Self) -> bool {
self.is_stay == other.is_stay
&& self.name == other.name
&& self.inputs == other.inputs
&& self.outputs == other.outputs
}
}
impl<T> Eq for Method<T> {}
impl<T> Debug for Method<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}({:?} -> {:?})", self.name, self.inputs, self.outputs)
}
}
impl<T> MethodSpec for Method<T> {
type Arg = T;
fn new(
name: String,
inputs: Vec<usize>,
outputs: Vec<usize>,
apply: MethodFunction<T>,
) -> Self {
Self {
is_stay: false,
name,
inputs,
outputs,
apply,
}
}
fn apply(&self, input: Vec<Arc<T>>) -> MethodResult<Arc<T>> {
if input.len() != self.n_inputs() {
return Err(MethodFailure::WrongInputCount(self.n_inputs(), input.len()));
}
let output = (self.apply)(input)?;
if output.len() != self.n_outputs() {
return Err(MethodFailure::WrongOutputCount(
self.n_outputs(),
output.len(),
));
}
Ok(output)
}
fn name(&self) -> &str {
&self.name
}
}
pub type SharedVariableActivationInner<T> = Arc<Mutex<ActivationInner<T>>>;
fn handle_error<T>(
output_indices: &[usize],
shared_states: &Arc<Vec<SharedVariableActivationInner<T>>>,
general_callback: &(impl Fn(EventWithLocation<'_, T, SolveError>) + Send + 'static),
generation: GenerationId,
errors: Vec<SolveError>,
) {
log::error!("{:?}", errors);
for &o in output_indices {
general_callback(EventWithLocation::new(o, generation, Event::Error(&errors)));
}
for shared_state in shared_states.iter() {
shared_state.lock().unwrap().set_error(errors.clone());
}
}
impl<T> Method<T> {
pub(crate) fn activate(
&self,
inputs: Vec<impl Into<Activation<T>>>,
shared_states: Vec<SharedVariableActivationInner<T>>,
location: (String, String),
generation: GenerationId,
pool: &mut impl ThreadPool,
general_callback: impl Fn(EventWithLocation<'_, T, SolveError>) + Send + 'static,
) -> Vec<Activation<T>>
where
T: Send + Sync + 'static + Debug,
Method<T>: Vertex,
{
let inputs: Vec<Activation<T>> = inputs.into_iter().map(|v| v.into()).collect();
let n_inputs = self.n_inputs();
let n_outputs = self.n_outputs();
let output_indices = self.outputs().to_vec();
let m_name = self.name().to_string();
let (component, constraint) = location;
log::trace!("Activating {}", &m_name);
let shared_states = Arc::new(shared_states);
let shared_states_clone = shared_states.clone();
for &o in &output_indices {
general_callback(EventWithLocation::new(o, generation, Event::Pending));
}
let f = self.apply.clone();
let handle = pool
.execute(move || {
let joined_inputs = futures::future::join_all(inputs);
let input_results = futures::executor::block_on(joined_inputs);
let formatted_inputs = format!("{:?}", &input_results);
let mut inputs = Vec::new();
let mut errors = Vec::new();
for state in input_results {
match state {
Ok(value) => inputs.push(value),
Err(es) => errors.extend(es),
}
}
if !errors.is_empty() {
handle_error(
&output_indices,
&shared_states_clone,
&general_callback,
generation,
errors,
);
return;
}
if inputs.len() != n_inputs {
let error = SolveError::new(
component.to_owned(),
constraint.to_owned(),
m_name.clone(),
Reason::MethodFailure(MethodFailure::WrongInputCount(
n_inputs,
inputs.len(),
)),
);
handle_error(
&output_indices,
&shared_states_clone,
&general_callback,
generation,
vec![error],
);
return;
}
let result = f(inputs);
log::trace!("{}({}) = {:?}", m_name, formatted_inputs, result);
match result {
Ok(outputs) => {
if outputs.len() != n_outputs {
let error = SolveError::new(
component.to_owned(),
constraint.to_owned(),
m_name.clone(),
Reason::MethodFailure(MethodFailure::WrongOutputCount(
n_outputs,
outputs.len(),
)),
);
handle_error(
&output_indices,
&shared_states_clone,
&general_callback,
generation,
vec![error],
);
return;
}
for ((st, res), &o) in
shared_states_clone.iter().zip(outputs).zip(&output_indices)
{
general_callback(EventWithLocation::new(
o,
generation,
Event::Ready(&res),
));
let mut shared_state = st.lock().unwrap();
shared_state.set_value_arc(res);
}
}
Err(e) => {
let error = SolveError::new(
component.to_owned(),
constraint.to_owned(),
m_name.clone(),
Reason::MethodFailure(e),
);
handle_error(
&output_indices,
&shared_states_clone,
&general_callback,
generation,
vec![error],
);
}
}
})
.expect("Could not spawn worker");
let values = shared_states
.iter()
.map(|st| Activation {
inner: st.clone(),
producer: Some(handle.clone()),
})
.collect();
values
}
}
impl<T> Vertex for Method<T> {
fn inputs(&self) -> &[usize] {
&self.inputs
}
fn outputs(&self) -> &[usize] {
&self.outputs
}
fn stay(index: usize) -> Self {
Self {
is_stay: true,
name: format!("_stay_{}", index),
inputs: vec![index],
outputs: vec![index],
apply: Arc::new(|_| panic!("stay constraints should not be run")),
}
}
fn is_stay(&self) -> bool {
self.is_stay
}
}