1use crate::components::*;
2use crate::imports::*;
3
4#[pyo3_api(
6 #[new]
7 fn __new__(
8 solver_type: String,
9 m1: ThermalMass,
10 m2: ThermalMass,
11 h12: Conductance,
12 m3: ThermalMass,
13 h23: Conductance,
14 t_report: Vec<f64>,
15 ) -> Self {
16 Self{
17 solver_type: SolverTypes::from_json(&solver_type).unwrap(),
18 m1,
19 m2,
20 h12,
21 m3,
22 h23,
23 t_report,
24 state: Default::default(),
25 history: Default::default(),
26 }
27 }
28
29 #[classmethod]
30 #[allow(clippy::too_many_arguments)]
31 fn new_rk45_cash_karp(
32 _cls: &PyType,
33 sol: AdaptiveSolverConfig,
34 m1: ThermalMass,
35 m2: ThermalMass,
36 h12: Conductance,
37 m3: ThermalMass,
38 h23: Conductance,
39 t_report: Vec<f64>,
40 ) -> Self {
41 Self{
42 solver_type: SolverTypes::RK45CashKarp(Box::new(sol)),
43 m1,
44 m2,
45 h12,
46 m3,
47 h23,
48 t_report,
49 state: Default::default(),
50 history: Default::default(),
51 }
52 }
53
54 #[classmethod]
55 #[allow(clippy::too_many_arguments)]
56 fn new_rk23_bogacki_shampine(
57 _cls: &PyType,
58 sol: AdaptiveSolverConfig,
59 m1: ThermalMass,
60 m2: ThermalMass,
61 h12: Conductance,
62 m3: ThermalMass,
63 h23: Conductance,
64 t_report: Vec<f64>,
65 ) -> Self {
66 Self{
67 solver_type: SolverTypes::RK23BogackiShampine(Box::new(sol)),
68 m1,
69 m2,
70 h12,
71 m3,
72 h23,
73 t_report,
74 state: Default::default(),
75 history: Default::default(),
76 }
77 }
78
79 #[getter]
80 fn get_solver_conf(&self) -> Option<AdaptiveSolverConfig> {
81 match &self.solver_type {
82 SolverTypes::RK45CashKarp(sc) => Some(*sc.clone()),
83 SolverTypes::RK23BogackiShampine(sc) => Some(*sc.clone()),
84 _ => None,
85 }
86 }
87
88 #[getter]
89 fn get_solver_type(&self) -> String {
90 self.solver_type.to_json()
91 }
92
93 #[pyo3(name = "walk")]
94 fn walk_py(&mut self) {
95 self.walk();
96 }
97)]
98#[solver(
99 fn update_derivs(&mut self) {
102 self.reset_derivs();
103 connect_states!(self, (m1, m2, h12), (m2, m3, h23));
104 update_derivs!(self, (m1, m2, h12), (m2, m3, h23));
105 }
106)]
107#[derive(HistoryMethods, BareClone)]
108#[common_derives]
109pub struct System3TM {
110 #[skip_get]
111 pub solver_type: SolverTypes,
112 #[use_state]
115 pub m1: ThermalMass,
116 #[use_state]
117 pub m2: ThermalMass,
118 #[save_state]
120 pub h12: Conductance,
121 #[use_state]
122 pub m3: ThermalMass,
123 #[save_state]
124 pub h23: Conductance,
125 pub t_report: Vec<f64>,
127 pub state: SystemState,
128 pub history: SystemStateHistoryVec,
129}
130
131impl Default for System3TM {
132 fn default() -> Self {
133 Self {
134 solver_type: SolverTypes::EulerFixed { dt: 5e-3 },
135 m1: ThermalMass {
136 c: 1.0,
137 state: ThermalMassState {
138 temp: 0.0,
139 dtemp: Default::default(),
140 },
141 history: Default::default(),
142 },
143 m2: ThermalMass {
144 c: 2.0,
145 state: ThermalMassState {
146 temp: 10.,
147 dtemp: Default::default(),
148 },
149 history: Default::default(),
150 },
151 h12: Conductance {
152 h: 5.0,
153 state: ConductanceState {
154 q: Default::default(),
155 },
156 history: ConductanceStateHistoryVec {
157 q: Default::default(),
158 },
159 },
160 m3: ThermalMass {
161 c: 1.5,
162 state: ThermalMassState {
163 temp: 12.,
164 dtemp: Default::default(),
165 },
166 history: Default::default(),
167 },
168 h23: Conductance {
169 h: 5.0,
170 state: ConductanceState {
171 q: Default::default(),
172 },
173 history: ConductanceStateHistoryVec {
174 q: Default::default(),
175 },
176 },
177 t_report: Vec::linspace(0.0, 1.0, 201),
178 state: Default::default(),
179 history: Default::default(),
180 }
181 }
182}
183
184pub fn mock_euler_sys() -> System3TM {
185 System3TM::default()
186}
187
188pub fn mock_heuns_sys() -> System3TM {
189 let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 51);
190
191 System3TM {
192 solver_type: SolverTypes::HeunsMethod { dt: 5e-3 },
193 t_report,
194 ..mock_euler_sys()
195 }
196}
197
198pub fn mock_midpoint_sys() -> System3TM {
199 let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 51);
200
201 System3TM {
202 solver_type: SolverTypes::MidpointMethod { dt: 5e-3 },
203 t_report,
204 ..mock_euler_sys()
205 }
206}
207
208pub fn mock_ralstons_sys() -> System3TM {
209 let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 51);
210
211 System3TM {
212 solver_type: SolverTypes::RalstonsMethod { dt: 5e-3 },
213 t_report,
214 ..mock_euler_sys()
215 }
216}
217
218pub fn mock_rk23_sys() -> System3TM {
219 let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 11);
220
221 System3TM {
222 solver_type: SolverTypes::RK23BogackiShampine(Box::default()),
223 t_report,
224 ..mock_euler_sys()
225 }
226}
227
228pub fn mock_rk4fixed_sys() -> System3TM {
229 let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 51);
230
231 System3TM {
232 solver_type: Default::default(),
233 t_report,
234 ..mock_euler_sys()
235 }
236}
237
238pub fn mock_rk45_sys() -> System3TM {
239 let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 11);
240
241 System3TM {
242 solver_type: SolverTypes::RK45CashKarp(Box::default()),
243 t_report,
244 ..mock_euler_sys()
245 }
246}
247
248pub fn run_three_tm_sys(overwrite_benchmarks: bool) {
249 let mut sys_euler = mock_euler_sys();
251
252 let t_euler = time_it!(sys_euler.walk());
253
254 let dt = sys_euler.t_report[1] - sys_euler.t_report.first().unwrap();
255
256 println!(
257 "Euler {} s time step elapsed time: {} μs",
258 dt,
259 t_euler.as_micros()
260 );
261
262 let overwrite_euler_benchmark: bool = overwrite_benchmarks;
263 if overwrite_euler_benchmark {
264 let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
265 .parent()
266 .unwrap()
267 .to_path_buf()
268 .join("dess-examples/src/tests/fixtures/euler benchmark.yaml");
269
270 sys_euler
271 .to_file(benchmark_file.as_os_str().to_str().unwrap())
272 .unwrap();
273 }
274
275 let mut sys_heuns = mock_heuns_sys();
277
278 let t_heuns = time_it!(sys_heuns.walk());
279
280 let dt = sys_heuns.t_report[1] - sys_heuns.t_report.first().unwrap();
281
282 println!(
283 "Heuns {} s time step elapsed time: {} μs",
284 dt,
285 t_heuns.as_micros()
286 );
287
288 let overwrite_heuns_benchmark: bool = overwrite_benchmarks;
289 if overwrite_heuns_benchmark {
290 let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
291 .parent()
292 .unwrap()
293 .to_path_buf()
294 .join("dess-examples/src/tests/fixtures/heuns benchmark.yaml");
295
296 sys_heuns
297 .to_file(benchmark_file.as_os_str().to_str().unwrap())
298 .unwrap();
299 }
300 let mut sys_midpoint = mock_midpoint_sys();
302
303 let t_midpoint = time_it!(sys_midpoint.walk());
304
305 let dt = sys_midpoint.t_report[1] - sys_midpoint.t_report.first().unwrap();
306
307 println!(
308 "Midpoint {} s time step elapsed time: {} μs",
309 dt,
310 t_midpoint.as_micros()
311 );
312
313 let overwrite_midpoint_benchmark: bool = overwrite_benchmarks;
314 if overwrite_midpoint_benchmark {
315 let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
316 .parent()
317 .unwrap()
318 .to_path_buf()
319 .join("dess-examples/src/tests/fixtures/midpoint benchmark.yaml");
320
321 sys_midpoint
322 .to_file(benchmark_file.as_os_str().to_str().unwrap())
323 .unwrap();
324 }
325 let mut sys_ralstons = mock_ralstons_sys();
327
328 let t_ralstons = time_it!(sys_ralstons.walk());
329
330 let dt = sys_ralstons.t_report[1] - sys_ralstons.t_report.first().unwrap();
331
332 println!(
333 "Ralstons {} s time step elapsed time: {} μs",
334 dt,
335 t_ralstons.as_micros()
336 );
337
338 let overwrite_ralstons_benchmark: bool = overwrite_benchmarks;
339 if overwrite_ralstons_benchmark {
340 let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
341 .parent()
342 .unwrap()
343 .to_path_buf()
344 .join("dess-examples/src/tests/fixtures/ralstons benchmark.yaml");
345
346 sys_ralstons
347 .to_file(benchmark_file.as_os_str().to_str().unwrap())
348 .unwrap();
349 }
350 let mut sys_rk23 = mock_rk23_sys();
352 let t_rk23 = time_it!(sys_rk23.walk());
353
354 let dt = sys_rk23.t_report[1] - sys_rk23.t_report.first().unwrap();
355
356 println!(
357 "RK23 Adaptive {} s init time step elapsed time: {} μs",
358 dt,
359 t_rk23.as_micros()
360 );
361
362 let overwrite_rk23_benchmark: bool = overwrite_benchmarks;
363 if overwrite_rk23_benchmark {
364 let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
365 .parent()
366 .unwrap()
367 .to_path_buf()
368 .join("dess-examples/src/tests/fixtures/rk23 benchmark.yaml");
369 sys_rk23
370 .to_file(benchmark_file.as_os_str().to_str().unwrap())
371 .unwrap();
372 }
373 let mut sys_rk4 = mock_rk4fixed_sys();
375
376 let t_rk4 = time_it!(sys_rk4.walk());
377
378 let dt = sys_rk4.t_report[1] - sys_rk4.t_report.first().unwrap();
379
380 println!(
381 "RK4 {} s time step elapsed time: {} μs",
382 dt,
383 t_rk4.as_micros()
384 );
385
386 let overwrite_rk4_benchmark: bool = overwrite_benchmarks;
387 if overwrite_rk4_benchmark {
388 let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
389 .parent()
390 .unwrap()
391 .to_path_buf()
392 .join("dess-examples/src/tests/fixtures/rk4 benchmark.yaml");
393
394 sys_rk4
395 .to_file(benchmark_file.as_os_str().to_str().unwrap())
396 .unwrap();
397 }
398
399 let mut sys_rk45 = mock_rk45_sys();
401 let t_rk45 = time_it!(sys_rk45.walk());
402
403 let dt = sys_rk45.t_report[1] - sys_rk45.t_report.first().unwrap();
404
405 println!(
406 "RK45 Adaptive {} s init time step elapsed time: {} μs",
407 dt,
408 t_rk45.as_micros()
409 );
410
411 let overwrite_rk45_benchmark: bool = overwrite_benchmarks;
412 if overwrite_rk45_benchmark {
413 let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
414 .parent()
415 .unwrap()
416 .to_path_buf()
417 .join("dess-examples/src/tests/fixtures/rk45 benchmark.yaml");
418 sys_rk45
419 .to_file(benchmark_file.as_os_str().to_str().unwrap())
420 .unwrap();
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_bare_clone() {
430 let mut sys = mock_euler_sys();
431 assert!(sys.history.is_empty());
432 assert!(sys.m1.history.is_empty());
433 assert!(sys.h12.history.is_empty());
434 sys.save_state();
435 assert!(sys.history.len() == 1);
437 assert!(sys.m1.history.len() == 1);
438 assert!(sys.h12.history.len() == 1);
439 let bare_sys = sys.bare_clone();
440 assert!(bare_sys.history.is_empty());
441 assert!(bare_sys.m1.history.is_empty());
442 assert!(bare_sys.h12.history.is_empty());
443 }
444
445 #[test]
446 fn test_euler_against_benchmark() {
447 let mut sys = mock_euler_sys();
448 sys.walk();
449
450 let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
451 .parent()
452 .unwrap()
453 .to_path_buf()
454 .join("dess-examples/src/tests/fixtures/euler benchmark.yaml");
455
456 let benchmark_sys =
457 System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
458 assert_eq!(sys, benchmark_sys);
459 }
460
461 #[test]
462 fn test_heuns_against_benchmark() {
463 let mut sys = mock_heuns_sys();
464 sys.walk();
465
466 let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
467 .parent()
468 .unwrap()
469 .to_path_buf()
470 .join("dess-examples/src/tests/fixtures/heuns benchmark.yaml");
471
472 let benchmark_sys =
473 System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
474 assert_eq!(sys, benchmark_sys);
475 }
476
477 #[test]
478 fn test_midpoint_against_benchmark() {
479 let mut sys = mock_midpoint_sys();
480 sys.walk();
481
482 let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
483 .parent()
484 .unwrap()
485 .to_path_buf()
486 .join("dess-examples/src/tests/fixtures/midpoint benchmark.yaml");
487
488 let benchmark_sys =
489 System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
490 assert_eq!(sys, benchmark_sys);
491 }
492
493 #[test]
494 fn test_ralstons_against_benchmark() {
495 let mut sys = mock_ralstons_sys();
496 sys.walk();
497
498 let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
499 .parent()
500 .unwrap()
501 .to_path_buf()
502 .join("dess-examples/src/tests/fixtures/ralstons benchmark.yaml");
503
504 let benchmark_sys =
505 System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
506 assert_eq!(sys, benchmark_sys);
507 }
508 #[test]
509 fn test_rk23_against_benchmark() {
510 let mut sys = mock_rk23_sys();
511 sys.walk();
512 let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
513 .parent()
514 .unwrap()
515 .to_path_buf()
516 .join("dess-examples/src/tests/fixtures/rk23 benchmark.yaml");
517
518 let benchmark_sys =
519 System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
520 assert_eq!(sys, benchmark_sys);
521 }
522 #[test]
523 fn test_rk4_against_benchmark() {
524 let mut sys = mock_rk4fixed_sys();
525 sys.walk();
526 let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
527 .parent()
528 .unwrap()
529 .to_path_buf()
530 .join("dess-examples/src/tests/fixtures/rk4 benchmark.yaml");
531
532 let benchmark_sys =
533 System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
534 assert_eq!(sys, benchmark_sys);
535 }
536
537 #[test]
538 fn test_rk4_dt_behavior() {
539 let base_sys = mock_rk4fixed_sys();
540
541 let mut sys_dt_smaller_than_t_report = System3TM {
543 solver_type: SolverTypes::RK4Fixed { dt: 1e-3 },
544 ..base_sys.clone()
545 };
546 sys_dt_smaller_than_t_report.walk();
547
548 let mut sys_dt_slightly_less_than_t_report = System3TM {
550 solver_type: SolverTypes::RK4Fixed {
551 dt: (base_sys.t_report[1] - base_sys.t_report[0]) * 0.9,
552 },
553 ..base_sys.clone()
554 };
555 sys_dt_slightly_less_than_t_report.walk();
556
557 assert!(
558 sys_dt_smaller_than_t_report.m1.history
559 != sys_dt_slightly_less_than_t_report.m1.history
560 );
561
562 let mut sys_dt_slightly_larger_than_t_report = System3TM {
564 solver_type: SolverTypes::RK4Fixed {
565 dt: (base_sys.t_report[1] - base_sys.t_report[0]) * 1.1,
566 },
567 ..base_sys.clone()
568 };
569
570 sys_dt_slightly_larger_than_t_report.walk();
571
572 let mut sys_dt_larger_than_t_report = System3TM {
574 solver_type: SolverTypes::RK4Fixed {
575 dt: (base_sys.t_report[1] - base_sys.t_report[0]) * 10.0,
576 },
577 ..base_sys.clone()
578 };
579 sys_dt_larger_than_t_report.walk();
580
581 assert!(
582 sys_dt_larger_than_t_report.m1.history != sys_dt_slightly_less_than_t_report.m1.history
583 );
584 assert!(
585 sys_dt_larger_than_t_report.m1.history
586 == sys_dt_slightly_larger_than_t_report.m1.history
587 );
588 }
589
590 #[test]
591 fn test_rk45_against_benchmark() {
592 let mut sys = mock_rk45_sys();
593 sys.walk();
594 let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
595 .parent()
596 .unwrap()
597 .to_path_buf()
598 .join("dess-examples/src/tests/fixtures/rk45 benchmark.yaml");
599
600 let benchmark_sys =
601 System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
602 assert_eq!(sys, benchmark_sys);
603 }
604}