drug_extraction_cli/
lib.rs

1use color_eyre::{
2    eyre::{eyre, Context, ContextCompat},
3    Result,
4};
5use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
6
7use std::{collections::HashSet, fs::File, path::Path, time::Duration};
8
9use serde::{Deserialize, Serialize};
10
11use itertools::Itertools;
12
13/// Create a spinner with default style, takes a message
14fn initialize_spinner_style(msg: String) -> ProgressBar {
15    let pb = ProgressBar::new_spinner();
16    pb.enable_steady_tick(Duration::from_millis(100));
17    pb.with_style(
18        ProgressStyle::with_template("{spinner:.blue} {msg}")
19            .unwrap()
20            .tick_strings(&[
21                "▹▹▹▹▹",
22                "▸▹▹▹▹",
23                "▹▸▹▹▹",
24                "▹▹▸▹▹",
25                "▹▹▹▸▹",
26                "▹▹▹▹▸",
27                "▪▪▪▪▪",
28            ]),
29    )
30    .with_message(msg)
31}
32
33/// Initialize a progress bar with default style, takes a message and length
34fn initialize_progress_bar(msg: String, len: u64) -> ProgressBar {
35    let pb = ProgressBar::new(len);
36    pb.enable_steady_tick(Duration::from_millis(100));
37    pb.with_style(
38        ProgressStyle::default_bar()
39            .template("{spinner:.blue} {msg} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos:>7}/{len:7} ({eta})").unwrap()
40            .progress_chars("##-"),
41    )
42    .with_message(msg)
43}
44
45/// Struct to hold search term and metadata
46#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
47pub struct SearchTerm {
48    /// The search term
49    pub term: String,
50    /// Optional metadata to be included in output
51    pub metadata: Option<String>,
52}
53
54/// Struct to hold search output
55#[derive(Serialize, Debug, Clone, PartialEq)]
56pub struct SearchOutput<'a> {
57    /// The row id of the matched record, either from specified column or line number
58    row_id: &'a str,
59    /// The search term that was matched
60    search_term: &'a str,
61    /// The matched term from the record
62    matched_term: &'a str,
63    /// The number of edits required to match the search term
64    edits: usize,
65    /// The similarity score between the search term and the matched term
66    similarity_score: f64,
67    /// The field that was searched
68    search_field: &'a str,
69    /// The metadata associated with the search term
70    metadata: &'a Option<String>,
71}
72
73/// Function to read in search terms from a csv file
74/// Performs cleaning of terms, ignoring metadata column
75pub fn read_terms_from_file<P: AsRef<Path>>(p: P) -> Result<Vec<SearchTerm>> {
76    let mut rdr = csv::Reader::from_path(p).wrap_err("Unable to read search terms file")?;
77    let mut records: Vec<SearchTerm> = Vec::new();
78    for (i, row) in rdr
79        .deserialize()
80        .enumerate()
81        .progress_with(initialize_spinner_style(
82            "Loading Search Terms...".to_string(),
83        ))
84    {
85        let mut record: SearchTerm =
86            row.wrap_err(format!("Could not load search term from line: {}", i))?;
87        record.term = clean_text(&record.term);
88        records.push(record);
89    }
90    records.sort_by_key(|x| x.term.split_ascii_whitespace().count());
91    Ok(records)
92}
93
94/// Function to remove non-alphanumeric characters from a string
95/// keeps hyphens due to their usage in abbreviations/medical terms.
96/// Also uppercase for standardization.
97/// Example:
98/// ```
99/// use drug_extraction_cli::clean_text;
100///
101/// let s = "This is a test-string with 1234 and some punctuation!@#$%^&*()";
102/// let cleaned = clean_text(s);
103/// assert_ne!(cleaned, "THIS IS A TEST STRING WITH 1234 AND SOME PUNCTUATION");
104/// assert_eq!(cleaned, "THIS IS A TEST-STRING WITH 1234 AND SOME PUNCTUATION");
105/// ```
106pub fn clean_text(s: &str) -> String {
107    s.replace(|c: char| !c.is_ascii_alphanumeric() && c != '-', " ")
108        .trim()
109        .to_ascii_uppercase()
110}
111
112/// Struct to hold information about the dataset
113#[derive(Debug)]
114pub struct DataSet {
115    /// csv reader for the dataset
116    pub reader: csv::Reader<File>,
117    /// rows in the dataset from first scan
118    pub rows: usize,
119    /// indices of the columns to search in the dataset
120    pub clean_search_columns: Vec<ColumnInfo>,
121    /// index of the column to use as an id
122    pub clean_id_column: Option<ColumnInfo>,
123    /// csv writer for the output file
124    pub writer: csv::Writer<File>,
125}
126
127#[derive(Debug, Clone, Default, PartialEq)]
128pub struct ColumnInfo {
129    pub name: String,
130    pub index: usize,
131}
132
133/// Function to get the column index for a given column name
134/// Returns an error if the column name is not found
135/// Typically ran using the header from the csv reader and is called
136/// inside the [collect_column_info] function to do this for each target column.
137/// Example:
138/// ```
139/// use drug_extraction_cli::get_column_info;
140///
141/// let header = vec!["ID", "NAME", "DESCRIPTION"];
142/// let column = "NAME";
143/// let column_info = get_column_info(&header, &column);
144/// assert!(column_info.is_ok());
145/// let column_info = column_info.unwrap();
146/// assert_eq!(column_info.name, "NAME");
147/// assert_eq!(column_info.index, 1);
148/// ```
149pub fn get_column_info<S: AsRef<str> + PartialEq>(header: &[S], column: &S) -> Result<ColumnInfo> {
150    let pos = header.iter().position(|h| h == column);
151    match pos {
152        Some(i) => Ok(ColumnInfo {
153            name: column.as_ref().to_string(),
154            index: i,
155        }),
156        None => Err(eyre!("Unable to find column {}", column.as_ref())),
157    }
158}
159
160/// Function to collect column info for each column to search
161/// Typically ran using the header from the csv reader and is called
162/// inside the [initialize_dataset] function to do this for each target column.
163/// Example:
164/// ```
165/// use drug_extraction_cli::collect_column_info;
166///
167/// let header = vec!["ID", "NAME", "DESCRIPTION"];
168/// let columns = vec!["NAME", "DESCRIPTION"];
169/// let column_info = collect_column_info(&header, &columns);
170/// assert!(column_info.is_ok());
171/// let column_info = column_info.unwrap();
172/// assert_eq!(column_info.len(), 2);
173/// assert_eq!(column_info[0].name, "NAME");
174/// assert_eq!(column_info[0].index, 1);
175/// assert_eq!(column_info[1].name, "DESCRIPTION");
176/// assert_eq!(column_info[1].index, 2);
177/// ```
178pub fn collect_column_info<S: AsRef<str> + PartialEq>(
179    header: &[S],
180    column_names: &[S],
181) -> Result<Vec<ColumnInfo>> {
182    column_names
183        .iter()
184        .map(|column| get_column_info(header, column))
185        .collect()
186}
187
188/// Function to initialize the dataset
189pub fn initialize_dataset<P: AsRef<Path>>(
190    data_file: P,
191    search_columns: &[String],
192    id_column: Option<String>,
193) -> Result<DataSet> {
194    let mut rdr = csv::Reader::from_path(&data_file).wrap_err("Unable to initialize csv reader")?;
195    let header = rdr
196        .headers()
197        .wrap_err("Unable to parse csv headers")?
198        .iter()
199        .map(clean_text)
200        .collect_vec();
201    // clean search cols and id col
202    let clean_search_cols = search_columns.iter().map(|c| clean_text(c)).collect_vec();
203    let clean_id_col = id_column.map(|c| clean_text(&c));
204    let column_info = collect_column_info(&header, &clean_search_cols)
205        .wrap_err("Unable to collect column indices")?;
206    let ds = match clean_id_col {
207        Some(c) => DataSet {
208            reader: csv::Reader::from_path(&data_file)
209                .wrap_err("Unable to initialize csv reader")?,
210            rows: rdr.records().count(),
211            clean_search_columns: column_info,
212            clean_id_column: Some(get_column_info(&header, &c)?),
213            writer: csv::Writer::from_path("output.csv")?,
214        },
215        None => DataSet {
216            reader: csv::Reader::from_path(&data_file)
217                .wrap_err("Unable to initialize csv reader")?,
218            rows: rdr.records().count(),
219            clean_search_columns: column_info,
220            clean_id_column: None,
221            writer: csv::Writer::from_path("output.csv")?,
222        },
223    };
224    Ok(ds)
225}
226
227/// Primary search function
228pub fn search(mut dataset: DataSet, search_terms: Vec<SearchTerm>) -> Result<()> {
229    let mut total_records_with_matches = 0;
230    let mut total_records = 0;
231    let mut matched_terms: HashSet<&str> = HashSet::new();
232
233    let spinner =
234        initialize_progress_bar("Searching for matches...".to_string(), dataset.rows as u64);
235    for (i, row) in dataset
236        .reader
237        .records()
238        .enumerate()
239        .progress_with(spinner.clone())
240    {
241        let record = row.wrap_err(format!("Unable to read record from line {}", i))?;
242
243        let id = match &dataset.clean_id_column {
244            Some(c) => record
245                .get(c.index)
246                .wrap_err(format!(
247                    "Unable to read id column {} from line {}",
248                    c.name, i
249                ))?
250                .to_string(),
251            None => i.to_string(),
252        };
253
254        let mut found_match = false;
255        for column in &dataset.clean_search_columns {
256            let text = record.get(column.index).wrap_err(format!(
257                "Unable to read column {} from line {}",
258                column.name, i
259            ))?;
260            let cleaned_text = clean_text(text);
261            let grams = cleaned_text.split_ascii_whitespace().collect_vec();
262            for (term_len, term_list) in &search_terms
263                .iter()
264                .group_by(|st| st.term.split_ascii_whitespace().count())
265            {
266                let combos = if term_len == 1 {
267                    term_list.cartesian_product(
268                        grams
269                            .iter()
270                            .unique()
271                            .map(|word| word.to_string())
272                            .collect_vec(),
273                    )
274                } else {
275                    term_list.cartesian_product(
276                        grams
277                            .windows(term_len)
278                            .unique()
279                            .map(|words| words.join(" "))
280                            .collect_vec(),
281                    )
282                };
283                for (search_term, comparison_term) in combos {
284                    let edits = strsim::osa_distance(&search_term.term, &comparison_term);
285                    match edits {
286                        0 => {
287                            dataset
288                                .writer
289                                .serialize(SearchOutput {
290                                    row_id: &id,
291                                    search_term: &search_term.term,
292                                    matched_term: &comparison_term,
293                                    edits,
294                                    similarity_score: 1.0,
295                                    search_field: &column.name,
296                                    metadata: &search_term.metadata,
297                                })
298                                .wrap_err("Enable to serialize output")?;
299                            found_match = true;
300                            matched_terms.insert(&search_term.term);
301                        }
302                        1 => {
303                            let sim = strsim::jaro_winkler(&search_term.term, &comparison_term);
304                            if sim >= 0.95 {
305                                dataset
306                                    .writer
307                                    .serialize(SearchOutput {
308                                        row_id: &id,
309                                        search_term: &search_term.term,
310                                        matched_term: &comparison_term,
311                                        edits,
312                                        similarity_score: sim,
313                                        search_field: &column.name,
314                                        metadata: &search_term.metadata,
315                                    })
316                                    .wrap_err("Enable to serialize output")?;
317                                found_match = true;
318                                matched_terms.insert(&search_term.term);
319                            }
320                        }
321                        2 => {
322                            let sim = strsim::jaro_winkler(&search_term.term, &comparison_term);
323                            if sim >= 0.97 {
324                                dataset
325                                    .writer
326                                    .serialize(SearchOutput {
327                                        row_id: &id,
328                                        search_term: &search_term.term,
329                                        matched_term: &comparison_term,
330                                        edits,
331                                        similarity_score: sim,
332                                        search_field: &column.name,
333                                        metadata: &search_term.metadata,
334                                    })
335                                    .wrap_err("Enable to serialize output")?;
336                                found_match = true;
337                                matched_terms.insert(&search_term.term);
338                            }
339                        }
340                        _ => continue,
341                    }
342                }
343            }
344        }
345        if found_match {
346            total_records_with_matches += 1;
347        }
348        total_records += 1;
349    }
350    dataset.writer.flush().wrap_err("Unable to flush writer")?;
351    spinner.finish_with_message("Done!");
352
353    println!(
354        "Found matches in {:} of {:} records ({:.2}%)",
355        total_records_with_matches,
356        total_records,
357        (total_records_with_matches as f64 / total_records as f64) * 100.0
358    );
359    println!(
360        "Found {:} of {:} search terms ({:.2}%)",
361        matched_terms.len(),
362        search_terms.len(),
363        (matched_terms.len() as f64 / search_terms.len() as f64) * 100.0
364    );
365
366    Ok(())
367}
368
369pub fn run_searcher<P: AsRef<Path>>(
370    data_file: P,
371    search_terms_file: P,
372    search_columns: Vec<String>,
373    id_column: Option<String>,
374) -> Result<()> {
375    let search_terms = read_terms_from_file(search_terms_file)?;
376    let dataset = initialize_dataset(data_file, &search_columns, id_column)?;
377    search(dataset, search_terms)
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn test_clean_text_no_changes() {
386        let s = "This is a test string.";
387        assert_eq!(clean_text(s), "this is a test string".to_ascii_uppercase());
388    }
389
390    #[test]
391    fn test_clean_text_numeric() {
392        let s = "This is a test string with 1234 numbers.";
393        assert_eq!(
394            clean_text(s),
395            "this is a test string with 1234 numbers".to_ascii_uppercase()
396        );
397    }
398
399    #[test]
400    fn test_clean_text_symbols() {
401        let s = "!@#$%^&*()_+-";
402        assert_eq!(clean_text(s), "-");
403    }
404
405    #[test]
406    fn test_clean_empty() {
407        let s = "";
408        assert_eq!(clean_text(s), "");
409    }
410
411    #[test]
412    fn test_clean_end_whitespace() {
413        let s = "!! This is a test string.   ";
414        assert_eq!(clean_text(s), "this is a test string".to_ascii_uppercase());
415    }
416
417    #[test]
418    fn test_clean_end_whitespace2() {
419        let s = "!! This is a test to test- - hyphenated string.   ";
420        assert_eq!(
421            clean_text(s),
422            "this is a test to test- - hyphenated string".to_ascii_uppercase()
423        );
424    }
425
426    #[test]
427    fn test_whitespace_split() {
428        let s = "!! This is a test to test- - hyphenated string.   ";
429        assert_eq!(
430            clean_text(s),
431            "this is a test to test- - hyphenated string".to_ascii_uppercase()
432        );
433        let c = clean_text(s);
434        let v = c.split_ascii_whitespace().collect_vec();
435        assert_eq!(
436            v,
437            vec![
438                "THIS",
439                "IS",
440                "A",
441                "TEST",
442                "TO",
443                "TEST-",
444                "-",
445                "HYPHENATED",
446                "STRING"
447            ]
448        );
449    }
450
451    #[test]
452    fn test_get_column_info() {
453        let header = vec!["a", "b", "c"];
454        let col = "a";
455        assert_eq!(get_column_info(&header, &col).unwrap().index, 0);
456    }
457
458    #[test]
459    fn test_get_column_info_errors() {
460        let header = vec!["a", "b", "c"];
461        let col = "d";
462        assert!(get_column_info(&header, &col).is_err());
463    }
464
465    #[test]
466    fn test_collect_column_info() {
467        let header = vec!["a", "b", "c"];
468        let cols = vec!["a", "b"];
469        let info = collect_column_info(&header, &cols);
470        assert!(info.is_ok());
471        let info = info.unwrap();
472        assert_eq!(info.len(), 2);
473        assert_eq!(
474            info,
475            vec![
476                ColumnInfo {
477                    name: "a".to_string(),
478                    index: 0
479                },
480                ColumnInfo {
481                    name: "b".to_string(),
482                    index: 1
483                }
484            ]
485        );
486    }
487
488    #[test]
489    fn test_collect_column_info_sample() -> Result<()> {
490        let header = csv::Reader::from_path("../data/search_terms.csv")?
491            .headers()?
492            .into_iter()
493            .map(clean_text)
494            .collect_vec();
495        let cols = ["term", "metadata"]
496            .iter()
497            .map(|c| clean_text(c))
498            .collect_vec();
499        let info = collect_column_info(&header, &cols)?;
500        assert_eq!(info.len(), 2);
501        Ok(())
502    }
503
504    #[test]
505    fn test_enumerated_reader() {
506        let mut reader = csv::Reader::from_path("../data/search_terms.csv").unwrap();
507        let (i, _) = reader.records().enumerate().next().unwrap();
508        assert!(i == 0);
509    }
510}