1use chrono::{Datelike, NaiveDate, NaiveTime, TimeDelta, Timelike};
2use gtfs_structures::Gtfs;
3use pyo3::{exceptions::PyValueError, pyclass, pymethods, PyResult};
4use std::io::Read;
5
6#[pyclass]
7pub struct Renfe {
8 gtfs: Gtfs,
9 schedules: Vec<Schedule>,
10}
11
12#[pyclass]
14pub struct Schedule {
15 train_type: String,
16 departure_time: NaiveTime,
19 arrival_time: NaiveTime,
20 duration: TimeDelta,
21}
22
23#[pyclass]
25#[derive(Debug, Clone)]
26pub struct Station {
27 pub name: String,
28 pub id: String,
29}
30
31#[pymethods]
32impl Renfe {
33 #[new]
34 pub fn new(cercanias: bool) -> PyResult<Self> {
35 let mut res = reqwest::blocking::get(
36 match cercanias {
37 false => {
38 println!("Loading default GTFS data from Renfe web - Alta velocidad, Larga distancia y Media distancia");
39 "https://ssl.renfe.com/gtransit/Fichero_AV_LD/google_transit.zip"
40 },
41 true => {
42 println!("Loading CercanÃas GTFS data from Renfe web - long load time");
43 "https://ssl.renfe.com/ftransit/Fichero_CER_FOMENTO/fomento_transit.zip"
44 },
45 },
46 )
47 .expect("Error downloading GTFS zip file");
48 let mut body = Vec::new();
49 res.read_to_end(&mut body)?;
50 let cursor = std::io::Cursor::new(body);
51
52 let gtfs = Gtfs::from_reader(cursor).expect("Error parsing GTFS zip");
53
54 gtfs.print_stats();
55
56 Ok(Renfe {
57 gtfs,
58 schedules: Vec::new(),
59 })
60 }
61
62 pub fn all_stations(&self) -> PyResult<Vec<Station>> {
63 let stations: Vec<Station> = self
64 .gtfs
65 .stops
66 .iter()
67 .map(|s| Station {
68 name: s.1.name.clone().unwrap(),
69 id: s.1.id.clone(),
70 })
71 .collect();
72 Ok(stations)
73 }
74
75 pub fn stations_match(&self, station: String) -> PyResult<Vec<Station>> {
76 let found: Vec<Station> = self
77 .gtfs
78 .stops
79 .iter()
80 .filter(|s| {
81 s.1.name
82 .clone()
83 .unwrap()
84 .to_lowercase()
85 .contains(&station.to_lowercase())
86 })
87 .map(|s| Station {
88 name: s.1.name.clone().unwrap(),
89 id: s.1.id.clone(),
90 })
91 .collect();
92 Ok(found)
93 }
94
95 pub fn filter_station(&self, station: String) -> PyResult<Station> {
96 match self.stations_match(station.clone()) {
97 Ok(v) if v.len() == 1 => {
98 println!(
99 "Provided input '{}' does a match with '{:?}'",
100 station, v[0]
101 );
102 Ok(v[0].clone())
103 }
104 Ok(v) => Err(PyValueError::new_err(format!(
105 "Provided input '{station}' does match with '{v:?}' -> There must be ONLY one match"
106 ))),
107 Err(e) => Err(e),
108 }
109 }
110
111 pub fn set_train_schedules(
113 &mut self,
114 origin_station_id: &str,
115 destination_station_id: &str,
116 day: u32,
117 month: u32,
118 year: i32,
119 sorted: bool,
120 ) -> PyResult<()> {
121 let gtfs = &self.gtfs;
122 let date = match NaiveDate::from_ymd_opt(year, month, day) {
124 Some(date) => date,
125 None => {
126 return Err(PyValueError::new_err(format!(
127 "Provided date '{year}-{month}-{day}' does not exist"
128 )))
129 }
130 };
131
132 let mut schedules = Vec::new();
133
134 for trip in gtfs.trips.values() {
136 if is_service_active(gtfs, &trip.service_id, date) {
138 let stop_times: Vec<_> = trip.stop_times.clone();
140
141 let origin_stop = stop_times.iter().find(|st| st.stop.id == origin_station_id);
143 let destination_stop = stop_times
144 .iter()
145 .find(|st| st.stop.id == destination_station_id);
146
147 if let (Some(origin), Some(destination)) = (origin_stop, destination_stop) {
149 if origin.stop_sequence < destination.stop_sequence {
150 let time_origin = origin.departure_time.unwrap();
151 let time_destination = destination.arrival_time.unwrap();
152 let departure_time = NaiveTime::from_hms_opt(
153 (time_origin / 3600) % 24,
154 time_origin % 3600 / 60,
155 time_origin % 60,
156 )
157 .unwrap();
158 let arrival_time = NaiveTime::from_hms_opt(
159 (time_destination / 3600) % 24,
160 time_destination % 3600 / 60,
161 time_destination % 60,
162 )
163 .unwrap();
164
165 let mut duration = arrival_time.signed_duration_since(departure_time);
166 if time_destination >= 86400 {
167 duration = duration.checked_add(&TimeDelta::seconds(86400)).unwrap();
168 }
169
170 schedules.push(Schedule {
171 train_type: gtfs
172 .get_route(&trip.route_id)
173 .unwrap()
174 .short_name
175 .clone()
176 .unwrap(),
177 departure_time,
183 arrival_time,
184 duration,
185 });
186 }
187 }
188 }
189 }
190
191 schedules.sort_by_key(|schedule| schedule.departure_time);
193
194 if sorted {
195 println!("sorting timetable by duration");
196 schedules.sort_by(|a, b| a.duration.cmp(&b.duration));
197 }
198
199 self.schedules = schedules;
200
201 Ok(())
202 }
203
204 pub fn print_timetable(&self) {
205 if self.schedules.is_empty() {
206 println!("\nNo schedules available...won't print timetable.");
207 } else {
208 println!("\n=========================TIMETABLE=========================");
209 println!(
210 " {0: <12} | {1: <10} | {2: <10} | {3: <12}",
211 "Train", "Departure", "Arrival", "Duration"
212 );
213 for track in &self.schedules {
214 println!("-----------------------------------------------------------");
215 println!(
216 " {0: <11} | {1: <9} | {2: <9} | {3: <10}",
217 track.train_type,
218 format!(
219 "{:02}:{:02}",
220 track.departure_time.hour(),
221 track.departure_time.minute() % 60
222 ),
223 format!(
224 "{:02}:{:02}",
225 track.arrival_time.hour(),
226 track.arrival_time.minute() % 60
227 ),
228 format!(
229 "{:02}:{:02}",
230 track.duration.num_hours(),
231 track.duration.num_minutes() % 60
232 )
233 );
234 }
235 println!("===========================================================");
236 }
237 }
238}
239
240fn is_service_active(gtfs: &Gtfs, service_id: &str, date: NaiveDate) -> bool {
242 if let Some(calendar) = gtfs.calendar.get(service_id) {
244 let weekday = match date.weekday() {
245 chrono::Weekday::Mon => calendar.monday,
246 chrono::Weekday::Tue => calendar.tuesday,
247 chrono::Weekday::Wed => calendar.wednesday,
248 chrono::Weekday::Thu => calendar.thursday,
249 chrono::Weekday::Fri => calendar.friday,
250 chrono::Weekday::Sat => calendar.saturday,
251 chrono::Weekday::Sun => calendar.sunday,
252 };
253
254 if weekday && date >= calendar.start_date && date <= calendar.end_date {
255 if let Some(calendar_dates) = gtfs.calendar_dates.get(service_id) {
257 for date_override in calendar_dates {
258 if date_override.date == date {
259 return !(date_override.exception_type
260 == gtfs_structures::Exception::Deleted);
261 }
262 }
263 }
264 return true;
265 }
266 }
267
268 if let Some(calendar_dates) = gtfs.calendar_dates.get(service_id) {
270 for date_override in calendar_dates {
271 if date_override.date == date {
272 return date_override.exception_type == gtfs_structures::Exception::Added;
273 }
274 }
275 }
276
277 false
278}