mathhook_core/calculus/ode/
classifier.rs1use crate::core::{Expression, Symbol};
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub enum ODEType {
12 Separable,
14 LinearFirstOrder,
16 Exact,
18 Bernoulli,
20 Homogeneous,
22 ConstantCoefficients,
24 VariableCoefficients,
26 Unknown,
28}
29
30pub struct ODEClassifier;
32
33impl ODEClassifier {
34 pub fn classify_first_order(
63 rhs: &Expression,
64 dependent: &Symbol,
65 independent: &Symbol,
66 ) -> ODEType {
67 if Self::is_separable(rhs, dependent, independent) {
68 return ODEType::Separable;
69 }
70
71 if Self::is_linear_first_order(rhs, dependent, independent) {
72 return ODEType::LinearFirstOrder;
73 }
74
75 if Self::is_bernoulli(rhs, dependent, independent) {
76 return ODEType::Bernoulli;
77 }
78
79 if Self::is_exact(rhs, dependent, independent) {
80 return ODEType::Exact;
81 }
82
83 if Self::is_homogeneous(rhs, dependent, independent) {
84 return ODEType::Homogeneous;
85 }
86
87 ODEType::Unknown
88 }
89
90 pub fn classify_second_order(
117 _lhs: &Expression,
118 _rhs: &Expression,
119 _dependent: &Symbol,
120 _independent: &Symbol,
121 ) -> ODEType {
122 ODEType::ConstantCoefficients
123 }
124
125 fn is_separable(rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> bool {
130 use super::first_order::SeparableODESolver;
131 SeparableODESolver::new().is_separable(rhs, dependent, independent)
132 }
133
134 fn is_linear_first_order(rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> bool {
139 match rhs {
140 Expression::Add(terms) => {
141 let mut has_y_term = false;
142 let mut has_const_term = false;
143
144 for term in terms.iter() {
145 if term.contains_variable(dependent) {
146 if Self::is_linear_in_y(term, dependent) {
147 has_y_term = true;
148 } else {
149 return false;
150 }
151 } else if term.contains_variable(independent) {
152 has_const_term = true;
153 }
154 }
155
156 has_y_term || has_const_term
157 }
158 Expression::Mul(factors) => {
159 let mut y_count = 0;
160 for factor in factors.iter() {
161 if factor.contains_variable(dependent) {
162 y_count += 1;
163 }
164 }
165 y_count <= 1
166 }
167 _ => !rhs.contains_variable(dependent) || Self::is_linear_in_y(rhs, dependent),
168 }
169 }
170
171 fn is_linear_in_y(expr: &Expression, y: &Symbol) -> bool {
173 match expr {
174 Expression::Symbol(s) => s == y,
175 Expression::Mul(factors) => {
176 let mut y_count = 0;
177 for factor in factors.iter() {
178 if factor.contains_variable(y) {
179 if matches!(factor, Expression::Symbol(s) if s == y) {
180 y_count += 1;
181 } else {
182 return false;
183 }
184 }
185 }
186 y_count <= 1
187 }
188 _ => false,
189 }
190 }
191
192 fn is_bernoulli(rhs: &Expression, dependent: &Symbol, _independent: &Symbol) -> bool {
197 match rhs {
198 Expression::Add(terms) => {
199 let mut has_y_power = false;
200 let mut has_linear_y = false;
201
202 for term in terms.iter() {
203 if term.contains_variable(dependent) {
204 if Self::has_y_power(term, dependent) {
205 has_y_power = true;
206 } else if Self::is_linear_in_y(term, dependent) {
207 has_linear_y = true;
208 }
209 }
210 }
211
212 has_y_power && has_linear_y
213 }
214 _ => false,
215 }
216 }
217
218 fn has_y_power(expr: &Expression, y: &Symbol) -> bool {
220 match expr {
221 Expression::Pow(base, exp) => {
222 matches!(**base, Expression::Symbol(ref s) if s == y)
223 && !matches!(**exp, Expression::Number(ref n) if n.is_one())
224 }
225 Expression::Mul(factors) => factors.iter().any(|f| Self::has_y_power(f, y)),
226 _ => false,
227 }
228 }
229
230 fn is_exact(_rhs: &Expression, _dependent: &Symbol, _independent: &Symbol) -> bool {
234 false
235 }
236
237 fn is_homogeneous(_rhs: &Expression, _dependent: &Symbol, _independent: &Symbol) -> bool {
242 false
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use crate::{expr, symbol};
250
251 #[test]
252 fn test_classify_separable_product() {
253 let x = symbol!(x);
254 let y = symbol!(y);
255
256 let rhs = expr!(x * y);
257 assert_eq!(
258 ODEClassifier::classify_first_order(&rhs, &y, &x),
259 ODEType::Separable
260 );
261 }
262
263 #[test]
264 fn test_classify_separable_quotient() {
265 let x = symbol!(x);
266 let y = symbol!(y);
267
268 let rhs = expr!(x / y);
269 assert_eq!(
270 ODEClassifier::classify_first_order(&rhs, &y, &x),
271 ODEType::Separable
272 );
273 }
274
275 #[test]
276 fn test_classify_linear_simple() {
277 let x = symbol!(x);
278 let y = symbol!(y);
279
280 let rhs = Expression::add(vec![
281 Expression::mul(vec![Expression::integer(-1), Expression::symbol(y.clone())]),
282 Expression::symbol(x.clone()),
283 ]);
284 assert_eq!(
285 ODEClassifier::classify_first_order(&rhs, &y, &x),
286 ODEType::LinearFirstOrder
287 );
288 }
289
290 #[test]
291 fn test_classify_linear_with_coefficient() {
292 let x = symbol!(x);
293 let y = symbol!(y);
294
295 let rhs = expr!(x * y);
296 assert_eq!(
297 ODEClassifier::classify_first_order(&rhs, &y, &x),
298 ODEType::Separable
299 );
300 }
301
302 #[test]
303 fn test_classify_bernoulli() {
304 let x = symbol!(x);
305 let y = symbol!(y);
306
307 let rhs = Expression::add(vec![
308 Expression::symbol(y.clone()),
309 Expression::mul(vec![
310 Expression::symbol(x.clone()),
311 Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
312 ]),
313 ]);
314 assert_eq!(
315 ODEClassifier::classify_first_order(&rhs, &y, &x),
316 ODEType::Bernoulli
317 );
318 }
319
320 #[test]
321 fn test_classify_unknown() {
322 let x = symbol!(x);
323 let y = symbol!(y);
324
325 let rhs = Expression::function(
326 "sin",
327 vec![Expression::mul(vec![
328 Expression::symbol(x.clone()),
329 Expression::symbol(y.clone()),
330 ])],
331 );
332 assert_eq!(
333 ODEClassifier::classify_first_order(&rhs, &y, &x),
334 ODEType::Unknown
335 );
336 }
337
338 #[test]
339 fn test_is_linear_in_y_symbol() {
340 let y = symbol!(y);
341 assert!(ODEClassifier::is_linear_in_y(
342 &Expression::symbol(y.clone()),
343 &y
344 ));
345 }
346
347 #[test]
348 fn test_is_linear_in_y_product() {
349 let y = symbol!(y);
350
351 let expr = expr!(x * y);
352 assert!(ODEClassifier::is_linear_in_y(&expr, &y));
353 }
354
355 #[test]
356 fn test_is_linear_in_y_nonlinear() {
357 let y = symbol!(y);
358
359 let expr = Expression::pow(Expression::symbol(y.clone()), Expression::integer(2));
360 assert!(!ODEClassifier::is_linear_in_y(&expr, &y));
361 }
362
363 #[test]
364 fn test_has_y_power_true() {
365 let y = symbol!(y);
366
367 let expr = Expression::pow(Expression::symbol(y.clone()), Expression::integer(2));
368 assert!(ODEClassifier::has_y_power(&expr, &y));
369 }
370
371 #[test]
372 fn test_has_y_power_false_linear() {
373 let y = symbol!(y);
374
375 let expr = Expression::symbol(y.clone());
376 assert!(!ODEClassifier::has_y_power(&expr, &y));
377 }
378
379 #[test]
380 fn test_classify_second_order_constant_coeff() {
381 let x = symbol!(x);
382 let y = symbol!(y);
383
384 let lhs = expr!(y + y);
385 let rhs = Expression::integer(0);
386
387 assert_eq!(
388 ODEClassifier::classify_second_order(&lhs, &rhs, &y, &x),
389 ODEType::ConstantCoefficients
390 );
391 }
392}