use crate::epi::channel::{MessageReceiver, MessageSender};
use crate::epi::solver::SolverMessage;
use crate::problem::ProblemSpec;
use crate::registry::FactoryRegistry;
use crate::rng::{ArcRng, Rng as _};
use crate::solver::{Solver, SolverFactory, SolverRecipe, SolverSpec};
use crate::trial::{EvaluatedTrial, IdGen, NextTrial};
use crate::{Error, ErrorKind, Result};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::cell::RefCell;
use std::collections::HashMap;
use std::path::PathBuf;
use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
use std::sync::atomic::{self, AtomicU64};
use std::sync::{Arc, Mutex};
use std::thread_local;
use structopt::StructOpt;
thread_local! {
static FACTORIES: RefCell<HashMap<Vec<u8>, ExternalProgramSolverFactory>> =
RefCell::new(HashMap::new());
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, StructOpt, Serialize, Deserialize)]
#[structopt(rename_all = "kebab-case")]
pub struct ExternalProgramSolverRecipe {
pub path: PathBuf,
pub args: Vec<String>,
}
impl ExternalProgramSolverRecipe {
fn create_new_factory(
&self,
_registry: &FactoryRegistry,
) -> Result<ExternalProgramSolverFactory> {
let mut child = track!(Command::new(&self.path)
.args(&self.args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.map_err(Error::from))?;
let stdin = track_assert_some!(child.stdin.take(), ErrorKind::IoError);
let stdout = track_assert_some!(child.stdout.take(), ErrorKind::IoError);
let tx = MessageSender::new(stdin);
let mut rx = MessageReceiver::new(stdout);
let spec = match track!(rx.recv())? {
SolverMessage::SolverSpecCast { spec } => spec,
m => track_panic!(ErrorKind::InvalidInput, "Unexpected message: {:?}", m),
};
Ok(ExternalProgramSolverFactory(Arc::new(
ExternalProgramSolverFactoryInner {
spec,
child,
tx: Arc::new(Mutex::new(tx)),
rx: Arc::new(Mutex::new(rx)),
next_solver_id: AtomicU64::new(0),
},
)))
}
fn cache_key(&self) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(&*self.path.to_string_lossy());
for arg in &self.args {
hasher.update(arg.as_bytes());
}
hasher.finalize().to_vec()
}
}
impl SolverRecipe for ExternalProgramSolverRecipe {
type Factory = ExternalProgramSolverFactory;
fn create_factory(&self, registry: &FactoryRegistry) -> Result<Self::Factory> {
FACTORIES.with(|f| {
let mut f = f.borrow_mut();
let key = self.cache_key();
if !f.contains_key(&key) {
f.insert(key.clone(), track!(self.create_new_factory(registry))?);
}
Ok(f[&key].clone())
})
}
}
#[derive(Debug, Clone)]
pub struct ExternalProgramSolverFactory(Arc<ExternalProgramSolverFactoryInner>);
impl SolverFactory for ExternalProgramSolverFactory {
type Solver = ExternalProgramSolver;
fn specification(&self) -> Result<SolverSpec> {
track!(self.0.specification())
}
fn create_solver(&self, rng: ArcRng, problem: &ProblemSpec) -> Result<Self::Solver> {
track!(self.0.create_solver(rng, problem))
}
}
#[derive(Debug)]
struct ExternalProgramSolverFactoryInner {
spec: SolverSpec,
child: Child,
tx: Arc<Mutex<MessageSender<SolverMessage, ChildStdin>>>,
rx: Arc<Mutex<MessageReceiver<SolverMessage, ChildStdout>>>,
next_solver_id: AtomicU64,
}
impl SolverFactory for ExternalProgramSolverFactoryInner {
type Solver = ExternalProgramSolver;
fn specification(&self) -> Result<SolverSpec> {
Ok(self.spec.clone())
}
fn create_solver(&self, mut rng: ArcRng, problem: &ProblemSpec) -> Result<Self::Solver> {
let solver_id = self.next_solver_id.fetch_add(1, atomic::Ordering::SeqCst);
let m = SolverMessage::CreateSolverCast {
solver_id,
random_seed: rng.gen(),
problem: problem.clone(),
};
let mut tx = track!(self.tx.lock().map_err(Error::from))?;
track!(tx.send(&m))?;
Ok(ExternalProgramSolver {
solver_id,
tx: Arc::clone(&self.tx),
rx: Arc::clone(&self.rx),
})
}
}
impl Drop for ExternalProgramSolverFactoryInner {
fn drop(&mut self) {
if self.child.kill().is_ok() {
let _ = self.child.wait(); }
}
}
#[derive(Debug)]
pub struct ExternalProgramSolver {
solver_id: u64,
tx: Arc<Mutex<MessageSender<SolverMessage, ChildStdin>>>,
rx: Arc<Mutex<MessageReceiver<SolverMessage, ChildStdout>>>,
}
impl Solver for ExternalProgramSolver {
fn ask(&mut self, idg: &mut IdGen) -> Result<NextTrial> {
let m = SolverMessage::AskCall {
solver_id: self.solver_id,
next_trial_id: idg.peek_id().get(),
};
let mut tx = track!(self.tx.lock().map_err(Error::from))?;
track!(tx.send(&m))?;
let mut rx = track!(self.rx.lock().map_err(Error::from))?;
match track!(rx.recv())? {
SolverMessage::AskReply {
trial,
next_trial_id,
} => {
track_assert!(
idg.peek_id().get() <= next_trial_id,
ErrorKind::InvalidInput; idg.peek_id().get(), next_trial_id
);
while idg.peek_id().get() < next_trial_id {
idg.generate();
}
Ok(trial)
}
SolverMessage::ErrorReply { kind, message } => {
if let Some(message) = message {
track_panic!(kind, "{}", message);
} else {
track_panic!(kind);
}
}
m => {
track_panic!(ErrorKind::Other, "Unexpected message: {:?}", m);
}
}
}
fn tell(&mut self, trial: EvaluatedTrial) -> Result<()> {
let m = SolverMessage::TellCall {
solver_id: self.solver_id,
trial,
};
let mut tx = track!(self.tx.lock().map_err(Error::from))?;
track!(tx.send(&m))?;
let mut rx = track!(self.rx.lock().map_err(Error::from))?;
match track!(rx.recv())? {
SolverMessage::TellReply => Ok(()),
SolverMessage::ErrorReply { kind, message } => {
if let Some(message) = message {
track_panic!(kind, "{}", message);
} else {
track_panic!(kind);
}
}
m => {
track_panic!(ErrorKind::Other, "Unexpected message: {:?}", m);
}
}
}
}
impl Drop for ExternalProgramSolver {
fn drop(&mut self) {
let solver_id = self.solver_id;
let m = SolverMessage::DropSolverCast { solver_id };
if let Ok(mut tx) = self.tx.lock() {
let _ = tx.send(&m);
}
}
}