use crate::core::traits::{Model, Tokenizer};
use crate::error::Result;
use crate::pipeline::{BasePipeline, Device, Pipeline};
use serde::{Deserialize, Serialize};
use trustformers_core::cache::CacheKeyBuilder;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentUnderstandingConfig {
pub max_length: usize,
pub return_ocr_results: bool,
pub return_layout: bool,
pub return_key_value_pairs: bool,
pub return_entities: bool,
pub confidence_threshold: f32,
pub return_text: bool,
pub language_hints: Vec<String>,
pub preprocess_text: bool,
}
impl Default for DocumentUnderstandingConfig {
fn default() -> Self {
Self {
max_length: 512,
return_ocr_results: true,
return_layout: true,
return_key_value_pairs: true,
return_entities: true,
confidence_threshold: 0.5,
return_text: true,
language_hints: vec!["en".to_string()],
preprocess_text: true,
}
}
}
#[derive(Debug, Clone)]
pub struct DocumentUnderstandingInput {
pub image: Vec<u8>,
pub image_type: String,
pub question: Option<String>,
pub extraction_targets: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BoundingBox {
pub x: f32,
pub y: f32,
pub width: f32,
pub height: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextBlock {
pub text: String,
pub bounding_box: BoundingBox,
pub confidence: f32,
pub block_type: TextBlockType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TextBlockType {
Title,
Heading,
Paragraph,
List,
Table,
Footer,
Header,
Caption,
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyValuePair {
pub key: String,
pub value: String,
pub key_bbox: BoundingBox,
pub value_bbox: BoundingBox,
pub confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentEntity {
pub text: String,
pub entity_type: String,
pub bounding_box: BoundingBox,
pub confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Table {
pub rows: Vec<Vec<String>>,
pub headers: Option<Vec<String>>,
pub bounding_box: BoundingBox,
pub confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OCRResult {
pub text: String,
pub bounding_box: BoundingBox,
pub confidence: f32,
pub word_level_boxes: Option<Vec<(String, BoundingBox)>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentUnderstandingOutput {
pub text: Option<String>,
pub text_blocks: Option<Vec<TextBlock>>,
pub key_value_pairs: Option<Vec<KeyValuePair>>,
pub entities: Option<Vec<DocumentEntity>>,
pub tables: Option<Vec<Table>>,
pub ocr_results: Option<Vec<OCRResult>>,
pub answer: Option<String>,
pub metadata: DocumentMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentMetadata {
pub page_count: usize,
pub processing_time_ms: u64,
pub detected_language: String,
pub text_orientation: f32,
pub quality_score: f32,
}
#[derive(Debug, Clone)]
struct DocumentRegion {
pub bbox: BoundingBox,
pub region_type: RegionType,
}
#[derive(Debug, Clone)]
enum RegionType {
Header,
Title,
Body,
Footer,
Table,
List,
}
pub struct DocumentUnderstandingPipeline<M, T> {
base: BasePipeline<M, T>,
config: DocumentUnderstandingConfig,
}
impl<M, T> DocumentUnderstandingPipeline<M, T>
where
M: Model + Send + Sync + 'static,
T: Tokenizer + Send + Sync + 'static,
{
pub fn new(model: M, tokenizer: T) -> Result<Self> {
Ok(Self {
base: BasePipeline::new(model, tokenizer),
config: DocumentUnderstandingConfig::default(),
})
}
pub fn with_config(mut self, config: DocumentUnderstandingConfig) -> Self {
self.config = config;
self
}
pub fn with_max_length(mut self, max_length: usize) -> Self {
self.config.max_length = max_length;
self
}
pub fn with_confidence_threshold(mut self, threshold: f32) -> Self {
self.config.confidence_threshold = threshold;
self
}
pub fn with_language_hints(mut self, hints: Vec<String>) -> Self {
self.config.language_hints = hints;
self
}
pub fn to_device(mut self, device: Device) -> Self {
self.base = self.base.to_device(device);
self
}
fn extract_text(&self, image: &[u8]) -> Result<String> {
if image.is_empty() {
return Ok(String::new());
}
let mut extracted_text = String::new();
if self.is_pdf_image(image) {
extracted_text = self.extract_from_pdf(image)?;
} else if self.is_text_image(image) {
extracted_text = self.extract_from_image(image)?;
}
if !self.config.language_hints.is_empty() {
extracted_text = self.apply_language_processing(&extracted_text)?;
}
Ok(extracted_text)
}
fn is_pdf_image(&self, image: &[u8]) -> bool {
image.len() > 4 && &image[0..4] == b"%PDF"
}
fn is_text_image(&self, _image: &[u8]) -> bool {
true
}
fn extract_from_pdf(&self, _image: &[u8]) -> Result<String> {
Ok("Extracted text from PDF document".to_string())
}
fn extract_from_image(&self, _image: &[u8]) -> Result<String> {
let mut text_blocks = Vec::new();
text_blocks.push(("Document Header", 0.95));
text_blocks.push(("Main content paragraph with detailed information", 0.88));
text_blocks.push(("Footer information", 0.82));
let filtered_text: Vec<String> = text_blocks
.into_iter()
.filter(|(_, confidence)| *confidence >= self.config.confidence_threshold)
.map(|(text, _)| text.to_string())
.collect();
Ok(filtered_text.join(" "))
}
fn apply_language_processing(&self, text: &str) -> Result<String> {
let mut processed_text = text.to_string();
for lang in &self.config.language_hints {
match lang.as_str() {
"zh" | "zh-CN" | "zh-TW" => {
processed_text = self.process_chinese_text(&processed_text);
},
"ja" => {
processed_text = self.process_japanese_text(&processed_text);
},
"ar" => {
processed_text = self.process_arabic_text(&processed_text);
},
_ => {
processed_text = self.process_latin_text(&processed_text);
},
}
}
Ok(processed_text)
}
fn process_chinese_text(&self, text: &str) -> String {
text.chars()
.filter(|c| !c.is_whitespace() || c == &' ')
.collect::<String>()
.trim()
.to_string()
}
fn process_japanese_text(&self, text: &str) -> String {
text.lines()
.map(|line| line.trim())
.filter(|line| !line.is_empty())
.collect::<Vec<_>>()
.join("")
}
fn process_arabic_text(&self, text: &str) -> String {
text.trim().to_string()
}
fn process_latin_text(&self, text: &str) -> String {
text.lines()
.map(|line| line.trim())
.filter(|line| !line.is_empty())
.collect::<Vec<_>>()
.join(" ")
}
fn extract_layout(&self, image: &[u8]) -> Result<Vec<TextBlock>> {
if image.is_empty() {
return Ok(Vec::new());
}
let mut blocks = Vec::new();
let document_bounds = self.detect_document_bounds(image)?;
let regions = self.segment_document_regions(image, &document_bounds)?;
for region in regions {
let block = self.analyze_text_region(®ion)?;
if block.confidence >= self.config.confidence_threshold {
blocks.push(block);
}
}
blocks.sort_by(|a, b| {
let y_diff = (a.bounding_box.y - b.bounding_box.y).abs();
if y_diff < 20.0 {
a.bounding_box
.x
.partial_cmp(&b.bounding_box.x)
.unwrap_or(std::cmp::Ordering::Equal)
} else {
a.bounding_box
.y
.partial_cmp(&b.bounding_box.y)
.unwrap_or(std::cmp::Ordering::Equal)
}
});
Ok(blocks)
}
fn detect_document_bounds(&self, _image: &[u8]) -> Result<BoundingBox> {
Ok(BoundingBox {
x: 0.0,
y: 0.0,
width: 595.0, height: 842.0, })
}
fn segment_document_regions(
&self,
_image: &[u8],
bounds: &BoundingBox,
) -> Result<Vec<DocumentRegion>> {
let mut regions = Vec::new();
regions.push(DocumentRegion {
bbox: BoundingBox {
x: bounds.x + 50.0,
y: bounds.y + 30.0,
width: bounds.width - 100.0,
height: 40.0,
},
region_type: RegionType::Header,
});
regions.push(DocumentRegion {
bbox: BoundingBox {
x: bounds.x + 50.0,
y: bounds.y + 80.0,
width: bounds.width - 100.0,
height: 60.0,
},
region_type: RegionType::Title,
});
regions.push(DocumentRegion {
bbox: BoundingBox {
x: bounds.x + 50.0,
y: bounds.y + 150.0,
width: bounds.width - 100.0,
height: bounds.height - 250.0,
},
region_type: RegionType::Body,
});
regions.push(DocumentRegion {
bbox: BoundingBox {
x: bounds.x + 50.0,
y: bounds.height - 50.0,
width: bounds.width - 100.0,
height: 30.0,
},
region_type: RegionType::Footer,
});
Ok(regions)
}
fn analyze_text_region(&self, region: &DocumentRegion) -> Result<TextBlock> {
let (text, confidence) = match region.region_type {
RegionType::Header => ("Document Header", 0.95),
RegionType::Title => ("Main Document Title", 0.98),
RegionType::Body => ("This is the main body content of the document with detailed information about the subject matter.", 0.90),
RegionType::Footer => ("Page 1 | Footer Information", 0.85),
RegionType::Table => ("Table Content", 0.88),
RegionType::List => ("• List Item 1\n• List Item 2", 0.87),
};
let block_type = match region.region_type {
RegionType::Header => TextBlockType::Header,
RegionType::Title => TextBlockType::Title,
RegionType::Body => TextBlockType::Paragraph,
RegionType::Footer => TextBlockType::Footer,
RegionType::Table => TextBlockType::Table,
RegionType::List => TextBlockType::List,
};
Ok(TextBlock {
text: text.to_string(),
bounding_box: region.bbox.clone(),
confidence,
block_type,
})
}
fn extract_key_value_pairs(&self, _image: &[u8], text: &str) -> Result<Vec<KeyValuePair>> {
let mut pairs = Vec::new();
let kv_patterns = [
(r"([A-Za-z\s]+):\s*(.+)", 1.0), (r"([A-Za-z\s]+)\s*=\s*(.+)", 0.9), (r"([A-Za-z\s]+)\s*-\s*(.+)", 0.8), (r"([A-Za-z\s]+)\s+(.+?)(?:\n|$)", 0.7), ];
for line in text.lines() {
for (pattern, base_confidence) in &kv_patterns {
if let Ok(re) = regex::Regex::new(pattern) {
if let Some(captures) = re.captures(line.trim()) {
if let (Some(key_match), Some(value_match)) =
(captures.get(1), captures.get(2))
{
let key = key_match.as_str().trim();
let value = value_match.as_str().trim();
if value.len() < 2 || key.len() < 2 {
continue;
}
let confidence =
self.calculate_kv_confidence(key, value, *base_confidence);
if confidence >= self.config.confidence_threshold {
let pair = KeyValuePair {
key: key.to_string(),
value: value.to_string(),
key_bbox: self.estimate_text_bbox(
key,
100.0,
200.0 + pairs.len() as f32 * 25.0,
),
value_bbox: self.estimate_text_bbox(
value,
200.0,
200.0 + pairs.len() as f32 * 25.0,
),
confidence,
};
pairs.push(pair);
break; }
}
}
}
}
}
self.deduplicate_key_value_pairs(pairs)
}
fn calculate_kv_confidence(&self, key: &str, value: &str, base_confidence: f32) -> f32 {
let mut confidence = base_confidence;
let common_keys = [
"name",
"address",
"phone",
"email",
"date",
"amount",
"total",
"quantity",
"price",
"description",
"company",
];
if common_keys.iter().any(|&k| key.to_lowercase().contains(k)) {
confidence += 0.1;
}
if key.len() > 50 || value.len() > 200 {
confidence -= 0.2;
}
if self.is_structured_value(value) {
confidence += 0.15;
}
confidence.clamp(0.0, 1.0)
}
fn is_structured_value(&self, value: &str) -> bool {
if regex::Regex::new(r"\d{1,2}[/-]\d{1,2}[/-]\d{2,4}")
.expect("static regex pattern is valid")
.is_match(value)
{
return true;
}
if regex::Regex::new(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b")
.expect("static regex pattern is valid")
.is_match(value)
{
return true;
}
if regex::Regex::new(r"\b\d{3}[-.]?\d{3}[-.]?\d{4}\b")
.expect("static regex pattern is valid")
.is_match(value)
{
return true;
}
false
}
fn estimate_text_bbox(&self, text: &str, x: f32, y: f32) -> BoundingBox {
let char_width = 8.0; let line_height = 20.0;
BoundingBox {
x,
y,
width: text.len() as f32 * char_width,
height: line_height,
}
}
fn deduplicate_key_value_pairs(&self, pairs: Vec<KeyValuePair>) -> Result<Vec<KeyValuePair>> {
use std::collections::HashMap;
let mut best_pairs: HashMap<String, KeyValuePair> = HashMap::new();
for pair in pairs {
let key_normalized = pair.key.to_lowercase().trim().to_string();
match best_pairs.get(&key_normalized) {
Some(existing) if existing.confidence >= pair.confidence => {
},
_ => {
best_pairs.insert(key_normalized, pair);
},
}
}
Ok(best_pairs.into_values().collect())
}
fn extract_entities(&self, text: &str) -> Result<Vec<DocumentEntity>> {
let entities = vec![DocumentEntity {
text: "John Doe".to_string(),
entity_type: "PERSON".to_string(),
bounding_box: BoundingBox {
x: 160.0,
y: 200.0,
width: 80.0,
height: 20.0,
},
confidence: 0.89,
}];
Ok(entities)
}
fn extract_tables(&self, _image: &[u8]) -> Result<Vec<Table>> {
let mut tables = Vec::new();
let table_regions = self.detect_table_regions()?;
for region in table_regions {
let table = self.extract_table_from_region(®ion)?;
if table.confidence >= self.config.confidence_threshold {
tables.push(table);
}
}
Ok(tables)
}
fn detect_table_regions(&self) -> Result<Vec<BoundingBox>> {
let regions = vec![
BoundingBox {
x: 100.0,
y: 300.0,
width: 400.0,
height: 120.0,
},
BoundingBox {
x: 100.0,
y: 450.0,
width: 350.0,
height: 80.0,
},
];
Ok(regions)
}
fn extract_table_from_region(&self, region: &BoundingBox) -> Result<Table> {
let (rows, headers, confidence) = if region.y < 400.0 {
let headers = vec![
"Item".to_string(),
"Quantity".to_string(),
"Price".to_string(),
"Total".to_string(),
];
let rows = vec![
headers.clone(),
vec![
"Product A".to_string(),
"5".to_string(),
"$10.00".to_string(),
"$50.00".to_string(),
],
vec![
"Product B".to_string(),
"3".to_string(),
"$15.00".to_string(),
"$45.00".to_string(),
],
vec![
"Product C".to_string(),
"2".to_string(),
"$25.00".to_string(),
"$50.00".to_string(),
],
vec![
"Total".to_string(),
"10".to_string(),
"-".to_string(),
"$145.00".to_string(),
],
];
(rows, Some(headers), 0.92)
} else {
let headers = vec![
"Name".to_string(),
"Department".to_string(),
"Email".to_string(),
];
let rows = vec![
headers.clone(),
vec![
"John Smith".to_string(),
"Engineering".to_string(),
"john.smith@company.com".to_string(),
],
vec![
"Jane Doe".to_string(),
"Marketing".to_string(),
"jane.doe@company.com".to_string(),
],
vec![
"Bob Johnson".to_string(),
"Sales".to_string(),
"bob.johnson@company.com".to_string(),
],
];
(rows, Some(headers), 0.88)
};
Ok(Table {
rows,
headers,
bounding_box: region.clone(),
confidence,
})
}
fn perform_ocr(&self, image: &[u8]) -> Result<Vec<OCRResult>> {
let ocr_result = OCRResult {
text: "Sample OCR text".to_string(),
bounding_box: BoundingBox {
x: 0.0,
y: 0.0,
width: 500.0,
height: 400.0,
},
confidence: 0.92,
word_level_boxes: Some(vec![
(
"Sample".to_string(),
BoundingBox {
x: 0.0,
y: 0.0,
width: 60.0,
height: 20.0,
},
),
(
"OCR".to_string(),
BoundingBox {
x: 65.0,
y: 0.0,
width: 40.0,
height: 20.0,
},
),
]),
};
Ok(vec![ocr_result])
}
fn answer_question(&self, text: &str, question: &str) -> Result<String> {
let answer = format!("Answer to '{}' based on document content", question);
Ok(answer)
}
fn preprocess_text(&self, text: &str) -> String {
if self.config.preprocess_text {
text.lines()
.map(|line| line.trim())
.filter(|line| !line.is_empty())
.collect::<Vec<_>>()
.join(" ")
} else {
text.to_string()
}
}
}
impl<M, T> Pipeline for DocumentUnderstandingPipeline<M, T>
where
M: Model + Send + Sync + 'static,
T: Tokenizer + Send + Sync + 'static,
{
type Input = DocumentUnderstandingInput;
type Output = DocumentUnderstandingOutput;
fn __call__(&self, input: Self::Input) -> Result<Self::Output> {
let start_time = std::time::Instant::now();
let cache_key = if let Some(cache) = &self.base.cache {
let mut builder = CacheKeyBuilder::new("document_understanding", "image_analysis")
.with_param("image_type", &input.image_type)
.with_param("image_hash", &input.image.len()) .with_param("config", &serde_json::to_string(&self.config).unwrap_or_default());
if let Some(question) = &input.question {
builder = builder.with_text(question);
}
let key = builder.build();
if let Some(cached) = cache.get(&key) {
if let Ok(output) = serde_json::from_slice::<DocumentUnderstandingOutput>(&cached) {
return Ok(output);
}
}
Some(key)
} else {
None
};
let text = self.extract_text(&input.image)?;
let processed_text = self.preprocess_text(&text);
let mut output = DocumentUnderstandingOutput {
text: None,
text_blocks: None,
key_value_pairs: None,
entities: None,
tables: None,
ocr_results: None,
answer: None,
metadata: DocumentMetadata {
page_count: 1,
processing_time_ms: 0,
detected_language: "en".to_string(),
text_orientation: 0.0,
quality_score: 0.9,
},
};
if self.config.return_text {
output.text = Some(processed_text.clone());
}
if self.config.return_layout {
output.text_blocks = Some(self.extract_layout(&input.image)?);
}
if self.config.return_key_value_pairs {
output.key_value_pairs =
Some(self.extract_key_value_pairs(&input.image, &processed_text)?);
}
if self.config.return_entities {
output.entities = Some(self.extract_entities(&processed_text)?);
}
if self.config.return_ocr_results {
output.ocr_results = Some(self.perform_ocr(&input.image)?);
}
output.tables = Some(self.extract_tables(&input.image)?);
if let Some(question) = &input.question {
output.answer = Some(self.answer_question(&processed_text, question)?);
}
output.metadata.processing_time_ms = start_time.elapsed().as_millis() as u64;
if let (Some(cache), Some(key)) = (&self.base.cache, cache_key) {
if let Ok(serialized) = serde_json::to_vec(&output) {
cache.insert(key, serialized);
}
}
Ok(output)
}
}
pub fn document_understanding_pipeline<M, T>(
model: M,
tokenizer: T,
) -> Result<DocumentUnderstandingPipeline<M, T>>
where
M: Model + Send + Sync + 'static,
T: Tokenizer + Send + Sync + 'static,
{
DocumentUnderstandingPipeline::new(model, tokenizer)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default_values() {
let config = DocumentUnderstandingConfig::default();
assert_eq!(config.max_length, 512, "default max_length should be 512");
assert!(
config.return_ocr_results,
"default should return OCR results"
);
assert!(config.return_layout, "default should return layout");
assert!(
config.return_key_value_pairs,
"default should return key-value pairs"
);
assert!(config.return_entities, "default should return entities");
assert!(config.return_text, "default should return text");
assert!(config.preprocess_text, "default should preprocess text");
}
#[test]
fn test_config_confidence_threshold_default_in_range() {
let config = DocumentUnderstandingConfig::default();
assert!(
config.confidence_threshold >= 0.0 && config.confidence_threshold <= 1.0,
"confidence_threshold should be in [0.0, 1.0], got {}",
config.confidence_threshold
);
}
#[test]
fn test_config_language_hints_default_contains_english() {
let config = DocumentUnderstandingConfig::default();
assert!(
config.language_hints.contains(&"en".to_string()),
"default language_hints should contain 'en'"
);
}
#[test]
fn test_bounding_box_construction() {
let bbox = BoundingBox {
x: 10.0,
y: 20.0,
width: 100.0,
height: 50.0,
};
assert!((bbox.x - 10.0).abs() < 1e-6);
assert!((bbox.y - 20.0).abs() < 1e-6);
assert!((bbox.width - 100.0).abs() < 1e-6);
assert!((bbox.height - 50.0).abs() < 1e-6);
}
#[test]
fn test_bounding_box_dimensions_non_negative() {
let bbox = BoundingBox {
x: 0.0,
y: 0.0,
width: 50.0,
height: 30.0,
};
assert!(
bbox.width >= 0.0,
"bounding box width should be non-negative"
);
assert!(
bbox.height >= 0.0,
"bounding box height should be non-negative"
);
}
#[test]
fn test_text_block_confidence_in_range() {
let block = TextBlock {
text: "Sample paragraph text".to_string(),
bounding_box: BoundingBox {
x: 0.0,
y: 0.0,
width: 200.0,
height: 30.0,
},
confidence: 0.88,
block_type: TextBlockType::Paragraph,
};
assert!(
block.confidence >= 0.0 && block.confidence <= 1.0,
"confidence must be in [0.0, 1.0]"
);
}
#[test]
fn test_text_block_heading_type() {
let block = TextBlock {
text: "Chapter 1: Introduction".to_string(),
bounding_box: BoundingBox {
x: 0.0,
y: 0.0,
width: 300.0,
height: 40.0,
},
confidence: 0.95,
block_type: TextBlockType::Heading,
};
assert!(
matches!(block.block_type, TextBlockType::Heading),
"block_type should be Heading"
);
}
#[test]
fn test_text_block_title_type() {
let block = TextBlock {
text: "Annual Report 2024".to_string(),
bounding_box: BoundingBox {
x: 0.0,
y: 0.0,
width: 400.0,
height: 60.0,
},
confidence: 0.97,
block_type: TextBlockType::Title,
};
assert!(matches!(block.block_type, TextBlockType::Title));
}
#[test]
fn test_table_row_col_count() {
let headers = vec!["Name".to_string(), "Value".to_string()];
let rows = vec![
vec!["Row1".to_string(), "100".to_string()],
vec!["Row2".to_string(), "200".to_string()],
vec!["Row3".to_string(), "300".to_string()],
];
let table = Table {
rows: rows.clone(),
headers: Some(headers),
bounding_box: BoundingBox {
x: 0.0,
y: 0.0,
width: 300.0,
height: 100.0,
},
confidence: 0.92,
};
assert_eq!(table.rows.len(), 3, "table should have 3 rows");
assert_eq!(table.rows[0].len(), 2, "each row should have 2 columns");
}
#[test]
fn test_table_headers_present() {
let headers = vec!["Item".to_string(), "Qty".to_string(), "Price".to_string()];
let table = Table {
rows: vec![headers.clone()],
headers: Some(headers.clone()),
bounding_box: BoundingBox {
x: 0.0,
y: 0.0,
width: 400.0,
height: 200.0,
},
confidence: 0.90,
};
assert!(table.headers.is_some(), "table should have headers");
assert_eq!(
table.headers.as_ref().expect("headers present").len(),
3,
"table should have 3 column headers"
);
}
#[test]
fn test_table_confidence_in_range() {
let table = Table {
rows: vec![vec!["data".to_string()]],
headers: None,
bounding_box: BoundingBox {
x: 0.0,
y: 0.0,
width: 100.0,
height: 50.0,
},
confidence: 0.85,
};
assert!(
table.confidence >= 0.0 && table.confidence <= 1.0,
"table confidence must be in [0.0, 1.0]"
);
}
#[test]
fn test_ocr_result_confidence_threshold() {
let ocr = OCRResult {
text: "Extracted text here".to_string(),
bounding_box: BoundingBox {
x: 0.0,
y: 0.0,
width: 200.0,
height: 25.0,
},
confidence: 0.92,
word_level_boxes: None,
};
let threshold = 0.5;
assert!(
ocr.confidence >= threshold,
"OCR result with confidence {} should pass threshold {}",
ocr.confidence,
threshold
);
}
#[test]
fn test_ocr_result_with_word_boxes() {
let ocr = OCRResult {
text: "Sample OCR".to_string(),
bounding_box: BoundingBox {
x: 0.0,
y: 0.0,
width: 150.0,
height: 25.0,
},
confidence: 0.95,
word_level_boxes: Some(vec![
(
"Sample".to_string(),
BoundingBox {
x: 0.0,
y: 0.0,
width: 70.0,
height: 25.0,
},
),
(
"OCR".to_string(),
BoundingBox {
x: 75.0,
y: 0.0,
width: 50.0,
height: 25.0,
},
),
]),
};
let boxes = ocr.word_level_boxes.as_ref().expect("word level boxes should be present");
assert_eq!(boxes.len(), 2, "should have 2 word-level bounding boxes");
}
#[test]
fn test_key_value_pair_fields() {
let kv = KeyValuePair {
key: "Invoice Number".to_string(),
value: "INV-12345".to_string(),
key_bbox: BoundingBox {
x: 10.0,
y: 50.0,
width: 100.0,
height: 20.0,
},
value_bbox: BoundingBox {
x: 120.0,
y: 50.0,
width: 80.0,
height: 20.0,
},
confidence: 0.88,
};
assert_eq!(kv.key, "Invoice Number");
assert_eq!(kv.value, "INV-12345");
assert!(kv.confidence >= 0.0 && kv.confidence <= 1.0);
}
#[test]
fn test_document_metadata_quality_score_in_range() {
let meta = DocumentMetadata {
page_count: 1,
processing_time_ms: 150,
detected_language: "en".to_string(),
text_orientation: 0.0,
quality_score: 0.92,
};
assert!(
meta.quality_score >= 0.0 && meta.quality_score <= 1.0,
"quality_score must be in [0.0, 1.0]"
);
}
#[test]
fn test_document_metadata_page_count_positive() {
let meta = DocumentMetadata {
page_count: 5,
processing_time_ms: 500,
detected_language: "en".to_string(),
text_orientation: 0.0,
quality_score: 0.85,
};
assert!(meta.page_count > 0, "page_count should be at least 1");
}
#[test]
fn test_document_understanding_output_construction() {
let output = DocumentUnderstandingOutput {
text: Some("Sample document text".to_string()),
text_blocks: None,
key_value_pairs: None,
entities: None,
tables: None,
ocr_results: None,
answer: None,
metadata: DocumentMetadata {
page_count: 1,
processing_time_ms: 200,
detected_language: "en".to_string(),
text_orientation: 0.0,
quality_score: 0.9,
},
};
assert!(output.text.is_some(), "output should have text");
assert_eq!(output.metadata.page_count, 1);
}
#[test]
fn test_layout_reading_order_top_to_bottom() {
let blocks = vec![
TextBlock {
text: "Header text".to_string(),
bounding_box: BoundingBox {
x: 10.0,
y: 10.0,
width: 500.0,
height: 30.0,
},
confidence: 0.95,
block_type: TextBlockType::Header,
},
TextBlock {
text: "Body text paragraph".to_string(),
bounding_box: BoundingBox {
x: 10.0,
y: 100.0,
width: 500.0,
height: 60.0,
},
confidence: 0.90,
block_type: TextBlockType::Paragraph,
},
TextBlock {
text: "Footer text".to_string(),
bounding_box: BoundingBox {
x: 10.0,
y: 900.0,
width: 500.0,
height: 20.0,
},
confidence: 0.85,
block_type: TextBlockType::Footer,
},
];
let mut sorted = blocks.clone();
sorted.sort_by(|a, b| {
a.bounding_box
.y
.partial_cmp(&b.bounding_box.y)
.unwrap_or(std::cmp::Ordering::Equal)
});
assert_eq!(
sorted[0].bounding_box.y, 10.0,
"first block should have smallest y"
);
assert_eq!(
sorted[2].bounding_box.y, 900.0,
"last block should have largest y"
);
}
#[test]
fn test_layout_reading_order_left_to_right() {
let left_block = TextBlock {
text: "Left column".to_string(),
bounding_box: BoundingBox {
x: 10.0,
y: 100.0,
width: 200.0,
height: 50.0,
},
confidence: 0.90,
block_type: TextBlockType::Paragraph,
};
let right_block = TextBlock {
text: "Right column".to_string(),
bounding_box: BoundingBox {
x: 300.0,
y: 100.0,
width: 200.0,
height: 50.0,
},
confidence: 0.88,
block_type: TextBlockType::Paragraph,
};
assert!(
left_block.bounding_box.x < right_block.bounding_box.x,
"left column x ({}) should be less than right column x ({})",
left_block.bounding_box.x,
right_block.bounding_box.x
);
}
#[test]
fn test_text_block_type_variants_accessible() {
let variants = [
TextBlockType::Title,
TextBlockType::Heading,
TextBlockType::Paragraph,
TextBlockType::List,
TextBlockType::Table,
TextBlockType::Footer,
TextBlockType::Header,
TextBlockType::Caption,
TextBlockType::Other,
];
let has_heading = variants.iter().any(|v| matches!(v, TextBlockType::Heading));
assert!(
has_heading,
"TextBlockType should include Heading variant for H1/H2/H3 detection"
);
}
}