use crate::common::error::RustBertError;
use crate::pipelines::common::TokenizerOption;
use crate::pipelines::token_classification::{
Token, TokenClassificationConfig, TokenClassificationModel,
};
use rust_tokenizers::Offset;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Entity {
pub word: String,
pub score: f64,
pub label: String,
pub offset: Offset,
}
type NERConfig = TokenClassificationConfig;
pub struct NERModel {
token_classification_model: TokenClassificationModel,
}
impl NERModel {
pub fn new(ner_config: NERConfig) -> Result<NERModel, RustBertError> {
let model = TokenClassificationModel::new(ner_config)?;
Ok(NERModel {
token_classification_model: model,
})
}
pub fn new_with_tokenizer(
ner_config: NERConfig,
tokenizer: TokenizerOption,
) -> Result<NERModel, RustBertError> {
let model = TokenClassificationModel::new_with_tokenizer(ner_config, tokenizer)?;
Ok(NERModel {
token_classification_model: model,
})
}
pub fn get_tokenizer(&self) -> &TokenizerOption {
self.token_classification_model.get_tokenizer()
}
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
self.token_classification_model.get_tokenizer_mut()
}
pub fn predict<S>(&self, input: &[S]) -> Vec<Vec<Entity>>
where
S: AsRef<str>,
{
self.token_classification_model
.predict(input, true, false)
.into_iter()
.map(|sequence_tokens| {
sequence_tokens
.into_iter()
.filter(|token| token.label != "O")
.map(|token| Entity {
offset: token.offset.unwrap(),
word: token.text,
score: token.score,
label: token.label,
})
.collect::<Vec<Entity>>()
})
.collect::<Vec<Vec<Entity>>>()
}
pub fn predict_full_entities<S>(&self, input: &[S]) -> Vec<Vec<Entity>>
where
S: AsRef<str>,
{
let tokens = self.token_classification_model.predict(input, true, false);
let mut entities: Vec<Vec<Entity>> = Vec::new();
for sequence_tokens in tokens {
entities.push(Self::consolidate_entities(&sequence_tokens));
}
entities
}
fn consolidate_entities(tokens: &[Token]) -> Vec<Entity> {
let mut entities: Vec<Entity> = Vec::new();
let mut entity_builder = EntityBuilder::new();
for (position, token) in tokens.iter().enumerate() {
let tag = token.get_tag();
let label = token.get_label();
if let Some(entity) = entity_builder.handle_current_tag(tag, label, position, tokens) {
entities.push(entity)
}
}
if let Some(entity) = entity_builder.flush_and_reset(tokens.len(), tokens) {
entities.push(entity);
}
entities
}
}
struct EntityBuilder<'a> {
previous_node: Option<(usize, Tag, &'a str)>,
}
impl<'a> EntityBuilder<'a> {
fn new() -> Self {
EntityBuilder {
previous_node: None,
}
}
fn handle_current_tag(
&mut self,
tag: Tag,
label: &'a str,
position: usize,
tokens: &[Token],
) -> Option<Entity> {
match tag {
Tag::Outside => self.flush_and_reset(position, tokens),
Tag::Begin | Tag::Single => {
let entity = self.flush_and_reset(position, tokens);
self.start_new(position, tag, label);
entity
}
Tag::Inside | Tag::End => {
if let Some((_, previous_tag, previous_label)) = self.previous_node {
if (previous_tag == Tag::End)
| (previous_tag == Tag::Single)
| (previous_label != label)
{
let entity = self.flush_and_reset(position, tokens);
self.start_new(position, tag, label);
entity
} else {
None
}
} else {
self.start_new(position, tag, label);
None
}
}
}
}
fn flush_and_reset(&mut self, position: usize, tokens: &[Token]) -> Option<Entity> {
let entity = if let Some((start, _, label)) = self.previous_node {
let entity_tokens = &tokens[start..position];
Some(Entity {
word: entity_tokens
.iter()
.map(|token| token.text.as_str())
.collect::<Vec<&str>>()
.join(" "),
score: entity_tokens.iter().map(|token| token.score).product(),
label: label.to_string(),
offset: Offset {
begin: entity_tokens.first()?.offset?.begin,
end: entity_tokens.last()?.offset?.end,
},
})
} else {
None
};
self.previous_node = None;
entity
}
fn start_new(&mut self, position: usize, tag: Tag, label: &'a str) {
self.previous_node = Some((position, tag, label))
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum Tag {
Begin,
Inside,
Outside,
End,
Single,
}
impl Token {
fn get_tag(&self) -> Tag {
match self.label.split('-').collect::<Vec<&str>>()[0] {
"B" => Tag::Begin,
"I" => Tag::Inside,
"O" => Tag::Outside,
"E" => Tag::End,
"S" => Tag::Single,
_ => panic!("Invalid tag encountered for token {:?}", self),
}
}
fn get_label(&self) -> &str {
let split_label = self.label.split('-').collect::<Vec<&str>>();
if split_label.len() > 1 {
split_label[1]
} else {
""
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[ignore] fn test() {
let config = NERConfig::default();
let _: Box<dyn Send> = Box::new(NERModel::new(config));
}
}