1use crate::{
10 config::{Config, Filters, Format},
11 utils::get_progress_bar,
12};
13use log::{error, info};
14use rayon::prelude::*;
15use serde::{Deserialize, Serialize};
16use std::{
17 collections::{HashMap, HashSet},
18 io::Write,
19};
20use std::{env, error::Error};
21use std::{path::Path, time::Instant};
22
23pub struct Quickner {
24 pub config: Config,
27 pub config_file: Config,
28}
29
30#[derive(Eq, PartialEq, Serialize, Deserialize, Clone, Hash, Debug)]
31pub struct Text {
32 pub text: String,
33}
34
35#[derive(Eq, PartialEq, Hash, Serialize, Deserialize, Clone, Debug)]
36pub struct Entity {
37 pub name: String,
38 pub label: String,
39}
40
41#[derive(Serialize, Deserialize, Clone, Debug)]
42pub struct Annotation {
43 pub id: u32,
44 pub text: String,
45 pub label: Vec<(usize, usize, String)>,
46}
47
48impl Annotation {
49 pub fn new(id: u32, text: String, label: Vec<(usize, usize, String)>) -> Self {
50 Annotation { id, text, label }
51 }
52
53 pub fn from_string(text: String) -> Self {
54 Annotation {
55 id: 0,
56 text,
57 label: Vec::new(),
58 }
59 }
60
61 pub fn annotate(&mut self, entities: HashSet<Entity>) {
62 let label = Annotations::find_index(self.text.clone(), entities);
63 match label {
64 Some(label) => self.label = label,
65 None => self.label = Vec::new(),
66 }
67 }
68}
69
70impl Format {
71 pub fn save(&self, annotations: Vec<Annotation>, path: &str) -> Result<String, std::io::Error> {
72 match self {
73 Format::Spacy => Format::spacy(annotations, path),
74 Format::Jsonl => Format::jsonl(annotations, path),
75 Format::Csv => Format::csv(annotations, path),
76 Format::Brat => Format::brat(annotations, path),
77 Format::Conll => Format::conll(annotations, path),
78 }
79 }
80
81 fn remove_extension_from_path(path: &str) -> String {
82 let mut path = path.to_string();
83 if path.contains('.') {
84 path.truncate(path.rfind('.').unwrap());
85 }
86 path
87 }
88
89 fn spacy(annotations: Vec<Annotation>, path: &str) -> Result<String, std::io::Error> {
90 #[derive(Serialize)]
94 struct SpacyEntity {
95 entity: HashMap<String, Vec<(usize, usize, String)>>,
96 }
97
98 let path = Format::remove_extension_from_path(path);
99 let mut file = std::fs::File::create(format!("{path}.json"))?;
100 let annotations_tranformed: Vec<(String, SpacyEntity)> = annotations
101 .into_iter()
102 .map(|annotation| {
103 let mut map = HashMap::new();
104 map.insert("entity".to_string(), annotation.label);
105 (annotation.text, SpacyEntity { entity: map })
106 })
107 .collect();
108 let json = serde_json::to_string(&annotations_tranformed).unwrap();
109 file.write_all(json.as_bytes())?;
110 Ok(path)
111 }
112
113 fn jsonl(annotations: Vec<Annotation>, path: &str) -> Result<String, std::io::Error> {
114 let path = Format::remove_extension_from_path(path);
116 let mut file = std::fs::File::create(format!("{path}.jsonl"))?;
117 for annotation in annotations {
118 let json = serde_json::to_string(&annotation).unwrap();
119 file.write_all(json.as_bytes())?;
120 file.write_all(b"\n")?;
121 }
122 Ok(path)
123 }
124
125 fn csv(annotations: Vec<Annotation>, path: &str) -> Result<String, std::io::Error> {
126 let path = Format::remove_extension_from_path(path);
128 let mut file = std::fs::File::create(format!("{path}.csv"))?;
129 for annotation in annotations {
130 let json = serde_json::to_string(&annotation).unwrap();
131 file.write_all(json.as_bytes())?;
132 file.write_all(b"\n")?;
133 }
134 Ok(path)
135 }
136
137 fn brat(annotations: Vec<Annotation>, path: &str) -> Result<String, std::io::Error> {
138 let path = Format::remove_extension_from_path(path);
140 let mut file_ann = std::fs::File::create(format!("{path}.ann"))?;
141 let mut file_txt = std::fs::File::create(format!("{path}.txt"))?;
142 for annotation in annotations {
143 let text = annotation.text;
144 file_txt.write_all(text.as_bytes())?;
145 file_txt.write_all(b"\n")?;
146 for (id, (start, end, label)) in annotation.label.into_iter().enumerate() {
147 let entity = text[start..end].to_string();
148 let line = format!("T{id}\t{label}\t{start}\t{end}\t{entity}");
149 file_ann.write_all(line.as_bytes())?;
150 file_ann.write_all(b"\n")?;
151 }
152 }
153 Ok(path)
154 }
155
156 fn conll(annotations: Vec<Annotation>, path: &str) -> Result<String, std::io::Error> {
157 let path = Format::remove_extension_from_path(path);
159 let mut file = std::fs::File::create(format!("{path}.txt"))?;
160 let annotations_tranformed: Vec<Vec<(String, String)>> = annotations
161 .into_iter()
162 .map(|annotation| {
163 let text = annotation.text;
164 let words: Vec<&str> = text.split_whitespace().collect();
166 let mut labels: Vec<String> = vec!["O".to_string(); words.len()];
168 for (start, end, label) in annotation.label {
170 let entity = text[start..end].to_string();
171 let index = words.iter().position(|&word| word.contains(&entity));
173 if index.is_none() {
174 continue;
175 }
176 let index = index.unwrap();
177 labels[index] = label;
179 }
180 words
182 .iter()
183 .zip(labels.iter())
184 .map(|(word, label)| (word.to_string(), label.to_string()))
185 .collect()
186 })
187 .collect();
188 for annotation in annotations_tranformed {
190 for (word, label) in annotation {
191 let line = format!("{word}\t{label}");
192 file.write_all(line.as_bytes())?;
193 file.write_all(b"\n")?;
194 }
195 file.write_all(b"\n")?;
196 }
197 Ok(path)
198 }
199}
200
201impl PartialEq for Annotation {
202 fn eq(&self, other: &Self) -> bool {
203 self.id == other.id
204 }
205}
206
207#[derive(Serialize, Deserialize, Clone)]
208pub struct Annotations {
209 pub annotations: Vec<Annotation>,
210 pub entities: HashSet<Entity>,
211 pub texts: Vec<Text>,
212}
213
214impl Annotations {
215 pub fn new(entities: HashSet<Entity>, texts: Vec<Text>) -> Self {
216 Annotations {
217 annotations: Vec::new(),
218 entities,
219 texts,
220 }
221 }
222
223 fn find_index(text: String, entities: HashSet<Entity>) -> Option<Vec<(usize, usize, String)>> {
224 let annotations = entities.iter().map(|entity| {
226 let target_len = entity.name.len();
227 for (start, _) in text.to_lowercase().match_indices(entity.name.as_str()) {
228 if start == 0
229 || text.chars().nth(start - 1).unwrap().is_whitespace()
230 || text.chars().nth(start - 1).unwrap().is_ascii_punctuation()
231 || ((start + target_len) == text.len()
232 || text
233 .chars()
234 .nth(start + target_len)
235 .unwrap_or('N')
236 .is_whitespace()
237 || (text
238 .chars()
239 .nth(start + target_len)
240 .unwrap_or('N')
241 .is_ascii_punctuation()
242 && text.chars().nth(start + target_len).unwrap() != '.'
243 && (start > 0 && text.chars().nth(start - 1).unwrap() != '.')))
244 {
245 return (start, start + target_len, entity.label.to_string());
246 }
247 }
248 (0, 0, String::new())
249 });
250 let annotations: Vec<(usize, usize, String)> = annotations
251 .filter(|(_, _, label)| !label.is_empty())
252 .collect();
253 if !annotations.is_empty() {
254 Some(annotations)
255 } else {
256 None
257 }
258 }
259
260 pub fn annotate(&mut self) {
261 let pb = get_progress_bar(self.texts.len() as u64);
262 pb.set_message("Annotating texts");
263 let start = Instant::now();
264 self.texts
265 .par_iter()
266 .enumerate()
267 .map(|(i, text)| {
268 let t = text.text.clone();
269 let index = Annotations::find_index(t, self.entities.clone());
270 let mut index = match index {
271 Some(index) => index,
272 None => vec![],
273 };
274 index.sort_by(|a, b| a.0.cmp(&b.0));
275 pb.inc(1);
276 Annotation {
277 id: (i + 1) as u32,
278 text: text.text.clone(),
279 label: index,
280 }
281 })
282 .collect_into_vec(&mut self.annotations);
283 let end = start.elapsed();
284 println!(
285 "Time elapsed in find_index() is: {:?} for {} texts",
286 end,
287 self.texts.len() * self.entities.len()
288 );
289 pb.finish();
290 }
291}
292
293impl Quickner {
294 pub fn new(config_file: Option<&str>) -> Self {
298 println!("New instance of Quickner");
299 println!("Configuration file: {config_file:?}");
300 let config_file = match config_file {
301 Some(config_file) => config_file.to_string(),
302 None => "./config.toml".to_string(),
303 };
304 if Path::new(config_file.as_str()).exists() {
306 info!("Configuration file: {}", config_file.as_str());
307 } else {
308 println!("Configuration file {} does not exist", config_file.as_str());
309 error!("Configuration file {} does not exist", config_file.as_str());
310 std::process::exit(1);
311 }
312 let config = Config::from_file(config_file.as_str());
313 Quickner {
314 config,
315 config_file: Config::from_file(config_file.as_str()),
316 }
317 }
318
319 fn parse_config(&self) -> Config {
320 let mut config = self.config.clone();
321 config.entities.filters.set_special_characters();
322 config.texts.filters.set_special_characters();
323 let log_level_is_set = env::var("QUICKNER_LOG_LEVEL_SET").ok();
324 if log_level_is_set.is_none() {
325 match config.logging {
326 Some(ref mut logging) => {
327 env_logger::Builder::from_env(
328 env_logger::Env::default().default_filter_or(logging.level.as_str()),
329 )
330 .init();
331 env::set_var("QUICKNER_LOG_LEVEL_SET", "true");
332 }
333 None => {
334 env_logger::Builder::from_env(
335 env_logger::Env::default().default_filter_or("info"),
336 )
337 .init();
338 env::set_var("QUICKNER_LOG_LEVEL_SET", "true");
339 }
340 };
341 }
342
343 config
344 }
345
346 pub fn process(&self, save: bool) -> Result<Annotations, Box<dyn Error>> {
348 let config = self.parse_config();
349 config.summary();
350
351 info!("----------------------------------------");
352 let entities: HashSet<Entity> = self.entities(
353 config.entities.input.path.as_str(),
354 config.entities.filters,
355 config.entities.input.filter.unwrap_or(false),
356 );
357 let texts: HashSet<Text> = self.texts(
358 config.texts.input.path.as_str(),
359 config.texts.filters,
360 config.texts.input.filter.unwrap_or(false),
361 );
362 let texts: Vec<Text> = texts.into_iter().collect();
363 let excludes: HashSet<String> = match config.entities.excludes.path {
364 Some(path) => {
365 info!("Reading excludes from {}", path.as_str());
366 self.excludes(path.as_str())
367 }
368 None => {
369 info!("No excludes file provided");
370 HashSet::new()
371 }
372 };
373 let entities: HashSet<Entity> = entities
375 .iter()
376 .filter(|entity| !excludes.contains(&entity.name))
377 .cloned()
378 .collect();
379 info!("{} entities found", entities.len());
380 info!("{} texts found", texts.len());
381 let mut annotations = Annotations::new(entities, texts);
382 annotations.annotate();
383 info!("{} annotations found", annotations.annotations.len());
384 if save {
386 let save = config.annotations.format.save(
387 annotations.annotations.clone(),
388 &config.annotations.output.path,
389 );
390 match save {
391 Ok(_) => info!(
392 "Annotations saved with format {:?}",
393 config.annotations.format
394 ),
395 Err(e) => error!("Unable to save the annotations: {}", e),
396 }
397 }
398 Ok(annotations)
404 }
405
406 fn entities(&self, path: &str, filters: Filters, filter: bool) -> HashSet<Entity> {
407 info!("Reading entities from {}", path);
410 let rdr = csv::Reader::from_path(path);
411 match rdr {
412 Ok(mut rdr) => {
413 let mut entities = HashSet::new();
414 for result in rdr.deserialize() {
415 let record: Result<Entity, csv::Error> = result;
416 match record {
417 Ok(entity) => {
418 if filter {
419 if filters.is_valid(&entity.name) {
420 entities.insert(entity);
421 }
422 } else {
423 entities.insert(entity);
424 }
425 }
426 Err(e) => {
427 error!("Unable to parse the entities file: {}", e);
428 std::process::exit(1);
429 }
430 }
431 }
432 entities
433 }
434 Err(e) => {
435 error!("Unable to parse the entities file: {}", e);
436 std::process::exit(1);
437 }
438 }
439 }
440
441 fn texts(&self, path: &str, filters: Filters, filter: bool) -> HashSet<Text> {
442 info!("Reading texts from {}", path);
445 let rdr = csv::Reader::from_path(path);
446 match rdr {
447 Ok(mut rdr) => {
448 let mut texts = HashSet::new();
449 for result in rdr.deserialize() {
450 let record: Result<Text, csv::Error> = result;
451 match record {
452 Ok(text) => {
453 if filter {
454 if filters.is_valid(&text.text) {
455 texts.insert(text);
456 }
457 } else {
458 texts.insert(text);
459 }
460 }
461 Err(e) => {
462 error!("Unable to parse the texts file: {}", e);
463 std::process::exit(1);
464 }
465 }
466 }
467 texts
468 }
469 Err(e) => {
470 error!("Unable to parse the texts file: {}", e);
471 std::process::exit(1);
472 }
473 }
474 }
475
476 fn excludes(&self, path: &str) -> HashSet<String> {
477 let rdr = csv::Reader::from_path(path);
479 match rdr {
480 Ok(mut rdr) => {
481 let mut excludes = HashSet::new();
482 for result in rdr.records() {
483 let record = result.unwrap();
484 excludes.insert(record[0].to_string());
485 }
486 excludes
487 }
488 Err(e) => {
489 error!("Unable to parse the excludes file: {}", e);
490 std::process::exit(1);
491 }
492 }
493 }
494}