rust_bert/pipelines/ner.rs
1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright (c) 2018 chakki (https://github.com/chakki-works/seqeval/blob/master/seqeval/metrics/sequence_labeling.py)
3// Copyright 2019 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! # Named Entity Recognition pipeline
15//! Extracts entities (Person, Location, Organization, Miscellaneous) from text.
16//! Pretrained models are available for the following languages:
17//! - English
18//! - German
19//! - Spanish
20//! - Dutch
21//!
22//! The default NER mode is an English BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
23//! All resources for this model can be downloaded using the Python utility script included in this repository.
24//! 1. Set-up a Python virtual environment and install dependencies (in ./requirements.txt)
25//! 2. Run the conversion script python /utils/download-dependencies_bert_ner.py.
26//! The dependencies will be downloaded to the user's home directory, under ~/rustbert/bert-ner
27//!
28//! The example below illustrate how to run the model for the default English NER model
29//! ```no_run
30//! use rust_bert::pipelines::ner::NERModel;
31//! # fn main() -> anyhow::Result<()> {
32//! let ner_model = NERModel::new(Default::default())?;
33//!
34//! let input = [
35//! "My name is Amy. I live in Paris.",
36//! "Paris is a city in France.",
37//! ];
38//! let output = ner_model.predict(&input);
39//! # Ok(())
40//! # }
41//! ```
42//! Output: \
43//! ```no_run
44//! # use rust_bert::pipelines::ner::Entity;
45//! # use rust_tokenizers::Offset;
46//! # let output =
47//! [
48//! [
49//! Entity {
50//! word: String::from("Amy"),
51//! score: 0.9986,
52//! label: String::from("I-PER"),
53//! offset: Offset { begin: 11, end: 14 },
54//! },
55//! Entity {
56//! word: String::from("Paris"),
57//! score: 0.9985,
58//! label: String::from("I-LOC"),
59//! offset: Offset { begin: 26, end: 31 },
60//! },
61//! ],
62//! [
63//! Entity {
64//! word: String::from("Paris"),
65//! score: 0.9988,
66//! label: String::from("I-LOC"),
67//! offset: Offset { begin: 0, end: 5 },
68//! },
69//! Entity {
70//! word: String::from("France"),
71//! score: 0.9993,
72//! label: String::from("I-LOC"),
73//! offset: Offset { begin: 19, end: 25 },
74//! },
75//! ],
76//! ]
77//! # ;
78//! ```
79//!
80//! To run the pipeline for another language, change the NERModel configuration from its default:
81//!
82//! ```no_run
83//! use rust_bert::pipelines::common::ModelType;
84//! use rust_bert::pipelines::ner::NERModel;
85//! use rust_bert::pipelines::token_classification::TokenClassificationConfig;
86//! use rust_bert::resources::RemoteResource;
87//! use rust_bert::roberta::{
88//! RobertaConfigResources, RobertaModelResources, RobertaVocabResources,
89//! };
90//! use tch::Device;
91//!
92//! # fn main() -> anyhow::Result<()> {
93//! use rust_bert::pipelines::common::ModelResource;
94//! let ner_config = TokenClassificationConfig {
95//! model_type: ModelType::XLMRoberta,
96//! model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
97//! RobertaModelResources::XLM_ROBERTA_NER_DE,
98//! ))),
99//! config_resource: Box::new(RemoteResource::from_pretrained(
100//! RobertaConfigResources::XLM_ROBERTA_NER_DE,
101//! )),
102//! vocab_resource: Box::new(RemoteResource::from_pretrained(
103//! RobertaVocabResources::XLM_ROBERTA_NER_DE,
104//! )),
105//! lower_case: false,
106//! device: Device::cuda_if_available(),
107//! ..Default::default()
108//! };
109//!
110//! let ner_model = NERModel::new(ner_config)?;
111//!
112//! // Define input
113//! let input = [
114//! "Mein Name ist Amélie. Ich lebe in Paris.",
115//! "Paris ist eine Stadt in Frankreich.",
116//! ];
117//! let output = ner_model.predict(&input);
118//! # Ok(())
119//! # }
120//! ```
121//! The XLMRoberta models for the languages are defined as follows:
122//!
123//! | **Language** |**Model name**|
124//! :-----:|:----:
125//! English| XLM_ROBERTA_NER_EN |
126//! German| XLM_ROBERTA_NER_DE |
127//! Spanish| XLM_ROBERTA_NER_ES |
128//! Dutch| XLM_ROBERTA_NER_NL |
129
130use crate::common::error::RustBertError;
131use crate::pipelines::common::TokenizerOption;
132use crate::pipelines::token_classification::{
133 Token, TokenClassificationConfig, TokenClassificationModel,
134};
135use rust_tokenizers::Offset;
136use serde::{Deserialize, Serialize};
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
139/// # Entity generated by a `NERModel`
140pub struct Entity {
141 /// String representation of the Entity
142 pub word: String,
143 /// Confidence score
144 pub score: f64,
145 /// Entity label (e.g. ORG, LOC...)
146 pub label: String,
147 /// Token offsets
148 pub offset: Offset,
149}
150
151//type alias for some backward compatibility
152type NERConfig = TokenClassificationConfig;
153
154/// # NERModel to extract named entities
155pub struct NERModel {
156 token_classification_model: TokenClassificationModel,
157}
158
159impl NERModel {
160 /// Build a new `NERModel`
161 ///
162 /// # Arguments
163 ///
164 /// * `ner_config` - `NERConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
165 ///
166 /// # Example
167 ///
168 /// ```no_run
169 /// # fn main() -> anyhow::Result<()> {
170 /// use rust_bert::pipelines::ner::NERModel;
171 ///
172 /// let ner_model = NERModel::new(Default::default())?;
173 /// # Ok(())
174 /// # }
175 /// ```
176 pub fn new(ner_config: NERConfig) -> Result<NERModel, RustBertError> {
177 let model = TokenClassificationModel::new(ner_config)?;
178 Ok(NERModel {
179 token_classification_model: model,
180 })
181 }
182
183 /// Build a new `NERModel` with a provided tokenizer.
184 ///
185 /// # Arguments
186 ///
187 /// * `ner_config` - `NERConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
188 /// * `tokenizer` - `TokenizerOption` tokenizer to use for token classification
189 ///
190 /// # Example
191 ///
192 /// ```no_run
193 /// # fn main() -> anyhow::Result<()> {
194 /// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
195 /// use rust_bert::pipelines::ner::NERModel;
196 /// let tokenizer = TokenizerOption::from_file(
197 /// ModelType::Bert,
198 /// "path/to/vocab.txt",
199 /// None,
200 /// false,
201 /// None,
202 /// None,
203 /// )?;
204 /// let ner_model = NERModel::new_with_tokenizer(Default::default(), tokenizer)?;
205 /// # Ok(())
206 /// # }
207 /// ```
208 pub fn new_with_tokenizer(
209 ner_config: NERConfig,
210 tokenizer: TokenizerOption,
211 ) -> Result<NERModel, RustBertError> {
212 let model = TokenClassificationModel::new_with_tokenizer(ner_config, tokenizer)?;
213 Ok(NERModel {
214 token_classification_model: model,
215 })
216 }
217
218 /// Get a reference to the model tokenizer.
219 pub fn get_tokenizer(&self) -> &TokenizerOption {
220 self.token_classification_model.get_tokenizer()
221 }
222
223 /// Get a mutable reference to the model tokenizer.
224 pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
225 self.token_classification_model.get_tokenizer_mut()
226 }
227
228 /// Extract entities from a text
229 ///
230 /// # Arguments
231 ///
232 /// * `input` - `&[&str]` Array of texts to extract entities from.
233 ///
234 /// # Returns
235 ///
236 /// * `Vec<Vec<Entity>>` containing extracted entities
237 ///
238 /// # Example
239 ///
240 /// ```no_run
241 /// # fn main() -> anyhow::Result<()> {
242 /// # use rust_bert::pipelines::ner::NERModel;
243 ///
244 /// let ner_model = NERModel::new(Default::default())?;
245 /// let input = [
246 /// "My name is Amy. I live in Paris.",
247 /// "Paris is a city in France.",
248 /// ];
249 /// let output = ner_model.predict(&input);
250 /// # Ok(())
251 /// # }
252 /// ```
253 pub fn predict<S>(&self, input: &[S]) -> Vec<Vec<Entity>>
254 where
255 S: AsRef<str>,
256 {
257 self.token_classification_model
258 .predict(input, true, false)
259 .into_iter()
260 .map(|sequence_tokens| {
261 sequence_tokens
262 .into_iter()
263 .filter(|token| token.label != "O")
264 .map(|token| Entity {
265 offset: token.offset.unwrap(),
266 word: token.text,
267 score: token.score,
268 label: token.label,
269 })
270 .collect::<Vec<Entity>>()
271 })
272 .collect::<Vec<Vec<Entity>>>()
273 }
274
275 /// Extract full entities from a text performing entity chunking. Follows the algorithm for entities
276 /// chunking described in [Erik F. Tjong Kim Sang, Jorn Veenstra, Representing Text Chunks](https://www.aclweb.org/anthology/E99-1023/)
277 /// The proposed implementation is inspired by the [Python seqeval library](https://github.com/chakki-works/seqeval) (shared under MIT license).
278 ///
279 /// # Arguments
280 ///
281 /// * `input` - `&[&str]` Array of texts to extract entities from.
282 ///
283 /// # Returns
284 ///
285 /// * `Vec<Entity>` containing consolidated extracted entities
286 ///
287 /// # Example
288 ///
289 /// ```no_run
290 /// # fn main() -> anyhow::Result<()> {
291 /// # use rust_bert::pipelines::ner::NERModel;
292 ///
293 /// let ner_model = NERModel::new(Default::default())?;
294 /// let input = ["Asked John Smith about Acme Corp"];
295 /// let output = ner_model.predict_full_entities(&input);
296 /// # Ok(())
297 /// # }
298 /// ```
299 ///
300 /// Outputs:
301 ///
302 /// Output: \
303 /// ```no_run
304 /// # use rust_bert::pipelines::question_answering::Answer;
305 /// # use rust_bert::pipelines::ner::Entity;
306 /// # use rust_tokenizers::Offset;
307 /// # let output =
308 /// [[
309 /// Entity {
310 /// word: String::from("John Smith"),
311 /// score: 0.9747,
312 /// label: String::from("PER"),
313 /// offset: Offset { begin: 6, end: 16 },
314 /// },
315 /// Entity {
316 /// word: String::from("Acme Corp"),
317 /// score: 0.8847,
318 /// label: String::from("I-LOC"),
319 /// offset: Offset { begin: 23, end: 32 },
320 /// },
321 /// ]]
322 /// # ;
323 /// ```
324 pub fn predict_full_entities<S>(&self, input: &[S]) -> Vec<Vec<Entity>>
325 where
326 S: AsRef<str>,
327 {
328 let tokens = self.token_classification_model.predict(input, true, false);
329 let mut entities: Vec<Vec<Entity>> = Vec::new();
330
331 for sequence_tokens in tokens {
332 entities.push(Self::consolidate_entities(&sequence_tokens));
333 }
334 entities
335 }
336
337 fn consolidate_entities(tokens: &[Token]) -> Vec<Entity> {
338 let mut entities: Vec<Entity> = Vec::new();
339
340 let mut entity_builder = EntityBuilder::new();
341 for (position, token) in tokens.iter().enumerate() {
342 let tag = token.get_tag();
343 let label = token.get_label();
344 if let Some(entity) = entity_builder.handle_current_tag(tag, label, position, tokens) {
345 entities.push(entity)
346 }
347 }
348 if let Some(entity) = entity_builder.flush_and_reset(tokens.len(), tokens) {
349 entities.push(entity);
350 }
351 entities
352 }
353}
354
355struct EntityBuilder<'a> {
356 previous_node: Option<(usize, Tag, &'a str)>,
357}
358
359impl<'a> EntityBuilder<'a> {
360 fn new() -> Self {
361 EntityBuilder {
362 previous_node: None,
363 }
364 }
365
366 fn handle_current_tag(
367 &mut self,
368 tag: Tag,
369 label: &'a str,
370 position: usize,
371 tokens: &[Token],
372 ) -> Option<Entity> {
373 match tag {
374 Tag::Outside => self.flush_and_reset(position, tokens),
375 Tag::Begin | Tag::Single => {
376 let entity = self.flush_and_reset(position, tokens);
377 self.start_new(position, tag, label);
378 entity
379 }
380 Tag::Inside | Tag::End => {
381 if let Some((_, previous_tag, previous_label)) = self.previous_node {
382 if (previous_tag == Tag::End)
383 | (previous_tag == Tag::Single)
384 | (previous_label != label)
385 {
386 let entity = self.flush_and_reset(position, tokens);
387 self.start_new(position, tag, label);
388 entity
389 } else {
390 None
391 }
392 } else {
393 self.start_new(position, tag, label);
394 None
395 }
396 }
397 }
398 }
399
400 fn flush_and_reset(&mut self, position: usize, tokens: &[Token]) -> Option<Entity> {
401 let entity = if let Some((start, _, label)) = self.previous_node {
402 let entity_tokens = &tokens[start..position];
403 Some(Entity {
404 word: entity_tokens
405 .iter()
406 .map(|token| token.text.as_str())
407 .collect::<Vec<&str>>()
408 .join(" "),
409 score: entity_tokens.iter().map(|token| token.score).product(),
410 label: label.to_string(),
411 offset: Offset {
412 begin: entity_tokens.first()?.offset?.begin,
413 end: entity_tokens.last()?.offset?.end,
414 },
415 })
416 } else {
417 None
418 };
419 self.previous_node = None;
420 entity
421 }
422
423 fn start_new(&mut self, position: usize, tag: Tag, label: &'a str) {
424 self.previous_node = Some((position, tag, label))
425 }
426}
427
428#[derive(Debug, Clone, Copy, PartialEq)]
429enum Tag {
430 Begin,
431 Inside,
432 Outside,
433 End,
434 Single,
435}
436
437impl Token {
438 fn get_tag(&self) -> Tag {
439 match self.label.split('-').collect::<Vec<&str>>()[0] {
440 "B" => Tag::Begin,
441 "I" => Tag::Inside,
442 "O" => Tag::Outside,
443 "E" => Tag::End,
444 "S" => Tag::Single,
445 _ => panic!("Invalid tag encountered for token {:?}", self),
446 }
447 }
448
449 fn get_label(&self) -> &str {
450 let split_label = self.label.split('-').collect::<Vec<&str>>();
451 if split_label.len() > 1 {
452 split_label[1]
453 } else {
454 ""
455 }
456 }
457}
458
459#[cfg(test)]
460mod test {
461 use super::*;
462
463 #[test]
464 #[ignore] // no need to run, compilation is enough to verify it is Send
465 fn test() {
466 let config = NERConfig::default();
467 let _: Box<dyn Send> = Box::new(NERModel::new(config));
468 }
469}