1use color_eyre::{
2 eyre::{eyre, Context, ContextCompat},
3 Result,
4};
5use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
6
7use std::{collections::HashSet, fs::File, path::Path, time::Duration};
8
9use serde::{Deserialize, Serialize};
10
11use itertools::Itertools;
12
13fn initialize_spinner_style(msg: String) -> ProgressBar {
15 let pb = ProgressBar::new_spinner();
16 pb.enable_steady_tick(Duration::from_millis(100));
17 pb.with_style(
18 ProgressStyle::with_template("{spinner:.blue} {msg}")
19 .unwrap()
20 .tick_strings(&[
21 "▹▹▹▹▹",
22 "▸▹▹▹▹",
23 "▹▸▹▹▹",
24 "▹▹▸▹▹",
25 "▹▹▹▸▹",
26 "▹▹▹▹▸",
27 "▪▪▪▪▪",
28 ]),
29 )
30 .with_message(msg)
31}
32
33fn initialize_progress_bar(msg: String, len: u64) -> ProgressBar {
35 let pb = ProgressBar::new(len);
36 pb.enable_steady_tick(Duration::from_millis(100));
37 pb.with_style(
38 ProgressStyle::default_bar()
39 .template("{spinner:.blue} {msg} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos:>7}/{len:7} ({eta})").unwrap()
40 .progress_chars("##-"),
41 )
42 .with_message(msg)
43}
44
45#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
47pub struct SearchTerm {
48 pub term: String,
50 pub metadata: Option<String>,
52}
53
54#[derive(Serialize, Debug, Clone, PartialEq)]
56pub struct SearchOutput<'a> {
57 row_id: &'a str,
59 search_term: &'a str,
61 matched_term: &'a str,
63 edits: usize,
65 similarity_score: f64,
67 search_field: &'a str,
69 metadata: &'a Option<String>,
71}
72
73pub fn read_terms_from_file<P: AsRef<Path>>(p: P) -> Result<Vec<SearchTerm>> {
76 let mut rdr = csv::Reader::from_path(p).wrap_err("Unable to read search terms file")?;
77 let mut records: Vec<SearchTerm> = Vec::new();
78 for (i, row) in rdr
79 .deserialize()
80 .enumerate()
81 .progress_with(initialize_spinner_style(
82 "Loading Search Terms...".to_string(),
83 ))
84 {
85 let mut record: SearchTerm =
86 row.wrap_err(format!("Could not load search term from line: {}", i))?;
87 record.term = clean_text(&record.term);
88 records.push(record);
89 }
90 records.sort_by_key(|x| x.term.split_ascii_whitespace().count());
91 Ok(records)
92}
93
94pub fn clean_text(s: &str) -> String {
107 s.replace(|c: char| !c.is_ascii_alphanumeric() && c != '-', " ")
108 .trim()
109 .to_ascii_uppercase()
110}
111
112#[derive(Debug)]
114pub struct DataSet {
115 pub reader: csv::Reader<File>,
117 pub rows: usize,
119 pub clean_search_columns: Vec<ColumnInfo>,
121 pub clean_id_column: Option<ColumnInfo>,
123 pub writer: csv::Writer<File>,
125}
126
127#[derive(Debug, Clone, Default, PartialEq)]
128pub struct ColumnInfo {
129 pub name: String,
130 pub index: usize,
131}
132
133pub fn get_column_info<S: AsRef<str> + PartialEq>(header: &[S], column: &S) -> Result<ColumnInfo> {
150 let pos = header.iter().position(|h| h == column);
151 match pos {
152 Some(i) => Ok(ColumnInfo {
153 name: column.as_ref().to_string(),
154 index: i,
155 }),
156 None => Err(eyre!("Unable to find column {}", column.as_ref())),
157 }
158}
159
160pub fn collect_column_info<S: AsRef<str> + PartialEq>(
179 header: &[S],
180 column_names: &[S],
181) -> Result<Vec<ColumnInfo>> {
182 column_names
183 .iter()
184 .map(|column| get_column_info(header, column))
185 .collect()
186}
187
188pub fn initialize_dataset<P: AsRef<Path>>(
190 data_file: P,
191 search_columns: &[String],
192 id_column: Option<String>,
193) -> Result<DataSet> {
194 let mut rdr = csv::Reader::from_path(&data_file).wrap_err("Unable to initialize csv reader")?;
195 let header = rdr
196 .headers()
197 .wrap_err("Unable to parse csv headers")?
198 .iter()
199 .map(clean_text)
200 .collect_vec();
201 let clean_search_cols = search_columns.iter().map(|c| clean_text(c)).collect_vec();
203 let clean_id_col = id_column.map(|c| clean_text(&c));
204 let column_info = collect_column_info(&header, &clean_search_cols)
205 .wrap_err("Unable to collect column indices")?;
206 let ds = match clean_id_col {
207 Some(c) => DataSet {
208 reader: csv::Reader::from_path(&data_file)
209 .wrap_err("Unable to initialize csv reader")?,
210 rows: rdr.records().count(),
211 clean_search_columns: column_info,
212 clean_id_column: Some(get_column_info(&header, &c)?),
213 writer: csv::Writer::from_path("output.csv")?,
214 },
215 None => DataSet {
216 reader: csv::Reader::from_path(&data_file)
217 .wrap_err("Unable to initialize csv reader")?,
218 rows: rdr.records().count(),
219 clean_search_columns: column_info,
220 clean_id_column: None,
221 writer: csv::Writer::from_path("output.csv")?,
222 },
223 };
224 Ok(ds)
225}
226
227pub fn search(mut dataset: DataSet, search_terms: Vec<SearchTerm>) -> Result<()> {
229 let mut total_records_with_matches = 0;
230 let mut total_records = 0;
231 let mut matched_terms: HashSet<&str> = HashSet::new();
232
233 let spinner =
234 initialize_progress_bar("Searching for matches...".to_string(), dataset.rows as u64);
235 for (i, row) in dataset
236 .reader
237 .records()
238 .enumerate()
239 .progress_with(spinner.clone())
240 {
241 let record = row.wrap_err(format!("Unable to read record from line {}", i))?;
242
243 let id = match &dataset.clean_id_column {
244 Some(c) => record
245 .get(c.index)
246 .wrap_err(format!(
247 "Unable to read id column {} from line {}",
248 c.name, i
249 ))?
250 .to_string(),
251 None => i.to_string(),
252 };
253
254 let mut found_match = false;
255 for column in &dataset.clean_search_columns {
256 let text = record.get(column.index).wrap_err(format!(
257 "Unable to read column {} from line {}",
258 column.name, i
259 ))?;
260 let cleaned_text = clean_text(text);
261 let grams = cleaned_text.split_ascii_whitespace().collect_vec();
262 for (term_len, term_list) in &search_terms
263 .iter()
264 .group_by(|st| st.term.split_ascii_whitespace().count())
265 {
266 let combos = if term_len == 1 {
267 term_list.cartesian_product(
268 grams
269 .iter()
270 .unique()
271 .map(|word| word.to_string())
272 .collect_vec(),
273 )
274 } else {
275 term_list.cartesian_product(
276 grams
277 .windows(term_len)
278 .unique()
279 .map(|words| words.join(" "))
280 .collect_vec(),
281 )
282 };
283 for (search_term, comparison_term) in combos {
284 if search_term.term.len().abs_diff(comparison_term.len()) > 2 {
286 continue;
287 }
288 let edits = strsim::osa_distance(&search_term.term, &comparison_term);
289 match edits {
290 0 => {
291 dataset
292 .writer
293 .serialize(SearchOutput {
294 row_id: &id,
295 search_term: &search_term.term,
296 matched_term: &comparison_term,
297 edits,
298 similarity_score: 1.0,
299 search_field: &column.name,
300 metadata: &search_term.metadata,
301 })
302 .wrap_err("Enable to serialize output")?;
303 found_match = true;
304 matched_terms.insert(&search_term.term);
305 }
306 1 => {
307 let sim = strsim::jaro_winkler(&search_term.term, &comparison_term);
308 if sim >= 0.95 {
309 dataset
310 .writer
311 .serialize(SearchOutput {
312 row_id: &id,
313 search_term: &search_term.term,
314 matched_term: &comparison_term,
315 edits,
316 similarity_score: sim,
317 search_field: &column.name,
318 metadata: &search_term.metadata,
319 })
320 .wrap_err("Enable to serialize output")?;
321 found_match = true;
322 matched_terms.insert(&search_term.term);
323 }
324 }
325 2 => {
326 let sim = strsim::jaro_winkler(&search_term.term, &comparison_term);
327 if sim >= 0.97 {
328 dataset
329 .writer
330 .serialize(SearchOutput {
331 row_id: &id,
332 search_term: &search_term.term,
333 matched_term: &comparison_term,
334 edits,
335 similarity_score: sim,
336 search_field: &column.name,
337 metadata: &search_term.metadata,
338 })
339 .wrap_err("Enable to serialize output")?;
340 found_match = true;
341 matched_terms.insert(&search_term.term);
342 }
343 }
344 _ => continue,
345 }
346 }
347 }
348 }
349 if found_match {
350 total_records_with_matches += 1;
351 }
352 total_records += 1;
353 }
354 dataset.writer.flush().wrap_err("Unable to flush writer")?;
355 spinner.finish_with_message("Done!");
356
357 println!(
358 "Found matches in {:} of {:} records ({:.2}%)",
359 total_records_with_matches,
360 total_records,
361 (total_records_with_matches as f64 / total_records as f64) * 100.0
362 );
363 println!(
364 "Found {:} of {:} search terms ({:.2}%)",
365 matched_terms.len(),
366 search_terms.len(),
367 (matched_terms.len() as f64 / search_terms.len() as f64) * 100.0
368 );
369
370 Ok(())
371}
372
373pub fn run_searcher<P: AsRef<Path>>(
374 data_file: P,
375 search_terms_file: P,
376 search_columns: Vec<String>,
377 id_column: Option<String>,
378) -> Result<()> {
379 let search_terms = read_terms_from_file(search_terms_file)?;
380 let dataset = initialize_dataset(data_file, &search_columns, id_column)?;
381 search(dataset, search_terms)
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
389 fn test_clean_text_no_changes() {
390 let s = "This is a test string.";
391 assert_eq!(clean_text(s), "this is a test string".to_ascii_uppercase());
392 }
393
394 #[test]
395 fn test_clean_text_numeric() {
396 let s = "This is a test string with 1234 numbers.";
397 assert_eq!(
398 clean_text(s),
399 "this is a test string with 1234 numbers".to_ascii_uppercase()
400 );
401 }
402
403 #[test]
404 fn test_clean_text_symbols() {
405 let s = "!@#$%^&*()_+-";
406 assert_eq!(clean_text(s), "-");
407 }
408
409 #[test]
410 fn test_clean_empty() {
411 let s = "";
412 assert_eq!(clean_text(s), "");
413 }
414
415 #[test]
416 fn test_clean_end_whitespace() {
417 let s = "!! This is a test string. ";
418 assert_eq!(clean_text(s), "this is a test string".to_ascii_uppercase());
419 }
420
421 #[test]
422 fn test_clean_end_whitespace2() {
423 let s = "!! This is a test to test- - hyphenated string. ";
424 assert_eq!(
425 clean_text(s),
426 "this is a test to test- - hyphenated string".to_ascii_uppercase()
427 );
428 }
429
430 #[test]
431 fn test_whitespace_split() {
432 let s = "!! This is a test to test- - hyphenated string. ";
433 assert_eq!(
434 clean_text(s),
435 "this is a test to test- - hyphenated string".to_ascii_uppercase()
436 );
437 let c = clean_text(s);
438 let v = c.split_ascii_whitespace().collect_vec();
439 assert_eq!(
440 v,
441 vec![
442 "THIS",
443 "IS",
444 "A",
445 "TEST",
446 "TO",
447 "TEST-",
448 "-",
449 "HYPHENATED",
450 "STRING"
451 ]
452 );
453 }
454
455 #[test]
456 fn test_get_column_info() {
457 let header = vec!["a", "b", "c"];
458 let col = "a";
459 assert_eq!(get_column_info(&header, &col).unwrap().index, 0);
460 }
461
462 #[test]
463 fn test_get_column_info_errors() {
464 let header = vec!["a", "b", "c"];
465 let col = "d";
466 assert!(get_column_info(&header, &col).is_err());
467 }
468
469 #[test]
470 fn test_collect_column_info() {
471 let header = vec!["a", "b", "c"];
472 let cols = vec!["a", "b"];
473 let info = collect_column_info(&header, &cols);
474 assert!(info.is_ok());
475 let info = info.unwrap();
476 assert_eq!(info.len(), 2);
477 assert_eq!(
478 info,
479 vec![
480 ColumnInfo {
481 name: "a".to_string(),
482 index: 0
483 },
484 ColumnInfo {
485 name: "b".to_string(),
486 index: 1
487 }
488 ]
489 );
490 }
491
492 #[test]
493 fn test_collect_column_info_sample() -> Result<()> {
494 let header = csv::Reader::from_path("../data/search_terms.csv")?
495 .headers()?
496 .into_iter()
497 .map(clean_text)
498 .collect_vec();
499 let cols = ["term", "metadata"]
500 .iter()
501 .map(|c| clean_text(c))
502 .collect_vec();
503 let info = collect_column_info(&header, &cols)?;
504 assert_eq!(info.len(), 2);
505 Ok(())
506 }
507
508 #[test]
509 fn test_enumerated_reader() {
510 let mut reader = csv::Reader::from_path("../data/search_terms.csv").unwrap();
511 let (i, _) = reader.records().enumerate().next().unwrap();
512 assert!(i == 0);
513 }
514}