Skip to main content

oar_ocr_core/domain/
predictions.rs

1//! Prediction result types for the OCR pipeline.
2//!
3//! This module defines various types and traits for representing and working with
4//! prediction results in the OCR pipeline. It includes enums for different types
5//! of predictions (detection, recognition, classification, rectification) and
6//! traits for converting between different representations.
7
8use image::RgbImage;
9use serde::{Deserialize, Serialize};
10use std::borrow::Cow;
11use std::sync::Arc;
12
13/// Enum representing different types of prediction results.
14///
15/// This enum is used to represent the results of different types of predictions
16/// in the OCR pipeline, such as text detection, text recognition, image classification,
17/// and image rectification.
18///
19/// # Type Parameters
20///
21/// * `'a` - The lifetime of the borrowed data.
22/// * `I` - The type of the input images.
23#[derive(Debug, Clone)]
24pub enum PredictionResult<'a, I = Arc<RgbImage>> {
25    /// Results from text detection.
26    Detection {
27        /// The input paths of the images.
28        input_path: Vec<Cow<'a, str>>,
29        /// The indices of the images in the batch.
30        index: Vec<usize>,
31        /// The input images.
32        input_img: Vec<I>,
33        /// The detected polygons.
34        dt_polys: Vec<Vec<crate::processors::BoundingBox>>,
35        /// The scores for the detected polygons.
36        dt_scores: Vec<Vec<f32>>,
37    },
38    /// Results from text recognition.
39    Recognition {
40        /// The input paths of the images.
41        input_path: Vec<Cow<'a, str>>,
42        /// The indices of the images in the batch.
43        index: Vec<usize>,
44        /// The input images.
45        input_img: Vec<I>,
46        /// The recognized text.
47        rec_text: Vec<Cow<'a, str>>,
48        /// The scores for the recognized text.
49        rec_score: Vec<f32>,
50    },
51    /// Results from image classification.
52    Classification {
53        /// The input paths of the images.
54        input_path: Vec<Cow<'a, str>>,
55        /// The indices of the images in the batch.
56        index: Vec<usize>,
57        /// The input images.
58        input_img: Vec<I>,
59        /// The class IDs for the classifications.
60        class_ids: Vec<Vec<usize>>,
61        /// The scores for the classifications.
62        scores: Vec<Vec<f32>>,
63        /// The label names for the classifications.
64        label_names: Vec<Vec<Cow<'a, str>>>,
65    },
66    /// Results from image rectification.
67    Rectification {
68        /// The input paths of the images.
69        input_path: Vec<Cow<'a, str>>,
70        /// The indices of the images in the batch.
71        index: Vec<usize>,
72        /// The input images.
73        input_img: Vec<I>,
74        /// The rectified images.
75        rectified_img: Vec<I>,
76    },
77}
78
79/// Enum representing owned prediction results.
80///
81/// This enum is similar to PredictionResult, but uses owned String values instead
82/// of borrowed Cow values. It also implements Serialize and Deserialize traits
83/// for easy serialization and deserialization.
84///
85/// # Type Parameters
86///
87/// * `I` - The type of the input images.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub enum OwnedPredictionResult<I = Arc<RgbImage>> {
90    /// Results from text detection.
91    Detection {
92        /// The input paths of the images.
93        input_path: Vec<String>,
94        /// The indices of the images in the batch.
95        index: Vec<usize>,
96        /// The input images.
97        #[serde(skip)]
98        input_img: Vec<I>,
99        /// The detected polygons.
100        dt_polys: Vec<Vec<crate::processors::BoundingBox>>,
101        /// The scores for the detected polygons.
102        dt_scores: Vec<Vec<f32>>,
103    },
104    /// Results from text recognition.
105    Recognition {
106        /// The input paths of the images.
107        input_path: Vec<String>,
108        /// The indices of the images in the batch.
109        index: Vec<usize>,
110        /// The input images.
111        #[serde(skip)]
112        input_img: Vec<I>,
113        /// The recognized text.
114        rec_text: Vec<String>,
115        /// The scores for the recognized text.
116        rec_score: Vec<f32>,
117    },
118    /// Results from image classification.
119    Classification {
120        /// The input paths of the images.
121        input_path: Vec<String>,
122        /// The indices of the images in the batch.
123        index: Vec<usize>,
124        /// The input images.
125        #[serde(skip)]
126        input_img: Vec<I>,
127        /// The class IDs for the classifications.
128        class_ids: Vec<Vec<usize>>,
129        /// The scores for the classifications.
130        scores: Vec<Vec<f32>>,
131        /// The label names for the classifications.
132        label_names: Vec<Vec<String>>,
133    },
134    /// Results from image rectification.
135    Rectification {
136        /// The input paths of the images.
137        input_path: Vec<String>,
138        /// The indices of the images in the batch.
139        index: Vec<usize>,
140        /// The input images.
141        #[serde(skip)]
142        input_img: Vec<I>,
143        /// The rectified images.
144        #[serde(skip)]
145        rectified_img: Vec<I>,
146    },
147}
148
149/// Implementation of methods for PredictionResult.
150impl<'a, I> PredictionResult<'a, I> {
151    /// Gets the input paths of the images.
152    ///
153    /// # Returns
154    ///
155    /// A slice of the input paths.
156    pub fn input_paths(&self) -> &[Cow<'a, str>] {
157        match self {
158            PredictionResult::Detection { input_path, .. } => input_path,
159            PredictionResult::Recognition { input_path, .. } => input_path,
160            PredictionResult::Classification { input_path, .. } => input_path,
161            PredictionResult::Rectification { input_path, .. } => input_path,
162        }
163    }
164
165    /// Gets the indices of the images in the batch.
166    ///
167    /// # Returns
168    ///
169    /// A slice of the indices.
170    pub fn indices(&self) -> &[usize] {
171        match self {
172            PredictionResult::Detection { index, .. } => index,
173            PredictionResult::Recognition { index, .. } => index,
174            PredictionResult::Classification { index, .. } => index,
175            PredictionResult::Rectification { index, .. } => index,
176        }
177    }
178
179    /// Gets the input images.
180    ///
181    /// # Returns
182    ///
183    /// A slice of the input images.
184    pub fn input_images(&self) -> &[I] {
185        match self {
186            PredictionResult::Detection { input_img, .. } => input_img,
187            PredictionResult::Recognition { input_img, .. } => input_img,
188            PredictionResult::Classification { input_img, .. } => input_img,
189            PredictionResult::Rectification { input_img, .. } => input_img,
190        }
191    }
192
193    /// Checks if the prediction result is a detection result.
194    ///
195    /// # Returns
196    ///
197    /// True if the prediction result is a detection result, false otherwise.
198    pub fn is_detection(&self) -> bool {
199        matches!(self, PredictionResult::Detection { .. })
200    }
201
202    /// Checks if the prediction result is a recognition result.
203    ///
204    /// # Returns
205    ///
206    /// True if the prediction result is a recognition result, false otherwise.
207    pub fn is_recognition(&self) -> bool {
208        matches!(self, PredictionResult::Recognition { .. })
209    }
210
211    /// Checks if the prediction result is a classification result.
212    ///
213    /// # Returns
214    ///
215    /// True if the prediction result is a classification result, false otherwise.
216    pub fn is_classification(&self) -> bool {
217        matches!(self, PredictionResult::Classification { .. })
218    }
219
220    /// Checks if the prediction result is a rectification result.
221    ///
222    /// # Returns
223    ///
224    /// True if the prediction result is a rectification result, false otherwise.
225    pub fn is_rectification(&self) -> bool {
226        matches!(self, PredictionResult::Rectification { .. })
227    }
228
229    /// Converts the prediction result to an owned prediction result.
230    ///
231    /// # Returns
232    ///
233    /// An OwnedPredictionResult with the same data.
234    pub fn into_owned(self) -> OwnedPredictionResult<I> {
235        match self {
236            PredictionResult::Detection {
237                input_path,
238                index,
239                input_img,
240                dt_polys,
241                dt_scores,
242            } => OwnedPredictionResult::Detection {
243                input_path: input_path.into_iter().map(|cow| cow.into_owned()).collect(),
244                index,
245                input_img,
246                dt_polys,
247                dt_scores,
248            },
249            PredictionResult::Recognition {
250                input_path,
251                index,
252                input_img,
253                rec_text,
254                rec_score,
255            } => OwnedPredictionResult::Recognition {
256                input_path: input_path.into_iter().map(|cow| cow.into_owned()).collect(),
257                index,
258                input_img,
259                rec_text: rec_text.into_iter().map(|cow| cow.into_owned()).collect(),
260                rec_score,
261            },
262            PredictionResult::Classification {
263                input_path,
264                index,
265                input_img,
266                class_ids,
267                scores,
268                label_names,
269            } => OwnedPredictionResult::Classification {
270                input_path: input_path.into_iter().map(|cow| cow.into_owned()).collect(),
271                index,
272                input_img,
273                class_ids,
274                scores,
275                label_names: label_names
276                    .into_iter()
277                    .map(|vec| vec.into_iter().map(|cow| cow.into_owned()).collect())
278                    .collect(),
279            },
280            PredictionResult::Rectification {
281                input_path,
282                index,
283                input_img,
284                rectified_img,
285            } => OwnedPredictionResult::Rectification {
286                input_path: input_path.into_iter().map(|cow| cow.into_owned()).collect(),
287                index,
288                input_img,
289                rectified_img,
290            },
291        }
292    }
293}
294
295/// Implementation of methods for OwnedPredictionResult.
296impl<I> OwnedPredictionResult<I> {
297    /// Gets the input paths of the images.
298    ///
299    /// # Returns
300    ///
301    /// A slice of the input paths.
302    pub fn input_paths(&self) -> &[String] {
303        match self {
304            OwnedPredictionResult::Detection { input_path, .. } => input_path,
305            OwnedPredictionResult::Recognition { input_path, .. } => input_path,
306            OwnedPredictionResult::Classification { input_path, .. } => input_path,
307            OwnedPredictionResult::Rectification { input_path, .. } => input_path,
308        }
309    }
310
311    /// Gets the indices of the images in the batch.
312    ///
313    /// # Returns
314    ///
315    /// A slice of the indices.
316    pub fn indices(&self) -> &[usize] {
317        match self {
318            OwnedPredictionResult::Detection { index, .. } => index,
319            OwnedPredictionResult::Recognition { index, .. } => index,
320            OwnedPredictionResult::Classification { index, .. } => index,
321            OwnedPredictionResult::Rectification { index, .. } => index,
322        }
323    }
324
325    /// Gets the input images.
326    ///
327    /// # Returns
328    ///
329    /// A slice of the input images.
330    pub fn input_images(&self) -> &[I] {
331        match self {
332            OwnedPredictionResult::Detection { input_img, .. } => input_img,
333            OwnedPredictionResult::Recognition { input_img, .. } => input_img,
334            OwnedPredictionResult::Classification { input_img, .. } => input_img,
335            OwnedPredictionResult::Rectification { input_img, .. } => input_img,
336        }
337    }
338
339    /// Checks if the prediction result is a detection result.
340    ///
341    /// # Returns
342    ///
343    /// True if the prediction result is a detection result, false otherwise.
344    pub fn is_detection(&self) -> bool {
345        matches!(self, OwnedPredictionResult::Detection { .. })
346    }
347
348    /// Checks if the prediction result is a recognition result.
349    ///
350    /// # Returns
351    ///
352    /// True if the prediction result is a recognition result, false otherwise.
353    pub fn is_recognition(&self) -> bool {
354        matches!(self, OwnedPredictionResult::Recognition { .. })
355    }
356
357    /// Checks if the prediction result is a classification result.
358    ///
359    /// # Returns
360    ///
361    /// True if the prediction result is a classification result, false otherwise.
362    pub fn is_classification(&self) -> bool {
363        matches!(self, OwnedPredictionResult::Classification { .. })
364    }
365
366    /// Checks if the prediction result is a rectification result.
367    ///
368    /// # Returns
369    ///
370    /// True if the prediction result is a rectification result, false otherwise.
371    pub fn is_rectification(&self) -> bool {
372        matches!(self, OwnedPredictionResult::Rectification { .. })
373    }
374
375    /// Converts the owned prediction result to a borrowed prediction result.
376    ///
377    /// # Returns
378    ///
379    /// A PredictionResult with borrowed data.
380    pub fn as_prediction_result(&self) -> PredictionResult<'_, &I> {
381        match self {
382            OwnedPredictionResult::Detection {
383                input_path,
384                index,
385                input_img,
386                dt_polys,
387                dt_scores,
388            } => PredictionResult::Detection {
389                input_path: input_path
390                    .iter()
391                    .map(|s| Cow::Borrowed(s.as_str()))
392                    .collect(),
393                index: index.clone(),
394                input_img: input_img.iter().collect(),
395                dt_polys: dt_polys.clone(),
396                dt_scores: dt_scores.clone(),
397            },
398            OwnedPredictionResult::Recognition {
399                input_path,
400                index,
401                input_img,
402                rec_text,
403                rec_score,
404            } => PredictionResult::Recognition {
405                input_path: input_path
406                    .iter()
407                    .map(|s| Cow::Borrowed(s.as_str()))
408                    .collect(),
409                index: index.clone(),
410                input_img: input_img.iter().collect(),
411                rec_text: rec_text.iter().map(|s| Cow::Borrowed(s.as_str())).collect(),
412                rec_score: rec_score.clone(),
413            },
414            OwnedPredictionResult::Classification {
415                input_path,
416                index,
417                input_img,
418                class_ids,
419                scores,
420                label_names,
421            } => PredictionResult::Classification {
422                input_path: input_path
423                    .iter()
424                    .map(|s| Cow::Borrowed(s.as_str()))
425                    .collect(),
426                index: index.clone(),
427                input_img: input_img.iter().collect(),
428                class_ids: class_ids.clone(),
429                scores: scores.clone(),
430                label_names: label_names
431                    .iter()
432                    .map(|vec| vec.iter().map(|s| Cow::Borrowed(s.as_str())).collect())
433                    .collect(),
434            },
435            OwnedPredictionResult::Rectification {
436                input_path,
437                index,
438                input_img,
439                rectified_img,
440            } => PredictionResult::Rectification {
441                input_path: input_path
442                    .iter()
443                    .map(|s| Cow::Borrowed(s.as_str()))
444                    .collect(),
445                index: index.clone(),
446                input_img: input_img.iter().collect(),
447                rectified_img: rectified_img.iter().collect(),
448            },
449        }
450    }
451}
452
453/// Trait for converting a type into a prediction result.
454///
455/// This trait is used to convert a type into a prediction result.
456pub trait IntoPrediction {
457    /// The output type.
458    type Out;
459    /// Converts the type into a prediction result.
460    ///
461    /// # Returns
462    ///
463    /// The prediction result.
464    fn into_prediction(self) -> Self::Out;
465}
466
467/// Trait for converting a type into an owned prediction result.
468///
469/// This trait is used to convert a type into an owned prediction result.
470pub trait IntoOwnedPrediction {
471    /// The output type.
472    type Out;
473    /// Converts the type into an owned prediction result.
474    ///
475    /// # Returns
476    ///
477    /// The owned prediction result.
478    fn into_owned_prediction(self) -> Self::Out;
479}
480
481/// Implementation of IntoOwnedPrediction for types that implement IntoPrediction.
482///
483/// This implementation allows types that implement IntoPrediction to be converted
484/// into owned prediction results.
485impl<T> IntoOwnedPrediction for T
486where
487    T: IntoPrediction,
488    T::Out: Into<OwnedPredictionResult>,
489{
490    type Out = OwnedPredictionResult;
491
492    fn into_owned_prediction(self) -> Self::Out {
493        self.into_prediction().into()
494    }
495}
496
497/// Implementation of From for converting PredictionResult to OwnedPredictionResult.
498///
499/// This implementation allows PredictionResult to be converted to OwnedPredictionResult.
500impl<I> From<PredictionResult<'_, I>> for OwnedPredictionResult<I> {
501    fn from(result: PredictionResult<'_, I>) -> Self {
502        result.into_owned()
503    }
504}