differential_equations/
traits.rs

1//! Defines Generics for the library. Includes generics for the floating point numbers.
2
3use nalgebra::{RealField, SMatrix};
4use num_complex::Complex;
5use std::{
6    fmt::Debug,
7    ops::{Add, AddAssign, Div, Mul, Neg, Sub},
8};
9
10/// Real Number Trait
11///
12/// This trait specifies the acceptable types for real numbers.
13/// Currently implemented for:
14/// * `f32` - 32-bit floating point
15/// * `f64` - 64-bit floating point
16///
17/// Provides additional functionality required for ODE solvers beyond
18/// what's provided by nalgebra's RealField trait.
19///
20pub trait Real: Copy + RealField {
21    fn infinity() -> Self;
22    fn to_f64(&self) -> f64;
23    fn to_f32(&self) -> f32;
24}
25
26impl Real for f32 {
27    fn infinity() -> Self {
28        f32::INFINITY
29    }
30
31    fn to_f64(&self) -> f64 {
32        *self as f64
33    }
34
35    fn to_f32(&self) -> f32 {
36        *self
37    }
38}
39
40impl Real for f64 {
41    fn infinity() -> Self {
42        f64::INFINITY
43    }
44
45    fn to_f64(&self) -> f64 {
46        *self
47    }
48
49    fn to_f32(&self) -> f32 {
50        *self as f32
51    }
52}
53
54/// State vector trait
55///
56/// Represents the state of the system being solved.
57///
58/// Implements for the following types:
59/// * `f32` - 32-bit floating point
60/// * `f64` - 64-bit floating point
61/// * `SMatrix` - Matrix type from nalgebra
62/// * `Complex` - Complex number type from num-complex
63/// * `Struct<T>` - Any struct with all fields of type T using #[derive(State)] from the `derive` module
64///
65pub trait State<T: Real>:
66    Clone
67    + Copy
68    + Debug
69    + Add<Output = Self>
70    + Sub<Output = Self>
71    + AddAssign
72    + Mul<T, Output = Self>
73    + Div<T, Output = Self>
74    + Neg<Output = Self>
75{
76    fn len(&self) -> usize;
77
78    fn get(&self, i: usize) -> T;
79
80    fn set(&mut self, i: usize, value: T);
81
82    fn zeros() -> Self;
83}
84
85impl<T: Real> State<T> for T {
86    fn len(&self) -> usize {
87        1
88    }
89
90    fn get(&self, i: usize) -> T {
91        if i == 0 {
92            *self
93        } else {
94            panic!("Index out of bounds")
95        }
96    }
97
98    fn set(&mut self, i: usize, value: T) {
99        if i == 0 {
100            *self = value;
101        } else {
102            panic!("Index out of bounds")
103        }
104    }
105
106    fn zeros() -> Self {
107        T::zero()
108    }
109}
110
111impl<T, const R: usize, const C: usize> State<T> for SMatrix<T, R, C>
112where
113    T: Real,
114{
115    fn len(&self) -> usize {
116        R * C
117    }
118
119    fn get(&self, i: usize) -> T {
120        self[(i / C, i % C)]
121    }
122
123    fn set(&mut self, i: usize, value: T) {
124        self[(i / C, i % C)] = value;
125    }
126
127    fn zeros() -> Self {
128        SMatrix::<T, R, C>::zeros()
129    }
130}
131
132impl<T> State<T> for Complex<T>
133where
134    T: Real,
135{
136    fn len(&self) -> usize {
137        2
138    }
139
140    fn get(&self, i: usize) -> T {
141        if i == 0 {
142            self.re
143        } else if i == 1 {
144            self.im
145        } else {
146            panic!("Index out of bounds")
147        }
148    }
149
150    fn set(&mut self, i: usize, value: T) {
151        if i == 0 {
152            self.re = value;
153        } else if i == 1 {
154            self.im = value;
155        } else {
156            panic!("Index out of bounds")
157        }
158    }
159
160    fn zeros() -> Self {
161        Complex::new(T::zero(), T::zero())
162    }
163}