dynamics_spatial/
configuration.rs1use approx::{AbsDiffEq, RelativeEq};
4use nalgebra::DVector;
5
6use rand::{Rng, rngs::ThreadRng};
7use std::{
8 f64::consts::PI,
9 fmt::Display,
10 ops::{Add, Index, IndexMut, Mul, Sub, SubAssign},
11};
12
13#[cfg(feature = "python")]
14use numpy::PyReadonlyArrayDyn;
15#[cfg(feature = "python")]
16use pyo3::prelude::*;
17
18#[derive(Clone, Debug, PartialEq)]
19pub struct Configuration(pub(crate) DVector<f64>);
21
22impl Configuration {
23 #[must_use]
29 pub fn zeros(size: usize) -> Self {
30 Configuration(DVector::zeros(size))
31 }
32
33 #[must_use]
39 pub fn ones(size: usize) -> Self {
40 Configuration(DVector::from_element(size, 1.0))
41 }
42
43 #[must_use]
50 pub fn from_element(size: usize, value: f64) -> Self {
51 Configuration(DVector::from_element(size, value))
52 }
53
54 #[must_use]
56 pub fn len(&self) -> usize {
57 self.0.len()
58 }
59
60 #[must_use]
62 pub fn is_empty(&self) -> bool {
63 self.0.is_empty()
64 }
65
66 #[must_use]
78 pub fn rows(&self, start: usize, nrows: usize) -> Configuration {
79 Configuration(self.0.rows(start, nrows).into_owned())
80 }
81
82 pub fn update_rows(
94 &mut self,
95 start: usize,
96 values: &Configuration,
97 ) -> Result<(), ConfigurationError> {
98 if self.0.rows(start, values.len()).len() != values.0.len() {
99 Err(ConfigurationError::MismatchedUpdateSize(
100 self.0.rows(start, values.len()).len(),
101 values.0.len(),
102 ))
103 } else {
104 self.0.rows_mut(start, values.len()).copy_from(&values.0);
105 Ok(())
106 }
107 }
108
109 #[must_use]
115 pub fn from_row_slice(data: &[f64]) -> Self {
116 Configuration(DVector::from_row_slice(data))
117 }
118
119 #[cfg(feature = "python")]
125 pub fn from_pyarray(array: &PyReadonlyArrayDyn<f64>) -> Result<Configuration, PyErr> {
126 let array = array.as_array();
127 let flat: Vec<f64> = array.iter().copied().collect();
128 Ok(Configuration::from_row_slice(&flat))
129 }
130
131 #[must_use]
137 pub fn concat(configs: &[Configuration]) -> Configuration {
138 let mut all_values = Vec::new();
139 for config in configs {
140 all_values.extend_from_slice(config.0.as_slice());
141 }
142 Configuration::from_row_slice(&all_values)
143 }
144
145 pub fn random(
155 nq: usize,
156 rng: &mut ThreadRng,
157 min: &Configuration,
158 max: &Configuration,
159 ) -> Self {
160 let mut values = Vec::with_capacity(nq);
161 for i in 0..nq {
162 let min_i = if min[i].is_infinite() && min[i] < 0.0 {
164 -2.0 * PI
165 } else {
166 min[i]
167 };
168 let max_i = if max[i].is_infinite() && max[i] > 0.0 {
170 2.0 * PI
171 } else {
172 max[i]
173 };
174 values.push(rng.random_range(min_i..=max_i));
175 }
176 Configuration::from_row_slice(&values)
177 }
178
179 pub fn check_size(&self, name: &str, expected_size: usize) -> Result<(), ConfigurationError> {
181 if self.len() != expected_size {
182 Err(ConfigurationError::InvalidParameterSize(
183 name.to_string(),
184 expected_size,
185 self.len(),
186 ))
187 } else {
188 Ok(())
189 }
190 }
191}
192
193impl Index<usize> for Configuration {
194 type Output = f64;
195
196 fn index(&self, index: usize) -> &Self::Output {
197 &self.0[index]
198 }
199}
200
201impl IndexMut<usize> for Configuration {
202 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
203 &mut self.0[index]
204 }
205}
206
207impl AbsDiffEq for Configuration {
208 type Epsilon = f64;
209
210 fn default_epsilon() -> Self::Epsilon {
211 f64::default_epsilon()
212 }
213
214 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
215 self.0.abs_diff_eq(&other.0, epsilon)
216 }
217}
218
219impl RelativeEq for Configuration {
220 fn default_max_relative() -> f64 {
221 f64::default_max_relative()
222 }
223
224 fn relative_eq(&self, other: &Self, epsilon: f64, max_relative: f64) -> bool {
225 self.0.relative_eq(&other.0, epsilon, max_relative)
226 }
227}
228
229impl Add for Configuration {
230 type Output = Configuration;
231
232 fn add(self, rhs: Self) -> Self::Output {
233 Configuration(self.0 + rhs.0)
234 }
235}
236
237impl Add for &Configuration {
238 type Output = Configuration;
239
240 fn add(self, rhs: Self) -> Self::Output {
241 Configuration(&self.0 + &rhs.0)
242 }
243}
244
245impl Mul<f64> for &Configuration {
246 type Output = DVector<f64>;
247
248 fn mul(self, rhs: f64) -> Self::Output {
249 &self.0 * rhs
250 }
251}
252
253impl Sub for &Configuration {
254 type Output = Configuration;
255
256 fn sub(self, rhs: Self) -> Self::Output {
257 Configuration(&self.0 - &rhs.0)
258 }
259}
260
261impl SubAssign<&Configuration> for Configuration {
262 fn sub_assign(&mut self, rhs: &Configuration) {
263 self.0 -= &rhs.0;
264 }
265}
266
267impl Display for Configuration {
268 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269 write!(f, "Configuration(")?;
270 for i in 0..self.len() {
271 write!(f, "{:.5}", self[i])?;
272 if i < self.len() - 1 {
273 write!(f, ", ")?;
274 }
275 }
276 write!(f, ")")?;
277 Ok(())
278 }
279}
280
281pub enum ConfigurationError {
283 InvalidParameterSize(String, usize, usize),
288 MismatchedUpdateSize(usize, usize),
292}
293
294impl std::fmt::Display for ConfigurationError {
295 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
296 match self {
297 ConfigurationError::InvalidParameterSize(name, expected, actual) => {
298 write!(
299 f,
300 "Parameter '{name}' expected configuration size {expected}, but got {actual}"
301 )
302 }
303 ConfigurationError::MismatchedUpdateSize(expected, actual) => {
304 write!(
305 f,
306 "Mismatched sizes when updating configuration rows. Expected size {expected}, got {actual}."
307 )
308 }
309 }
310 }
311}
312
313impl std::fmt::Debug for ConfigurationError {
314 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
315 write!(f, "{self}")
316 }
317}
318
319impl std::error::Error for ConfigurationError {}