#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Language {
Spanish,
English,
German,
Italian,
}
impl Language {
pub fn from_code(code: &str) -> Option<Self> {
match code.to_lowercase().as_str() {
"es" | "spa" | "spanish" => Some(Language::Spanish),
"en" | "eng" | "english" => Some(Language::English),
"de" | "deu" | "german" => Some(Language::German),
"it" | "ita" | "italian" => Some(Language::Italian),
_ => None,
}
}
pub fn code(&self) -> &'static str {
match self {
Language::Spanish => "es",
Language::English => "en",
Language::German => "de",
Language::Italian => "it",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BoundingBox {
pub x: f64,
pub y: f64,
pub width: f64,
pub height: f64,
}
impl BoundingBox {
pub fn new(x: f64, y: f64, width: f64, height: f64) -> Self {
Self {
x,
y,
width,
height,
}
}
pub fn contains(&self, px: f64, py: f64) -> bool {
px >= self.x && px <= self.x + self.width && py >= self.y && py <= self.y + self.height
}
pub fn area(&self) -> f64 {
self.width * self.height
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum InvoiceField {
InvoiceNumber(String),
InvoiceDate(String),
DueDate(String),
TotalAmount(f64),
TaxAmount(f64),
NetAmount(f64),
VatNumber(String),
SupplierName(String),
CustomerName(String),
Currency(String),
ArticleNumber(String),
LineItemDescription(String),
LineItemQuantity(f64),
LineItemUnitPrice(f64),
}
impl InvoiceField {
pub fn name(&self) -> &'static str {
match self {
InvoiceField::InvoiceNumber(_) => "Invoice Number",
InvoiceField::InvoiceDate(_) => "Invoice Date",
InvoiceField::DueDate(_) => "Due Date",
InvoiceField::TotalAmount(_) => "Total Amount",
InvoiceField::TaxAmount(_) => "Tax Amount",
InvoiceField::NetAmount(_) => "Net Amount",
InvoiceField::VatNumber(_) => "VAT Number",
InvoiceField::SupplierName(_) => "Supplier Name",
InvoiceField::CustomerName(_) => "Customer Name",
InvoiceField::Currency(_) => "Currency",
InvoiceField::ArticleNumber(_) => "Article Number",
InvoiceField::LineItemDescription(_) => "Line Item Description",
InvoiceField::LineItemQuantity(_) => "Line Item Quantity",
InvoiceField::LineItemUnitPrice(_) => "Line Item Unit Price",
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ExtractedField {
pub field_type: InvoiceField,
pub confidence: f64,
pub position: BoundingBox,
pub raw_text: String,
}
impl ExtractedField {
pub fn new(
field_type: InvoiceField,
confidence: f64,
position: BoundingBox,
raw_text: String,
) -> Self {
Self {
field_type,
confidence,
position,
raw_text,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct InvoiceMetadata {
pub page_number: u32,
pub extraction_confidence: f64,
pub detected_language: Option<Language>,
}
impl InvoiceMetadata {
pub fn new(page_number: u32, extraction_confidence: f64) -> Self {
Self {
page_number,
extraction_confidence,
detected_language: None,
}
}
pub fn with_language(mut self, lang: Language) -> Self {
self.detected_language = Some(lang);
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct InvoiceData {
pub fields: Vec<ExtractedField>,
pub metadata: InvoiceMetadata,
}
impl InvoiceData {
pub fn new(fields: Vec<ExtractedField>, metadata: InvoiceMetadata) -> Self {
Self { fields, metadata }
}
pub fn get_fields(&self, field_name: &str) -> Vec<&ExtractedField> {
self.fields
.iter()
.filter(|f| f.field_type.name() == field_name)
.collect()
}
pub fn get_field(&self, field_name: &str) -> Option<&ExtractedField> {
self.fields
.iter()
.find(|f| f.field_type.name() == field_name)
}
pub fn field_count(&self) -> usize {
self.fields.len()
}
pub fn filter_by_confidence(mut self, min_confidence: f64) -> Self {
self.fields.retain(|f| f.confidence >= min_confidence);
if !self.fields.is_empty() {
let sum: f64 = self.fields.iter().map(|f| f.confidence).sum();
self.metadata.extraction_confidence = sum / self.fields.len() as f64;
}
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_language_from_code() {
assert_eq!(Language::from_code("es"), Some(Language::Spanish));
assert_eq!(Language::from_code("ES"), Some(Language::Spanish));
assert_eq!(Language::from_code("spanish"), Some(Language::Spanish));
assert_eq!(Language::from_code("en"), Some(Language::English));
assert_eq!(Language::from_code("de"), Some(Language::German));
assert_eq!(Language::from_code("it"), Some(Language::Italian));
assert_eq!(Language::from_code("fr"), None);
assert_eq!(Language::from_code("invalid"), None);
}
#[test]
fn test_language_code() {
assert_eq!(Language::Spanish.code(), "es");
assert_eq!(Language::English.code(), "en");
assert_eq!(Language::German.code(), "de");
assert_eq!(Language::Italian.code(), "it");
}
#[test]
fn test_bounding_box_contains() {
let bbox = BoundingBox::new(10.0, 20.0, 50.0, 30.0);
assert!(bbox.contains(10.0, 20.0)); assert!(bbox.contains(60.0, 50.0)); assert!(bbox.contains(35.0, 35.0));
assert!(!bbox.contains(5.0, 20.0)); assert!(!bbox.contains(65.0, 35.0)); assert!(!bbox.contains(35.0, 15.0)); assert!(!bbox.contains(35.0, 55.0)); }
#[test]
fn test_bounding_box_area() {
let bbox = BoundingBox::new(0.0, 0.0, 10.0, 5.0);
assert_eq!(bbox.area(), 50.0);
}
#[test]
fn test_invoice_field_name() {
let field = InvoiceField::InvoiceNumber("INV-001".to_string());
assert_eq!(field.name(), "Invoice Number");
let field = InvoiceField::TotalAmount(1234.56);
assert_eq!(field.name(), "Total Amount");
}
#[test]
fn test_invoice_data_get_field() {
let bbox = BoundingBox::new(0.0, 0.0, 10.0, 10.0);
let field1 = ExtractedField::new(
InvoiceField::InvoiceNumber("INV-001".to_string()),
0.9,
bbox,
"INV-001".to_string(),
);
let field2 = ExtractedField::new(
InvoiceField::TotalAmount(100.0),
0.8,
bbox,
"100.00".to_string(),
);
let metadata = InvoiceMetadata::new(1, 0.85);
let data = InvoiceData::new(vec![field1, field2], metadata);
assert_eq!(data.field_count(), 2);
assert!(data.get_field("Invoice Number").is_some());
assert!(data.get_field("Total Amount").is_some());
assert!(data.get_field("Nonexistent").is_none());
}
#[test]
fn test_invoice_data_filter_by_confidence() {
let bbox = BoundingBox::new(0.0, 0.0, 10.0, 10.0);
let field1 = ExtractedField::new(
InvoiceField::InvoiceNumber("INV-001".to_string()),
0.9,
bbox,
"INV-001".to_string(),
);
let field2 = ExtractedField::new(
InvoiceField::TotalAmount(100.0),
0.5,
bbox,
"100.00".to_string(),
);
let metadata = InvoiceMetadata::new(1, 0.7);
let data = InvoiceData::new(vec![field1, field2], metadata);
let filtered = data.filter_by_confidence(0.7);
assert_eq!(filtered.field_count(), 1);
assert!(filtered.get_field("Invoice Number").is_some());
assert!(filtered.get_field("Total Amount").is_none());
}
}