use super::poly_utils::{as_integer, is_free_of_subexpr};
use crate::kernel::{ExprData, ExprId, ExprPool};
#[derive(Debug, Clone, Copy)]
pub struct FieldElem {
pub a: ExprId, pub b: ExprId, }
impl FieldElem {
pub fn pure_rational(a: ExprId, pool: &ExprPool) -> Self {
FieldElem {
a,
b: pool.integer(0_i32),
}
}
pub fn pure_sqrt(b: ExprId, pool: &ExprPool) -> Self {
FieldElem {
a: pool.integer(0_i32),
b,
}
}
pub fn one(pool: &ExprPool) -> Self {
FieldElem {
a: pool.integer(1_i32),
b: pool.integer(0_i32),
}
}
pub fn zero(pool: &ExprPool) -> Self {
FieldElem {
a: pool.integer(0_i32),
b: pool.integer(0_i32),
}
}
pub fn add(self, other: FieldElem, pool: &ExprPool) -> FieldElem {
let a = pool.add(vec![self.a, other.a]);
let b = pool.add(vec![self.b, other.b]);
FieldElem { a, b }
}
#[allow(dead_code)]
pub fn neg(self, pool: &ExprPool) -> FieldElem {
let neg1 = pool.integer(-1_i32);
let a = pool.mul(vec![neg1, self.a]);
let b = pool.mul(vec![neg1, self.b]);
FieldElem { a, b }
}
pub fn mul(self, other: FieldElem, p: ExprId, pool: &ExprPool) -> FieldElem {
let ac = pool.mul(vec![self.a, other.a]);
let bd_p = pool.mul(vec![self.b, other.b, p]);
let new_a = pool.add(vec![ac, bd_p]);
let ad = pool.mul(vec![self.a, other.b]);
let bc = pool.mul(vec![self.b, other.a]);
let new_b = pool.add(vec![ad, bc]);
FieldElem { a: new_a, b: new_b }
}
pub fn inv(self, p: ExprId, pool: &ExprPool) -> FieldElem {
use super::poly_utils::is_zero_expr;
if is_zero_expr(self.a, pool) {
let bp = pool.mul(vec![self.b, p]);
let new_b = pool.pow(bp, pool.integer(-1_i32));
return FieldElem {
a: pool.integer(0_i32),
b: new_b,
};
}
let a2 = pool.pow(self.a, pool.integer(2_i32));
let b2_p = pool.mul(vec![pool.pow(self.b, pool.integer(2_i32)), p]);
let neg1 = pool.integer(-1_i32);
let norm = pool.add(vec![a2, pool.mul(vec![neg1, b2_p])]);
let norm_inv = pool.pow(norm, pool.integer(-1_i32));
let new_a = pool.mul(vec![self.a, norm_inv]);
let new_b = pool.mul(vec![neg1, self.b, norm_inv]);
FieldElem { a: new_a, b: new_b }
}
pub fn powi(self, n: i64, p: ExprId, pool: &ExprPool) -> FieldElem {
if n == 0 {
return FieldElem::one(pool);
}
if n < 0 {
return self.inv(p, pool).powi(-n, p, pool);
}
if n == 1 {
return self;
}
let half = self.powi(n / 2, p, pool);
let sq = half.mul(half, p, pool);
if n % 2 == 0 {
sq
} else {
sq.mul(self, p, pool)
}
}
}
pub fn decompose_sqrt(
expr: ExprId,
sqrt_id: ExprId,
p_expr: ExprId,
pool: &ExprPool,
) -> Option<(ExprId, ExprId)> {
let elem = decompose_elem(expr, sqrt_id, p_expr, pool)?;
Some((elem.a, elem.b))
}
fn decompose_elem(
expr: ExprId,
sqrt_id: ExprId,
p_expr: ExprId,
pool: &ExprPool,
) -> Option<FieldElem> {
if expr == sqrt_id {
return Some(FieldElem::pure_sqrt(pool.integer(1_i32), pool));
}
if is_free_of_subexpr(expr, sqrt_id, pool) {
return Some(FieldElem::pure_rational(expr, pool));
}
match pool.get(expr) {
ExprData::Add(args) => {
let mut acc = FieldElem::zero(pool);
for a in &args {
let elem = decompose_elem(*a, sqrt_id, p_expr, pool)?;
acc = acc.add(elem, pool);
}
Some(acc)
}
ExprData::Mul(args) => {
let mut acc = FieldElem::one(pool);
for a in &args {
let elem = decompose_elem(*a, sqrt_id, p_expr, pool)?;
acc = acc.mul(elem, p_expr, pool);
}
Some(acc)
}
ExprData::Pow { base, exp } => {
if base == sqrt_id {
let n = as_integer(exp, pool)?;
if n == 0 {
return Some(FieldElem::one(pool));
}
if n > 0 {
let n_u = n as u32;
if n_u % 2 == 0 {
let p_pow = pool.pow(p_expr, pool.integer(n_u / 2));
return Some(FieldElem::pure_rational(p_pow, pool));
} else {
let p_pow = pool.pow(p_expr, pool.integer((n_u - 1) / 2));
return Some(FieldElem::pure_sqrt(p_pow, pool));
}
} else {
let base_elem = FieldElem::pure_sqrt(pool.integer(1_i32), pool);
return Some(base_elem.powi(n, p_expr, pool));
}
}
if let Some(n) = as_integer(exp, pool) {
let base_elem = decompose_elem(base, sqrt_id, p_expr, pool)?;
return Some(base_elem.powi(n, p_expr, pool));
}
None
}
ExprData::Func { ref name, ref args } if name == "sqrt" && args.len() == 1 => {
if expr == sqrt_id {
Some(FieldElem::pure_sqrt(pool.integer(1_i32), pool))
} else {
None
}
}
_ => {
if is_free_of_subexpr(expr, sqrt_id, pool) {
Some(FieldElem::pure_rational(expr, pool))
} else {
None
}
}
}
}