use super::{FeatureExtractor, FeatureExtractorConfig};
use crate::auto::types::{DocumentFormat, FeatureInput, FeatureOutput, SpecialToken};
use crate::error::{Result, TrustformersError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct DocumentFeatureExtractor {
config: DocumentFeatureConfig,
}
impl DocumentFeatureExtractor {
pub fn new(config: DocumentFeatureConfig) -> Self {
Self { config }
}
fn preprocess_document(&self, content: &[u8], format: DocumentFormat) -> Result<String> {
match format {
DocumentFormat::Text => {
String::from_utf8(content.to_vec()).map_err(|e| TrustformersError::Io {
message: format!("Failed to decode text: {}", e),
path: None,
suggestion: Some(
"Check that the input is valid UTF-8 encoded text".to_string(),
),
})
},
DocumentFormat::Html => {
let html = String::from_utf8_lossy(content);
Ok(self.extract_text_from_html(&html))
},
DocumentFormat::Markdown => {
let markdown = String::from_utf8_lossy(content);
Ok(self.extract_text_from_markdown(&markdown))
},
_ => Err(TrustformersError::invalid_input(
format!("Unsupported document format: {:?}", format),
Some("document_format"),
Some("supported format (PDF, TXT, HTML, Markdown)"),
Some(format!("{:?}", format)),
)),
}
}
fn extract_text_from_html(&self, html: &str) -> String {
html.replace("<", " <")
.replace(">", "> ")
.split_whitespace()
.filter(|word| !word.starts_with('<') || !word.ends_with('>'))
.collect::<Vec<_>>()
.join(" ")
}
fn extract_text_from_markdown(&self, markdown: &str) -> String {
markdown
.lines()
.map(|line| {
line.trim()
.replace("# ", "")
.replace("## ", "")
.replace("### ", "")
.replace("**", "")
.replace("*", "")
.replace("`", "")
})
.filter(|line| !line.is_empty())
.collect::<Vec<_>>()
.join(" ")
}
fn extract_document_features(&self, text: &str) -> Result<Vec<f32>> {
let words: Vec<&str> = text.split_whitespace().collect();
let mut features = Vec::with_capacity(self.config.max_length * self.config.feature_size);
for token_idx in 0..self.config.max_length {
for feat_idx in 0..self.config.feature_size {
let feature_val = if token_idx < words.len() {
let word = words[token_idx];
match feat_idx % 4 {
0 => word.len() as f32 / 20.0, 1 => token_idx as f32 / self.config.max_length as f32, 2 if word.chars().all(|c| c.is_alphabetic()) => 1.0, 3 if word.chars().any(|c| c.is_uppercase()) => 1.0, _ => 0.0,
}
} else {
0.0 };
features.push(feature_val);
}
}
Ok(features)
}
}
impl FeatureExtractor for DocumentFeatureExtractor {
fn extract_features(&self, input: &FeatureInput) -> Result<FeatureOutput> {
match input {
FeatureInput::Document {
content,
format,
metadata,
} => {
let processed_content = self.preprocess_document(content, *format)?;
let features = self.extract_document_features(&processed_content)?;
let mut output_metadata = HashMap::new();
output_metadata.insert(
"format".to_string(),
serde_json::Value::String(format!("{:?}", format)),
);
Ok(FeatureOutput {
features,
shape: vec![self.config.max_length, self.config.feature_size],
metadata: output_metadata,
attention_mask: Some(vec![1; self.config.max_length]),
special_tokens: vec![
SpecialToken {
token_type: "CLS".to_string(),
position: 0,
value: "[CLS]".to_string(),
},
SpecialToken {
token_type: "SEP".to_string(),
position: self.config.max_length - 1,
value: "[SEP]".to_string(),
},
],
})
},
_ => Err(TrustformersError::invalid_input(
"Document feature extractor requires document input".to_string(),
Some("input_type".to_string()),
Some("DocumentInput".to_string()),
Some("Other input type".to_string()),
)),
}
}
fn config(&self) -> &dyn FeatureExtractorConfig {
&self.config
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentFeatureConfig {
pub max_length: usize,
pub feature_size: usize,
pub include_layout: bool,
pub include_visual_features: bool,
pub max_batch_size: Option<usize>,
}
impl DocumentFeatureConfig {
pub fn from_config(config: &serde_json::Value) -> Result<Self> {
Ok(Self {
max_length: config
.get("max_position_embeddings")
.and_then(|v| v.as_u64())
.unwrap_or(512) as usize,
feature_size: config.get("hidden_size").and_then(|v| v.as_u64()).unwrap_or(768)
as usize,
include_layout: config
.get("has_visual_segment_embedding")
.and_then(|v| v.as_bool())
.unwrap_or(false),
include_visual_features: config
.get("has_spatial_attention_bias")
.and_then(|v| v.as_bool())
.unwrap_or(false),
max_batch_size: config
.get("max_batch_size")
.and_then(|v| v.as_u64())
.map(|v| v as usize),
})
}
}
impl FeatureExtractorConfig for DocumentFeatureConfig {
fn feature_size(&self) -> usize {
self.feature_size
}
fn supports_batching(&self) -> bool {
true
}
fn max_batch_size(&self) -> Option<usize> {
self.max_batch_size
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auto::types::DocumentMetadata;
#[test]
fn test_document_feature_extractor_creation() {
let config = DocumentFeatureConfig {
max_length: 512,
feature_size: 768,
include_layout: false,
include_visual_features: false,
max_batch_size: Some(8),
};
let extractor = DocumentFeatureExtractor::new(config);
assert_eq!(extractor.config().feature_size(), 768);
assert!(extractor.config().supports_batching());
assert_eq!(extractor.config().max_batch_size(), Some(8));
}
#[test]
fn test_text_document_extraction() {
let config = DocumentFeatureConfig {
max_length: 512,
feature_size: 768,
include_layout: false,
include_visual_features: false,
max_batch_size: Some(8),
};
let extractor = DocumentFeatureExtractor::new(config);
let content = b"This is a test document with some text content.";
let input = FeatureInput::Document {
content: content.to_vec(),
format: DocumentFormat::Text,
metadata: Some(DocumentMetadata {
page_count: Some(1),
author: Some("Test Author".to_string()),
title: Some("Test Document".to_string()),
creation_date: None,
}),
};
let result = extractor.extract_features(&input);
assert!(result.is_ok());
let output = result.expect("operation failed in test");
assert_eq!(output.features.len(), 512 * 768);
assert_eq!(output.shape, vec![512, 768]);
assert!(output.attention_mask.is_some());
assert_eq!(output.special_tokens.len(), 2);
}
#[test]
fn test_html_document_extraction() {
let config = DocumentFeatureConfig {
max_length: 256,
feature_size: 384,
include_layout: false,
include_visual_features: false,
max_batch_size: Some(4),
};
let extractor = DocumentFeatureExtractor::new(config);
let html_content = br#"
<html>
<body>
<h1>Document Title</h1>
<p>This is the main content of the document.</p>
</body>
</html>
"#;
let input = FeatureInput::Document {
content: html_content.to_vec(),
format: DocumentFormat::Html,
metadata: None,
};
let result = extractor.extract_features(&input);
assert!(result.is_ok());
let output = result.expect("operation failed in test");
assert_eq!(output.features.len(), 256 * 384);
assert_eq!(output.shape, vec![256, 384]);
}
#[test]
fn test_markdown_document_extraction() {
let config = DocumentFeatureConfig {
max_length: 128,
feature_size: 256,
include_layout: true,
include_visual_features: false,
max_batch_size: Some(16),
};
let extractor = DocumentFeatureExtractor::new(config);
let markdown_content = br#"
# Document Title
This is a **bold** text and *italic* text.
## Section Header
Some `code` and regular text.
"#;
let input = FeatureInput::Document {
content: markdown_content.to_vec(),
format: DocumentFormat::Markdown,
metadata: None,
};
let result = extractor.extract_features(&input);
assert!(result.is_ok());
let output = result.expect("operation failed in test");
assert_eq!(output.features.len(), 128 * 256);
assert_eq!(output.shape, vec![128, 256]);
}
#[test]
fn test_document_config_from_json() {
let model_config = serde_json::json!({
"model_type": "layoutlm",
"hidden_size": 768,
"max_position_embeddings": 512,
"has_visual_segment_embedding": true,
"has_spatial_attention_bias": true,
"max_batch_size": 8
});
let config =
DocumentFeatureConfig::from_config(&model_config).expect("operation failed in test");
assert_eq!(config.feature_size, 768);
assert_eq!(config.max_length, 512);
assert!(config.include_layout);
assert!(config.include_visual_features);
assert_eq!(config.max_batch_size, Some(8));
}
#[test]
fn test_document_config_defaults() {
let minimal_config = serde_json::json!({});
let config =
DocumentFeatureConfig::from_config(&minimal_config).expect("operation failed in test");
assert_eq!(config.feature_size, 768);
assert_eq!(config.max_length, 512);
assert!(!config.include_layout);
assert!(!config.include_visual_features);
assert_eq!(config.max_batch_size, None);
}
#[test]
fn test_invalid_input_type() {
let config = DocumentFeatureConfig {
max_length: 512,
feature_size: 768,
include_layout: false,
include_visual_features: false,
max_batch_size: Some(8),
};
let extractor = DocumentFeatureExtractor::new(config);
let input = FeatureInput::Text {
content: "This is not a document input".to_string(),
metadata: None,
};
let result = extractor.extract_features(&input);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TrustformersError::InvalidInput { .. }
));
}
#[test]
fn test_html_text_extraction() {
let config = DocumentFeatureConfig {
max_length: 512,
feature_size: 768,
include_layout: false,
include_visual_features: false,
max_batch_size: Some(8),
};
let extractor = DocumentFeatureExtractor::new(config);
let html = "<h1>Title</h1><p>Paragraph content</p>";
let extracted = extractor.extract_text_from_html(html);
assert!(extracted.contains("Title"));
assert!(extracted.contains("Paragraph"));
assert!(extracted.contains("content"));
assert!(!extracted.contains("<h1>"));
assert!(!extracted.contains("</p>"));
}
#[test]
fn test_markdown_text_extraction() {
let config = DocumentFeatureConfig {
max_length: 512,
feature_size: 768,
include_layout: false,
include_visual_features: false,
max_batch_size: Some(8),
};
let extractor = DocumentFeatureExtractor::new(config);
let markdown = "# Title\n\n**Bold text** and *italic text*\n\n`code`";
let extracted = extractor.extract_text_from_markdown(markdown);
assert!(extracted.contains("Title"));
assert!(extracted.contains("Bold text"));
assert!(extracted.contains("italic text"));
assert!(extracted.contains("code"));
assert!(!extracted.contains("# "));
assert!(!extracted.contains("**"));
assert!(!extracted.contains("*"));
assert!(!extracted.contains("`"));
}
#[test]
fn test_feature_extraction_with_special_tokens() {
let config = DocumentFeatureConfig {
max_length: 10,
feature_size: 4,
include_layout: false,
include_visual_features: false,
max_batch_size: Some(1),
};
let extractor = DocumentFeatureExtractor::new(config);
let input = FeatureInput::Document {
content: b"short text".to_vec(),
format: DocumentFormat::Text,
metadata: None,
};
let result = extractor.extract_features(&input).expect("operation failed in test");
assert_eq!(result.special_tokens.len(), 2);
assert_eq!(result.special_tokens[0].token_type, "CLS");
assert_eq!(result.special_tokens[0].position, 0);
assert_eq!(result.special_tokens[1].token_type, "SEP");
assert_eq!(result.special_tokens[1].position, 9); }
}