differential_equations/
traits.rs

1//! Defines Generics for the library. Includes generics for the floating point numbers.
2
3use nalgebra::{RealField, SMatrix};
4use num_complex::Complex;
5use std::{fmt::Debug, ops::{Add, AddAssign, Div, Mul, Sub}};
6
7/// Real Number Trait
8///
9/// This trait specifies the acceptable types for real numbers.
10/// Currently implemented for:
11/// * `f32` - 32-bit floating point
12/// * `f64` - 64-bit floating point
13///
14/// Provides additional functionality required for ODE solvers beyond
15/// what's provided by nalgebra's RealField trait.
16///
17pub trait Real: Copy + RealField {
18    fn infinity() -> Self;
19    fn to_f64(self) -> f64;
20}
21
22impl Real for f32 {
23    fn infinity() -> Self {
24        std::f32::INFINITY
25    }
26
27    fn to_f64(self) -> f64 {
28        f64::from(self)
29    }
30}
31
32impl Real for f64 {
33    fn infinity() -> Self {
34        std::f64::INFINITY
35    }
36
37    fn to_f64(self) -> f64 {
38        self
39    }
40}
41
42/// State vector trait
43///
44/// Represents the state of the system being solved.
45///
46/// Implements for the following types:
47/// * `f32` - 32-bit floating point
48/// * `f64` - 64-bit floating point
49/// * `SMatrix` - Matrix type from nalgebra
50/// * `Complex` - Complex number type from num-complex
51/// * `Struct<T>` - Any struct with all fields of type T using #[derive(State)] from the `derive` module
52///
53pub trait State<T>:
54    Clone
55    + Copy
56    + Debug
57    + Add<Output = Self>
58    + Sub<Output = Self>
59    + AddAssign
60    + Mul<T, Output = Self>
61    + Div<T, Output = Self>
62{
63    fn len(&self) -> usize;
64
65    fn get(&self, i: usize) -> T;
66
67    fn set(&mut self, i: usize, value: T);
68
69    fn zeros() -> Self;
70}
71
72impl<T: Real> State<T> for T {
73    fn len(&self) -> usize {
74        1
75    }
76
77    fn get(&self, i: usize) -> T {
78        if i == 0 {
79            *self
80        } else {
81            panic!("Index out of bounds")
82        }
83    }
84
85    fn set(&mut self, i: usize, value: T) {
86        if i == 0 {
87            *self = value;
88        } else {
89            panic!("Index out of bounds")
90        }
91    }
92
93    fn zeros() -> Self {
94        T::zero()
95    }
96}
97
98impl<T, const R: usize, const C: usize> State<T> for SMatrix<T, R, C>
99where
100    T: Real,
101{
102    fn len(&self) -> usize {
103        R * C
104    }
105
106    fn get(&self, i: usize) -> T {
107        if i < self.len() {
108            self[(i / C, i % C)]
109        } else {
110            panic!("Index out of bounds")
111        }
112    }
113
114    fn set(&mut self, i: usize, value: T) {
115        if i < self.len() {
116            self[(i / C, i % C)] = value;
117        } else {
118            panic!("Index out of bounds")
119        }
120    }
121
122    fn zeros() -> Self {
123        SMatrix::<T, R, C>::zeros()
124    }
125}
126
127impl<T> State<T> for Complex<T>
128where
129    T: Real,
130{
131    fn len(&self) -> usize {
132        2
133    }
134
135    fn get(&self, i: usize) -> T {
136        if i == 0 {
137            self.re
138        } else if i == 1 {
139            self.im
140        } else {
141            panic!("Index out of bounds")
142        }
143    }
144
145    fn set(&mut self, i: usize, value: T) {
146        if i == 0 {
147            self.re = value;
148        } else if i == 1 {
149            self.im = value;
150        } else {
151            panic!("Index out of bounds")
152        }
153    }
154
155    fn zeros() -> Self {
156        Complex::new(T::zero(), T::zero())
157    }
158}
159
160/// Callback data trait
161///
162/// This trait represents data that can be returned from functions
163/// that are used to control the solver's execution flow. The
164/// Clone and Debug traits are required for internal use but anything
165/// that implements this trait can be used as callback data.
166/// For example, this can be a string, a number, or any other type
167/// that implements the Clone and Debug traits.
168///
169pub trait CallBackData: Clone + std::fmt::Debug {}
170
171// Implement for any type that already satisfies the bounds
172impl<T: Clone + std::fmt::Debug> CallBackData for T {}