programinduction/domains/
circuits.rs

1//! The Boolean circuit domain.
2//!
3//! As in the paper "Bootstrap Learning for Modular Concept Discovery" (2013).
4//!
5//! # Examples
6//!
7//! ```
8//! use programinduction::domains::circuits;
9//! use programinduction::{ECParams, EC};
10//! use rand::{rngs::SmallRng, SeedableRng};
11//!
12//! let dsl = circuits::dsl();
13//! let rng = &mut SmallRng::from_seed([1u8; 32]);
14//! let tasks = circuits::make_tasks(rng, 250);
15//! let ec_params = ECParams {
16//!     frontier_limit: 100,
17//!     search_limit_timeout: None,
18//!     search_limit_description_length: Some(9.0),
19//! };
20//!
21//! let frontiers = dsl.explore(&ec_params, &tasks);
22//! let hits = frontiers.iter().filter_map(|f| f.best_solution()).count();
23//! assert!(40 < hits && hits < 80, "hits = {}", hits);
24//! ```
25
26use itertools::Itertools;
27use polytype::{ptp, tp, Type, TypeScheme};
28use rand::{
29    distributions::{Distribution, WeightedIndex},
30    Rng,
31};
32use std::iter;
33use std::sync::Arc;
34
35use crate::lambda::{Evaluator as EvaluatorT, Expression, Language};
36use crate::Task;
37
38/// The circuit representation, a [`lambda::Language`], only defines the binary `nand` operation.
39///
40/// ```compile_fails
41/// "nand": ptp!(@arrow[tp!(bool), tp!(bool), tp!(bool)])
42/// ```
43///
44/// [`lambda::Language`]: ../../lambda/struct.Language.html
45pub fn dsl() -> Language {
46    Language::uniform(vec![(
47        "nand",
48        ptp!(@arrow[tp!(bool), tp!(bool), tp!(bool)]),
49    )])
50}
51
52/// All values in the circuits domain can be represented in this `Space`.
53pub type Space = bool;
54
55/// An [`Evaluator`] for the circuits domain.
56///
57/// # Examples
58///
59/// ```
60/// use polytype::{ptp, tp};
61/// use programinduction::domains::circuits;
62/// use programinduction::{lambda, ECParams, EC};
63///
64/// let dsl = circuits::dsl();
65///
66/// let examples = vec![ // NOT
67///     (vec![false], true),
68///     (vec![true], false),
69/// ];
70/// let task = lambda::task_by_evaluation(
71///     circuits::Evaluator,
72///     ptp!(@arrow[tp!(bool), tp!(bool)]),
73///     &examples,
74/// );
75/// let ec_params = ECParams {
76///     frontier_limit: 1,
77///     search_limit_timeout: None,
78///     search_limit_description_length: Some(5.0),
79/// };
80///
81/// let frontiers = dsl.explore(&ec_params, &[task]);
82/// let (expr, _logprior, _loglikelihood) = frontiers[0].best_solution().unwrap();
83/// assert_eq!(dsl.display(expr), "(λ (nand $0 $0))");
84/// ```
85///
86/// [`Evaluator`]: ../../lambda/trait.Evaluator.html
87#[derive(Copy, Clone)]
88pub struct Evaluator;
89impl EvaluatorT for Evaluator {
90    type Space = Space;
91    type Error = ();
92    fn evaluate(&self, primitive: &str, inp: &[Self::Space]) -> Result<Self::Space, Self::Error> {
93        match primitive {
94            "nand" => Ok(!(inp[0] & inp[1])),
95            _ => unreachable!(),
96        }
97    }
98}
99
100/// Randomly sample a number of circuits into [`Task`]s.
101///
102/// For a circuit, the number of inputs is sampled from 1 to 6 with weights 1, 2, 3, 4, 4, and 4
103/// respectively. The number of gates is sampled from 1 to 3 with weights 1, 2, and 2 respectively.
104/// The gates themselves are sampled from NOT, AND, OR, and MUX2 with weights 1, 2, 2, and 4,
105/// respectively. All circuits are connected: every input is used and every gate's output is either
106/// wired to another gate or to the circuits final output.
107///
108/// The task observations are outputs of the truth table in sequence, for example
109///
110/// ```text
111///    --- INPUTS ---      OUTPUTS
112/// false, false, false => false
113/// false, false, true  => false
114/// false, true,  false => false
115/// false, true,  true  => false
116/// true,  false, false => false
117/// true,  false, true  => false
118/// true,  true,  false => false
119/// true,  true,  true  => true
120/// ```
121///
122/// [`Task`]: ../../struct.Task.html
123pub fn make_tasks<R: Rng>(
124    rng: &mut R,
125    count: u32,
126) -> Vec<impl Task<[bool], Representation = Language, Expression = Expression>> {
127    make_tasks_advanced(
128        rng,
129        count,
130        [1, 2, 3, 4, 4, 4, 0, 0],
131        [1, 2, 2, 0, 0, 0, 0, 0],
132        1,
133        2,
134        2,
135        4,
136        0,
137    )
138}
139
140/// Like [`make_tasks`], but with a configurable circuit distribution.
141///
142/// This works by randomly constructing Boolean circuits, consisting of some number of inputs and
143/// some number of gates transforming those inputs into final output. The resulting truth table
144/// serves the task's oracle, providing a log-likelihood of `0.0` for success
145/// and `f64::NEG_INFINITY` for failure.
146///
147/// The `n_input_weights` and `n_gate_weights` arguments specify the relative distributions for the
148/// number of inputs/gates respectively from 1 to 8. The `gate_` arguments are relative
149/// weights for sampling the respective logic gate.
150/// Sample circuits which are invalid (i.e. not connected) are rejected.
151///
152/// [`make_tasks`]: fn.make_tasks.html
153#[allow(clippy::too_many_arguments)]
154pub fn make_tasks_advanced<R: Rng>(
155    rng: &mut R,
156    count: u32,
157    n_input_weights: [u32; 8],
158    n_gate_weights: [u32; 8],
159    gate_not: u32,
160    gate_and: u32,
161    gate_or: u32,
162    gate_mux2: u32,
163    gate_mux4: u32,
164) -> Vec<impl Task<[bool], Representation = Language, Expression = Expression>> {
165    let n_input_distribution =
166        WeightedIndex::new(n_input_weights).expect("invalid weights for number of circuit inputs");
167    let n_gate_distribution =
168        WeightedIndex::new(n_gate_weights).expect("invalid weights for number of circuit gates");
169    let gate_weights = WeightedIndex::new([gate_not, gate_and, gate_or, gate_mux2, gate_mux4])
170        .expect("invalid weights for circuit gates");
171
172    (0..count)
173        .map(move |_| {
174            let mut n_inputs = 1 + n_input_distribution.sample(rng);
175            let mut n_gates = 1 + n_gate_distribution.sample(rng);
176            while n_inputs / n_gates >= 3 {
177                n_inputs = 1 + n_input_distribution.sample(rng);
178                n_gates = 1 + n_gate_distribution.sample(rng);
179            }
180            let circuit = gates::Circuit::new(rng, &gate_weights, n_inputs as u32, n_gates);
181            let outputs: Vec<_> = iter::repeat(vec![false, true])
182                .take(n_inputs)
183                .multi_cartesian_product()
184                .map(|ins| circuit.eval(&ins))
185                .collect();
186            CircuitTask::new(n_inputs, outputs)
187        })
188        .collect()
189}
190
191struct CircuitTask {
192    n_inputs: usize,
193    expected_outputs: Vec<bool>,
194    tp: TypeScheme,
195}
196impl CircuitTask {
197    fn new(n_inputs: usize, expected_outputs: Vec<bool>) -> Self {
198        let tp = TypeScheme::Monotype(Type::from(vec![tp!(bool); n_inputs + 1]));
199        CircuitTask {
200            n_inputs,
201            expected_outputs,
202            tp,
203        }
204    }
205}
206impl Task<[bool]> for CircuitTask {
207    type Representation = Language;
208    type Expression = Expression;
209
210    fn oracle(&self, dsl: &Self::Representation, expr: &Self::Expression) -> f64 {
211        let evaluator = Arc::new(Evaluator);
212        let success = iter::repeat(vec![false, true])
213            .take(self.n_inputs)
214            .multi_cartesian_product()
215            .zip(&self.expected_outputs)
216            .all(|(inps, out)| {
217                if let Ok(o) = dsl.eval_arc(expr, &evaluator, &inps) {
218                    o == *out
219                } else {
220                    false
221                }
222            });
223        if success {
224            0f64
225        } else {
226            f64::NEG_INFINITY
227        }
228    }
229    fn tp(&self) -> &TypeScheme {
230        &self.tp
231    }
232    fn observation(&self) -> &[bool] {
233        &self.expected_outputs
234    }
235}
236
237mod gates {
238    use rand::{
239        distributions::{Distribution, WeightedIndex},
240        seq::index::sample,
241        Rng,
242    };
243
244    const GATE_CHOICES: [Gate; 5] = [Gate::Not, Gate::And, Gate::Or, Gate::Mux2, Gate::Mux4];
245
246    #[derive(Copy, Clone, Debug, PartialEq, Eq)]
247    pub enum Gate {
248        Not,
249        And,
250        Or,
251        Mux2,
252        Mux4,
253    }
254    impl Gate {
255        fn n_inputs(self) -> u32 {
256            match self {
257                Gate::Not => 1,
258                Gate::And | Gate::Or => 2,
259                Gate::Mux2 => 3,
260                Gate::Mux4 => 6,
261            }
262        }
263        fn eval(self, inp: &[bool]) -> bool {
264            match self {
265                Gate::Not => !inp[0],
266                Gate::And => inp[0] & inp[1],
267                Gate::Or => inp[0] | inp[1],
268                Gate::Mux2 => [inp[0], inp[1]][inp[2] as usize],
269                Gate::Mux4 => {
270                    [inp[0], inp[1], inp[2], inp[3]][((inp[5] as usize) << 1) + inp[4] as usize]
271                }
272            }
273        }
274    }
275
276    #[derive(Debug, PartialEq, Eq)]
277    pub struct Circuit {
278        n_inputs: u32,
279        operations: Vec<(Gate, Vec<u32>)>,
280    }
281    impl Circuit {
282        pub fn new<T: Rng>(
283            rng: &mut T,
284            gate_distribution: &WeightedIndex<u32>,
285            n_inputs: u32,
286            n_gates: usize,
287        ) -> Self {
288            loop {
289                let mut operations = Vec::with_capacity(n_gates);
290                while operations.len() < n_gates {
291                    let gate = GATE_CHOICES[gate_distribution.sample(rng)];
292                    let n_lanes = n_inputs + (operations.len() as u32);
293                    if gate.n_inputs() > n_lanes {
294                        continue;
295                    }
296                    let args = sample(rng, n_lanes as usize, gate.n_inputs() as usize)
297                        .into_iter()
298                        .map(|x| x as u32)
299                        .collect();
300                    operations.push((gate, args));
301                }
302                let circuit = Circuit {
303                    n_inputs,
304                    operations,
305                };
306                if circuit.is_connected() {
307                    break circuit;
308                }
309            }
310        }
311        /// A circuit is connected if every output except for the last one is an input for some
312        /// other gate.
313        fn is_connected(&self) -> bool {
314            let n_lanes = self.n_inputs as usize + self.operations.len();
315            let mut is_used = vec![false; n_lanes];
316            for (_, args) in &self.operations {
317                for i in args {
318                    is_used[*i as usize] = true;
319                }
320            }
321            is_used.pop();
322            is_used.into_iter().all(|x| x)
323        }
324        pub fn eval(&self, inp: &[bool]) -> bool {
325            let mut lanes = inp.to_vec();
326            for (gate, args) in &self.operations {
327                let gate_inp: Vec<bool> = args.iter().map(|a| lanes[*a as usize]).collect();
328                lanes.push(gate.eval(&gate_inp));
329            }
330            lanes.pop().unwrap()
331        }
332    }
333}