Skip to main content

prolog2/
top_prog.rs

1use std::{
2    collections::{HashMap, HashSet},
3    io::{self, Write},
4    process::ExitCode,
5    sync::{
6        atomic::{AtomicUsize, Ordering},
7        mpsc::{self, Sender},
8        Arc,
9    },
10    thread, usize,
11};
12
13use lazy_static::lazy_static;
14use rayon;
15use smallvec::SmallVec;
16
17use crate::{
18    heap::{
19        heap::{Cell, Heap, Tag},
20        query_heap::QueryHeap,
21    },
22    parser::{
23        build_tree::{TokenStream, TreeClause},
24        execute_tree::build_clause,
25        tokeniser::tokenise,
26    },
27    program::{clause::Clause, hypothesis::Hypothesis, predicate_table::PredicateTable},
28    resolution::proof::Proof,
29    Config, Examples,
30};
31
32lazy_static! {
33    static ref CPU_COUNT: usize = num_cpus::get();
34}
35
36/// Message sent from a proof thread to the main thread.
37struct HypothesisMsg {
38    cells: Vec<Cell>,
39    h: Vec<Clause>,
40}
41
42/// Top Program Construction entry point
43pub fn run(
44    examples: Examples,
45    predicate_table: &PredicateTable,
46    mut heap: Vec<Cell>,
47    config: Config,
48    no_reduce: bool,
49) -> ExitCode {
50    println!("=== Top Program Construction ===");
51    println!(
52        "Positive examples: {}, Negative examples: {}",
53        examples.pos.len(),
54        examples.neg.len()
55    );
56
57    // Step 1: Generalise
58    let (cells, hypotheses) = generalise(&examples.pos, predicate_table, &heap, config);
59    heap.extend_from_slice(&cells);
60
61    println!(
62        "\n=== Generalisation Results ===\n{} unique hypotheses, {} heap cells",
63        hypotheses.len(),
64        cells.len()
65    );
66
67    // Step 2: Specialise
68    let retained = specialise(&examples.neg, &hypotheses, &heap, predicate_table, config);
69
70    let surviving_count = retained.iter().filter(|&&b| b).count();
71    let rejected_count = hypotheses.len() - surviving_count;
72    println!(
73        "\n=== Specialisation Results ===\n{} hypotheses survived, {} rejected",
74        surviving_count, rejected_count
75    );
76
77    // Build final top program from surviving hypotheses
78    let mut seen = HashSet::new();
79    let mut top_program: Vec<&Clause> = Vec::new();
80
81    if no_reduce {
82        // Simple union without reduction
83        for (hypothesis, &alive) in hypotheses.iter().zip(retained.iter()) {
84            if alive {
85                for clause in hypothesis {
86                    let key = clause.to_string(&heap);
87                    if seen.insert(key) {
88                        top_program.push(clause);
89                    }
90                }
91            }
92        }
93
94        println!("\n=== Top Program ({} clauses) ===", top_program.len());
95        for clause in &top_program {
96            println!("  {}", clause.to_string(&heap));
97        }
98    } else {
99        // Step 2b: Per-hypothesis reduction — remove redundant clauses within each
100        // sub-hypothesis before union, so specific clauses don't drown out general ones.
101        let surviving: Vec<&Vec<Clause>> = hypotheses
102            .iter()
103            .zip(retained.iter())
104            .filter_map(|(h, &alive)| if alive { Some(h) } else { None })
105            .collect();
106        let sub_total = surviving.len();
107
108        // Reduce each sub-hypothesis and score by coverage (number of positives entailed)
109        let mut scored: Vec<(usize, Vec<&Clause>)> = Vec::new();
110        for (idx, hypothesis) in surviving.iter().enumerate() {
111            eprint!("\rSub-reduce: {}/{sub_total}    ", idx + 1);
112            let _ = io::stderr().flush();
113            let clauses: Vec<&Clause> = hypothesis.iter().collect();
114            let reduced = reduce(
115                &examples.pos,
116                clauses,
117                &heap,
118                predicate_table,
119                config,
120                false,
121            );
122            let coverage = count_coverage(&examples.pos, &reduced, &heap, predicate_table, config);
123            scored.push((coverage, reduced));
124        }
125        eprintln!("\rSub-reduce: {sub_total}/{sub_total} ...done    ");
126
127        // Sort by coverage ascending — specific hypotheses first in the union.
128        // Plotkin's reduction checks clauses front-to-back: specific clauses get
129        // checked first and removed (the general ones behind them cover the same
130        // examples). By the time we reach the general clauses, the specific ones
131        // are gone and the general ones become essential.
132        scored.sort_by(|a, b| a.0.cmp(&b.0));
133
134        // Union all reduced sub-hypothesis clauses, deduplicated
135        for (_coverage, reduced_h) in &scored {
136            for &clause in reduced_h {
137                let key = clause.to_string(&heap);
138                if seen.insert(key) {
139                    top_program.push(clause);
140                }
141            }
142        }
143
144        println!("\n=== Top Program ({} clauses) ===", top_program.len());
145        for clause in &top_program {
146            println!("  {}", clause.to_string(&heap));
147        }
148
149        // Step 3: Final reduction on the union
150        let reduced = reduce(
151            &examples.pos,
152            top_program,
153            &heap,
154            predicate_table,
155            config,
156            true,
157        );
158
159        println!("\n=== Reduced Program ({} clauses) ===", reduced.len());
160        for clause in &reduced {
161            println!("  {}", clause.to_string(&heap));
162        }
163    }
164
165    ExitCode::SUCCESS
166}
167
168/// Parse a single example string into a goal on the given query heap.
169fn parse_example(example: &str, query_heap: &mut QueryHeap) -> Result<usize, String> {
170    let query = format!(":-{example}.");
171    let literals = match TokenStream::new(tokenise(query)?).parse_clause()? {
172        Some(TreeClause::Directive(literals)) => literals,
173        _ => return Err(format!("Example '{example}' incorrectly formatted")),
174    };
175    let clause = build_clause(literals, None, None, query_heap, true);
176    Ok(clause[0])
177}
178
179/// Minimal work on the worker thread — just the copy.
180fn extract_hypothesis_local(proof: &Proof) -> (Vec<Cell>, Vec<Clause>) {
181    let mut local_cells: Vec<Cell> = Vec::new();
182    let mut ref_map = HashMap::new();
183    let mut clauses = Vec::new();
184
185    for clause in proof.hypothesis.iter() {
186        let new_literals: Vec<usize> = clause
187            .iter()
188            .map(|&lit_addr| {
189                local_cells.copy_term_with_ref_map(&proof.heap, lit_addr, &mut ref_map)
190            })
191            .collect();
192        clauses.push(Clause::new(new_literals, None, None));
193    }
194
195    (local_cells, clauses)
196}
197
198fn generalise(
199    pos_examples: &[String],
200    predicate_table: &PredicateTable,
201    heap: &[Cell],
202    config: Config,
203) -> (Vec<Cell>, Vec<Vec<Clause>>) {
204    let pool = rayon::ThreadPoolBuilder::new()
205        .num_threads(*CPU_COUNT - 1)
206        .build()
207        .unwrap();
208
209    let (tx, rx) = mpsc::channel::<HypothesisMsg>();
210    let total = pos_examples.len();
211    let completed = Arc::new(AtomicUsize::new(0));
212    let heap_len = heap.len();
213
214    // Collector runs on its own OS thread, processing results as they arrive
215    let collector = thread::spawn(move || {
216        let mut hypothesis_cells = Vec::new();
217        let mut hypotheses = Vec::new();
218        let mut seen = HashSet::new();
219        let mut offset = heap_len;
220
221        for HypothesisMsg { cells, mut h } in rx {
222            // Build canonical key before offset adjustment, using local cells
223            let mut clause_strings: Vec<String> =
224                h.iter().map(|clause| clause.to_string(&cells)).collect();
225            clause_strings.sort_unstable();
226            let key = clause_strings.join("|");
227
228            if !seen.insert(key) {
229                continue; // Duplicate hypothesis, skip
230            }
231
232            let len = cells.len();
233            for cell in cells {
234                let adjusted = match cell {
235                    (Tag::Str, addr) => (Tag::Str, addr + offset),
236                    (Tag::Lis, addr) => (Tag::Lis, addr + offset),
237                    (Tag::Ref, addr) => (Tag::Ref, addr + offset),
238                    other => other,
239                };
240                hypothesis_cells.push(adjusted);
241            }
242            for clause in h.iter_mut() {
243                for literal in clause.iter_mut() {
244                    *literal += offset;
245                }
246            }
247            hypotheses.push(h);
248            offset += len;
249        }
250
251        (hypothesis_cells, hypotheses)
252    });
253
254    // Workers — scope blocks until all are done, then drops tx clones
255    pool.scope(|s| {
256        for example in pos_examples {
257            let tx = tx.clone();
258            let completed = completed.clone();
259            s.spawn(move |_| {
260                generalise_thread(example, predicate_table, &heap, config, tx);
261                let done = completed.fetch_add(1, Ordering::Relaxed) + 1;
262                eprint!("\rGeneralise: {done}/{total} examples");
263                let _ = io::stderr().flush();
264            });
265        }
266    });
267    drop(tx); // drop the original sender so the collector's rx iterator ends
268    eprintln!(" ...done");
269
270    collector.join().unwrap()
271}
272
273fn generalise_thread(
274    example: &str,
275    predicate_table: &PredicateTable,
276    prog_heap: &[Cell],
277    config: Config,
278    tx: Sender<HypothesisMsg>,
279) {
280    let mut query_heap = QueryHeap::new(prog_heap, None);
281    let goal = match parse_example(&example, &mut query_heap) {
282        Ok(g) => g,
283        Err(e) => {
284            eprintln!("Failed to parse example '{}': {}", example, e);
285            return;
286        }
287    };
288    let mut proof = Proof::new(query_heap, &[goal]);
289
290    while proof.prove(predicate_table, config) {
291        for clause in proof.hypothesis.iter() {
292            clause.normalise_clause_vars(&mut proof.heap);
293            let (cells, h) = extract_hypothesis_local(&proof);
294            if tx.send(HypothesisMsg { cells, h }).is_err() {
295                break; // Receiver dropped
296            }
297        }
298    }
299}
300
301fn specialise(
302    neg_examples: &[String],
303    hypotheses: &[Vec<Clause>],
304    heap: &[Cell],
305    predicate_table: &PredicateTable,
306    config: Config,
307) -> Vec<bool> {
308    let pool = rayon::ThreadPoolBuilder::new()
309        .num_threads(*CPU_COUNT - 1)
310        .build()
311        .unwrap();
312
313    let (tx, rx) = mpsc::channel::<(usize, bool)>();
314    let total = hypotheses.len();
315    let completed = Arc::new(AtomicUsize::new(0));
316
317    // Collector: build the retain mask as results arrive
318    let collector = thread::spawn(move || {
319        let mut retained = vec![true; total];
320        for (idx, keep) in rx {
321            retained[idx] = keep;
322        }
323        retained
324    });
325
326    // One worker per hypothesis
327    pool.scope(|s| {
328        for (idx, hypothesis) in hypotheses.iter().enumerate() {
329            let tx = tx.clone();
330            let completed = completed.clone();
331            s.spawn(move |_| {
332                let keep =
333                    specialise_thread(neg_examples, hypothesis, heap, predicate_table, config);
334                let _ = tx.send((idx, keep));
335                let done = completed.fetch_add(1, Ordering::Relaxed) + 1;
336                eprint!("\rSpecialise: {done}/{total} hypotheses tested");
337                let _ = io::stderr().flush();
338            });
339        }
340    });
341    drop(tx);
342    eprintln!(" ...done");
343
344    collector.join().unwrap()
345}
346
347/// Test one hypothesis against all negative examples.
348/// Returns `true` if the hypothesis should be retained (no negative is provable).
349fn specialise_thread(
350    neg_examples: &[String],
351    hypothesis: &[Clause],
352    heap: &[Cell],
353    predicate_table: &PredicateTable,
354    config: Config,
355) -> bool {
356    // Use the original max_depth to bound recursive hypotheses, but disable learning
357    let config = Config {
358        max_depth: config.max_depth,
359        max_clause: 0,
360        max_pred: 0,
361        debug: false,
362    };
363
364    // Build a Hypothesis from the clauses so we can use Proof::with_hypothesis
365    let mut h = Hypothesis::new();
366    for clause in hypothesis {
367        h.push_clause(clause.clone(), SmallVec::new());
368    }
369
370    for example in neg_examples {
371        let mut query_heap = QueryHeap::new(heap, None);
372        let goal = match parse_example(example, &mut query_heap) {
373            Ok(g) => g,
374            Err(e) => {
375                eprintln!("Failed to parse negative example '{}': {}", example, e);
376                continue;
377            }
378        };
379        let mut proof = Proof::with_hypothesis(query_heap, &[goal], h);
380        // If any negative example is provable, reject this hypothesis
381        if proof.prove(predicate_table, config) {
382            return false;
383        }
384        // Reclaim the hypothesis — it was never mutated since max_clause is 0
385        h = std::mem::replace(&mut proof.hypothesis, Hypothesis::new());
386    }
387    true
388}
389
390/// Count how many positive examples a set of clauses can prove.
391fn count_coverage(
392    pos_examples: &[String],
393    clauses: &[&Clause],
394    heap: &[Cell],
395    predicate_table: &PredicateTable,
396    config: Config,
397) -> usize {
398    let config = Config {
399        max_depth: config.max_depth,
400        max_clause: 0,
401        max_pred: 0,
402        debug: false,
403    };
404
405    let mut h = Hypothesis::new();
406    for clause in clauses {
407        h.push_clause((*clause).clone(), SmallVec::new());
408    }
409
410    pos_examples
411        .iter()
412        .filter(|example| {
413            let mut query_heap = QueryHeap::new(heap, None);
414            let goal = match parse_example(example, &mut query_heap) {
415                Ok(g) => g,
416                Err(_) => return false,
417            };
418            let mut proof = Proof::with_hypothesis(query_heap, &[goal], h.clone());
419            proof.prove(predicate_table, config)
420        })
421        .count()
422}
423
424/// Plotkin's program reduction (Algorithm 3).
425/// Sequentially tries removing each clause; if all positives are still provable
426/// without it, the clause is redundant and permanently removed.
427fn reduce<'a>(
428    pos_examples: &[String],
429    mut top_program: Vec<&'a Clause>,
430    heap: &[Cell],
431    predicate_table: &PredicateTable,
432    config: Config,
433    verbose: bool,
434) -> Vec<&'a Clause> {
435    let config = Config {
436        max_depth: config.max_depth,
437        max_clause: 0,
438        max_pred: 0,
439        debug: false,
440    };
441
442    let total = top_program.len();
443    let mut removed = 0usize;
444    let mut i = 0;
445    while i < top_program.len() {
446        if verbose {
447            eprint!(
448                "\rReduce: {}/{total} checked, {removed} removed    ",
449                i + removed + 1
450            );
451            let _ = io::stderr().flush();
452        }
453
454        // Build hypothesis from all clauses except the one at index i
455        let mut h = Hypothesis::new();
456        for (j, clause) in top_program.iter().enumerate() {
457            if j != i {
458                h.push_clause((*clause).clone(), SmallVec::new());
459            }
460        }
461
462        // Check if all positive examples are still provable without clause i
463        let redundant = pos_examples.iter().all(|example| {
464            let mut query_heap = QueryHeap::new(heap, None);
465            let goal = match parse_example(example, &mut query_heap) {
466                Ok(g) => g,
467                Err(_) => return true, // skip unparseable examples
468            };
469            let mut proof = Proof::with_hypothesis(query_heap, &[goal], h.clone());
470            proof.prove(predicate_table, config)
471        });
472
473        if redundant {
474            top_program.remove(i);
475            removed += 1;
476            // Don't increment i — next clause slides into this position
477        } else {
478            i += 1;
479        }
480    }
481    if verbose {
482        eprintln!(" ...done");
483    }
484
485    top_program
486}