Documentation
use std::collections::HashSet;
use std::ops;

use im::Vector;
use proptest::strategy::Strategy;

use crate::polar::Ty;
use crate::tests::{arb_polar_ty, Constructed};
use crate::Polarity;

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(in crate::biunify) struct Constraint(
    pub(in crate::biunify) Ty<Constructed, char>,
    pub(in crate::biunify) Ty<Constructed, char>,
);

pub(in crate::biunify) fn arb_constraint() -> impl Strategy<Value = Constraint> {
    (arb_polar_ty(Polarity::Pos), arb_polar_ty(Polarity::Neg)).prop_map(|(l, r)| Constraint(l, r))
}

impl Constraint {
    fn bisubst(self, sub: &Bisubst) -> Self {
        Constraint(
            sub.apply(self.0, Polarity::Pos),
            sub.apply(self.1, Polarity::Neg),
        )
    }
}

fn subi(con: Constraint) -> Result<Vec<Constraint>, ()> {
    match con {
        Constraint(
            Ty::Constructed(Constructed::Fun(d1, r1)),
            Ty::Constructed(Constructed::Fun(d2, r2)),
        ) => Ok(vec![Constraint(*d2, *d1), Constraint(*r1, *r2)]),
        Constraint(Ty::Constructed(Constructed::Bool), Ty::Constructed(Constructed::Bool)) => {
            Ok(vec![])
        }
        Constraint(
            Ty::Constructed(Constructed::Record(f1)),
            Ty::Constructed(Constructed::Record(f2)),
        ) => {
            if iter_set::difference(f2.keys(), f1.keys()).next().is_none() {
                Ok(iter_set::intersection(f1.keys(), f2.keys())
                    .map(|key| Constraint(*f1[key].clone(), *f2[key].clone()))
                    .collect())
            } else {
                Err(())
            }
        }
        Constraint(Ty::Recursive(lhs), rhs) => {
            let lhs = subst((*lhs).clone(), 0, Ty::Recursive(lhs));
            Ok(vec![Constraint(lhs, rhs)])
        }
        Constraint(lhs, Ty::Recursive(rhs)) => {
            let rhs = subst((*rhs).clone(), 0, Ty::Recursive(rhs));
            Ok(vec![Constraint(lhs, rhs)])
        }
        Constraint(Ty::Add(lhsa, lhsb), rhs) => {
            Ok(vec![Constraint(*lhsa, rhs.clone()), Constraint(*lhsb, rhs)])
        }
        Constraint(lhs, Ty::Add(rhsa, rhsb)) => {
            Ok(vec![Constraint(lhs.clone(), *rhsa), Constraint(lhs, *rhsb)])
        }
        Constraint(Ty::Zero, _) => Ok(vec![]),
        Constraint(_, Ty::Zero) => Ok(vec![]),
        _ => Err(()),
    }
}

fn atomic(con: &Constraint) -> Result<Bisubst, ()> {
    match con {
        &Constraint(Ty::UnboundVar(v), Ty::Constructed(_))
        | &Constraint(Ty::UnboundVar(v), Ty::UnboundVar(_)) => Ok(Bisubst::unit(
            v,
            Polarity::Neg,
            fixpoint(Ty::Add(
                Box::new(Ty::UnboundVar(v)),
                Box::new(bisubst(
                    con.1.clone(),
                    Polarity::Neg,
                    (Polarity::Neg, v),
                    Ty::BoundVar(0),
                )),
            )),
        )),
        &Constraint(Ty::Constructed(_), Ty::UnboundVar(v)) => Ok(Bisubst::unit(
            v,
            Polarity::Pos,
            fixpoint(Ty::Add(
                Box::new(Ty::UnboundVar(v)),
                Box::new(bisubst(
                    con.0.clone(),
                    Polarity::Pos,
                    (Polarity::Pos, v),
                    Ty::BoundVar(0),
                )),
            )),
        )),
        _ => Err(()),
    }
}

#[derive(Debug, Clone)]
pub(in crate::biunify) struct Bisubst {
    sub: Vector<((Polarity, char), Ty<Constructed, char>)>,
}

impl Bisubst {
    fn new() -> Self {
        Bisubst { sub: Vector::new() }
    }

    fn unit(v: char, pol: Polarity, ty: Ty<Constructed, char>) -> Self {
        Bisubst {
            sub: Vector::unit(((pol, v), ty)),
        }
    }

    fn apply(&self, mut ty: Ty<Constructed, char>, pol: Polarity) -> Ty<Constructed, char> {
        for (v, sub) in &self.sub {
            ty = bisubst(ty, pol, *v, sub.clone())
        }
        ty
    }
}

impl ops::MulAssign for Bisubst {
    fn mul_assign(&mut self, other: Self) {
        self.sub.append(other.sub)
    }
}

fn subst(
    ty: Ty<Constructed, char>,
    var: usize,
    sub: Ty<Constructed, char>,
) -> Ty<Constructed, char> {
    match ty {
        Ty::Add(l, r) => Ty::Add(
            Box::new(subst(*l, var, sub.clone())),
            Box::new(subst(*r, var, sub)),
        ),
        Ty::Recursive(t) => Ty::Recursive(Box::new(subst(*t, var + 1, sub))),
        Ty::BoundVar(idx) if idx == var => sub,
        Ty::Constructed(Constructed::Fun(d, r)) => Ty::Constructed(Constructed::Fun(
            Box::new(subst(*d, var, sub.clone())),
            Box::new(subst(*r, var, sub)),
        )),
        Ty::Constructed(Constructed::Record(fields)) => Ty::Constructed(Constructed::Record(
            fields
                .into_iter()
                .map(|(k, v)| (k, Box::new(subst(*v, var, sub.clone()))))
                .collect(),
        )),
        _ => ty,
    }
}

fn bisubst(
    ty: Ty<Constructed, char>,
    pol: Polarity,
    var: (Polarity, char),
    mut sub: Ty<Constructed, char>,
) -> Ty<Constructed, char> {
    match ty {
        Ty::Add(l, r) => Ty::Add(
            Box::new(bisubst(*l, pol, var, sub.clone())),
            Box::new(bisubst(*r, pol, var, sub)),
        ),
        Ty::Recursive(t) => {
            shift(&mut sub, 1);
            Ty::Recursive(Box::new(bisubst(*t, pol, var, sub)))
        }
        Ty::UnboundVar(v) if (pol, v) == var => sub,
        Ty::Constructed(Constructed::Fun(d, r)) => Ty::Constructed(Constructed::Fun(
            Box::new(bisubst(*d, -pol, var, sub.clone())),
            Box::new(bisubst(*r, pol, var, sub)),
        )),
        Ty::Constructed(Constructed::Record(fields)) => Ty::Constructed(Constructed::Record(
            fields
                .into_iter()
                .map(|(k, v)| (k, Box::new(bisubst(*v, pol, var, sub.clone()))))
                .collect(),
        )),
        _ => ty,
    }
}

fn split(ty: Ty<Constructed, char>, var: usize) -> (Ty<Constructed, char>, Ty<Constructed, char>) {
    match ty {
        Ty::BoundVar(idx) if idx == var => (ty, Ty::Zero),
        Ty::Zero => (Ty::Zero, Ty::Zero),
        Ty::Add(l, r) => {
            let (la, lg) = split(*l, var);
            let (ra, rg) = split(*r, var);
            (
                Ty::Add(Box::new(la), Box::new(ra)),
                Ty::Add(Box::new(lg), Box::new(rg)),
            )
        }
        Ty::BoundVar(_) | Ty::UnboundVar(_) | Ty::Constructed(_) => (Ty::Zero, ty),
        Ty::Recursive(ref t) => {
            let (ta, tg) = split((**t).clone(), var + 1);
            (ta, subst(tg, var + 1, ty))
        }
    }
}

pub(crate) fn fixpoint(ty: Ty<Constructed, char>) -> Ty<Constructed, char> {
    Ty::Recursive(Box::new(split(ty, 0).1))
}

fn shift(ty: &mut Ty<Constructed, char>, n: usize) {
    match ty {
        Ty::BoundVar(idx) => *idx += n,
        Ty::Add(l, r) => {
            shift(l, n);
            shift(r, n);
        }
        Ty::Constructed(Constructed::Fun(d, r)) => {
            shift(d, n);
            shift(r, n);
        }
        Ty::Constructed(Constructed::Record(fields)) => {
            fields.values_mut().for_each(|t| shift(t, n))
        }
        Ty::Recursive(t) => shift(t, n),
        _ => (),
    }
}

pub(in crate::biunify) fn biunify(constraint: Constraint) -> Result<Bisubst, ()> {
    biunify_all(vec![constraint])
}

pub(in crate::biunify) fn biunify_all(mut cons: Vec<Constraint>) -> Result<Bisubst, ()> {
    let mut hyp = HashSet::new();
    let mut result = Bisubst::new();
    while let Some(con) = cons.pop() {
        if hyp.contains(&con) {
            continue;
        } else if let Ok(bisub) = atomic(&con) {
            hyp.insert(con);
            cons = cons.into_iter().map(|con| con.bisubst(&bisub)).collect();
            hyp = hyp.into_iter().map(|con| con.bisubst(&bisub)).collect();
            result *= bisub;
        } else if let Ok(sub) = subi(con.clone()) {
            hyp.insert(con);
            cons.extend(sub);
        } else {
            return Err(());
        }
    }

    Ok(result)
}