use crate::expr::{Expr, Names, VarRange};
use crate::flag::Flag;
use crate::operator::*;
use ndarray::{Array1, ScalarOperand};
use num::{FromPrimitive, Integer, Num};
use sprs::{CsMat, CsMatView, CsVec, TriMat};
use std::fmt::*;
use std::ops::*;
#[derive(Clone, Debug)]
pub struct QFlag<N, F> {
pub basis: Basis<F>,
pub data: Array1<N>,
pub scale: u64,
pub expr: Expr<N, F>,
}
impl<N, F> PartialEq for QFlag<N, F>
where
N: Num + FromPrimitive + Clone,
F: Flag,
{
fn eq(&self, other: &Self) -> bool {
assert_eq!(self.basis, other.basis);
assert_eq!(self.data.len(), other.data.len());
let s1 = N::from_u64(self.scale).unwrap();
let s2 = N::from_u64(other.scale).unwrap();
self.data
.iter()
.zip(other.data.iter())
.all(|(x, y)| x.clone() * s2.clone() == y.clone() * s1.clone())
}
}
fn matching_scales<N>(scale1: u64, scale2: u64) -> (N, N, u64)
where
N: FromPrimitive,
{
let gcd = scale1.gcd(&scale2);
let c1 = N::from_u64(scale2 / gcd).unwrap();
let c2 = N::from_u64(scale1 / gcd).unwrap();
let scale = (scale1 / gcd) * scale2;
(c1, c2, scale)
}
impl<N, F> Add<&Self> for QFlag<N, F>
where
N: Clone + FromPrimitive + Num + ScalarOperand,
F: Flag,
{
type Output = Self;
fn add(self, other: &Self) -> Self::Output {
assert_eq!(self.basis, other.basis);
assert_eq!(self.data.len(), other.data.len());
let (a1, a2, scale) = matching_scales::<N>(self.scale, other.scale);
QFlag {
basis: self.basis,
data: self.data * a1 + &other.data * a2,
scale,
expr: Expr::add(self.expr, other.expr.clone()),
}
}
}
impl<N, F> Add for QFlag<N, F>
where
N: Clone + Num + FromPrimitive + ScalarOperand,
F: Flag,
{
type Output = Self;
fn add(self, other: Self) -> Self::Output {
self + &other
}
}
impl<'a, N, F> Sub for &'a QFlag<N, F>
where
N: Clone + Num + FromPrimitive + ScalarOperand,
F: Flag,
{
type Output = QFlag<N, F>;
fn sub(self, other: Self) -> Self::Output {
assert_eq!(self.basis, other.basis);
assert_eq!(self.data.len(), other.data.len());
let (a1, a2, scale) = matching_scales::<N>(self.scale, other.scale);
QFlag {
basis: self.basis,
data: &self.data * a1 - &other.data * a2,
scale,
expr: Expr::sub(self.expr.clone(), other.expr.clone()),
}
}
}
impl<N, F> Sub for QFlag<N, F>
where
N: Clone + Num + FromPrimitive + ScalarOperand,
F: Flag,
{
type Output = Self;
fn sub(self, other: Self) -> Self::Output {
&self - &other
}
}
impl<N, F> Neg for QFlag<N, F>
where
N: Clone + Neg<Output = N>,
{
type Output = Self;
fn neg(self) -> Self::Output {
Self {
basis: self.basis,
data: -self.data,
scale: self.scale,
expr: self.expr.neg(),
}
}
}
impl<'a, N, F> Neg for &'a QFlag<N, F>
where
N: Clone + Neg<Output = N>,
F: Clone,
{
type Output = QFlag<N, F>;
fn neg(self) -> Self::Output {
QFlag {
basis: self.basis,
data: -self.data.clone(),
scale: self.scale,
expr: self.expr.clone().neg(),
}
}
}
impl<N, F> Mul<N> for QFlag<N, F>
where
N: Num + ScalarOperand + Display,
F: Flag,
{
type Output = Self;
fn mul(self, rhs: N) -> Self::Output {
Self {
expr: Expr::mul(Expr::num(&rhs), self.expr.clone()),
basis: self.basis,
data: self.data * rhs,
scale: self.scale,
}
}
}
impl<N, F> Display for IneqMeta<N, F>
where
N: Display,
{
fn fmt(&self, f: &mut Formatter) -> Result {
write!(
f,
"{}\t{} {}",
self.flag_expr,
if self.equality { '=' } else { '≥' },
self.bound_expr
)
}
}
impl<N, F> Display for Ineq<N, F>
where
N: Display,
{
fn fmt(&self, f: &mut Formatter) -> Result {
self.meta.fmt(f)
}
}
fn quadratic_form<N>(lhs: &Array1<N>, matrix: &CsMat<u32>, rhs: &Array1<N>) -> N
where
N: Num + Clone + FromPrimitive,
{
assert_eq!(lhs.len(), matrix.rows());
assert_eq!(rhs.len(), matrix.cols());
let mut res = N::zero();
for (v, (i, j)) in matrix.iter() {
res = res + (N::from_u32(v.clone()).unwrap() * lhs[i].clone() * rhs[j].clone());
}
res
}
fn vector_matrix_mul<N>(matrix: &CsMatView<u32>, vec: &Array1<N>) -> Array1<N>
where
N: Num + Clone + FromPrimitive,
{
assert_eq!(vec.len(), matrix.cols());
let mut res: Array1<N> = Array1::zeros(matrix.rows());
for (&v, (i, j)) in matrix.iter() {
res[i] = res[i].clone() + N::from_u32(v).unwrap() * vec[j].clone();
}
res
}
fn multiply<N>(lhs: &Array1<N>, table: &[CsMat<u32>], rhs: &Array1<N>) -> Array1<N>
where
N: Num + Clone + FromPrimitive,
{
let mut res = Array1::<N>::zeros(table.len());
for (i, matrix) in table.iter().enumerate() {
res[i] = quadratic_form(lhs, matrix, rhs);
}
res
}
fn csvec_from_array<N>(array: &Array1<N>) -> CsVec<N>
where
N: Num + Clone,
{
let mut res = CsVec::empty(array.len());
for (i, val) in array.iter().enumerate() {
if val != &N::zero() {
res.append(i, val.clone())
}
}
res
}
fn array_from_csvec<N>(csvec: &CsVec<N>) -> Array1<N>
where
N: Num + Clone,
{
let mut res = vec![N::zero(); csvec.dim()];
csvec.scatter(&mut res);
Array1::from(res)
}
impl<N, F> QFlag<N, F>
where
N: Num + Clone + FromPrimitive,
F: Flag,
{
fn raw_expand(&self, operator: &CsMat<u32>, outbasis: Basis<F>, denom: u32) -> Self {
Self {
basis: outbasis,
data: vector_matrix_mul(&operator.view(), &self.data),
scale: self.scale * denom as u64,
expr: self.expr.clone(),
}
}
fn raw_multiply(&self, table: &[CsMat<u32>], other: &Self, denom: u32) -> Self {
assert_eq!(self.basis.t, other.basis.t);
Self {
basis: self.basis * other.basis,
data: multiply(&self.data, table, &other.data),
scale: self.scale * denom as u64 * other.scale,
expr: Expr::mul(self.expr.clone(), other.expr.clone()),
}
}
fn raw_untype(
&self,
untype_flag: &[usize],
untype_count: &[u32],
outbasis: Basis<F>,
outbasis_size: usize,
denom: u32,
) -> Self {
assert_eq!(untype_flag.len(), untype_count.len());
let mut data = Array1::<N>::zeros(outbasis_size);
for (i, v) in self.data.iter().enumerate() {
data[untype_flag[i]] =
data[untype_flag[i]].clone() + v.clone() * N::from_u32(untype_count[i]).unwrap()
}
Self {
basis: outbasis,
data,
scale: self.scale * denom as u64,
expr: self.expr.clone().unlab(),
}
}
}
fn untype_matrix<N>(untype_flag: &[usize], untype_count: &[u32], outbasis_size: usize) -> CsMat<N>
where
N: Num + FromPrimitive + Clone,
{
let inbasis_size = untype_flag.len();
let shape = (outbasis_size, inbasis_size);
let mut trimat = TriMat::with_capacity(shape, inbasis_size);
for i in 0..untype_flag.len() {
trimat.add_triplet(untype_flag[i], i, N::from_u32(untype_count[i]).unwrap())
}
trimat.to_csr()
}
impl<N, F> QFlag<N, F>
where
N: Num + Clone + FromPrimitive,
F: Flag,
{
pub fn expand(&self, outbasis: Basis<F>) -> Self {
let subflag = SubflagCount::from_to(self.basis, outbasis);
self.raw_expand(&subflag.get(), outbasis, subflag.denom())
}
pub fn untype(&self) -> Self {
let unlabeling = Unlabeling::<F>::total(self.basis.t);
let size = self.basis.size;
let outbasis = self.basis.with_type(Type::empty());
let unlabel = Unlabel { unlabeling, size };
let (unlab_flag, unlab_count) = unlabel.get();
self.raw_untype(
&unlab_flag,
&unlab_count,
outbasis,
outbasis.get().len(),
unlabel.denom(),
)
}
}
impl<'a, N, F> Mul for &'a QFlag<N, F>
where
N: Num + Clone + FromPrimitive,
F: Flag,
{
type Output = QFlag<N, F>;
fn mul(self, other: Self) -> QFlag<N, F> {
let split = SplitCount::from_input(&self.basis, &other.basis);
self.raw_multiply(&split.get(), other, split.denom())
}
}
impl<N, F> Mul for QFlag<N, F>
where
N: Num + Clone + FromPrimitive,
F: Flag,
{
type Output = Self;
fn mul(self, other: Self) -> Self {
&self * &other
}
}
impl<N, F> QFlag<N, F> {
pub fn with_expr(mut self, expr: Expr<N, F>) -> Self {
self.expr = expr;
self
}
pub fn named(mut self, name: String) -> Self {
self.expr = self.expr.named(name);
self
}
pub fn no_scale(mut self) -> Self
where
N: FromPrimitive + DivAssign<N> + ScalarOperand,
{
self.data /= N::from_u64(self.scale).unwrap();
self.scale = 1;
self
}
pub fn map<G, M>(&self, g: G) -> QFlag<M, F>
where
G: Fn(&N) -> M,
{
QFlag {
basis: self.basis,
data: self.data.map(&g),
scale: self.scale,
expr: self.expr.map(&g),
}
}
}
impl<N, F> QFlag<N, F>
where
N: Num + FromPrimitive + Clone + Display,
F: Flag,
{
pub fn at_least(&self, x: N) -> Ineq<N, F> {
Ineq {
meta: IneqMeta {
basis: self.basis,
flag_expr: self.expr.clone(),
bound_expr: Expr::num(&x),
equality: false,
forall: None,
scale: self.scale,
},
data: vec![IneqData {
flag: csvec_from_array(&self.data),
bound: x * N::from_u64(self.scale).unwrap(),
}],
}
}
pub fn at_most(&self, x: N) -> Ineq<N, F>
where
N: Clone + Neg<Output = N>,
{
(-self.clone()).at_least(-x)
}
pub fn non_negative(&self) -> Ineq<N, F>
where
N: Num,
{
self.at_least(N::zero())
}
pub fn equal(self, n: N) -> Ineq<N, F>
where
N: Clone + Neg<Output = N>,
{
self.at_least(n).equality()
}
}
pub fn total_sum_is_one<N, F>(basis: Basis<F>) -> Ineq<N, F>
where
F: Flag,
N: Num + Clone + Neg<Output = N> + FromPrimitive + Display,
{
basis.one().equal(N::one())
}
pub fn flags_are_nonnegative<N, F>(basis: Basis<F>) -> Ineq<N, F>
where
F: Flag,
N: Num + Clone + Neg<Output = N>,
{
let n = basis.get().len();
let mut data = Vec::with_capacity(n);
for i in 0..n {
let mut flag = CsVec::empty(n);
flag.append(i, N::one());
data.push(IneqData {
flag,
bound: N::zero(),
})
}
let meta = IneqMeta {
basis,
flag_expr: Expr::Var(0).named(format!("flag(:{})", basis.print_concise())),
bound_expr: Expr::Zero,
equality: false,
forall: Some(VarRange::InBasis(basis)),
scale: 1,
};
Ineq { meta, data }
}
#[derive(Clone, Debug)]
pub struct IneqMeta<N, F> {
pub basis: Basis<F>,
pub flag_expr: Expr<N, F>,
forall: Option<VarRange<F>>,
pub bound_expr: Expr<N, F>,
pub equality: bool,
scale: u64,
}
impl<N, F: Flag> IneqMeta<N, F> {
fn opposite(self) -> Self {
Self {
basis: self.basis,
flag_expr: self.flag_expr.neg(),
bound_expr: self.bound_expr.neg(),
equality: self.equality,
forall: self.forall,
scale: self.scale,
}
}
fn one_sided_expr(&self) -> Expr<N, F> {
Expr::sub(self.flag_expr.clone(), self.bound_expr.clone())
}
fn multiply(&self, rhs_basis: &Basis<F>, rhs_expr: Expr<N, F>) -> Self {
let forall = if let Expr::Var(_) = rhs_expr {
match self.forall {
None => Some(VarRange::InBasis(rhs_basis.clone())),
Some(_) => unimplemented!(),
}
} else {
self.forall.clone()
};
Self {
basis: self.basis * *rhs_basis,
flag_expr: Expr::mul(self.one_sided_expr(), rhs_expr),
bound_expr: Expr::Zero,
equality: self.equality,
forall,
scale: self.scale * SplitCount::from_input(&self.basis, rhs_basis).denom() as u64,
}
}
fn untype(&self) -> Self {
Self {
basis: self.basis.with_type(Type::empty()),
flag_expr: Expr::unlab(self.flag_expr.clone()),
bound_expr: self.bound_expr.clone(),
equality: self.equality,
forall: self.forall.clone(),
scale: self.scale * Unlabel::total(self.basis).denom() as u64,
}
}
pub(crate) fn latex(&self, names: &mut Names<N, F>) -> String
where
N: Display,
{
format!(
"{}{} {} {}",
if let Some(ref range) = self.forall {
range.latex(names)
} else {
"".into()
},
self.flag_expr.latex(names),
if self.equality { "=" } else { "\\geq" },
self.bound_expr.latex(names),
)
}
}
#[derive(Clone, Debug)]
pub struct IneqData<N> {
pub flag: CsVec<N>,
pub bound: N,
}
impl<N> IneqData<N>
where
N: Num + Clone,
{
fn opposite(self) -> Self
where
N: Neg<Output = N>,
{
let mut flag = self.flag;
flag.map_inplace(|x| -x.clone());
Self {
flag,
bound: -self.bound,
}
}
fn one_sided(self) -> Self
where
N: Neg<Output = N>,
{
if self.bound == N::zero() {
self
} else {
let n = self.flag.dim();
let mut flag = CsVec::empty(n);
flag.reserve(n);
let mut next_j = 0;
for (i, val) in self.flag.iter() {
for j in next_j..i {
flag.append(j, -self.bound.clone())
}
flag.append(i, val.clone() - self.bound.clone());
next_j = i + 1;
}
for j in next_j..n {
flag.append(j, -self.bound.clone())
}
Self {
flag,
bound: N::zero(),
}
}
}
fn untype(&self, untype_matrix: &CsMat<N>, denom: u32) -> Self
where
N: Copy + Num + Default + std::iter::Sum + AddAssign + Send + Sync + FromPrimitive,
{
Self {
flag: untype_matrix * &self.flag,
bound: self.bound.clone() * N::from_u32(denom).unwrap(),
}
}
fn multiply_by_all(self, table: &[CsMat<N>], acc: &mut Vec<Self>)
where
N: Num + Copy + Send + Sync + std::iter::Sum + AddAssign + Default + Neg<Output = N>,
{
if let Some(other_size) = table.first().map(|mat| mat.cols()) {
let one_sided = self.one_sided();
let mut flags: Vec<CsVec<N>> = vec![CsVec::empty(table.len()); other_size];
for (i, mat) in table.iter().enumerate() {
let vec: CsVec<N> = &mat.transpose_view() * &one_sided.flag.view();
for (j, val) in vec.iter() {
flags[j].append(i, val.clone())
}
}
for flag in flags.into_iter() {
let ineq_data = Self {
flag,
bound: N::zero(),
};
acc.push(ineq_data)
}
}
}
}
#[derive(Clone, Debug)]
pub struct Ineq<N, F> {
pub meta: IneqMeta<N, F>,
pub data: Vec<IneqData<N>>,
}
impl<N, F> Ineq<N, F>
where
N: Clone + Num,
F: Flag,
{
pub fn opposite(self) -> Self
where
N: Neg<Output = N>,
{
Self {
meta: self.meta.opposite(),
data: self.data.into_iter().map(|x| x.opposite()).collect(),
}
}
pub fn equality(mut self) -> Self {
self.meta.equality = true;
self
}
pub fn relaxed(mut self, eps: N) -> Self
where
N: SubAssign,
{
for ineq in &mut self.data {
ineq.bound -= eps.clone()
}
self
}
pub fn lhs(&self, i: usize) -> QFlag<N, F> {
assert!(i < self.data.len());
QFlag {
basis: self.meta.basis,
data: array_from_csvec(&self.data[i].flag),
scale: self.meta.scale,
expr: self.meta.flag_expr.substitute_option(&self.meta.forall, i),
}
}
pub fn check(&self)
where
N: Debug + Neg<Output = N> + Clone + FromPrimitive + ScalarOperand + Display,
{
for i in 0..self.data.len() {
let x = self.lhs(i);
assert_eq!(x, x.expr.eval());
}
}
}
impl<N, F> Ineq<N, F>
where
N: Num + Copy + Send + Sync + Default + FromPrimitive + AddAssign + std::iter::Sum,
F: Flag,
{
pub fn untype(&self) -> Self {
let unlabeling = Unlabeling::<F>::total(self.meta.basis.t);
let size = self.meta.basis.size;
let unlabel = Unlabel { unlabeling, size };
let (unlab_f, unlab_c) = unlabel.get();
let outbasis_size = unlabel.output_basis().get().len();
let unlab_matrix = untype_matrix(&unlab_f, &unlab_c, outbasis_size);
let denom = unlabel.denom();
let mut data = Vec::new();
for i in &self.data {
let f = i.untype(&unlab_matrix, denom);
data.push(f)
}
Self {
meta: self.meta.untype(),
data,
}
}
pub fn multiply_by_all(self, outbasis: Basis<F>) -> Self
where
N: Neg<Output = N>,
{
let b = outbasis / self.meta.basis;
let splitcount = SplitCount::from_input(&self.meta.basis, &b);
let table: Vec<CsMat<N>> = splitcount
.get()
.iter()
.map(|m| m.map(|&x| N::from_u32(x).unwrap()))
.collect();
let mut data = Vec::new();
for ineq in self.data {
ineq.multiply_by_all(&table, &mut data)
}
Self {
data,
meta: self.meta.multiply(&b, Expr::Var(0)),
}
}
pub fn multiply_and_unlabel(self, outbasis: Basis<F>) -> Self
where
N: Neg<Output = N>,
{
assert_eq!(outbasis.t, Type::empty());
let unlabeling = Unlabeling::total(self.meta.basis.t);
let other = outbasis.with_type(self.meta.basis.t) / self.meta.basis;
let splitcount = SplitCount::from_input(&self.meta.basis, &other);
let operator = MulAndUnlabel {
split: splitcount,
unlabeling,
};
let table: Vec<CsMat<N>> = operator
.get()
.iter()
.map(|m| m.map(|&x| N::from_i64(x).unwrap()))
.collect();
let mut data = Vec::new();
for ineq in self.data {
ineq.multiply_by_all(&table, &mut data)
}
Self {
data,
meta: self.meta.multiply(&other, Expr::Var(0)).untype(),
}
}
}
pub fn flag<N, F>(f: &F) -> QFlag<N, F>
where
N: Num + Clone,
F: Flag,
{
Basis::new(f.size()).flag(f)
}
pub fn flag_typed<N, F>(f: &F, type_size: usize) -> QFlag<N, F>
where
N: Num + Clone,
F: Flag,
{
let flag = f.canonical_typed(type_size);
let type_flag = flag.induce(&(0..type_size).collect::<Vec<_>>());
let t = Type::from_flag(&type_flag);
let basis = Basis::new(f.size()).with_type(t);
basis.flag(&flag)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::flags::Graph;
use ndarray::array;
#[test]
fn test_internals() {
assert_eq!(matching_scales(15, 12), (4, 5, 60));
assert_eq!(matching_scales(2, 24), (12, 1, 24));
let (c1, c2, scale): (u64, u64, _) = matching_scales(1788, 2444);
let big = 1788 * 2444 * 1048;
assert_eq!((big * c1) / scale, big / 1788);
assert_eq!((big * c2) / scale, big / 2444);
}
#[test]
fn test_qflags() {
let qflag = QFlag {
basis: Basis::<Graph>::new(1),
data: array![3., 2., -5., 3.14],
scale: 42,
expr: Expr::Zero,
};
assert_eq!(qflag.clone().no_scale(), qflag)
}
}