use crate::core::traits::{Model, Tokenizer};
use crate::error::{Result, TrustformersError};
use crate::models::bert::tasks::SequenceClassifierOutput;
use crate::pipeline::{
BasePipeline, Pipeline, PipelineOutput, TokenClassificationOutput as PipelineTokenOutput,
};
use crate::{AutoModel, AutoTokenizer};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenClassificationConfig {
pub max_length: usize,
pub aggregation_strategy: String,
pub ignore_labels: Vec<String>,
}
impl Default for TokenClassificationConfig {
fn default() -> Self {
Self {
max_length: 512,
aggregation_strategy: "simple".to_string(),
ignore_labels: vec!["O".to_string()],
}
}
}
#[derive(Debug, Clone)]
struct TokenOutput {
pub entity: String,
pub score: f32,
pub index: usize,
pub word: String,
pub start: usize,
pub end: usize,
}
#[derive(Clone, Debug)]
pub enum AggregationStrategy {
None,
Simple,
First,
Average,
Max,
}
#[derive(Clone)]
pub struct TokenClassificationPipeline {
base: BasePipeline<AutoModel, AutoTokenizer>,
aggregation_strategy: AggregationStrategy,
labels: Arc<Vec<String>>,
}
impl TokenClassificationPipeline {
pub fn new(model: AutoModel, tokenizer: AutoTokenizer) -> Result<Self> {
Ok(Self {
base: BasePipeline::new(model, tokenizer),
aggregation_strategy: AggregationStrategy::Simple,
labels: Arc::new(vec![
"O".to_string(),
"B-PER".to_string(),
"I-PER".to_string(),
"B-ORG".to_string(),
"I-ORG".to_string(),
"B-LOC".to_string(),
"I-LOC".to_string(),
"B-MISC".to_string(),
"I-MISC".to_string(),
]),
})
}
pub fn with_aggregation_strategy(mut self, strategy: AggregationStrategy) -> Self {
self.aggregation_strategy = strategy;
self
}
pub fn with_labels(mut self, labels: Vec<String>) -> Self {
self.labels = Arc::new(labels);
self
}
fn classify_tokens(&self, text: &str) -> Result<Vec<PipelineTokenOutput>> {
let inputs = self.base.tokenizer.encode(text)?;
match &self.base.model.model_type {
#[cfg(feature = "bert")]
crate::automodel::AutoModelType::BertForSequenceClassification(model) => {
let output = model.forward(inputs.clone())?;
let token_outputs =
self.adapt_sequence_to_token_classification(&output, &inputs, text)?;
let aggregated = self.aggregate_entities(token_outputs);
Ok(aggregated)
},
#[cfg(feature = "bert")]
crate::automodel::AutoModelType::Bert(_model) => {
self.fallback_token_classification(text, &inputs)
},
_ => Err(TrustformersError::runtime_error(
"Model does not support token classification",
)),
}
}
fn classify_tokens_batch(&self, texts: &[String]) -> Result<Vec<Vec<PipelineTokenOutput>>> {
texts.iter().map(|text| self.classify_tokens(text)).collect::<Result<Vec<_>>>()
}
fn aggregate_entities(&self, token_outputs: Vec<TokenOutput>) -> Vec<PipelineTokenOutput> {
match self.aggregation_strategy {
AggregationStrategy::None => {
token_outputs
.into_iter()
.map(|t| PipelineTokenOutput {
entity: t.entity,
score: t.score,
index: t.index,
word: t.word,
start: t.start,
end: t.end,
})
.collect()
},
AggregationStrategy::Simple => {
self.simple_entity_aggregation(token_outputs)
},
AggregationStrategy::First => {
self.first_entity_aggregation(token_outputs)
},
AggregationStrategy::Average => {
self.average_entity_aggregation(token_outputs)
},
AggregationStrategy::Max => {
self.max_entity_aggregation(token_outputs)
},
}
}
fn adapt_sequence_to_token_classification(
&self,
output: &SequenceClassifierOutput,
inputs: &crate::core::traits::TokenizedInput,
original_text: &str,
) -> Result<Vec<TokenOutput>> {
let logits = &output.logits;
let logits_data = logits.data()?;
let shape = logits.shape();
if shape.len() < 2 {
return Err(TrustformersError::runtime_error(
"Sequence classification output must have shape [batch, num_classes]",
));
}
let num_classes = shape[shape.len() - 1];
let mut token_outputs = Vec::new();
let sequence_logits = &logits_data[0..num_classes.min(logits_data.len())];
let probs = self.softmax(sequence_logits);
let (max_class_idx, &max_score) = probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or((0, &0.0));
let words: Vec<&str> = original_text.split_whitespace().collect();
for (word_idx, word) in words.iter().enumerate() {
if max_score > 0.3 && word.chars().next().map(|c| c.is_uppercase()).unwrap_or(false) {
let label = if max_class_idx < self.labels.len() {
&self.labels[max_class_idx]
} else {
"B-MISC" };
if label != "O" {
let word_start = original_text.find(word).unwrap_or(word_idx * 5);
let word_end = word_start + word.len();
token_outputs.push(TokenOutput {
entity: label.to_string(),
score: max_score * 0.8, index: word_idx,
word: word.to_string(),
start: word_start,
end: word_end,
});
}
}
}
Ok(token_outputs)
}
fn fallback_token_classification(
&self,
text: &str,
inputs: &crate::core::traits::TokenizedInput,
) -> Result<Vec<PipelineTokenOutput>> {
let mut results = Vec::new();
let patterns = [
(r"[A-Z][a-z]+ [A-Z][a-z]+", "B-PER"), (r"[A-Z][a-z]+ Inc\.|Corp\.|LLC", "B-ORG"), (r"[A-Z][a-z]+, [A-Z][A-Z]", "B-LOC"), ];
let words: Vec<&str> = text.split_whitespace().collect();
for (i, word) in words.iter().enumerate() {
if word.chars().next().map(|c| c.is_uppercase()).unwrap_or(false) {
results.push(PipelineTokenOutput {
entity: "B-MISC".to_string(),
score: 0.6, index: i,
word: word.to_string(),
start: text.find(word).unwrap_or(0),
end: text.find(word).unwrap_or(0) + word.len(),
});
}
}
Ok(results)
}
fn softmax(&self, logits: &[f32]) -> Vec<f32> {
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
let sum_exp: f32 = exp_logits.iter().sum();
exp_logits.iter().map(|&x| x / sum_exp).collect()
}
fn simple_entity_aggregation(&self, tokens: Vec<TokenOutput>) -> Vec<PipelineTokenOutput> {
let mut aggregated = Vec::new();
let mut current_entity: Option<PipelineTokenOutput> = None;
for token in tokens {
if token.entity.starts_with("B-") {
if let Some(entity) = current_entity.take() {
aggregated.push(entity);
}
current_entity = Some(PipelineTokenOutput {
entity: token.entity[2..].to_string(), score: token.score,
index: token.index,
word: token.word,
start: token.start,
end: token.end,
});
} else if token.entity.starts_with("I-") {
if let Some(ref mut entity) = current_entity {
if entity.entity == token.entity[2..] {
entity.word = format!("{} {}", entity.word, token.word);
entity.end = token.end;
entity.score = (entity.score + token.score) / 2.0; } else {
aggregated.push(current_entity.take().expect("inside if let Some"));
current_entity = Some(PipelineTokenOutput {
entity: token.entity[2..].to_string(),
score: token.score,
index: token.index,
word: token.word,
start: token.start,
end: token.end,
});
}
}
} else if token.entity != "O" {
if let Some(entity) = current_entity.take() {
aggregated.push(entity);
}
aggregated.push(PipelineTokenOutput {
entity: token.entity,
score: token.score,
index: token.index,
word: token.word,
start: token.start,
end: token.end,
});
}
}
if let Some(entity) = current_entity {
aggregated.push(entity);
}
aggregated
}
fn first_entity_aggregation(&self, tokens: Vec<TokenOutput>) -> Vec<PipelineTokenOutput> {
let mut seen_entities = HashMap::new();
let mut results = Vec::new();
for token in tokens {
let entity_type = if token.entity.starts_with("B-") || token.entity.starts_with("I-") {
&token.entity[2..]
} else {
&token.entity
};
if !seen_entities.contains_key(entity_type) {
seen_entities.insert(entity_type.to_string(), true);
results.push(PipelineTokenOutput {
entity: entity_type.to_string(),
score: token.score,
index: token.index,
word: token.word,
start: token.start,
end: token.end,
});
}
}
results
}
fn average_entity_aggregation(&self, tokens: Vec<TokenOutput>) -> Vec<PipelineTokenOutput> {
let mut entity_groups: HashMap<String, Vec<TokenOutput>> = HashMap::new();
for token in tokens {
let entity_type = if token.entity.starts_with("B-") || token.entity.starts_with("I-") {
token.entity[2..].to_string()
} else {
token.entity.clone()
};
entity_groups.entry(entity_type).or_default().push(token);
}
let mut results = Vec::new();
for (entity_type, group) in entity_groups {
if !group.is_empty() {
let avg_score = group.iter().map(|t| t.score).sum::<f32>() / group.len() as f32;
let first_token = &group[0];
let last_token = &group[group.len() - 1];
let combined_word =
group.iter().map(|t| t.word.as_str()).collect::<Vec<_>>().join(" ");
results.push(PipelineTokenOutput {
entity: entity_type,
score: avg_score,
index: first_token.index,
word: combined_word,
start: first_token.start,
end: last_token.end,
});
}
}
results
}
fn max_entity_aggregation(&self, tokens: Vec<TokenOutput>) -> Vec<PipelineTokenOutput> {
let mut entity_groups: HashMap<String, Vec<TokenOutput>> = HashMap::new();
for token in tokens {
let entity_type = if token.entity.starts_with("B-") || token.entity.starts_with("I-") {
token.entity[2..].to_string()
} else {
token.entity.clone()
};
entity_groups.entry(entity_type).or_default().push(token);
}
let mut results = Vec::new();
for (entity_type, group) in entity_groups {
if let Some(max_token) = group
.iter()
.max_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
{
results.push(PipelineTokenOutput {
entity: entity_type,
score: max_token.score,
index: max_token.index,
word: max_token.word.clone(),
start: max_token.start,
end: max_token.end,
});
}
}
results
}
}
impl Pipeline for TokenClassificationPipeline {
type Input = String;
type Output = PipelineOutput;
fn __call__(&self, input: Self::Input) -> Result<Self::Output> {
let results = self.classify_tokens(&input)?;
Ok(PipelineOutput::TokenClassification(results))
}
fn batch(&self, inputs: Vec<Self::Input>) -> Result<Vec<Self::Output>> {
let batch_results = self.classify_tokens_batch(&inputs)?;
Ok(batch_results.into_iter().map(PipelineOutput::TokenClassification).collect())
}
}
#[cfg(feature = "async")]
#[async_trait::async_trait]
impl crate::pipeline::AsyncPipeline for TokenClassificationPipeline {
type Input = String;
type Output = PipelineOutput;
async fn __call_async__(&self, input: Self::Input) -> Result<Self::Output> {
let pipeline = self.clone();
tokio::task::spawn_blocking(move || pipeline.__call__(input))
.await
.map_err(|e| TrustformersError::pipeline(e.to_string(), "runtime"))?
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_token(entity: &str, score: f32, idx: usize, word: &str, start: usize) -> TokenOutput {
let end = start + word.len();
TokenOutput {
entity: entity.to_string(),
score,
index: idx,
word: word.to_string(),
start,
end,
}
}
fn make_pipeline_token(
entity: &str,
score: f32,
idx: usize,
word: &str,
start: usize,
) -> PipelineTokenOutput {
let end = start + word.len();
PipelineTokenOutput {
entity: entity.to_string(),
score,
index: idx,
word: word.to_string(),
start,
end,
}
}
fn softmax(logits: &[f32]) -> Vec<f32> {
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
let sum_exp: f32 = exp_logits.iter().sum();
exp_logits.iter().map(|&x| x / sum_exp).collect()
}
#[test]
fn test_config_default_values() {
let cfg = TokenClassificationConfig::default();
assert_eq!(cfg.max_length, 512);
assert_eq!(cfg.aggregation_strategy, "simple");
assert!(cfg.ignore_labels.contains(&"O".to_string()));
}
#[test]
fn test_config_clone() {
let cfg = TokenClassificationConfig {
max_length: 256,
..TokenClassificationConfig::default()
};
assert_eq!(cfg.clone().max_length, 256);
}
#[test]
fn test_b_tag_starts_new_entity() {
assert!("B-PER".starts_with("B-"));
assert!("B-ORG".starts_with("B-"));
assert!("B-LOC".starts_with("B-"));
}
#[test]
fn test_i_tag_continues_entity() {
assert!("I-PER".starts_with("I-"));
assert!("I-ORG".starts_with("I-"));
}
#[test]
fn test_o_tag_is_outside() {
assert_eq!("O", "O");
assert!(!"O".starts_with("B-"));
assert!(!"O".starts_with("I-"));
}
#[test]
fn test_bio_entity_type_extraction() {
assert_eq!(&"B-PER"[2..], "PER");
assert_eq!(&"I-ORG"[2..], "ORG");
assert_eq!(&"B-LOC"[2..], "LOC");
}
fn simple_aggregate(tokens: Vec<TokenOutput>) -> Vec<PipelineTokenOutput> {
let mut aggregated: Vec<PipelineTokenOutput> = Vec::new();
let mut current: Option<PipelineTokenOutput> = None;
for token in tokens {
if token.entity.starts_with("B-") {
if let Some(e) = current.take() {
aggregated.push(e);
}
current = Some(PipelineTokenOutput {
entity: token.entity[2..].to_string(),
score: token.score,
index: token.index,
word: token.word.clone(),
start: token.start,
end: token.end,
});
} else if token.entity.starts_with("I-") {
if let Some(ref mut e) = current {
if e.entity == token.entity[2..] {
e.word = format!("{} {}", e.word, token.word);
e.end = token.end;
e.score = (e.score + token.score) / 2.0;
}
}
} else if token.entity != "O" {
if let Some(e) = current.take() {
aggregated.push(e);
}
aggregated.push(PipelineTokenOutput {
entity: token.entity.clone(),
score: token.score,
index: token.index,
word: token.word.clone(),
start: token.start,
end: token.end,
});
}
}
if let Some(e) = current {
aggregated.push(e);
}
aggregated
}
#[test]
fn test_simple_aggregate_single_b_tag() {
let tokens = vec![make_token("B-PER", 0.9, 0, "Alice", 0)];
let result = simple_aggregate(tokens);
assert_eq!(result.len(), 1);
assert_eq!(result[0].entity, "PER");
assert_eq!(result[0].word, "Alice");
}
#[test]
fn test_simple_aggregate_b_i_merged() {
let tokens = vec![
make_token("B-PER", 0.9, 0, "John", 0),
make_token("I-PER", 0.85, 1, "Smith", 5),
];
let result = simple_aggregate(tokens);
assert_eq!(result.len(), 1);
assert_eq!(result[0].entity, "PER");
assert!(result[0].word.contains("John"));
assert!(result[0].word.contains("Smith"));
}
#[test]
fn test_simple_aggregate_two_separate_entities() {
let tokens = vec![
make_token("B-PER", 0.9, 0, "Alice", 0),
make_token("B-ORG", 0.8, 1, "ACME", 6),
];
let result = simple_aggregate(tokens);
assert_eq!(result.len(), 2);
}
#[test]
fn test_simple_aggregate_o_tokens_ignored() {
let tokens = vec![
make_token("O", 0.99, 0, "the", 0),
make_token("B-PER", 0.9, 1, "Alice", 4),
];
let result = simple_aggregate(tokens);
assert_eq!(result.len(), 1);
assert_eq!(result[0].entity, "PER");
}
#[test]
fn test_confidence_score_range() {
let token = make_pipeline_token("B-PER", 0.87, 0, "Alice", 0);
assert!(token.score >= 0.0 && token.score <= 1.0);
}
#[test]
fn test_average_score_after_merge() {
let tokens = vec![
make_token("B-PER", 0.8, 0, "John", 0),
make_token("I-PER", 0.6, 1, "Doe", 5),
];
let result = simple_aggregate(tokens);
assert_eq!(result.len(), 1);
let expected_score = (0.8 + 0.6) / 2.0;
assert!((result[0].score - expected_score).abs() < 1e-5);
}
#[test]
fn test_span_start_end_consistency() {
let token = make_pipeline_token("B-ORG", 0.9, 2, "OpenAI", 10);
assert_eq!(token.end - token.start, "OpenAI".len());
}
#[test]
fn test_merged_span_covers_full_range() {
let tokens = vec![
make_token("B-PER", 0.9, 0, "John", 0), make_token("I-PER", 0.85, 1, "Smith", 5), ];
let result = simple_aggregate(tokens);
assert_eq!(result[0].start, 0);
assert_eq!(result[0].end, 10);
}
#[test]
fn test_aggregation_strategy_variants_constructable() {
let _none = AggregationStrategy::None;
let _simple = AggregationStrategy::Simple;
let _first = AggregationStrategy::First;
let _avg = AggregationStrategy::Average;
let _max = AggregationStrategy::Max;
}
#[test]
fn test_first_entity_aggregation() {
let tokens = vec![
make_token("B-PER", 0.9, 0, "Alice", 0),
make_token("B-PER", 0.8, 1, "Bob", 6),
];
let mut seen: HashMap<String, bool> = HashMap::new();
let mut results = Vec::new();
for token in tokens {
let etype = if token.entity.starts_with("B-") || token.entity.starts_with("I-") {
token.entity[2..].to_string()
} else {
token.entity.clone()
};
if !seen.contains_key(&etype) {
seen.insert(etype.clone(), true);
results.push(make_pipeline_token(
&etype,
token.score,
token.index,
&token.word,
token.start,
));
}
}
assert_eq!(results.len(), 1);
assert_eq!(results[0].word, "Alice");
}
#[test]
fn test_max_entity_aggregation_picks_highest_score() {
let tokens = vec![
make_token("B-PER", 0.6, 0, "Alice", 0),
make_token("B-PER", 0.95, 1, "Bob", 6),
];
let mut groups: HashMap<String, Vec<TokenOutput>> = HashMap::new();
for token in tokens {
let etype = if token.entity.starts_with("B-") || token.entity.starts_with("I-") {
token.entity[2..].to_string()
} else {
token.entity.clone()
};
groups.entry(etype).or_default().push(token);
}
let mut results = Vec::new();
for (etype, group) in groups {
if let Some(max_t) = group
.iter()
.max_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
{
results.push(make_pipeline_token(
&etype,
max_t.score,
max_t.index,
&max_t.word,
max_t.start,
));
}
}
assert_eq!(results.len(), 1);
assert!((results[0].score - 0.95).abs() < 1e-5);
}
#[test]
fn test_average_entity_aggregation() {
let tokens = [
make_token("B-ORG", 0.8, 0, "OpenAI", 0),
make_token("I-ORG", 0.6, 1, "Inc", 7),
];
let group_score: f32 = tokens.iter().map(|t| t.score).sum::<f32>() / tokens.len() as f32;
assert!((group_score - 0.7).abs() < 1e-5);
}
#[test]
fn test_softmax_sums_to_one() {
let logits = vec![1.0, 2.0, 0.5, -1.0];
let probs = softmax(&logits);
let sum: f32 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
#[test]
fn test_softmax_highest_logit_has_highest_prob() {
let logits = vec![0.0, 5.0, 1.0];
let probs = softmax(&logits);
assert!(probs[1] > probs[0] && probs[1] > probs[2]);
}
#[test]
fn test_default_labels_include_bio_tags() {
let expected = [
"O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC",
];
for label in &expected {
assert!(expected.contains(label));
}
}
#[test]
fn test_label_count_nine() {
let labels: Vec<String> = vec![
"O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC",
]
.into_iter()
.map(|s| s.to_string())
.collect();
assert_eq!(labels.len(), 9);
}
}