rust_bert/pipelines/
ner.rs

1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright (c) 2018 chakki (https://github.com/chakki-works/seqeval/blob/master/seqeval/metrics/sequence_labeling.py)
3// Copyright 2019 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! # Named Entity Recognition pipeline
15//! Extracts entities (Person, Location, Organization, Miscellaneous) from text.
16//! Pretrained models are available for the following languages:
17//! - English
18//! - German
19//! - Spanish
20//! - Dutch
21//!
22//! The default NER mode is an English BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
23//! All resources for this model can be downloaded using the Python utility script included in this repository.
24//! 1. Set-up a Python virtual environment and install dependencies (in ./requirements.txt)
25//! 2. Run the conversion script python /utils/download-dependencies_bert_ner.py.
26//!     The dependencies will be downloaded to the user's home directory, under ~/rustbert/bert-ner
27//!
28//! The example below illustrate how to run the model for the default English NER model
29//! ```no_run
30//! use rust_bert::pipelines::ner::NERModel;
31//! # fn main() -> anyhow::Result<()> {
32//! let ner_model = NERModel::new(Default::default())?;
33//!
34//! let input = [
35//!     "My name is Amy. I live in Paris.",
36//!     "Paris is a city in France.",
37//! ];
38//! let output = ner_model.predict(&input);
39//! # Ok(())
40//! # }
41//! ```
42//! Output: \
43//! ```no_run
44//! # use rust_bert::pipelines::ner::Entity;
45//! # use rust_tokenizers::Offset;
46//! # let output =
47//! [
48//!     [
49//!         Entity {
50//!             word: String::from("Amy"),
51//!             score: 0.9986,
52//!             label: String::from("I-PER"),
53//!             offset: Offset { begin: 11, end: 14 },
54//!         },
55//!         Entity {
56//!             word: String::from("Paris"),
57//!             score: 0.9985,
58//!             label: String::from("I-LOC"),
59//!             offset: Offset { begin: 26, end: 31 },
60//!         },
61//!     ],
62//!     [
63//!         Entity {
64//!             word: String::from("Paris"),
65//!             score: 0.9988,
66//!             label: String::from("I-LOC"),
67//!             offset: Offset { begin: 0, end: 5 },
68//!         },
69//!         Entity {
70//!             word: String::from("France"),
71//!             score: 0.9993,
72//!             label: String::from("I-LOC"),
73//!             offset: Offset { begin: 19, end: 25 },
74//!         },
75//!     ],
76//! ]
77//! # ;
78//! ```
79//!
80//! To run the pipeline for another language, change the NERModel configuration from its default:
81//!
82//! ```no_run
83//! use rust_bert::pipelines::common::ModelType;
84//! use rust_bert::pipelines::ner::NERModel;
85//! use rust_bert::pipelines::token_classification::TokenClassificationConfig;
86//! use rust_bert::resources::RemoteResource;
87//! use rust_bert::roberta::{
88//!     RobertaConfigResources, RobertaModelResources, RobertaVocabResources,
89//! };
90//! use tch::Device;
91//!
92//! # fn main() -> anyhow::Result<()> {
93//! use rust_bert::pipelines::common::ModelResource;
94//! let ner_config = TokenClassificationConfig {
95//!     model_type: ModelType::XLMRoberta,
96//!     model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
97//!         RobertaModelResources::XLM_ROBERTA_NER_DE,
98//!     ))),
99//!     config_resource: Box::new(RemoteResource::from_pretrained(
100//!         RobertaConfigResources::XLM_ROBERTA_NER_DE,
101//!     )),
102//!     vocab_resource: Box::new(RemoteResource::from_pretrained(
103//!         RobertaVocabResources::XLM_ROBERTA_NER_DE,
104//!     )),
105//!     lower_case: false,
106//!     device: Device::cuda_if_available(),
107//!     ..Default::default()
108//! };
109//!
110//! let ner_model = NERModel::new(ner_config)?;
111//!
112//! //    Define input
113//! let input = [
114//!     "Mein Name ist Amélie. Ich lebe in Paris.",
115//!     "Paris ist eine Stadt in Frankreich.",
116//! ];
117//! let output = ner_model.predict(&input);
118//! # Ok(())
119//! # }
120//! ```
121//! The XLMRoberta models for the languages are defined as follows:
122//!
123//! | **Language** |**Model name**|
124//! :-----:|:----:
125//! English| XLM_ROBERTA_NER_EN |
126//! German| XLM_ROBERTA_NER_DE |
127//! Spanish| XLM_ROBERTA_NER_ES |
128//! Dutch| XLM_ROBERTA_NER_NL |
129
130use crate::common::error::RustBertError;
131use crate::pipelines::common::TokenizerOption;
132use crate::pipelines::token_classification::{
133    Token, TokenClassificationConfig, TokenClassificationModel,
134};
135use rust_tokenizers::Offset;
136use serde::{Deserialize, Serialize};
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
139/// # Entity generated by a `NERModel`
140pub struct Entity {
141    /// String representation of the Entity
142    pub word: String,
143    /// Confidence score
144    pub score: f64,
145    /// Entity label (e.g. ORG, LOC...)
146    pub label: String,
147    /// Token offsets
148    pub offset: Offset,
149}
150
151//type alias for some backward compatibility
152type NERConfig = TokenClassificationConfig;
153
154/// # NERModel to extract named entities
155pub struct NERModel {
156    token_classification_model: TokenClassificationModel,
157}
158
159impl NERModel {
160    /// Build a new `NERModel`
161    ///
162    /// # Arguments
163    ///
164    /// * `ner_config` - `NERConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
165    ///
166    /// # Example
167    ///
168    /// ```no_run
169    /// # fn main() -> anyhow::Result<()> {
170    /// use rust_bert::pipelines::ner::NERModel;
171    ///
172    /// let ner_model = NERModel::new(Default::default())?;
173    /// # Ok(())
174    /// # }
175    /// ```
176    pub fn new(ner_config: NERConfig) -> Result<NERModel, RustBertError> {
177        let model = TokenClassificationModel::new(ner_config)?;
178        Ok(NERModel {
179            token_classification_model: model,
180        })
181    }
182
183    /// Build a new `NERModel` with a provided tokenizer.
184    ///
185    /// # Arguments
186    ///
187    /// * `ner_config` - `NERConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
188    /// * `tokenizer` - `TokenizerOption` tokenizer to use for token classification
189    ///
190    /// # Example
191    ///
192    /// ```no_run
193    /// # fn main() -> anyhow::Result<()> {
194    /// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
195    /// use rust_bert::pipelines::ner::NERModel;
196    /// let tokenizer = TokenizerOption::from_file(
197    ///     ModelType::Bert,
198    ///     "path/to/vocab.txt",
199    ///     None,
200    ///     false,
201    ///     None,
202    ///     None,
203    /// )?;
204    /// let ner_model = NERModel::new_with_tokenizer(Default::default(), tokenizer)?;
205    /// # Ok(())
206    /// # }
207    /// ```
208    pub fn new_with_tokenizer(
209        ner_config: NERConfig,
210        tokenizer: TokenizerOption,
211    ) -> Result<NERModel, RustBertError> {
212        let model = TokenClassificationModel::new_with_tokenizer(ner_config, tokenizer)?;
213        Ok(NERModel {
214            token_classification_model: model,
215        })
216    }
217
218    /// Get a reference to the model tokenizer.
219    pub fn get_tokenizer(&self) -> &TokenizerOption {
220        self.token_classification_model.get_tokenizer()
221    }
222
223    /// Get a mutable reference to the model tokenizer.
224    pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
225        self.token_classification_model.get_tokenizer_mut()
226    }
227
228    /// Extract entities from a text
229    ///
230    /// # Arguments
231    ///
232    /// * `input` - `&[&str]` Array of texts to extract entities from.
233    ///
234    /// # Returns
235    ///
236    /// * `Vec<Vec<Entity>>` containing extracted entities
237    ///
238    /// # Example
239    ///
240    /// ```no_run
241    /// # fn main() -> anyhow::Result<()> {
242    /// # use rust_bert::pipelines::ner::NERModel;
243    ///
244    /// let ner_model = NERModel::new(Default::default())?;
245    /// let input = [
246    ///     "My name is Amy. I live in Paris.",
247    ///     "Paris is a city in France.",
248    /// ];
249    /// let output = ner_model.predict(&input);
250    /// # Ok(())
251    /// # }
252    /// ```
253    pub fn predict<S>(&self, input: &[S]) -> Vec<Vec<Entity>>
254    where
255        S: AsRef<str>,
256    {
257        self.token_classification_model
258            .predict(input, true, false)
259            .into_iter()
260            .map(|sequence_tokens| {
261                sequence_tokens
262                    .into_iter()
263                    .filter(|token| token.label != "O")
264                    .map(|token| Entity {
265                        offset: token.offset.unwrap(),
266                        word: token.text,
267                        score: token.score,
268                        label: token.label,
269                    })
270                    .collect::<Vec<Entity>>()
271            })
272            .collect::<Vec<Vec<Entity>>>()
273    }
274
275    /// Extract full entities from a text performing entity chunking. Follows the algorithm for entities
276    /// chunking described in [Erik F. Tjong Kim Sang, Jorn Veenstra, Representing Text Chunks](https://www.aclweb.org/anthology/E99-1023/)
277    /// The proposed implementation is inspired by the [Python seqeval library](https://github.com/chakki-works/seqeval) (shared under MIT license).
278    ///
279    /// # Arguments
280    ///
281    /// * `input` - `&[&str]` Array of texts to extract entities from.
282    ///
283    /// # Returns
284    ///
285    /// * `Vec<Entity>` containing consolidated extracted entities
286    ///
287    /// # Example
288    ///
289    /// ```no_run
290    /// # fn main() -> anyhow::Result<()> {
291    /// # use rust_bert::pipelines::ner::NERModel;
292    ///
293    /// let ner_model = NERModel::new(Default::default())?;
294    /// let input = ["Asked John Smith about Acme Corp"];
295    /// let output = ner_model.predict_full_entities(&input);
296    /// # Ok(())
297    /// # }
298    /// ```
299    ///
300    /// Outputs:
301    ///
302    /// Output: \
303    /// ```no_run
304    /// # use rust_bert::pipelines::question_answering::Answer;
305    /// # use rust_bert::pipelines::ner::Entity;
306    /// # use rust_tokenizers::Offset;
307    /// # let output =
308    /// [[
309    ///     Entity {
310    ///         word: String::from("John Smith"),
311    ///         score: 0.9747,
312    ///         label: String::from("PER"),
313    ///         offset: Offset { begin: 6, end: 16 },
314    ///     },
315    ///     Entity {
316    ///         word: String::from("Acme Corp"),
317    ///         score: 0.8847,
318    ///         label: String::from("I-LOC"),
319    ///         offset: Offset { begin: 23, end: 32 },
320    ///     },
321    /// ]]
322    /// # ;
323    /// ```
324    pub fn predict_full_entities<S>(&self, input: &[S]) -> Vec<Vec<Entity>>
325    where
326        S: AsRef<str>,
327    {
328        let tokens = self.token_classification_model.predict(input, true, false);
329        let mut entities: Vec<Vec<Entity>> = Vec::new();
330
331        for sequence_tokens in tokens {
332            entities.push(Self::consolidate_entities(&sequence_tokens));
333        }
334        entities
335    }
336
337    fn consolidate_entities(tokens: &[Token]) -> Vec<Entity> {
338        let mut entities: Vec<Entity> = Vec::new();
339
340        let mut entity_builder = EntityBuilder::new();
341        for (position, token) in tokens.iter().enumerate() {
342            let tag = token.get_tag();
343            let label = token.get_label();
344            if let Some(entity) = entity_builder.handle_current_tag(tag, label, position, tokens) {
345                entities.push(entity)
346            }
347        }
348        if let Some(entity) = entity_builder.flush_and_reset(tokens.len(), tokens) {
349            entities.push(entity);
350        }
351        entities
352    }
353}
354
355struct EntityBuilder<'a> {
356    previous_node: Option<(usize, Tag, &'a str)>,
357}
358
359impl<'a> EntityBuilder<'a> {
360    fn new() -> Self {
361        EntityBuilder {
362            previous_node: None,
363        }
364    }
365
366    fn handle_current_tag(
367        &mut self,
368        tag: Tag,
369        label: &'a str,
370        position: usize,
371        tokens: &[Token],
372    ) -> Option<Entity> {
373        match tag {
374            Tag::Outside => self.flush_and_reset(position, tokens),
375            Tag::Begin | Tag::Single => {
376                let entity = self.flush_and_reset(position, tokens);
377                self.start_new(position, tag, label);
378                entity
379            }
380            Tag::Inside | Tag::End => {
381                if let Some((_, previous_tag, previous_label)) = self.previous_node {
382                    if (previous_tag == Tag::End)
383                        | (previous_tag == Tag::Single)
384                        | (previous_label != label)
385                    {
386                        let entity = self.flush_and_reset(position, tokens);
387                        self.start_new(position, tag, label);
388                        entity
389                    } else {
390                        None
391                    }
392                } else {
393                    self.start_new(position, tag, label);
394                    None
395                }
396            }
397        }
398    }
399
400    fn flush_and_reset(&mut self, position: usize, tokens: &[Token]) -> Option<Entity> {
401        let entity = if let Some((start, _, label)) = self.previous_node {
402            let entity_tokens = &tokens[start..position];
403            Some(Entity {
404                word: entity_tokens
405                    .iter()
406                    .map(|token| token.text.as_str())
407                    .collect::<Vec<&str>>()
408                    .join(" "),
409                score: entity_tokens.iter().map(|token| token.score).product(),
410                label: label.to_string(),
411                offset: Offset {
412                    begin: entity_tokens.first()?.offset?.begin,
413                    end: entity_tokens.last()?.offset?.end,
414                },
415            })
416        } else {
417            None
418        };
419        self.previous_node = None;
420        entity
421    }
422
423    fn start_new(&mut self, position: usize, tag: Tag, label: &'a str) {
424        self.previous_node = Some((position, tag, label))
425    }
426}
427
428#[derive(Debug, Clone, Copy, PartialEq)]
429enum Tag {
430    Begin,
431    Inside,
432    Outside,
433    End,
434    Single,
435}
436
437impl Token {
438    fn get_tag(&self) -> Tag {
439        match self.label.split('-').collect::<Vec<&str>>()[0] {
440            "B" => Tag::Begin,
441            "I" => Tag::Inside,
442            "O" => Tag::Outside,
443            "E" => Tag::End,
444            "S" => Tag::Single,
445            _ => panic!("Invalid tag encountered for token {:?}", self),
446        }
447    }
448
449    fn get_label(&self) -> &str {
450        let split_label = self.label.split('-').collect::<Vec<&str>>();
451        if split_label.len() > 1 {
452            split_label[1]
453        } else {
454            ""
455        }
456    }
457}
458
459#[cfg(test)]
460mod test {
461    use super::*;
462
463    #[test]
464    #[ignore] // no need to run, compilation is enough to verify it is Send
465    fn test() {
466        let config = NERConfig::default();
467        let _: Box<dyn Send> = Box::new(NERModel::new(config));
468    }
469}