differential_equations/
traits.rs

1//! Defines Generics for the library. Includes generics for the floating point numbers.
2
3use nalgebra::{RealField, SMatrix};
4use num_complex::Complex;
5use std::{
6    fmt::Debug,
7    ops::{Add, AddAssign, Div, Mul, Sub},
8};
9
10/// Real Number Trait
11///
12/// This trait specifies the acceptable types for real numbers.
13/// Currently implemented for:
14/// * `f32` - 32-bit floating point
15/// * `f64` - 64-bit floating point
16///
17/// Provides additional functionality required for ODE solvers beyond
18/// what's provided by nalgebra's RealField trait.
19///
20pub trait Real: Copy + RealField {
21    fn infinity() -> Self;
22}
23
24impl Real for f32 {
25    fn infinity() -> Self {
26        f32::INFINITY
27    }
28}
29
30impl Real for f64 {
31    fn infinity() -> Self {
32        f64::INFINITY
33    }
34}
35
36/// State vector trait
37///
38/// Represents the state of the system being solved.
39///
40/// Implements for the following types:
41/// * `f32` - 32-bit floating point
42/// * `f64` - 64-bit floating point
43/// * `SMatrix` - Matrix type from nalgebra
44/// * `Complex` - Complex number type from num-complex
45/// * `Struct<T>` - Any struct with all fields of type T using #[derive(State)] from the `derive` module
46///
47pub trait State<T: Real>:
48    Clone
49    + Copy
50    + Debug
51    + Add<Output = Self>
52    + Sub<Output = Self>
53    + AddAssign
54    + Mul<T, Output = Self>
55    + Div<T, Output = Self>
56{
57    fn len(&self) -> usize;
58
59    fn get(&self, i: usize) -> T;
60
61    fn set(&mut self, i: usize, value: T);
62
63    fn zeros() -> Self;
64}
65
66impl<T: Real> State<T> for T {
67    fn len(&self) -> usize {
68        1
69    }
70
71    fn get(&self, i: usize) -> T {
72        if i == 0 {
73            *self
74        } else {
75            panic!("Index out of bounds")
76        }
77    }
78
79    fn set(&mut self, i: usize, value: T) {
80        if i == 0 {
81            *self = value;
82        } else {
83            panic!("Index out of bounds")
84        }
85    }
86
87    fn zeros() -> Self {
88        T::zero()
89    }
90}
91
92impl<T, const R: usize, const C: usize> State<T> for SMatrix<T, R, C>
93where
94    T: Real,
95{
96    fn len(&self) -> usize {
97        R * C
98    }
99
100    fn get(&self, i: usize) -> T {
101        self[(i / C, i % C)]
102    }
103
104    fn set(&mut self, i: usize, value: T) {
105        self[(i / C, i % C)] = value;
106    }
107
108    fn zeros() -> Self {
109        SMatrix::<T, R, C>::zeros()
110    }
111}
112
113impl<T> State<T> for Complex<T>
114where
115    T: Real,
116{
117    fn len(&self) -> usize {
118        2
119    }
120
121    fn get(&self, i: usize) -> T {
122        if i == 0 {
123            self.re
124        } else if i == 1 {
125            self.im
126        } else {
127            panic!("Index out of bounds")
128        }
129    }
130
131    fn set(&mut self, i: usize, value: T) {
132        if i == 0 {
133            self.re = value;
134        } else if i == 1 {
135            self.im = value;
136        } else {
137            panic!("Index out of bounds")
138        }
139    }
140
141    fn zeros() -> Self {
142        Complex::new(T::zero(), T::zero())
143    }
144}
145
146/// Callback data trait
147///
148/// This trait represents data that can be returned from functions
149/// that are used to control the solver's execution flow. The
150/// Clone and Debug traits are required for internal use but anything
151/// that implements this trait can be used as callback data.
152/// For example, this can be a string, a number, or any other type
153/// that implements the Clone and Debug traits.
154///
155pub trait CallBackData: Clone + std::fmt::Debug {}
156
157// Implement for any type that already satisfies the bounds
158impl<T: Clone + std::fmt::Debug> CallBackData for T {}