use rug::Rational;
use super::number_field::{KPoly, NumberField};
pub type QPoly = Vec<Rational>;
pub fn trim(mut p: QPoly) -> QPoly {
while p.last().is_some_and(|c| *c == 0) {
p.pop();
}
p
}
pub fn degree(p: &QPoly) -> i64 {
let mut d = p.len() as i64 - 1;
while d >= 0 && p[d as usize] == 0 {
d -= 1;
}
d
}
pub fn poly_zero() -> QPoly {
vec![]
}
#[allow(dead_code)]
pub fn poly_one() -> QPoly {
vec![Rational::from(1)]
}
pub fn poly_add(a: &QPoly, b: &QPoly) -> QPoly {
let n = a.len().max(b.len());
let mut result = vec![Rational::from(0); n];
for (i, c) in a.iter().enumerate() {
result[i] += c;
}
for (i, c) in b.iter().enumerate() {
result[i] += c;
}
trim(result)
}
pub fn poly_mul(a: &QPoly, b: &QPoly) -> QPoly {
if a.is_empty() || b.is_empty() {
return poly_zero();
}
let mut result = vec![Rational::from(0); a.len() + b.len() - 1];
for (i, ca) in a.iter().enumerate() {
for (j, cb) in b.iter().enumerate() {
result[i + j] += ca.clone() * cb.clone();
}
}
trim(result)
}
pub fn poly_scale(p: &QPoly, s: &Rational) -> QPoly {
if *s == 0 || p.is_empty() {
return poly_zero();
}
trim(p.iter().map(|c| c.clone() * s.clone()).collect())
}
pub fn poly_deriv(p: &QPoly) -> QPoly {
if p.len() <= 1 {
return poly_zero();
}
trim(
p[1..]
.iter()
.enumerate()
.map(|(i, c)| c.clone() * Rational::from(i as i64 + 1))
.collect(),
)
}
pub fn poly_integrate(p: &QPoly) -> QPoly {
let p = trim(p.clone());
if p.is_empty() {
return poly_zero();
}
let mut result = vec![Rational::from(0)]; for (i, c) in p.iter().enumerate() {
result.push(c.clone() / Rational::from(i as i64 + 1));
}
trim(result)
}
pub fn solve_poly_rde(k: i64, deta: &[Rational], h: &[Rational]) -> Option<QPoly> {
let deta = trim(deta.to_vec());
let h = trim(h.to_vec());
if h.is_empty() {
return Some(poly_zero());
}
let deg_deta = degree(&deta);
if deg_deta < 0 {
return Some(poly_integrate(&h));
}
assert!(k != 0, "solve_poly_rde called with k=0: caller bug");
let deg_h = degree(&h);
let m_signed = deg_h - deg_deta;
if m_signed < 0 {
return None;
}
let m = m_signed as usize;
let lc_deta = deta[deg_deta as usize].clone(); let k_rat = Rational::from(k);
let mut y = vec![Rational::from(0); m + 1];
for j in (0..=m).rev() {
let target_deg = j as i64 + deg_deta;
let mut rhs = if target_deg < h.len() as i64 {
h[target_deg as usize].clone()
} else {
Rational::from(0)
};
let deriv_idx = target_deg as usize + 1;
if deriv_idx <= m {
rhs -= Rational::from(target_deg + 1) * y[deriv_idx].clone();
}
for (i, deta_i) in deta.iter().enumerate().take(deg_deta as usize) {
let l = (target_deg - i as i64) as usize;
if l <= m && l != j {
rhs -= k_rat.clone() * deta_i.clone() * y[l].clone();
}
}
let divisor = k_rat.clone() * lc_deta.clone();
y[j] = rhs / divisor;
}
let y = trim(y);
let y_prime = poly_deriv(&y);
let k_deta_y = poly_scale(&poly_mul(&deta, &y), &k_rat);
let lhs = trim(poly_add(&y_prime, &k_deta_y));
let h_trimmed = trim(h);
if polys_equal(&lhs, &h_trimmed) {
Some(y)
} else {
None
}
}
pub fn solve_poly_rde_k(field: &NumberField, k: i64, deta: &KPoly, h: &KPoly) -> Option<KPoly> {
let deta = NumberField::kpoly_trim(deta.clone());
let h = NumberField::kpoly_trim(h.clone());
if h.is_empty() {
return Some(Vec::new());
}
let deg_deta = NumberField::kdeg(&deta);
if deg_deta < 0 {
return Some(field.kpoly_integrate(&h));
}
assert!(k != 0, "solve_poly_rde_k called with k=0: caller bug");
let deg_h = NumberField::kdeg(&h);
let m_signed = deg_h - deg_deta;
if m_signed < 0 {
return None; }
let m = m_signed as usize;
let kk = field.from_int(k);
let lc = deta[deg_deta as usize].clone(); let divisor_inv = field.inv(&field.mul(&kk, &lc))?;
let mut y: KPoly = vec![NumberField::k_zero(); m + 1];
for j in (0..=m).rev() {
let target_deg = j as i64 + deg_deta;
let mut rhs = if (target_deg as usize) < h.len() {
h[target_deg as usize].clone()
} else {
NumberField::k_zero()
};
let deriv_idx = target_deg as usize + 1;
if deriv_idx <= m {
let coef = field.from_int(target_deg + 1);
rhs = field.sub(&rhs, &field.mul(&coef, &y[deriv_idx]));
}
for i in 0..deg_deta as usize {
let deta_i = deta.get(i).cloned().unwrap_or_else(NumberField::k_zero);
let l = (target_deg - i as i64) as usize;
if l <= m && l != j {
let term = field.mul(&field.mul(&kk, &deta_i), &y[l]);
rhs = field.sub(&rhs, &term);
}
}
y[j] = field.mul(&rhs, &divisor_inv);
}
let y = NumberField::kpoly_trim(y);
let yp = field.kpoly_deriv(&y);
let kdeta_y = field.kpoly_scale(&field.kpoly_mul(&deta, &y), &kk);
let lhs = field.kpoly_add(&yp, &kdeta_y);
if kpoly_eq(&lhs, &h) {
Some(y)
} else {
None
}
}
fn kpoly_eq(a: &KPoly, b: &KPoly) -> bool {
let a = NumberField::kpoly_trim(a.clone());
let b = NumberField::kpoly_trim(b.clone());
if a.len() != b.len() {
return false;
}
a.iter()
.zip(b.iter())
.all(|(x, y)| trim(x.clone()) == trim(y.clone()))
}
fn polys_equal(a: &QPoly, b: &QPoly) -> bool {
let a = trim(a.clone());
let b = trim(b.clone());
if a.len() != b.len() {
return false;
}
a.iter().zip(b.iter()).all(|(x, y)| *x == *y)
}
use crate::kernel::{ExprData, ExprId, ExprPool};
pub fn expr_to_qpoly(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<QPoly> {
let mut coeffs: Vec<Rational> = Vec::new();
if collect_qpoly(expr, var, pool, &mut coeffs, 1) {
Some(trim(coeffs))
} else {
None
}
}
fn collect_qpoly(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
coeffs: &mut Vec<Rational>,
factor: i64,
) -> bool {
let ensure_len = |coeffs: &mut Vec<Rational>, n: usize| {
while coeffs.len() < n {
coeffs.push(Rational::from(0));
}
};
if expr == var {
ensure_len(coeffs, 2);
coeffs[1] += Rational::from(factor);
return true;
}
match pool.get(expr) {
ExprData::Integer(n) => {
let Some(val) = n.0.to_i64() else {
return false; };
ensure_len(coeffs, 1);
coeffs[0] += Rational::from(factor * val);
true
}
ExprData::Rational(r) => {
let rat = r.0.clone();
ensure_len(coeffs, 1);
coeffs[0] += rat * Rational::from(factor);
true
}
ExprData::Add(args) => {
for a in &args {
if !collect_qpoly(*a, var, pool, coeffs, factor) {
return false;
}
}
true
}
ExprData::Mul(args) => {
let mut rat_factor = Rational::from(factor);
let mut var_parts: Vec<ExprId> = Vec::new();
for &a in &args {
if is_free_of_var(a, var, pool) {
match to_rational_const(a, pool) {
Some(r) => rat_factor *= r,
None => return false, }
} else {
var_parts.push(a);
}
}
if var_parts.is_empty() {
ensure_len(coeffs, 1);
coeffs[0] += rat_factor;
return true;
}
if var_parts.len() == 1 {
let mut sub = Vec::new();
if !collect_qpoly(var_parts[0], var, pool, &mut sub, 1) {
return false;
}
let scale = rat_factor;
ensure_len(coeffs, sub.len());
for (i, c) in sub.iter().enumerate() {
if i >= coeffs.len() {
coeffs.push(Rational::from(0));
}
coeffs[i] += c.clone() * scale.clone();
}
return true;
}
false
}
ExprData::Pow { base, exp } => {
if base == var {
if let ExprData::Integer(n) = pool.get(exp) {
if let Some(n_u) = n.0.to_u32() {
ensure_len(coeffs, n_u as usize + 1);
coeffs[n_u as usize] += Rational::from(factor);
return true;
}
}
}
if is_free_of_var(expr, var, pool) {
if let Some(r) = to_rational_const(expr, pool) {
ensure_len(coeffs, 1);
coeffs[0] += r * Rational::from(factor);
return true;
}
}
false
}
_ => {
if is_free_of_var(expr, var, pool) {
if let Some(r) = to_rational_const(expr, pool) {
ensure_len(coeffs, 1);
coeffs[0] += r * Rational::from(factor);
return true;
}
}
false
}
}
}
pub fn is_free_of_var(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
let mut cache: std::collections::HashMap<ExprId, bool> = std::collections::HashMap::new();
is_free_of_var_memo(expr, var, pool, &mut cache)
}
fn is_free_of_var_memo(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
cache: &mut std::collections::HashMap<ExprId, bool>,
) -> bool {
if expr == var {
return false;
}
if let Some(&hit) = cache.get(&expr) {
return hit;
}
let result = match pool.get(expr) {
ExprData::Add(args) | ExprData::Mul(args) => args
.iter()
.all(|&a| is_free_of_var_memo(a, var, pool, cache)),
ExprData::Pow { base, exp } => {
is_free_of_var_memo(base, var, pool, cache)
&& is_free_of_var_memo(exp, var, pool, cache)
}
ExprData::Func { ref args, .. } => args
.iter()
.all(|&a| is_free_of_var_memo(a, var, pool, cache)),
_ => true,
};
cache.insert(expr, result);
result
}
fn to_rational_const(expr: ExprId, pool: &ExprPool) -> Option<Rational> {
match pool.get(expr) {
ExprData::Integer(n) => n.0.to_i64().map(Rational::from),
ExprData::Rational(r) => Some(r.0.clone()),
ExprData::Pow { base, exp } => {
if let ExprData::Integer(n) = pool.get(exp) {
if let Some(n_i) = n.0.to_i64() {
if let Some(b_r) = to_rational_const(base, pool) {
if n_i >= 0 {
let mut result = Rational::from(1);
for _ in 0..n_i {
result *= b_r.clone();
}
return Some(result);
} else {
if b_r != 0 {
let mut result = Rational::from(1);
for _ in 0..(-n_i) {
result *= b_r.clone();
}
return Some(Rational::from(1) / result);
}
}
}
}
}
None
}
ExprData::Mul(args) => {
let mut result = Rational::from(1);
for &a in &args {
result *= to_rational_const(a, pool)?;
}
Some(result)
}
_ => None,
}
}
pub fn qpoly_to_expr(poly: &QPoly, var: ExprId, pool: &ExprPool) -> ExprId {
let poly = trim(poly.clone());
if poly.is_empty() {
return pool.integer(0_i32);
}
let mut terms: Vec<ExprId> = Vec::new();
for (deg, coeff) in poly.iter().enumerate() {
if *coeff == 0 {
continue;
}
let coeff_expr = rational_to_expr(coeff, pool);
let term = if deg == 0 {
coeff_expr
} else if deg == 1 {
if *coeff == 1 {
var
} else {
pool.mul(vec![coeff_expr, var])
}
} else {
let power = pool.pow(var, pool.integer(deg as i32));
if *coeff == 1 {
power
} else {
pool.mul(vec![coeff_expr, power])
}
};
terms.push(term);
}
match terms.len() {
0 => pool.integer(0_i32),
1 => terms[0],
_ => pool.add(terms),
}
}
pub fn contains_subexpr(expr: ExprId, target: ExprId, pool: &ExprPool) -> bool {
if expr == target {
return true;
}
match pool.get(expr) {
ExprData::Add(args) | ExprData::Mul(args) => {
args.iter().any(|&a| contains_subexpr(a, target, pool))
}
ExprData::Pow { base, exp } => {
contains_subexpr(base, target, pool) || contains_subexpr(exp, target, pool)
}
ExprData::Func { ref args, .. } => args.iter().any(|&a| contains_subexpr(a, target, pool)),
_ => false,
}
}
pub fn rational_to_expr(r: &Rational, pool: &ExprPool) -> ExprId {
if *r.denom() == 1 {
pool.integer(r.numer().clone())
} else {
pool.rational(r.numer().clone(), r.denom().clone())
}
}
pub fn split_const_factor(c: ExprId, var: ExprId, pool: &ExprPool) -> (ExprId, ExprId) {
let one = pool.integer(1_i32);
match pool.get(c) {
ExprData::Mul(args) => {
let mut consts: Vec<ExprId> = Vec::new();
let mut vars: Vec<ExprId> = Vec::new();
for &a in &args {
if is_free_of_var(a, var, pool) {
consts.push(a);
} else {
vars.push(a);
}
}
if consts.is_empty() {
return (one, c);
}
let k = match consts.len() {
1 => consts[0],
_ => pool.mul(consts),
};
let rest = match vars.len() {
0 => one,
1 => vars[0],
_ => pool.mul(vars),
};
(k, rest)
}
_ => {
if is_free_of_var(c, var, pool) {
(c, one)
} else {
(one, c)
}
}
}
}
pub fn apply_const(k_const: ExprId, core: ExprId, pool: &ExprPool) -> ExprId {
if matches!(pool.get(k_const), ExprData::Integer(n) if n.0 == 1) {
core
} else {
pool.mul(vec![k_const, core])
}
}
#[cfg(test)]
mod tests {
use super::*;
fn rat(n: i64) -> Rational {
Rational::from(n)
}
fn rat_frac(n: i64, d: i64) -> Rational {
Rational::from((n, d))
}
#[test]
fn rde_exp_x2_nonelementary() {
let deta = vec![rat(0), rat(2)]; let h = vec![rat(1)]; assert!(solve_poly_rde(1, &deta, &h).is_none());
}
#[test]
fn rde_x_exp_x2_elementary() {
let deta = vec![rat(0), rat(2)]; let h = vec![rat(0), rat(1)]; let sol = solve_poly_rde(1, &deta, &h);
assert!(sol.is_some(), "expected a polynomial solution");
let y = sol.unwrap();
assert_eq!(degree(&y), 0, "solution should be a constant");
assert_eq!(y[0], rat_frac(1, 2), "y = 1/2");
}
#[test]
fn rde_2x2plus1_exp_x2_elementary() {
let deta = vec![rat(0), rat(2)]; let h = vec![rat(1), rat(0), rat(2)]; let sol = solve_poly_rde(1, &deta, &h);
assert!(sol.is_some(), "expected a polynomial solution");
let y = sol.unwrap();
assert_eq!(degree(&y), 1, "solution should be linear");
assert_eq!(y[0], rat(0));
assert_eq!(y[1], rat(1)); }
#[test]
fn rde_x2_exp_x_elementary() {
let deta = vec![rat(1)]; let h = vec![rat(0), rat(0), rat(1)]; let sol = solve_poly_rde(1, &deta, &h);
assert!(sol.is_some(), "expected a polynomial solution");
let y = sol.unwrap();
assert_eq!(degree(&y), 2);
assert_eq!(y[0], rat(2)); assert_eq!(y[1], rat(-2)); assert_eq!(y[2], rat(1)); }
#[test]
fn rde_x_exp_x_elementary() {
let deta = vec![rat(1)]; let h = vec![rat(0), rat(1)]; let sol = solve_poly_rde(1, &deta, &h);
assert!(sol.is_some());
let y = sol.unwrap();
assert_eq!(y[0], rat(-1));
assert_eq!(y[1], rat(1));
}
#[test]
fn rde_verify_consistency() {
let cases = vec![
(1i64, vec![rat(1)], vec![rat(0), rat(0), rat(1)]), (1, vec![rat(0), rat(2)], vec![rat(0), rat(1)]), (2, vec![rat(1)], vec![rat(0), rat(1), rat(0), rat(1)]), ];
for (k, deta, h) in cases {
if let Some(y) = solve_poly_rde(k, &deta, &h) {
let k_rat = Rational::from(k);
let lhs = trim(poly_add(
&poly_deriv(&y),
&poly_scale(&poly_mul(&deta, &y), &k_rat),
));
assert!(
polys_equal(&lhs, &h),
"verification failed for k={k}: lhs={lhs:?}, h={h:?}"
);
}
}
}
fn field_sqrt2() -> NumberField {
NumberField::new(vec![rat(-2), rat(0), rat(1)])
}
#[test]
fn rde_k_linear_sqrt2() {
let field = field_sqrt2();
let deta: KPoly = vec![vec![rat(1)]]; let h: KPoly = vec![vec![rat(0), rat(1)], vec![rat(1)]]; let y = solve_poly_rde_k(&field, 1, &deta, &h).expect("elementary");
assert_eq!(trim(y[0].clone()), vec![rat(-1), rat(1)]);
assert_eq!(trim(y[1].clone()), vec![rat(1)]);
}
#[test]
fn rde_k_quadratic_sqrt3() {
let field = NumberField::new(vec![rat(-3), rat(0), rat(1)]);
let deta: KPoly = vec![vec![rat(1)]]; let h: KPoly = vec![NumberField::k_zero(), vec![rat(1)], vec![rat(0), rat(1)]];
let y = solve_poly_rde_k(&field, 1, &deta, &h).expect("elementary");
let yp = field.kpoly_deriv(&y);
let lhs = field.kpoly_add(&yp, &y);
let lhs = NumberField::kpoly_trim(lhs);
let h = NumberField::kpoly_trim(h);
assert_eq!(lhs.len(), h.len());
for (a, b) in lhs.iter().zip(h.iter()) {
assert_eq!(trim(a.clone()), trim(b.clone()));
}
}
#[test]
fn rde_k_nonelementary_sqrt2_gaussian() {
let field = field_sqrt2();
let deta: KPoly = vec![NumberField::k_zero(), vec![rat(2)]]; let h: KPoly = vec![vec![rat(0), rat(1)], vec![rat(1)]]; assert!(solve_poly_rde_k(&field, 1, &deta, &h).is_none());
}
#[test]
fn rde_k_reduces_to_rational_case() {
let field = field_sqrt2();
let deta: KPoly = vec![vec![rat(1)]];
let h: KPoly = vec![NumberField::k_zero(), NumberField::k_zero(), vec![rat(1)]]; let y = solve_poly_rde_k(&field, 1, &deta, &h).expect("elementary");
assert_eq!(trim(y[0].clone()), vec![rat(2)]);
assert_eq!(trim(y[1].clone()), vec![rat(-2)]);
assert_eq!(trim(y[2].clone()), vec![rat(1)]);
}
}