Skip to main content

drug_extraction_cli/
lib.rs

1use color_eyre::{
2    eyre::{eyre, Context, ContextCompat},
3    Result,
4};
5use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
6
7use std::{collections::HashSet, fs::File, path::Path, time::Duration};
8
9use serde::{Deserialize, Serialize};
10
11use itertools::Itertools;
12
13/// Create a spinner with default style, takes a message
14fn initialize_spinner_style(msg: String) -> ProgressBar {
15    let pb = ProgressBar::new_spinner();
16    pb.enable_steady_tick(Duration::from_millis(100));
17    pb.with_style(
18        ProgressStyle::with_template("{spinner:.blue} {msg}")
19            .unwrap()
20            .tick_strings(&[
21                "▹▹▹▹▹",
22                "▸▹▹▹▹",
23                "▹▸▹▹▹",
24                "▹▹▸▹▹",
25                "▹▹▹▸▹",
26                "▹▹▹▹▸",
27                "▪▪▪▪▪",
28            ]),
29    )
30    .with_message(msg)
31}
32
33/// Initialize a progress bar with default style, takes a message and length
34fn initialize_progress_bar(msg: String, len: u64) -> ProgressBar {
35    let pb = ProgressBar::new(len);
36    pb.enable_steady_tick(Duration::from_millis(100));
37    pb.with_style(
38        ProgressStyle::default_bar()
39            .template("{spinner:.blue} {msg} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos:>7}/{len:7} ({eta})").unwrap()
40            .progress_chars("##-"),
41    )
42    .with_message(msg)
43}
44
45/// Struct to hold search term and metadata
46#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
47pub struct SearchTerm {
48    /// The search term
49    pub term: String,
50    /// Optional metadata to be included in output
51    pub metadata: Option<String>,
52}
53
54/// Struct to hold search output
55#[derive(Serialize, Debug, Clone, PartialEq)]
56pub struct SearchOutput<'a> {
57    /// The row id of the matched record, either from specified column or line number
58    row_id: &'a str,
59    /// The search term that was matched
60    search_term: &'a str,
61    /// The matched term from the record
62    matched_term: &'a str,
63    /// The number of edits required to match the search term
64    edits: usize,
65    /// The similarity score between the search term and the matched term
66    similarity_score: f64,
67    /// The field that was searched
68    search_field: &'a str,
69    /// The metadata associated with the search term
70    metadata: &'a Option<String>,
71}
72
73/// Function to read in search terms from a csv file
74/// Performs cleaning of terms, ignoring metadata column
75pub fn read_terms_from_file<P: AsRef<Path>>(p: P) -> Result<Vec<SearchTerm>> {
76    let mut rdr = csv::Reader::from_path(p).wrap_err("Unable to read search terms file")?;
77    let mut records: Vec<SearchTerm> = Vec::new();
78    for (i, row) in rdr
79        .deserialize()
80        .enumerate()
81        .progress_with(initialize_spinner_style(
82            "Loading Search Terms...".to_string(),
83        ))
84    {
85        let mut record: SearchTerm =
86            row.wrap_err(format!("Could not load search term from line: {}", i))?;
87        record.term = clean_text(&record.term);
88        records.push(record);
89    }
90    records.sort_by_key(|x| x.term.split_ascii_whitespace().count());
91    Ok(records)
92}
93
94/// Function to remove non-alphanumeric characters from a string
95/// keeps hyphens due to their usage in abbreviations/medical terms.
96/// Also uppercase for standardization.
97/// Example:
98/// ```
99/// use drug_extraction_cli::clean_text;
100///
101/// let s = "This is a test-string with 1234 and some punctuation!@#$%^&*()";
102/// let cleaned = clean_text(s);
103/// assert_ne!(cleaned, "THIS IS A TEST STRING WITH 1234 AND SOME PUNCTUATION");
104/// assert_eq!(cleaned, "THIS IS A TEST-STRING WITH 1234 AND SOME PUNCTUATION");
105/// ```
106pub fn clean_text(s: &str) -> String {
107    s.replace(|c: char| !c.is_ascii_alphanumeric() && c != '-', " ")
108        .trim()
109        .to_ascii_uppercase()
110}
111
112/// Struct to hold information about the dataset
113#[derive(Debug)]
114pub struct DataSet {
115    /// csv reader for the dataset
116    pub reader: csv::Reader<File>,
117    /// rows in the dataset from first scan
118    pub rows: usize,
119    /// indices of the columns to search in the dataset
120    pub clean_search_columns: Vec<ColumnInfo>,
121    /// index of the column to use as an id
122    pub clean_id_column: Option<ColumnInfo>,
123    /// csv writer for the output file
124    pub writer: csv::Writer<File>,
125}
126
127#[derive(Debug, Clone, Default, PartialEq)]
128pub struct ColumnInfo {
129    pub name: String,
130    pub index: usize,
131}
132
133/// Function to get the column index for a given column name
134/// Returns an error if the column name is not found
135/// Typically ran using the header from the csv reader and is called
136/// inside the [collect_column_info] function to do this for each target column.
137/// Example:
138/// ```
139/// use drug_extraction_cli::get_column_info;
140///
141/// let header = vec!["ID", "NAME", "DESCRIPTION"];
142/// let column = "NAME";
143/// let column_info = get_column_info(&header, &column);
144/// assert!(column_info.is_ok());
145/// let column_info = column_info.unwrap();
146/// assert_eq!(column_info.name, "NAME");
147/// assert_eq!(column_info.index, 1);
148/// ```
149pub fn get_column_info<S: AsRef<str> + PartialEq>(header: &[S], column: &S) -> Result<ColumnInfo> {
150    let pos = header.iter().position(|h| h == column);
151    match pos {
152        Some(i) => Ok(ColumnInfo {
153            name: column.as_ref().to_string(),
154            index: i,
155        }),
156        None => Err(eyre!("Unable to find column {}", column.as_ref())),
157    }
158}
159
160/// Function to collect column info for each column to search
161/// Typically ran using the header from the csv reader and is called
162/// inside the [initialize_dataset] function to do this for each target column.
163/// Example:
164/// ```
165/// use drug_extraction_cli::collect_column_info;
166///
167/// let header = vec!["ID", "NAME", "DESCRIPTION"];
168/// let columns = vec!["NAME", "DESCRIPTION"];
169/// let column_info = collect_column_info(&header, &columns);
170/// assert!(column_info.is_ok());
171/// let column_info = column_info.unwrap();
172/// assert_eq!(column_info.len(), 2);
173/// assert_eq!(column_info[0].name, "NAME");
174/// assert_eq!(column_info[0].index, 1);
175/// assert_eq!(column_info[1].name, "DESCRIPTION");
176/// assert_eq!(column_info[1].index, 2);
177/// ```
178pub fn collect_column_info<S: AsRef<str> + PartialEq>(
179    header: &[S],
180    column_names: &[S],
181) -> Result<Vec<ColumnInfo>> {
182    column_names
183        .iter()
184        .map(|column| get_column_info(header, column))
185        .collect()
186}
187
188/// Function to initialize the dataset
189pub fn initialize_dataset<P: AsRef<Path>>(
190    data_file: P,
191    search_columns: &[String],
192    id_column: Option<String>,
193) -> Result<DataSet> {
194    let mut rdr = csv::Reader::from_path(&data_file).wrap_err("Unable to initialize csv reader")?;
195    let header = rdr
196        .headers()
197        .wrap_err("Unable to parse csv headers")?
198        .iter()
199        .map(clean_text)
200        .collect_vec();
201    // clean search cols and id col
202    let clean_search_cols = search_columns.iter().map(|c| clean_text(c)).collect_vec();
203    let clean_id_col = id_column.map(|c| clean_text(&c));
204    let column_info = collect_column_info(&header, &clean_search_cols)
205        .wrap_err("Unable to collect column indices")?;
206    let ds = match clean_id_col {
207        Some(c) => DataSet {
208            reader: csv::Reader::from_path(&data_file)
209                .wrap_err("Unable to initialize csv reader")?,
210            rows: rdr.records().count(),
211            clean_search_columns: column_info,
212            clean_id_column: Some(get_column_info(&header, &c)?),
213            writer: csv::Writer::from_path("output.csv")?,
214        },
215        None => DataSet {
216            reader: csv::Reader::from_path(&data_file)
217                .wrap_err("Unable to initialize csv reader")?,
218            rows: rdr.records().count(),
219            clean_search_columns: column_info,
220            clean_id_column: None,
221            writer: csv::Writer::from_path("output.csv")?,
222        },
223    };
224    Ok(ds)
225}
226
227/// Primary search function
228pub fn search(mut dataset: DataSet, search_terms: Vec<SearchTerm>) -> Result<()> {
229    let mut total_records_with_matches = 0;
230    let mut total_records = 0;
231    let mut matched_terms: HashSet<&str> = HashSet::new();
232
233    let spinner =
234        initialize_progress_bar("Searching for matches...".to_string(), dataset.rows as u64);
235    for (i, row) in dataset
236        .reader
237        .records()
238        .enumerate()
239        .progress_with(spinner.clone())
240    {
241        let record = row.wrap_err(format!("Unable to read record from line {}", i))?;
242
243        let id = match &dataset.clean_id_column {
244            Some(c) => record
245                .get(c.index)
246                .wrap_err(format!(
247                    "Unable to read id column {} from line {}",
248                    c.name, i
249                ))?
250                .to_string(),
251            None => i.to_string(),
252        };
253
254        let mut found_match = false;
255        for column in &dataset.clean_search_columns {
256            let text = record.get(column.index).wrap_err(format!(
257                "Unable to read column {} from line {}",
258                column.name, i
259            ))?;
260            let cleaned_text = clean_text(text);
261            let grams = cleaned_text.split_ascii_whitespace().collect_vec();
262            for (term_len, term_list) in &search_terms
263                .iter()
264                .group_by(|st| st.term.split_ascii_whitespace().count())
265            {
266                let combos = if term_len == 1 {
267                    term_list.cartesian_product(
268                        grams
269                            .iter()
270                            .unique()
271                            .map(|word| word.to_string())
272                            .collect_vec(),
273                    )
274                } else {
275                    term_list.cartesian_product(
276                        grams
277                            .windows(term_len)
278                            .unique()
279                            .map(|words| words.join(" "))
280                            .collect_vec(),
281                    )
282                };
283                for (search_term, comparison_term) in combos {
284                    // outside of window
285                    if search_term.term.len().abs_diff(comparison_term.len()) > 2 {
286                        continue;
287                    }
288                    let edits = strsim::osa_distance(&search_term.term, &comparison_term);
289                    match edits {
290                        0 => {
291                            dataset
292                                .writer
293                                .serialize(SearchOutput {
294                                    row_id: &id,
295                                    search_term: &search_term.term,
296                                    matched_term: &comparison_term,
297                                    edits,
298                                    similarity_score: 1.0,
299                                    search_field: &column.name,
300                                    metadata: &search_term.metadata,
301                                })
302                                .wrap_err("Enable to serialize output")?;
303                            found_match = true;
304                            matched_terms.insert(&search_term.term);
305                        }
306                        1 => {
307                            let sim = strsim::jaro_winkler(&search_term.term, &comparison_term);
308                            if sim >= 0.95 {
309                                dataset
310                                    .writer
311                                    .serialize(SearchOutput {
312                                        row_id: &id,
313                                        search_term: &search_term.term,
314                                        matched_term: &comparison_term,
315                                        edits,
316                                        similarity_score: sim,
317                                        search_field: &column.name,
318                                        metadata: &search_term.metadata,
319                                    })
320                                    .wrap_err("Enable to serialize output")?;
321                                found_match = true;
322                                matched_terms.insert(&search_term.term);
323                            }
324                        }
325                        2 => {
326                            let sim = strsim::jaro_winkler(&search_term.term, &comparison_term);
327                            if sim >= 0.97 {
328                                dataset
329                                    .writer
330                                    .serialize(SearchOutput {
331                                        row_id: &id,
332                                        search_term: &search_term.term,
333                                        matched_term: &comparison_term,
334                                        edits,
335                                        similarity_score: sim,
336                                        search_field: &column.name,
337                                        metadata: &search_term.metadata,
338                                    })
339                                    .wrap_err("Enable to serialize output")?;
340                                found_match = true;
341                                matched_terms.insert(&search_term.term);
342                            }
343                        }
344                        _ => continue,
345                    }
346                }
347            }
348        }
349        if found_match {
350            total_records_with_matches += 1;
351        }
352        total_records += 1;
353    }
354    dataset.writer.flush().wrap_err("Unable to flush writer")?;
355    spinner.finish_with_message("Done!");
356
357    println!(
358        "Found matches in {:} of {:} records ({:.2}%)",
359        total_records_with_matches,
360        total_records,
361        (total_records_with_matches as f64 / total_records as f64) * 100.0
362    );
363    println!(
364        "Found {:} of {:} search terms ({:.2}%)",
365        matched_terms.len(),
366        search_terms.len(),
367        (matched_terms.len() as f64 / search_terms.len() as f64) * 100.0
368    );
369
370    Ok(())
371}
372
373pub fn run_searcher<P: AsRef<Path>>(
374    data_file: P,
375    search_terms_file: P,
376    search_columns: Vec<String>,
377    id_column: Option<String>,
378) -> Result<()> {
379    let search_terms = read_terms_from_file(search_terms_file)?;
380    let dataset = initialize_dataset(data_file, &search_columns, id_column)?;
381    search(dataset, search_terms)
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    #[test]
389    fn test_clean_text_no_changes() {
390        let s = "This is a test string.";
391        assert_eq!(clean_text(s), "this is a test string".to_ascii_uppercase());
392    }
393
394    #[test]
395    fn test_clean_text_numeric() {
396        let s = "This is a test string with 1234 numbers.";
397        assert_eq!(
398            clean_text(s),
399            "this is a test string with 1234 numbers".to_ascii_uppercase()
400        );
401    }
402
403    #[test]
404    fn test_clean_text_symbols() {
405        let s = "!@#$%^&*()_+-";
406        assert_eq!(clean_text(s), "-");
407    }
408
409    #[test]
410    fn test_clean_empty() {
411        let s = "";
412        assert_eq!(clean_text(s), "");
413    }
414
415    #[test]
416    fn test_clean_end_whitespace() {
417        let s = "!! This is a test string.   ";
418        assert_eq!(clean_text(s), "this is a test string".to_ascii_uppercase());
419    }
420
421    #[test]
422    fn test_clean_end_whitespace2() {
423        let s = "!! This is a test to test- - hyphenated string.   ";
424        assert_eq!(
425            clean_text(s),
426            "this is a test to test- - hyphenated string".to_ascii_uppercase()
427        );
428    }
429
430    #[test]
431    fn test_whitespace_split() {
432        let s = "!! This is a test to test- - hyphenated string.   ";
433        assert_eq!(
434            clean_text(s),
435            "this is a test to test- - hyphenated string".to_ascii_uppercase()
436        );
437        let c = clean_text(s);
438        let v = c.split_ascii_whitespace().collect_vec();
439        assert_eq!(
440            v,
441            vec![
442                "THIS",
443                "IS",
444                "A",
445                "TEST",
446                "TO",
447                "TEST-",
448                "-",
449                "HYPHENATED",
450                "STRING"
451            ]
452        );
453    }
454
455    #[test]
456    fn test_get_column_info() {
457        let header = vec!["a", "b", "c"];
458        let col = "a";
459        assert_eq!(get_column_info(&header, &col).unwrap().index, 0);
460    }
461
462    #[test]
463    fn test_get_column_info_errors() {
464        let header = vec!["a", "b", "c"];
465        let col = "d";
466        assert!(get_column_info(&header, &col).is_err());
467    }
468
469    #[test]
470    fn test_collect_column_info() {
471        let header = vec!["a", "b", "c"];
472        let cols = vec!["a", "b"];
473        let info = collect_column_info(&header, &cols);
474        assert!(info.is_ok());
475        let info = info.unwrap();
476        assert_eq!(info.len(), 2);
477        assert_eq!(
478            info,
479            vec![
480                ColumnInfo {
481                    name: "a".to_string(),
482                    index: 0
483                },
484                ColumnInfo {
485                    name: "b".to_string(),
486                    index: 1
487                }
488            ]
489        );
490    }
491
492    #[test]
493    fn test_collect_column_info_sample() -> Result<()> {
494        let header = csv::Reader::from_path("../data/search_terms.csv")?
495            .headers()?
496            .into_iter()
497            .map(clean_text)
498            .collect_vec();
499        let cols = ["term", "metadata"]
500            .iter()
501            .map(|c| clean_text(c))
502            .collect_vec();
503        let info = collect_column_info(&header, &cols)?;
504        assert_eq!(info.len(), 2);
505        Ok(())
506    }
507
508    #[test]
509    fn test_enumerated_reader() {
510        let mut reader = csv::Reader::from_path("../data/search_terms.csv").unwrap();
511        let (i, _) = reader.records().enumerate().next().unwrap();
512        assert!(i == 0);
513    }
514}