1use color_eyre::{
2 eyre::{eyre, Context, ContextCompat},
3 Result,
4};
5use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
6
7use std::{collections::HashSet, fs::File, path::Path, time::Duration};
8
9use serde::{Deserialize, Serialize};
10
11use itertools::Itertools;
12
13fn initialize_spinner_style(msg: String) -> ProgressBar {
15 let pb = ProgressBar::new_spinner();
16 pb.enable_steady_tick(Duration::from_millis(100));
17 pb.with_style(
18 ProgressStyle::with_template("{spinner:.blue} {msg}")
19 .unwrap()
20 .tick_strings(&[
21 "▹▹▹▹▹",
22 "▸▹▹▹▹",
23 "▹▸▹▹▹",
24 "▹▹▸▹▹",
25 "▹▹▹▸▹",
26 "▹▹▹▹▸",
27 "▪▪▪▪▪",
28 ]),
29 )
30 .with_message(msg)
31}
32
33fn initialize_progress_bar(msg: String, len: u64) -> ProgressBar {
35 let pb = ProgressBar::new(len);
36 pb.enable_steady_tick(Duration::from_millis(100));
37 pb.with_style(
38 ProgressStyle::default_bar()
39 .template("{spinner:.blue} {msg} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos:>7}/{len:7} ({eta})").unwrap()
40 .progress_chars("##-"),
41 )
42 .with_message(msg)
43}
44
45#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
47pub struct SearchTerm {
48 pub term: String,
50 pub metadata: Option<String>,
52}
53
54#[derive(Serialize, Debug, Clone, PartialEq)]
56pub struct SearchOutput<'a> {
57 row_id: &'a str,
59 search_term: &'a str,
61 matched_term: &'a str,
63 edits: usize,
65 similarity_score: f64,
67 search_field: &'a str,
69 metadata: &'a Option<String>,
71}
72
73pub fn read_terms_from_file<P: AsRef<Path>>(p: P) -> Result<Vec<SearchTerm>> {
76 let mut rdr = csv::Reader::from_path(p).wrap_err("Unable to read search terms file")?;
77 let mut records: Vec<SearchTerm> = Vec::new();
78 for (i, row) in rdr
79 .deserialize()
80 .enumerate()
81 .progress_with(initialize_spinner_style(
82 "Loading Search Terms...".to_string(),
83 ))
84 {
85 let mut record: SearchTerm =
86 row.wrap_err(format!("Could not load search term from line: {}", i))?;
87 record.term = clean_text(&record.term);
88 records.push(record);
89 }
90 records.sort_by_key(|x| x.term.split_ascii_whitespace().count());
91 Ok(records)
92}
93
94pub fn clean_text(s: &str) -> String {
107 s.replace(|c: char| !c.is_ascii_alphanumeric() && c != '-', " ")
108 .trim()
109 .to_ascii_uppercase()
110}
111
112#[derive(Debug)]
114pub struct DataSet {
115 pub reader: csv::Reader<File>,
117 pub rows: usize,
119 pub clean_search_columns: Vec<ColumnInfo>,
121 pub clean_id_column: Option<ColumnInfo>,
123 pub writer: csv::Writer<File>,
125}
126
127#[derive(Debug, Clone, Default, PartialEq)]
128pub struct ColumnInfo {
129 pub name: String,
130 pub index: usize,
131}
132
133pub fn get_column_info<S: AsRef<str> + PartialEq>(header: &[S], column: &S) -> Result<ColumnInfo> {
150 let pos = header.iter().position(|h| h == column);
151 match pos {
152 Some(i) => Ok(ColumnInfo {
153 name: column.as_ref().to_string(),
154 index: i,
155 }),
156 None => Err(eyre!("Unable to find column {}", column.as_ref())),
157 }
158}
159
160pub fn collect_column_info<S: AsRef<str> + PartialEq>(
179 header: &[S],
180 column_names: &[S],
181) -> Result<Vec<ColumnInfo>> {
182 column_names
183 .iter()
184 .map(|column| get_column_info(header, column))
185 .collect()
186}
187
188pub fn initialize_dataset<P: AsRef<Path>>(
190 data_file: P,
191 search_columns: &[String],
192 id_column: Option<String>,
193) -> Result<DataSet> {
194 let mut rdr = csv::Reader::from_path(&data_file).wrap_err("Unable to initialize csv reader")?;
195 let header = rdr
196 .headers()
197 .wrap_err("Unable to parse csv headers")?
198 .iter()
199 .map(clean_text)
200 .collect_vec();
201 let clean_search_cols = search_columns.iter().map(|c| clean_text(c)).collect_vec();
203 let clean_id_col = id_column.map(|c| clean_text(&c));
204 let column_info = collect_column_info(&header, &clean_search_cols)
205 .wrap_err("Unable to collect column indices")?;
206 let ds = match clean_id_col {
207 Some(c) => DataSet {
208 reader: csv::Reader::from_path(&data_file)
209 .wrap_err("Unable to initialize csv reader")?,
210 rows: rdr.records().count(),
211 clean_search_columns: column_info,
212 clean_id_column: Some(get_column_info(&header, &c)?),
213 writer: csv::Writer::from_path("output.csv")?,
214 },
215 None => DataSet {
216 reader: csv::Reader::from_path(&data_file)
217 .wrap_err("Unable to initialize csv reader")?,
218 rows: rdr.records().count(),
219 clean_search_columns: column_info,
220 clean_id_column: None,
221 writer: csv::Writer::from_path("output.csv")?,
222 },
223 };
224 Ok(ds)
225}
226
227pub fn search(mut dataset: DataSet, search_terms: Vec<SearchTerm>) -> Result<()> {
229 let mut total_records_with_matches = 0;
230 let mut total_records = 0;
231 let mut matched_terms: HashSet<&str> = HashSet::new();
232
233 let spinner =
234 initialize_progress_bar("Searching for matches...".to_string(), dataset.rows as u64);
235 for (i, row) in dataset
236 .reader
237 .records()
238 .enumerate()
239 .progress_with(spinner.clone())
240 {
241 let record = row.wrap_err(format!("Unable to read record from line {}", i))?;
242
243 let id = match &dataset.clean_id_column {
244 Some(c) => record
245 .get(c.index)
246 .wrap_err(format!(
247 "Unable to read id column {} from line {}",
248 c.name, i
249 ))?
250 .to_string(),
251 None => i.to_string(),
252 };
253
254 let mut found_match = false;
255 for column in &dataset.clean_search_columns {
256 let text = record.get(column.index).wrap_err(format!(
257 "Unable to read column {} from line {}",
258 column.name, i
259 ))?;
260 let cleaned_text = clean_text(text);
261 let grams = cleaned_text.split_ascii_whitespace().collect_vec();
262 for (term_len, term_list) in &search_terms
263 .iter()
264 .group_by(|st| st.term.split_ascii_whitespace().count())
265 {
266 let combos = if term_len == 1 {
267 term_list.cartesian_product(
268 grams
269 .iter()
270 .unique()
271 .map(|word| word.to_string())
272 .collect_vec(),
273 )
274 } else {
275 term_list.cartesian_product(
276 grams
277 .windows(term_len)
278 .unique()
279 .map(|words| words.join(" "))
280 .collect_vec(),
281 )
282 };
283 for (search_term, comparison_term) in combos {
284 let edits = strsim::osa_distance(&search_term.term, &comparison_term);
285 match edits {
286 0 => {
287 dataset
288 .writer
289 .serialize(SearchOutput {
290 row_id: &id,
291 search_term: &search_term.term,
292 matched_term: &comparison_term,
293 edits,
294 similarity_score: 1.0,
295 search_field: &column.name,
296 metadata: &search_term.metadata,
297 })
298 .wrap_err("Enable to serialize output")?;
299 found_match = true;
300 matched_terms.insert(&search_term.term);
301 }
302 1 => {
303 let sim = strsim::jaro_winkler(&search_term.term, &comparison_term);
304 if sim >= 0.95 {
305 dataset
306 .writer
307 .serialize(SearchOutput {
308 row_id: &id,
309 search_term: &search_term.term,
310 matched_term: &comparison_term,
311 edits,
312 similarity_score: sim,
313 search_field: &column.name,
314 metadata: &search_term.metadata,
315 })
316 .wrap_err("Enable to serialize output")?;
317 found_match = true;
318 matched_terms.insert(&search_term.term);
319 }
320 }
321 2 => {
322 let sim = strsim::jaro_winkler(&search_term.term, &comparison_term);
323 if sim >= 0.97 {
324 dataset
325 .writer
326 .serialize(SearchOutput {
327 row_id: &id,
328 search_term: &search_term.term,
329 matched_term: &comparison_term,
330 edits,
331 similarity_score: sim,
332 search_field: &column.name,
333 metadata: &search_term.metadata,
334 })
335 .wrap_err("Enable to serialize output")?;
336 found_match = true;
337 matched_terms.insert(&search_term.term);
338 }
339 }
340 _ => continue,
341 }
342 }
343 }
344 }
345 if found_match {
346 total_records_with_matches += 1;
347 }
348 total_records += 1;
349 }
350 dataset.writer.flush().wrap_err("Unable to flush writer")?;
351 spinner.finish_with_message("Done!");
352
353 println!(
354 "Found matches in {:} of {:} records ({:.2}%)",
355 total_records_with_matches,
356 total_records,
357 (total_records_with_matches as f64 / total_records as f64) * 100.0
358 );
359 println!(
360 "Found {:} of {:} search terms ({:.2}%)",
361 matched_terms.len(),
362 search_terms.len(),
363 (matched_terms.len() as f64 / search_terms.len() as f64) * 100.0
364 );
365
366 Ok(())
367}
368
369pub fn run_searcher<P: AsRef<Path>>(
370 data_file: P,
371 search_terms_file: P,
372 search_columns: Vec<String>,
373 id_column: Option<String>,
374) -> Result<()> {
375 let search_terms = read_terms_from_file(search_terms_file)?;
376 let dataset = initialize_dataset(data_file, &search_columns, id_column)?;
377 search(dataset, search_terms)
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn test_clean_text_no_changes() {
386 let s = "This is a test string.";
387 assert_eq!(clean_text(s), "this is a test string".to_ascii_uppercase());
388 }
389
390 #[test]
391 fn test_clean_text_numeric() {
392 let s = "This is a test string with 1234 numbers.";
393 assert_eq!(
394 clean_text(s),
395 "this is a test string with 1234 numbers".to_ascii_uppercase()
396 );
397 }
398
399 #[test]
400 fn test_clean_text_symbols() {
401 let s = "!@#$%^&*()_+-";
402 assert_eq!(clean_text(s), "-");
403 }
404
405 #[test]
406 fn test_clean_empty() {
407 let s = "";
408 assert_eq!(clean_text(s), "");
409 }
410
411 #[test]
412 fn test_clean_end_whitespace() {
413 let s = "!! This is a test string. ";
414 assert_eq!(clean_text(s), "this is a test string".to_ascii_uppercase());
415 }
416
417 #[test]
418 fn test_clean_end_whitespace2() {
419 let s = "!! This is a test to test- - hyphenated string. ";
420 assert_eq!(
421 clean_text(s),
422 "this is a test to test- - hyphenated string".to_ascii_uppercase()
423 );
424 }
425
426 #[test]
427 fn test_whitespace_split() {
428 let s = "!! This is a test to test- - hyphenated string. ";
429 assert_eq!(
430 clean_text(s),
431 "this is a test to test- - hyphenated string".to_ascii_uppercase()
432 );
433 let c = clean_text(s);
434 let v = c.split_ascii_whitespace().collect_vec();
435 assert_eq!(
436 v,
437 vec![
438 "THIS",
439 "IS",
440 "A",
441 "TEST",
442 "TO",
443 "TEST-",
444 "-",
445 "HYPHENATED",
446 "STRING"
447 ]
448 );
449 }
450
451 #[test]
452 fn test_get_column_info() {
453 let header = vec!["a", "b", "c"];
454 let col = "a";
455 assert_eq!(get_column_info(&header, &col).unwrap().index, 0);
456 }
457
458 #[test]
459 fn test_get_column_info_errors() {
460 let header = vec!["a", "b", "c"];
461 let col = "d";
462 assert!(get_column_info(&header, &col).is_err());
463 }
464
465 #[test]
466 fn test_collect_column_info() {
467 let header = vec!["a", "b", "c"];
468 let cols = vec!["a", "b"];
469 let info = collect_column_info(&header, &cols);
470 assert!(info.is_ok());
471 let info = info.unwrap();
472 assert_eq!(info.len(), 2);
473 assert_eq!(
474 info,
475 vec![
476 ColumnInfo {
477 name: "a".to_string(),
478 index: 0
479 },
480 ColumnInfo {
481 name: "b".to_string(),
482 index: 1
483 }
484 ]
485 );
486 }
487
488 #[test]
489 fn test_collect_column_info_sample() -> Result<()> {
490 let header = csv::Reader::from_path("../data/search_terms.csv")?
491 .headers()?
492 .into_iter()
493 .map(clean_text)
494 .collect_vec();
495 let cols = ["term", "metadata"]
496 .iter()
497 .map(|c| clean_text(c))
498 .collect_vec();
499 let info = collect_column_info(&header, &cols)?;
500 assert_eq!(info.len(), 2);
501 Ok(())
502 }
503
504 #[test]
505 fn test_enumerated_reader() {
506 let mut reader = csv::Reader::from_path("../data/search_terms.csv").unwrap();
507 let (i, _) = reader.records().enumerate().next().unwrap();
508 assert!(i == 0);
509 }
510}