use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct BoundingBox {
pub x1: f32,
pub y1: f32,
pub x2: f32,
pub y2: f32,
}
impl BoundingBox {
pub fn new(x1: f32, y1: f32, x2: f32, y2: f32) -> Self {
Self { x1, y1, x2, y2 }
}
pub fn area(&self) -> f32 {
(self.x2 - self.x1).max(0.0) * (self.y2 - self.y1).max(0.0)
}
pub fn iou(&self, other: &BoundingBox) -> f32 {
let x1 = self.x1.max(other.x1);
let y1 = self.y1.max(other.y1);
let x2 = self.x2.min(other.x2);
let y2 = self.y2.min(other.y2);
let intersection = (x2 - x1).max(0.0) * (y2 - y1).max(0.0);
let union = self.area() + other.area() - intersection;
if union > 0.0 {
intersection / union
} else {
0.0
}
}
pub fn overlaps(&self, other: &BoundingBox, threshold: f32) -> bool {
self.iou(other) >= threshold
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisualGold {
pub text: String,
pub entity_type: String,
pub bbox: BoundingBox,
}
impl VisualGold {
pub fn new(text: impl Into<String>, entity_type: impl Into<String>, bbox: BoundingBox) -> Self {
Self {
text: text.into(),
entity_type: entity_type.into(),
bbox,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisualPrediction {
pub text: String,
pub entity_type: String,
pub bbox: BoundingBox,
pub confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisualEvalConfig {
pub iou_threshold: f32,
pub case_insensitive: bool,
pub normalize_whitespace: bool,
pub require_type_match: bool,
}
impl Default for VisualEvalConfig {
fn default() -> Self {
Self {
iou_threshold: 0.5,
case_insensitive: false,
normalize_whitespace: true,
require_type_match: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisualNERMetrics {
pub text_precision: f64,
pub text_recall: f64,
pub text_f1: f64,
pub mean_iou: f64,
pub box_precision: f64,
pub box_recall: f64,
pub box_f1: f64,
pub e2e_precision: f64,
pub e2e_recall: f64,
pub e2e_f1: f64,
pub per_type: HashMap<String, VisualTypeMetrics>,
pub num_predicted: usize,
pub num_gold: usize,
pub text_matches: usize,
pub box_matches: usize,
pub e2e_matches: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisualTypeMetrics {
pub entity_type: String,
pub text_f1: f64,
pub box_f1: f64,
pub e2e_f1: f64,
pub support: usize,
}
pub fn evaluate_visual_ner(
gold: &[VisualGold],
pred: &[VisualPrediction],
config: &VisualEvalConfig,
) -> VisualNERMetrics {
let mut text_matches = 0;
let mut box_matches = 0;
let mut e2e_matches = 0;
let mut iou_sum = 0.0f64;
let mut iou_count = 0;
let mut type_stats: HashMap<String, (usize, usize, usize, usize, usize)> = HashMap::new();
let mut gold_text_matched = vec![false; gold.len()];
let mut gold_box_matched = vec![false; gold.len()];
let mut gold_e2e_matched = vec![false; gold.len()];
for g in gold {
type_stats
.entry(g.entity_type.clone())
.or_insert((0, 0, 0, 0, 0))
.0 += 1;
}
for p in pred {
type_stats
.entry(p.entity_type.clone())
.or_insert((0, 0, 0, 0, 0))
.1 += 1;
}
for p in pred {
let pred_text = normalize_text(&p.text, config);
for (g_idx, g) in gold.iter().enumerate() {
if config.require_type_match && p.entity_type != g.entity_type {
continue;
}
let gold_text = normalize_text(&g.text, config);
let text_match = pred_text == gold_text;
let iou = p.bbox.iou(&g.bbox);
let box_match = iou >= config.iou_threshold;
if iou > 0.0 {
iou_sum += iou as f64;
iou_count += 1;
}
if text_match && !gold_text_matched[g_idx] {
gold_text_matched[g_idx] = true;
text_matches += 1;
if let Some(stats) = type_stats.get_mut(&g.entity_type) {
stats.2 += 1;
}
}
if box_match && !gold_box_matched[g_idx] {
gold_box_matched[g_idx] = true;
box_matches += 1;
if let Some(stats) = type_stats.get_mut(&g.entity_type) {
stats.3 += 1;
}
}
if text_match && box_match && !gold_e2e_matched[g_idx] {
gold_e2e_matched[g_idx] = true;
e2e_matches += 1;
if let Some(stats) = type_stats.get_mut(&g.entity_type) {
stats.4 += 1;
}
break; }
}
}
let num_gold = gold.len();
let num_pred = pred.len();
let text_precision = if num_pred > 0 {
text_matches as f64 / num_pred as f64
} else {
0.0
};
let text_recall = if num_gold > 0 {
text_matches as f64 / num_gold as f64
} else {
0.0
};
let text_f1 = f1(text_precision, text_recall);
let box_precision = if num_pred > 0 {
box_matches as f64 / num_pred as f64
} else {
0.0
};
let box_recall = if num_gold > 0 {
box_matches as f64 / num_gold as f64
} else {
0.0
};
let box_f1 = f1(box_precision, box_recall);
let e2e_precision = if num_pred > 0 {
e2e_matches as f64 / num_pred as f64
} else {
0.0
};
let e2e_recall = if num_gold > 0 {
e2e_matches as f64 / num_gold as f64
} else {
0.0
};
let e2e_f1 = f1(e2e_precision, e2e_recall);
let mean_iou = if iou_count > 0 {
iou_sum / iou_count as f64
} else {
0.0
};
let per_type: HashMap<_, _> = type_stats
.into_iter()
.map(|(et, (gold_count, pred_count, text_tp, box_tp, e2e_tp))| {
let text_f1 = if gold_count > 0 && pred_count > 0 {
let p = text_tp as f64 / pred_count as f64;
let r = text_tp as f64 / gold_count as f64;
f1(p, r)
} else {
0.0
};
let box_f1 = if gold_count > 0 && pred_count > 0 {
let p = box_tp as f64 / pred_count as f64;
let r = box_tp as f64 / gold_count as f64;
f1(p, r)
} else {
0.0
};
let e2e_f1 = if gold_count > 0 && pred_count > 0 {
let p = e2e_tp as f64 / pred_count as f64;
let r = e2e_tp as f64 / gold_count as f64;
f1(p, r)
} else {
0.0
};
(
et.clone(),
VisualTypeMetrics {
entity_type: et,
text_f1,
box_f1,
e2e_f1,
support: gold_count,
},
)
})
.collect();
VisualNERMetrics {
text_precision,
text_recall,
text_f1,
mean_iou,
box_precision,
box_recall,
box_f1,
e2e_precision,
e2e_recall,
e2e_f1,
per_type,
num_predicted: num_pred,
num_gold,
text_matches,
box_matches,
e2e_matches,
}
}
fn normalize_text(text: &str, config: &VisualEvalConfig) -> String {
let mut s = text.to_string();
if config.case_insensitive {
s = s.to_lowercase();
}
if config.normalize_whitespace {
s = s.split_whitespace().collect::<Vec<_>>().join(" ");
}
s
}
fn f1(precision: f64, recall: f64) -> f64 {
if precision + recall > 0.0 {
2.0 * precision * recall / (precision + recall)
} else {
0.0
}
}
pub fn synthetic_visual_examples() -> Vec<(String, Vec<VisualGold>)> {
vec![
(
"Invoice #12345".to_string(),
vec![VisualGold::new(
"Invoice #12345",
"DOCUMENT_ID",
BoundingBox::new(0.1, 0.05, 0.4, 0.1),
)],
),
(
"Total: $1,234.56\nDate: 2024-01-15".to_string(),
vec![
VisualGold::new("$1,234.56", "MONEY", BoundingBox::new(0.5, 0.8, 0.7, 0.85)),
VisualGold::new("2024-01-15", "DATE", BoundingBox::new(0.5, 0.7, 0.7, 0.75)),
],
),
(
"Acme Corp\n123 Main St, City".to_string(),
vec![
VisualGold::new("Acme Corp", "ORG", BoundingBox::new(0.1, 0.1, 0.35, 0.15)),
VisualGold::new(
"123 Main St, City",
"ADDRESS",
BoundingBox::new(0.1, 0.16, 0.5, 0.21),
),
],
),
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bounding_box_area() {
let bbox = BoundingBox::new(0.0, 0.0, 0.5, 0.5);
assert!((bbox.area() - 0.25).abs() < 0.001);
}
#[test]
fn test_bounding_box_iou_identical() {
let bbox1 = BoundingBox::new(0.1, 0.1, 0.5, 0.5);
let bbox2 = BoundingBox::new(0.1, 0.1, 0.5, 0.5);
assert!((bbox1.iou(&bbox2) - 1.0).abs() < 0.001);
}
#[test]
fn test_bounding_box_iou_no_overlap() {
let bbox1 = BoundingBox::new(0.0, 0.0, 0.2, 0.2);
let bbox2 = BoundingBox::new(0.5, 0.5, 0.7, 0.7);
assert!(bbox1.iou(&bbox2) < 0.001);
}
#[test]
fn test_bounding_box_iou_partial() {
let bbox1 = BoundingBox::new(0.0, 0.0, 0.5, 0.5);
let bbox2 = BoundingBox::new(0.25, 0.25, 0.75, 0.75);
let iou = bbox1.iou(&bbox2);
assert!(iou > 0.1 && iou < 0.2);
}
#[test]
fn test_evaluate_perfect_match() {
let gold = vec![VisualGold::new(
"Invoice",
"DOC",
BoundingBox::new(0.1, 0.1, 0.3, 0.15),
)];
let pred = vec![VisualPrediction {
text: "Invoice".to_string(),
entity_type: "DOC".to_string(),
bbox: BoundingBox::new(0.1, 0.1, 0.3, 0.15),
confidence: 0.95,
}];
let config = VisualEvalConfig::default();
let metrics = evaluate_visual_ner(&gold, &pred, &config);
assert!((metrics.text_f1 - 1.0).abs() < 0.001);
assert!((metrics.e2e_f1 - 1.0).abs() < 0.001);
}
#[test]
fn test_evaluate_text_only_match() {
let gold = vec![VisualGold::new(
"Invoice",
"DOC",
BoundingBox::new(0.1, 0.1, 0.3, 0.15),
)];
let pred = vec![VisualPrediction {
text: "Invoice".to_string(),
entity_type: "DOC".to_string(),
bbox: BoundingBox::new(0.5, 0.5, 0.7, 0.6), confidence: 0.95,
}];
let config = VisualEvalConfig::default();
let metrics = evaluate_visual_ner(&gold, &pred, &config);
assert!((metrics.text_f1 - 1.0).abs() < 0.001);
assert!(metrics.e2e_f1 < 0.5); }
#[test]
fn test_synthetic_examples_valid() {
let examples = synthetic_visual_examples();
assert!(!examples.is_empty());
for (text, entities) in &examples {
assert!(!text.is_empty());
for entity in entities {
assert!(entity.bbox.x1 >= 0.0 && entity.bbox.x1 <= 1.0);
assert!(entity.bbox.y1 >= 0.0 && entity.bbox.y1 <= 1.0);
assert!(entity.bbox.x2 >= entity.bbox.x1 && entity.bbox.x2 <= 1.0);
assert!(entity.bbox.y2 >= entity.bbox.y1 && entity.bbox.y2 <= 1.0);
}
}
}
}