1use crate::{
2 serde::{occlusion, tracking_truncation},
3 Error,
4};
5use measurements::{Angle, Length};
6use noisy_float::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::{
9 borrow::Borrow,
10 fs::File,
11 io::{self, prelude::*, BufReader, BufWriter, Cursor},
12 path::Path,
13};
14
15pub use crate::object::{BoundingBox, Extents, Location, Occlusion};
16
17pub type LabelFromReaderIter<R> = csv::DeserializeRecordsIntoIter<R, Label>;
18pub type LabelFromPathIter = LabelFromReaderIter<BufReader<File>>;
19pub type LabelFromStrIter<'a> = LabelFromReaderIter<Cursor<&'a str>>;
20
21#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
22#[serde(from = "SerializedLabel", into = "SerializedLabel")]
23pub struct Label {
24 pub frame: u32,
25 pub track_id: Option<u32>,
26 pub class: String,
27 pub truncation: Option<Truncation>,
28 pub occlusion: Option<Occlusion>,
29 pub alpha: Angle,
30 pub bbox: BoundingBox,
31 pub extents: Extents,
32 pub location: Location,
33 pub rotation_y: Angle,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37struct SerializedLabel {
38 pub frame: u32,
39 pub track_id: i32,
40 pub class: String,
41 #[serde(with = "tracking_truncation")]
42 pub truncation: Option<Truncation>,
43 #[serde(with = "occlusion")]
44 pub occlusion: Option<Occlusion>,
45 pub alpha: f64,
46 pub xmin: f64,
47 pub ymin: f64,
48 pub xmax: f64,
49 pub ymax: f64,
50 pub height: f64,
51 pub width: f64,
52 pub length: f64,
53 pub x: f64,
54 pub y: f64,
55 pub z: f64,
56 pub rotation_y: f64,
57}
58
59impl From<SerializedLabel> for Label {
60 fn from(from: SerializedLabel) -> Self {
61 let SerializedLabel {
62 frame,
63 track_id,
64 class,
65 truncation,
66 occlusion,
67 alpha,
68 xmin,
69 ymin,
70 xmax,
71 ymax,
72 height,
73 width,
74 length,
75 x,
76 y,
77 z,
78 rotation_y,
79 } = from;
80
81 Self {
82 frame,
83 track_id: if track_id >= 0 {
84 Some(track_id as u32)
85 } else {
86 None
87 },
88 class,
89 truncation,
90 occlusion,
91 alpha: Angle::from_radians(alpha),
92 bbox: BoundingBox {
93 xmin,
94 ymin,
95 xmax,
96 ymax,
97 },
98 extents: Extents {
99 height: Length::from_meters(height),
100 width: Length::from_meters(width),
101 length: Length::from_meters(length),
102 },
103 location: Location {
104 x: Length::from_meters(x),
105 y: Length::from_meters(y),
106 z: Length::from_meters(z),
107 },
108 rotation_y: Angle::from_radians(rotation_y),
109 }
110 }
111}
112
113impl From<Label> for SerializedLabel {
114 fn from(from: Label) -> Self {
115 let Label {
116 frame,
117 track_id,
118 class,
119 truncation,
120 occlusion,
121 alpha,
122 bbox:
123 BoundingBox {
124 xmin,
125 ymin,
126 xmax,
127 ymax,
128 },
129 extents:
130 Extents {
131 height,
132 width,
133 length,
134 },
135 location: Location { x, y, z },
136 rotation_y,
137 } = from;
138
139 SerializedLabel {
140 frame,
141 track_id: match track_id {
142 Some(track_id) => track_id as i32,
143 None => -1,
144 },
145 class,
146 truncation,
147 occlusion,
148 alpha: alpha.as_radians(),
149 xmin,
150 ymin,
151 xmax,
152 ymax,
153 height: height.as_meters(),
154 width: width.as_meters(),
155 length: length.as_meters(),
156 x: x.as_meters(),
157 y: y.as_meters(),
158 z: z.as_meters(),
159 rotation_y: rotation_y.as_radians(),
160 }
161 }
162}
163
164impl Label {
165 pub fn iter_from_reader<R>(reader: R) -> LabelFromReaderIter<R>
166 where
167 R: Read,
168 {
169 let reader = csv::ReaderBuilder::new()
170 .has_headers(false)
171 .delimiter(b' ')
172 .from_reader(reader);
173 reader.into_deserialize()
174 }
175
176 pub fn iter_from_path<P>(path: P) -> io::Result<LabelFromPathIter>
177 where
178 P: AsRef<Path>,
179 {
180 let reader = BufReader::new(File::open(path)?);
181 Ok(Self::iter_from_reader(reader))
182 }
183
184 pub fn iter_from_str(text: &str) -> LabelFromStrIter<'_> {
185 let reader = Cursor::new(text);
186 Self::iter_from_reader(reader)
187 }
188
189 pub fn vec_from_reader<R>(reader: R) -> Result<Vec<Label>, Error>
190 where
191 R: Read,
192 {
193 let reader = csv::ReaderBuilder::new()
194 .has_headers(false)
195 .delimiter(b' ')
196 .from_reader(reader);
197 let result: Result<Vec<Label>, _> = reader.into_deserialize().collect();
198 Ok(result?)
199 }
200
201 pub fn vec_from_path<P>(path: P) -> Result<Vec<Label>, Error>
202 where
203 P: AsRef<Path>,
204 {
205 let reader = BufReader::new(File::open(path)?);
206 Self::vec_from_reader(reader)
207 }
208
209 pub fn vec_from_str(text: &str) -> Result<Vec<Label>, Error> {
210 let reader = Cursor::new(text);
211 Self::vec_from_reader(reader)
212 }
213
214 pub fn write_to_writer<W, I, A>(writer: W, labels: I) -> io::Result<()>
215 where
216 I: IntoIterator<Item = A>,
217 W: Write,
218 A: Borrow<Label>,
219 {
220 let mut writer = csv::WriterBuilder::new()
221 .has_headers(false)
222 .delimiter(b' ')
223 .from_writer(writer);
224
225 for record in labels {
226 writer.serialize(record.borrow())?;
227 }
228
229 writer.flush()?;
230
231 Ok(())
232 }
233
234 pub fn write_to_path<P, I, A>(path: P, labels: I) -> io::Result<()>
235 where
236 I: IntoIterator<Item = A>,
237 P: AsRef<Path>,
238 A: Borrow<Label>,
239 {
240 let writer = BufWriter::new(File::create(path)?);
241 Self::write_to_writer(writer, labels)
242 }
243
244 pub fn write_to_string<I, A>(labels: I) -> io::Result<String>
245 where
246 I: IntoIterator<Item = A>,
247 A: Borrow<Label>,
248 {
249 let mut buf = vec![];
250 Self::write_to_writer(&mut buf, labels)?;
251 Ok(String::from_utf8(buf).unwrap())
252 }
253}
254
255#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
256pub enum Truncation {
257 Ignored,
258 Labeled(R64),
259}
260
261impl Truncation {
262 pub fn as_f64(&self) -> Option<f64> {
263 match self {
264 Truncation::Ignored => None,
265 Truncation::Labeled(value) => Some(value.raw()),
266 }
267 }
268
269 pub fn as_f32(&self) -> Option<f32> {
270 Some(self.as_f64()? as f32)
271 }
272
273 pub fn from_f64(value: f64) -> Result<Self, Error> {
274 value.try_into()
275 }
276
277 pub fn from_f32(value: f32) -> Result<Self, Error> {
278 (value as f64).try_into()
279 }
280}
281
282impl TryFrom<f64> for Truncation {
283 type Error = Error;
284
285 fn try_from(fval: f64) -> Result<Self, Self::Error> {
286 let error = || Error::InvalidTruncationValue(fval);
287
288 let rval = R64::try_from(fval).map_err(|_| error())?;
289
290 if !(r64(0.0)..=r64(1.0)).contains(&rval) {
291 return Err(error());
292 }
293
294 Ok(Truncation::Labeled(rval))
295 }
296}