altrios_core/train/
train_config.rs

1use super::environment::TemperatureTrace;
2use super::resistance::kind as res_kind;
3use super::resistance::method as res_method;
4use super::speed_limit_train_sim::TimedLinkPath;
5#[cfg(feature = "pyo3")]
6use super::TrainResWrapper;
7use crate::consist::locomotive::locomotive_model::PowertrainType;
8
9use super::{
10    friction_brakes::*, rail_vehicle::RailVehicle, train_imports::*, InitTrainState,
11    SetSpeedTrainSim, SetSpeedTrainSimBuilder, SpeedLimitTrainSim, SpeedLimitTrainSimBuilder,
12    SpeedTrace, TrainState,
13};
14use crate::track::link::link_idx::LinkPath;
15use crate::track::link::network::Network;
16use crate::track::LocationMap;
17
18use polars::prelude::*;
19use polars_lazy::dsl::max_horizontal;
20#[allow(unused_imports)]
21use polars_lazy::prelude::*;
22use pyo3_polars::PyDataFrame;
23
24#[serde_api]
25#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
26#[cfg_attr(feature = "pyo3", pyclass(module = "altrios", subclass, eq))]
27/// User-defined train configuration used to generate
28/// [crate::prelude::TrainParams]. Any optional fields will be populated later
29/// in [TrainSimBuilder::make_train_sim_parts]
30pub struct TrainConfig {
31    /// Types of rail vehicle composing the train
32    pub rail_vehicles: Vec<RailVehicle>,
33    /// Number of railcars by type on the train
34    pub n_cars_by_type: HashMap<String, u32>,
35    /// Train type matching one of the PTC types
36    pub train_type: TrainType,
37    /// Train length that overrides the railcar specific value, if provided
38    pub train_length: Option<si::Length>,
39    /// Total train mass that overrides the railcar specific values, if provided
40    pub train_mass: Option<si::Mass>,
41
42    #[serde(default)]
43    /// Optional vector of drag areas (i.e. drag coeff. times frontal area)
44    /// for each car.  If provided, the total drag area (drag coefficient
45    /// times frontal area) calculated from this vector is the sum of these
46    /// coefficients. Otherwise, each rail car's drag contribution based on its
47    /// drag coefficient and frontal area will be summed across the train.
48    pub cd_area_vec: Option<Vec<si::Area>>,
49}
50
51#[pyo3_api]
52impl TrainConfig {
53    #[new]
54    #[pyo3(signature = (
55        rail_vehicles,
56        n_cars_by_type,
57        train_type=None,
58        train_length_meters=None,
59        train_mass_kilograms=None,
60        cd_area_vec=None,
61    ))]
62    fn __new__(
63        rail_vehicles: Vec<RailVehicle>,
64        n_cars_by_type: HashMap<String, u32>,
65        train_type: Option<TrainType>,
66        train_length_meters: Option<f64>,
67        train_mass_kilograms: Option<f64>,
68        cd_area_vec: Option<Vec<f64>>,
69    ) -> anyhow::Result<Self> {
70        Self::new(
71            rail_vehicles,
72            n_cars_by_type,
73            train_type.unwrap_or_default(),
74            train_length_meters.map(|v| v * uc::M),
75            train_mass_kilograms.map(|v| v * uc::KG),
76            cd_area_vec.map(|dcv| dcv.iter().map(|dc| *dc * uc::M2).collect()),
77        )
78    }
79
80    #[pyo3(name = "make_train_params")]
81    /// - `rail_vehicles` - list of `RailVehicle` objects with 1 element for each _type_ of rail vehicle
82    fn make_train_params_py(&self) -> anyhow::Result<TrainParams> {
83        self.make_train_params()
84    }
85
86    #[getter]
87    fn get_train_length_meters(&self) -> Option<f64> {
88        self.train_length.map(|l| l.get::<si::meter>())
89    }
90
91    #[getter]
92    fn get_train_mass_kilograms(&self) -> Option<f64> {
93        self.train_mass.map(|l| l.get::<si::kilogram>())
94    }
95
96    #[getter]
97    fn get_cd_area_vec_meters_squared(&self) -> Option<Vec<f64>> {
98        self.cd_area_vec.as_ref().map(|dcv| {
99            dcv.iter()
100                .cloned()
101                .map(|x| x.get::<si::square_meter>())
102                .collect()
103        })
104    }
105
106    fn set_cd_area_vec(&mut self, new_val: Vec<f64>) -> anyhow::Result<()> {
107        self.cd_area_vec = Some(new_val.iter().map(|x| *x * uc::M2).collect());
108        Ok(())
109    }
110}
111
112impl Init for TrainConfig {
113    fn init(&mut self) -> Result<(), Error> {
114        if let Some(dcv) = &self.cd_area_vec {
115            // TODO: account for locomotive drag here, too
116            if dcv.len() as u32 != self.cars_total() {
117                return Err(Error::InitError(
118                    "`cd_area_vec` len and `cars_total()` do not match".into(),
119                ));
120            }
121        };
122        Ok(())
123    }
124}
125impl SerdeAPI for TrainConfig {}
126
127impl TrainConfig {
128    pub fn new(
129        rail_vehicles: Vec<RailVehicle>,
130        n_cars_by_type: HashMap<String, u32>,
131        train_type: TrainType,
132        train_length: Option<si::Length>,
133        train_mass: Option<si::Mass>,
134        cd_area_vec: Option<Vec<si::Area>>,
135    ) -> anyhow::Result<Self> {
136        let mut train_config = Self {
137            rail_vehicles,
138            n_cars_by_type,
139            train_type,
140            train_length,
141            train_mass,
142            cd_area_vec,
143        };
144        train_config.init()?;
145        Ok(train_config)
146    }
147
148    pub fn cars_total(&self) -> u32 {
149        self.n_cars_by_type.values().fold(0, |acc, n| *n + acc)
150    }
151
152    /// # Arguments
153    /// - `rail_vehicles` - slice of `RailVehicle` objects with 1 element for each _type_ of rail vehicle
154    /// # Important
155    /// This method assumes that any calling method has already checked that
156    /// all the `car_type` fields in `rail_vehicles` have matching keys in
157    /// `self.n_cars_by_type`.
158    pub fn make_train_params(&self) -> anyhow::Result<TrainParams> {
159        // total towed mass of rail vehicles
160        let towed_mass_static = self.train_mass.unwrap_or({
161            self.rail_vehicles.iter().try_fold(
162                0. * uc::KG,
163                |acc, rv| -> anyhow::Result<si::Mass> {
164                    Ok(acc
165                        + rv.mass()
166                            .with_context(|| format_dbg!())?
167                            .with_context(|| "`make_train_params` failed")?
168                            * *self
169                                .n_cars_by_type
170                                .get(&rv.car_type)
171                                .with_context(|| format_dbg!())?
172                                as f64
173                            * uc::R)
174                },
175            )?
176        });
177
178        let length: si::Length = match self.train_length {
179            Some(tl) => tl,
180            None => self
181                .rail_vehicles
182                .iter()
183                .fold(0. * uc::M, |acc, rv| -> si::Length {
184                    acc + rv.length * *self.n_cars_by_type.get(&rv.car_type).unwrap() as f64
185                }),
186        };
187
188        let train_params = TrainParams {
189            length,
190            speed_max: self.rail_vehicles.iter().fold(
191                f64::INFINITY * uc::MPS,
192                |acc, rv| -> si::Velocity {
193                    if *self.n_cars_by_type.get(&rv.car_type).unwrap() > 0 {
194                        acc.min(rv.speed_max)
195                    } else {
196                        acc
197                    }
198                },
199            ),
200            towed_mass_static,
201            mass_per_brake: (towed_mass_static + {
202                let mass_rot = self
203                    .rail_vehicles
204                    .iter()
205                    .fold(0. * uc::KG, |acc, rv| -> si::Mass {
206                        acc + rv.mass_rot_per_axle
207                            * *self.n_cars_by_type.get(&rv.car_type).unwrap() as f64
208                            * rv.axle_count as f64
209                    });
210                mass_rot
211            }) / self.rail_vehicles.iter().fold(0, |acc, rv| -> u32 {
212                acc + rv.brake_count as u32 * *self.n_cars_by_type.get(&rv.car_type).unwrap()
213            }) as f64,
214            axle_count: self.rail_vehicles.iter().fold(0, |acc, rv| -> u32 {
215                acc + rv.axle_count as u32 * *self.n_cars_by_type.get(&rv.car_type).unwrap()
216            }),
217            train_type: self.train_type,
218            // TODO: change it so that curve coefficient is specified at the train level, and replace `unwrap` function calls
219            // with proper result handling, and relpace `first().unwrap()` with real code.
220            curve_coeff_0: self.rail_vehicles.first().unwrap().curve_coeff_0,
221            curve_coeff_1: self.rail_vehicles.first().unwrap().curve_coeff_1,
222            curve_coeff_2: self.rail_vehicles.first().unwrap().curve_coeff_2,
223        };
224        Ok(train_params)
225    }
226}
227
228impl Valid for TrainConfig {
229    fn valid() -> Self {
230        Self {
231            rail_vehicles: vec![RailVehicle::default()],
232            n_cars_by_type: HashMap::from([("Bulk".into(), 100_u32)]),
233            train_type: TrainType::Freight,
234            train_length: None,
235            train_mass: None,
236            cd_area_vec: None,
237        }
238    }
239}
240
241#[serde_api]
242#[derive(Debug, Default, Clone, Deserialize, Serialize, PartialEq)]
243#[cfg_attr(feature = "pyo3", pyclass(module = "altrios", subclass, eq))]
244pub struct TrainSimBuilder {
245    /// Unique, user-defined identifier for the train
246    pub train_id: String,
247    pub train_config: TrainConfig,
248    pub loco_con: Consist,
249    /// Origin_ID from train planner to map to track network locations.  Only needed if
250    /// [Self::make_speed_limit_train_sim] will be called.
251    pub origin_id: Option<String>,
252    /// Destination_ID from train planner to map to track network locations.  Only needed if
253    /// [Self::make_speed_limit_train_sim] will be called.
254    pub destination_id: Option<String>,
255
256    init_train_state: Option<InitTrainState>,
257}
258
259#[pyo3_api]
260impl TrainSimBuilder {
261    #[new]
262    #[pyo3(signature = (
263        train_id,
264        train_config,
265        loco_con,
266        origin_id=None,
267        destination_id=None,
268        init_train_state=None,
269    ))]
270    fn __new__(
271        train_id: String,
272        train_config: TrainConfig,
273        loco_con: Consist,
274        origin_id: Option<String>,
275        destination_id: Option<String>,
276        init_train_state: Option<InitTrainState>,
277    ) -> Self {
278        Self::new(
279            train_id,
280            train_config,
281            loco_con,
282            origin_id,
283            destination_id,
284            init_train_state,
285        )
286    }
287
288    #[pyo3(
289        name = "make_set_speed_train_sim",
290        signature = (
291            network,
292            link_path,
293            speed_trace,
294            save_interval=None,
295            temp_trace=None,
296        )
297    )]
298    fn make_set_speed_train_sim_py(
299        &self,
300        network: &Bound<PyAny>,
301        link_path: &Bound<PyAny>,
302        speed_trace: SpeedTrace,
303        save_interval: Option<usize>,
304        temp_trace: Option<TemperatureTrace>,
305    ) -> anyhow::Result<SetSpeedTrainSim> {
306        let network = match network.extract::<Network>() {
307            Ok(n) => n,
308            Err(_) => {
309                let n = network
310                    .extract::<Vec<Link>>()
311                    .map_err(|_| anyhow!("{}", format_dbg!()))?;
312                Network(Default::default(), n)
313            }
314        };
315
316        let link_path = match link_path.extract::<LinkPath>() {
317            Ok(lp) => lp,
318            Err(_) => {
319                let lp = link_path
320                    .extract::<Vec<LinkIdx>>()
321                    .map_err(|_| anyhow!("{}", format_dbg!()))?;
322                LinkPath(lp)
323            }
324        };
325
326        self.make_set_speed_train_sim(network, link_path, speed_trace, save_interval, temp_trace)
327    }
328
329    #[pyo3(
330        name = "make_set_speed_train_sim_and_parts",
331        signature = (
332            network,
333            link_path,
334            speed_trace,
335            save_interval=None,
336            temp_trace=None,
337        )
338    )]
339    fn make_set_speed_train_sim_and_parts_py(
340        &self,
341        network: &Bound<PyAny>,
342        link_path: &Bound<PyAny>,
343        speed_trace: SpeedTrace,
344        save_interval: Option<usize>,
345        temp_trace: Option<TemperatureTrace>,
346    ) -> anyhow::Result<(
347        SetSpeedTrainSim,
348        TrainParams,
349        PathTpc,
350        TrainResWrapper,
351        FricBrake,
352    )> {
353        let network = match network.extract::<Network>() {
354            Ok(n) => n,
355            Err(_) => {
356                let n = network
357                    .extract::<Vec<Link>>()
358                    .map_err(|_| anyhow!("{}", format_dbg!()))?;
359                Network(Default::default(), n)
360            }
361        };
362
363        let link_path = match link_path.extract::<LinkPath>() {
364            Ok(lp) => lp,
365            Err(_) => {
366                let lp = link_path
367                    .extract::<Vec<LinkIdx>>()
368                    .map_err(|_| anyhow!("{}", format_dbg!()))?;
369                LinkPath(lp)
370            }
371        };
372
373        let (train_sim, train_params, path_tpc, tr, fb) = self
374            .make_set_speed_train_sim_and_parts(
375                network,
376                link_path,
377                speed_trace,
378                save_interval,
379                temp_trace,
380            )
381            .with_context(|| format_dbg!())?;
382
383        let trw = TrainResWrapper(tr);
384        Ok((train_sim, train_params, path_tpc, trw, fb))
385    }
386
387    #[pyo3(
388        name = "make_speed_limit_train_sim",
389        signature = (
390            location_map,
391            save_interval=None,
392            simulation_days=None,
393            scenario_year=None,
394            temp_trace=None,
395        )
396    )]
397    fn make_speed_limit_train_sim_py(
398        &self,
399        location_map: LocationMap,
400        save_interval: Option<usize>,
401        simulation_days: Option<i32>,
402        scenario_year: Option<i32>,
403        temp_trace: Option<TemperatureTrace>,
404    ) -> anyhow::Result<SpeedLimitTrainSim> {
405        self.make_speed_limit_train_sim(
406            &location_map,
407            save_interval,
408            simulation_days,
409            scenario_year,
410            temp_trace,
411        )
412    }
413
414    #[pyo3(
415        name = "make_speed_limit_train_sim_and_parts",
416        signature = (
417            location_map,
418            save_interval=None,
419            simulation_days=None,
420            scenario_year=None,
421            temp_trace=None,
422        )
423    )]
424    fn make_speed_limit_train_sim_and_parts_py(
425        &self,
426        location_map: LocationMap,
427        save_interval: Option<usize>,
428        simulation_days: Option<i32>,
429        scenario_year: Option<i32>,
430        temp_trace: Option<TemperatureTrace>,
431    ) -> anyhow::Result<(SpeedLimitTrainSim, PathTpc, TrainResWrapper, FricBrake)> {
432        let (ts, path_tpc, tr, fb) = self.make_speed_limit_train_sim_and_parts(
433            &location_map,
434            save_interval,
435            simulation_days,
436            scenario_year,
437            temp_trace,
438        )?;
439
440        let trw = TrainResWrapper(tr);
441        Ok((ts, path_tpc, trw, fb))
442    }
443}
444
445impl Init for TrainSimBuilder {}
446impl SerdeAPI for TrainSimBuilder {}
447
448impl TrainSimBuilder {
449    pub fn new(
450        train_id: String,
451        train_config: TrainConfig,
452        loco_con: Consist,
453        origin_id: Option<String>,
454        destination_id: Option<String>,
455        init_train_state: Option<InitTrainState>,
456    ) -> Self {
457        Self {
458            train_id,
459            train_config,
460            loco_con,
461            origin_id,
462            destination_id,
463            init_train_state,
464        }
465    }
466
467    fn make_train_sim_parts(
468        &self,
469        save_interval: Option<usize>,
470    ) -> anyhow::Result<(TrainParams, TrainState, PathTpc, TrainRes, FricBrake)> {
471        let rvs = &self.train_config.rail_vehicles;
472        // check that `self.train_config.n_cars_by_type` has keys matching `rail_vehicles`
473        self.check_rv_keys()?;
474        let train_params = self
475            .train_config
476            .make_train_params()
477            .with_context(|| format_dbg!())?;
478
479        let length = train_params.length;
480        // total train weight including locomotives, baseline rail vehicle masses, and freight mass
481        let train_mass_static = train_params.towed_mass_static
482            + self
483                .loco_con
484                .mass()
485                .with_context(|| format_dbg!())?
486                .unwrap_or_else(|| 0. * uc::KG);
487
488        let mass_rot = rvs.iter().fold(0. * uc::KG, |acc, rv| -> si::Mass {
489            acc + rv.mass_rot_per_axle
490                * *self.train_config.n_cars_by_type.get(&rv.car_type).unwrap() as f64
491                * rv.axle_count as f64
492        });
493        let mass_freight = rvs.iter().fold(0. * uc::KG, |acc, rv| -> si::Mass {
494            acc + rv.mass_freight
495                * *self.train_config.n_cars_by_type.get(&rv.car_type).unwrap() as f64
496        });
497        let max_fric_braking = uc::ACC_GRAV
498            * train_params.towed_mass_static
499            * rvs.iter().fold(0. * uc::R, |acc, rv| -> si::Ratio {
500                acc + rv.braking_ratio
501                    * *self.train_config.n_cars_by_type.get(&rv.car_type).unwrap() as f64
502            })
503            / self.train_config.cars_total() as f64;
504
505        let state = TrainState::new(
506            length,
507            train_mass_static,
508            mass_rot,
509            mass_freight,
510            self.init_train_state.clone(),
511        );
512
513        let path_tpc = PathTpc::new(train_params);
514
515        let train_res = {
516            let res_bearing = res_kind::bearing::Basic::new(rvs.iter().fold(
517                0. * uc::N,
518                |acc, rv| -> si::Force {
519                    acc + rv.bearing_res_per_axle
520                        * rv.axle_count as f64
521                        * *self.train_config.n_cars_by_type.get(&rv.car_type).unwrap() as f64
522                },
523            ));
524            // Sum of mass-averaged rolling resistances across all railcars
525            let res_rolling = res_kind::rolling::Basic::new(rvs.iter().try_fold(
526                0.0 * uc::R,
527                |acc, rv| -> anyhow::Result<si::Ratio> {
528                    Ok(acc
529                        + rv.rolling_ratio
530                            * rv.mass()
531                                .with_context(|| format_dbg!())?
532                                .with_context(|| format!("{}\nExpected `Some`", format_dbg!()))?
533                            / train_params.towed_mass_static // does not include locomotive consist mass -- TODO: fix this, carefully                            
534                            * *self.train_config.n_cars_by_type.get(&rv.car_type).unwrap() as f64
535                            * uc::R)
536                },
537            )?);
538            let davis_b = res_kind::davis_b::Basic::new(rvs.iter().try_fold(
539                0.0 * uc::S / uc::M,
540                |acc, rv| -> anyhow::Result<si::InverseVelocity> {
541                    Ok(acc
542                        + rv.davis_b
543                            * rv.mass()
544                                .with_context(|| format_dbg!())?
545                                .with_context(|| format!("{}\nExpected `Some`", format_dbg!()))?
546                            / train_params.towed_mass_static // does not include locomotive consist mass -- TODO: fix this, carefully
547                            * *self.train_config.n_cars_by_type.get(&rv.car_type).unwrap() as f64
548                            * uc::R)
549                },
550            )?);
551            let res_aero =
552                res_kind::aerodynamic::Basic::new(match &self.train_config.cd_area_vec {
553                    Some(dave) => dave.iter().fold(0. * uc::M2, |acc, dc| *dc + acc),
554                    None => rvs.iter().fold(0.0 * uc::M2, |acc, rv| -> si::Area {
555                        acc + rv.cd_area
556                            * *self.train_config.n_cars_by_type.get(&rv.car_type).unwrap() as f64
557                    }),
558                });
559            let res_grade = res_kind::path_res::Strap::new(path_tpc.grades(), &state)?;
560            let res_curve = res_kind::path_res::Strap::new(path_tpc.curves(), &state)?;
561            TrainRes::Strap(res_method::Strap::new(
562                res_bearing,
563                res_rolling,
564                davis_b,
565                res_aero,
566                res_grade,
567                res_curve,
568            ))
569        };
570
571        let fric_brake = FricBrake::new(max_fric_braking, None, None, None, save_interval);
572
573        Ok((train_params, state, path_tpc, train_res, fric_brake))
574    }
575
576    fn check_rv_keys(&self) -> anyhow::Result<()> {
577        let rv_car_type_set = HashSet::<String>::from_iter(
578            self.train_config
579                .rail_vehicles
580                .iter()
581                .map(|rv| rv.car_type.clone()),
582        );
583        let n_cars_type_set =
584            HashSet::<String>::from_iter(self.train_config.n_cars_by_type.keys().cloned());
585        let extra_keys_in_rv = rv_car_type_set
586            .difference(&n_cars_type_set)
587            .collect::<Vec<&String>>();
588        let extra_keys_in_n_cars = n_cars_type_set
589            .difference(&rv_car_type_set)
590            .collect::<Vec<&String>>();
591        if !extra_keys_in_rv.is_empty() {
592            bail!(
593                "Extra values in `car_type` for `rail_vehicles` that are not in `n_cars_by_type`: {:?}",
594                extra_keys_in_rv
595            );
596        }
597        if !extra_keys_in_n_cars.is_empty() {
598            bail!(
599                "Extra values in `n_cars_by_type` that are not in `car_type` for `rail_vehicles`: {:?}",
600                extra_keys_in_n_cars
601            );
602        }
603        Ok(())
604    }
605
606    pub fn make_set_speed_train_sim<Q: AsRef<[Link]>, R: AsRef<[LinkIdx]>>(
607        &self,
608        network: Q,
609        link_path: R,
610        speed_trace: SpeedTrace,
611        save_interval: Option<usize>,
612        temp_trace: Option<TemperatureTrace>,
613    ) -> anyhow::Result<SetSpeedTrainSim> {
614        ensure!(
615            self.origin_id.is_none() & self.destination_id.is_none(),
616            "{}\n`origin_id` and `destination_id` must both be `None` when calling `make_set_speed_train_sim`.",
617            format_dbg!()
618        );
619
620        let (_, state, mut path_tpc, train_res, _fric_brake) = self
621            .make_train_sim_parts(save_interval)
622            .with_context(|| format_dbg!())?;
623
624        path_tpc.extend(network, link_path)?;
625        Ok(SetSpeedTrainSimBuilder {
626            loco_con: self.loco_con.clone(),
627            n_cars_by_type: self.train_config.n_cars_by_type.clone(),
628            state,
629            speed_trace,
630            train_res,
631            path_tpc,
632            save_interval,
633            temp_trace,
634        }
635        .into())
636    }
637
638    pub fn make_set_speed_train_sim_and_parts<Q: AsRef<[Link]>, R: AsRef<[LinkIdx]>>(
639        &self,
640        network: Q,
641        link_path: R,
642        speed_trace: SpeedTrace,
643        save_interval: Option<usize>,
644        temp_trace: Option<TemperatureTrace>,
645    ) -> anyhow::Result<(SetSpeedTrainSim, TrainParams, PathTpc, TrainRes, FricBrake)> {
646        ensure!(
647            self.origin_id.is_none() & self.destination_id.is_none(),
648            "{}\n`origin_id` and `destination_id` must both be `None` when calling `make_set_speed_train_sim`.",
649            format_dbg!()
650        );
651
652        let (train_params, state, mut path_tpc, train_res, fric_brake) = self
653            .make_train_sim_parts(save_interval)
654            .with_context(|| format_dbg!())?;
655
656        path_tpc.extend(network, link_path)?;
657        Ok((
658            SetSpeedTrainSimBuilder {
659                loco_con: self.loco_con.clone(),
660                n_cars_by_type: self.train_config.n_cars_by_type.clone(),
661                state,
662                speed_trace,
663                train_res: train_res.clone(),
664                path_tpc: path_tpc.clone(),
665                save_interval,
666                temp_trace,
667            }
668            .into(),
669            train_params,
670            path_tpc,
671            train_res,
672            fric_brake,
673        ))
674    }
675
676    pub fn make_speed_limit_train_sim(
677        &self,
678        location_map: &LocationMap,
679        save_interval: Option<usize>,
680        simulation_days: Option<i32>,
681        scenario_year: Option<i32>,
682        temp_trace: Option<TemperatureTrace>,
683    ) -> anyhow::Result<SpeedLimitTrainSim> {
684        let (_, state, path_tpc, train_res, fric_brake) = self
685            .make_train_sim_parts(save_interval)
686            .with_context(|| format_dbg!())?;
687
688        ensure!(
689            self.origin_id.is_some() & self.destination_id.is_some(),
690            "{}\nBoth `origin_id` and `destination_id` must be provided when initializing{} ",
691            format_dbg!(),
692            "`TrainSimBuilder` for `make_speed_limit_train_sim` to work."
693        );
694
695        Ok(SpeedLimitTrainSimBuilder {
696            train_id: self.train_id.clone(),
697            // `self.origin_id` verified to be `Some` earlier
698            origs: location_map
699                .get(self.origin_id.as_ref().unwrap())
700                .with_context(|| {
701                    anyhow!(format!(
702                        "{}\n`origin_id`: \"{}\" not found in `location_map` keys: {:?}",
703                        format_dbg!(),
704                        self.origin_id.as_ref().unwrap(),
705                        location_map.keys(),
706                    ))
707                })?
708                .to_vec(),
709            // `self.destination_id` verified to be `Some` earlier
710            dests: location_map
711                .get(self.destination_id.as_ref().unwrap())
712                .with_context(|| {
713                    anyhow!(format!(
714                        "{}\n`destination_id`: \"{}\" not found in `location_map` keys: {:?}",
715                        format_dbg!(),
716                        self.destination_id.as_ref().unwrap(),
717                        location_map.keys(),
718                    ))
719                })?
720                .to_vec(),
721            loco_con: self.loco_con.clone(),
722            n_cars_by_type: self.train_config.n_cars_by_type.clone(),
723            state,
724            train_res,
725            path_tpc,
726            fric_brake,
727            save_interval,
728            simulation_days,
729            scenario_year,
730            temp_trace,
731        }
732        .into())
733    }
734
735    pub fn make_speed_limit_train_sim_and_parts(
736        &self,
737        location_map: &LocationMap,
738        save_interval: Option<usize>,
739        simulation_days: Option<i32>,
740        scenario_year: Option<i32>,
741        temp_trace: Option<TemperatureTrace>,
742    ) -> anyhow::Result<(SpeedLimitTrainSim, PathTpc, TrainRes, FricBrake)> {
743        let (_, state, path_tpc, train_res, fric_brake) = self
744            .make_train_sim_parts(save_interval)
745            .with_context(|| format_dbg!())?;
746
747        ensure!(
748            self.origin_id.is_some() & self.destination_id.is_some(),
749            "{}\nBoth `origin_id` and `destination_id` must be provided when initializing{} ",
750            format_dbg!(),
751            "`TrainSimBuilder` for `make_speed_limit_train_sim` to work."
752        );
753
754        let ts = SpeedLimitTrainSimBuilder {
755            train_id: self.train_id.clone(),
756            // `self.origin_id` verified to be `Some` earlier
757            origs: location_map
758                .get(self.origin_id.as_ref().unwrap())
759                .with_context(|| {
760                    anyhow!(format!(
761                        "{}\n`origin_id`: \"{}\" not found in `location_map` keys: {:?}",
762                        format_dbg!(),
763                        self.origin_id.as_ref().unwrap(),
764                        location_map.keys(),
765                    ))
766                })?
767                .to_vec(),
768            // `self.destination_id` verified to be `Some` earlier
769            dests: location_map
770                .get(self.destination_id.as_ref().unwrap())
771                .with_context(|| {
772                    anyhow!(format!(
773                        "{}\n`destination_id`: \"{}\" not found in `location_map` keys: {:?}",
774                        format_dbg!(),
775                        self.destination_id.as_ref().unwrap(),
776                        location_map.keys(),
777                    ))
778                })?
779                .to_vec(),
780            loco_con: self.loco_con.clone(),
781            n_cars_by_type: self.train_config.n_cars_by_type.clone(),
782            state,
783            train_res: train_res.clone(),
784            path_tpc: path_tpc.clone(),
785            fric_brake: fric_brake.clone(),
786            save_interval,
787            simulation_days,
788            scenario_year,
789            temp_trace,
790        };
791        Ok((ts.into(), path_tpc, train_res, fric_brake))
792    }
793}
794
795/// Converts either `Column::Series` or `Column::Scalar` to `Series`
796fn to_series(col: Column) -> anyhow::Result<Series> {
797    match col.clone() {
798        Column::Series(s) => Ok(s.take()),
799        Column::Scalar(s) => Ok(s.to_series()),
800        Column::Partitioned(_) => bail!("{}\nPartitioned column!", format_dbg!()),
801    }
802}
803
804#[allow(unused_variables)]
805#[cfg(feature = "pyo3")]
806#[pyfunction]
807pub fn run_speed_limit_train_sims(
808    mut speed_limit_train_sims: SpeedLimitTrainSimVec,
809    network: &Bound<PyAny>,
810    train_consist_plan_py: PyDataFrame,
811    loco_pool_py: PyDataFrame,
812    refuel_facilities_py: PyDataFrame,
813    timed_paths: Vec<TimedLinkPath>,
814) -> anyhow::Result<(SpeedLimitTrainSimVec, PyDataFrame)> {
815    let network = match network.extract::<Network>() {
816        Ok(n) => n,
817        Err(_) => {
818            let n = network
819                .extract::<Vec<Link>>()
820                .map_err(|_| anyhow!("{}", format_dbg!()))?;
821            Network(Default::default(), n)
822        }
823    };
824
825    let train_consist_plan: DataFrame = train_consist_plan_py.clone().into();
826    let mut loco_pool: DataFrame = loco_pool_py.clone().into();
827    let refuel_facilities: DataFrame = refuel_facilities_py.clone().into();
828
829    loco_pool = loco_pool
830        .lazy()
831        .with_columns(vec![
832            lit(f64::ZERO).alias("Trip_Energy_J").to_owned(),
833            lit(f64::INFINITY).alias("Ready_Time_Min").to_owned(),
834            lit(f64::INFINITY).alias("Ready_Time_Est").to_owned(),
835            lit("Ready").alias("Status").to_owned(),
836            col("SOC_Max_J").alias("SOC_J").to_owned(),
837        ])
838        .collect()
839        .with_context(|| format_dbg!())?;
840
841    let mut arrival_times = train_consist_plan
842        .clone()
843        .lazy()
844        .select(vec![
845            col("Arrival_Time_Actual_Hr"),
846            col("Locomotive_ID"),
847            col("Destination_ID"),
848            col("TrainSimVec_Index"),
849        ])
850        .sort_by_exprs(
851            vec![col("Arrival_Time_Actual_Hr"), col("Locomotive_ID")],
852            SortMultipleOptions::default(),
853        )
854        .collect()
855        .with_context(|| format_dbg!())?;
856
857    let departure_times = train_consist_plan
858        .clone()
859        .lazy()
860        .select(vec![col("Departure_Time_Actual_Hr"), col("Locomotive_ID")])
861        .sort_by_exprs(
862            vec![col("Locomotive_ID"), col("Departure_Time_Actual_Hr")],
863            SortMultipleOptions::default(),
864            // vec![false, false],
865            // false,
866            // false,
867        )
868        .collect()
869        .with_context(|| format_dbg!())?;
870
871    let mut refuel_sessions = DataFrame::default();
872
873    let active_loco_statuses =
874        Series::from_iter(vec!["Refueling".to_string(), "Dispatched".to_string()]);
875    let mut current_time: f64 = arrival_times
876        .column("Arrival_Time_Actual_Hr")
877        .with_context(|| format_dbg!())?
878        .f64()
879        .with_context(|| format_dbg!())?
880        .min()
881        .unwrap();
882
883    let mut done = false;
884    while !done {
885        let arrivals_mask = arrival_times
886            .column("Arrival_Time_Actual_Hr")?
887            .equal(&Column::new(
888                "current_time_const".into(),
889                vec![current_time; arrival_times.height()],
890            ))
891            .with_context(|| format_dbg!())?;
892        let arrivals = arrival_times
893            .clone()
894            .filter(&arrivals_mask)
895            .with_context(|| format_dbg!())?;
896        let arrivals_merged = loco_pool
897            .clone()
898            .left_join(&arrivals, ["Locomotive_ID"], ["Locomotive_ID"])
899            .with_context(|| format_dbg!())?;
900        let arrival_locations = arrivals_merged.column("Destination_ID")?;
901        if arrivals.height() > 0 {
902            let arrival_ids = arrivals
903                .column("Locomotive_ID")
904                .with_context(|| format_dbg!())?;
905            loco_pool = loco_pool
906                .lazy()
907                .with_columns(vec![
908                    when(col("Locomotive_ID").is_in(lit(
909                        to_series(arrival_ids.clone()).with_context(|| format_dbg!())?,
910                    )))
911                    .then(lit("Queued"))
912                    .otherwise(col("Status"))
913                    .alias("Status"),
914                    when(col("Locomotive_ID").is_in(lit(
915                        to_series(arrival_ids.clone()).with_context(|| format_dbg!())?,
916                    )))
917                    .then(lit(current_time))
918                    .otherwise(col("Ready_Time_Est"))
919                    .alias("Ready_Time_Est"),
920                    when(col("Locomotive_ID").is_in(lit(
921                        to_series(arrival_ids.clone()).with_context(|| format_dbg!())?,
922                    )))
923                    .then(lit(arrival_locations
924                        .clone()
925                        .as_series()
926                        .with_context(|| format_dbg!())?
927                        .clone()))
928                    .otherwise(col("Node"))
929                    .alias("Node"),
930                ])
931                .drop(vec![
932                    "Refueler_J_Per_Hr",
933                    "Refueler_Efficiency",
934                    "Port_Count",
935                    "Battery_Headroom_J",
936                ])
937                .join(
938                    refuel_facilities.clone().lazy().select(&[
939                        col("Node"),
940                        col("Locomotive_Type"),
941                        col("Fuel_Type"),
942                        col("Refueler_J_Per_Hr"),
943                        col("Refueler_Efficiency"),
944                        col("Port_Count"),
945                        col("Battery_Headroom_J"),
946                    ]),
947                    [col("Node"), col("Locomotive_Type"), col("Fuel_Type")],
948                    [col("Node"), col("Locomotive_Type"), col("Fuel_Type")],
949                    JoinArgs::new(JoinType::Left),
950                )
951                .with_columns(vec![col("Battery_Headroom_J").fill_null(0)])
952                .with_columns(vec![max_horizontal([
953                    col("SOC_Max_J") - col("Battery_Headroom_J"),
954                    col("SOC_Min_J"),
955                ])
956                .with_context(|| format_dbg!())?
957                .alias("SOC_Target_J")])
958                .sort(["Locomotive_ID"], SortMultipleOptions::default())
959                .collect()
960                .with_context(|| format_dbg!())?;
961
962            let indices = arrivals
963                .column("TrainSimVec_Index")
964                .with_context(|| format_dbg!())?
965                .u32()
966                .with_context(|| format_dbg!())?
967                .unique()
968                .with_context(|| format_dbg!())?;
969            for index in indices.into_iter() {
970                let idx = index.unwrap() as usize;
971                let departing_soc_pct_iter = train_consist_plan
972                    .clone()
973                    .lazy()
974                    // retain rows in which "TrainSimVec_Index" equals current `index`
975                    .filter(col("TrainSimVec_Index").eq(index.unwrap()))
976                    // Select "Locomotive_ID" column
977                    .select(vec![col("Locomotive_ID")])
978                    // find unique locomotive IDs
979                    .unique(None, UniqueKeepStrategy::First)
980                    .join(
981                        loco_pool.clone().lazy(),
982                        [col("Locomotive_ID")],
983                        [col("Locomotive_ID")],
984                        JoinArgs::new(JoinType::Left),
985                    )
986                    .sort(["Locomotive_ID"], SortMultipleOptions::default())
987                    .with_columns(vec![(col("SOC_J") / col("Capacity_J")).alias("SOC_Pct")])
988                    .collect()
989                    .with_context(|| format_dbg!())?;
990
991                let departing_soc_pct = to_series(
992                    departing_soc_pct_iter
993                        .column("SOC_Pct")
994                        .with_context(|| format_dbg!())?
995                        .clone(),
996                )
997                .with_context(|| format_dbg!())?;
998
999                let departing_soc_pct_vec: Vec<f64> = departing_soc_pct
1000                    .f64()
1001                    .with_context(|| format_dbg!())?
1002                    .into_no_null_iter()
1003                    .collect();
1004                let sim = &mut speed_limit_train_sims.0[idx];
1005                sim.loco_con
1006                    .loco_vec
1007                    .iter_mut()
1008                    .zip(departing_soc_pct_vec)
1009                    .try_for_each(|(loco, soc)| {
1010                        if let Some(res) = &mut loco.reversible_energy_storage_mut() {
1011                            res.state
1012                                .soc
1013                                .update_unchecked(soc * uc::R, || format_dbg!())
1014                        } else {
1015                            Ok(())
1016                        }
1017                    })?;
1018                let _ = sim
1019                    .walk_timed_path(&network, &timed_paths[idx])
1020                    .map_err(|err| err.context(format!("train sim idx: {}", idx)));
1021
1022                let mut new_soc_vec: Vec<f64> = vec![];
1023                sim.loco_con
1024                    .loco_vec
1025                    .iter()
1026                    .try_for_each(|loco| -> anyhow::Result<()> {
1027                        match loco.loco_type {
1028                            PowertrainType::BatteryElectricLoco(_) => {
1029                                new_soc_vec.push(
1030                                    (*loco
1031                                        .reversible_energy_storage()
1032                                        .unwrap()
1033                                        .state
1034                                        .soc
1035                                        .get_fresh(|| format_dbg!())?
1036                                        * loco
1037                                            .reversible_energy_storage()
1038                                            .unwrap()
1039                                            .energy_capacity)
1040                                        .get::<si::joule>(),
1041                                );
1042                            }
1043                            _ => new_soc_vec.push(f64::ZERO),
1044                        }
1045                        Ok(())
1046                    })
1047                    .with_context(|| format_dbg!())?;
1048                let mut new_energy_j_vec: Vec<f64> = vec![];
1049                sim.loco_con
1050                    .loco_vec
1051                    .iter()
1052                    .try_for_each(|loco| -> anyhow::Result<()> {
1053                        new_energy_j_vec.push(match loco.loco_type {
1054                            PowertrainType::BatteryElectricLoco(_) => loco
1055                                .reversible_energy_storage()
1056                                .unwrap()
1057                                .state
1058                                .energy_out_chemical
1059                                .get_fresh(|| format_dbg!())?
1060                                .get::<si::joule>(),
1061                            _ => f64::ZERO,
1062                        });
1063                        Ok(())
1064                    })
1065                    .with_context(|| format_dbg!())?;
1066                let mut all_current_socs: Vec<f64> = loco_pool
1067                    .column("SOC_J")
1068                    .with_context(|| format_dbg!())?
1069                    .f64()
1070                    .with_context(|| format_dbg!())?
1071                    .into_no_null_iter()
1072                    .collect();
1073                let mut all_energy_j: Vec<f64> = (loco_pool
1074                    .column("SOC_J")
1075                    .with_context(|| format_dbg!())?
1076                    .f64()?
1077                    * 0.0)
1078                    .into_no_null_iter()
1079                    .collect();
1080                let idx_mask = arrival_times
1081                    .column("TrainSimVec_Index")
1082                    .with_context(|| format_dbg!())?
1083                    .equal(&Column::new(
1084                        "idx_const".into(),
1085                        vec![idx as u32; arrival_times.height()],
1086                    ))
1087                    .with_context(|| format_dbg!())?;
1088                let arrival_locos = arrival_times
1089                    .filter(&idx_mask)
1090                    .with_context(|| format_dbg!())?;
1091                let arrival_loco_ids = arrival_locos
1092                    .column("Locomotive_ID")
1093                    .with_context(|| format_dbg!())?
1094                    .u32()
1095                    .with_context(|| format_dbg!())?;
1096                let arrival_loco_mask: ChunkedArray<BooleanType> = is_in(
1097                    loco_pool
1098                        .column("Locomotive_ID")
1099                        .with_context(|| format_dbg!())?
1100                        .as_series()
1101                        .with_context(|| format_dbg!())?,
1102                    &Series::from(arrival_loco_ids.clone()),
1103                )
1104                .with_context(|| format_dbg!())?;
1105
1106                // Get the indices of true values in the boolean ChunkedArray
1107                let arrival_loco_indices: Vec<usize> = arrival_loco_mask
1108                    .into_iter()
1109                    .enumerate()
1110                    .filter(|(_, val)| val.unwrap_or_default())
1111                    .map(|(i, _)| i)
1112                    .collect();
1113
1114                // TODO: rewrite this a little so it doesn't depend on the previous sort
1115                for (index, value) in arrival_loco_indices.iter().zip(new_soc_vec) {
1116                    all_current_socs[*index] = value;
1117                }
1118                for (index, value) in arrival_loco_indices.iter().zip(new_energy_j_vec) {
1119                    all_energy_j[*index] = value;
1120                }
1121                loco_pool = loco_pool
1122                    .lazy()
1123                    .with_columns(vec![
1124                        when(lit(arrival_loco_mask.clone().into_series()))
1125                            .then(lit(Series::new("SOC_J".into(), all_current_socs)))
1126                            .otherwise(col("SOC_J"))
1127                            .alias("SOC_J"),
1128                        when(lit(arrival_loco_mask.into_series()))
1129                            .then(lit(Series::new("Trip_Energy_J".into(), all_energy_j)))
1130                            .otherwise(col("Trip_Energy_J"))
1131                            .alias("Trip_Energy_J"),
1132                    ])
1133                    .collect()
1134                    .with_context(|| format_dbg!())?;
1135            }
1136            loco_pool = loco_pool
1137                .lazy()
1138                .sort(["Ready_Time_Est"], SortMultipleOptions::default())
1139                .collect()
1140                .with_context(|| format_dbg!())?;
1141        }
1142
1143        let refueling_mask = (loco_pool)
1144            .column("Status")
1145            .with_context(|| format_dbg!())?
1146            .equal(&Column::new(
1147                "refueling_const".into(),
1148                vec!["Refueling"; loco_pool.height()],
1149            ))
1150            .with_context(|| format_dbg!())?;
1151        let refueling_finished_mask = refueling_mask.clone()
1152            & (loco_pool)
1153                .column("Ready_Time_Est")
1154                .with_context(|| format_dbg!())?
1155                .equal(&Column::new(
1156                    "current_time_const".into(),
1157                    vec![current_time; refueling_mask.len()],
1158                ))
1159                .with_context(|| format_dbg!())?;
1160        let refueling_finished = loco_pool
1161            .clone()
1162            .filter(&refueling_finished_mask)
1163            .with_context(|| format_dbg!())?;
1164        if refueling_finished_mask.sum().unwrap_or_default() > 0 {
1165            loco_pool = loco_pool
1166                .lazy()
1167                .with_columns(vec![when(lit(refueling_finished_mask.into_series()))
1168                    .then(lit("Ready"))
1169                    .otherwise(col("Status"))
1170                    .alias("Status")])
1171                .collect()
1172                .with_context(|| format_dbg!())?;
1173        }
1174
1175        if (arrivals.height() > 0) || (refueling_finished.height() > 0) {
1176            // update queue
1177            let place_in_queue_iter = loco_pool
1178                .clone()
1179                .lazy()
1180                .select(&[((col("Status").eq(lit("Refueling")).sum().over([
1181                    "Node",
1182                    "Locomotive_Type",
1183                    "Fuel_Type",
1184                ])) + (col("Status").eq(lit("Queued")).over([
1185                    "Node",
1186                    "Locomotive_Type",
1187                    "Fuel_Type",
1188                ])))
1189                .alias("place_in_queue")])
1190                .collect()?;
1191            let place_in_queue = place_in_queue_iter
1192                .column("place_in_queue")?
1193                .as_series()
1194                .with_context(|| format_dbg!())?;
1195            let future_times_mask = departure_times
1196                .column("Departure_Time_Actual_Hr")?
1197                .f64()?
1198                .gt(current_time);
1199
1200            let next_departure_time = departure_times
1201                .clone()
1202                .lazy()
1203                .filter(col("Departure_Time_Actual_Hr").gt(lit(current_time)))
1204                .group_by(["Locomotive_ID"])
1205                .agg([col("Departure_Time_Actual_Hr").min()])
1206                .collect()
1207                .with_context(|| format_dbg!())?;
1208
1209            let departures_merged = loco_pool.clone().left_join(
1210                &next_departure_time,
1211                ["Locomotive_ID"],
1212                ["Locomotive_ID"],
1213            )?;
1214            let departure_times = departures_merged
1215                .column("Departure_Time_Actual_Hr")?
1216                .f64()?;
1217
1218            let target_j = loco_pool
1219                .clone()
1220                .lazy()
1221                .select(&[(col("SOC_Max_J") - col("Battery_Headroom_J")).alias("Target_J")])
1222                .collect()?
1223                .column("Target_J")?
1224                .clone();
1225            let target_j_f64 = target_j.f64()?;
1226            let current_j = loco_pool.column("SOC_J")?.f64()?;
1227
1228            let soc_target: Vec<f64> = target_j_f64
1229                .into_iter()
1230                .zip(current_j.into_iter())
1231                .map(|(b, v)| b.unwrap_or(f64::ZERO).max(v.unwrap_or(f64::ZERO)))
1232                .collect::<Vec<_>>();
1233            let soc_target_series = Series::new("soc_target".into(), soc_target);
1234
1235            let refuel_end_time_ideal_iter = loco_pool
1236                .clone()
1237                .lazy()
1238                .select(&[(lit(current_time)
1239                    + (max_horizontal([col("SOC_J"), col("SOC_Target_J")])? - col("SOC_J"))
1240                        / col("Refueler_J_Per_Hr"))
1241                .alias("Refuel_End_Time")])
1242                .collect()?;
1243            let refuel_end_time_ideal = refuel_end_time_ideal_iter
1244                .column("Refuel_End_Time")?
1245                .as_series()
1246                .with_context(|| format_dbg!())?;
1247
1248            let refuel_end_time: Vec<f64> = departure_times
1249                .into_iter()
1250                .zip(refuel_end_time_ideal.f64()?.into_iter())
1251                .map(|(b, v)| b.unwrap_or(f64::INFINITY).min(v.unwrap_or(f64::INFINITY)))
1252                .collect::<Vec<_>>();
1253
1254            let mut refuel_duration: Vec<f64> = refuel_end_time.clone();
1255            for element in refuel_duration.iter_mut() {
1256                *element -= current_time;
1257            }
1258
1259            let refuel_duration_series = Series::new("refuel_duration".into(), refuel_duration);
1260            let refuel_end_series = Series::new("refuel_end_time".into(), refuel_end_time);
1261
1262            loco_pool = loco_pool
1263                .lazy()
1264                .with_columns(vec![
1265                    lit(place_in_queue.clone()),
1266                    lit(refuel_duration_series.clone()),
1267                    lit(refuel_end_series.clone()),
1268                ])
1269                .collect()
1270                .with_context(|| format_dbg!())?;
1271
1272            // store the filter as an Expr
1273            let refuel_starting = loco_pool
1274                .clone()
1275                .lazy()
1276                .filter(
1277                    col("Status")
1278                        .eq(lit("Queued"))
1279                        .and(col("Port_Count").gt_eq(col("place_in_queue"))),
1280                )
1281                .collect()
1282                .with_context(|| format_dbg!())?;
1283
1284            let these_refuel_sessions = refuel_starting
1285                .clone()
1286                .lazy()
1287                .with_columns(vec![
1288                    (col("Refueler_J_Per_Hr") * col("refuel_duration")
1289                        / col("Refueler_Efficiency"))
1290                    .alias("Refuel_Energy_J"),
1291                    (col("refuel_end_time") - col("refuel_duration")).alias("Refuel_Start_Time_Hr"),
1292                ])
1293                .rename(
1294                    ["refuel_end_time", "refuel_duration"],
1295                    ["Refuel_End_Time_Hr", "Refuel_Duration_Hr"],
1296                    true,
1297                )
1298                .select(vec![
1299                    col("Node"),
1300                    col("Locomotive_Type"),
1301                    col("Fuel_Type"),
1302                    col("Locomotive_ID"),
1303                    col("Refueler_J_Per_Hr"),
1304                    col("Refueler_Efficiency"),
1305                    col("Trip_Energy_J"),
1306                    col("SOC_J"),
1307                    col("Refuel_Energy_J"),
1308                    col("Refuel_Duration_Hr"),
1309                    col("Refuel_Start_Time_Hr"),
1310                    col("Refuel_End_Time_Hr"),
1311                ])
1312                .collect()
1313                .with_context(|| format_dbg!())?;
1314            refuel_sessions.vstack_mut(&these_refuel_sessions)?;
1315            // set finishedCharging times to min(max soc OR departure time)
1316            loco_pool = loco_pool
1317                .clone()
1318                .lazy()
1319                .with_columns(vec![
1320                    when(
1321                        col("Status")
1322                            .eq(lit("Queued"))
1323                            .and(col("Port_Count").gt_eq(col("place_in_queue"))),
1324                    )
1325                    .then(col("SOC_J") + col("refuel_duration") * col("Refueler_J_Per_Hr"))
1326                    .otherwise(col("SOC_J"))
1327                    .alias("SOC_J"),
1328                    when(
1329                        col("Status")
1330                            .eq(lit("Queued"))
1331                            .and(col("Port_Count").gt_eq(col("place_in_queue"))),
1332                    )
1333                    .then(col("refuel_end_time"))
1334                    .otherwise(col("Ready_Time_Est"))
1335                    .alias("Ready_Time_Est"),
1336                    when(
1337                        col("Status")
1338                            .eq(lit("Queued"))
1339                            .and(col("Port_Count").gt_eq(col("place_in_queue"))),
1340                    )
1341                    .then(lit("Refueling"))
1342                    .otherwise(col("Status"))
1343                    .alias("Status"),
1344                ])
1345                .collect()
1346                .with_context(|| format_dbg!())?;
1347
1348            loco_pool = loco_pool.drop("place_in_queue")?;
1349            loco_pool = loco_pool.drop("refuel_duration")?;
1350            loco_pool = loco_pool.drop("refuel_end_time")?;
1351        }
1352
1353        let active_loco_ready_times_iter = loco_pool
1354            .clone()
1355            .lazy()
1356            .filter(col("Status").is_in(lit(active_loco_statuses.clone())))
1357            .select(vec![col("Ready_Time_Est")])
1358            .collect()?;
1359        let active_loco_ready_times = active_loco_ready_times_iter
1360            .column("Ready_Time_Est")
1361            .with_context(|| format_dbg!(active_loco_ready_times_iter))?;
1362        arrival_times = arrival_times
1363            .lazy()
1364            .filter(col("Arrival_Time_Actual_Hr").gt(current_time))
1365            .collect()?;
1366        let arrival_times_remaining_iter = arrival_times
1367            .clone()
1368            .lazy()
1369            .select(vec![
1370                col("Arrival_Time_Actual_Hr").alias("Arrival_Time_Actual_Hr")
1371            ])
1372            .collect()?;
1373        let arrival_times_remaining = arrival_times_remaining_iter
1374            .column("Arrival_Time_Actual_Hr")
1375            .with_context(|| format_dbg!())?;
1376
1377        if (arrival_times_remaining.is_empty()) & (active_loco_ready_times.is_empty()) {
1378            done = true;
1379        } else {
1380            let min1 = active_loco_ready_times
1381                .f64()?
1382                .min()
1383                .unwrap_or(f64::INFINITY);
1384            let min2 = arrival_times_remaining
1385                .f64()?
1386                .min()
1387                .unwrap_or(f64::INFINITY);
1388            current_time = f64::min(min1, min2);
1389        }
1390    }
1391
1392    Ok((speed_limit_train_sims, PyDataFrame(refuel_sessions)))
1393}
1394
1395// This MUST remain a unit struct to trigger correct tolist() behavior
1396#[serde_api]
1397#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq)]
1398#[cfg_attr(feature = "pyo3", pyclass(module = "altrios", subclass, eq))]
1399pub struct SpeedLimitTrainSimVec(pub Vec<SpeedLimitTrainSim>);
1400
1401#[pyo3_api]
1402impl SpeedLimitTrainSimVec {
1403    #![allow(non_snake_case)]
1404    #[pyo3(name = "get_energy_fuel_joules")]
1405    pub fn get_energy_fuel_py(&self, annualize: bool) -> anyhow::Result<f64> {
1406        Ok(self.get_energy_fuel(annualize)?.get::<si::joule>())
1407    }
1408
1409    #[pyo3(name = "get_net_energy_res_joules")]
1410    pub fn get_net_energy_res_py(&self, annualize: bool) -> anyhow::Result<f64> {
1411        Ok(self.get_net_energy_res(annualize)?.get::<si::joule>())
1412    }
1413
1414    #[pyo3(name = "get_kilometers")]
1415    pub fn get_kilometers_py(&self, annualize: bool) -> anyhow::Result<f64> {
1416        self.get_kilometers(annualize)
1417    }
1418
1419    #[pyo3(name = "get_megagram_kilometers")]
1420    pub fn get_megagram_kilometers_py(&self, annualize: bool) -> anyhow::Result<f64> {
1421        self.get_megagram_kilometers(annualize)
1422    }
1423
1424    #[pyo3(name = "get_car_kilometers")]
1425    pub fn get_car_kilometers_py(&self, annualize: bool) -> anyhow::Result<f64> {
1426        self.get_car_kilometers(annualize)
1427    }
1428
1429    #[pyo3(name = "get_cars_moved")]
1430    pub fn get_cars_moved_py(&self, annualize: bool) -> f64 {
1431        self.get_cars_moved(annualize)
1432    }
1433
1434    #[pyo3(name = "get_res_kilometers")]
1435    pub fn get_res_kilometers_py(&mut self, annualize: bool) -> anyhow::Result<f64> {
1436        self.get_res_kilometers(annualize)
1437    }
1438
1439    #[pyo3(name = "get_non_res_kilometers")]
1440    pub fn get_non_res_kilometers_py(&mut self, annualize: bool) -> anyhow::Result<f64> {
1441        self.get_non_res_kilometers(annualize)
1442    }
1443
1444    #[pyo3(name = "set_save_interval")]
1445    #[pyo3(signature = (save_interval=None))]
1446    pub fn set_save_interval_py(&mut self, save_interval: Option<usize>) {
1447        self.set_save_interval(save_interval);
1448    }
1449
1450    #[new]
1451    /// Rust-defined `__new__` magic method for Python used exposed via PyO3.
1452    fn __new__(v: Vec<SpeedLimitTrainSim>) -> Self {
1453        Self(v)
1454    }
1455}
1456
1457impl SpeedLimitTrainSimVec {
1458    pub fn new(value: Vec<SpeedLimitTrainSim>) -> Self {
1459        Self(value)
1460    }
1461
1462    pub fn get_energy_fuel(&self, annualize: bool) -> anyhow::Result<si::Energy> {
1463        self.0.iter().try_fold(si::Energy::ZERO, |acc, sim| {
1464            Ok(acc + sim.get_energy_fuel(annualize)?)
1465        })
1466    }
1467
1468    pub fn get_net_energy_res(&self, annualize: bool) -> anyhow::Result<si::Energy> {
1469        self.0.iter().try_fold(si::Energy::ZERO, |acc, sim| {
1470            Ok(acc + sim.get_net_energy_res(annualize)?)
1471        })
1472    }
1473
1474    pub fn get_kilometers(&self, annualize: bool) -> anyhow::Result<f64> {
1475        self.0
1476            .iter()
1477            .try_fold(0.0, |acc, sim| Ok(acc + sim.get_kilometers(annualize)?))
1478    }
1479
1480    pub fn get_megagram_kilometers(&self, annualize: bool) -> anyhow::Result<f64> {
1481        self.0.iter().try_fold(0.0, |acc, sim| {
1482            Ok(acc + sim.get_megagram_kilometers(annualize)?)
1483        })
1484    }
1485
1486    pub fn get_car_kilometers(&self, annualize: bool) -> anyhow::Result<f64> {
1487        self.0
1488            .iter()
1489            .try_fold(0.0, |acc, sim| Ok(acc + sim.get_car_kilometers(annualize)?))
1490    }
1491
1492    pub fn get_cars_moved(&self, annualize: bool) -> f64 {
1493        self.0.iter().map(|sim| sim.get_cars_moved(annualize)).sum()
1494    }
1495
1496    pub fn get_res_kilometers(&mut self, annualize: bool) -> anyhow::Result<f64> {
1497        self.0
1498            .iter_mut()
1499            .try_fold(0.0, |acc, sim| Ok(acc + sim.get_res_kilometers(annualize)?))
1500    }
1501
1502    pub fn get_non_res_kilometers(&mut self, annualize: bool) -> anyhow::Result<f64> {
1503        self.0.iter_mut().try_fold(0.0, |acc, sim| {
1504            Ok(acc + sim.get_non_res_kilometers(annualize)?)
1505        })
1506    }
1507
1508    pub fn set_save_interval(&mut self, save_interval: Option<usize>) {
1509        self.0
1510            .iter_mut()
1511            .for_each(|slts| slts.set_save_interval(save_interval));
1512    }
1513}
1514
1515impl Init for SpeedLimitTrainSimVec {
1516    fn init(&mut self) -> Result<(), Error> {
1517        self.0.iter_mut().try_for_each(|ts| ts.init())?;
1518        Ok(())
1519    }
1520}
1521impl SerdeAPI for SpeedLimitTrainSimVec {}