renfe_cli/
renfe.rs

1use chrono::{Datelike, NaiveDate, NaiveTime, TimeDelta, Timelike};
2use gtfs_structures::Gtfs;
3use pyo3::{exceptions::PyValueError, pyclass, pymethods, PyResult};
4use std::io::Read;
5
6#[pyclass]
7pub struct Renfe {
8    gtfs: Gtfs,
9    schedules: Vec<Schedule>,
10}
11
12// Struct to hold the schedule details
13#[pyclass]
14pub struct Schedule {
15    train_type: String,
16    // origin_stop_name: String,
17    // destination_stop_name: String,
18    departure_time: NaiveTime,
19    arrival_time: NaiveTime,
20    duration: TimeDelta,
21}
22
23// Struct to hold the station name and ID
24#[pyclass]
25#[derive(Debug, Clone)]
26pub struct Station {
27    pub name: String,
28    pub id: String,
29}
30
31#[pymethods]
32impl Renfe {
33    #[new]
34    pub fn new(cercanias: bool) -> PyResult<Self> {
35        let mut res = reqwest::blocking::get(
36            match cercanias {
37                false => {
38                    println!("Loading default GTFS data from Renfe web - Alta velocidad, Larga distancia y Media distancia");
39                    "https://ssl.renfe.com/gtransit/Fichero_AV_LD/google_transit.zip"
40                },
41                true => {
42                    println!("Loading Cercanías GTFS data from Renfe web - long load time");
43                    "https://ssl.renfe.com/ftransit/Fichero_CER_FOMENTO/fomento_transit.zip"
44                },
45            },
46        )
47        .expect("Error downloading GTFS zip file");
48        let mut body = Vec::new();
49        res.read_to_end(&mut body)?;
50        let cursor = std::io::Cursor::new(body);
51
52        let gtfs = Gtfs::from_reader(cursor).expect("Error parsing GTFS zip");
53
54        gtfs.print_stats();
55
56        Ok(Renfe {
57            gtfs,
58            schedules: Vec::new(),
59        })
60    }
61
62    pub fn all_stations(&self) -> PyResult<Vec<Station>> {
63        let stations: Vec<Station> = self
64            .gtfs
65            .stops
66            .iter()
67            .map(|s| Station {
68                name: s.1.name.clone().unwrap(),
69                id: s.1.id.clone(),
70            })
71            .collect();
72        Ok(stations)
73    }
74
75    pub fn stations_match(&self, station: String) -> PyResult<Vec<Station>> {
76        let found: Vec<Station> = self
77            .gtfs
78            .stops
79            .iter()
80            .filter(|s| {
81                s.1.name
82                    .clone()
83                    .unwrap()
84                    .to_lowercase()
85                    .contains(&station.to_lowercase())
86            })
87            .map(|s| Station {
88                name: s.1.name.clone().unwrap(),
89                id: s.1.id.clone(),
90            })
91            .collect();
92        Ok(found)
93    }
94
95    pub fn filter_station(&self, station: String) -> PyResult<Station> {
96        match self.stations_match(station.clone()) {
97            Ok(v) if v.len() == 1 => {
98                println!(
99                    "Provided input '{}' does a match with '{:?}'",
100                    station, v[0]
101                );
102                Ok(v[0].clone())
103            }
104            Ok(v) => Err(PyValueError::new_err(format!(
105                "Provided input '{station}' does match with '{v:?}' -> There must be ONLY one match"
106            ))),
107            Err(e) => Err(e),
108        }
109    }
110
111    // Function to get train schedules between an origin and a destination on a given date
112    pub fn set_train_schedules(
113        &mut self,
114        origin_station_id: &str,
115        destination_station_id: &str,
116        day: u32,
117        month: u32,
118        year: i32,
119        sorted: bool,
120    ) -> PyResult<()> {
121        let gtfs = &self.gtfs;
122        // the date for which schedules are needed
123        let date = match NaiveDate::from_ymd_opt(year, month, day) {
124            Some(date) => date,
125            None => {
126                return Err(PyValueError::new_err(format!(
127                    "Provided date '{year}-{month}-{day}' does not exist"
128                )))
129            }
130        };
131
132        let mut schedules = Vec::new();
133
134        // Loop through each trip to find ones active on the given date
135        for trip in gtfs.trips.values() {
136            // Check if the trip's service is active on the given date
137            if is_service_active(gtfs, &trip.service_id, date) {
138                // Filter stop times for the trip
139                let stop_times: Vec<_> = trip.stop_times.clone();
140
141                // Find the origin and destination stops in the trip's stop times
142                let origin_stop = stop_times.iter().find(|st| st.stop.id == origin_station_id);
143                let destination_stop = stop_times
144                    .iter()
145                    .find(|st| st.stop.id == destination_station_id);
146
147                // If the trip includes both origin and destination, and origin is before destination
148                if let (Some(origin), Some(destination)) = (origin_stop, destination_stop) {
149                    if origin.stop_sequence < destination.stop_sequence {
150                        let time_origin = origin.departure_time.unwrap();
151                        let time_destination = destination.arrival_time.unwrap();
152                        let departure_time = NaiveTime::from_hms_opt(
153                            (time_origin / 3600) % 24,
154                            time_origin % 3600 / 60,
155                            time_origin % 60,
156                        )
157                        .unwrap();
158                        let arrival_time = NaiveTime::from_hms_opt(
159                            (time_destination / 3600) % 24,
160                            time_destination % 3600 / 60,
161                            time_destination % 60,
162                        )
163                        .unwrap();
164
165                        let mut duration = arrival_time.signed_duration_since(departure_time);
166                        if time_destination >= 86400 {
167                            duration = duration.checked_add(&TimeDelta::seconds(86400)).unwrap();
168                        }
169
170                        schedules.push(Schedule {
171                            train_type: gtfs
172                                .get_route(&trip.route_id)
173                                .unwrap()
174                                .short_name
175                                .clone()
176                                .unwrap(),
177                            // origin_stop_name: gtfs.stops[&origin.stop.id].name.clone().unwrap(),
178                            // destination_stop_name: gtfs.stops[&destination.stop.id]
179                            //     .name
180                            //     .clone()
181                            //     .unwrap(),
182                            departure_time,
183                            arrival_time,
184                            duration,
185                        });
186                    }
187                }
188            }
189        }
190
191        // Sort schedules by departure_time
192        schedules.sort_by_key(|schedule| schedule.departure_time);
193
194        if sorted {
195            println!("sorting timetable by duration");
196            schedules.sort_by(|a, b| a.duration.cmp(&b.duration));
197        }
198
199        self.schedules = schedules;
200
201        Ok(())
202    }
203
204    pub fn print_timetable(&self) {
205        if self.schedules.is_empty() {
206            println!("\nNo schedules available...won't print timetable.");
207        } else {
208            println!("\n=========================TIMETABLE=========================");
209            println!(
210                "  {0: <12} |   {1: <10} |   {2: <10} |   {3: <12}",
211                "Train", "Departure", "Arrival", "Duration"
212            );
213            for track in &self.schedules {
214                println!("-----------------------------------------------------------");
215                println!(
216                    "   {0: <11} |    {1: <9} |    {2: <9} |    {3: <10}",
217                    track.train_type,
218                    format!(
219                        "{:02}:{:02}",
220                        track.departure_time.hour(),
221                        track.departure_time.minute() % 60
222                    ),
223                    format!(
224                        "{:02}:{:02}",
225                        track.arrival_time.hour(),
226                        track.arrival_time.minute() % 60
227                    ),
228                    format!(
229                        "{:02}:{:02}",
230                        track.duration.num_hours(),
231                        track.duration.num_minutes() % 60
232                    )
233                );
234            }
235            println!("===========================================================");
236        }
237    }
238}
239
240// Helper function to check if a service is active on a given date
241fn is_service_active(gtfs: &Gtfs, service_id: &str, date: NaiveDate) -> bool {
242    // First check the `calendar.txt`
243    if let Some(calendar) = gtfs.calendar.get(service_id) {
244        let weekday = match date.weekday() {
245            chrono::Weekday::Mon => calendar.monday,
246            chrono::Weekday::Tue => calendar.tuesday,
247            chrono::Weekday::Wed => calendar.wednesday,
248            chrono::Weekday::Thu => calendar.thursday,
249            chrono::Weekday::Fri => calendar.friday,
250            chrono::Weekday::Sat => calendar.saturday,
251            chrono::Weekday::Sun => calendar.sunday,
252        };
253
254        if weekday && date >= calendar.start_date && date <= calendar.end_date {
255            // this should never happen - but a check is for free
256            if let Some(calendar_dates) = gtfs.calendar_dates.get(service_id) {
257                for date_override in calendar_dates {
258                    if date_override.date == date {
259                        return !(date_override.exception_type
260                            == gtfs_structures::Exception::Deleted);
261                    }
262                }
263            }
264            return true;
265        }
266    }
267
268    // Then check the `calendar_dates.txt` for exceptions
269    if let Some(calendar_dates) = gtfs.calendar_dates.get(service_id) {
270        for date_override in calendar_dates {
271            if date_override.date == date {
272                return date_override.exception_type == gtfs_structures::Exception::Added;
273            }
274        }
275    }
276
277    false
278}