use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::token_classification::{
LabelAggregationOption, TokenClassificationConfig, TokenClassificationModel,
};
use rust_bert::resources::{RemoteResource, Resource};
fn main() -> failure::Fallible<()> {
let config = TokenClassificationConfig::new(
ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
)),
Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
)),
Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
)),
None, false, LabelAggregationOption::Mode,
);
let token_classification_model = TokenClassificationModel::new(config)?;
let input = [
"My name is Amélie. I live in Москва.",
"Chongqing is a city in China.",
];
let token_outputs = token_classification_model.predict(&input, true, false);
for token in token_outputs {
println!("{:?}", token);
}
Ok(())
}