1use crate::{
2 dataset::{Dataset, DatasetInner, InstanceInner, SampleInner, SceneInner},
3 error::{Error, Result},
4 serializable::{
5 Attribute, CalibratedSensor, Category, EgoPose, Instance, Log, Map, Sample,
6 SampleAnnotation, SampleData, Scene, Sensor, Token, Visibility, VisibilityToken,
7 },
8 utils::{ParallelIteratorExt, WithToken},
9};
10use chrono::NaiveDateTime;
11use itertools::Itertools;
12use rayon::prelude::*;
13use serde::Deserialize;
14use std::{
15 collections::HashMap,
16 fs::File,
17 io::BufReader,
18 path::{Path, PathBuf},
19};
20
21macro_rules! bail_corrupted {
22 ($($arg:expr),*) => {
23 {
24 let msg = format!($($arg),*);
25 return Err(Error::CorruptedDataset(msg));
26 }
27 };
28}
29
30macro_rules! ensure_corrupted {
31 ($cond:expr, $($arg:expr),*) => {
32 {
33 if !$cond {
34 bail_corrupted!($($arg),*);
35 }
36 }
37 };
38}
39
40#[derive(Debug, Clone)]
41pub struct DatasetLoader {
42 pub check: bool,
43}
44
45impl DatasetLoader {
46 pub fn load<P>(&self, version: &str, dir: P) -> Result<Dataset>
58 where
59 P: AsRef<Path>,
60 {
61 let Self { check } = *self;
62 let dataset_dir = dir.as_ref();
63 let meta_dir = dataset_dir.join(version);
64
65 let load_json = load_json_files(&meta_dir)?;
67
68 if check {
70 check_loaded_json(&load_json)?;
71 }
72
73 let inner = index_records(version.to_string(), dataset_dir.to_owned(), load_json)?;
75
76 Ok(Dataset::from_inner(inner))
77 }
78}
79
80impl Default for DatasetLoader {
81 fn default() -> Self {
82 Self { check: true }
83 }
84}
85
86struct LoadJson {
87 pub attribute_map: HashMap<Token, Attribute>,
88 pub calibrated_sensor_map: HashMap<Token, CalibratedSensor>,
89 pub category_map: HashMap<Token, Category>,
90 pub ego_pose_map: HashMap<Token, EgoPose>,
91 pub instance_map: HashMap<Token, Instance>,
92 pub log_map: HashMap<Token, Log>,
93 pub map_map: HashMap<Token, Map>,
94 pub scene_map: HashMap<Token, Scene>,
95 pub sample_map: HashMap<Token, Sample>,
96 pub sample_annotation_map: HashMap<Token, SampleAnnotation>,
97 pub sample_data_map: HashMap<Token, SampleData>,
98 pub sensor_map: HashMap<Token, Sensor>,
99 pub visibility_map: HashMap<VisibilityToken, Visibility>,
100}
101
102fn load_json_files(dir: &Path) -> Result<LoadJson> {
103 let mut attribute_map: Result<HashMap<Token, Attribute>> = Ok(Default::default());
104 let mut calibrated_sensor_map: Result<HashMap<Token, CalibratedSensor>> =
105 Ok(Default::default());
106 let mut category_map: Result<HashMap<Token, Category>> = Ok(Default::default());
107 let mut ego_pose_map: Result<HashMap<Token, EgoPose>> = Ok(Default::default());
108 let mut instance_map: Result<HashMap<Token, Instance>> = Ok(Default::default());
109 let mut log_map: Result<HashMap<Token, Log>> = Ok(Default::default());
110 let mut map_map: Result<HashMap<Token, Map>> = Ok(Default::default());
111 let mut sample_annotation_map: Result<HashMap<Token, SampleAnnotation>> =
112 Ok(Default::default());
113 let mut sample_data_map: Result<HashMap<Token, SampleData>> = Ok(Default::default());
114 let mut sample_map: Result<HashMap<Token, Sample>> = Ok(Default::default());
115 let mut scene_map: Result<HashMap<Token, Scene>> = Ok(Default::default());
116 let mut sensor_map: Result<HashMap<Token, Sensor>> = Ok(Default::default());
117 let mut visibility_map: Result<HashMap<VisibilityToken, Visibility>> = Ok(Default::default());
118
119 rayon::scope(|scope| {
120 scope.spawn(|_| {
121 attribute_map = load_map(dir.join("attribute.json"));
122 });
123 scope.spawn(|_| {
124 calibrated_sensor_map = load_map(dir.join("calibrated_sensor.json"));
125 });
126 scope.spawn(|_| {
127 category_map = load_map(dir.join("category.json"));
128 });
129 scope.spawn(|_| {
130 ego_pose_map = load_map(dir.join("ego_pose.json"));
131 });
132 scope.spawn(|_| {
133 instance_map = load_map(dir.join("instance.json"));
134 });
135 scope.spawn(|_| {
136 log_map = load_map(dir.join("log.json"));
137 });
138 scope.spawn(|_| {
139 map_map = load_map(dir.join("map.json"));
140 });
141 scope.spawn(|_| {
142 sample_annotation_map = load_map(dir.join("sample_annotation.json"));
143 });
144 scope.spawn(|_| {
145 sample_data_map = load_map(dir.join("sample_data.json"));
146 });
147 scope.spawn(|_| {
148 sample_map = load_map(dir.join("sample.json"));
149 });
150 scope.spawn(|_| {
151 scene_map = load_map(dir.join("scene.json"));
152 });
153 scope.spawn(|_| {
154 sensor_map = load_map(dir.join("sensor.json"));
155 });
156 scope.spawn(|_| {
157 visibility_map = (|| {
158 let vec: Vec<Visibility> = load_json(dir.join("visibility.json"))?;
159 let map: HashMap<VisibilityToken, Visibility> =
160 vec.into_iter().map(|item| (item.token, item)).collect();
161 Ok(map)
162 })();
163 });
164 });
165
166 let attribute_map = attribute_map?;
167 let calibrated_sensor_map = calibrated_sensor_map?;
168 let category_map = category_map?;
169 let ego_pose_map = ego_pose_map?;
170 let instance_map = instance_map?;
171 let log_map = log_map?;
172 let map_map = map_map?;
173 let sample_annotation_map = sample_annotation_map?;
174 let sample_data_map = sample_data_map?;
175 let sample_map = sample_map?;
176 let scene_map = scene_map?;
177 let sensor_map = sensor_map?;
178 let visibility_map = visibility_map?;
179
180 Ok(LoadJson {
181 attribute_map,
182 calibrated_sensor_map,
183 category_map,
184 ego_pose_map,
185 instance_map,
186 log_map,
187 map_map,
188 scene_map,
189 sample_map,
190 sample_annotation_map,
191 sample_data_map,
192 sensor_map,
193 visibility_map,
194 })
195}
196
197fn check_loaded_json(load_json: &LoadJson) -> Result<()> {
198 let LoadJson {
199 attribute_map,
200 calibrated_sensor_map,
201 category_map,
202 ego_pose_map,
203 instance_map,
204 log_map,
205 map_map,
206 scene_map,
207 sample_map,
208 sample_annotation_map,
209 sample_data_map,
210 sensor_map,
211 visibility_map,
212 } = load_json;
213
214 calibrated_sensor_map
216 .par_iter()
217 .try_for_each(|(_, calibrated_sensor)| {
218 ensure_corrupted!(
219 sensor_map.contains_key(&calibrated_sensor.sensor_token),
220 "the token {} does not refer to any sensor",
221 calibrated_sensor.sensor_token
222 );
223 Ok(())
224 })?;
225
226 sample_annotation_map
228 .par_iter()
229 .try_for_each(|(_, sample_annotation)| {
230 ensure_corrupted!(
231 sample_map.contains_key(&sample_annotation.sample_token),
232 "the token {} does not refer to any sample",
233 sample_annotation.sample_token
234 );
235
236 ensure_corrupted!(
237 instance_map.contains_key(&sample_annotation.instance_token),
238 "the token {} does not refer to any instance",
239 sample_annotation.instance_token
240 );
241
242 sample_annotation
243 .attribute_tokens
244 .par_iter()
245 .try_for_each(|token| {
246 ensure_corrupted!(
247 attribute_map.contains_key(token),
248 "the token {} does not refer to any attribute",
249 token
250 );
251 Ok(())
252 })?;
253
254 if let Some(token) = &sample_annotation.visibility_token {
255 ensure_corrupted!(
256 visibility_map.contains_key(token),
257 "the token {} does not refer to any visibility",
258 token
259 );
260 }
261
262 if let Some(token) = &sample_annotation.prev {
263 ensure_corrupted!(
264 sample_annotation_map.contains_key(token),
265 "the token {} does not refer to any sample annotation",
266 token
267 );
268 }
269
270 if let Some(token) = &sample_annotation.next {
271 ensure_corrupted!(
272 sample_annotation_map.contains_key(token),
273 "the token {} does not refer to any sample annotation",
274 token
275 );
276 }
277
278 Ok(())
279 })?;
280
281 {
283 let mut prev_edges: Vec<_> = sample_annotation_map
284 .par_iter()
285 .filter_map(|(&curr_token, annotation)| Some((annotation.prev?, curr_token)))
286 .collect();
287 prev_edges.par_sort_unstable();
288
289 let mut next_edges: Vec<_> = sample_annotation_map
290 .par_iter()
291 .filter_map(|(&curr_token, annotation)| Some((curr_token, annotation.next?)))
292 .collect();
293 next_edges.par_sort_unstable();
294
295 ensure_corrupted!(
296 prev_edges.len() == next_edges.len(),
297 "The number of non-null sample_annotation.next does not match the number of sample_annotation.prev"
298 );
299
300 prev_edges
301 .par_iter()
302 .zip(next_edges.par_iter())
303 .try_for_each(|(e1, e2)| {
304 ensure_corrupted!(
305 e1 == e2,
306 "The prev and next fields of sample_annotatoin.json are corrupted"
307 );
308 Ok(())
309 })?;
310 }
311
312 instance_map.par_iter().try_for_each(|(_, instance)| {
314 ensure_corrupted!(
315 sample_annotation_map.contains_key(&instance.first_annotation_token),
316 "the token {} does not refer to any sample annotation",
317 instance.first_annotation_token
318 );
319
320 ensure_corrupted!(
321 sample_annotation_map.contains_key(&instance.last_annotation_token),
322 "the token {} does not refer to any sample annotation",
323 instance.last_annotation_token
324 );
325
326 ensure_corrupted!(
327 category_map.contains_key(&instance.category_token),
328 "the token {} does not refer to any sample category",
329 instance.category_token
330 );
331
332 Ok(())
333 })?;
334
335 {
337 let mut lhs: Vec<_> = sample_annotation_map
338 .par_iter()
339 .filter_map(|(&token, annotation)| annotation.prev.is_none().then_some(token))
340 .collect();
341 let mut rhs: Vec<_> = instance_map
342 .par_iter()
343 .map(|(_, instance)| instance.first_annotation_token)
344 .collect();
345
346 lhs.par_sort_unstable();
347 rhs.par_sort_unstable();
348 lhs.par_iter()
349 .zip(rhs.par_iter())
350 .try_for_each(|(lhs, rhs)| {
351 ensure_corrupted!(lhs == rhs, "instance.first_annotation_token is corrupted");
352 Ok(())
353 })?;
354 }
355
356 {
358 let mut lhs: Vec<_> = sample_annotation_map
359 .par_iter()
360 .filter_map(|(&token, annotation)| annotation.next.is_none().then_some(token))
361 .collect();
362 let mut rhs: Vec<_> = instance_map
363 .par_iter()
364 .map(|(_, instance)| instance.last_annotation_token)
365 .collect();
366
367 lhs.par_sort_unstable();
368 rhs.par_sort_unstable();
369
370 lhs.par_iter()
371 .zip(rhs.par_iter())
372 .try_for_each(|(lhs, rhs)| {
373 ensure_corrupted!(lhs == rhs, "instance.first_annotation_token is corrupted");
374 Ok(())
375 })?;
376 }
377
378 map_map
428 .par_iter()
429 .flat_map(|(map_token, map)| {
430 map.log_tokens
431 .par_iter()
432 .map(move |log_token| (map_token, log_token))
433 })
434 .try_for_each(|(map_token, log_token)| {
435 ensure_corrupted!(
436 log_map.contains_key(log_token),
437 "in the map {map_token}, the log_token {log_token} does not refer to any valid log"
438 );
439 Ok(())
440 })?;
441
442 sample_map.par_iter().try_for_each(|(_, sample)| {
444 ensure_corrupted!(
445 scene_map.contains_key(&sample.scene_token),
446 "the token {} does not refer to any scene",
447 sample.scene_token
448 );
449
450 if let Some(token) = &sample.prev {
451 ensure_corrupted!(
452 sample_map.contains_key(token),
453 "the token {} does not refer to any sample",
454 token
455 );
456 }
457
458 if let Some(token) = &sample.next {
459 ensure_corrupted!(
460 sample_map.contains_key(token),
461 "the token {} does not refer to any sample",
462 token
463 );
464 }
465
466 Ok(())
467 })?;
468
469 {
471 let mut prev_edges: Vec<_> = sample_map
472 .par_iter()
473 .filter_map(|(&curr_token, sample)| Some((sample.prev?, curr_token)))
474 .collect();
475 prev_edges.par_sort_unstable();
476
477 let mut next_edges: Vec<_> = sample_map
478 .par_iter()
479 .filter_map(|(&curr_token, sample)| Some((curr_token, sample.next?)))
480 .collect();
481 next_edges.par_sort_unstable();
482
483 ensure_corrupted!(
484 prev_edges.len() == next_edges.len(),
485 "The number of non-null sample.next does not match the number of sample.prev"
486 );
487
488 prev_edges
489 .par_iter()
490 .zip(next_edges.par_iter())
491 .try_for_each(|(e1, e2)| {
492 ensure_corrupted!(
493 e1 == e2,
494 "The prev and next fields of sample.json are corrupted"
495 );
496 Ok(())
497 })?;
498 }
499
500 scene_map.par_iter().try_for_each(|(_, scene)| {
502 ensure_corrupted!(
503 log_map.contains_key(&scene.log_token),
504 "the token {} does not refer to any log",
505 scene.log_token
506 );
507
508 ensure_corrupted!(
509 sample_map.contains_key(&scene.first_sample_token),
510 "the token {} does not refer to any sample",
511 scene.first_sample_token
512 );
513
514 ensure_corrupted!(
515 sample_map.contains_key(&scene.last_sample_token),
516 "the token {} does not refer to any sample",
517 scene.last_sample_token
518 );
519
520 Ok(())
521 })?;
522
523 {
525 let mut lhs: Vec<_> = sample_map
526 .par_iter()
527 .filter_map(|(&token, sample)| sample.prev.is_none().then_some(token))
528 .collect();
529 let mut rhs: Vec<_> = scene_map
530 .par_iter()
531 .map(|(_, scene)| scene.first_sample_token)
532 .collect();
533
534 lhs.par_sort_unstable();
535 rhs.par_sort_unstable();
536 lhs.par_iter()
537 .zip(rhs.par_iter())
538 .try_for_each(|(lhs, rhs)| {
539 ensure_corrupted!(lhs == rhs, "scene.first_sample_token is corrupted");
540 Ok(())
541 })?;
542 }
543
544 {
546 let mut lhs: Vec<_> = sample_map
547 .par_iter()
548 .filter_map(|(&token, sample)| sample.next.is_none().then_some(token))
549 .collect();
550 let mut rhs: Vec<_> = scene_map
551 .par_iter()
552 .map(|(_, scene)| scene.last_sample_token)
553 .collect();
554
555 lhs.par_sort_unstable();
556 rhs.par_sort_unstable();
557
558 lhs.par_iter()
559 .zip(rhs.par_iter())
560 .try_for_each(|(lhs, rhs)| {
561 ensure_corrupted!(lhs == rhs, "scene.first_sample_token is corrupted");
562 Ok(())
563 })?;
564 }
565
566 sample_data_map
614 .par_iter()
615 .try_for_each(|(_, sample_data)| {
616 ensure_corrupted!(
617 sample_map.contains_key(&sample_data.sample_token),
618 "the token {} does not refer to any sample",
619 sample_data.sample_token
620 );
621
622 ensure_corrupted!(
623 ego_pose_map.contains_key(&sample_data.ego_pose_token),
624 "the token {} does not refer to any ego pose",
625 sample_data.ego_pose_token
626 );
627
628 ensure_corrupted!(
629 calibrated_sensor_map.contains_key(&sample_data.calibrated_sensor_token),
630 "the token {} does not refer to any calibrated sensor",
631 sample_data.calibrated_sensor_token
632 );
633
634 if let Some(token) = &sample_data.prev {
635 ensure_corrupted!(
636 sample_data_map.contains_key(token),
637 "the token {} does not refer to any sample data",
638 token
639 );
640 }
641
642 if let Some(token) = &sample_data.next {
643 ensure_corrupted!(
644 sample_data_map.contains_key(token),
645 "the token {} does not refer to any sample data",
646 token
647 );
648 }
649
650 Ok(())
651 })?;
652
653 {
655 let mut prev_edges: Vec<_> = sample_data_map
656 .par_iter()
657 .filter_map(|(&curr_token, data)| Some((data.prev?, curr_token)))
658 .collect();
659 prev_edges.par_sort_unstable();
660
661 let mut next_edges: Vec<_> = sample_data_map
662 .par_iter()
663 .filter_map(|(&curr_token, data)| Some((curr_token, data.next?)))
664 .collect();
665 next_edges.par_sort_unstable();
666
667 ensure_corrupted!(
668 prev_edges.len() == next_edges.len(),
669 "The number of non-null sample_data.next does not match the number of sample_data.prev"
670 );
671
672 prev_edges
673 .par_iter()
674 .zip(next_edges.par_iter())
675 .try_for_each(|(e1, e2)| {
676 ensure_corrupted!(
677 e1 == e2,
678 "The prev and next fields of sample_annotatoin.json are corrupted"
679 );
680 Ok(())
681 })?;
682 }
683
684 Ok(())
685}
686
687fn index_records(
688 version: String,
689 dataset_dir: PathBuf,
690 load_json: LoadJson,
691) -> Result<DatasetInner> {
692 let LoadJson {
693 attribute_map,
694 calibrated_sensor_map,
695 category_map,
696 ego_pose_map,
697 instance_map,
698 log_map,
699 map_map,
700 scene_map,
701 sample_map,
702 sample_annotation_map,
703 sample_data_map,
704 sensor_map,
705 visibility_map,
706 } = load_json;
707
708 let mut sample_to_annotation_groups = sample_annotation_map
710 .iter()
711 .map(|(sample_annotation_token, sample_annotation)| {
712 (sample_annotation.sample_token, *sample_annotation_token)
713 })
714 .into_group_map();
715
716 let mut sample_to_sample_data_groups = sample_data_map
718 .iter()
719 .map(|(sample_data_token, sample_data)| (sample_data.sample_token, *sample_data_token))
720 .into_group_map();
721
722 let instance_internal_map: HashMap<Token, InstanceInner> = instance_map
724 .into_par_iter()
725 .map(|(instance_token, instance)| -> Result<_> {
726 let ret = InstanceInner::from(instance, &sample_annotation_map)?;
727 Ok((instance_token, ret))
728 })
729 .par_try_collect()?;
730
731 let scene_internal_map: HashMap<_, _> = scene_map
732 .into_par_iter()
733 .map(|(scene_token, scene)| -> Result<_> {
734 let internal = SceneInner::from(scene, &sample_map)?;
735 Ok((scene_token, internal))
736 })
737 .par_try_collect()?;
738
739 let sample_internal_map: HashMap<_, _> = sample_map
740 .into_iter()
741 .map(|(sample_token, sample)| -> Result<_> {
742 let sample_data_tokens = sample_to_sample_data_groups
743 .remove(&sample_token)
744 .unwrap_or_default();
745 let annotation_tokens = sample_to_annotation_groups
746 .remove(&sample_token)
747 .unwrap_or_default();
748 let internal = SampleInner::from(sample, annotation_tokens, sample_data_tokens);
749 Ok((sample_token, internal))
750 })
751 .try_collect()?;
752
753 let sorted_ego_pose_tokens: Vec<_> = {
755 let mut sorted_pairs: Vec<(&Token, NaiveDateTime)> = ego_pose_map
756 .par_iter()
757 .map(|(sample_token, sample)| (sample_token, sample.timestamp))
758 .collect();
759 sorted_pairs.par_sort_unstable_by_key(|(_, timestamp)| *timestamp);
760 sorted_pairs
761 .into_par_iter()
762 .map(|(token, _)| *token)
763 .collect()
764 };
765
766 let sorted_sample_tokens: Vec<_> = {
768 let mut sorted_pairs: Vec<(&Token, NaiveDateTime)> = sample_internal_map
769 .par_iter()
770 .map(|(sample_token, sample)| (sample_token, sample.timestamp))
771 .collect();
772 sorted_pairs.par_sort_unstable_by_key(|(_, timestamp)| *timestamp);
773 sorted_pairs
774 .into_par_iter()
775 .map(|(token, _)| *token)
776 .collect()
777 };
778
779 let sorted_sample_data_tokens: Vec<_> = {
781 let mut sorted_pairs: Vec<(&Token, NaiveDateTime)> = sample_data_map
782 .par_iter()
783 .map(|(sample_token, sample)| (sample_token, sample.timestamp))
784 .collect();
785 sorted_pairs.par_sort_unstable_by_key(|(_, timestamp)| *timestamp);
786 sorted_pairs
787 .into_par_iter()
788 .map(|(token, _)| *token)
789 .collect()
790 };
791
792 let sorted_scene_tokens: Vec<_> = {
794 let mut sorted_pairs: Vec<_> = scene_internal_map
795 .par_iter()
796 .map(|(scene_token, scene)| {
797 let timestamps: Vec<NaiveDateTime> = scene
798 .sample_tokens
799 .par_iter()
800 .map(|sample_token| {
801 let sample = sample_internal_map
802 .get(sample_token)
803 .expect("internal error: invalid sample_token");
804 sample.timestamp
805 })
806 .collect();
807
808 let timestamp = timestamps
809 .into_par_iter()
810 .min()
811 .expect("scene.sample_tokens must not be empty");
812
813 (scene_token, timestamp)
814 })
815 .collect();
816 sorted_pairs.par_sort_unstable_by_key(|(_, timestamp)| *timestamp);
817
818 sorted_pairs
819 .into_par_iter()
820 .map(|(token, _)| *token)
821 .collect()
822 };
823
824 let inner = DatasetInner {
826 version,
827 dataset_dir,
828 attribute_map,
829 calibrated_sensor_map,
830 category_map,
831 ego_pose_map,
832 instance_map: instance_internal_map,
833 log_map,
834 map_map,
835 sample_map: sample_internal_map,
836 sample_annotation_map,
837 sample_data_map,
838 scene_map: scene_internal_map,
839 sensor_map,
840 visibility_map,
841 sorted_ego_pose_tokens,
842 sorted_scene_tokens,
843 sorted_sample_tokens,
844 sorted_sample_data_tokens,
845 };
846
847 Ok(inner)
848}
849
850fn load_map<T, P>(path: P) -> Result<HashMap<Token, T>>
851where
852 P: AsRef<Path>,
853 T: for<'a> Deserialize<'a> + WithToken + Send,
854 Vec<T>: rayon::iter::IntoParallelIterator<Item = T>,
855{
856 let vec: Vec<T> = load_json(path)?;
857 let map = vec
858 .into_par_iter()
859 .map(|item| (item.token(), item))
860 .collect();
861 Ok(map)
862}
863
864fn load_json<T, P>(path: P) -> Result<T>
865where
866 P: AsRef<Path>,
867 T: for<'a> Deserialize<'a>,
868{
869 let reader = BufReader::new(File::open(path.as_ref())?);
870 let value = serde_json::from_reader(reader).map_err(|err| {
871 let msg = format!("failed to load file {}: {:?}", path.as_ref().display(), err);
872 Error::CorruptedDataset(msg)
873 })?;
874 Ok(value)
875}