quickner/
models.rs

1// quickner
2//
3// NER tool for quick and simple NER annotation
4// Copyright (C) 2023, Omar MHAIMDAT
5//
6// Licensed under Mozilla Public License 2.0
7//
8
9use crate::{config::Format, Document};
10use serde::{Deserialize, Serialize};
11use std::io::Write;
12
13#[derive(Eq, PartialEq, Serialize, Deserialize, Clone, Hash, Debug)]
14pub struct Text {
15    pub text: String,
16}
17
18#[derive(Serialize, Deserialize, Clone, Debug)]
19pub struct SpacyEntity {
20    pub entity: Vec<(usize, usize, String)>,
21}
22
23impl Format {
24    /// Save annotations to a file in the specified format
25    /// # Examples
26    /// ```
27    /// use quickner::models::Format;
28    /// use quickner::models::Document;
29    ///
30    /// let annotations = vec![Annotation::from_string("Hello World".to_string())];
31    /// let format = Format::Spacy;
32    /// let path = "./test";
33    /// let result = format.save(annotations, path);
34    /// ```
35    /// # Errors
36    /// Returns an error if the file cannot be written
37    /// # Panics
38    /// Panics if the format is not supported
39    pub fn save(&self, annotations: &Vec<Document>, path: &str) -> Result<String, std::io::Error> {
40        match self {
41            Format::Spacy => Format::spacy(annotations, path),
42            Format::Jsonl => Format::jsonl(annotations, path),
43            Format::Csv => Format::csv(annotations, path),
44            Format::Brat => Format::brat(annotations, path),
45            Format::Conll => Format::conll(annotations, path),
46        }
47    }
48
49    fn remove_extension_from_path(path: &str) -> String {
50        let mut path = path.to_string();
51        if path.contains('.') {
52            path.truncate(path.rfind('.').unwrap());
53        }
54        path
55    }
56
57    fn spacy(documents: &Vec<Document>, path: &str) -> Result<String, std::io::Error> {
58        // Save as such [["text", {"entity": [[0, 4, "ORG"], [5, 10, "ORG"]]}]]
59
60        // Transform Vec<(String, HashMap<String, Vec<(usize, usize, String)>>)> into Structure
61
62        let path = Format::remove_extension_from_path(path);
63        let mut file = std::fs::File::create(format!("{path}.json"))?;
64        let annotations_tranformed: Vec<(String, SpacyEntity)> = documents
65            .into_iter()
66            .map(|annotation| {
67                (
68                    (*annotation.text).to_string(),
69                    SpacyEntity {
70                        entity: (*annotation.label).to_vec(),
71                    },
72                )
73            })
74            .collect();
75        let json = serde_json::to_string(&annotations_tranformed).unwrap();
76        file.write_all(json.as_bytes())?;
77        Ok(path)
78    }
79
80    fn jsonl(documents: &Vec<Document>, path: &str) -> Result<String, std::io::Error> {
81        // Save as such {"text": "text", "label": [[0, 4, "ORG"], [5, 10, "ORG"]]}
82        let path = Format::remove_extension_from_path(path);
83        let mut file = std::fs::File::create(format!("{path}.jsonl"))?;
84        for document in documents {
85            let json = serde_json::to_string(&document).unwrap();
86            file.write_all(json.as_bytes())?;
87            file.write_all(b"\n")?;
88        }
89        Ok(path)
90    }
91
92    fn csv(documents: &Vec<Document>, path: &str) -> Result<String, std::io::Error> {
93        // Save as such "text", "label"
94        let path = Format::remove_extension_from_path(path);
95        let mut file = std::fs::File::create(format!("{path}.csv"))?;
96        for document in documents {
97            let json = serde_json::to_string(&document).unwrap();
98            file.write_all(json.as_bytes())?;
99            file.write_all(b"\n")?;
100        }
101        Ok(path)
102    }
103
104    fn brat(documents: &Vec<Document>, path: &str) -> Result<String, std::io::Error> {
105        // Save .ann and .txt files
106        let path = Format::remove_extension_from_path(path);
107        let mut file_ann = std::fs::File::create(format!("{path}.ann"))?;
108        let mut file_txt = std::fs::File::create(format!("{path}.txt"))?;
109        for document in documents {
110            let text = &document.text;
111            file_txt.write_all(text.as_bytes())?;
112            file_txt.write_all(b"\n")?;
113            for (id, (start, end, label)) in (*document.label).to_vec().into_iter().enumerate() {
114                let entity = text[start..end].to_string();
115                let line = format!("T{id}\t{label}\t{start}\t{end}\t{entity}");
116                file_ann.write_all(line.as_bytes())?;
117                file_ann.write_all(b"\n")?;
118            }
119        }
120        Ok(path)
121    }
122
123    fn conll(documents: &Vec<Document>, path: &str) -> Result<String, std::io::Error> {
124        // for reference: https://simpletransformers.ai/docs/ner-data-formats/
125        let path = Format::remove_extension_from_path(path);
126        let mut file = std::fs::File::create(format!("{path}.txt"))?;
127        let annotations_tranformed: Vec<Vec<(String, String)>> = documents
128            .into_iter()
129            .map(|annotation| {
130                let text = &annotation.text;
131                // Split text into words
132                let words: Vec<&str> = text.split_whitespace().collect();
133                // If the word is not associated with an entity, then it is an "O"
134                let mut labels: Vec<String> = vec!["O".to_string(); words.len()];
135                // For each entity, find the word that contains it and assign the label to it
136                for (start, end, label) in (*annotation.label).to_vec() {
137                    let entity = text[start..end].to_string();
138                    // Find the index of the word that contains the entity
139                    let index = words.iter().position(|&word| word.contains(&entity));
140                    if index.is_none() {
141                        continue;
142                    }
143                    let index = index.unwrap();
144                    // If the word is the same as the entity, then it is a "B" label
145                    labels[index] = label;
146                }
147                // Combine the words and labels into a single vector
148                words
149                    .iter()
150                    .zip(labels.iter())
151                    .map(|(word, label)| (word.to_string(), label.to_string()))
152                    .collect()
153            })
154            .collect();
155        // Save the data, one line per word with the word and label separated by a space
156        for annotation in annotations_tranformed {
157            for (word, label) in annotation {
158                let line = format!("{word}\t{label}");
159                file.write_all(line.as_bytes())?;
160                file.write_all(b"\n")?;
161            }
162            file.write_all(b"\n")?;
163        }
164        Ok(path)
165    }
166}