dess_examples/
three_thrml_mass_w_bc_sys.rs

1use crate::components::*;
2use crate::imports::*;
3
4/// System of connected components
5#[derive(HistoryMethods, BareClone)]
6#[pyo3_api(
7    #[new]
8    fn __new__(
9        solver_type: String,
10        m1: ThermalReservoir,
11        m2: ThermalMass,
12        h12: Conductance,
13        m3: ThermalMass,
14        h23: Conductance,
15        t_report: Vec<f64>,
16    ) -> Self {
17        Self{
18            solver_type: SolverTypes::from_json(&solver_type).unwrap(),
19            m1,
20            m2,
21            h12,
22            m3,
23            h23,
24            t_report,
25            state: Default::default(),
26            history: Default::default(),
27        }
28    }
29
30    #[classmethod]
31    #[allow(clippy::too_many_arguments)]
32    fn new_rk45_cash_karp(
33        _cls: &PyType,
34        sol: AdaptiveSolverConfig,
35        m1: ThermalReservoir,
36        m2: ThermalMass,
37        h12: Conductance,
38        m3: ThermalMass,
39        h23: Conductance,
40        t_report: Vec<f64>,
41    ) -> Self {
42        Self{
43            solver_type: SolverTypes::RK45CashKarp(Box::new(sol)),
44            m1,
45            m2,
46            h12,
47            m3,
48            h23,
49            t_report,
50            state: Default::default(),
51            history: Default::default(),
52        }
53    }
54
55    #[classmethod]
56    #[allow(clippy::too_many_arguments)]
57    fn new_rk23_bogacki_shampine(
58        _cls: &PyType,
59        sol: AdaptiveSolverConfig,
60        m1: ThermalReservoir,
61        m2: ThermalMass,
62        h12: Conductance,
63        m3: ThermalMass,
64        h23: Conductance,
65        t_report: Vec<f64>,
66    ) -> Self {
67        Self{
68            solver_type: SolverTypes::RK23BogackiShampine(Box::new(sol)),
69            m1,
70            m2,
71            h12,
72            m3,
73            h23,
74            t_report,
75            state: Default::default(),
76            history: Default::default(),
77        }
78    }
79
80    #[getter]
81    fn get_solver_conf(&self) -> Option<AdaptiveSolverConfig> {
82        match &self.solver_type {
83            SolverTypes::RK45CashKarp(sc) => Some(*sc.clone()),
84            SolverTypes::RK23BogackiShampine(sc) => Some(*sc.clone()),
85            _ => None,
86        }
87    }
88
89    #[getter]
90    fn get_solver_type(&self) -> String {
91        self.solver_type.to_json()
92    }
93
94    #[pyo3(name = "walk")]
95    fn walk_py(&mut self) {
96        self.walk();
97    }
98)]
99#[solver(
100    /// Updates time derivatives of states.
101    /// This method must be user defined in `solver` macro args.
102    fn update_derivs(&mut self) {
103        self.reset_derivs();
104        // change in temperature to test what the solver does
105        let offset = 10.0;
106        let freq = 100. * self.state.time;
107        let lag = 0.25;
108        self.m1.state.temp = offset + 3. * f64::sin(freq * self.state.time) * f64::exp(-self.state.time / lag);
109        //create system of equations/relationships
110        connect_states!(self, (m1, m2, h12), (m2, m3, h23));
111        //update derivs based on system of equations created
112        update_derivs!(self, (m1, m2, h12), (m2, m3, h23));
113    }
114)]
115#[common_derives]
116pub struct System3TMWithBC {
117    #[skip_get]
118    pub solver_type: SolverTypes,
119    // components
120    #[use_state]
121    pub m1: ThermalReservoir,
122    #[use_state]
123    pub m2: ThermalMass,
124    /// h12 connects m1 to m2
125    #[save_state]
126    pub h12: Conductance,
127    #[use_state]
128    pub m3: ThermalMass,
129    #[save_state]
130    pub h23: Conductance,
131    // fields needed by `solver` procedural macro
132    pub t_report: Vec<f64>,
133    pub state: SystemState,
134    pub history: SystemStateHistoryVec,
135}
136
137impl Default for System3TMWithBC {
138    fn default() -> Self {
139        Self {
140            solver_type: SolverTypes::EulerFixed { dt: 5e-3 },
141            m1: ThermalReservoir {
142                state: ThermalMassState {
143                    temp: -1.0,
144                    dtemp: Default::default(),
145                },
146                history: Default::default(),
147            },
148            m2: ThermalMass {
149                c: 2.0,
150                state: ThermalMassState {
151                    temp: 10.,
152                    dtemp: Default::default(),
153                },
154                history: Default::default(),
155            },
156            h12: Conductance {
157                h: 5.0,
158                state: ConductanceState {
159                    q: Default::default(),
160                },
161                history: ConductanceStateHistoryVec {
162                    q: Default::default(),
163                },
164            },
165            m3: ThermalMass {
166                c: 1.5,
167                state: ThermalMassState {
168                    temp: 12.,
169                    dtemp: Default::default(),
170                },
171                history: Default::default(),
172            },
173            h23: Conductance {
174                h: 5.0,
175                state: ConductanceState {
176                    q: Default::default(),
177                },
178                history: ConductanceStateHistoryVec {
179                    q: Default::default(),
180                },
181            },
182            t_report: Vec::linspace(0.0, 1.0, 201),
183            state: Default::default(),
184            history: Default::default(),
185        }
186    }
187}
188
189#[derive(Copy, HistoryVec, Default)]
190#[common_derives]
191#[pyo3_api]
192pub struct SystemState3TM {
193    // current index in `t_report`
194    i: usize,
195    // current time
196    time: f64,
197}
198pub fn mock_euler_sys() -> System3TMWithBC {
199    System3TMWithBC::default()
200}
201
202pub fn mock_heuns_sys() -> System3TMWithBC {
203    let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 51);
204
205    System3TMWithBC {
206        solver_type: SolverTypes::HeunsMethod { dt: 5e-3 },
207        t_report,
208        ..mock_euler_sys()
209    }
210}
211
212pub fn mock_midpoint_sys() -> System3TMWithBC {
213    let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 51);
214
215    System3TMWithBC {
216        solver_type: SolverTypes::MidpointMethod { dt: 5e-3 },
217        t_report,
218        ..mock_euler_sys()
219    }
220}
221
222pub fn mock_ralstons_sys() -> System3TMWithBC {
223    let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 51);
224
225    System3TMWithBC {
226        solver_type: SolverTypes::RalstonsMethod { dt: 5e-3 },
227        t_report,
228        ..mock_euler_sys()
229    }
230}
231
232pub fn mock_rk23_sys() -> System3TMWithBC {
233    let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 11);
234
235    System3TMWithBC {
236        solver_type: SolverTypes::RK23BogackiShampine(Box::default()),
237        t_report,
238        ..mock_euler_sys()
239    }
240}
241
242pub fn mock_rk4fixed_sys() -> System3TMWithBC {
243    let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 51);
244
245    System3TMWithBC {
246        solver_type: Default::default(),
247        t_report,
248        ..mock_euler_sys()
249    }
250}
251
252pub fn mock_rk45_sys() -> System3TMWithBC {
253    let t_report: Vec<f64> = Vec::linspace(0.0, 1.0, 11);
254
255    System3TMWithBC {
256        solver_type: SolverTypes::RK45CashKarp(Box::default()),
257        t_report,
258        ..mock_euler_sys()
259    }
260}
261
262pub fn run_three_tm_w_bc_sys() {
263    //let overwrite_benchmarks = false;
264    // build and run prescribed-step Euler system
265    let mut sys_euler = mock_euler_sys();
266
267    let t_euler = time_it!(sys_euler.walk());
268
269    let dt = sys_euler.t_report[1] - sys_euler.t_report.first().unwrap();
270
271    println!(
272        "Euler {} s time step elapsed time: {} μs",
273        dt,
274        t_euler.as_micros()
275    );
276
277    /* let overwrite_euler_benchmark: bool = overwrite_benchmarks;
278    if overwrite_euler_benchmark {
279        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
280            .parent()
281            .unwrap()
282            .to_path_buf()
283            .join("dess-examples/src/tests/fixtures/euler_benchmark_w_bc.yaml");
284
285        sys_euler
286            .to_file(benchmark_file.as_os_str().to_str().unwrap())
287            .unwrap();
288    } */
289
290    // build and run prescribed-step Heuns system
291    let mut sys_heuns = mock_heuns_sys();
292
293    let t_heuns = time_it!(sys_heuns.walk());
294
295    let dt = sys_heuns.t_report[1] - sys_heuns.t_report.first().unwrap();
296
297    println!(
298        "Heuns {} s time step elapsed time: {} μs",
299        dt,
300        t_heuns.as_micros()
301    );
302
303    /* let overwrite_heuns_benchmark: bool = overwrite_benchmarks;
304    if overwrite_heuns_benchmark {
305        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
306            .parent()
307            .unwrap()
308            .to_path_buf()
309            .join("dess-examples/src/tests/fixtures/heuns_benchmark_w_bc.yaml");
310
311        sys_heuns
312            .to_file(benchmark_file.as_os_str().to_str().unwrap())
313            .unwrap();
314    } */
315
316    // build and run prescribed-step Midpoint system
317    let mut sys_midpoint = mock_midpoint_sys();
318
319    let t_midpoint = time_it!(sys_midpoint.walk());
320
321    let dt = sys_midpoint.t_report[1] - sys_midpoint.t_report.first().unwrap();
322
323    println!(
324        "Midpoint {} s time step elapsed time: {} μs",
325        dt,
326        t_midpoint.as_micros()
327    );
328
329    /* let overwrite_midpoint_benchmark: bool = overwrite_benchmarks;
330    if overwrite_midpoint_benchmark {
331        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
332            .parent()
333            .unwrap()
334            .to_path_buf()
335            .join("dess-examples/src/tests/fixtures/midpoint_benchmark_w_bc.yaml");
336
337        sys_midpoint
338            .to_file(benchmark_file.as_os_str().to_str().unwrap())
339            .unwrap();
340    } */
341
342    // build and run prescribed-step Ralston's system
343    let mut sys_ralstons = mock_ralstons_sys();
344
345    let t_ralstons = time_it!(sys_ralstons.walk());
346
347    let dt = sys_ralstons.t_report[1] - sys_ralstons.t_report.first().unwrap();
348
349    println!(
350        "Ralstons {} s time step elapsed time: {} μs",
351        dt,
352        t_ralstons.as_micros()
353    );
354
355    /* let overwrite_ralstons_benchmark: bool = overwrite_benchmarks;
356    if overwrite_ralstons_benchmark {
357        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
358            .parent()
359            .unwrap()
360            .to_path_buf()
361            .join("dess-examples/src/tests/fixtures/ralstons_benchmark_w_bc.yaml");
362
363        sys_ralstons
364            .to_file(benchmark_file.as_os_str().to_str().unwrap())
365            .unwrap();
366    } */
367
368    // build and run adaptive RK23
369    let mut sys_rk23 = mock_rk23_sys();
370
371    let t_rk23 = time_it!(sys_rk23.walk());
372
373    let dt = sys_rk23.t_report[1] - sys_rk23.t_report.first().unwrap();
374
375    println!(
376        "RK23 Adaptive {} s init time step elapsed time: {} μs",
377        dt,
378        t_rk23.as_micros()
379    );
380
381    // let overwrite_rk23_benchmark: bool = overwrite_benchmarks;
382    // if overwrite_rk23_benchmark {
383    //     let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
384    //         .parent()
385    //         .unwrap()
386    //         .to_path_buf()
387    //         .join("dess-examples/tests/fixtures/rk23 benchmark.yaml");
388
389    //     sys_rk23
390    //         .to_file(benchmark_file.as_os_str().to_str().unwrap())
391    //         .unwrap();
392    // }
393
394    // build and run prescribed-step 4th-order Runge-Kutta system
395    let mut sys_rk4 = mock_rk4fixed_sys();
396
397    let t_rk4 = time_it!(sys_rk4.walk());
398
399    let dt = sys_rk4.t_report[1] - sys_rk4.t_report.first().unwrap();
400
401    println!(
402        "RK4 {} s time step elapsed time: {} μs",
403        dt,
404        t_rk4.as_micros()
405    );
406
407    /* let overwrite_rk4_benchmark: bool = overwrite_benchmarks;
408    if overwrite_rk4_benchmark {
409        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
410            .parent()
411            .unwrap()
412            .to_path_buf()
413            .join("dess-examples/src/tests/fixtures/rk4_benchmark_w_bc.yaml");
414
415        sys_rk4
416            .to_file(benchmark_file.as_os_str().to_str().unwrap())
417            .unwrap();
418    } */
419
420    // build and run adaptive RK45
421    let mut sys_rk45 = mock_rk45_sys();
422
423    let t_rk45 = time_it!(sys_rk45.walk());
424
425    let dt = sys_rk45.t_report[1] - sys_rk45.t_report.first().unwrap();
426
427    println!(
428        "RK45 Adaptive {} s init time step elapsed time: {} μs",
429        dt,
430        t_rk45.as_micros()
431    );
432
433    /* let overwrite_rk45_benchmark: bool = overwrite_benchmarks;
434    if overwrite_rk45_benchmark {
435        let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
436            .parent()
437            .unwrap()
438            .to_path_buf()
439            .join("dess-examples/src/tests/fixtures/rk45_benchmark_w_bc.yaml");
440
441        sys_rk45
442            .to_file(benchmark_file.as_os_str().to_str().unwrap())
443            .unwrap();
444    } */
445}
446
447// #[cfg(test)]
448// mod tests {
449//     use super::*;
450
451//     #[test]
452//     fn test_bare_clone() {
453//         let mut sys = mock_euler_sys();
454//         assert!(sys.history.is_empty());
455//         assert!(sys.m1.history.is_empty());
456//         assert!(sys.h12.history.is_empty());
457//         sys.save_state();
458//         // verify that at least a couple of the expected changes happened
459//         assert!(sys.history.len() == 1);
460//         assert!(sys.m1.history.len() == 1);
461//         assert!(sys.h12.history.len() == 1);
462//         let bare_sys = sys.bare_clone();
463//         assert!(bare_sys.history.is_empty());
464//         assert!(bare_sys.m1.history.is_empty());
465//         assert!(bare_sys.h12.history.is_empty());
466//     }
467
468//     #[test]
469//     fn test_euler_against_benchmark() {
470//         let mut sys = mock_euler_sys();
471//         sys.walk();
472
473//         let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
474//             .parent()
475//             .unwrap()
476//             .to_path_buf()
477//             .join("dess-examples/tests/fixtures/euler benchmark.yaml");
478
479//         let benchmark_sys =
480//             System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
481//         assert_eq!(sys, benchmark_sys);
482//     }
483
484//     #[test]
485//     fn test_rk4_against_benchmark() {
486//         let mut sys = mock_rk4fixed_sys();
487//         sys.walk();
488//         let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
489//             .parent()
490//             .unwrap()
491//             .to_path_buf()
492//             .join("dess-examples/tests/fixtures/rk4 benchmark.yaml");
493
494//         let benchmark_sys =
495//             System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
496//         assert_eq!(sys, benchmark_sys);
497//     }
498
499//     #[test]
500//     fn test_rk4_dt_behavior() {
501//         let base_sys = mock_rk4fixed_sys();
502
503//         // system for checking if small dt results in relatively higher accuracy
504//         let mut sys_dt_smaller_than_t_report = System3TM {
505//             solver_type: SolverTypes::RK4Fixed { dt: 1e-3 },
506//             ..base_sys.clone()
507//         };
508//         sys_dt_smaller_than_t_report.walk();
509
510//         // system for checking if dt slightly less than t_report works ok
511//         let mut sys_dt_slightly_less_than_t_report = System3TM {
512//             solver_type: SolverTypes::RK4Fixed {
513//                 dt: (base_sys.t_report[1].clone() - base_sys.t_report[0].clone()) * 0.9,
514//             },
515//             ..base_sys.clone()
516//         };
517//         sys_dt_slightly_less_than_t_report.walk();
518
519//         assert!(
520//             sys_dt_smaller_than_t_report.m1.history
521//                 != sys_dt_slightly_less_than_t_report.m1.history
522//         );
523
524//         // system for checking that t_report overrides dt when dt is slightly larger than t_report
525//         let mut sys_dt_slightly_larger_than_t_report = System3TM {
526//             solver_type: SolverTypes::RK4Fixed {
527//                 dt: (base_sys.t_report[1].clone() - base_sys.t_report[0].clone()) * 1.1,
528//             },
529//             ..base_sys.clone()
530//         };
531
532//         sys_dt_slightly_larger_than_t_report.walk();
533
534//         // system for checking that t_report overrides dt when dt is large
535//         let mut sys_dt_larger_than_t_report = System3TM {
536//             solver_type: SolverTypes::RK4Fixed {
537//                 dt: (base_sys.t_report[1].clone() - base_sys.t_report[0].clone()) * 10.0,
538//             },
539//             ..base_sys.clone()
540//         };
541//         sys_dt_larger_than_t_report.walk();
542
543//         assert!(
544//             sys_dt_larger_than_t_report.m1.history != sys_dt_slightly_less_than_t_report.m1.history
545//         );
546//         assert!(
547//             sys_dt_larger_than_t_report.m1.history
548//                 == sys_dt_slightly_larger_than_t_report.m1.history
549//         );
550//     }
551
552//     #[test]
553//     fn test_rk45_against_benchmark() {
554//         let mut sys = mock_rk45_sys();
555//         sys.walk();
556//         let benchmark_file = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
557//             .parent()
558//             .unwrap()
559//             .to_path_buf()
560//             .join("dess-examples/tests/fixtures/rk45 benchmark.yaml");
561
562//         let benchmark_sys =
563//             System3TM::from_file(benchmark_file.as_os_str().to_str().unwrap()).unwrap();
564//         assert_eq!(sys, benchmark_sys);
565//     }
566// }