use std::collections::HashMap;
use std::f64;
use std::io::{self, BufRead, BufReader, Write};
use std::process::{ChildStdin, ChildStdout, Command, Stdio};
use std::sync::mpsc::channel;
use itertools::Itertools;
use polytype::Type;
use workerpool::{Pool, Worker};
use super::super::Task;
use super::{Expression, Language};
pub type LispError = io::Error;
pub struct LispEvaluator {
conversions: HashMap<String, String>,
pool: Pool<Racket>,
}
impl LispEvaluator {
pub fn new(prims: Vec<(&str, &str)>) -> Self {
let conversions = prims
.into_iter()
.map(|(name, definition)| (String::from(name), String::from(definition)))
.collect();
let pool = Pool::<Racket>::default();
LispEvaluator { conversions, pool }
}
pub fn check(
&self,
dsl: &Language,
expr: &Expression,
input: Option<&str>,
output: &str,
) -> Result<bool, LispError> {
let cmd = dsl.lispify(expr, &self.conversions);
let op = if let Some(inp) = input {
format!("(equal? ({} {}) {})", cmd, inp, output)
} else {
format!("(equal? {} {})", cmd, output)
};
let (tx, rx) = channel();
self.pool.execute_to(tx, op.clone());
let response = rx.recv().expect("receive")?;
match &*response {
"#t\n" => Ok(true),
"#f\n" => Ok(false),
_ => Err(io::Error::new(io::ErrorKind::Other, response)),
}
}
pub fn check_many(
&self,
dsl: &Language,
expr: &Expression,
examples: &[(&str, &str)],
) -> Result<bool, LispError> {
let cmd = dsl.lispify(expr, &self.conversions);
let op = format!(
"(and {})",
examples
.iter()
.map(|&(i, o)| format!("(equal? ({} {}) {})", cmd, i, o))
.join(" ")
);
let (tx, rx) = channel();
self.pool.execute_to(tx, op.clone());
let response = rx.recv().expect("receive")?;
match &*response {
"#t\n" => Ok(true),
"#f\n" => Ok(false),
_ => Err(io::Error::new(io::ErrorKind::Other, response)),
}
}
pub fn make_task<'a>(
&'a self,
tp: Type,
examples: &[(&'a str, &'a str)],
) -> Task<'a, Language, Expression, Vec<(String, String)>> {
let examples: Vec<_> = examples.to_vec();
let observation: Vec<_> = examples
.iter()
.map(|&(inp, out)| (String::from(inp), String::from(out)))
.collect();
let oracle = Box::new(move |dsl: &Language, expr: &Expression| -> f64 {
if self.check_many(dsl, expr, &examples).unwrap_or(false) {
0f64
} else {
f64::NEG_INFINITY
}
});
Task {
oracle,
observation,
tp,
}
}
pub fn make_task_output_only<'a>(
&'a self,
tp: Type,
output: &'a str,
) -> Task<'a, Language, Expression, String> {
let observation = String::from(output);
let oracle = Box::new(move |dsl: &Language, expr: &Expression| -> f64 {
if self.check(dsl, expr, None, output).unwrap_or(false) {
0f64
} else {
f64::NEG_INFINITY
}
});
Task {
oracle,
observation,
tp,
}
}
}
impl Default for LispEvaluator {
fn default() -> Self {
let conversions = HashMap::new();
let pool = Pool::<Racket>::default();
LispEvaluator { conversions, pool }
}
}
struct Racket {
stdin: ChildStdin,
stdout: BufReader<ChildStdout>,
}
impl Default for Racket {
fn default() -> Self {
let child = Command::new("racket")
.arg("-e")
.arg(
"(let lp ()
(with-handlers ([exn:fail? (λ (exn) (displayln \"ERROR\"))])
(displayln (eval (read))))
(flush-output)
(lp))",
)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.spawn()
.expect("could not spawn racket process");
let stdin = child.stdin.expect("connect to racket stdin");
let stdout = child.stdout.expect("connect to racket stdout");
let stdout = BufReader::new(stdout);
Racket { stdin, stdout }
}
}
impl Worker for Racket {
type Input = String;
type Output = Result<String, LispError>;
fn execute(&mut self, op: Self::Input) -> Self::Output {
self.stdin.write_all(op.as_bytes())?;
self.stdin.write_all(b"\n")?;
self.stdin.flush()?;
let mut s = String::new();
self.stdout.read_line(&mut s)?;
Ok(s)
}
}