use super::error::{ExtractionError, Result};
use super::patterns::{InvoiceFieldType, PatternLibrary};
use super::types::{
BoundingBox, ExtractedField, InvoiceData, InvoiceField, InvoiceMetadata, Language,
};
use super::validators;
use crate::text::extraction::TextFragment;
pub struct InvoiceExtractor {
pattern_library: PatternLibrary,
confidence_threshold: f64,
use_kerning: bool,
language: Option<Language>,
}
impl InvoiceExtractor {
pub fn builder() -> InvoiceExtractorBuilder {
InvoiceExtractorBuilder::new()
}
pub fn extract(&self, text_fragments: &[TextFragment]) -> Result<InvoiceData> {
if text_fragments.is_empty() {
return Err(ExtractionError::NoTextFound(1));
}
let full_text = self.reconstruct_text(text_fragments);
let matches = self.pattern_library.match_text(&full_text);
let mut fields = Vec::new();
for (field_type, matched_value, base_confidence) in matches {
let confidence =
self.calculate_confidence(&field_type, base_confidence, &matched_value, &full_text);
if confidence < self.confidence_threshold {
continue;
}
let position = self.find_match_position(&matched_value, text_fragments);
if let Some(invoice_field) = self.convert_to_invoice_field(field_type, &matched_value) {
fields.push(ExtractedField::new(
invoice_field,
confidence,
position,
matched_value,
));
}
}
let overall_confidence = if fields.is_empty() {
0.0
} else {
fields.iter().map(|f| f.confidence).sum::<f64>() / fields.len() as f64
};
let metadata = InvoiceMetadata::new(1, overall_confidence)
.with_language(self.language.unwrap_or(Language::English));
Ok(InvoiceData::new(fields, metadata))
}
pub fn extract_from_text(&self, text: &str) -> Result<InvoiceData> {
if text.is_empty() {
return Err(ExtractionError::NoTextFound(1));
}
let fragment = TextFragment {
text: text.to_string(),
x: 0.0,
y: 0.0,
width: 0.0,
height: 12.0,
font_size: 12.0,
font_name: None,
is_bold: false,
is_italic: false,
color: None,
space_decisions: Vec::new(),
};
self.extract(&[fragment])
}
fn reconstruct_text(&self, fragments: &[TextFragment]) -> String {
if fragments.is_empty() {
return String::new();
}
if !self.use_kerning {
return fragments
.iter()
.map(|f| f.text.as_str())
.collect::<Vec<_>>()
.join(" ");
}
let mut result = String::with_capacity(
fragments.iter().map(|f| f.text.len()).sum::<usize>() + fragments.len(),
);
for (i, fragment) in fragments.iter().enumerate() {
result.push_str(&fragment.text);
if i < fragments.len() - 1 {
let next = &fragments[i + 1];
let spacing = match (&fragment.font_name, &next.font_name) {
(Some(f1), Some(f2)) if f1 == f2 => " ", _ => " ", };
result.push_str(spacing);
}
}
result
}
fn parse_amount(&self, value: &str) -> Option<f64> {
let uses_european_format = matches!(
self.language,
Some(Language::Spanish) | Some(Language::German) | Some(Language::Italian)
);
let normalized = if uses_european_format {
value.replace('.', "").replace(',', ".")
} else {
value.replace(',', "")
};
normalized.parse::<f64>().ok()
}
fn calculate_confidence(
&self,
field_type: &InvoiceFieldType,
base_confidence: f64,
matched_value: &str,
full_text: &str,
) -> f64 {
let mut score = base_confidence;
let validation_adjustment = match field_type {
InvoiceFieldType::InvoiceDate | InvoiceFieldType::DueDate => {
validators::validate_date(matched_value)
}
InvoiceFieldType::TotalAmount
| InvoiceFieldType::TaxAmount
| InvoiceFieldType::NetAmount
| InvoiceFieldType::LineItemUnitPrice => validators::validate_amount(matched_value),
InvoiceFieldType::InvoiceNumber => validators::validate_invoice_number(matched_value),
InvoiceFieldType::VatNumber => validators::validate_vat_number(matched_value),
InvoiceFieldType::SupplierName
| InvoiceFieldType::CustomerName
| InvoiceFieldType::Currency
| InvoiceFieldType::ArticleNumber
| InvoiceFieldType::LineItemDescription
| InvoiceFieldType::LineItemQuantity => 0.0,
};
score += validation_adjustment;
let proximity_bonus = self.calculate_proximity_bonus(field_type, matched_value, full_text);
score += proximity_bonus;
score.clamp(0.0, 1.0)
}
fn calculate_proximity_bonus(
&self,
field_type: &InvoiceFieldType,
matched_value: &str,
full_text: &str,
) -> f64 {
let keywords: Vec<&str> = match field_type {
InvoiceFieldType::InvoiceNumber => {
vec![
"Invoice", "Factura", "Rechnung", "Fattura", "Number", "Número", "Nr",
]
}
InvoiceFieldType::InvoiceDate => {
vec!["Date", "Fecha", "Datum", "Data", "Invoice Date"]
}
InvoiceFieldType::DueDate => {
vec!["Due", "Vencimiento", "Fällig", "Scadenza", "Payment"]
}
InvoiceFieldType::TotalAmount => {
vec![
"Total",
"Grand Total",
"Amount Due",
"Gesamtbetrag",
"Totale",
]
}
InvoiceFieldType::TaxAmount => {
vec!["VAT", "IVA", "MwSt", "Tax", "Impuesto"]
}
InvoiceFieldType::NetAmount => {
vec![
"Subtotal",
"Net",
"Neto",
"Nettobetrag",
"Imponibile",
"Base",
]
}
InvoiceFieldType::VatNumber => {
vec!["VAT", "CIF", "NIF", "USt", "Partita IVA", "Tax ID"]
}
InvoiceFieldType::CustomerName => {
vec!["Bill to", "Customer", "Client", "Cliente"]
}
InvoiceFieldType::SupplierName => {
vec!["From", "Supplier", "Vendor", "Proveedor"]
}
_ => return 0.0, };
let match_pos = match full_text.find(matched_value) {
Some(pos) => pos,
None => return 0.0, };
let mut min_distance = usize::MAX;
for keyword in keywords {
let text_lower = full_text.to_lowercase();
let keyword_lower = keyword.to_lowercase();
if let Some(keyword_pos) = text_lower.find(&keyword_lower) {
let distance = if keyword_pos < match_pos {
match_pos - keyword_pos
} else {
keyword_pos - match_pos
};
min_distance = min_distance.min(distance);
}
}
match min_distance {
0..=20 => 0.15, 21..=50 => 0.10, 51..=100 => 0.05, _ => 0.0, }
}
fn find_match_position(&self, matched_value: &str, fragments: &[TextFragment]) -> BoundingBox {
for fragment in fragments {
if fragment.text.contains(matched_value) {
return BoundingBox::new(fragment.x, fragment.y, fragment.width, fragment.height);
}
}
if let Some(first) = fragments.first() {
BoundingBox::new(first.x, first.y, first.width, first.height)
} else {
BoundingBox::new(0.0, 0.0, 0.0, 0.0)
}
}
fn convert_to_invoice_field(
&self,
field_type: InvoiceFieldType,
value: &str,
) -> Option<InvoiceField> {
match field_type {
InvoiceFieldType::InvoiceNumber => Some(InvoiceField::InvoiceNumber(value.to_string())),
InvoiceFieldType::InvoiceDate => Some(InvoiceField::InvoiceDate(value.to_string())),
InvoiceFieldType::DueDate => Some(InvoiceField::DueDate(value.to_string())),
InvoiceFieldType::TotalAmount => {
self.parse_amount(value).map(InvoiceField::TotalAmount)
}
InvoiceFieldType::TaxAmount => self.parse_amount(value).map(InvoiceField::TaxAmount),
InvoiceFieldType::NetAmount => self.parse_amount(value).map(InvoiceField::NetAmount),
InvoiceFieldType::VatNumber => Some(InvoiceField::VatNumber(value.to_string())),
InvoiceFieldType::SupplierName => Some(InvoiceField::SupplierName(value.to_string())),
InvoiceFieldType::CustomerName => Some(InvoiceField::CustomerName(value.to_string())),
InvoiceFieldType::Currency => Some(InvoiceField::Currency(value.to_string())),
InvoiceFieldType::ArticleNumber => Some(InvoiceField::ArticleNumber(value.to_string())),
InvoiceFieldType::LineItemDescription => {
Some(InvoiceField::LineItemDescription(value.to_string()))
}
InvoiceFieldType::LineItemQuantity => {
self.parse_amount(value).map(InvoiceField::LineItemQuantity)
}
InvoiceFieldType::LineItemUnitPrice => self
.parse_amount(value)
.map(InvoiceField::LineItemUnitPrice),
}
}
}
pub struct InvoiceExtractorBuilder {
language: Option<Language>,
confidence_threshold: f64,
use_kerning: bool,
custom_patterns: Option<PatternLibrary>,
}
impl InvoiceExtractorBuilder {
pub fn new() -> Self {
Self {
language: None,
confidence_threshold: 0.7,
use_kerning: true,
custom_patterns: None,
}
}
pub fn with_language(mut self, lang: &str) -> Self {
self.language = Language::from_code(lang);
self
}
pub fn confidence_threshold(mut self, threshold: f64) -> Self {
self.confidence_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn use_kerning(mut self, enabled: bool) -> Self {
self.use_kerning = enabled;
self
}
pub fn with_custom_patterns(mut self, patterns: PatternLibrary) -> Self {
self.custom_patterns = Some(patterns);
self
}
pub fn build(self) -> InvoiceExtractor {
let pattern_library = if let Some(custom) = self.custom_patterns {
custom
} else if let Some(lang) = self.language {
PatternLibrary::with_language(lang)
} else {
PatternLibrary::new()
};
InvoiceExtractor {
pattern_library,
confidence_threshold: self.confidence_threshold,
use_kerning: self.use_kerning,
language: self.language,
}
}
}
impl Default for InvoiceExtractorBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_defaults() {
let extractor = InvoiceExtractor::builder().build();
assert_eq!(extractor.confidence_threshold, 0.7);
assert!(extractor.use_kerning);
assert!(extractor.language.is_none());
}
#[test]
fn test_builder_with_language() {
let extractor = InvoiceExtractor::builder().with_language("es").build();
assert_eq!(extractor.language, Some(Language::Spanish));
}
#[test]
fn test_builder_confidence_threshold() {
let extractor = InvoiceExtractor::builder()
.confidence_threshold(0.9)
.build();
assert_eq!(extractor.confidence_threshold, 0.9);
}
#[test]
fn test_builder_use_kerning() {
let extractor = InvoiceExtractor::builder().use_kerning(false).build();
assert!(!extractor.use_kerning);
}
#[test]
fn test_use_kerning_stored_for_future_use() {
let extractor_enabled = InvoiceExtractor::builder().use_kerning(true).build();
assert!(
extractor_enabled.use_kerning,
"use_kerning should be stored as true"
);
let extractor_disabled = InvoiceExtractor::builder().use_kerning(false).build();
assert!(
!extractor_disabled.use_kerning,
"use_kerning should be stored as false"
);
let extractor_default = InvoiceExtractor::builder().build();
assert!(
extractor_default.use_kerning,
"use_kerning should default to true"
);
}
}