use crate::common::error::RustBertError;
use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel};
use serde::{Deserialize, Serialize};
use crate::pipelines::common::TokenizerOption;
#[cfg(feature = "remote")]
use {
crate::{
mobilebert::{
MobileBertConfigResources, MobileBertModelResources, MobileBertVocabResources,
},
pipelines::{
common::{ModelResource, ModelType},
token_classification::LabelAggregationOption,
},
resources::RemoteResource,
},
tch::Device,
};
#[derive(Debug, Serialize, Deserialize)]
pub struct POSTag {
pub word: String,
pub score: f64,
pub label: String,
}
pub struct POSConfig {
token_classification_config: TokenClassificationConfig,
}
#[cfg(feature = "remote")]
impl Default for POSConfig {
fn default() -> POSConfig {
POSConfig {
token_classification_config: TokenClassificationConfig {
model_type: ModelType::MobileBert,
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
MobileBertModelResources::MOBILEBERT_ENGLISH_POS,
))),
config_resource: Box::new(RemoteResource::from_pretrained(
MobileBertConfigResources::MOBILEBERT_ENGLISH_POS,
)),
vocab_resource: Box::new(RemoteResource::from_pretrained(
MobileBertVocabResources::MOBILEBERT_ENGLISH_POS,
)),
merges_resource: None,
lower_case: true,
strip_accents: Some(true),
add_prefix_space: None,
device: Device::cuda_if_available(),
kind: None,
label_aggregation_function: LabelAggregationOption::First,
batch_size: 64,
},
}
}
}
impl From<TokenClassificationConfig> for POSConfig {
fn from(token_classification_config: TokenClassificationConfig) -> Self {
POSConfig {
token_classification_config,
}
}
}
impl From<POSConfig> for TokenClassificationConfig {
fn from(pos_config: POSConfig) -> Self {
pos_config.token_classification_config
}
}
pub struct POSModel {
token_classification_model: TokenClassificationModel,
}
impl POSModel {
pub fn new(pos_config: POSConfig) -> Result<POSModel, RustBertError> {
let model = TokenClassificationModel::new(pos_config.into())?;
Ok(POSModel {
token_classification_model: model,
})
}
pub fn new_with_tokenizer(
pos_config: POSConfig,
tokenizer: TokenizerOption,
) -> Result<POSModel, RustBertError> {
let model = TokenClassificationModel::new_with_tokenizer(pos_config.into(), tokenizer)?;
Ok(POSModel {
token_classification_model: model,
})
}
pub fn get_tokenizer(&self) -> &TokenizerOption {
self.token_classification_model.get_tokenizer()
}
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
self.token_classification_model.get_tokenizer_mut()
}
pub fn predict<S>(&self, input: &[S]) -> Vec<Vec<POSTag>>
where
S: AsRef<str>,
{
self.token_classification_model
.predict(input, true, false)
.into_iter()
.map(|sequence_tokens| {
sequence_tokens
.into_iter()
.map(|mut token| {
if (Self::is_punctuation(token.text.as_str()))
& ((token.score < 0.5) | token.score.is_nan())
{
token.label = String::from(".");
token.score = 1f64;
};
token
})
.map(|token| POSTag {
word: token.text,
score: token.score,
label: token.label,
})
.collect::<Vec<POSTag>>()
})
.collect::<Vec<Vec<POSTag>>>()
}
fn is_punctuation(string: &str) -> bool {
string.chars().all(|c| c.is_ascii_punctuation())
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[ignore] fn test() {
let config = POSConfig::default();
let _: Box<dyn Send> = Box::new(POSModel::new(config));
}
}