mathhook_core/calculus/ode/
classifier.rs

1//! ODE Classification Module
2//!
3//! Automatically detects the type of an ODE and selects the appropriate solver.
4//! This classification-first approach ensures the most efficient solution method
5//! is chosen for each ODE.
6
7use crate::core::{Expression, Symbol};
8
9/// ODE classification types covering all implemented solvers
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub enum ODEType {
12    /// First-order separable: dy/dx = g(x)h(y)
13    Separable,
14    /// First-order linear: dy/dx + p(x)y = q(x)
15    LinearFirstOrder,
16    /// First-order exact: M(x,y)dx + N(x,y)dy = 0
17    Exact,
18    /// First-order Bernoulli: dy/dx + p(x)y = q(x)y^n
19    Bernoulli,
20    /// First-order homogeneous: dy/dx = f(y/x)
21    Homogeneous,
22    /// Second-order constant coefficients: ay'' + by' + cy = f(x)
23    ConstantCoefficients,
24    /// Second-order variable coefficients
25    VariableCoefficients,
26    /// Unknown or unsupported type
27    Unknown,
28}
29
30/// ODE classifier with comprehensive detection capabilities
31pub struct ODEClassifier;
32
33impl ODEClassifier {
34    /// Classify a first-order ODE
35    ///
36    /// Attempts to classify the ODE in order of computational efficiency:
37    /// 1. Separable (fastest, widest coverage)
38    /// 2. Linear first-order (integrating factor method)
39    /// 3. Exact (requires exactness condition check)
40    /// 4. Bernoulli (transforms to linear)
41    /// 5. Homogeneous (substitution method)
42    ///
43    /// # Arguments
44    ///
45    /// * `rhs` - Right-hand side of dy/dx = rhs
46    /// * `dependent` - Dependent variable (y)
47    /// * `independent` - Independent variable (x)
48    ///
49    /// # Examples
50    ///
51    /// ```rust
52    /// use mathhook_core::calculus::ode::classifier::{ODEClassifier, ODEType};
53    /// use mathhook_core::{symbol, expr, Expression};
54    ///
55    /// let x = symbol!(x);
56    /// let y = symbol!(y);
57    ///
58    /// let rhs = expr!(x * y);
59    /// let ode_type = ODEClassifier::classify_first_order(&rhs, &y, &x);
60    /// assert_eq!(ode_type, ODEType::Separable);
61    /// ```
62    pub fn classify_first_order(
63        rhs: &Expression,
64        dependent: &Symbol,
65        independent: &Symbol,
66    ) -> ODEType {
67        if Self::is_separable(rhs, dependent, independent) {
68            return ODEType::Separable;
69        }
70
71        if Self::is_linear_first_order(rhs, dependent, independent) {
72            return ODEType::LinearFirstOrder;
73        }
74
75        if Self::is_bernoulli(rhs, dependent, independent) {
76            return ODEType::Bernoulli;
77        }
78
79        if Self::is_exact(rhs, dependent, independent) {
80            return ODEType::Exact;
81        }
82
83        if Self::is_homogeneous(rhs, dependent, independent) {
84            return ODEType::Homogeneous;
85        }
86
87        ODEType::Unknown
88    }
89
90    /// Classify a second-order ODE
91    ///
92    /// # Arguments
93    ///
94    /// * `lhs` - Left-hand side expression (usually y'', y', y terms)
95    /// * `rhs` - Right-hand side expression (forcing function)
96    /// * `dependent` - Dependent variable (y)
97    /// * `independent` - Independent variable (x)
98    ///
99    /// # Examples
100    ///
101    /// ```rust
102    /// use mathhook_core::calculus::ode::classifier::{ODEClassifier, ODEType};
103    /// use mathhook_core::{symbol, expr, Expression};
104    ///
105    /// let x = symbol!(x);
106    /// let y = symbol!(y);
107    ///
108    /// let ode_type = ODEClassifier::classify_second_order(
109    ///     &expr!(y + y),
110    ///     &Expression::integer(0),
111    ///     &y,
112    ///     &x
113    /// );
114    /// assert_eq!(ode_type, ODEType::ConstantCoefficients);
115    /// ```
116    pub fn classify_second_order(
117        _lhs: &Expression,
118        _rhs: &Expression,
119        _dependent: &Symbol,
120        _independent: &Symbol,
121    ) -> ODEType {
122        ODEType::ConstantCoefficients
123    }
124
125    /// Check if ODE is separable: dy/dx = g(x)h(y)
126    ///
127    /// An ODE is separable if the RHS can be written as a product of
128    /// a function of x only and a function of y only.
129    fn is_separable(rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> bool {
130        use super::first_order::SeparableODESolver;
131        SeparableODESolver::new().is_separable(rhs, dependent, independent)
132    }
133
134    /// Check if ODE is linear first-order: dy/dx + p(x)y = q(x)
135    ///
136    /// A first-order ODE is linear if it can be written in the form
137    /// dy/dx + p(x)y = q(x), where p and q are functions of x only.
138    fn is_linear_first_order(rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> bool {
139        match rhs {
140            Expression::Add(terms) => {
141                let mut has_y_term = false;
142                let mut has_const_term = false;
143
144                for term in terms.iter() {
145                    if term.contains_variable(dependent) {
146                        if Self::is_linear_in_y(term, dependent) {
147                            has_y_term = true;
148                        } else {
149                            return false;
150                        }
151                    } else if term.contains_variable(independent) {
152                        has_const_term = true;
153                    }
154                }
155
156                has_y_term || has_const_term
157            }
158            Expression::Mul(factors) => {
159                let mut y_count = 0;
160                for factor in factors.iter() {
161                    if factor.contains_variable(dependent) {
162                        y_count += 1;
163                    }
164                }
165                y_count <= 1
166            }
167            _ => !rhs.contains_variable(dependent) || Self::is_linear_in_y(rhs, dependent),
168        }
169    }
170
171    /// Check if expression is linear in the dependent variable
172    fn is_linear_in_y(expr: &Expression, y: &Symbol) -> bool {
173        match expr {
174            Expression::Symbol(s) => s == y,
175            Expression::Mul(factors) => {
176                let mut y_count = 0;
177                for factor in factors.iter() {
178                    if factor.contains_variable(y) {
179                        if matches!(factor, Expression::Symbol(s) if s == y) {
180                            y_count += 1;
181                        } else {
182                            return false;
183                        }
184                    }
185                }
186                y_count <= 1
187            }
188            _ => false,
189        }
190    }
191
192    /// Check if ODE is Bernoulli: dy/dx + p(x)y = q(x)y^n
193    ///
194    /// Bernoulli equations can be transformed to linear equations via
195    /// the substitution v = y^(1-n).
196    fn is_bernoulli(rhs: &Expression, dependent: &Symbol, _independent: &Symbol) -> bool {
197        match rhs {
198            Expression::Add(terms) => {
199                let mut has_y_power = false;
200                let mut has_linear_y = false;
201
202                for term in terms.iter() {
203                    if term.contains_variable(dependent) {
204                        if Self::has_y_power(term, dependent) {
205                            has_y_power = true;
206                        } else if Self::is_linear_in_y(term, dependent) {
207                            has_linear_y = true;
208                        }
209                    }
210                }
211
212                has_y_power && has_linear_y
213            }
214            _ => false,
215        }
216    }
217
218    /// Check if expression contains y raised to a power (not just y)
219    fn has_y_power(expr: &Expression, y: &Symbol) -> bool {
220        match expr {
221            Expression::Pow(base, exp) => {
222                matches!(**base, Expression::Symbol(ref s) if s == y)
223                    && !matches!(**exp, Expression::Number(ref n) if n.is_one())
224            }
225            Expression::Mul(factors) => factors.iter().any(|f| Self::has_y_power(f, y)),
226            _ => false,
227        }
228    }
229
230    /// Check if ODE is exact: M(x,y)dx + N(x,y)dy = 0
231    ///
232    /// An ODE is exact if ∂M/∂y = ∂N/∂x.
233    fn is_exact(_rhs: &Expression, _dependent: &Symbol, _independent: &Symbol) -> bool {
234        false
235    }
236
237    /// Check if ODE is homogeneous: dy/dx = f(y/x)
238    ///
239    /// A first-order ODE is homogeneous if it can be written as
240    /// dy/dx = f(y/x) for some function f.
241    fn is_homogeneous(_rhs: &Expression, _dependent: &Symbol, _independent: &Symbol) -> bool {
242        false
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use crate::{expr, symbol};
250
251    #[test]
252    fn test_classify_separable_product() {
253        let x = symbol!(x);
254        let y = symbol!(y);
255
256        let rhs = expr!(x * y);
257        assert_eq!(
258            ODEClassifier::classify_first_order(&rhs, &y, &x),
259            ODEType::Separable
260        );
261    }
262
263    #[test]
264    fn test_classify_separable_quotient() {
265        let x = symbol!(x);
266        let y = symbol!(y);
267
268        let rhs = expr!(x / y);
269        assert_eq!(
270            ODEClassifier::classify_first_order(&rhs, &y, &x),
271            ODEType::Separable
272        );
273    }
274
275    #[test]
276    fn test_classify_linear_simple() {
277        let x = symbol!(x);
278        let y = symbol!(y);
279
280        let rhs = Expression::add(vec![
281            Expression::mul(vec![Expression::integer(-1), Expression::symbol(y.clone())]),
282            Expression::symbol(x.clone()),
283        ]);
284        assert_eq!(
285            ODEClassifier::classify_first_order(&rhs, &y, &x),
286            ODEType::LinearFirstOrder
287        );
288    }
289
290    #[test]
291    fn test_classify_linear_with_coefficient() {
292        let x = symbol!(x);
293        let y = symbol!(y);
294
295        let rhs = expr!(x * y);
296        assert_eq!(
297            ODEClassifier::classify_first_order(&rhs, &y, &x),
298            ODEType::Separable
299        );
300    }
301
302    #[test]
303    fn test_classify_bernoulli() {
304        let x = symbol!(x);
305        let y = symbol!(y);
306
307        let rhs = Expression::add(vec![
308            Expression::symbol(y.clone()),
309            Expression::mul(vec![
310                Expression::symbol(x.clone()),
311                Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
312            ]),
313        ]);
314        assert_eq!(
315            ODEClassifier::classify_first_order(&rhs, &y, &x),
316            ODEType::Bernoulli
317        );
318    }
319
320    #[test]
321    fn test_classify_unknown() {
322        let x = symbol!(x);
323        let y = symbol!(y);
324
325        let rhs = Expression::function(
326            "sin",
327            vec![Expression::mul(vec![
328                Expression::symbol(x.clone()),
329                Expression::symbol(y.clone()),
330            ])],
331        );
332        assert_eq!(
333            ODEClassifier::classify_first_order(&rhs, &y, &x),
334            ODEType::Unknown
335        );
336    }
337
338    #[test]
339    fn test_is_linear_in_y_symbol() {
340        let y = symbol!(y);
341        assert!(ODEClassifier::is_linear_in_y(
342            &Expression::symbol(y.clone()),
343            &y
344        ));
345    }
346
347    #[test]
348    fn test_is_linear_in_y_product() {
349        let y = symbol!(y);
350
351        let expr = expr!(x * y);
352        assert!(ODEClassifier::is_linear_in_y(&expr, &y));
353    }
354
355    #[test]
356    fn test_is_linear_in_y_nonlinear() {
357        let y = symbol!(y);
358
359        let expr = Expression::pow(Expression::symbol(y.clone()), Expression::integer(2));
360        assert!(!ODEClassifier::is_linear_in_y(&expr, &y));
361    }
362
363    #[test]
364    fn test_has_y_power_true() {
365        let y = symbol!(y);
366
367        let expr = Expression::pow(Expression::symbol(y.clone()), Expression::integer(2));
368        assert!(ODEClassifier::has_y_power(&expr, &y));
369    }
370
371    #[test]
372    fn test_has_y_power_false_linear() {
373        let y = symbol!(y);
374
375        let expr = Expression::symbol(y.clone());
376        assert!(!ODEClassifier::has_y_power(&expr, &y));
377    }
378
379    #[test]
380    fn test_classify_second_order_constant_coeff() {
381        let x = symbol!(x);
382        let y = symbol!(y);
383
384        let lhs = expr!(y + y);
385        let rhs = Expression::integer(0);
386
387        assert_eq!(
388            ODEClassifier::classify_second_order(&lhs, &rhs, &y, &x),
389            ODEType::ConstantCoefficients
390        );
391    }
392}