kitti_dataset/tracking/
label.rs

1use crate::{
2    serde::{occlusion, tracking_truncation},
3    Error,
4};
5use measurements::{Angle, Length};
6use noisy_float::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::{
9    borrow::Borrow,
10    fs::File,
11    io::{self, prelude::*, BufReader, BufWriter, Cursor},
12    path::Path,
13};
14
15pub use crate::object::{BoundingBox, Extents, Location, Occlusion};
16
17pub type LabelFromReaderIter<R> = csv::DeserializeRecordsIntoIter<R, Label>;
18pub type LabelFromPathIter = LabelFromReaderIter<BufReader<File>>;
19pub type LabelFromStrIter<'a> = LabelFromReaderIter<Cursor<&'a str>>;
20
21#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
22#[serde(from = "SerializedLabel", into = "SerializedLabel")]
23pub struct Label {
24    pub frame: u32,
25    pub track_id: Option<u32>,
26    pub class: String,
27    pub truncation: Option<Truncation>,
28    pub occlusion: Option<Occlusion>,
29    pub alpha: Angle,
30    pub bbox: BoundingBox,
31    pub extents: Extents,
32    pub location: Location,
33    pub rotation_y: Angle,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37struct SerializedLabel {
38    pub frame: u32,
39    pub track_id: i32,
40    pub class: String,
41    #[serde(with = "tracking_truncation")]
42    pub truncation: Option<Truncation>,
43    #[serde(with = "occlusion")]
44    pub occlusion: Option<Occlusion>,
45    pub alpha: f64,
46    pub xmin: f64,
47    pub ymin: f64,
48    pub xmax: f64,
49    pub ymax: f64,
50    pub height: f64,
51    pub width: f64,
52    pub length: f64,
53    pub x: f64,
54    pub y: f64,
55    pub z: f64,
56    pub rotation_y: f64,
57}
58
59impl From<SerializedLabel> for Label {
60    fn from(from: SerializedLabel) -> Self {
61        let SerializedLabel {
62            frame,
63            track_id,
64            class,
65            truncation,
66            occlusion,
67            alpha,
68            xmin,
69            ymin,
70            xmax,
71            ymax,
72            height,
73            width,
74            length,
75            x,
76            y,
77            z,
78            rotation_y,
79        } = from;
80
81        Self {
82            frame,
83            track_id: if track_id >= 0 {
84                Some(track_id as u32)
85            } else {
86                None
87            },
88            class,
89            truncation,
90            occlusion,
91            alpha: Angle::from_radians(alpha),
92            bbox: BoundingBox {
93                xmin,
94                ymin,
95                xmax,
96                ymax,
97            },
98            extents: Extents {
99                height: Length::from_meters(height),
100                width: Length::from_meters(width),
101                length: Length::from_meters(length),
102            },
103            location: Location {
104                x: Length::from_meters(x),
105                y: Length::from_meters(y),
106                z: Length::from_meters(z),
107            },
108            rotation_y: Angle::from_radians(rotation_y),
109        }
110    }
111}
112
113impl From<Label> for SerializedLabel {
114    fn from(from: Label) -> Self {
115        let Label {
116            frame,
117            track_id,
118            class,
119            truncation,
120            occlusion,
121            alpha,
122            bbox:
123                BoundingBox {
124                    xmin,
125                    ymin,
126                    xmax,
127                    ymax,
128                },
129            extents:
130                Extents {
131                    height,
132                    width,
133                    length,
134                },
135            location: Location { x, y, z },
136            rotation_y,
137        } = from;
138
139        SerializedLabel {
140            frame,
141            track_id: match track_id {
142                Some(track_id) => track_id as i32,
143                None => -1,
144            },
145            class,
146            truncation,
147            occlusion,
148            alpha: alpha.as_radians(),
149            xmin,
150            ymin,
151            xmax,
152            ymax,
153            height: height.as_meters(),
154            width: width.as_meters(),
155            length: length.as_meters(),
156            x: x.as_meters(),
157            y: y.as_meters(),
158            z: z.as_meters(),
159            rotation_y: rotation_y.as_radians(),
160        }
161    }
162}
163
164impl Label {
165    pub fn iter_from_reader<R>(reader: R) -> LabelFromReaderIter<R>
166    where
167        R: Read,
168    {
169        let reader = csv::ReaderBuilder::new()
170            .has_headers(false)
171            .delimiter(b' ')
172            .from_reader(reader);
173        reader.into_deserialize()
174    }
175
176    pub fn iter_from_path<P>(path: P) -> io::Result<LabelFromPathIter>
177    where
178        P: AsRef<Path>,
179    {
180        let reader = BufReader::new(File::open(path)?);
181        Ok(Self::iter_from_reader(reader))
182    }
183
184    pub fn iter_from_str(text: &str) -> LabelFromStrIter<'_> {
185        let reader = Cursor::new(text);
186        Self::iter_from_reader(reader)
187    }
188
189    pub fn vec_from_reader<R>(reader: R) -> Result<Vec<Label>, Error>
190    where
191        R: Read,
192    {
193        let reader = csv::ReaderBuilder::new()
194            .has_headers(false)
195            .delimiter(b' ')
196            .from_reader(reader);
197        let result: Result<Vec<Label>, _> = reader.into_deserialize().collect();
198        Ok(result?)
199    }
200
201    pub fn vec_from_path<P>(path: P) -> Result<Vec<Label>, Error>
202    where
203        P: AsRef<Path>,
204    {
205        let reader = BufReader::new(File::open(path)?);
206        Self::vec_from_reader(reader)
207    }
208
209    pub fn vec_from_str(text: &str) -> Result<Vec<Label>, Error> {
210        let reader = Cursor::new(text);
211        Self::vec_from_reader(reader)
212    }
213
214    pub fn write_to_writer<W, I, A>(writer: W, labels: I) -> io::Result<()>
215    where
216        I: IntoIterator<Item = A>,
217        W: Write,
218        A: Borrow<Label>,
219    {
220        let mut writer = csv::WriterBuilder::new()
221            .has_headers(false)
222            .delimiter(b' ')
223            .from_writer(writer);
224
225        for record in labels {
226            writer.serialize(record.borrow())?;
227        }
228
229        writer.flush()?;
230
231        Ok(())
232    }
233
234    pub fn write_to_path<P, I, A>(path: P, labels: I) -> io::Result<()>
235    where
236        I: IntoIterator<Item = A>,
237        P: AsRef<Path>,
238        A: Borrow<Label>,
239    {
240        let writer = BufWriter::new(File::create(path)?);
241        Self::write_to_writer(writer, labels)
242    }
243
244    pub fn write_to_string<I, A>(labels: I) -> io::Result<String>
245    where
246        I: IntoIterator<Item = A>,
247        A: Borrow<Label>,
248    {
249        let mut buf = vec![];
250        Self::write_to_writer(&mut buf, labels)?;
251        Ok(String::from_utf8(buf).unwrap())
252    }
253}
254
255#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
256pub enum Truncation {
257    Ignored,
258    Labeled(R64),
259}
260
261impl Truncation {
262    pub fn as_f64(&self) -> Option<f64> {
263        match self {
264            Truncation::Ignored => None,
265            Truncation::Labeled(value) => Some(value.raw()),
266        }
267    }
268
269    pub fn as_f32(&self) -> Option<f32> {
270        Some(self.as_f64()? as f32)
271    }
272
273    pub fn from_f64(value: f64) -> Result<Self, Error> {
274        value.try_into()
275    }
276
277    pub fn from_f32(value: f32) -> Result<Self, Error> {
278        (value as f64).try_into()
279    }
280}
281
282impl TryFrom<f64> for Truncation {
283    type Error = Error;
284
285    fn try_from(fval: f64) -> Result<Self, Self::Error> {
286        let error = || Error::InvalidTruncationValue(fval);
287
288        let rval = R64::try_from(fval).map_err(|_| error())?;
289
290        if !(r64(0.0)..=r64(1.0)).contains(&rval) {
291            return Err(error());
292        }
293
294        Ok(Truncation::Labeled(rval))
295    }
296}