mathhook_core/core/polynomial/
classification.rs1use crate::core::expression::ExpressionClass;
8use crate::core::polynomial::poly::IntPoly;
9use crate::core::{Expression, Number, Symbol};
10
11pub trait PolynomialClassification {
38 fn is_polynomial(&self) -> bool;
64
65 fn is_polynomial_in(&self, vars: &[Symbol]) -> bool;
89
90 fn polynomial_variables(&self) -> Vec<Symbol>;
108
109 fn classify(&self) -> ExpressionClass;
135
136 fn is_intpoly_compatible(&self) -> bool;
153
154 fn try_as_intpoly(&self) -> Option<(IntPoly, Symbol)>;
174}
175
176impl PolynomialClassification for Expression {
177 fn is_polynomial(&self) -> bool {
178 is_polynomial_impl(self)
179 }
180
181 fn is_polynomial_in(&self, vars: &[Symbol]) -> bool {
182 is_polynomial_in_impl(self, vars)
183 }
184
185 fn polynomial_variables(&self) -> Vec<Symbol> {
186 collect_polynomial_variables(self)
187 }
188
189 fn classify(&self) -> ExpressionClass {
190 classify_impl(self)
191 }
192
193 fn is_intpoly_compatible(&self) -> bool {
194 let vars = self.polynomial_variables();
195 if vars.len() != 1 {
196 return false;
197 }
198 has_only_integer_coefficients(self)
199 }
200
201 fn try_as_intpoly(&self) -> Option<(IntPoly, Symbol)> {
202 let vars = self.polynomial_variables();
203 if vars.len() != 1 {
204 return None;
205 }
206 let var = &vars[0];
207 IntPoly::try_from_expression(self, var).map(|poly| (poly, var.clone()))
208 }
209}
210
211fn extract_integer(expr: &Expression) -> Option<i64> {
213 match expr {
214 Expression::Number(Number::Integer(n)) => Some(*n),
215 _ => None,
216 }
217}
218
219fn is_rational(expr: &Expression) -> bool {
221 matches!(expr, Expression::Number(Number::Rational(_)))
222}
223
224fn is_polynomial_impl(expr: &Expression) -> bool {
226 match expr {
227 Expression::Number(_) => true,
228 Expression::Symbol(_) => true,
229 Expression::Add(terms) | Expression::Mul(terms) => terms.iter().all(is_polynomial_impl),
230 Expression::Pow(base, exp) => {
231 if !is_polynomial_impl(base) {
232 return false;
233 }
234 if let Some(n) = extract_integer(exp) {
235 n >= 0
236 } else {
237 false
238 }
239 }
240 Expression::Function { .. } => false,
241 _ => false,
242 }
243}
244
245fn is_polynomial_in_impl(expr: &Expression, vars: &[Symbol]) -> bool {
247 match expr {
248 Expression::Number(_) => true,
249 Expression::Symbol(_s) => true,
250 Expression::Add(terms) | Expression::Mul(terms) => {
251 terms.iter().all(|t| is_polynomial_in_impl(t, vars))
252 }
253 Expression::Pow(base, exp) => {
254 if !is_polynomial_in_impl(base, vars) {
255 return false;
256 }
257 if let Some(n) = extract_integer(exp) {
258 n >= 0
259 } else {
260 let exp_vars = collect_polynomial_variables(exp);
261 !exp_vars.iter().any(|v| vars.contains(v))
262 }
263 }
264 Expression::Function { .. } => false,
265 _ => false,
266 }
267}
268
269fn collect_polynomial_variables(expr: &Expression) -> Vec<Symbol> {
271 use std::collections::HashSet;
272 let mut vars = HashSet::new();
273 collect_vars_impl(expr, &mut vars);
274 vars.into_iter().collect()
275}
276
277fn collect_vars_impl(expr: &Expression, vars: &mut std::collections::HashSet<Symbol>) {
278 match expr {
279 Expression::Symbol(s) => {
280 vars.insert(s.clone());
281 }
282 Expression::Add(terms) | Expression::Mul(terms) => {
283 for term in terms.iter() {
284 collect_vars_impl(term, vars);
285 }
286 }
287 Expression::Pow(base, exp) => {
288 collect_vars_impl(base, vars);
289 collect_vars_impl(exp, vars);
290 }
291 _ => {}
292 }
293}
294
295fn classify_impl(expr: &Expression) -> ExpressionClass {
297 if extract_integer(expr).is_some() {
298 return ExpressionClass::Integer;
299 }
300
301 if !is_polynomial_impl(expr) {
302 if contains_transcendental(expr) {
303 return ExpressionClass::Transcendental;
304 }
305 return ExpressionClass::Symbolic;
306 }
307
308 let vars = collect_polynomial_variables(expr);
309
310 match vars.len() {
311 0 => {
312 if is_rational(expr) {
313 ExpressionClass::Rational
314 } else {
315 ExpressionClass::Integer
316 }
317 }
318 1 => {
319 let var = vars.into_iter().next().unwrap();
320 let degree = compute_degree(expr, &var).unwrap_or(0);
321 ExpressionClass::UnivariatePolynomial { var, degree }
322 }
323 _ => {
324 let total_degree = vars.iter().filter_map(|v| compute_degree(expr, v)).sum();
325 ExpressionClass::MultivariatePolynomial { vars, total_degree }
326 }
327 }
328}
329
330fn contains_transcendental(expr: &Expression) -> bool {
332 match expr {
333 Expression::Function { name, .. } => {
334 let transcendental_fns = [
335 "sin", "cos", "tan", "cot", "sec", "csc", "sinh", "cosh", "tanh", "exp", "log",
336 "ln", "arcsin", "arccos", "arctan",
337 ];
338 transcendental_fns.contains(&name.as_ref())
339 }
340 Expression::Add(terms) | Expression::Mul(terms) => {
341 terms.iter().any(contains_transcendental)
342 }
343 Expression::Pow(base, exp) => contains_transcendental(base) || contains_transcendental(exp),
344 _ => false,
345 }
346}
347
348fn compute_degree(expr: &Expression, var: &Symbol) -> Option<i64> {
350 match expr {
351 Expression::Number(_) => Some(0),
352 Expression::Symbol(s) => {
353 if s == var {
354 Some(1)
355 } else {
356 Some(0)
357 }
358 }
359 Expression::Add(terms) => terms.iter().filter_map(|t| compute_degree(t, var)).max(),
360 Expression::Mul(terms) => {
361 let degrees: Option<Vec<i64>> = terms.iter().map(|t| compute_degree(t, var)).collect();
362 degrees.map(|ds| ds.into_iter().sum())
363 }
364 Expression::Pow(base, exp) => {
365 let base_deg = compute_degree(base, var)?;
366 let exp_val = extract_integer(exp)?;
367 Some(base_deg * exp_val)
368 }
369 _ => None,
370 }
371}
372
373fn has_only_integer_coefficients(expr: &Expression) -> bool {
375 match expr {
376 Expression::Number(Number::Integer(_)) => true,
377 Expression::Symbol(_) => true,
378 Expression::Add(terms) | Expression::Mul(terms) => {
379 terms.iter().all(has_only_integer_coefficients)
380 }
381 Expression::Pow(base, exp) => {
382 has_only_integer_coefficients(base)
383 && matches!(exp.as_ref(), Expression::Number(Number::Integer(n)) if *n >= 0)
384 }
385 _ => false,
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::{expr, symbol};
393
394 #[test]
395 fn test_is_polynomial() {
396 let x = symbol!(x);
397
398 assert!(Expression::integer(5).is_polynomial());
399 assert!(Expression::symbol(x.clone()).is_polynomial());
400
401 let poly = expr!(x + 1);
402 assert!(poly.is_polynomial());
403
404 let poly2 = expr!(x ^ 2);
405 assert!(poly2.is_polynomial());
406 }
407
408 #[test]
409 fn test_classify_integer() {
410 let five = Expression::integer(5);
411 assert_eq!(five.classify(), ExpressionClass::Integer);
412 }
413
414 #[test]
415 fn test_classify_univariate() {
416 let x = symbol!(x);
417 let poly = expr!(x ^ 2);
418
419 match poly.classify() {
420 ExpressionClass::UnivariatePolynomial { var, degree } => {
421 assert_eq!(var, x);
422 assert_eq!(degree, 2);
423 }
424 other => panic!("Expected UnivariatePolynomial, got {:?}", other),
425 }
426 }
427
428 #[test]
429 fn test_polynomial_variables() {
430 let x = symbol!(x);
431 let y = symbol!(y);
432
433 let poly = Expression::add(vec![
434 Expression::symbol(x.clone()),
435 Expression::symbol(y.clone()),
436 ]);
437
438 let vars = poly.polynomial_variables();
439 assert_eq!(vars.len(), 2);
440 assert!(vars.contains(&x));
441 assert!(vars.contains(&y));
442 }
443
444 #[test]
445 fn test_is_polynomial_in() {
446 let x = symbol!(x);
447 let y = symbol!(y);
448 let poly = expr!(x * y);
449
450 assert!(poly.is_polynomial_in(std::slice::from_ref(&x)));
451 assert!(poly.is_polynomial_in(std::slice::from_ref(&y)));
452 assert!(poly.is_polynomial_in(&[x.clone(), y.clone()]));
453 }
454
455 #[test]
456 fn test_classify_multivariate() {
457 let x = symbol!(x);
458 let y = symbol!(y);
459 let poly = Expression::add(vec![
460 Expression::symbol(x.clone()),
461 Expression::symbol(y.clone()),
462 ]);
463
464 match poly.classify() {
465 ExpressionClass::MultivariatePolynomial { vars, .. } => {
466 assert_eq!(vars.len(), 2);
467 assert!(vars.contains(&x));
468 assert!(vars.contains(&y));
469 }
470 other => panic!("Expected MultivariatePolynomial, got {:?}", other),
471 }
472 }
473
474 #[test]
475 fn test_classify_transcendental() {
476 let x = symbol!(x);
477 let expr = Expression::function("sin", vec![Expression::symbol(x)]);
478
479 assert_eq!(expr.classify(), ExpressionClass::Transcendental);
480 }
481
482 #[test]
483 fn test_is_intpoly_compatible() {
484 assert!(expr!(2 * x + 3).is_intpoly_compatible());
485 assert!(expr!(x ^ 2 + 2 * x + 1).is_intpoly_compatible());
486
487 assert!(!expr!(x + y).is_intpoly_compatible());
488
489 assert!(!expr!(1.5 * x + 2).is_intpoly_compatible());
490
491 assert!(!expr!(x ^ (-1)).is_intpoly_compatible());
492 }
493
494 #[test]
495 fn test_try_as_intpoly() {
496 let x = symbol!(x);
497 let poly_expr = expr!(x ^ 2 + 2 * x + 3);
498
499 let result = poly_expr.try_as_intpoly();
500 assert!(result.is_some());
501
502 let (intpoly, var) = result.unwrap();
503 assert_eq!(var, x);
504 assert_eq!(intpoly.degree(), Some(2));
505 assert_eq!(intpoly.coefficients(), &[3, 2, 1]);
506 }
507}