quickner/
models.rs

1// quickner
2//
3// NER tool for quick and simple NER annotation
4// Copyright (C) 2023, Omar MHAIMDAT
5//
6// Licensed under Mozilla Public License 2.0
7//
8
9use crate::{
10    config::{Config, Filters, Format},
11    utils::get_progress_bar,
12};
13use log::{error, info};
14use rayon::prelude::*;
15use serde::{Deserialize, Serialize};
16use std::{
17    collections::{HashMap, HashSet},
18    io::Write,
19};
20use std::{env, error::Error};
21use std::{path::Path, time::Instant};
22
23pub struct Quickner {
24    /// Path to the configuration file
25    /// Default: ./config.toml
26    pub config: Config,
27    pub config_file: Config,
28}
29
30#[derive(Eq, PartialEq, Serialize, Deserialize, Clone, Hash, Debug)]
31pub struct Text {
32    pub text: String,
33}
34
35#[derive(Eq, PartialEq, Hash, Serialize, Deserialize, Clone, Debug)]
36pub struct Entity {
37    pub name: String,
38    pub label: String,
39}
40
41#[derive(Serialize, Deserialize, Clone, Debug)]
42pub struct Annotation {
43    pub id: u32,
44    pub text: String,
45    pub label: Vec<(usize, usize, String)>,
46}
47
48impl Annotation {
49    pub fn new(id: u32, text: String, label: Vec<(usize, usize, String)>) -> Self {
50        Annotation { id, text, label }
51    }
52
53    pub fn from_string(text: String) -> Self {
54        Annotation {
55            id: 0,
56            text,
57            label: Vec::new(),
58        }
59    }
60
61    pub fn annotate(&mut self, entities: HashSet<Entity>) {
62        let label = Annotations::find_index(self.text.clone(), entities);
63        match label {
64            Some(label) => self.label = label,
65            None => self.label = Vec::new(),
66        }
67    }
68}
69
70impl Format {
71    pub fn save(&self, annotations: Vec<Annotation>, path: &str) -> Result<String, std::io::Error> {
72        match self {
73            Format::Spacy => Format::spacy(annotations, path),
74            Format::Jsonl => Format::jsonl(annotations, path),
75            Format::Csv => Format::csv(annotations, path),
76            Format::Brat => Format::brat(annotations, path),
77            Format::Conll => Format::conll(annotations, path),
78        }
79    }
80
81    fn remove_extension_from_path(path: &str) -> String {
82        let mut path = path.to_string();
83        if path.contains('.') {
84            path.truncate(path.rfind('.').unwrap());
85        }
86        path
87    }
88
89    fn spacy(annotations: Vec<Annotation>, path: &str) -> Result<String, std::io::Error> {
90        // Save as such [["text", {"entity": [[0, 4, "ORG"], [5, 10, "ORG"]]}]]
91
92        // Transform Vec<(String, HashMap<String, Vec<(usize, usize, String)>>)> into Structure
93        #[derive(Serialize)]
94        struct SpacyEntity {
95            entity: HashMap<String, Vec<(usize, usize, String)>>,
96        }
97
98        let path = Format::remove_extension_from_path(path);
99        let mut file = std::fs::File::create(format!("{path}.json"))?;
100        let annotations_tranformed: Vec<(String, SpacyEntity)> = annotations
101            .into_iter()
102            .map(|annotation| {
103                let mut map = HashMap::new();
104                map.insert("entity".to_string(), annotation.label);
105                (annotation.text, SpacyEntity { entity: map })
106            })
107            .collect();
108        let json = serde_json::to_string(&annotations_tranformed).unwrap();
109        file.write_all(json.as_bytes())?;
110        Ok(path)
111    }
112
113    fn jsonl(annotations: Vec<Annotation>, path: &str) -> Result<String, std::io::Error> {
114        // Save as such {"text": "text", "label": [[0, 4, "ORG"], [5, 10, "ORG"]]}
115        let path = Format::remove_extension_from_path(path);
116        let mut file = std::fs::File::create(format!("{path}.jsonl"))?;
117        for annotation in annotations {
118            let json = serde_json::to_string(&annotation).unwrap();
119            file.write_all(json.as_bytes())?;
120            file.write_all(b"\n")?;
121        }
122        Ok(path)
123    }
124
125    fn csv(annotations: Vec<Annotation>, path: &str) -> Result<String, std::io::Error> {
126        // Save as such "text", "label"
127        let path = Format::remove_extension_from_path(path);
128        let mut file = std::fs::File::create(format!("{path}.csv"))?;
129        for annotation in annotations {
130            let json = serde_json::to_string(&annotation).unwrap();
131            file.write_all(json.as_bytes())?;
132            file.write_all(b"\n")?;
133        }
134        Ok(path)
135    }
136
137    fn brat(annotations: Vec<Annotation>, path: &str) -> Result<String, std::io::Error> {
138        // Save .ann and .txt files
139        let path = Format::remove_extension_from_path(path);
140        let mut file_ann = std::fs::File::create(format!("{path}.ann"))?;
141        let mut file_txt = std::fs::File::create(format!("{path}.txt"))?;
142        for annotation in annotations {
143            let text = annotation.text;
144            file_txt.write_all(text.as_bytes())?;
145            file_txt.write_all(b"\n")?;
146            for (id, (start, end, label)) in annotation.label.into_iter().enumerate() {
147                let entity = text[start..end].to_string();
148                let line = format!("T{id}\t{label}\t{start}\t{end}\t{entity}");
149                file_ann.write_all(line.as_bytes())?;
150                file_ann.write_all(b"\n")?;
151            }
152        }
153        Ok(path)
154    }
155
156    fn conll(annotations: Vec<Annotation>, path: &str) -> Result<String, std::io::Error> {
157        // for reference: https://simpletransformers.ai/docs/ner-data-formats/
158        let path = Format::remove_extension_from_path(path);
159        let mut file = std::fs::File::create(format!("{path}.txt"))?;
160        let annotations_tranformed: Vec<Vec<(String, String)>> = annotations
161            .into_iter()
162            .map(|annotation| {
163                let text = annotation.text;
164                // Split text into words
165                let words: Vec<&str> = text.split_whitespace().collect();
166                // If the word is not associated with an entity, then it is an "O"
167                let mut labels: Vec<String> = vec!["O".to_string(); words.len()];
168                // For each entity, find the word that contains it and assign the label to it
169                for (start, end, label) in annotation.label {
170                    let entity = text[start..end].to_string();
171                    // Find the index of the word that contains the entity
172                    let index = words.iter().position(|&word| word.contains(&entity));
173                    if index.is_none() {
174                        continue;
175                    }
176                    let index = index.unwrap();
177                    // If the word is the same as the entity, then it is a "B" label
178                    labels[index] = label;
179                }
180                // Combine the words and labels into a single vector
181                words
182                    .iter()
183                    .zip(labels.iter())
184                    .map(|(word, label)| (word.to_string(), label.to_string()))
185                    .collect()
186            })
187            .collect();
188        // Save the data, one line per word with the word and label separated by a space
189        for annotation in annotations_tranformed {
190            for (word, label) in annotation {
191                let line = format!("{word}\t{label}");
192                file.write_all(line.as_bytes())?;
193                file.write_all(b"\n")?;
194            }
195            file.write_all(b"\n")?;
196        }
197        Ok(path)
198    }
199}
200
201impl PartialEq for Annotation {
202    fn eq(&self, other: &Self) -> bool {
203        self.id == other.id
204    }
205}
206
207#[derive(Serialize, Deserialize, Clone)]
208pub struct Annotations {
209    pub annotations: Vec<Annotation>,
210    pub entities: HashSet<Entity>,
211    pub texts: Vec<Text>,
212}
213
214impl Annotations {
215    pub fn new(entities: HashSet<Entity>, texts: Vec<Text>) -> Self {
216        Annotations {
217            annotations: Vec::new(),
218            entities,
219            texts,
220        }
221    }
222
223    fn find_index(text: String, entities: HashSet<Entity>) -> Option<Vec<(usize, usize, String)>> {
224        // let mut annotations = Vec::new();
225        let annotations = entities.iter().map(|entity| {
226            let target_len = entity.name.len();
227            for (start, _) in text.to_lowercase().match_indices(entity.name.as_str()) {
228                if start == 0
229                    || text.chars().nth(start - 1).unwrap().is_whitespace()
230                    || text.chars().nth(start - 1).unwrap().is_ascii_punctuation()
231                    || ((start + target_len) == text.len()
232                        || text
233                            .chars()
234                            .nth(start + target_len)
235                            .unwrap_or('N')
236                            .is_whitespace()
237                        || (text
238                            .chars()
239                            .nth(start + target_len)
240                            .unwrap_or('N')
241                            .is_ascii_punctuation()
242                            && text.chars().nth(start + target_len).unwrap() != '.'
243                            && (start > 0 && text.chars().nth(start - 1).unwrap() != '.')))
244                {
245                    return (start, start + target_len, entity.label.to_string());
246                }
247            }
248            (0, 0, String::new())
249        });
250        let annotations: Vec<(usize, usize, String)> = annotations
251            .filter(|(_, _, label)| !label.is_empty())
252            .collect();
253        if !annotations.is_empty() {
254            Some(annotations)
255        } else {
256            None
257        }
258    }
259
260    pub fn annotate(&mut self) {
261        let pb = get_progress_bar(self.texts.len() as u64);
262        pb.set_message("Annotating texts");
263        let start = Instant::now();
264        self.texts
265            .par_iter()
266            .enumerate()
267            .map(|(i, text)| {
268                let t = text.text.clone();
269                let index = Annotations::find_index(t, self.entities.clone());
270                let mut index = match index {
271                    Some(index) => index,
272                    None => vec![],
273                };
274                index.sort_by(|a, b| a.0.cmp(&b.0));
275                pb.inc(1);
276                Annotation {
277                    id: (i + 1) as u32,
278                    text: text.text.clone(),
279                    label: index,
280                }
281            })
282            .collect_into_vec(&mut self.annotations);
283        let end = start.elapsed();
284        println!(
285            "Time elapsed in find_index() is: {:?} for {} texts",
286            end,
287            self.texts.len() * self.entities.len()
288        );
289        pb.finish();
290    }
291}
292
293impl Quickner {
294    /// Creates a new instance of Quickner
295    /// If no configuration file is provided, the default configuration file is used.
296    /// Default: ./config.toml
297    pub fn new(config_file: Option<&str>) -> Self {
298        println!("New instance of Quickner");
299        println!("Configuration file: {config_file:?}");
300        let config_file = match config_file {
301            Some(config_file) => config_file.to_string(),
302            None => "./config.toml".to_string(),
303        };
304        // Check if the configuration file path exists
305        if Path::new(config_file.as_str()).exists() {
306            info!("Configuration file: {}", config_file.as_str());
307        } else {
308            println!("Configuration file {} does not exist", config_file.as_str());
309            error!("Configuration file {} does not exist", config_file.as_str());
310            std::process::exit(1);
311        }
312        let config = Config::from_file(config_file.as_str());
313        Quickner {
314            config,
315            config_file: Config::from_file(config_file.as_str()),
316        }
317    }
318
319    fn parse_config(&self) -> Config {
320        let mut config = self.config.clone();
321        config.entities.filters.set_special_characters();
322        config.texts.filters.set_special_characters();
323        let log_level_is_set = env::var("QUICKNER_LOG_LEVEL_SET").ok();
324        if log_level_is_set.is_none() {
325            match config.logging {
326                Some(ref mut logging) => {
327                    env_logger::Builder::from_env(
328                        env_logger::Env::default().default_filter_or(logging.level.as_str()),
329                    )
330                    .init();
331                    env::set_var("QUICKNER_LOG_LEVEL_SET", "true");
332                }
333                None => {
334                    env_logger::Builder::from_env(
335                        env_logger::Env::default().default_filter_or("info"),
336                    )
337                    .init();
338                    env::set_var("QUICKNER_LOG_LEVEL_SET", "true");
339                }
340            };
341        }
342
343        config
344    }
345
346    /// Returns a list of Annotations
347    pub fn process(&self, save: bool) -> Result<Annotations, Box<dyn Error>> {
348        let config = self.parse_config();
349        config.summary();
350
351        info!("----------------------------------------");
352        let entities: HashSet<Entity> = self.entities(
353            config.entities.input.path.as_str(),
354            config.entities.filters,
355            config.entities.input.filter.unwrap_or(false),
356        );
357        let texts: HashSet<Text> = self.texts(
358            config.texts.input.path.as_str(),
359            config.texts.filters,
360            config.texts.input.filter.unwrap_or(false),
361        );
362        let texts: Vec<Text> = texts.into_iter().collect();
363        let excludes: HashSet<String> = match config.entities.excludes.path {
364            Some(path) => {
365                info!("Reading excludes from {}", path.as_str());
366                self.excludes(path.as_str())
367            }
368            None => {
369                info!("No excludes file provided");
370                HashSet::new()
371            }
372        };
373        // Remove excludes from entities
374        let entities: HashSet<Entity> = entities
375            .iter()
376            .filter(|entity| !excludes.contains(&entity.name))
377            .cloned()
378            .collect();
379        info!("{} entities found", entities.len());
380        info!("{} texts found", texts.len());
381        let mut annotations = Annotations::new(entities, texts);
382        annotations.annotate();
383        info!("{} annotations found", annotations.annotations.len());
384        // annotations.save(&config.annotations.output.path);
385        if save {
386            let save = config.annotations.format.save(
387                annotations.annotations.clone(),
388                &config.annotations.output.path,
389            );
390            match save {
391                Ok(_) => info!(
392                    "Annotations saved with format {:?}",
393                    config.annotations.format
394                ),
395                Err(e) => error!("Unable to save the annotations: {}", e),
396            }
397        }
398        // Transform annotations to Python objects
399        // List of tuples (text, [[start, end, label], [start, end, label], ...
400        // let annotations_py: Vec<(String, Vec<(usize, usize, String)>)> =
401        //     annotations.transform_annotations();
402        // Ok(annotations_py)
403        Ok(annotations)
404    }
405
406    fn entities(&self, path: &str, filters: Filters, filter: bool) -> HashSet<Entity> {
407        // Read CSV file and parse it
408        // Expect columns: name, label
409        info!("Reading entities from {}", path);
410        let rdr = csv::Reader::from_path(path);
411        match rdr {
412            Ok(mut rdr) => {
413                let mut entities = HashSet::new();
414                for result in rdr.deserialize() {
415                    let record: Result<Entity, csv::Error> = result;
416                    match record {
417                        Ok(entity) => {
418                            if filter {
419                                if filters.is_valid(&entity.name) {
420                                    entities.insert(entity);
421                                }
422                            } else {
423                                entities.insert(entity);
424                            }
425                        }
426                        Err(e) => {
427                            error!("Unable to parse the entities file: {}", e);
428                            std::process::exit(1);
429                        }
430                    }
431                }
432                entities
433            }
434            Err(e) => {
435                error!("Unable to parse the entities file: {}", e);
436                std::process::exit(1);
437            }
438        }
439    }
440
441    fn texts(&self, path: &str, filters: Filters, filter: bool) -> HashSet<Text> {
442        // Read CSV file and parse it
443        // Expect columns: texts
444        info!("Reading texts from {}", path);
445        let rdr = csv::Reader::from_path(path);
446        match rdr {
447            Ok(mut rdr) => {
448                let mut texts = HashSet::new();
449                for result in rdr.deserialize() {
450                    let record: Result<Text, csv::Error> = result;
451                    match record {
452                        Ok(text) => {
453                            if filter {
454                                if filters.is_valid(&text.text) {
455                                    texts.insert(text);
456                                }
457                            } else {
458                                texts.insert(text);
459                            }
460                        }
461                        Err(e) => {
462                            error!("Unable to parse the texts file: {}", e);
463                            std::process::exit(1);
464                        }
465                    }
466                }
467                texts
468            }
469            Err(e) => {
470                error!("Unable to parse the texts file: {}", e);
471                std::process::exit(1);
472            }
473        }
474    }
475
476    fn excludes(&self, path: &str) -> HashSet<String> {
477        // Read CSV file and parse it
478        let rdr = csv::Reader::from_path(path);
479        match rdr {
480            Ok(mut rdr) => {
481                let mut excludes = HashSet::new();
482                for result in rdr.records() {
483                    let record = result.unwrap();
484                    excludes.insert(record[0].to_string());
485                }
486                excludes
487            }
488            Err(e) => {
489                error!("Unable to parse the excludes file: {}", e);
490                std::process::exit(1);
491            }
492        }
493    }
494}