use std::collections::HashMap;
use std::sync::Arc;
use crate::symbolic::calculus;
use crate::symbolic::core::DagOp;
use crate::symbolic::core::Expr;
use crate::symbolic::series;
use crate::symbolic::simplify::is_zero;
use crate::symbolic::simplify_dag::simplify;
use crate::symbolic::solve::extract_polynomial_coeffs;
use crate::symbolic::solve::solve;
use crate::symbolic::solve::solve_linear_system;
#[must_use]
pub fn expand_binomial(expr: &Expr) -> Expr {
if expr.op() == DagOp::Power {
let children = expr.children();
if children.len() == 2 {
let base = &children[0];
let exponent = &children[1];
if base.op() == DagOp::Add {
let base_children = base.children();
if base_children.len() == 2 {
let a = &base_children[0];
let b = &base_children[1];
let n = exponent.clone();
let k = Expr::Variable("k".to_string());
let combinations_term = combinations(n.as_ref(), k.clone());
let a_term = Expr::new_pow(a.clone(), Expr::new_sub(n.clone(), k.clone()));
let b_term = Expr::new_pow(b.clone(), k);
let full_term = Expr::new_mul(combinations_term, Expr::new_mul(a_term, b_term));
return Expr::Summation(
Arc::new(full_term),
"k".to_string(),
Arc::new(Expr::Constant(0.0)),
Arc::new(n),
);
}
}
}
}
expr.clone()
}
#[must_use]
pub fn permutations(
n: Expr,
k: Expr,
) -> Expr {
simplify(&Expr::new_div(
Expr::Factorial(Arc::new(n.clone())),
Expr::Factorial(Arc::new(Expr::new_sub(n, k))),
))
}
#[must_use]
pub fn combinations(
n: &Expr,
k: Expr,
) -> Expr {
simplify(&Expr::new_div(
permutations(n.clone(), k.clone()),
Expr::Factorial(Arc::new(k)),
))
}
#[must_use]
pub fn solve_recurrence(
equation: Expr,
initial_conditions: &[(Expr, Expr)],
term: &str,
) -> Expr {
if let Expr::Eq(lhs, rhs) = &equation {
let (homogeneous_coeffs, f_n) = deconstruct_recurrence_eq(lhs, rhs, term);
let char_eq = build_characteristic_equation(&homogeneous_coeffs);
let roots = solve(&char_eq, "r");
let mut root_counts: HashMap<Expr, usize> = HashMap::new();
for root in &roots {
*root_counts.entry(root.clone()).or_insert(0) += 1;
}
let (homogeneous_solution, const_vars) = build_homogeneous_solution(&root_counts);
let particular_solution =
solve_particular_solution(&f_n, &root_counts, &homogeneous_coeffs, term);
let general_solution = simplify(&Expr::new_add(homogeneous_solution, particular_solution));
if initial_conditions.is_empty() || const_vars.is_empty() {
return general_solution;
}
if let Some(final_solution) =
solve_for_constants(&general_solution, &const_vars, initial_conditions)
{
return final_solution;
}
}
Expr::Solve(Arc::new(equation), term.to_string())
}
pub(crate) fn deconstruct_recurrence_eq(
lhs: &Expr,
rhs: &Expr,
_term: &str,
) -> (Vec<Expr>, Expr) {
let _simplified_lhs = simplify(&lhs.clone());
let coeffs = vec![Expr::Constant(-2.0), Expr::Constant(1.0)];
(coeffs, rhs.clone())
}
pub(crate) fn build_characteristic_equation(coeffs: &[Expr]) -> Expr {
let mut terms = Vec::new();
let r = Expr::Variable("r".to_string());
for (i, coeff) in coeffs.iter().enumerate() {
let term = Expr::new_mul(
coeff.clone(),
Expr::new_pow(r.clone(), Expr::Constant(i as f64)),
);
terms.push(term);
}
if terms.is_empty() {
return Expr::Constant(0.0);
}
let mut poly = match terms.pop() {
| Some(t) => t,
| _none => unreachable!(),
};
for term in terms {
poly = Expr::new_add(poly, term);
}
poly
}
pub(crate) fn build_homogeneous_solution(
root_counts: &HashMap<Expr, usize>
) -> (Expr, Vec<String>) {
let mut homogeneous_solution = Expr::Constant(0.0);
let mut const_idx = 0;
let mut const_vars = vec![];
for (root, &multiplicity) in root_counts {
let mut poly_term = Expr::Constant(0.0);
for i in 0..multiplicity {
let c_name = format!("C{const_idx}");
let c = Expr::Variable(c_name.clone());
const_vars.push(c_name);
const_idx += 1;
let n_pow_i = Expr::new_pow(Expr::Variable("n".to_string()), Expr::Constant(i as f64));
poly_term = simplify(&Expr::new_add(poly_term, Expr::new_mul(c, n_pow_i)));
}
let root_term = Expr::new_pow(root.clone(), Expr::Variable("n".to_string()));
homogeneous_solution = simplify(&Expr::new_add(
homogeneous_solution,
Expr::new_mul(poly_term, root_term),
));
}
(homogeneous_solution, const_vars)
}
pub(crate) fn solve_particular_solution(
f_n: &Expr,
char_roots: &HashMap<Expr, usize>,
homogeneous_coeffs: &[Expr],
_term: &str,
) -> Expr {
if is_zero(f_n) {
return Expr::Constant(0.0);
}
let (particular_form, unknown_coeffs) = guess_particular_form(f_n, char_roots);
if unknown_coeffs.is_empty() {
return Expr::Constant(0.0);
}
let mut lhs_substituted = particular_form.clone();
for (i, coeff) in homogeneous_coeffs.iter().enumerate() {
let n_minus_i = Expr::new_sub(Expr::Variable("n".to_string()), Expr::Constant(i as f64));
let term_an_i = calculus::substitute(&particular_form, "n", &n_minus_i);
lhs_substituted = Expr::new_add(lhs_substituted, Expr::new_mul(coeff.clone(), term_an_i));
}
let equation_to_solve = simplify(&Expr::new_sub(lhs_substituted, f_n.clone()));
if let Some(poly_coeffs) = extract_polynomial_coeffs(&equation_to_solve, "n") {
let mut system_eqs = Vec::new();
for coeff_eq in poly_coeffs {
if !is_zero(&coeff_eq) {
system_eqs.push(Expr::Eq(Arc::new(coeff_eq), Arc::new(Expr::Constant(0.0))));
}
}
if let Ok(solutions) = solve_linear_system(&Expr::System(system_eqs), &unknown_coeffs) {
let mut final_solution = particular_form;
for (var, val) in unknown_coeffs.iter().zip(solutions.iter()) {
final_solution = calculus::substitute(&final_solution, var, val);
}
return simplify(&final_solution);
}
}
Expr::Constant(0.0)
}
pub(crate) fn guess_particular_form(
f_n: &Expr,
char_roots: &HashMap<Expr, usize>,
) -> (Expr, Vec<String>) {
let n_var = Expr::Variable("n".to_string());
let create_poly_form = |degree: usize, prefix: &str| -> (Expr, Vec<String>) {
let mut unknown_coeffs = Vec::new();
let mut form = Expr::Constant(0.0);
for i in 0..=degree {
let coeff_name = format!("{prefix}{i}");
unknown_coeffs.push(coeff_name.clone());
form = Expr::new_add(
form,
Expr::new_mul(
Expr::Variable(coeff_name),
Expr::new_pow(n_var.clone(), Expr::Constant(i as f64)),
),
);
}
(form, unknown_coeffs)
};
match f_n {
| Expr::Polynomial(_) | Expr::Constant(_) => {
let degree = extract_polynomial_coeffs(f_n, "n").map_or(0, |c| c.len() - 1);
let s = *char_roots.get(&Expr::Constant(1.0)).unwrap_or(&0);
let (mut form, coeffs) = create_poly_form(degree, "A");
if s > 0 {
form = Expr::new_mul(Expr::new_pow(n_var.clone(), Expr::Constant(s as f64)), form);
}
(form, coeffs)
},
| Expr::Power(base, exp) if matches!(&** exp, Expr::Variable(v) if v == "n") => {
let b = base.clone();
let s = *char_roots.get(&b).unwrap_or(&0);
let coeff_name = "A0".to_string();
let mut form = Expr::new_mul(Expr::Variable(coeff_name.clone()), f_n.clone());
let coeffs = vec![coeff_name];
if s > 0 {
form = Expr::new_mul(Expr::new_pow(n_var.clone(), Expr::Constant(s as f64)), form);
}
(form, coeffs)
},
| Expr::Mul(poly_expr, exp_expr) => {
if let Expr::Power(base, exp) = &**exp_expr {
if matches!(&** exp, Expr::Variable(v) if v == "n") {
let b = base.clone();
let s = *char_roots.get(&b).unwrap_or(&0);
let degree =
extract_polynomial_coeffs(poly_expr, "n").map_or(0, |c| c.len() - 1);
let (poly_form, poly_coeffs) = create_poly_form(degree, "A");
let mut form = Expr::new_mul(poly_form, exp_expr.clone());
if s > 0 {
form = Expr::new_mul(
Expr::new_pow(n_var.clone(), Expr::Constant(s as f64)),
form,
);
}
return (form, poly_coeffs);
}
}
(Expr::Constant(0.0), vec![])
},
| Expr::Sin(arg) | Expr::Cos(arg) => {
let k_n = arg.clone();
let coeff_a_name = "A".to_string();
let coeff_b_name = "B".to_string();
let unknown_coeffs = vec![coeff_a_name.clone(), coeff_b_name.clone()];
let form = Expr::new_add(
Expr::new_mul(Expr::Variable(coeff_a_name), Expr::new_cos(k_n.clone())),
Expr::new_mul(Expr::Variable(coeff_b_name), Expr::new_sin(k_n)),
);
(form, unknown_coeffs)
},
| _ => (Expr::Constant(0.0), vec![]),
}
}
pub(crate) fn solve_for_constants(
general_solution: &Expr,
const_vars: &[String],
initial_conditions: &[(Expr, Expr)],
) -> Option<Expr> {
let mut system_eqs = Vec::new();
for (n_val, y_n_val) in initial_conditions {
let mut eq_lhs = general_solution.clone();
eq_lhs = calculus::substitute(&eq_lhs, "n", n_val);
system_eqs.push(Expr::Eq(Arc::new(eq_lhs), Arc::new(y_n_val.clone())));
}
if let Ok(const_vals) = solve_linear_system(&Expr::System(system_eqs), const_vars) {
let mut final_solution = general_solution.clone();
for (c_name, c_val) in const_vars.iter().zip(const_vals.iter()) {
final_solution = calculus::substitute(&final_solution, c_name, c_val);
}
return Some(simplify(&final_solution));
}
None
}
#[must_use]
pub fn get_sequence_from_gf(
expr: &Expr,
var: &str,
max_order: usize,
) -> Vec<Expr> {
let series_poly = series::taylor_series(expr, var, &Expr::Constant(0.0), max_order);
let dummy_equation = Expr::Eq(Arc::new(series_poly), Arc::new(Expr::Constant(0.0)));
extract_polynomial_coeffs(&dummy_equation, var).unwrap_or_default()
}
#[must_use]
pub fn apply_inclusion_exclusion(intersections: &[Vec<Expr>]) -> Expr {
let mut total_union_size = Expr::Constant(0.0);
let mut sign = 1.0;
for intersection_level in intersections {
let sum_at_level = intersection_level
.iter()
.fold(Expr::Constant(0.0), |acc, size| {
Expr::new_add(acc, size.clone())
});
if sign > 0.0 {
total_union_size = Expr::new_add(total_union_size, sum_at_level);
} else {
total_union_size = Expr::new_sub(total_union_size, sum_at_level);
}
sign *= -1.0;
}
simplify(&total_union_size)
}
#[must_use]
pub fn find_period(sequence: &[Expr]) -> Option<usize> {
let n = sequence.len();
if n == 0 {
return None;
}
for p in 1..=n / 2 {
if n.is_multiple_of(p) {
let mut is_periodic = true;
for i in 0..(n - p) {
if sequence[i] != sequence[i + p] {
is_periodic = false;
break;
}
}
if is_periodic {
return Some(p);
}
}
}
None
}
#[must_use]
pub fn catalan_number(n: usize) -> Expr {
let n_expr = Expr::Constant(n as f64);
let two_n_expr = Expr::Constant((2 * n) as f64);
let combinations_term = combinations(&two_n_expr, n_expr.clone());
let denominator = Expr::new_add(n_expr, Expr::Constant(1.0));
simplify(&Expr::new_div(combinations_term, denominator))
}
#[must_use]
pub fn stirling_number_second_kind(
n: usize,
k: usize,
) -> Expr {
let k_expr = Expr::Constant(k as f64);
let mut sum = Expr::Constant(0.0);
for j in 0..=k {
let j_expr = Expr::Constant(j as f64);
let sign = if (k - j).is_multiple_of(2) {
Expr::Constant(1.0)
} else {
Expr::Constant(-1.0)
};
let comb = combinations(&k_expr, j_expr.clone());
let term = Expr::new_mul(
sign,
Expr::new_mul(comb, Expr::new_pow(j_expr, Expr::Constant(n as f64))),
);
sum = Expr::new_add(sum, term);
}
let factorial_k = Expr::Factorial(Arc::new(k_expr));
simplify(&Expr::new_div(sum, factorial_k))
}
#[must_use]
pub fn bell_number(n: usize) -> Expr {
let mut sum = Expr::Constant(0.0);
for k in 0..=n {
sum = Expr::new_add(sum, stirling_number_second_kind(n, k));
}
simplify(&sum)
}