dess_examples/
three_thermal_mass_sys.rs

1use crate::components::*;
2use crate::imports::*;
3
4/// System of connected components
5#[pyo3_api(
6    #[new]
7    fn __new__(
8        solver_type: String,
9        m1: ThermalMass,
10        m2: ThermalMass,
11        h12: Conductance,
12        m3: ThermalMass,
13        h23: Conductance,
14        t_report: Vec<f64>,
15    ) -> Self { 
16        Self{
17            solver_type: SolverTypes::from_json(&solver_type).unwrap(), 
18            m1, 
19            m2, 
20            h12, 
21            m3, 
22            h23, 
23            t_report,
24            state: Default::default(),
25            history: Default::default(),
26        }
27    }
28
29    #[classmethod]
30    #[allow(clippy::too_many_arguments)]
31    fn new_rk45_cash_karp(
32        _cls: &PyType,
33        sol: AdaptiveSolverConfig,
34        m1: ThermalMass,
35        m2: ThermalMass,
36        h12: Conductance,
37        m3: ThermalMass,
38        h23: Conductance,
39        t_report: Vec<f64>,
40    ) -> Self {
41        Self{
42            solver_type: SolverTypes::RK45CashKarp(Box::new(sol)),
43            m1, 
44            m2, 
45            h12, 
46            m3, 
47            h23, 
48            t_report,
49            state: Default::default(),
50            history: Default::default(),
51        }
52    }
53
54    #[classmethod]
55    #[allow(clippy::too_many_arguments)]
56    fn new_rk23_bogacki_shampine(
57        _cls: &PyType,
58        sol: AdaptiveSolverConfig,
59        m1: ThermalMass,
60        m2: ThermalMass,
61        h12: Conductance,
62        m3: ThermalMass,
63        h23: Conductance,
64        t_report: Vec<f64>,
65    ) -> Self {
66        Self{
67            solver_type: SolverTypes::RK23BogackiShampine(Box::new(sol)),
68            m1, 
69            m2, 
70            h12, 
71            m3, 
72            h23, 
73            t_report,
74            state: Default::default(),
75            history: Default::default(),
76        }
77    }
78
79    #[getter]
80    fn get_solver_conf(&self) -> Option<AdaptiveSolverConfig> {
81        match &self.solver_type {
82            SolverTypes::RK45CashKarp(sc) => Some(*sc.clone()),
83            SolverTypes::RK23BogackiShampine(sc) => Some(*sc.clone()),
84            _ => None,
85        }
86    }
87
88    #[getter]
89    fn get_solver_type(&self) -> String {
90        self.solver_type.to_json()
91    }
92
93    #[pyo3(name = "walk")]
94    fn walk_py(&mut self) {
95        self.walk();
96    }
97)]
98#[solver(
99    /// Updates time derivatives of states.
100    /// This method must be user defined in `solver` macro args.
101    fn update_derivs(&mut self) {
102        self.reset_derivs();
103        connect_states!(self, (m1, m2, h12), (m2, m3, h23));
104        update_derivs!(self, (m1, m2, h12), (m2, m3, h23));
105    }
106)]
107#[derive(HistoryMethods, BareClone)]
108#[common_derives]
109pub struct System3TM {
110    #[skip_get]
111    pub solver_type: SolverTypes,
112    // components
113    // the `use_state` attribute tells the SystemSolver TODO: finish this thought
114    #[use_state]
115    pub m1: ThermalMass,
116    #[use_state]
117    pub m2: ThermalMass,
118    /// h12 connects m1 to m2
119    #[save_state]
120    pub h12: Conductance,
121    #[use_state]
122    pub m3: ThermalMass,
123    #[save_state]
124    pub h23: Conductance,
125    // fields needed by `solver` procedural macro
126    pub t_report: Vec<f64>,
127    pub state: SystemState,
128    pub history: SystemStateHistoryVec,
129}
130
131impl Default for System3TM {
132    fn default() -> Self {
133        Self {
134            solver_type: SolverTypes::EulerFixed { dt: 5e-3 },
135            m1: ThermalMass {
136                c: 1.0,
137                state: ThermalMassState {
138                    temp: 0.0,
139                    dtemp: Default::default(),
140                },
141                history: Default::default(),
142            },
143            m2: ThermalMass {
144                c: 2.0,
145                state: ThermalMassState {
146                    temp: 10.,
147                    dtemp: Default::default(),
148                },
149                history: Default::default(),
150            },
151            h12: Conductance {
152                h: 5.0,
153                state: ConductanceState {
154                    q: Default::default(),
155                },
156                history: ConductanceStateHistoryVec {
157                    q: Default::default(),
158                },
159            },
160            m3: ThermalMass {
161                c: 1.5,
162                state: ThermalMassState {
163                    temp: 12.,
164                    dtemp: Default::default(),
165                },
166                history: Default::default(),
167            },
168            h23: Conductance {
169                h: 5.0,
170                state: ConductanceState {
171                    q: Default::default(),
172                },
173                history: ConductanceStateHistoryVec {
174                    q: Default::default(),
175                },
176            },
177            t_report: Vec::linspace(0.0, 1.0, 201),
178            state: Default::default(),
179            history: Default::default(),
180        }
181    }
182}
183
184pub fn mock_euler_sys() -> System3TM {
185    System3TM::default()
186}
187
188pub fn mock_heuns_sys() -> System3TM {
189    let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 51);
190
191    System3TM {
192        solver_type: SolverTypes::HeunsMethod { dt: 5e-3 },
193        t_report,
194        ..mock_euler_sys()
195    }
196}
197
198pub fn mock_midpoint_sys() -> System3TM {
199    let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 51);
200
201    System3TM {
202        solver_type: SolverTypes::MidpointMethod { dt: 5e-3 },
203        t_report,
204        ..mock_euler_sys()
205    }
206}
207
208pub fn mock_ralstons_sys() -> System3TM {
209    let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 51);
210
211    System3TM {
212        solver_type: SolverTypes::RalstonsMethod { dt: 5e-3 },
213        t_report,
214        ..mock_euler_sys()
215    }
216}
217
218pub fn mock_rk23_sys() -> System3TM {
219    let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 11);
220
221    System3TM {
222        solver_type: SolverTypes::RK23BogackiShampine(Box::default()),
223        t_report,
224        ..mock_euler_sys()
225    }
226}
227
228pub fn mock_rk4fixed_sys() -> System3TM {
229    let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 51);
230
231    System3TM {
232        solver_type: Default::default(),
233        t_report,
234        ..mock_euler_sys()
235    }
236}
237
238pub fn mock_rk45_sys() -> System3TM {
239    let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 11);
240
241    System3TM {
242        solver_type: SolverTypes::RK45CashKarp(Box::default()),
243        t_report,
244        ..mock_euler_sys()
245    }
246}
247
248pub fn run_three_tm_sys(overwrite_benchmarks: bool) {
249    // build and run prescribed-step Euler system
250    let mut sys_euler = mock_euler_sys();
251
252    let t_euler = time_it!(sys_euler.walk());
253
254    let dt = sys_euler.t_report[1] - sys_euler.t_report.first().unwrap();
255
256    println!(
257        "Euler {} s time step elapsed time: {} μs",
258        dt,
259        t_euler.as_micros()
260    );
261
262    let overwrite_euler_benchmark: bool = overwrite_benchmarks;
263    if overwrite_euler_benchmark {
264        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
265            .parent()
266            .unwrap()
267            .to_path_buf()
268            .join("dess-examples/src/tests/fixtures/euler benchmark.yaml");
269
270        sys_euler
271            .to_file(benchmark_file.as_os_str().to_str().unwrap())
272            .unwrap();
273    }
274
275    // build and run prescribed-step Heuns system
276    let mut sys_heuns = mock_heuns_sys();
277
278    let t_heuns = time_it!(sys_heuns.walk());
279
280    let dt = sys_heuns.t_report[1] - sys_heuns.t_report.first().unwrap();
281
282    println!(
283        "Heuns {} s time step elapsed time: {} μs",
284        dt,
285        t_heuns.as_micros()
286    );
287
288    let overwrite_heuns_benchmark: bool = overwrite_benchmarks;
289    if overwrite_heuns_benchmark {
290        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
291            .parent()
292            .unwrap()
293            .to_path_buf()
294            .join("dess-examples/src/tests/fixtures/heuns benchmark.yaml");
295
296        sys_heuns
297            .to_file(benchmark_file.as_os_str().to_str().unwrap())
298            .unwrap();
299    }
300    // build and run prescribed-step midpoint system
301    let mut sys_midpoint = mock_midpoint_sys();
302
303    let t_midpoint = time_it!(sys_midpoint.walk());
304
305    let dt = sys_midpoint.t_report[1] - sys_midpoint.t_report.first().unwrap();
306
307    println!(
308        "Midpoint {} s time step elapsed time: {} μs",
309        dt,
310        t_midpoint.as_micros()
311    );
312
313    let overwrite_midpoint_benchmark: bool = overwrite_benchmarks;
314    if overwrite_midpoint_benchmark {
315        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
316            .parent()
317            .unwrap()
318            .to_path_buf()
319            .join("dess-examples/src/tests/fixtures/midpoint benchmark.yaml");
320
321        sys_midpoint
322            .to_file(benchmark_file.as_os_str().to_str().unwrap())
323            .unwrap();
324    }
325    // build and run prescribed-step Ralston's system
326    let mut sys_ralstons = mock_ralstons_sys();
327
328    let t_ralstons = time_it!(sys_ralstons.walk());
329
330    let dt = sys_ralstons.t_report[1] - sys_ralstons.t_report.first().unwrap();
331
332    println!(
333        "Ralstons {} s time step elapsed time: {} μs",
334        dt,
335        t_ralstons.as_micros()
336    );
337
338    let overwrite_ralstons_benchmark: bool = overwrite_benchmarks;
339    if overwrite_ralstons_benchmark {
340        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
341            .parent()
342            .unwrap()
343            .to_path_buf()
344            .join("dess-examples/src/tests/fixtures/ralstons benchmark.yaml");
345
346        sys_ralstons
347            .to_file(benchmark_file.as_os_str().to_str().unwrap())
348            .unwrap();
349    }
350    // build and run adaptive RK23
351    let mut sys_rk23 = mock_rk23_sys();
352    let t_rk23 = time_it!(sys_rk23.walk());
353
354    let dt = sys_rk23.t_report[1] - sys_rk23.t_report.first().unwrap();
355
356    println!(
357        "RK23 Adaptive {} s init time step elapsed time: {} μs",
358        dt,
359        t_rk23.as_micros()
360    );
361
362    let overwrite_rk23_benchmark: bool = overwrite_benchmarks;
363    if overwrite_rk23_benchmark {
364        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
365            .parent()
366            .unwrap()
367            .to_path_buf()
368            .join("dess-examples/src/tests/fixtures/rk23 benchmark.yaml");
369        sys_rk23
370            .to_file(benchmark_file.as_os_str().to_str().unwrap())
371            .unwrap();
372    }
373    // build and run prescribed-step 4th-order Runge-Kutta system
374    let mut sys_rk4 = mock_rk4fixed_sys();
375
376    let t_rk4 = time_it!(sys_rk4.walk());
377
378    let dt = sys_rk4.t_report[1] - sys_rk4.t_report.first().unwrap();
379
380    println!(
381        "RK4 {} s time step elapsed time: {} μs",
382        dt,
383        t_rk4.as_micros()
384    );
385
386    let overwrite_rk4_benchmark: bool = overwrite_benchmarks;
387    if overwrite_rk4_benchmark {
388        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
389            .parent()
390            .unwrap()
391            .to_path_buf()
392            .join("dess-examples/src/tests/fixtures/rk4 benchmark.yaml");
393
394        sys_rk4
395            .to_file(benchmark_file.as_os_str().to_str().unwrap())
396            .unwrap();
397    }
398
399    // build and run adaptive RK45
400    let mut sys_rk45 = mock_rk45_sys();
401    let t_rk45 = time_it!(sys_rk45.walk());
402
403    let dt = sys_rk45.t_report[1] - sys_rk45.t_report.first().unwrap();
404
405    println!(
406        "RK45 Adaptive {} s init time step elapsed time: {} μs",
407        dt,
408        t_rk45.as_micros()
409    );
410
411    let overwrite_rk45_benchmark: bool = overwrite_benchmarks;
412    if overwrite_rk45_benchmark {
413        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
414            .parent()
415            .unwrap()
416            .to_path_buf()
417            .join("dess-examples/src/tests/fixtures/rk45 benchmark.yaml");
418        sys_rk45
419            .to_file(benchmark_file.as_os_str().to_str().unwrap())
420            .unwrap();
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_bare_clone() {
430        let mut sys = mock_euler_sys();
431        assert!(sys.history.is_empty());
432        assert!(sys.m1.history.is_empty());
433        assert!(sys.h12.history.is_empty());
434        sys.save_state();
435        // verify that at least a couple of the expected changes happened
436        assert!(sys.history.len() == 1);
437        assert!(sys.m1.history.len() == 1);
438        assert!(sys.h12.history.len() == 1);
439        let bare_sys = sys.bare_clone();
440        assert!(bare_sys.history.is_empty());
441        assert!(bare_sys.m1.history.is_empty());
442        assert!(bare_sys.h12.history.is_empty());
443    }
444
445    #[test]
446    fn test_euler_against_benchmark() {
447        let mut sys = mock_euler_sys();
448        sys.walk();
449
450        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
451            .parent()
452            .unwrap()
453            .to_path_buf()
454            .join("dess-examples/src/tests/fixtures/euler benchmark.yaml");
455
456        let benchmark_sys =
457            System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
458        assert_eq!(sys, benchmark_sys);
459    }
460
461    #[test]
462    fn test_heuns_against_benchmark() {
463        let mut sys = mock_heuns_sys();
464        sys.walk();
465
466        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
467            .parent()
468            .unwrap()
469            .to_path_buf()
470            .join("dess-examples/src/tests/fixtures/heuns benchmark.yaml");
471
472        let benchmark_sys =
473            System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
474        assert_eq!(sys, benchmark_sys);
475    }
476
477    #[test]
478    fn test_midpoint_against_benchmark() {
479        let mut sys = mock_midpoint_sys();
480        sys.walk();
481
482        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
483            .parent()
484            .unwrap()
485            .to_path_buf()
486            .join("dess-examples/src/tests/fixtures/midpoint benchmark.yaml");
487
488        let benchmark_sys =
489            System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
490        assert_eq!(sys, benchmark_sys);
491    }
492
493    #[test]
494    fn test_ralstons_against_benchmark() {
495        let mut sys = mock_ralstons_sys();
496        sys.walk();
497
498        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
499            .parent()
500            .unwrap()
501            .to_path_buf()
502            .join("dess-examples/src/tests/fixtures/ralstons benchmark.yaml");
503
504        let benchmark_sys =
505            System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
506        assert_eq!(sys, benchmark_sys);
507    }
508    #[test]
509    fn test_rk23_against_benchmark() {
510        let mut sys = mock_rk23_sys();
511        sys.walk();
512        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
513            .parent()
514            .unwrap()
515            .to_path_buf()
516            .join("dess-examples/src/tests/fixtures/rk23 benchmark.yaml");
517
518        let benchmark_sys =
519            System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
520        assert_eq!(sys, benchmark_sys);
521    }
522    #[test]
523    fn test_rk4_against_benchmark() {
524        let mut sys = mock_rk4fixed_sys();
525        sys.walk();
526        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
527            .parent()
528            .unwrap()
529            .to_path_buf()
530            .join("dess-examples/src/tests/fixtures/rk4 benchmark.yaml");
531
532        let benchmark_sys =
533            System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
534        assert_eq!(sys, benchmark_sys);
535    }
536
537    #[test]
538    fn test_rk4_dt_behavior() {
539        let base_sys = mock_rk4fixed_sys();
540
541        // system for checking if small dt results in relatively higher accuracy
542        let mut sys_dt_smaller_than_t_report = System3TM {
543            solver_type: SolverTypes::RK4Fixed { dt: 1e-3 },
544            ..base_sys.clone()
545        };
546        sys_dt_smaller_than_t_report.walk();
547
548        // system for checking if dt slightly less than t_report works ok
549        let mut sys_dt_slightly_less_than_t_report = System3TM {
550            solver_type: SolverTypes::RK4Fixed {
551                dt: (base_sys.t_report[1] - base_sys.t_report[0]) * 0.9,
552            },
553            ..base_sys.clone()
554        };
555        sys_dt_slightly_less_than_t_report.walk();
556
557        assert!(
558            sys_dt_smaller_than_t_report.m1.history
559                != sys_dt_slightly_less_than_t_report.m1.history
560        );
561
562        // system for checking that t_report overrides dt when dt is slightly larger than t_report
563        let mut sys_dt_slightly_larger_than_t_report = System3TM {
564            solver_type: SolverTypes::RK4Fixed {
565                dt: (base_sys.t_report[1] - base_sys.t_report[0]) * 1.1,
566            },
567            ..base_sys.clone()
568        };
569
570        sys_dt_slightly_larger_than_t_report.walk();
571
572        // system for checking that t_report overrides dt when dt is large
573        let mut sys_dt_larger_than_t_report = System3TM {
574            solver_type: SolverTypes::RK4Fixed {
575                dt: (base_sys.t_report[1] - base_sys.t_report[0]) * 10.0,
576            },
577            ..base_sys.clone()
578        };
579        sys_dt_larger_than_t_report.walk();
580
581        assert!(
582            sys_dt_larger_than_t_report.m1.history != sys_dt_slightly_less_than_t_report.m1.history
583        );
584        assert!(
585            sys_dt_larger_than_t_report.m1.history
586                == sys_dt_slightly_larger_than_t_report.m1.history
587        );
588    }
589
590    #[test]
591    fn test_rk45_against_benchmark() {
592        let mut sys = mock_rk45_sys();
593        sys.walk();
594        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
595            .parent()
596            .unwrap()
597            .to_path_buf()
598            .join("dess-examples/src/tests/fixtures/rk45 benchmark.yaml");
599
600        let benchmark_sys =
601            System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
602        assert_eq!(sys, benchmark_sys);
603    }
604}