programinduction/domains/
circuits.rs1use itertools::Itertools;
27use polytype::{ptp, tp, Type, TypeScheme};
28use rand::{
29 distributions::{Distribution, WeightedIndex},
30 Rng,
31};
32use std::iter;
33use std::sync::Arc;
34
35use crate::lambda::{Evaluator as EvaluatorT, Expression, Language};
36use crate::Task;
37
38pub fn dsl() -> Language {
46 Language::uniform(vec![(
47 "nand",
48 ptp!(@arrow[tp!(bool), tp!(bool), tp!(bool)]),
49 )])
50}
51
52pub type Space = bool;
54
55#[derive(Copy, Clone)]
88pub struct Evaluator;
89impl EvaluatorT for Evaluator {
90 type Space = Space;
91 type Error = ();
92 fn evaluate(&self, primitive: &str, inp: &[Self::Space]) -> Result<Self::Space, Self::Error> {
93 match primitive {
94 "nand" => Ok(!(inp[0] & inp[1])),
95 _ => unreachable!(),
96 }
97 }
98}
99
100pub fn make_tasks<R: Rng>(
124 rng: &mut R,
125 count: u32,
126) -> Vec<impl Task<[bool], Representation = Language, Expression = Expression>> {
127 make_tasks_advanced(
128 rng,
129 count,
130 [1, 2, 3, 4, 4, 4, 0, 0],
131 [1, 2, 2, 0, 0, 0, 0, 0],
132 1,
133 2,
134 2,
135 4,
136 0,
137 )
138}
139
140#[allow(clippy::too_many_arguments)]
154pub fn make_tasks_advanced<R: Rng>(
155 rng: &mut R,
156 count: u32,
157 n_input_weights: [u32; 8],
158 n_gate_weights: [u32; 8],
159 gate_not: u32,
160 gate_and: u32,
161 gate_or: u32,
162 gate_mux2: u32,
163 gate_mux4: u32,
164) -> Vec<impl Task<[bool], Representation = Language, Expression = Expression>> {
165 let n_input_distribution =
166 WeightedIndex::new(n_input_weights).expect("invalid weights for number of circuit inputs");
167 let n_gate_distribution =
168 WeightedIndex::new(n_gate_weights).expect("invalid weights for number of circuit gates");
169 let gate_weights = WeightedIndex::new([gate_not, gate_and, gate_or, gate_mux2, gate_mux4])
170 .expect("invalid weights for circuit gates");
171
172 (0..count)
173 .map(move |_| {
174 let mut n_inputs = 1 + n_input_distribution.sample(rng);
175 let mut n_gates = 1 + n_gate_distribution.sample(rng);
176 while n_inputs / n_gates >= 3 {
177 n_inputs = 1 + n_input_distribution.sample(rng);
178 n_gates = 1 + n_gate_distribution.sample(rng);
179 }
180 let circuit = gates::Circuit::new(rng, &gate_weights, n_inputs as u32, n_gates);
181 let outputs: Vec<_> = iter::repeat(vec![false, true])
182 .take(n_inputs)
183 .multi_cartesian_product()
184 .map(|ins| circuit.eval(&ins))
185 .collect();
186 CircuitTask::new(n_inputs, outputs)
187 })
188 .collect()
189}
190
191struct CircuitTask {
192 n_inputs: usize,
193 expected_outputs: Vec<bool>,
194 tp: TypeScheme,
195}
196impl CircuitTask {
197 fn new(n_inputs: usize, expected_outputs: Vec<bool>) -> Self {
198 let tp = TypeScheme::Monotype(Type::from(vec![tp!(bool); n_inputs + 1]));
199 CircuitTask {
200 n_inputs,
201 expected_outputs,
202 tp,
203 }
204 }
205}
206impl Task<[bool]> for CircuitTask {
207 type Representation = Language;
208 type Expression = Expression;
209
210 fn oracle(&self, dsl: &Self::Representation, expr: &Self::Expression) -> f64 {
211 let evaluator = Arc::new(Evaluator);
212 let success = iter::repeat(vec![false, true])
213 .take(self.n_inputs)
214 .multi_cartesian_product()
215 .zip(&self.expected_outputs)
216 .all(|(inps, out)| {
217 if let Ok(o) = dsl.eval_arc(expr, &evaluator, &inps) {
218 o == *out
219 } else {
220 false
221 }
222 });
223 if success {
224 0f64
225 } else {
226 f64::NEG_INFINITY
227 }
228 }
229 fn tp(&self) -> &TypeScheme {
230 &self.tp
231 }
232 fn observation(&self) -> &[bool] {
233 &self.expected_outputs
234 }
235}
236
237mod gates {
238 use rand::{
239 distributions::{Distribution, WeightedIndex},
240 seq::index::sample,
241 Rng,
242 };
243
244 const GATE_CHOICES: [Gate; 5] = [Gate::Not, Gate::And, Gate::Or, Gate::Mux2, Gate::Mux4];
245
246 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
247 pub enum Gate {
248 Not,
249 And,
250 Or,
251 Mux2,
252 Mux4,
253 }
254 impl Gate {
255 fn n_inputs(self) -> u32 {
256 match self {
257 Gate::Not => 1,
258 Gate::And | Gate::Or => 2,
259 Gate::Mux2 => 3,
260 Gate::Mux4 => 6,
261 }
262 }
263 fn eval(self, inp: &[bool]) -> bool {
264 match self {
265 Gate::Not => !inp[0],
266 Gate::And => inp[0] & inp[1],
267 Gate::Or => inp[0] | inp[1],
268 Gate::Mux2 => [inp[0], inp[1]][inp[2] as usize],
269 Gate::Mux4 => {
270 [inp[0], inp[1], inp[2], inp[3]][((inp[5] as usize) << 1) + inp[4] as usize]
271 }
272 }
273 }
274 }
275
276 #[derive(Debug, PartialEq, Eq)]
277 pub struct Circuit {
278 n_inputs: u32,
279 operations: Vec<(Gate, Vec<u32>)>,
280 }
281 impl Circuit {
282 pub fn new<T: Rng>(
283 rng: &mut T,
284 gate_distribution: &WeightedIndex<u32>,
285 n_inputs: u32,
286 n_gates: usize,
287 ) -> Self {
288 loop {
289 let mut operations = Vec::with_capacity(n_gates);
290 while operations.len() < n_gates {
291 let gate = GATE_CHOICES[gate_distribution.sample(rng)];
292 let n_lanes = n_inputs + (operations.len() as u32);
293 if gate.n_inputs() > n_lanes {
294 continue;
295 }
296 let args = sample(rng, n_lanes as usize, gate.n_inputs() as usize)
297 .into_iter()
298 .map(|x| x as u32)
299 .collect();
300 operations.push((gate, args));
301 }
302 let circuit = Circuit {
303 n_inputs,
304 operations,
305 };
306 if circuit.is_connected() {
307 break circuit;
308 }
309 }
310 }
311 fn is_connected(&self) -> bool {
314 let n_lanes = self.n_inputs as usize + self.operations.len();
315 let mut is_used = vec![false; n_lanes];
316 for (_, args) in &self.operations {
317 for i in args {
318 is_used[*i as usize] = true;
319 }
320 }
321 is_used.pop();
322 is_used.into_iter().all(|x| x)
323 }
324 pub fn eval(&self, inp: &[bool]) -> bool {
325 let mut lanes = inp.to_vec();
326 for (gate, args) in &self.operations {
327 let gate_inp: Vec<bool> = args.iter().map(|a| lanes[*a as usize]).collect();
328 lanes.push(gate.eval(&gate_inp));
329 }
330 lanes.pop().unwrap()
331 }
332 }
333}