1use crate::core::{Expression, Number, Symbol};
34use crate::expr;
35use crate::simplify::Simplify;
36
37pub mod educational;
38
39pub trait Summation {
41 fn finite_sum(&self, variable: &Symbol, start: &Expression, end: &Expression) -> Expression;
55
56 fn infinite_sum(&self, variable: &Symbol, start: &Expression) -> Expression;
70
71 fn finite_product(&self, variable: &Symbol, start: &Expression, end: &Expression)
85 -> Expression;
86
87 fn infinite_product(&self, variable: &Symbol, start: &Expression) -> Expression;
101}
102
103pub struct SummationMethods;
110
111struct PowerSumFormula {
116 power: i64,
117 compute: fn(&Expression) -> Expression,
118}
119
120impl PowerSumFormula {
121 const FORMULAS: &'static [PowerSumFormula] = &[
122 PowerSumFormula {
123 power: 0,
124 compute: |n| n.clone(),
125 },
126 PowerSumFormula {
127 power: 1,
128 compute: |n| {
129 let n_plus_1 = Expression::add(vec![n.clone(), expr!(1)]).simplify();
130 Expression::mul(vec![n.clone(), n_plus_1, expr!(1 / 2)]).simplify()
131 },
132 },
133 PowerSumFormula {
134 power: 2,
135 compute: |n| {
136 let n_plus_1 = Expression::add(vec![n.clone(), expr!(1)]).simplify();
137 let two_n = Expression::mul(vec![expr!(2), n.clone()]).simplify();
138 let two_n_plus_1 = Expression::add(vec![two_n, expr!(1)]).simplify();
139 Expression::mul(vec![n.clone(), n_plus_1, two_n_plus_1, expr!(1 / 6)]).simplify()
140 },
141 },
142 PowerSumFormula {
143 power: 3,
144 compute: |n| {
145 let n_plus_1 = Expression::add(vec![n.clone(), expr!(1)]).simplify();
146 let base = Expression::mul(vec![n.clone(), n_plus_1, expr!(1 / 2)]).simplify();
147 Expression::pow(base, expr!(2)).simplify()
148 },
149 },
150 ];
151
152 fn lookup(power: i64) -> Option<fn(&Expression) -> Expression> {
153 Self::FORMULAS
154 .iter()
155 .find(|formula| formula.power == power)
156 .map(|formula| formula.compute)
157 }
158}
159
160impl SummationMethods {
161 pub fn arithmetic_series(
169 first_term: &Expression,
170 common_difference: &Expression,
171 num_terms: &Expression,
172 ) -> Expression {
173 let n_over_2 = Expression::mul(vec![num_terms.clone(), expr!(1 / 2)]);
174 let two_a = Expression::mul(vec![expr!(2), first_term.clone()]);
175 let n_minus_1 = Expression::add(vec![num_terms.clone(), expr!(-1)]);
176 let n_minus_1_times_d = Expression::mul(vec![n_minus_1, common_difference.clone()]);
177 let inner_sum = Expression::add(vec![two_a, n_minus_1_times_d]);
178
179 Expression::mul(vec![n_over_2, inner_sum]).simplify()
180 }
181
182 pub fn geometric_series(
191 first_term: &Expression,
192 common_ratio: &Expression,
193 num_terms: &Expression,
194 ) -> Expression {
195 let simplified_ratio = common_ratio.simplify();
196 let ratio_power = Expression::pow(simplified_ratio.clone(), num_terms.clone()).simplify();
197 let one_minus_ratio_power = Expression::add(vec![
198 expr!(1),
199 Expression::mul(vec![expr!(-1), ratio_power]),
200 ])
201 .simplify();
202
203 let numerator = Expression::mul(vec![first_term.clone(), one_minus_ratio_power]).simplify();
204 let denominator = Expression::add(vec![
205 expr!(1),
206 Expression::mul(vec![expr!(-1), simplified_ratio]),
207 ])
208 .simplify();
209
210 Expression::mul(vec![numerator, Expression::pow(denominator, expr!(-1))]).simplify()
211 }
212
213 pub fn infinite_geometric_series(
225 first_term: &Expression,
226 common_ratio: &Expression,
227 ) -> Expression {
228 let one_minus_r = Expression::add(vec![
229 expr!(1),
230 Expression::mul(vec![expr!(-1), common_ratio.clone()]),
231 ])
232 .simplify();
233
234 Expression::mul(vec![
235 first_term.clone(),
236 Expression::pow(one_minus_r, expr!(-1)),
237 ])
238 .simplify()
239 }
240
241 pub fn power_sum(power: &Expression, upper_limit: &Expression) -> Expression {
253 if let Expression::Number(Number::Integer(k_val)) = power {
254 if let Some(compute_fn) = PowerSumFormula::lookup(*k_val) {
255 return compute_fn(upper_limit);
256 }
257 }
258
259 Expression::function("power_sum", vec![power.clone(), upper_limit.clone()])
260 }
261
262 pub fn convergence_test(expr: &Expression, variable: &Symbol) -> ConvergenceResult {
270 if let Expression::Pow(base, exp) = expr {
271 if matches!(
272 (base.as_ref(), exp.as_ref()),
273 (Expression::Symbol(sym), Expression::Number(Number::Float(exp_val)))
274 if sym == variable && *exp_val < -1.0
275 ) {
276 return ConvergenceResult::Convergent;
277 }
278
279 if matches!(
280 (base.as_ref(), exp.as_ref()),
281 (Expression::Symbol(sym), Expression::Number(Number::Float(exp_val)))
282 if sym == variable && *exp_val >= -1.0
283 ) {
284 return ConvergenceResult::Divergent;
285 }
286 }
287
288 ConvergenceResult::Unknown
289 }
290}
291
292#[derive(Debug, Clone, PartialEq)]
294pub enum ConvergenceResult {
295 Convergent,
296 Divergent,
297 ConditionallyConvergent,
298 Unknown,
299}
300
301impl Summation for Expression {
302 fn finite_sum(&self, variable: &Symbol, start: &Expression, end: &Expression) -> Expression {
303 if let Expression::Symbol(sym) = self {
304 if sym == variable {
305 let n = Expression::add(vec![
306 end.clone(),
307 Expression::mul(vec![expr!(-1), start.clone()]),
308 expr!(1),
309 ]);
310
311 let first = start.clone();
312 let last = end.clone();
313
314 return Expression::mul(vec![n, Expression::add(vec![first, last]), expr!(1 / 2)])
315 .simplify();
316 }
317 }
318
319 if matches!(
320 self,
321 Expression::Pow(base, _) if matches!(base.as_ref(), Expression::Symbol(sym) if sym == variable)
322 ) {
323 if let Expression::Pow(_, exp) = self {
324 return SummationMethods::power_sum(exp, end);
325 }
326 }
327
328 Expression::function(
329 "finite_sum",
330 vec![
331 self.clone(),
332 variable.clone().into(),
333 start.clone(),
334 end.clone(),
335 ],
336 )
337 }
338
339 fn infinite_sum(&self, variable: &Symbol, start: &Expression) -> Expression {
340 if let Expression::Pow(base, exp) = self {
341 if matches!(
342 (base.as_ref(), exp.as_ref()),
343 (Expression::Number(Number::Float(r_val)), Expression::Symbol(sym))
344 if sym == variable && r_val.abs() < 1.0
345 ) {
346 let one_minus_r = Expression::add(vec![
347 expr!(1),
348 Expression::mul(vec![expr!(-1), base.as_ref().clone()]),
349 ]);
350
351 return Expression::mul(vec![
352 Expression::pow(base.as_ref().clone(), start.clone()),
353 Expression::pow(one_minus_r, expr!(-1)),
354 ])
355 .simplify();
356 }
357 }
358
359 Expression::function(
360 "infinite_sum",
361 vec![self.clone(), variable.clone().into(), start.clone()],
362 )
363 }
364
365 fn finite_product(
366 &self,
367 variable: &Symbol,
368 start: &Expression,
369 end: &Expression,
370 ) -> Expression {
371 Expression::function(
372 "finite_product",
373 vec![
374 self.clone(),
375 variable.clone().into(),
376 start.clone(),
377 end.clone(),
378 ],
379 )
380 }
381
382 fn infinite_product(&self, variable: &Symbol, start: &Expression) -> Expression {
383 Expression::function(
384 "infinite_product",
385 vec![self.clone(), variable.clone().into(), start.clone()],
386 )
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::symbol;
394
395 #[test]
396 fn test_arithmetic_series() {
397 let first = expr!(1);
398 let diff = expr!(1);
399 let n = expr!(10);
400
401 let result = SummationMethods::arithmetic_series(&first, &diff, &n);
402 assert_eq!(result, expr!(55));
403 }
404
405 #[test]
406 fn test_geometric_series() {
407 let first = expr!(1);
408 let ratio = expr!(1 / 2);
409 let n = expr!(3);
410
411 let result = SummationMethods::geometric_series(&first, &ratio, &n);
412 assert_eq!(result.simplify(), Expression::rational(7, 4));
413 }
414
415 #[test]
416 fn test_power_sum_linear() {
417 let power = expr!(1);
418 let n = expr!(5);
419
420 let result = SummationMethods::power_sum(&power, &n);
421 assert_eq!(result.simplify(), expr!(15));
422 }
423
424 #[test]
425 fn test_power_sum_quadratic() {
426 let power = expr!(2);
427 let n = expr!(3);
428
429 let result = SummationMethods::power_sum(&power, &n);
430 assert_eq!(result.simplify(), expr!(14));
431 }
432
433 #[test]
434 fn test_finite_sum_linear() {
435 let i = symbol!(i);
436 let start = expr!(1);
437 let end = expr!(4);
438
439 let expr_i: Expression = i.clone().into();
440 let result = expr_i.finite_sum(&i, &start, &end);
441 assert_eq!(result.simplify(), expr!(10));
442 }
443
444 #[test]
445 fn test_infinite_geometric_series() {
446 let first = expr!(1);
447 let ratio = Expression::rational(1, 3);
448
449 let result = SummationMethods::infinite_geometric_series(&first, &ratio);
450 assert_eq!(result, Expression::rational(3, 2));
451 }
452}