use itertools::Itertools;
use polytype::{ptp, tp, Type, TypeScheme};
use rand::{
distributions::{Distribution, WeightedIndex},
Rng,
};
use std::iter;
use std::sync::Arc;
use crate::lambda::{Evaluator as EvaluatorT, Expression, Language};
use crate::Task;
pub fn dsl() -> Language {
Language::uniform(vec![(
"nand",
ptp!(@arrow[tp!(bool), tp!(bool), tp!(bool)]),
)])
}
pub type Space = bool;
#[derive(Copy, Clone)]
pub struct Evaluator;
impl EvaluatorT for Evaluator {
type Space = Space;
type Error = ();
fn evaluate(&self, primitive: &str, inp: &[Self::Space]) -> Result<Self::Space, Self::Error> {
match primitive {
"nand" => Ok(!(inp[0] & inp[1])),
_ => unreachable!(),
}
}
}
pub fn make_tasks<R: Rng>(
rng: &mut R,
count: u32,
) -> Vec<impl Task<[bool], Representation = Language, Expression = Expression>> {
make_tasks_advanced(
rng,
count,
[1, 2, 3, 4, 4, 4, 0, 0],
[1, 2, 2, 0, 0, 0, 0, 0],
1,
2,
2,
4,
0,
)
}
#[allow(clippy::too_many_arguments)]
pub fn make_tasks_advanced<R: Rng>(
rng: &mut R,
count: u32,
n_input_weights: [u32; 8],
n_gate_weights: [u32; 8],
gate_not: u32,
gate_and: u32,
gate_or: u32,
gate_mux2: u32,
gate_mux4: u32,
) -> Vec<impl Task<[bool], Representation = Language, Expression = Expression>> {
let n_input_distribution =
WeightedIndex::new(n_input_weights).expect("invalid weights for number of circuit inputs");
let n_gate_distribution =
WeightedIndex::new(n_gate_weights).expect("invalid weights for number of circuit gates");
let gate_weights = WeightedIndex::new([gate_not, gate_and, gate_or, gate_mux2, gate_mux4])
.expect("invalid weights for circuit gates");
(0..count)
.map(move |_| {
let mut n_inputs = 1 + n_input_distribution.sample(rng);
let mut n_gates = 1 + n_gate_distribution.sample(rng);
while n_inputs / n_gates >= 3 {
n_inputs = 1 + n_input_distribution.sample(rng);
n_gates = 1 + n_gate_distribution.sample(rng);
}
let circuit = gates::Circuit::new(rng, &gate_weights, n_inputs as u32, n_gates);
let outputs: Vec<_> = iter::repeat(vec![false, true])
.take(n_inputs)
.multi_cartesian_product()
.map(|ins| circuit.eval(&ins))
.collect();
CircuitTask::new(n_inputs, outputs)
})
.collect()
}
struct CircuitTask {
n_inputs: usize,
expected_outputs: Vec<bool>,
tp: TypeScheme,
}
impl CircuitTask {
fn new(n_inputs: usize, expected_outputs: Vec<bool>) -> Self {
let tp = TypeScheme::Monotype(Type::from(vec![tp!(bool); n_inputs + 1]));
CircuitTask {
n_inputs,
expected_outputs,
tp,
}
}
}
impl Task<[bool]> for CircuitTask {
type Representation = Language;
type Expression = Expression;
fn oracle(&self, dsl: &Self::Representation, expr: &Self::Expression) -> f64 {
let evaluator = Arc::new(Evaluator);
let success = iter::repeat(vec![false, true])
.take(self.n_inputs)
.multi_cartesian_product()
.zip(&self.expected_outputs)
.all(|(inps, out)| {
if let Ok(o) = dsl.eval_arc(expr, &evaluator, &inps) {
o == *out
} else {
false
}
});
if success {
0f64
} else {
f64::NEG_INFINITY
}
}
fn tp(&self) -> &TypeScheme {
&self.tp
}
fn observation(&self) -> &[bool] {
&self.expected_outputs
}
}
mod gates {
use rand::{
distributions::{Distribution, WeightedIndex},
seq::index::sample,
Rng,
};
const GATE_CHOICES: [Gate; 5] = [Gate::Not, Gate::And, Gate::Or, Gate::Mux2, Gate::Mux4];
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum Gate {
Not,
And,
Or,
Mux2,
Mux4,
}
impl Gate {
fn n_inputs(self) -> u32 {
match self {
Gate::Not => 1,
Gate::And | Gate::Or => 2,
Gate::Mux2 => 3,
Gate::Mux4 => 6,
}
}
fn eval(self, inp: &[bool]) -> bool {
match self {
Gate::Not => !inp[0],
Gate::And => inp[0] & inp[1],
Gate::Or => inp[0] | inp[1],
Gate::Mux2 => [inp[0], inp[1]][inp[2] as usize],
Gate::Mux4 => {
[inp[0], inp[1], inp[2], inp[3]][((inp[5] as usize) << 1) + inp[4] as usize]
}
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct Circuit {
n_inputs: u32,
operations: Vec<(Gate, Vec<u32>)>,
}
impl Circuit {
pub fn new<T: Rng>(
rng: &mut T,
gate_distribution: &WeightedIndex<u32>,
n_inputs: u32,
n_gates: usize,
) -> Self {
loop {
let mut operations = Vec::with_capacity(n_gates);
while operations.len() < n_gates {
let gate = GATE_CHOICES[gate_distribution.sample(rng)];
let n_lanes = n_inputs + (operations.len() as u32);
if gate.n_inputs() > n_lanes {
continue;
}
let args = sample(rng, n_lanes as usize, gate.n_inputs() as usize)
.into_iter()
.map(|x| x as u32)
.collect();
operations.push((gate, args));
}
let circuit = Circuit {
n_inputs,
operations,
};
if circuit.is_connected() {
break circuit;
}
}
}
fn is_connected(&self) -> bool {
let n_lanes = self.n_inputs as usize + self.operations.len();
let mut is_used = vec![false; n_lanes];
for (_, args) in &self.operations {
for i in args {
is_used[*i as usize] = true;
}
}
is_used.pop();
is_used.into_iter().all(|x| x)
}
pub fn eval(&self, inp: &[bool]) -> bool {
let mut lanes = inp.to_vec();
for (gate, args) in &self.operations {
let gate_inp: Vec<bool> = args.iter().map(|a| lanes[*a as usize]).collect();
lanes.push(gate.eval(&gate_inp));
}
lanes.pop().unwrap()
}
}
}