1use serde::Deserialize;
2use sha2::{Digest, Sha256};
3
4use crate::{Error, Gtfs, RawGtfs};
5use std::collections::HashMap;
6use std::convert::TryFrom;
7use std::fs::File;
8use std::io::Read;
9use std::path::Path;
10use web_time::Instant;
11
12#[derive(Derivative)]
35#[derivative(Default)]
36pub struct GtfsReader {
37 #[derivative(Default(value = "true"))]
39 pub read_stop_times: bool,
40 #[derivative(Default(value = "true"))]
42 pub read_shapes: bool,
43 #[derivative(Default(value = "false"))]
45 pub unkown_enum_as_default: bool,
46 #[derivative(Default(value = "true"))]
51 pub trim_fields: bool,
52}
53
54impl GtfsReader {
55 pub fn read_stop_times(mut self, read_stop_times: bool) -> Self {
60 self.read_stop_times = read_stop_times;
61 self
62 }
63
64 pub fn read_shapes(mut self, read_shapes: bool) -> Self {
67 self.read_shapes = read_shapes;
68 self
69 }
70
71 pub fn unkown_enum_as_default(mut self, unkown_enum_as_default: bool) -> Self {
78 self.unkown_enum_as_default = unkown_enum_as_default;
79 self
80 }
81
82 pub fn trim_fields(mut self, trim_fields: bool) -> Self {
87 self.trim_fields = trim_fields;
88 self
89 }
90
91 #[cfg(not(target_arch = "wasm32"))]
96 pub fn read(self, gtfs: &str) -> Result<Gtfs, Error> {
97 self.raw().read(gtfs).and_then(Gtfs::try_from)
98 }
99
100 pub fn read_from_path<P>(self, path: P) -> Result<Gtfs, Error>
102 where
103 P: AsRef<Path>,
104 {
105 self.raw().read_from_path(path).and_then(Gtfs::try_from)
106 }
107
108 #[cfg(all(feature = "read-url", not(target_arch = "wasm32")))]
112 pub fn read_from_url<U: reqwest::IntoUrl>(self, url: U) -> Result<Gtfs, Error> {
113 self.raw().read_from_url(url).and_then(Gtfs::try_from)
114 }
115
116 #[cfg(feature = "read-url")]
120 pub async fn read_from_url_async<U: reqwest::IntoUrl>(self, url: U) -> Result<Gtfs, Error> {
121 self.raw()
122 .read_from_url_async(url)
123 .await
124 .and_then(Gtfs::try_from)
125 }
126
127 pub fn raw(self) -> RawGtfsReader {
139 RawGtfsReader { reader: self }
140 }
141}
142
143pub struct RawGtfsReader {
147 reader: GtfsReader,
148}
149
150impl RawGtfsReader {
151 fn read_from_directory(&self, p: &std::path::Path) -> Result<RawGtfs, Error> {
152 let start_of_read_instant = Instant::now();
153 let files = std::fs::read_dir(p)?
156 .filter_map(|d| {
157 d.ok().and_then(|e| {
158 e.path()
159 .strip_prefix(p)
160 .ok()
161 .and_then(|f| f.to_str().map(|s| s.to_owned()))
162 })
163 })
164 .collect();
165
166 let mut result = RawGtfs {
167 trips: self.read_objs_from_path(p.join("trips.txt")),
168 calendar: self.read_objs_from_optional_path(p, "calendar.txt"),
169 calendar_dates: self.read_objs_from_optional_path(p, "calendar_dates.txt"),
170 stops: self.read_objs_from_path(p.join("stops.txt")),
171 routes: self.read_objs_from_path(p.join("routes.txt")),
172 stop_times: if self.reader.read_stop_times {
173 self.read_objs_from_path(p.join("stop_times.txt"))
174 } else {
175 Ok(Vec::new())
176 },
177 agencies: self.read_objs_from_path(p.join("agency.txt")),
178 shapes: self.read_objs_from_optional_path(p, "shapes.txt"),
179 fare_attributes: self.read_objs_from_optional_path(p, "fare_attributes.txt"),
180 fare_rules: self.read_objs_from_optional_path(p, "fare_rules.txt"),
181 fare_products: self.read_objs_from_optional_path(p, "fare_products.txt"),
182 fare_media: self.read_objs_from_optional_path(p, "fare_media.txt"),
183 rider_categories: self.read_objs_from_optional_path(p, "rider_categories.txt"),
184 frequencies: self.read_objs_from_optional_path(p, "frequencies.txt"),
185 transfers: self.read_objs_from_optional_path(p, "transfers.txt"),
186 pathways: self.read_objs_from_optional_path(p, "pathways.txt"),
187 feed_info: self.read_objs_from_optional_path(p, "feed_info.txt"),
188 read_duration: start_of_read_instant.elapsed(),
189 translations: self.read_objs_from_optional_path(p, "translations.txt"),
190 files,
191 source_format: crate::SourceFormat::Directory,
192 sha256: None,
193 };
194
195 if self.reader.unkown_enum_as_default {
196 result.unknown_to_default();
197 }
198 Ok(result)
199 }
200
201 #[cfg(not(target_arch = "wasm32"))]
204 pub fn read(self, gtfs: &str) -> Result<RawGtfs, Error> {
205 #[cfg(feature = "read-url")]
206 if gtfs.starts_with("http") {
207 return self.read_from_url(gtfs);
208 }
209 self.read_from_path(gtfs)
210 }
211
212 #[cfg(all(feature = "read-url", not(target_arch = "wasm32")))]
214 pub fn read_from_url<U: reqwest::IntoUrl>(self, url: U) -> Result<RawGtfs, Error> {
215 let mut res = reqwest::blocking::get(url)?;
216 let mut body = Vec::new();
217 res.read_to_end(&mut body)?;
218 let cursor = std::io::Cursor::new(body);
219 self.read_from_reader(cursor)
220 }
221
222 #[cfg(feature = "read-url")]
224 pub async fn read_from_url_async<U: reqwest::IntoUrl>(self, url: U) -> Result<RawGtfs, Error> {
225 let res = reqwest::get(url).await?.bytes().await?;
226 let reader = std::io::Cursor::new(res);
227 self.read_from_reader(reader)
228 }
229
230 pub fn read_from_path<P>(&self, path: P) -> Result<RawGtfs, Error>
232 where
233 P: AsRef<Path>,
234 {
235 let p = path.as_ref();
236 if p.is_file() {
237 let reader = File::open(p)?;
238 self.read_from_reader(reader)
239 } else if p.is_dir() {
240 self.read_from_directory(p)
241 } else {
242 Err(Error::NotFileNorDirectory(format!("{}", p.display())))
243 }
244 }
245
246 pub fn read_from_reader<T: std::io::Read + std::io::Seek>(
247 &self,
248 reader: T,
249 ) -> Result<RawGtfs, Error> {
250 let start_of_read_instant = Instant::now();
251 let mut hasher = Sha256::new();
252 let mut buf_reader = std::io::BufReader::new(reader);
253 let _n = std::io::copy(&mut buf_reader, &mut hasher)?;
254 let hash = hasher.finalize();
255 let mut archive = zip::ZipArchive::new(buf_reader)?;
256 let mut file_mapping = HashMap::new();
257 let mut files = Vec::new();
258
259 for i in 0..archive.len() {
260 let archive_file = archive.by_index(i)?;
261 files.push(archive_file.name().to_owned());
262
263 for gtfs_file in &[
264 "agency.txt",
265 "calendar.txt",
266 "calendar_dates.txt",
267 "routes.txt",
268 "stops.txt",
269 "stop_times.txt",
270 "trips.txt",
271 "fare_attributes.txt",
272 "fare_rules.txt",
273 "fare_products.txt",
274 "fare_media.txt",
275 "rider_categories.txt",
276 "frequencies.txt",
277 "transfers.txt",
278 "pathways.txt",
279 "feed_info.txt",
280 "shapes.txt",
281 ] {
282 let path = std::path::Path::new(archive_file.name());
283 if path.file_name() == Some(std::ffi::OsStr::new(gtfs_file)) {
284 file_mapping.insert(gtfs_file, i);
285 break;
286 }
287 }
288 }
289
290 let mut result = RawGtfs {
291 agencies: self.read_file(&file_mapping, &mut archive, "agency.txt"),
292 calendar: self.read_optional_file(&file_mapping, &mut archive, "calendar.txt"),
293 calendar_dates: self.read_optional_file(
294 &file_mapping,
295 &mut archive,
296 "calendar_dates.txt",
297 ),
298 routes: self.read_file(&file_mapping, &mut archive, "routes.txt"),
299 stops: self.read_file(&file_mapping, &mut archive, "stops.txt"),
300 stop_times: if self.reader.read_stop_times {
301 self.read_file(&file_mapping, &mut archive, "stop_times.txt")
302 } else {
303 Ok(Vec::new())
304 },
305 trips: self.read_file(&file_mapping, &mut archive, "trips.txt"),
306 fare_attributes: self.read_optional_file(
307 &file_mapping,
308 &mut archive,
309 "fare_attributes.txt",
310 ),
311 fare_rules: self.read_optional_file(&file_mapping, &mut archive, "fare_rules.txt"),
312 fare_products: self.read_optional_file(
313 &file_mapping,
314 &mut archive,
315 "fare_products.txt",
316 ),
317 fare_media: self.read_optional_file(&file_mapping, &mut archive, "fare_media.txt"),
318 rider_categories: self.read_optional_file(
319 &file_mapping,
320 &mut archive,
321 "rider_categories.txt",
322 ),
323 frequencies: self.read_optional_file(&file_mapping, &mut archive, "frequencies.txt"),
324 transfers: self.read_optional_file(&file_mapping, &mut archive, "transfers.txt"),
325 pathways: self.read_optional_file(&file_mapping, &mut archive, "pathways.txt"),
326 feed_info: self.read_optional_file(&file_mapping, &mut archive, "feed_info.txt"),
327 shapes: if self.reader.read_shapes {
328 self.read_optional_file(&file_mapping, &mut archive, "shapes.txt")
329 } else {
330 Some(Ok(Vec::new()))
331 },
332 translations: self.read_optional_file(&file_mapping, &mut archive, "translations.txt"),
333 read_duration: start_of_read_instant.elapsed(),
334 files,
335 source_format: crate::SourceFormat::Zip,
336 sha256: Some(format!("{hash:x}")),
337 };
338
339 if self.reader.unkown_enum_as_default {
340 result.unknown_to_default();
341 }
342 Ok(result)
343 }
344
345 fn read_objs<T, O>(&self, mut reader: T, file_name: &str) -> Result<Vec<O>, Error>
346 where
347 for<'de> O: Deserialize<'de>,
348 T: std::io::Read,
349 {
350 let mut bom = [0; 3];
351 reader
352 .read_exact(&mut bom)
353 .map_err(|e| Error::NamedFileIO {
354 file_name: file_name.to_owned(),
355 source: Box::new(e),
356 })?;
357
358 let chained = if bom != [0xefu8, 0xbbu8, 0xbfu8] {
359 bom.chain(reader)
360 } else {
361 [].chain(reader)
362 };
363
364 let mut reader = csv::ReaderBuilder::new()
365 .flexible(true)
366 .trim(if self.reader.trim_fields {
367 csv::Trim::Fields
368 } else {
369 csv::Trim::None
370 })
371 .from_reader(chained);
372 let headers = reader
374 .headers()
375 .map_err(|e| Error::CSVError {
376 file_name: file_name.to_owned(),
377 source: e,
378 line_in_error: None,
379 })?
380 .clone()
381 .into_iter()
382 .map(|x| x.trim())
383 .collect::<csv::StringRecord>();
384
385 let mut rec = csv::StringRecord::new();
387 let mut objs = Vec::new();
388
389 while reader.read_record(&mut rec).map_err(|e| Error::CSVError {
391 file_name: file_name.to_owned(),
392 source: e,
393 line_in_error: None,
394 })? {
395 let obj = rec
396 .deserialize(Some(&headers))
397 .map_err(|e| Error::CSVError {
398 file_name: file_name.to_owned(),
399 source: e,
400 line_in_error: Some(crate::error::LineError {
401 headers: headers.into_iter().map(String::from).collect(),
402 values: rec.into_iter().map(String::from).collect(),
403 }),
404 })?;
405 objs.push(obj);
406 }
407 Ok(objs)
408 }
409
410 fn read_objs_from_path<O>(&self, path: std::path::PathBuf) -> Result<Vec<O>, Error>
411 where
412 for<'de> O: Deserialize<'de>,
413 {
414 let file_name = path
415 .file_name()
416 .and_then(|f| f.to_str())
417 .unwrap_or("invalid_file_name")
418 .to_string();
419 if path.exists() {
420 File::open(path)
421 .map_err(|e| Error::NamedFileIO {
422 file_name: file_name.to_owned(),
423 source: Box::new(e),
424 })
425 .and_then(|r| self.read_objs(r, &file_name))
426 } else {
427 Err(Error::MissingFile(file_name))
428 }
429 }
430
431 fn read_objs_from_optional_path<O>(
432 &self,
433 dir_path: &std::path::Path,
434 file_name: &str,
435 ) -> Option<Result<Vec<O>, Error>>
436 where
437 for<'de> O: Deserialize<'de>,
438 {
439 File::open(dir_path.join(file_name))
440 .ok()
441 .map(|r| self.read_objs(r, file_name))
442 }
443
444 fn read_file<O, T>(
445 &self,
446 file_mapping: &HashMap<&&str, usize>,
447 archive: &mut zip::ZipArchive<T>,
448 file_name: &str,
449 ) -> Result<Vec<O>, Error>
450 where
451 for<'de> O: Deserialize<'de>,
452 T: std::io::Read + std::io::Seek,
453 {
454 self.read_optional_file(file_mapping, archive, file_name)
455 .unwrap_or_else(|| Err(Error::MissingFile(file_name.to_owned())))
456 }
457
458 fn read_optional_file<O, T>(
459 &self,
460 file_mapping: &HashMap<&&str, usize>,
461 archive: &mut zip::ZipArchive<T>,
462 file_name: &str,
463 ) -> Option<Result<Vec<O>, Error>>
464 where
465 for<'de> O: Deserialize<'de>,
466 T: std::io::Read + std::io::Seek,
467 {
468 file_mapping.get(&file_name).map(|i| {
469 self.read_objs(
470 archive.by_index(*i).map_err(|e| Error::NamedFileIO {
471 file_name: file_name.to_owned(),
472 source: Box::new(e),
473 })?,
474 file_name,
475 )
476 })
477 }
478}