use rug::Rational;
pub type QPoly = Vec<Rational>;
pub fn trim(mut p: QPoly) -> QPoly {
while p.last().is_some_and(|c| *c == 0) {
p.pop();
}
p
}
pub fn degree(p: &QPoly) -> i64 {
let mut d = p.len() as i64 - 1;
while d >= 0 && p[d as usize] == 0 {
d -= 1;
}
d
}
pub fn poly_zero() -> QPoly {
vec![]
}
#[allow(dead_code)]
pub fn poly_one() -> QPoly {
vec![Rational::from(1)]
}
pub fn poly_add(a: &QPoly, b: &QPoly) -> QPoly {
let n = a.len().max(b.len());
let mut result = vec![Rational::from(0); n];
for (i, c) in a.iter().enumerate() {
result[i] += c;
}
for (i, c) in b.iter().enumerate() {
result[i] += c;
}
trim(result)
}
pub fn poly_mul(a: &QPoly, b: &QPoly) -> QPoly {
if a.is_empty() || b.is_empty() {
return poly_zero();
}
let mut result = vec![Rational::from(0); a.len() + b.len() - 1];
for (i, ca) in a.iter().enumerate() {
for (j, cb) in b.iter().enumerate() {
result[i + j] += ca.clone() * cb.clone();
}
}
trim(result)
}
pub fn poly_scale(p: &QPoly, s: &Rational) -> QPoly {
if *s == 0 || p.is_empty() {
return poly_zero();
}
trim(p.iter().map(|c| c.clone() * s.clone()).collect())
}
pub fn poly_deriv(p: &QPoly) -> QPoly {
if p.len() <= 1 {
return poly_zero();
}
trim(
p[1..]
.iter()
.enumerate()
.map(|(i, c)| c.clone() * Rational::from(i as i64 + 1))
.collect(),
)
}
pub fn poly_integrate(p: &QPoly) -> QPoly {
let p = trim(p.clone());
if p.is_empty() {
return poly_zero();
}
let mut result = vec![Rational::from(0)]; for (i, c) in p.iter().enumerate() {
result.push(c.clone() / Rational::from(i as i64 + 1));
}
trim(result)
}
pub fn solve_poly_rde(k: i64, deta: &[Rational], h: &[Rational]) -> Option<QPoly> {
let deta = trim(deta.to_vec());
let h = trim(h.to_vec());
if h.is_empty() {
return Some(poly_zero());
}
let deg_deta = degree(&deta);
if deg_deta < 0 {
return Some(poly_integrate(&h));
}
assert!(k != 0, "solve_poly_rde called with k=0: caller bug");
let deg_h = degree(&h);
let m_signed = deg_h - deg_deta;
if m_signed < 0 {
return None;
}
let m = m_signed as usize;
let lc_deta = deta[deg_deta as usize].clone(); let k_rat = Rational::from(k);
let mut y = vec![Rational::from(0); m + 1];
for j in (0..=m).rev() {
let target_deg = j as i64 + deg_deta;
let mut rhs = if target_deg < h.len() as i64 {
h[target_deg as usize].clone()
} else {
Rational::from(0)
};
let deriv_idx = target_deg as usize + 1;
if deriv_idx <= m {
rhs -= Rational::from(target_deg + 1) * y[deriv_idx].clone();
}
for (i, deta_i) in deta.iter().enumerate().take(deg_deta as usize) {
let l = (target_deg - i as i64) as usize;
if l <= m && l != j {
rhs -= k_rat.clone() * deta_i.clone() * y[l].clone();
}
}
let divisor = k_rat.clone() * lc_deta.clone();
y[j] = rhs / divisor;
}
let y = trim(y);
let y_prime = poly_deriv(&y);
let k_deta_y = poly_scale(&poly_mul(&deta, &y), &k_rat);
let lhs = trim(poly_add(&y_prime, &k_deta_y));
let h_trimmed = trim(h);
if polys_equal(&lhs, &h_trimmed) {
Some(y)
} else {
None
}
}
fn polys_equal(a: &QPoly, b: &QPoly) -> bool {
let a = trim(a.clone());
let b = trim(b.clone());
if a.len() != b.len() {
return false;
}
a.iter().zip(b.iter()).all(|(x, y)| *x == *y)
}
use crate::kernel::{ExprData, ExprId, ExprPool};
pub fn expr_to_qpoly(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<QPoly> {
let mut coeffs: Vec<Rational> = Vec::new();
if collect_qpoly(expr, var, pool, &mut coeffs, 1) {
Some(trim(coeffs))
} else {
None
}
}
fn collect_qpoly(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
coeffs: &mut Vec<Rational>,
factor: i64,
) -> bool {
let ensure_len = |coeffs: &mut Vec<Rational>, n: usize| {
while coeffs.len() < n {
coeffs.push(Rational::from(0));
}
};
if expr == var {
ensure_len(coeffs, 2);
coeffs[1] += Rational::from(factor);
return true;
}
match pool.get(expr) {
ExprData::Integer(n) => {
let Some(val) = n.0.to_i64() else {
return false; };
ensure_len(coeffs, 1);
coeffs[0] += Rational::from(factor * val);
true
}
ExprData::Rational(r) => {
let rat = r.0.clone();
ensure_len(coeffs, 1);
coeffs[0] += rat * Rational::from(factor);
true
}
ExprData::Add(args) => {
for a in &args {
if !collect_qpoly(*a, var, pool, coeffs, factor) {
return false;
}
}
true
}
ExprData::Mul(args) => {
let mut rat_factor = Rational::from(factor);
let mut var_parts: Vec<ExprId> = Vec::new();
for &a in &args {
if is_free_of_var(a, var, pool) {
match to_rational_const(a, pool) {
Some(r) => rat_factor *= r,
None => return false, }
} else {
var_parts.push(a);
}
}
if var_parts.is_empty() {
ensure_len(coeffs, 1);
coeffs[0] += rat_factor;
return true;
}
if var_parts.len() == 1 {
let mut sub = Vec::new();
if !collect_qpoly(var_parts[0], var, pool, &mut sub, 1) {
return false;
}
let scale = rat_factor;
ensure_len(coeffs, sub.len());
for (i, c) in sub.iter().enumerate() {
if i >= coeffs.len() {
coeffs.push(Rational::from(0));
}
coeffs[i] += c.clone() * scale.clone();
}
return true;
}
false
}
ExprData::Pow { base, exp } => {
if base == var {
if let ExprData::Integer(n) = pool.get(exp) {
if let Some(n_u) = n.0.to_u32() {
ensure_len(coeffs, n_u as usize + 1);
coeffs[n_u as usize] += Rational::from(factor);
return true;
}
}
}
if is_free_of_var(expr, var, pool) {
if let Some(r) = to_rational_const(expr, pool) {
ensure_len(coeffs, 1);
coeffs[0] += r * Rational::from(factor);
return true;
}
}
false
}
_ => {
if is_free_of_var(expr, var, pool) {
if let Some(r) = to_rational_const(expr, pool) {
ensure_len(coeffs, 1);
coeffs[0] += r * Rational::from(factor);
return true;
}
}
false
}
}
}
pub fn is_free_of_var(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
if expr == var {
return false;
}
match pool.get(expr) {
ExprData::Add(args) | ExprData::Mul(args) => {
args.iter().all(|&a| is_free_of_var(a, var, pool))
}
ExprData::Pow { base, exp } => {
is_free_of_var(base, var, pool) && is_free_of_var(exp, var, pool)
}
ExprData::Func { ref args, .. } => args.iter().all(|&a| is_free_of_var(a, var, pool)),
_ => true,
}
}
fn to_rational_const(expr: ExprId, pool: &ExprPool) -> Option<Rational> {
match pool.get(expr) {
ExprData::Integer(n) => n.0.to_i64().map(Rational::from),
ExprData::Rational(r) => Some(r.0.clone()),
ExprData::Pow { base, exp } => {
if let ExprData::Integer(n) = pool.get(exp) {
if let Some(n_i) = n.0.to_i64() {
if let Some(b_r) = to_rational_const(base, pool) {
if n_i >= 0 {
let mut result = Rational::from(1);
for _ in 0..n_i {
result *= b_r.clone();
}
return Some(result);
} else {
if b_r != 0 {
let mut result = Rational::from(1);
for _ in 0..(-n_i) {
result *= b_r.clone();
}
return Some(Rational::from(1) / result);
}
}
}
}
}
None
}
ExprData::Mul(args) => {
let mut result = Rational::from(1);
for &a in &args {
result *= to_rational_const(a, pool)?;
}
Some(result)
}
_ => None,
}
}
pub fn qpoly_to_expr(poly: &QPoly, var: ExprId, pool: &ExprPool) -> ExprId {
let poly = trim(poly.clone());
if poly.is_empty() {
return pool.integer(0_i32);
}
let mut terms: Vec<ExprId> = Vec::new();
for (deg, coeff) in poly.iter().enumerate() {
if *coeff == 0 {
continue;
}
let coeff_expr = rational_to_expr(coeff, pool);
let term = if deg == 0 {
coeff_expr
} else if deg == 1 {
if *coeff == 1 {
var
} else {
pool.mul(vec![coeff_expr, var])
}
} else {
let power = pool.pow(var, pool.integer(deg as i32));
if *coeff == 1 {
power
} else {
pool.mul(vec![coeff_expr, power])
}
};
terms.push(term);
}
match terms.len() {
0 => pool.integer(0_i32),
1 => terms[0],
_ => pool.add(terms),
}
}
pub fn rational_to_expr(r: &Rational, pool: &ExprPool) -> ExprId {
if *r.denom() == 1 {
pool.integer(r.numer().clone())
} else {
pool.rational(r.numer().clone(), r.denom().clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn rat(n: i64) -> Rational {
Rational::from(n)
}
fn rat_frac(n: i64, d: i64) -> Rational {
Rational::from((n, d))
}
#[test]
fn rde_exp_x2_nonelementary() {
let deta = vec![rat(0), rat(2)]; let h = vec![rat(1)]; assert!(solve_poly_rde(1, &deta, &h).is_none());
}
#[test]
fn rde_x_exp_x2_elementary() {
let deta = vec![rat(0), rat(2)]; let h = vec![rat(0), rat(1)]; let sol = solve_poly_rde(1, &deta, &h);
assert!(sol.is_some(), "expected a polynomial solution");
let y = sol.unwrap();
assert_eq!(degree(&y), 0, "solution should be a constant");
assert_eq!(y[0], rat_frac(1, 2), "y = 1/2");
}
#[test]
fn rde_2x2plus1_exp_x2_elementary() {
let deta = vec![rat(0), rat(2)]; let h = vec![rat(1), rat(0), rat(2)]; let sol = solve_poly_rde(1, &deta, &h);
assert!(sol.is_some(), "expected a polynomial solution");
let y = sol.unwrap();
assert_eq!(degree(&y), 1, "solution should be linear");
assert_eq!(y[0], rat(0));
assert_eq!(y[1], rat(1)); }
#[test]
fn rde_x2_exp_x_elementary() {
let deta = vec![rat(1)]; let h = vec![rat(0), rat(0), rat(1)]; let sol = solve_poly_rde(1, &deta, &h);
assert!(sol.is_some(), "expected a polynomial solution");
let y = sol.unwrap();
assert_eq!(degree(&y), 2);
assert_eq!(y[0], rat(2)); assert_eq!(y[1], rat(-2)); assert_eq!(y[2], rat(1)); }
#[test]
fn rde_x_exp_x_elementary() {
let deta = vec![rat(1)]; let h = vec![rat(0), rat(1)]; let sol = solve_poly_rde(1, &deta, &h);
assert!(sol.is_some());
let y = sol.unwrap();
assert_eq!(y[0], rat(-1));
assert_eq!(y[1], rat(1));
}
#[test]
fn rde_verify_consistency() {
let cases = vec![
(1i64, vec![rat(1)], vec![rat(0), rat(0), rat(1)]), (1, vec![rat(0), rat(2)], vec![rat(0), rat(1)]), (2, vec![rat(1)], vec![rat(0), rat(1), rat(0), rat(1)]), ];
for (k, deta, h) in cases {
if let Some(y) = solve_poly_rde(k, &deta, &h) {
let k_rat = Rational::from(k);
let lhs = trim(poly_add(
&poly_deriv(&y),
&poly_scale(&poly_mul(&deta, &y), &k_rat),
));
assert!(
polys_equal(&lhs, &h),
"verification failed for k={k}: lhs={lhs:?}, h={h:?}"
);
}
}
}
}