oar_ocr_core/domain/predictions.rs
1//! Prediction result types for the OCR pipeline.
2//!
3//! This module defines various types and traits for representing and working with
4//! prediction results in the OCR pipeline. It includes enums for different types
5//! of predictions (detection, recognition, classification, rectification) and
6//! traits for converting between different representations.
7
8use image::RgbImage;
9use serde::{Deserialize, Serialize};
10use std::borrow::Cow;
11use std::sync::Arc;
12
13/// Enum representing different types of prediction results.
14///
15/// This enum is used to represent the results of different types of predictions
16/// in the OCR pipeline, such as text detection, text recognition, image classification,
17/// and image rectification.
18///
19/// # Type Parameters
20///
21/// * `'a` - The lifetime of the borrowed data.
22/// * `I` - The type of the input images.
23#[derive(Debug, Clone)]
24pub enum PredictionResult<'a, I = Arc<RgbImage>> {
25 /// Results from text detection.
26 Detection {
27 /// The input paths of the images.
28 input_path: Vec<Cow<'a, str>>,
29 /// The indices of the images in the batch.
30 index: Vec<usize>,
31 /// The input images.
32 input_img: Vec<I>,
33 /// The detected polygons.
34 dt_polys: Vec<Vec<crate::processors::BoundingBox>>,
35 /// The scores for the detected polygons.
36 dt_scores: Vec<Vec<f32>>,
37 },
38 /// Results from text recognition.
39 Recognition {
40 /// The input paths of the images.
41 input_path: Vec<Cow<'a, str>>,
42 /// The indices of the images in the batch.
43 index: Vec<usize>,
44 /// The input images.
45 input_img: Vec<I>,
46 /// The recognized text.
47 rec_text: Vec<Cow<'a, str>>,
48 /// The scores for the recognized text.
49 rec_score: Vec<f32>,
50 },
51 /// Results from image classification.
52 Classification {
53 /// The input paths of the images.
54 input_path: Vec<Cow<'a, str>>,
55 /// The indices of the images in the batch.
56 index: Vec<usize>,
57 /// The input images.
58 input_img: Vec<I>,
59 /// The class IDs for the classifications.
60 class_ids: Vec<Vec<usize>>,
61 /// The scores for the classifications.
62 scores: Vec<Vec<f32>>,
63 /// The label names for the classifications.
64 label_names: Vec<Vec<Cow<'a, str>>>,
65 },
66 /// Results from image rectification.
67 Rectification {
68 /// The input paths of the images.
69 input_path: Vec<Cow<'a, str>>,
70 /// The indices of the images in the batch.
71 index: Vec<usize>,
72 /// The input images.
73 input_img: Vec<I>,
74 /// The rectified images.
75 rectified_img: Vec<I>,
76 },
77}
78
79/// Enum representing owned prediction results.
80///
81/// This enum is similar to PredictionResult, but uses owned String values instead
82/// of borrowed Cow values. It also implements Serialize and Deserialize traits
83/// for easy serialization and deserialization.
84///
85/// # Type Parameters
86///
87/// * `I` - The type of the input images.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub enum OwnedPredictionResult<I = Arc<RgbImage>> {
90 /// Results from text detection.
91 Detection {
92 /// The input paths of the images.
93 input_path: Vec<String>,
94 /// The indices of the images in the batch.
95 index: Vec<usize>,
96 /// The input images.
97 #[serde(skip)]
98 input_img: Vec<I>,
99 /// The detected polygons.
100 dt_polys: Vec<Vec<crate::processors::BoundingBox>>,
101 /// The scores for the detected polygons.
102 dt_scores: Vec<Vec<f32>>,
103 },
104 /// Results from text recognition.
105 Recognition {
106 /// The input paths of the images.
107 input_path: Vec<String>,
108 /// The indices of the images in the batch.
109 index: Vec<usize>,
110 /// The input images.
111 #[serde(skip)]
112 input_img: Vec<I>,
113 /// The recognized text.
114 rec_text: Vec<String>,
115 /// The scores for the recognized text.
116 rec_score: Vec<f32>,
117 },
118 /// Results from image classification.
119 Classification {
120 /// The input paths of the images.
121 input_path: Vec<String>,
122 /// The indices of the images in the batch.
123 index: Vec<usize>,
124 /// The input images.
125 #[serde(skip)]
126 input_img: Vec<I>,
127 /// The class IDs for the classifications.
128 class_ids: Vec<Vec<usize>>,
129 /// The scores for the classifications.
130 scores: Vec<Vec<f32>>,
131 /// The label names for the classifications.
132 label_names: Vec<Vec<String>>,
133 },
134 /// Results from image rectification.
135 Rectification {
136 /// The input paths of the images.
137 input_path: Vec<String>,
138 /// The indices of the images in the batch.
139 index: Vec<usize>,
140 /// The input images.
141 #[serde(skip)]
142 input_img: Vec<I>,
143 /// The rectified images.
144 #[serde(skip)]
145 rectified_img: Vec<I>,
146 },
147}
148
149/// Implementation of methods for PredictionResult.
150impl<'a, I> PredictionResult<'a, I> {
151 /// Gets the input paths of the images.
152 ///
153 /// # Returns
154 ///
155 /// A slice of the input paths.
156 pub fn input_paths(&self) -> &[Cow<'a, str>] {
157 match self {
158 PredictionResult::Detection { input_path, .. } => input_path,
159 PredictionResult::Recognition { input_path, .. } => input_path,
160 PredictionResult::Classification { input_path, .. } => input_path,
161 PredictionResult::Rectification { input_path, .. } => input_path,
162 }
163 }
164
165 /// Gets the indices of the images in the batch.
166 ///
167 /// # Returns
168 ///
169 /// A slice of the indices.
170 pub fn indices(&self) -> &[usize] {
171 match self {
172 PredictionResult::Detection { index, .. } => index,
173 PredictionResult::Recognition { index, .. } => index,
174 PredictionResult::Classification { index, .. } => index,
175 PredictionResult::Rectification { index, .. } => index,
176 }
177 }
178
179 /// Gets the input images.
180 ///
181 /// # Returns
182 ///
183 /// A slice of the input images.
184 pub fn input_images(&self) -> &[I] {
185 match self {
186 PredictionResult::Detection { input_img, .. } => input_img,
187 PredictionResult::Recognition { input_img, .. } => input_img,
188 PredictionResult::Classification { input_img, .. } => input_img,
189 PredictionResult::Rectification { input_img, .. } => input_img,
190 }
191 }
192
193 /// Checks if the prediction result is a detection result.
194 ///
195 /// # Returns
196 ///
197 /// True if the prediction result is a detection result, false otherwise.
198 pub fn is_detection(&self) -> bool {
199 matches!(self, PredictionResult::Detection { .. })
200 }
201
202 /// Checks if the prediction result is a recognition result.
203 ///
204 /// # Returns
205 ///
206 /// True if the prediction result is a recognition result, false otherwise.
207 pub fn is_recognition(&self) -> bool {
208 matches!(self, PredictionResult::Recognition { .. })
209 }
210
211 /// Checks if the prediction result is a classification result.
212 ///
213 /// # Returns
214 ///
215 /// True if the prediction result is a classification result, false otherwise.
216 pub fn is_classification(&self) -> bool {
217 matches!(self, PredictionResult::Classification { .. })
218 }
219
220 /// Checks if the prediction result is a rectification result.
221 ///
222 /// # Returns
223 ///
224 /// True if the prediction result is a rectification result, false otherwise.
225 pub fn is_rectification(&self) -> bool {
226 matches!(self, PredictionResult::Rectification { .. })
227 }
228
229 /// Converts the prediction result to an owned prediction result.
230 ///
231 /// # Returns
232 ///
233 /// An OwnedPredictionResult with the same data.
234 pub fn into_owned(self) -> OwnedPredictionResult<I> {
235 match self {
236 PredictionResult::Detection {
237 input_path,
238 index,
239 input_img,
240 dt_polys,
241 dt_scores,
242 } => OwnedPredictionResult::Detection {
243 input_path: input_path.into_iter().map(|cow| cow.into_owned()).collect(),
244 index,
245 input_img,
246 dt_polys,
247 dt_scores,
248 },
249 PredictionResult::Recognition {
250 input_path,
251 index,
252 input_img,
253 rec_text,
254 rec_score,
255 } => OwnedPredictionResult::Recognition {
256 input_path: input_path.into_iter().map(|cow| cow.into_owned()).collect(),
257 index,
258 input_img,
259 rec_text: rec_text.into_iter().map(|cow| cow.into_owned()).collect(),
260 rec_score,
261 },
262 PredictionResult::Classification {
263 input_path,
264 index,
265 input_img,
266 class_ids,
267 scores,
268 label_names,
269 } => OwnedPredictionResult::Classification {
270 input_path: input_path.into_iter().map(|cow| cow.into_owned()).collect(),
271 index,
272 input_img,
273 class_ids,
274 scores,
275 label_names: label_names
276 .into_iter()
277 .map(|vec| vec.into_iter().map(|cow| cow.into_owned()).collect())
278 .collect(),
279 },
280 PredictionResult::Rectification {
281 input_path,
282 index,
283 input_img,
284 rectified_img,
285 } => OwnedPredictionResult::Rectification {
286 input_path: input_path.into_iter().map(|cow| cow.into_owned()).collect(),
287 index,
288 input_img,
289 rectified_img,
290 },
291 }
292 }
293}
294
295/// Implementation of methods for OwnedPredictionResult.
296impl<I> OwnedPredictionResult<I> {
297 /// Gets the input paths of the images.
298 ///
299 /// # Returns
300 ///
301 /// A slice of the input paths.
302 pub fn input_paths(&self) -> &[String] {
303 match self {
304 OwnedPredictionResult::Detection { input_path, .. } => input_path,
305 OwnedPredictionResult::Recognition { input_path, .. } => input_path,
306 OwnedPredictionResult::Classification { input_path, .. } => input_path,
307 OwnedPredictionResult::Rectification { input_path, .. } => input_path,
308 }
309 }
310
311 /// Gets the indices of the images in the batch.
312 ///
313 /// # Returns
314 ///
315 /// A slice of the indices.
316 pub fn indices(&self) -> &[usize] {
317 match self {
318 OwnedPredictionResult::Detection { index, .. } => index,
319 OwnedPredictionResult::Recognition { index, .. } => index,
320 OwnedPredictionResult::Classification { index, .. } => index,
321 OwnedPredictionResult::Rectification { index, .. } => index,
322 }
323 }
324
325 /// Gets the input images.
326 ///
327 /// # Returns
328 ///
329 /// A slice of the input images.
330 pub fn input_images(&self) -> &[I] {
331 match self {
332 OwnedPredictionResult::Detection { input_img, .. } => input_img,
333 OwnedPredictionResult::Recognition { input_img, .. } => input_img,
334 OwnedPredictionResult::Classification { input_img, .. } => input_img,
335 OwnedPredictionResult::Rectification { input_img, .. } => input_img,
336 }
337 }
338
339 /// Checks if the prediction result is a detection result.
340 ///
341 /// # Returns
342 ///
343 /// True if the prediction result is a detection result, false otherwise.
344 pub fn is_detection(&self) -> bool {
345 matches!(self, OwnedPredictionResult::Detection { .. })
346 }
347
348 /// Checks if the prediction result is a recognition result.
349 ///
350 /// # Returns
351 ///
352 /// True if the prediction result is a recognition result, false otherwise.
353 pub fn is_recognition(&self) -> bool {
354 matches!(self, OwnedPredictionResult::Recognition { .. })
355 }
356
357 /// Checks if the prediction result is a classification result.
358 ///
359 /// # Returns
360 ///
361 /// True if the prediction result is a classification result, false otherwise.
362 pub fn is_classification(&self) -> bool {
363 matches!(self, OwnedPredictionResult::Classification { .. })
364 }
365
366 /// Checks if the prediction result is a rectification result.
367 ///
368 /// # Returns
369 ///
370 /// True if the prediction result is a rectification result, false otherwise.
371 pub fn is_rectification(&self) -> bool {
372 matches!(self, OwnedPredictionResult::Rectification { .. })
373 }
374
375 /// Converts the owned prediction result to a borrowed prediction result.
376 ///
377 /// # Returns
378 ///
379 /// A PredictionResult with borrowed data.
380 pub fn as_prediction_result(&self) -> PredictionResult<'_, &I> {
381 match self {
382 OwnedPredictionResult::Detection {
383 input_path,
384 index,
385 input_img,
386 dt_polys,
387 dt_scores,
388 } => PredictionResult::Detection {
389 input_path: input_path
390 .iter()
391 .map(|s| Cow::Borrowed(s.as_str()))
392 .collect(),
393 index: index.clone(),
394 input_img: input_img.iter().collect(),
395 dt_polys: dt_polys.clone(),
396 dt_scores: dt_scores.clone(),
397 },
398 OwnedPredictionResult::Recognition {
399 input_path,
400 index,
401 input_img,
402 rec_text,
403 rec_score,
404 } => PredictionResult::Recognition {
405 input_path: input_path
406 .iter()
407 .map(|s| Cow::Borrowed(s.as_str()))
408 .collect(),
409 index: index.clone(),
410 input_img: input_img.iter().collect(),
411 rec_text: rec_text.iter().map(|s| Cow::Borrowed(s.as_str())).collect(),
412 rec_score: rec_score.clone(),
413 },
414 OwnedPredictionResult::Classification {
415 input_path,
416 index,
417 input_img,
418 class_ids,
419 scores,
420 label_names,
421 } => PredictionResult::Classification {
422 input_path: input_path
423 .iter()
424 .map(|s| Cow::Borrowed(s.as_str()))
425 .collect(),
426 index: index.clone(),
427 input_img: input_img.iter().collect(),
428 class_ids: class_ids.clone(),
429 scores: scores.clone(),
430 label_names: label_names
431 .iter()
432 .map(|vec| vec.iter().map(|s| Cow::Borrowed(s.as_str())).collect())
433 .collect(),
434 },
435 OwnedPredictionResult::Rectification {
436 input_path,
437 index,
438 input_img,
439 rectified_img,
440 } => PredictionResult::Rectification {
441 input_path: input_path
442 .iter()
443 .map(|s| Cow::Borrowed(s.as_str()))
444 .collect(),
445 index: index.clone(),
446 input_img: input_img.iter().collect(),
447 rectified_img: rectified_img.iter().collect(),
448 },
449 }
450 }
451}
452
453/// Trait for converting a type into a prediction result.
454///
455/// This trait is used to convert a type into a prediction result.
456pub trait IntoPrediction {
457 /// The output type.
458 type Out;
459 /// Converts the type into a prediction result.
460 ///
461 /// # Returns
462 ///
463 /// The prediction result.
464 fn into_prediction(self) -> Self::Out;
465}
466
467/// Trait for converting a type into an owned prediction result.
468///
469/// This trait is used to convert a type into an owned prediction result.
470pub trait IntoOwnedPrediction {
471 /// The output type.
472 type Out;
473 /// Converts the type into an owned prediction result.
474 ///
475 /// # Returns
476 ///
477 /// The owned prediction result.
478 fn into_owned_prediction(self) -> Self::Out;
479}
480
481/// Implementation of IntoOwnedPrediction for types that implement IntoPrediction.
482///
483/// This implementation allows types that implement IntoPrediction to be converted
484/// into owned prediction results.
485impl<T> IntoOwnedPrediction for T
486where
487 T: IntoPrediction,
488 T::Out: Into<OwnedPredictionResult>,
489{
490 type Out = OwnedPredictionResult;
491
492 fn into_owned_prediction(self) -> Self::Out {
493 self.into_prediction().into()
494 }
495}
496
497/// Implementation of From for converting PredictionResult to OwnedPredictionResult.
498///
499/// This implementation allows PredictionResult to be converted to OwnedPredictionResult.
500impl<I> From<PredictionResult<'_, I>> for OwnedPredictionResult<I> {
501 fn from(result: PredictionResult<'_, I>) -> Self {
502 result.into_owned()
503 }
504}