1use std::{
2 collections::{HashMap, HashSet},
3 io::{self, Write},
4 process::ExitCode,
5 sync::{
6 atomic::{AtomicUsize, Ordering},
7 mpsc::{self, Sender},
8 Arc,
9 },
10 thread, usize,
11};
12
13use lazy_static::lazy_static;
14use rayon;
15use smallvec::SmallVec;
16
17use crate::{
18 heap::{
19 heap::{Cell, Heap, Tag},
20 query_heap::QueryHeap,
21 },
22 parser::{
23 build_tree::{TokenStream, TreeClause},
24 execute_tree::build_clause,
25 tokeniser::tokenise,
26 },
27 program::{clause::Clause, hypothesis::Hypothesis, predicate_table::PredicateTable},
28 resolution::proof::Proof,
29 Config, Examples,
30};
31
32lazy_static! {
33 static ref CPU_COUNT: usize = num_cpus::get();
34}
35
36struct HypothesisMsg {
38 cells: Vec<Cell>,
39 h: Vec<Clause>,
40}
41
42pub fn run(
44 examples: Examples,
45 predicate_table: &PredicateTable,
46 mut heap: Vec<Cell>,
47 config: Config,
48 no_reduce: bool,
49) -> ExitCode {
50 println!("=== Top Program Construction ===");
51 println!(
52 "Positive examples: {}, Negative examples: {}",
53 examples.pos.len(),
54 examples.neg.len()
55 );
56
57 let (cells, hypotheses) = generalise(&examples.pos, predicate_table, &heap, config);
59 heap.extend_from_slice(&cells);
60
61 println!(
62 "\n=== Generalisation Results ===\n{} unique hypotheses, {} heap cells",
63 hypotheses.len(),
64 cells.len()
65 );
66
67 let retained = specialise(&examples.neg, &hypotheses, &heap, predicate_table, config);
69
70 let surviving_count = retained.iter().filter(|&&b| b).count();
71 let rejected_count = hypotheses.len() - surviving_count;
72 println!(
73 "\n=== Specialisation Results ===\n{} hypotheses survived, {} rejected",
74 surviving_count, rejected_count
75 );
76
77 let mut seen = HashSet::new();
79 let mut top_program: Vec<&Clause> = Vec::new();
80
81 if no_reduce {
82 for (hypothesis, &alive) in hypotheses.iter().zip(retained.iter()) {
84 if alive {
85 for clause in hypothesis {
86 let key = clause.to_string(&heap);
87 if seen.insert(key) {
88 top_program.push(clause);
89 }
90 }
91 }
92 }
93
94 println!("\n=== Top Program ({} clauses) ===", top_program.len());
95 for clause in &top_program {
96 println!(" {}", clause.to_string(&heap));
97 }
98 } else {
99 let surviving: Vec<&Vec<Clause>> = hypotheses
102 .iter()
103 .zip(retained.iter())
104 .filter_map(|(h, &alive)| if alive { Some(h) } else { None })
105 .collect();
106 let sub_total = surviving.len();
107
108 let mut scored: Vec<(usize, Vec<&Clause>)> = Vec::new();
110 for (idx, hypothesis) in surviving.iter().enumerate() {
111 eprint!("\rSub-reduce: {}/{sub_total} ", idx + 1);
112 let _ = io::stderr().flush();
113 let clauses: Vec<&Clause> = hypothesis.iter().collect();
114 let reduced = reduce(
115 &examples.pos,
116 clauses,
117 &heap,
118 predicate_table,
119 config,
120 false,
121 );
122 let coverage = count_coverage(&examples.pos, &reduced, &heap, predicate_table, config);
123 scored.push((coverage, reduced));
124 }
125 eprintln!("\rSub-reduce: {sub_total}/{sub_total} ...done ");
126
127 scored.sort_by(|a, b| a.0.cmp(&b.0));
133
134 for (_coverage, reduced_h) in &scored {
136 for &clause in reduced_h {
137 let key = clause.to_string(&heap);
138 if seen.insert(key) {
139 top_program.push(clause);
140 }
141 }
142 }
143
144 println!("\n=== Top Program ({} clauses) ===", top_program.len());
145 for clause in &top_program {
146 println!(" {}", clause.to_string(&heap));
147 }
148
149 let reduced = reduce(
151 &examples.pos,
152 top_program,
153 &heap,
154 predicate_table,
155 config,
156 true,
157 );
158
159 println!("\n=== Reduced Program ({} clauses) ===", reduced.len());
160 for clause in &reduced {
161 println!(" {}", clause.to_string(&heap));
162 }
163 }
164
165 ExitCode::SUCCESS
166}
167
168fn parse_example(example: &str, query_heap: &mut QueryHeap) -> Result<usize, String> {
170 let query = format!(":-{example}.");
171 let literals = match TokenStream::new(tokenise(query)?).parse_clause()? {
172 Some(TreeClause::Directive(literals)) => literals,
173 _ => return Err(format!("Example '{example}' incorrectly formatted")),
174 };
175 let clause = build_clause(literals, None, None, query_heap, true);
176 Ok(clause[0])
177}
178
179fn extract_hypothesis_local(proof: &Proof) -> (Vec<Cell>, Vec<Clause>) {
181 let mut local_cells: Vec<Cell> = Vec::new();
182 let mut ref_map = HashMap::new();
183 let mut clauses = Vec::new();
184
185 for clause in proof.hypothesis.iter() {
186 let new_literals: Vec<usize> = clause
187 .iter()
188 .map(|&lit_addr| {
189 local_cells.copy_term_with_ref_map(&proof.heap, lit_addr, &mut ref_map)
190 })
191 .collect();
192 clauses.push(Clause::new(new_literals, None, None));
193 }
194
195 (local_cells, clauses)
196}
197
198fn generalise(
199 pos_examples: &[String],
200 predicate_table: &PredicateTable,
201 heap: &[Cell],
202 config: Config,
203) -> (Vec<Cell>, Vec<Vec<Clause>>) {
204 let pool = rayon::ThreadPoolBuilder::new()
205 .num_threads(*CPU_COUNT - 1)
206 .build()
207 .unwrap();
208
209 let (tx, rx) = mpsc::channel::<HypothesisMsg>();
210 let total = pos_examples.len();
211 let completed = Arc::new(AtomicUsize::new(0));
212 let heap_len = heap.len();
213
214 let collector = thread::spawn(move || {
216 let mut hypothesis_cells = Vec::new();
217 let mut hypotheses = Vec::new();
218 let mut seen = HashSet::new();
219 let mut offset = heap_len;
220
221 for HypothesisMsg { cells, mut h } in rx {
222 let mut clause_strings: Vec<String> =
224 h.iter().map(|clause| clause.to_string(&cells)).collect();
225 clause_strings.sort_unstable();
226 let key = clause_strings.join("|");
227
228 if !seen.insert(key) {
229 continue; }
231
232 let len = cells.len();
233 for cell in cells {
234 let adjusted = match cell {
235 (Tag::Str, addr) => (Tag::Str, addr + offset),
236 (Tag::Lis, addr) => (Tag::Lis, addr + offset),
237 (Tag::Ref, addr) => (Tag::Ref, addr + offset),
238 other => other,
239 };
240 hypothesis_cells.push(adjusted);
241 }
242 for clause in h.iter_mut() {
243 for literal in clause.iter_mut() {
244 *literal += offset;
245 }
246 }
247 hypotheses.push(h);
248 offset += len;
249 }
250
251 (hypothesis_cells, hypotheses)
252 });
253
254 pool.scope(|s| {
256 for example in pos_examples {
257 let tx = tx.clone();
258 let completed = completed.clone();
259 s.spawn(move |_| {
260 generalise_thread(example, predicate_table, &heap, config, tx);
261 let done = completed.fetch_add(1, Ordering::Relaxed) + 1;
262 eprint!("\rGeneralise: {done}/{total} examples");
263 let _ = io::stderr().flush();
264 });
265 }
266 });
267 drop(tx); eprintln!(" ...done");
269
270 collector.join().unwrap()
271}
272
273fn generalise_thread(
274 example: &str,
275 predicate_table: &PredicateTable,
276 prog_heap: &[Cell],
277 config: Config,
278 tx: Sender<HypothesisMsg>,
279) {
280 let mut query_heap = QueryHeap::new(prog_heap, None);
281 let goal = match parse_example(&example, &mut query_heap) {
282 Ok(g) => g,
283 Err(e) => {
284 eprintln!("Failed to parse example '{}': {}", example, e);
285 return;
286 }
287 };
288 let mut proof = Proof::new(query_heap, &[goal]);
289
290 while proof.prove(predicate_table, config) {
291 for clause in proof.hypothesis.iter() {
292 clause.normalise_clause_vars(&mut proof.heap);
293 let (cells, h) = extract_hypothesis_local(&proof);
294 if tx.send(HypothesisMsg { cells, h }).is_err() {
295 break; }
297 }
298 }
299}
300
301fn specialise(
302 neg_examples: &[String],
303 hypotheses: &[Vec<Clause>],
304 heap: &[Cell],
305 predicate_table: &PredicateTable,
306 config: Config,
307) -> Vec<bool> {
308 let pool = rayon::ThreadPoolBuilder::new()
309 .num_threads(*CPU_COUNT - 1)
310 .build()
311 .unwrap();
312
313 let (tx, rx) = mpsc::channel::<(usize, bool)>();
314 let total = hypotheses.len();
315 let completed = Arc::new(AtomicUsize::new(0));
316
317 let collector = thread::spawn(move || {
319 let mut retained = vec![true; total];
320 for (idx, keep) in rx {
321 retained[idx] = keep;
322 }
323 retained
324 });
325
326 pool.scope(|s| {
328 for (idx, hypothesis) in hypotheses.iter().enumerate() {
329 let tx = tx.clone();
330 let completed = completed.clone();
331 s.spawn(move |_| {
332 let keep =
333 specialise_thread(neg_examples, hypothesis, heap, predicate_table, config);
334 let _ = tx.send((idx, keep));
335 let done = completed.fetch_add(1, Ordering::Relaxed) + 1;
336 eprint!("\rSpecialise: {done}/{total} hypotheses tested");
337 let _ = io::stderr().flush();
338 });
339 }
340 });
341 drop(tx);
342 eprintln!(" ...done");
343
344 collector.join().unwrap()
345}
346
347fn specialise_thread(
350 neg_examples: &[String],
351 hypothesis: &[Clause],
352 heap: &[Cell],
353 predicate_table: &PredicateTable,
354 config: Config,
355) -> bool {
356 let config = Config {
358 max_depth: config.max_depth,
359 max_clause: 0,
360 max_pred: 0,
361 debug: false,
362 };
363
364 let mut h = Hypothesis::new();
366 for clause in hypothesis {
367 h.push_clause(clause.clone(), SmallVec::new());
368 }
369
370 for example in neg_examples {
371 let mut query_heap = QueryHeap::new(heap, None);
372 let goal = match parse_example(example, &mut query_heap) {
373 Ok(g) => g,
374 Err(e) => {
375 eprintln!("Failed to parse negative example '{}': {}", example, e);
376 continue;
377 }
378 };
379 let mut proof = Proof::with_hypothesis(query_heap, &[goal], h);
380 if proof.prove(predicate_table, config) {
382 return false;
383 }
384 h = std::mem::replace(&mut proof.hypothesis, Hypothesis::new());
386 }
387 true
388}
389
390fn count_coverage(
392 pos_examples: &[String],
393 clauses: &[&Clause],
394 heap: &[Cell],
395 predicate_table: &PredicateTable,
396 config: Config,
397) -> usize {
398 let config = Config {
399 max_depth: config.max_depth,
400 max_clause: 0,
401 max_pred: 0,
402 debug: false,
403 };
404
405 let mut h = Hypothesis::new();
406 for clause in clauses {
407 h.push_clause((*clause).clone(), SmallVec::new());
408 }
409
410 pos_examples
411 .iter()
412 .filter(|example| {
413 let mut query_heap = QueryHeap::new(heap, None);
414 let goal = match parse_example(example, &mut query_heap) {
415 Ok(g) => g,
416 Err(_) => return false,
417 };
418 let mut proof = Proof::with_hypothesis(query_heap, &[goal], h.clone());
419 proof.prove(predicate_table, config)
420 })
421 .count()
422}
423
424fn reduce<'a>(
428 pos_examples: &[String],
429 mut top_program: Vec<&'a Clause>,
430 heap: &[Cell],
431 predicate_table: &PredicateTable,
432 config: Config,
433 verbose: bool,
434) -> Vec<&'a Clause> {
435 let config = Config {
436 max_depth: config.max_depth,
437 max_clause: 0,
438 max_pred: 0,
439 debug: false,
440 };
441
442 let total = top_program.len();
443 let mut removed = 0usize;
444 let mut i = 0;
445 while i < top_program.len() {
446 if verbose {
447 eprint!(
448 "\rReduce: {}/{total} checked, {removed} removed ",
449 i + removed + 1
450 );
451 let _ = io::stderr().flush();
452 }
453
454 let mut h = Hypothesis::new();
456 for (j, clause) in top_program.iter().enumerate() {
457 if j != i {
458 h.push_clause((*clause).clone(), SmallVec::new());
459 }
460 }
461
462 let redundant = pos_examples.iter().all(|example| {
464 let mut query_heap = QueryHeap::new(heap, None);
465 let goal = match parse_example(example, &mut query_heap) {
466 Ok(g) => g,
467 Err(_) => return true, };
469 let mut proof = Proof::with_hypothesis(query_heap, &[goal], h.clone());
470 proof.prove(predicate_table, config)
471 });
472
473 if redundant {
474 top_program.remove(i);
475 removed += 1;
476 } else {
478 i += 1;
479 }
480 }
481 if verbose {
482 eprintln!(" ...done");
483 }
484
485 top_program
486}