use std::env;
use std::{collections::HashSet, time::SystemTime};
use crate::{apl, ops};
use crate::base::Base;
use crate::nid::NID;
use crate::vid::VID;
use crate::ops::Ops;
use crate::reg::Reg;
use crate::{GraphViz, ast::{ASTBase, RawASTBase}, int::{GBASE,BInt,BaseBit}};
pub trait SubSolver {
fn init(&mut self, top: VID)->NID { NID::from_vid(top) }
fn subst(&mut self, ctx:NID, vid:VID, ops:&Ops)->NID;
fn get_one(&self, ctx:NID, nvars:usize)->Option<Reg> {
println!("Warning: default SubSolver::get_one() calls get_all(). Override this!");
self.get_all(ctx, nvars).iter().next().cloned() }
fn get_all(&self, ctx:NID, nvars:usize)->HashSet<Reg>;
fn status(&self)->String { "".to_string() }
fn dump(&self, _step: usize, _nid: NID) { }
fn init_stats(&mut self) { }
fn print_stats(&mut self) { }}
impl<B:Base> SubSolver for B {
fn subst(&mut self, ctx:NID, v:VID, ops:&Ops) ->NID {
let def = match ops {
Ops::RPN(x) => if x.len() == 3 {
match x[2].to_fun().unwrap() {
ops::AND => self.and(x[0], x[1]),
ops::XOR => self.xor(x[0], x[1]),
ops::VEL => self.or(x[0], x[1]),
_ => panic!("don't know how to translate {:?}", ops)}}
else { todo!("SubSolver impl for Base can only handle simple dyadic ops for now.") }};
self.sub(v, def, ctx)}
fn get_all(&self, ctx:NID, nvars:usize)->HashSet<Reg> { self.solution_set(ctx, nvars) }
fn init_stats(&mut self) { Base::init_stats(self) }
fn print_stats(&mut self) { Base::print_stats(self) }
}
pub trait Progress<S:SubSolver> {
fn on_start(&mut self, _ctx:&DstNid) { } fn on_step(&mut self, src:&RawASTBase, dest: &mut S, step:usize, millis:u128, oldtop:DstNid, newtop:DstNid);
fn on_done(&mut self, src:&RawASTBase, dest: &mut S, newtop:DstNid); }
pub struct ProgressReport<'a> {
pub start: std::time::SystemTime,
pub millis: u128,
pub save_dot: bool,
pub save_dest: bool,
pub prefix: &'a str }
#[derive(Clone, Copy, Debug, PartialEq)] pub struct SrcNid { pub n: NID }
#[derive(Clone, Copy, Debug, PartialEq)] pub struct DstNid { pub n: NID }
impl<S:SubSolver> Progress<S> for ProgressReport<'_> {
fn on_start(&mut self, _ctx:&DstNid) { self.start = std::time::SystemTime::now(); }
fn on_step(&mut self, _src:&RawASTBase, _dest: &mut S, _step:usize, _millis:u128, _oldtop:DstNid, _newtop:DstNid) { }
fn on_done(&mut self, _src:&RawASTBase, _dest: &mut S, _newtop:DstNid) {
println!("total time: {} ms", self.start.elapsed().unwrap().as_millis() ) }}
fn default_bitmask(_src:&RawASTBase, v:VID) -> u64 { v.bitmask() }
pub fn sort_by_cost(src:&RawASTBase, top:SrcNid)->(RawASTBase,SrcNid) {
let (mut src0,kept0) = src.repack(vec![top.n]);
src0.tag(kept0[0], "-top-".to_string());
let (_m0,c0) = src0.masks_and_costs(default_bitmask);
let p = apl::gradeup(&c0); let ast = src0.permute(&p);
let n = ast.get("-top-").expect("what? I just put it there.");
(ast,SrcNid{n}) }
pub fn convert_nid(sn:SrcNid)->DstNid {
let SrcNid{ n } = sn;
let r = if n.is_const() { n }
else {
let r0 =
if n.is_vir() { panic!("what? should never be a VIR in the source."); }
else if n.is_var() { n.raw() }
else if n.is_ixn() { NID::vir(n.idx() as u32) }
else { todo!("convert_nid({:?})", n) };
if n.is_inv() { !r0 } else { r0 }};
DstNid{ n: r } }
fn refine_one(dst: &mut dyn SubSolver, v:VID, src:&RawASTBase, d:DstNid)->DstNid {
let ctx = d.n;
let ops = src.get_ops(NID::ixn(v.vir_ix()));
let cn = |x0:&NID|->NID { if x0.is_fun() { *x0 } else { convert_nid(SrcNid{n:*x0}).n }};
let def:Ops = Ops::RPN( ops.to_rpn().map(cn).collect() );
DstNid{n: dst.subst(ctx, v, &def) }}
pub fn solve<S:SubSolver>(dst:&mut S, src0:&RawASTBase, sn:NID)->DstNid {
if sn.is_lit() { DstNid{n:sn} }
else {
dst.init(sn.vid());
let (src, top) = sort_by_cost(src0, SrcNid{n:sn});
let mut step:usize = top.n.idx();
let mut v = VID::vir(step as u32);
let mut ctx = DstNid{n: dst.init(v)};
let mut pr = ProgressReport{ start: SystemTime::now(), save_dot: false, save_dest: false, prefix:"x", millis: 0 };
<dyn Progress<S>>::on_start(&mut pr, &ctx);
while !(ctx.n.is_var() || ctx.n.is_const()) {
let now = std::time::SystemTime::now();
let old = ctx; ctx = refine_one(dst, v, &src, ctx);
let millis = now.elapsed().expect("elapsed?").as_millis();
pr.on_step(&src, dst, step, millis, old, ctx);
if step == 0 { break } else { step -= 1; v=VID::vir(step as u32) }}
pr.on_done(&src, dst, ctx);
ctx}}
fn multiplication_bits<T0:BInt, T1:BInt>(k:usize)->(BaseBit, BaseBit) {
GBASE.with(|gb| gb.replace(ASTBase::empty())); let (y, x) = (T0::def("y", 0), T0::def("x", T0::n())); let lt = x.lt(&y);
let xy:T1 = x.times(&y); let k = T1::new(k); let eq = xy.eq(&k);
(lt,eq) }
pub fn find_factors<T0:BInt, T1:BInt, S:SubSolver>(dest:&mut S, k:usize, expected:Vec<(u64,u64)>) {
let (lt, eq) = multiplication_bits::<T0,T1>(k);
let mut show_ast = false; for arg in env::args() { match arg.as_str() {
"-a" => { show_ast = true }
"-r" => { }
_ => {} }}
if show_ast {
GBASE.with(|gb| { gb.borrow().show_named(lt.clone().n, "lt") });
GBASE.with(|gb| { gb.borrow().show_named(eq.clone().n, "eq") }); }
let top:BaseBit = lt & eq;
assert!(top.n.is_ixn(), "top nid seems to be a literal. (TODO: handle these already solved cases)");
let gb = GBASE.with(|gb| gb.replace(ASTBase::empty())); let src = gb.raw_ast();
if show_ast { src.show_named(top.n, "ast"); }
dest.init_stats();
let answer:DstNid = solve(dest, src, top.n);
type Factors = (u64,u64);
let to_factors = |r:&Reg|->Factors {
let t = r.as_usize();
let x = t & ((1<<T0::n())-1);
let y = t >> T0::n();
(y as u64, x as u64) };
let actual_regs:HashSet<Reg> = dest.get_all(answer.n, 2*T0::n() as usize);
let actual:HashSet<Factors> = actual_regs.iter().map(to_factors).collect();
let expect:HashSet<Factors> = expected.iter().map(|&(x,y)| (x, y)).collect();
assert_eq!(actual, expect);
dest.print_stats(); }
#[test] pub fn test_nano_bdd() {
use crate::{bdd::BddBase, int::{X2,X4}};
find_factors::<X2,X4,BddBase>(&mut BddBase::new(), 6, vec![(2,3)]); }
#[test] pub fn test_nano_anf() {
use crate::{anf::ANFBase, int::{X2,X4}};
find_factors::<X2,X4,ANFBase>(&mut ANFBase::new(), 6, vec![(2,3)]); }
#[test] pub fn test_nano_swap() {
use crate::{swap::SwapSolver, int::{X2,X4}};
find_factors::<X2, X4, SwapSolver>(&mut SwapSolver::new(), 6, vec![(2,3)]); }
#[test] pub fn test_tiny_bdd() {
use crate::{bdd::BddBase, int::{X4,X8}};
find_factors::<X4, X8, BddBase>(&mut BddBase::new(), 210, vec![(14,15)]); }
#[test] pub fn test_tiny_anf() {
use crate::{anf::ANFBase, int::{X4,X8}};
find_factors::<X4, X8, ANFBase>(&mut ANFBase::new(), 210, vec![(14,15)]); }
#[test] pub fn test_tiny_swap() {
use crate::{swap::SwapSolver, int::{X4,X8}};
find_factors::<X4, X8, SwapSolver>(&mut SwapSolver::new(), 210, vec![(14,15)]); }
#[test] pub fn test_multi_bdd() {
use crate::{bdd::BddBase, int::{X4,X8}};
find_factors::<X4, X8, BddBase>(&mut BddBase::new(), 30, vec![(2,15), (3,10), (5,6)]); }
#[test] pub fn test_multi_anf() {
use crate::{anf::ANFBase, int::{X4,X8}};
find_factors::<X4, X8, ANFBase>(&mut ANFBase::new(), 30, vec![(2,15), (3,10), (5,6)]); }
#[cfg(feature="slowtests")]
#[test] pub fn test_small_bdd() {
use {bdd::BddBase, int::{X8,X16}};
let expected = vec![(1,210), (2,105), ( 3,70), ( 5,42),
(6, 35), (7, 30), (10,21), (14,15)];
find_factors::<X8, X16, BddBase>(&mut BddBase::new(), 210, expected); }
#[cfg(feature="slowtests")]
#[test] pub fn test_small_swap() {
use {swap::SwapSolver, int::{X8,X16}};
let expected = vec![(1,210), (2,105), ( 3,70), ( 5,42),
(6, 35), (7, 30), (10,21), (14,15)];
find_factors::<X8, X16, SwapSolver>(&mut SwapSolver::new(), 210, expected); }