altrios_core/train/
set_speed_train_sim.rs

1use super::environment::TemperatureTrace;
2use super::train_imports::*;
3
4#[serde_api]
5#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
6#[cfg_attr(feature = "pyo3", pyclass(module = "altrios", subclass, eq))]
7pub struct SpeedTrace {
8    /// simulation time
9    pub time: Vec<si::Time>,
10    /// simulation speed
11    pub speed: Vec<si::Velocity>,
12    /// Whether engine is on
13    pub engine_on: Option<Vec<bool>>,
14}
15
16#[pyo3_api]
17impl SpeedTrace {
18    #[new]
19    #[pyo3(signature = (
20        time_seconds,
21        speed_meters_per_second,
22        engine_on=None
23    ))]
24    fn __new__(
25        time_seconds: Vec<f64>,
26        speed_meters_per_second: Vec<f64>,
27        engine_on: Option<Vec<bool>>,
28    ) -> anyhow::Result<Self> {
29        Ok(Self::new(time_seconds, speed_meters_per_second, engine_on))
30    }
31
32    #[staticmethod]
33    #[pyo3(name = "from_csv_file")]
34    fn from_csv_file_py(filepath: &Bound<PyAny>) -> anyhow::Result<Self> {
35        Self::from_csv_file(PathBuf::extract_bound(filepath)?)
36    }
37
38    fn __len__(&self) -> usize {
39        self.len()
40    }
41
42    #[pyo3(name = "to_csv_file")]
43    fn to_csv_file_py(&self, filepath: &Bound<PyAny>) -> anyhow::Result<()> {
44        self.to_csv_file(PathBuf::extract_bound(filepath)?)
45    }
46
47    #[staticmethod]
48    #[pyo3(name = "default")]
49    fn default_py() -> Self {
50        Self::default()
51    }
52}
53
54impl SpeedTrace {
55    pub fn new(time_s: Vec<f64>, speed_mps: Vec<f64>, engine_on: Option<Vec<bool>>) -> Self {
56        SpeedTrace {
57            time: time_s.iter().map(|x| uc::S * (*x)).collect(),
58            speed: speed_mps.iter().map(|x| uc::MPS * (*x)).collect(),
59            engine_on,
60        }
61    }
62
63    pub fn trim(&mut self, start_idx: Option<usize>, end_idx: Option<usize>) -> anyhow::Result<()> {
64        let start_idx = start_idx.unwrap_or(0);
65        let end_idx = end_idx.unwrap_or_else(|| self.len());
66        ensure!(end_idx <= self.len(), format_dbg!(end_idx <= self.len()));
67
68        self.time = self.time[start_idx..end_idx].to_vec();
69        self.speed = self.speed[start_idx..end_idx].to_vec();
70        self.engine_on = self
71            .engine_on
72            .as_ref()
73            .map(|eo| eo[start_idx..end_idx].to_vec());
74        Ok(())
75    }
76
77    pub fn dt(&self, i: usize) -> si::Time {
78        self.time[i] - self.time[i - 1]
79    }
80
81    pub fn mean(&self, i: usize) -> si::Velocity {
82        0.5 * (self.speed[i] + self.speed[i - 1])
83    }
84
85    pub fn acc(&self, i: usize) -> si::Acceleration {
86        (self.speed[i] - self.speed[i - 1]) / self.dt(i)
87    }
88
89    pub fn len(&self) -> usize {
90        self.time.len()
91    }
92
93    /// method to prevent rust-analyzer from complaining
94    pub fn is_empty(&self) -> bool {
95        self.time.is_empty() && self.speed.is_empty() && self.engine_on.is_none()
96    }
97
98    pub fn push(&mut self, speed_element: SpeedTraceElement) -> anyhow::Result<()> {
99        self.time.push(speed_element.time);
100        self.speed.push(speed_element.speed);
101        self.engine_on
102            .as_mut()
103            .map(|eo| match speed_element.engine_on {
104                Some(seeeo) => {
105                    eo.push(seeeo);
106                    Ok(())
107                }
108                None => bail!(
109                    "`engine_one` in `SpeedTraceElement` and `SpeedTrace` must both have same option variant."),
110            });
111        Ok(())
112    }
113
114    pub fn empty() -> Self {
115        Self {
116            time: Vec::new(),
117            speed: Vec::new(),
118            engine_on: None,
119        }
120    }
121
122    /// Load speed trace from csv file
123    pub fn from_csv_file<P: AsRef<Path>>(filepath: P) -> anyhow::Result<Self> {
124        let filepath = filepath.as_ref();
125
126        // create empty SpeedTrace to be populated
127        let mut st = Self::empty();
128
129        let file = File::open(filepath)?;
130        let mut rdr = csv::ReaderBuilder::new()
131            .has_headers(true)
132            .from_reader(file);
133        for result in rdr.deserialize() {
134            let st_elem: SpeedTraceElement = result?;
135            st.push(st_elem)?;
136        }
137        ensure!(
138            !st.is_empty(),
139            "Invalid SpeedTrace file {:?}; SpeedTrace is empty",
140            filepath
141        );
142        Ok(st)
143    }
144
145    /// Save speed trace to csv file
146    pub fn to_csv_file<P: AsRef<Path>>(&self, filepath: P) -> anyhow::Result<()> {
147        let file = std::fs::OpenOptions::new()
148            .write(true)
149            .create(true)
150            .truncate(true)
151            .open(filepath)?;
152        let mut wrtr = csv::WriterBuilder::new()
153            .has_headers(true)
154            .from_writer(file);
155        let engine_on: Vec<Option<bool>> = match &self.engine_on {
156            Some(eo_vec) => eo_vec
157                .iter()
158                .map(|eo| Some(*eo))
159                .collect::<Vec<Option<bool>>>(),
160            None => vec![None; self.len()],
161        };
162        for ((time, speed), engine_on) in self.time.iter().zip(&self.speed).zip(engine_on) {
163            wrtr.serialize(SpeedTraceElement {
164                time: *time,
165                speed: *speed,
166                engine_on,
167            })?;
168        }
169        wrtr.flush()?;
170        Ok(())
171    }
172}
173
174impl Init for SpeedTrace {}
175impl SerdeAPI for SpeedTrace {}
176
177impl Default for SpeedTrace {
178    fn default() -> Self {
179        let mut speed_mps: Vec<f64> = Vec::linspace(0.0, 20.0, 800);
180        speed_mps.append(&mut [20.0; 100].to_vec());
181        speed_mps.append(&mut Vec::linspace(20.0, 0.0, 200));
182        speed_mps.push(0.0);
183        let time_s: Vec<f64> = (0..speed_mps.len()).map(|x| x as f64).collect();
184        Self::new(time_s, speed_mps, None)
185    }
186}
187
188/// Element of [SpeedTrace].  Used for vec-like operations.
189#[derive(Default, Debug, Serialize, Deserialize, PartialEq)]
190pub struct SpeedTraceElement {
191    /// simulation time
192    #[serde(alias = "time_seconds")]
193    time: si::Time,
194    /// prescribed speed
195    #[serde(alias = "speed_meters_per_second")]
196    speed: si::Velocity,
197    /// whether engine is on
198    engine_on: Option<bool>,
199}
200
201#[serde_api]
202#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
203#[cfg_attr(feature = "pyo3", pyclass(module = "altrios", subclass, eq))]
204/// Train simulation in which speed is prescribed.  Note that this is not guaranteed to
205/// produce identical results to [super::SpeedLimitTrainSim] because of differences in braking
206/// controls but should generally be very close (i.e. error in cumulative fuel/battery energy
207/// should be less than 0.1%)
208pub struct SetSpeedTrainSim {
209    pub loco_con: Consist,
210    pub n_cars_by_type: HashMap<String, u32>,
211    #[serde(default)]
212    pub state: TrainState,
213    pub speed_trace: SpeedTrace,
214
215    /// train resistance calculation
216    pub train_res: TrainRes,
217
218    path_tpc: PathTpc,
219    /// Custom vector of [Self::state]
220    #[serde(default)]
221    pub history: TrainStateHistoryVec,
222
223    save_interval: Option<usize>,
224    /// Time-dependent temperature at sea level that can be corrected for
225    /// altitude using a standard model
226    temp_trace: Option<TemperatureTrace>,
227}
228
229#[pyo3_api]
230impl SetSpeedTrainSim {
231    #[getter]
232    pub fn get_res_strap(&self) -> anyhow::Result<Option<method::Strap>> {
233        match &self.train_res {
234            TrainRes::Strap(strap) => Ok(Some(strap.clone())),
235            _ => Ok(None),
236        }
237    }
238
239    #[getter]
240    pub fn get_res_point(&self) -> anyhow::Result<Option<method::Point>> {
241        match &self.train_res {
242            TrainRes::Point(point) => Ok(Some(point.clone())),
243            _ => Ok(None),
244        }
245    }
246
247    #[pyo3(name = "walk")]
248    /// Exposes `walk` to Python.
249    fn walk_py(&mut self) -> anyhow::Result<()> {
250        self.walk()
251    }
252
253    #[pyo3(name = "step")]
254    fn step_py(&mut self) -> anyhow::Result<()> {
255        self.step(|| format_dbg!())
256    }
257
258    #[pyo3(name = "set_save_interval")]
259    #[pyo3(signature = (save_interval=None))]
260    /// Set save interval and cascade to nested components.
261    fn set_save_interval_py(&mut self, save_interval: Option<usize>) {
262        self.set_save_interval(save_interval);
263    }
264
265    #[pyo3(name = "get_save_interval")]
266    fn get_save_interval_py(&self) -> anyhow::Result<Option<usize>> {
267        Ok(self.get_save_interval())
268    }
269
270    #[pyo3(name = "trim_failed_steps")]
271    fn trim_failed_steps_py(&mut self) -> anyhow::Result<()> {
272        self.trim_failed_steps()?;
273        Ok(())
274    }
275}
276
277pub struct SetSpeedTrainSimBuilder {
278    pub loco_con: Consist,
279    /// Number of railcars by type on the train
280    pub n_cars_by_type: HashMap<String, u32>,
281    pub state: TrainState,
282    pub speed_trace: SpeedTrace,
283    pub train_res: TrainRes,
284    pub path_tpc: PathTpc,
285    pub save_interval: Option<usize>,
286    /// Time-dependent temperature at sea level that can be corrected for altitude using a standard model
287    pub temp_trace: Option<TemperatureTrace>,
288}
289
290impl From<SetSpeedTrainSimBuilder> for SetSpeedTrainSim {
291    fn from(value: SetSpeedTrainSimBuilder) -> Self {
292        SetSpeedTrainSim {
293            loco_con: value.loco_con,
294            n_cars_by_type: value.n_cars_by_type,
295            state: value.state,
296            speed_trace: value.speed_trace,
297            train_res: value.train_res,
298            path_tpc: value.path_tpc,
299            history: Default::default(),
300            save_interval: value.save_interval,
301            temp_trace: value.temp_trace,
302        }
303    }
304}
305
306impl SetSpeedTrainSim {
307    /// Trims off any portion of the trip that failed to run
308    pub fn trim_failed_steps(&mut self) -> anyhow::Result<()> {
309        if *self.state.i.get_fresh(|| format_dbg!())? <= 1 {
310            bail!("`walk` method has not proceeded past first time step.")
311        }
312        self.speed_trace
313            .trim(None, Some(*self.state.i.get_fresh(|| format_dbg!())?))?;
314
315        Ok(())
316    }
317
318    /// Sets `save_interval` for self and nested `loco_con`.
319    pub fn set_save_interval(&mut self, save_interval: Option<usize>) {
320        self.save_interval = save_interval;
321        self.loco_con.set_save_interval(save_interval);
322    }
323
324    /// Returns `self.save_interval` and asserts that this is equal
325    /// to `self.loco_con.get_save_interval()`.
326    pub fn get_save_interval(&self) -> Option<usize> {
327        // this ensures that save interval has been propagated
328        assert_eq!(self.save_interval, self.loco_con.get_save_interval());
329        self.save_interval
330    }
331
332    /// Solves time step.
333    pub fn solve_step(&mut self) -> anyhow::Result<()> {
334        // checking on speed trace to ensure it is at least stopped or moving forward (no backwards)
335        let dt = self.speed_trace.time[*self.state.i.get_fresh(|| format_dbg!())?]
336            - *self.state.time.get_stale(|| format_dbg!())?;
337        self.state.dt.update(dt, || format_dbg!())?;
338
339        ensure!(
340            self.speed_trace.speed[*self.state.i.get_fresh(|| format_dbg!())?]
341                >= si::Velocity::ZERO,
342            format_dbg!(
343                self.speed_trace.speed[*self.state.i.get_fresh(|| format_dbg!())?]
344                    >= si::Velocity::ZERO
345            )
346        );
347        self.loco_con
348            .state
349            .pwr_cat_lim
350            .mark_fresh(|| format_dbg!())?;
351        // not used in set_speed_train_sim
352        self.state.speed_limit.mark_fresh(|| format_dbg!())?;
353        // not used in set_speed_train_sim
354        self.state.speed_target.mark_fresh(|| format_dbg!())?;
355        // not used in set_speed_train_sim
356        self.state.mass_static.mark_fresh(|| format_dbg!())?;
357        // not used in set_speed_train_sim
358        self.state.mass_rot.mark_fresh(|| format_dbg!())?;
359        // not used in set_speed_train_sim
360        self.state.mass_freight.mark_fresh(|| format_dbg!())?;
361        // TODO: update this if length ever becomes dynamic
362        self.state.length.mark_fresh(|| format_dbg!())?;
363        // set the catenary power limit.  I'm assuming it is 0 at this point.
364        // self.loco_con.set_cat_power_limit(
365        //     &self.path_tpc,
366        //     *self.state.offset.get_fresh(|| format_dbg!())?,
367        // )?;
368        // set aux power loads.  this will be calculated in the locomotive model and be loco type dependent.
369        self.loco_con.set_pwr_aux(Some(true))?;
370        let train_mass = Some(self.state.mass_compound().with_context(|| format_dbg!())?);
371
372        let elev_and_temp: Option<(si::Length, si::ThermodynamicTemperature)> =
373            if let Some(tt) = &self.temp_trace {
374                Some((
375                    *self.state.elev_front.get_fresh(|| format_dbg!())?,
376                    tt.get_temp_at_time_and_elev(
377                        *self.state.time.get_fresh(|| format_dbg!())?,
378                        *self.state.elev_front.get_fresh(|| format_dbg!())?,
379                    )
380                    .with_context(|| format_dbg!())?,
381                ))
382            } else {
383                None
384            };
385
386        // set the max power out for the consist based on calculation of each loco state
387        self.loco_con.set_curr_pwr_max_out(
388            None,
389            elev_and_temp,
390            train_mass,
391            Some(*self.state.speed.get_stale(|| format_dbg!())?),
392            self.speed_trace
393                .dt(*self.state.i.get_fresh(|| format_dbg!())?),
394        )?;
395        // calculate the train resistance for current time steps.  Based on train config and calculated in train model.
396        self.train_res
397            .update_res(&mut self.state, &self.path_tpc, &Dir::Fwd)?;
398        // figure out how much power is needed to pull train with current speed trace.
399        self.solve_required_pwr(
400            self.speed_trace
401                .dt(*self.state.i.get_fresh(|| format_dbg!())?),
402        )?;
403        self.loco_con.solve_energy_consumption(
404            *self.state.pwr_whl_out.get_fresh(|| format_dbg!())?,
405            train_mass,
406            Some(self.speed_trace.speed[*self.state.i.get_fresh(|| format_dbg!())?]),
407            self.speed_trace
408                .dt(*self.state.i.get_fresh(|| format_dbg!())?),
409            Some(true),
410        )?;
411        // advance time
412        self.state.time.increment(dt, || format_dbg!())?;
413        // update speed
414        self.state.speed.update(
415            self.speed_trace.speed[*self.state.i.get_fresh(|| format_dbg!())?],
416            || format_dbg!(),
417        )?;
418        set_link_and_offset(&mut self.state, &self.path_tpc)?;
419        // update offset
420        self.state.offset.increment(
421            self.speed_trace
422                .mean(*self.state.i.get_fresh(|| format_dbg!())?)
423                * *self.state.dt.get_fresh(|| format_dbg!())?,
424            || format_dbg!(),
425        )?;
426        // update total distance
427        self.state.total_dist.increment(
428            (self
429                .speed_trace
430                .mean(*self.state.i.get_fresh(|| format_dbg!())?)
431                * *self.state.dt.get_fresh(|| format_dbg!())?)
432            .abs(),
433            || format_dbg!(),
434        )?;
435        self.set_cumulative(
436            *self.state.dt.get_fresh(|| format_dbg!())?,
437            || format_dbg!(),
438        )?;
439        Ok(())
440    }
441
442    /// Iterates `save_state` and `step` through all time steps.
443    pub fn walk(&mut self) -> anyhow::Result<()> {
444        self.save_state(|| format_dbg!())?;
445        loop {
446            if *self.state.i.get_fresh(|| format_dbg!())? > self.speed_trace.len() - 2 {
447                break;
448            }
449            self.step(|| format_dbg!()).with_context(|| format_dbg!())?;
450        }
451        Ok(())
452    }
453
454    /// Sets power requirements based on:
455    /// - rolling resistance
456    /// - drag
457    /// - inertia
458    /// - acceleration
459    pub fn solve_required_pwr(&mut self, dt: si::Time) -> anyhow::Result<()> {
460        // This calculates the maximum power from loco based on current power, ramp rate, and dt of model.  will return 0 if this is negative.
461        let pwr_pos_max = self
462            .loco_con
463            .state
464            .pwr_out_max
465            .get_fresh(|| format_dbg!())?
466            .min(
467                si::Power::ZERO.max(
468                    *self.state.pwr_whl_out.get_stale(|| format_dbg!())?
469                        + *self
470                            .loco_con
471                            .state
472                            .pwr_rate_out_max
473                            .get_fresh(|| format_dbg!())?
474                            * *self.state.dt.get_fresh(|| format_dbg!())?,
475                ),
476            );
477
478        // find max dynamic braking power as positive value
479        let pwr_neg_max = self
480            .loco_con
481            .state
482            .pwr_dyn_brake_max
483            .get_fresh(|| format_dbg!())?
484            .max(si::Power::ZERO);
485
486        // not sure why we have these checks if the max function worked earlier.
487        ensure!(
488            pwr_pos_max >= si::Power::ZERO,
489            format_dbg!(pwr_pos_max >= si::Power::ZERO)
490        );
491
492        // res for resistance is a horrible name.  It collides with reversible energy storage.  This like is calculating train resistance for the time step.
493        self.state.pwr_res.update(
494            self.state.res_net().with_context(|| format_dbg!())?
495                * self
496                    .speed_trace
497                    .mean(*self.state.i.get_fresh(|| format_dbg!())?),
498            || format_dbg!(),
499        )?;
500        // find power to accelerate the train mass from an energy perspective.
501        self.state.pwr_accel.update(
502            self.state.mass_compound().with_context(|| format_dbg!())?
503                / (2.0
504                    * self
505                        .speed_trace
506                        .dt(*self.state.i.get_fresh(|| format_dbg!())?))
507                * (self.speed_trace.speed[*self.state.i.get_fresh(|| format_dbg!())?]
508                    .powi(typenum::P2::new())
509                    - self.speed_trace.speed[*self.state.i.get_fresh(|| format_dbg!())? - 1]
510                        .powi(typenum::P2::new())),
511            || format_dbg!(),
512        )?;
513
514        // total power exerted by the consist to move the train, without limits applied
515        let pwr_whl_out_unclipped = *self.state.pwr_accel.get_fresh(|| format_dbg!())?
516            + *self.state.pwr_res.get_fresh(|| format_dbg!())?;
517
518        // limit power to within the consist capability
519        self.state.pwr_whl_out.update(
520            pwr_whl_out_unclipped.max(-pwr_neg_max).min(pwr_pos_max),
521            || format_dbg!(),
522        )?;
523
524        // add to positive or negative wheel energy tracking.
525        if *self.state.pwr_whl_out.get_fresh(|| format_dbg!())? >= 0. * uc::W {
526            self.state.energy_whl_out_pos.increment(
527                *self.state.pwr_whl_out.get_fresh(|| format_dbg!())? * dt,
528                || format_dbg!(),
529            )?;
530            self.state
531                .energy_whl_out_neg
532                .increment(si::Energy::ZERO, || format_dbg!())?;
533        } else {
534            self.state.energy_whl_out_neg.increment(
535                -*self.state.pwr_whl_out.get_fresh(|| format_dbg!())? * dt,
536                || format_dbg!(),
537            )?;
538            self.state
539                .energy_whl_out_pos
540                .increment(si::Energy::ZERO, || format_dbg!())?;
541        }
542        Ok(())
543    }
544}
545
546impl StateMethods for SetSpeedTrainSim {}
547impl CheckAndResetState for SetSpeedTrainSim {
548    fn check_and_reset<F: Fn() -> String>(&mut self, loc: F) -> anyhow::Result<()> {
549        // self.state.speed_limit.mark_fresh(|| format_dbg!())?;
550        self.state
551            .check_and_reset(|| format!("{}\n{}", loc(), format_dbg!()))?;
552        self.loco_con
553            .check_and_reset(|| format!("{}\n{}", loc(), format_dbg!()))?;
554        Ok(())
555    }
556}
557impl SetCumulative for SetSpeedTrainSim {
558    fn set_cumulative<F: Fn() -> String>(&mut self, dt: si::Time, loc: F) -> anyhow::Result<()> {
559        self.state
560            .set_cumulative(dt, || format!("{}\n{}", loc(), format_dbg!()))?;
561        self.loco_con
562            .set_cumulative(dt, || format!("{}\n{}", loc(), format_dbg!()))?;
563        Ok(())
564    }
565}
566
567impl Step for SetSpeedTrainSim {
568    /// Solves step, saves state, steps nested `loco_con`, and increments `self.i`.
569    fn step<F: Fn() -> String>(&mut self, loc: F) -> anyhow::Result<()> {
570        let i = *self.state.i.get_fresh(|| format_dbg!())?;
571        self.check_and_reset(|| format_dbg!())?;
572        self.state
573            .i
574            .increment(1, || format!("{}\n{}", loc(), format_dbg!()))?;
575        self.loco_con.step(|| format_dbg!())?;
576        self.solve_step()
577            .with_context(|| format!("{}\ntime step: {}", loc(), i))?;
578
579        self.save_state(|| format_dbg!())?;
580        Ok(())
581    }
582}
583impl SaveState for SetSpeedTrainSim {
584    /// Saves current time step for self and nested `loco_con`.
585    fn save_state<F: Fn() -> String>(&mut self, _loc: F) -> anyhow::Result<()> {
586        if let Some(interval) = self.save_interval {
587            if self.state.i.get_fresh(|| format_dbg!())? % interval == 0 {
588                self.history.push(self.state.clone());
589                self.loco_con.save_state(|| format_dbg!())?;
590            }
591        }
592        Ok(())
593    }
594}
595impl Init for SetSpeedTrainSim {
596    fn init(&mut self) -> Result<(), Error> {
597        self.loco_con.init()?;
598        self.speed_trace.init()?;
599        self.train_res.init()?;
600        self.path_tpc.init()?;
601        self.state.init()?;
602        self.history.init()?;
603        Ok(())
604    }
605}
606impl SerdeAPI for SetSpeedTrainSim {}
607
608impl Default for SetSpeedTrainSim {
609    fn default() -> Self {
610        Self {
611            loco_con: Consist::default(),
612            n_cars_by_type: Default::default(),
613            state: TrainState::valid(),
614            train_res: TrainRes::valid(),
615            path_tpc: PathTpc::valid(),
616            speed_trace: SpeedTrace::default(),
617            history: TrainStateHistoryVec::default(),
618            save_interval: None,
619            temp_trace: Default::default(),
620        }
621    }
622}
623
624#[cfg(test)]
625mod tests {
626    use super::SetSpeedTrainSim;
627
628    #[test]
629    fn test_set_speed_train_sim() {
630        let mut train_sim = SetSpeedTrainSim::default();
631        train_sim.walk().unwrap();
632        assert!(
633            *train_sim
634                .loco_con
635                .state
636                .i
637                .get_fresh(|| format_dbg!())
638                .unwrap()
639                > 1
640        );
641    }
642}