csv_log_cleaner/
lib.rs

1//! Clean CSV files to conform to a type schema by streaming them
2//! through small memory buffers using multiple threads and
3//! logging data loss.
4//!
5//! # Documentation
6//! [Github](https://github.com/ambidextrous/csv_log_cleaner)
7
8use chrono::NaiveDate;
9use csv::{Reader, StringRecord, Writer};
10use rayon::{ThreadPool, ThreadPoolBuildError};
11use rustc_hash::FxHashMap; // Lots of small HashMaps used, so prioritize fast writes and look ups over collision avoidance
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::error::Error;
15use std::io;
16use std::io::ErrorKind;
17use std::iter::Iterator;
18use std::marker::Send;
19use std::sync::mpsc;
20use std::sync::mpsc::{Receiver, Sender};
21use std::sync::{Arc, Mutex};
22use std::vec::Vec;
23
24#[derive(Debug, Clone)]
25struct Constants {
26    null_vals: Vec<String>,
27    bool_vals: Vec<String>,
28}
29
30#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Copy)]
31enum ColumnType {
32    String,
33    Int,
34    Date,
35    Float,
36    Enum,
37    Bool,
38}
39
40#[derive(Serialize, Deserialize, Debug, Clone)]
41pub struct Column {
42    column_type: ColumnType,
43    illegal_val_replacement: String,
44    legal_vals: Vec<String>,
45    format: String,
46}
47
48#[derive(Serialize, Deserialize, Debug, Clone)]
49struct Schema {
50    columns: Vec<Column>,
51}
52
53#[derive(Serialize, Deserialize, Debug, Clone)]
54pub struct JsonColumn {
55    name: String,
56    column_type: ColumnType,
57    illegal_val_replacement: Option<String>,
58    legal_vals: Option<Vec<String>>,
59    format: Option<String>,
60}
61
62#[derive(Serialize, Deserialize, Debug, Clone)]
63pub struct JsonSchema {
64    columns: Vec<JsonColumn>,
65}
66
67#[derive(Debug, Clone)]
68struct ProcessRowBufferConfig<'a, W>
69where
70    W: io::Write + Send + Sync,
71{
72    column_names: &'a StringRecord,
73    schema_map: &'a FxHashMap<String, Column>,
74    row_buffer: &'a [FxHashMap<String, String>],
75    constants: &'a Constants,
76    locked_wtr: Arc<Mutex<Writer<W>>>,
77    column_string_names: &'a [String],
78    tx: Sender<FxHashMap<String, ColumnLog>>,
79}
80
81type StringSender = Sender<Result<(), String>>;
82
83type StringReciever = Receiver<Result<(), String>>;
84
85type MapSender = Sender<FxHashMap<String, ColumnLog>>;
86
87type MapReceiver = Receiver<FxHashMap<String, ColumnLog>>;
88
89/// Holds data on the invalid count and max and min invalid string values
90/// (calculated by String comparison) found in that column.
91///
92/// # Examples
93///
94/// ```
95/// use csv_log_cleaner::ColumnLog;
96///
97/// let date_of_birth_column_log = ColumnLog {
98///     name: "DATE_OF_BIRTH".to_string(),
99///     invalid_count: 2,
100///     max_invalid: Some("2444-89-01".to_string()),
101///     min_invalid: Some("2004-31-01".to_string()),
102/// };
103/// ```
104#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
105pub struct ColumnLog {
106    pub name: String,
107    pub invalid_count: i32,
108    pub max_invalid: Option<String>,
109    pub min_invalid: Option<String>,
110}
111
112#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
113pub struct CleansingLog {
114    pub total_rows: i32,
115    pub log_map: HashMap<String, ColumnLog>,
116}
117
118type Record = FxHashMap<String, String>;
119
120/// Error representing a problem that occurred during the CSV cleaning process.
121///
122/// ```
123/// use csv_log_cleaner::CSVCleansingError;
124///
125/// assert_eq!(CSVCleansingError::new("Test error message".to_string()).to_string(), "Test error message".to_string());
126/// ```
127#[derive(Debug)]
128pub struct CSVCleansingError {
129    message: String,
130}
131
132impl CSVCleansingError {
133    /// Error representing a problem that occurred during the CSV cleaning process.
134    ///
135    /// ```
136    /// use csv_log_cleaner::CSVCleansingError;
137    ///
138    /// assert_eq!(CSVCleansingError::new("Test error message".to_string()).to_string(), "Test error message".to_string());
139    /// ```
140    pub fn new(message: String) -> CSVCleansingError {
141        CSVCleansingError { message }
142    }
143}
144
145impl Error for CSVCleansingError {}
146
147impl std::fmt::Display for CSVCleansingError {
148    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
149        f.write_str(&self.message)
150    }
151}
152
153/// Clean CSV files to conform to a type schema by streaming them through small memory buffers using multiple threads and logging data loss.
154///
155/// # Examples
156///
157/// ```
158/// use std::error::Error;
159/// use csv::{Reader,Writer};
160/// use csv_log_cleaner::{clean_csv, ColumnLog, get_schema_from_json_str};
161/// use tempfile::tempdir;
162/// use std::fs;
163///
164/// // Arrange
165/// let dir = tempdir().expect("To be able to create temporary directory");
166/// let input_csv_data = r#"NAME,AGE,DATE_OF_BIRTH
167/// Raul,27,2004-01-31
168/// Duke,27.8,2004-31-01
169/// "#;
170/// let input_path = dir.path().join("input.csv");
171/// let output_path = dir.path().join("output.csv");
172/// fs::write(input_path.clone(), input_csv_data).expect("To be able to write file");
173/// let mut csv_rdr = Reader::from_path(input_path).expect("To be able to create reader");   
174/// let mut csv_wtr = Writer::from_path(output_path.clone()).expect("To be able to create writer");
175/// let schema_path = dir.path().join("schema.json");
176/// let schema_path_str = schema_path.to_str().unwrap();
177/// let schema_path_string = String::from(schema_path_str);
178/// let schema_string = r#"{
179/// "columns": [
180///     {
181///         "name": "NAME",
182///         "column_type": "String"
183///     },
184///     {
185///         "name": "AGE",
186///         "column_type": "Int"
187///     },
188///     {
189///         "name": "DATE_OF_BIRTH",
190///         "column_type": "Date",
191///         "format": "%Y-%m-%d"
192///     }
193/// ]
194/// }"#;
195/// let schema_map = get_schema_from_json_str(&schema_string).unwrap();
196/// let buffer_size = 1;
197/// let expected_date_of_birth_column_log = ColumnLog {
198///     name: "DATE_OF_BIRTH".to_string(),
199///     invalid_count: 1,
200///     max_invalid: Some("2004-31-01".to_string()),
201///     min_invalid: Some("2004-31-01".to_string()),
202/// };
203///
204///
205/// // Act
206/// let result = clean_csv(&mut csv_rdr, csv_wtr, schema_map, buffer_size);
207/// let output_csv = fs::read_to_string(output_path).expect("To be able to read from file");
208/// let log_map = result.expect("Result to have content").log_map;
209///
210/// // Assert
211/// assert_eq!(log_map.get("DATE_OF_BIRTH").unwrap(), &expected_date_of_birth_column_log);
212/// ```
213pub fn clean_csv<R: io::Read, W: io::Write + std::marker::Send + std::marker::Sync + 'static>(
214    csv_rdr: &mut Reader<R>,
215    csv_wtr: Writer<W>,
216    schema_map: HashMap<String, Column>,
217    buffer_size: usize,
218) -> Result<CleansingLog, CSVCleansingError> {
219    let result = process_rows_private(csv_rdr, csv_wtr, schema_map, buffer_size);
220    match result {
221        Ok(cleansing_log) => Ok(cleansing_log),
222        Err(err) => Err(CSVCleansingError::new(err.to_string())),
223    }
224}
225
226fn process_rows_private<
227    R: io::Read,
228    W: io::Write + std::marker::Send + std::marker::Sync + 'static,
229>(
230    csv_rdr: &mut Reader<R>,
231    mut csv_wtr: Writer<W>,
232    schema_map: HashMap<String, Column>,
233    buffer_size: usize,
234) -> Result<CleansingLog, Box<dyn Error>> {
235    // Set up multi-threaded processing
236    let schema_map = convert_from_std_hashmap(&schema_map);
237    let (tx, rx): (MapSender, MapReceiver) = mpsc::channel();
238    let (error_tx, error_rx): (StringSender, StringReciever) = mpsc::channel();
239    let mut row_count = 0;
240    let constants = generate_constants();
241    let column_names = csv_rdr.headers()?.clone();
242    check_spec_valid_for_input(&column_names, &schema_map)?;
243    csv_wtr.write_record(&column_names)?;
244    let locked_wtr = Arc::new(Mutex::new(csv_wtr));
245    let column_string_names: Vec<String> = column_names.iter().map(|x| x.to_string()).collect();
246    let mut row_buffer = Vec::new();
247    let pool = create_thread_pool()?;
248    let mut job_counter = 0;
249    // Read input to buffer on one thread; process and write output on other threads as buffer fills
250    for row in csv_rdr.deserialize() {
251        row_count += 1;
252        let row_map: Record = row?;
253        row_buffer.push(row_map);
254        if row_buffer.len() == buffer_size {
255            job_counter += 1;
256            let cloned_row_buffer = row_buffer.clone();
257            let cloned_schema_map = schema_map.clone();
258            let cloned_column_names = column_names.clone();
259            let cloned_constants = constants.clone();
260            let cloned_locked_wtr = Arc::clone(&locked_wtr);
261            let cloned_column_string_names = column_string_names.clone();
262            let thread_tx = tx.clone();
263            let thread_error_tx = error_tx.clone();
264            pool.spawn(move || {
265                let row_buffer_data = ProcessRowBufferConfig {
266                    column_names: &cloned_column_names,
267                    schema_map: &cloned_schema_map,
268                    row_buffer: &cloned_row_buffer,
269                    constants: &cloned_constants,
270                    locked_wtr: cloned_locked_wtr,
271                    column_string_names: &cloned_column_string_names,
272                    tx: thread_tx,
273                };
274                process_row_buffer_errors(row_buffer_data, thread_error_tx)
275                    .expect("Fatal error calling ThreadPool::spawn");
276            });
277            row_buffer.clear();
278        }
279        for potential_error in error_rx.try_iter() {
280            potential_error?;
281        }
282    }
283    let thread_tx = tx;
284
285    // Process any remaining rows in buffer
286    if !row_buffer.is_empty() {
287        job_counter += 1;
288        let row_buffer_data = ProcessRowBufferConfig {
289            column_names: &column_names,
290            schema_map: &schema_map,
291            row_buffer: &row_buffer,
292            constants: &constants,
293            locked_wtr,
294            column_string_names: &column_string_names,
295            tx: thread_tx,
296        };
297        process_row_buffer(row_buffer_data)?;
298    }
299
300    // Combined logs and raise any error messages sent by threads
301    let combined_log_map =
302        generate_combined_log_map(&column_names, column_string_names, rx, job_counter)?;
303    for potential_error in error_rx.try_iter() {
304        potential_error?;
305    }
306
307    Ok(CleansingLog {
308        total_rows: row_count,
309        log_map: convert_to_std_hashmap(combined_log_map),
310    })
311}
312
313fn process_row_buffer<W>(config: ProcessRowBufferConfig<W>) -> Result<(), Box<dyn Error>>
314where
315    W: io::Write + Send + Sync,
316{
317    let mut buffer_log_map =
318        generate_column_log_map(config.column_names, config.column_string_names);
319    let mut cleaned_rows = Vec::new();
320    for row_map in config.row_buffer.iter() {
321        let cleaned_row = process_row(
322            config.column_names,
323            config.schema_map,
324            row_map.clone(),
325            &mut buffer_log_map,
326            config.constants,
327        )?;
328        cleaned_rows.push(cleaned_row);
329    }
330    let mut wtr = config
331        .locked_wtr
332        .lock()
333        .expect("Fatal error attempting to aquire Writer in function process_row_buffer");
334    for cleaned_row in cleaned_rows.iter() {
335        wtr.write_record(cleaned_row)?;
336    }
337    config.tx.send(buffer_log_map)?;
338
339    Ok(())
340}
341
342fn process_row<'a>(
343    ordered_column_names: &'a StringRecord,
344    schema_dict: &'a FxHashMap<String, Column>,
345    row_map: FxHashMap<String, String>,
346    log_map: &'a mut FxHashMap<String, ColumnLog>,
347    constants: &Constants,
348) -> Result<StringRecord, Box<dyn Error>> {
349    let mut processed_row = Vec::new();
350    for column_name in ordered_column_names {
351        let column_value = row_map.get(column_name).ok_or_else(|| {
352            format!("Key error, could not find column_name `{column_name}` in row map")
353        })?;
354        let cleaned_value = column_value.clean(constants);
355        let column = schema_dict.get(column_name).ok_or_else(|| {
356            format!("Key error, could not find column_name `{column_name}` in schema`")
357        })?;
358        let processed_value = cleaned_value.process(column, constants);
359        if processed_value != cleaned_value {
360            let column_log = log_map.get(column_name).ok_or_else(|| {
361                format!("Key error, could not find column_name `{column_name}` in log_map`")
362            })?;
363            let invalid_count = column_log.invalid_count + 1;
364            let mut max_invalid = column_log.max_invalid.clone();
365            let mut min_invalid = column_log.min_invalid.clone();
366            match &column_log.max_invalid {
367                Some(x) => {
368                    if &processed_value > x {
369                        max_invalid = Some(cleaned_value.clone());
370                    }
371                }
372                None => {
373                    max_invalid = Some(cleaned_value.clone());
374                }
375            }
376            match &column_log.min_invalid {
377                Some(x) => {
378                    if &processed_value < x {
379                        min_invalid = Some(cleaned_value.clone());
380                    }
381                }
382                None => {
383                    min_invalid = Some(cleaned_value.clone());
384                }
385            }
386            let column_log_mut = log_map.get_mut(&column_name.to_string()).ok_or_else(|| {
387                format!("Key error, could not find column_name `{column_name}` in log_map`")
388            })?;
389            column_log_mut.invalid_count = invalid_count;
390            column_log_mut.min_invalid = min_invalid;
391            column_log_mut.max_invalid = max_invalid;
392        }
393        processed_row.push(processed_value);
394    }
395    let processed_record = StringRecord::from(processed_row);
396
397    Ok(processed_record)
398}
399
400fn process_row_buffer_errors<W>(
401    config: ProcessRowBufferConfig<W>,
402    error_tx: Sender<Result<(), String>>,
403) -> Result<(), Box<dyn Error>>
404where
405    W: io::Write + Send + Sync,
406{
407    let buffer_processing_result = process_row_buffer(config);
408    if let Err(err) = buffer_processing_result {
409        // Can't send Box<dyn Error>> between threads, so convert e
410        // to String before sending through channel
411        error_tx.send(Err(err.to_string()))?;
412    }
413
414    Ok(())
415}
416
417fn convert_to_std_hashmap(fast_map: FxHashMap<String, ColumnLog>) -> HashMap<String, ColumnLog> {
418    let mut regular_map = HashMap::new();
419    for (key, value) in fast_map {
420        regular_map.insert(key, value);
421    }
422    regular_map
423}
424
425fn convert_from_std_hashmap(input_hashmap: &HashMap<String, Column>) -> FxHashMap<String, Column> {
426    input_hashmap
427        .iter()
428        .map(|(k, v)| (k.clone(), v.clone()))
429        .collect()
430}
431
432fn create_thread_pool() -> Result<ThreadPool, ThreadPoolBuildError> {
433    let core_count = num_cpus::get();
434    let num_threads = if core_count == 1 { 1 } else { core_count - 1 };
435    rayon::ThreadPoolBuilder::new()
436        .num_threads(num_threads)
437        .build()
438}
439
440fn check_spec_valid_for_input(
441    column_names: &StringRecord,
442    schema_map: &FxHashMap<String, Column>,
443) -> Result<(), Box<dyn Error>> {
444    let spec_and_csv_columns_match = are_equal_spec_and_csv_columns(column_names, schema_map);
445    if !spec_and_csv_columns_match {
446        return Err(Box::new(CSVCleansingError::new(
447            "Error: CSV columns and schema columns do not match".to_string(),
448        )));
449    }
450
451    Ok(())
452}
453
454fn are_equal_spec_and_csv_columns(
455    csv_columns_record: &StringRecord,
456    spec: &FxHashMap<String, Column>,
457) -> bool {
458    csv_columns_record.len() == spec.len()
459        && csv_columns_record
460            .iter()
461            .all(|field| spec.contains_key(field))
462}
463
464fn generate_column_log_map(
465    column_names: &StringRecord,
466    column_string_names: &[String],
467) -> FxHashMap<String, ColumnLog> {
468    let column_logs: Vec<ColumnLog> = column_names
469        .clone()
470        .into_iter()
471        .map(|x| ColumnLog {
472            name: x.to_string(),
473            invalid_count: 0,
474            max_invalid: None,
475            min_invalid: None,
476        })
477        .collect();
478    let mut_log_map: FxHashMap<String, ColumnLog> = column_string_names
479        .iter()
480        .cloned()
481        .zip(column_logs.iter().cloned())
482        .collect();
483    mut_log_map
484}
485
486fn generate_combined_log_map(
487    column_names: &StringRecord,
488    column_string_names: Vec<String>,
489    rx: MapReceiver,
490    mut job_counter: i32,
491) -> Result<FxHashMap<String, ColumnLog>, Box<dyn Error>> {
492    let mut combined_log_map = generate_column_log_map(column_names, &column_string_names);
493    for log_map in rx.iter() {
494        job_counter -= 1;
495        for (column_name, column_log) in log_map {
496            let obtained_log = combined_log_map.get(&column_name.clone()).ok_or_else(|| {
497                format!("Key error, could not find column_name `{column_name}` in log map")
498            })?;
499            let updated_log = obtained_log.update(&column_log);
500            combined_log_map.insert(column_name.clone(), updated_log);
501        }
502        if job_counter < 1 {
503            break;
504        }
505    }
506
507    Ok(combined_log_map)
508}
509
510impl ColumnLog {
511    fn update(&self, other: &ColumnLog) -> ColumnLog {
512        assert!(self.name == other.name);
513        let new_invalid_count = self.invalid_count + other.invalid_count;
514        let new_max = match (self.max_invalid.clone(), other.max_invalid.clone()) {
515            (Some(x), Some(y)) => {
516                if x > y {
517                    Some(x)
518                } else {
519                    Some(y)
520                }
521            }
522            (Some(x), None) => Some(x),
523            (None, Some(y)) => Some(y),
524            _ => None,
525        };
526        let new_min = match (self.min_invalid.clone(), other.min_invalid.clone()) {
527            (Some(x), Some(y)) => {
528                if x < y {
529                    Some(x)
530                } else {
531                    Some(y)
532                }
533            }
534            (Some(x), None) => Some(x),
535            (None, Some(y)) => Some(y),
536            _ => None,
537        };
538
539        ColumnLog {
540            name: self.name.clone(),
541            invalid_count: new_invalid_count,
542            max_invalid: new_max,
543            min_invalid: new_min,
544        }
545    }
546}
547
548/// Extracts a Result  containing a JsonSchema CSV cleansing specification from a JSON string specification.
549///
550/// Example
551/// ```
552/// use csv_log_cleaner::get_schema_from_json_str;
553///
554/// let schema_string = r#"{
555/// "columns": [
556///     {
557///         "name": "NAME",
558///         "column_type": "String"
559///     },
560///     {
561///         "name": "AGE",
562///         "column_type": "Int"
563///     },
564///     {
565///         "name": "DATE_OF_BIRTH",
566///         "column_type": "Date",
567///         "format": "%Y-%m-%d"
568///     }
569/// ]
570/// }"#;
571/// let schema_map = get_schema_from_json_str(&schema_string).unwrap();
572/// println!("{:?}", schema_map);
573/// ```
574pub fn get_schema_from_json_str(
575    schema_json_string: &str,
576) -> Result<HashMap<String, Column>, io::Error> {
577    let json_schema: JsonSchema = serde_json::from_str(schema_json_string)?;
578    generate_validated_schema(json_schema)
579}
580
581impl CleansingLog {
582    pub fn json(&self) -> String {
583        let mut combined_string = format!(
584            "{{\n\t\"total_rows\": {},\n\t\"columns_with_errors\": [\n\t\t",
585            self.total_rows
586        );
587        let mut is_first_row = true;
588        for (column_name, column_log) in self.log_map.iter() {
589            let mut max_val = String::new();
590            {
591                if let Some(x) = &column_log.max_invalid {
592                    max_val = x.clone();
593                }
594            }
595            let mut min_val = String::new();
596            {
597                if let Some(x) = &column_log.min_invalid {
598                    min_val = x.clone();
599                }
600            }
601            let invalid_row_count = column_log.invalid_count;
602            let col_string = format!("{{\n\t\t\t\"column_name\": \"{column_name}\",\n\t\t\t\"invalid_row_count\": {invalid_row_count},\n\t\t\t\"max_illegal_val\": \"{max_val}\",\n\t\t\t\"min_illegal_val\": \"{min_val}\"\n\t\t}}");
603            if is_first_row {
604                combined_string = format!("{combined_string}{col_string}");
605            } else {
606                combined_string = format!("{combined_string},{col_string}");
607            }
608            is_first_row = false;
609        }
610        combined_string = format!("{combined_string}\n\t]\n}}");
611        combined_string
612    }
613}
614
615fn generate_constants() -> Constants {
616    let null_vals = vec![
617        "#N/A".to_string(),
618        "#N/A".to_string(),
619        "N/A".to_string(),
620        "#NA".to_string(),
621        "-1.#IND".to_string(),
622        "-1.#QNAN".to_string(),
623        "-NaN".to_string(),
624        "-nan".to_string(),
625        "1.#IND".to_string(),
626        "1.#QNAN".to_string(),
627        "<NA>".to_string(),
628        "N/A".to_string(),
629        "NA".to_string(),
630        "NULL".to_string(),
631        "NaN".to_string(),
632        "n/a".to_string(),
633        "nan".to_string(),
634        "null".to_string(),
635    ];
636    let bool_vals = vec![
637        "true".to_string(),
638        "1".to_string(),
639        "1.0".to_string(),
640        "yes".to_string(),
641        "false".to_string(),
642        "0.0".to_string(),
643        "0".to_string(),
644        "no".to_string(),
645    ];
646    Constants {
647        null_vals,
648        bool_vals,
649    }
650}
651
652fn generate_validated_schema(
653    json_schema: JsonSchema,
654) -> Result<HashMap<String, Column>, io::Error> {
655    let empty_vec: Vec<String> = Vec::new();
656    let empty_string = String::new();
657    let mut column_map: HashMap<String, Column> = HashMap::default();
658    for column in json_schema.columns {
659        let new_col = Column {
660            column_type: column.column_type,
661            illegal_val_replacement: column
662                .illegal_val_replacement
663                .unwrap_or_else(|| empty_string.clone()),
664            legal_vals: column.legal_vals.unwrap_or_else(|| empty_vec.clone()),
665            format: column.format.unwrap_or_else(|| empty_string.clone()),
666        };
667
668        match column.column_type {
669            ColumnType::Date => {
670                if new_col.format.is_empty() {
671                    let custom_error = io::Error::new(
672                        ErrorKind::Other,
673                        "Missing required `format` string value for Date column",
674                    );
675                    return Err(custom_error);
676                }
677            }
678            ColumnType::Enum => {
679                if new_col.legal_vals.is_empty() {
680                    let custom_error = io::Error::new(
681                        ErrorKind::Other,
682                        "Missing required `legal_vals` string list value for Enum column",
683                    );
684                    return Err(custom_error);
685                }
686            }
687            _ => {}
688        }
689        column_map.insert(column.name, new_col);
690    }
691    Ok(column_map)
692}
693
694trait Process {
695    fn process(&self, column: &Column, constants: &Constants) -> Self;
696}
697
698impl Process for String {
699    fn process(&self, column: &Column, constants: &Constants) -> Self {
700        match column.column_type {
701            ColumnType::String => self.to_string(),
702            ColumnType::Int => {
703                let cleaned = self.de_pseudofloat();
704                if cleaned.casts_to_int() {
705                    cleaned
706                } else {
707                    column.illegal_val_replacement.to_owned()
708                }
709            }
710            ColumnType::Date => {
711                let cleaned = self;
712                if cleaned.casts_to_date(&column.format) {
713                    cleaned.to_string()
714                } else {
715                    column.illegal_val_replacement.to_owned()
716                }
717            }
718            ColumnType::Float => {
719                let cleaned = self;
720                if cleaned.casts_to_float() {
721                    cleaned.to_string()
722                } else {
723                    column.illegal_val_replacement.to_owned()
724                }
725            }
726            ColumnType::Enum => {
727                let cleaned = self;
728                if cleaned.casts_to_enum(&column.legal_vals) {
729                    cleaned.to_string()
730                } else {
731                    column.illegal_val_replacement.to_owned()
732                }
733            }
734            ColumnType::Bool => {
735                let cleaned = self;
736                if cleaned.casts_to_bool(constants) {
737                    cleaned.to_string()
738                } else {
739                    column.illegal_val_replacement.to_owned()
740                }
741            }
742        }
743    }
744}
745
746trait Clean {
747    fn clean(&self, constants: &Constants) -> Self;
748}
749
750impl Clean for String {
751    fn clean(&self, constants: &Constants) -> Self {
752        if constants.null_vals.contains(self) {
753            String::new()
754        } else {
755            self.to_string()
756        }
757    }
758}
759
760trait CastsToBool {
761    fn casts_to_bool(&self, constants: &Constants) -> bool;
762}
763
764impl CastsToBool for String {
765    fn casts_to_bool(&self, constants: &Constants) -> bool {
766        constants.bool_vals.contains(&self.to_lowercase())
767    }
768}
769
770trait CastsToEnum {
771    fn casts_to_enum(&self, legal_values: &[String]) -> bool;
772}
773
774impl CastsToEnum for String {
775    fn casts_to_enum(&self, legal_values: &[String]) -> bool {
776        legal_values.contains(self)
777    }
778}
779
780trait CastsToDate {
781    fn casts_to_date(&self, format: &str) -> bool;
782}
783
784impl CastsToDate for String {
785    // `format` parameter should be a value of the form defined here: https://docs.rs/chrono/latest/chrono/format/strftime/index.html
786    fn casts_to_date(&self, format: &str) -> bool {
787        NaiveDate::parse_from_str(self, format).is_ok()
788    }
789}
790
791trait CastsToInt {
792    fn casts_to_int(&self) -> bool;
793}
794
795impl CastsToInt for String {
796    fn casts_to_int(&self) -> bool {
797        // Note: Will return false is number is too large or small
798        // to be stored as a 64 bit signed int:
799        //     min val: -9_223_372_036_854_775_808
800        //     max val: 9_223_372_036_854_775_807
801        self.parse::<i64>().is_ok()
802    }
803}
804
805trait CastsToFloat {
806    fn casts_to_float(&self) -> bool;
807}
808
809impl CastsToFloat for String {
810    fn casts_to_float(&self) -> bool {
811        // Note: Will return false is number is too large or small
812        // to be stored as a 64 bit signed float:
813        //     min val: -1.7976931348623157E+308f64
814        //     max val: 1.7976931348623157E+308f64
815        self.parse::<f64>().is_ok()
816    }
817}
818
819trait DePseudofloat {
820    fn de_pseudofloat(&self) -> Self;
821}
822
823impl DePseudofloat for String {
824    fn de_pseudofloat(&self) -> Self {
825        let is_pseudofloat = self.ends_with(".0");
826        if is_pseudofloat {
827            rem_last_n_chars(self, 2).to_string()
828        } else {
829            self.to_owned()
830        }
831    }
832}
833
834fn rem_last_n_chars(value: &str, n: i32) -> &str {
835    let mut chars = value.chars();
836    for _ in 0..n {
837        chars.next_back();
838    }
839    chars.as_str()
840}
841
842#[cfg(test)]
843mod tests {
844    use super::*;
845
846    #[test]
847    fn clean_string() {
848        // Arrange
849        let input = vec!["NULL".to_string(), String::new(), " dog\t".to_string()];
850        let expected = vec![String::new(), String::new(), " dog\t".to_string()];
851        let constants = generate_constants();
852        // Act
853        let result = input
854            .iter()
855            .map(|x| x.clean(&constants))
856            .collect::<Vec<_>>();
857        // Assert
858        assert_eq!(result, expected);
859    }
860
861    #[test]
862    fn converts_to_int() {
863        // Arrange
864        let input = vec![
865            "1".to_string(),
866            "-3".to_string(),
867            "264634633426".to_string(),
868            "dog".to_string(),
869            "0.4".to_string(),
870            "1.0".to_string(),
871        ];
872        let expected = vec![true, true, true, false, false, false];
873        // Act
874        let result = input.iter().map(|x| x.casts_to_int()).collect::<Vec<_>>();
875        // Assert
876        assert_eq!(result, expected);
877    }
878
879    #[test]
880    fn de_psuedofloats() {
881        // Arrange
882        let input = vec![
883            String::new(),
884            "-3.0".to_string(),
885            "264634633426.0".to_string(),
886            "dog".to_string(),
887            "0.4".to_string(),
888            "1.0".to_string(),
889        ];
890        let expected = vec![
891            String::new(),
892            "-3".to_string(),
893            "264634633426".to_string(),
894            "dog".to_string(),
895            "0.4".to_string(),
896            "1".to_string(),
897        ];
898        // Act
899        let result = input.iter().map(|x| x.de_pseudofloat()).collect::<Vec<_>>();
900        // Assert
901        assert_eq!(result, expected);
902    }
903
904    #[test]
905    fn process_string() {
906        // Arrange
907        let input = vec![String::new(), " foo\t".to_string(), "bar".to_string()];
908        let expected = vec![String::new(), " foo\t".to_string(), "bar".to_string()];
909        let legal_vals: Vec<String> = Vec::new();
910        let column = Column {
911            column_type: ColumnType::String,
912            illegal_val_replacement: String::new(),
913            legal_vals: legal_vals,
914            format: String::new(),
915        };
916        let constants = generate_constants();
917        // Act
918        let result = input
919            .iter()
920            .map(|x| x.process(&column, &constants))
921            .collect::<Vec<_>>();
922        // Assert
923        assert_eq!(result, expected);
924    }
925
926    #[test]
927    fn process_int() {
928        // Arrange
929        let input = vec!["1".to_string(), "-2.0".to_string(), "3134.4".to_string()];
930        let expected = vec!["1".to_string(), "-2".to_string(), String::new()];
931        let legal_vals: Vec<String> = Vec::new();
932        let column = Column {
933            column_type: ColumnType::Int,
934            illegal_val_replacement: String::new(),
935            legal_vals: legal_vals,
936            format: String::new(),
937        };
938        let constants = generate_constants();
939        // Act
940        let result = input
941            .iter()
942            .map(|x| x.process(&column, &constants))
943            .collect::<Vec<_>>();
944        // Assert
945        assert_eq!(result, expected);
946    }
947
948    #[test]
949    fn converts_to_date() {
950        // Arrange
951        let input = vec![
952            "2022-01-31".to_string(),
953            "1878-02-03".to_string(),
954            "2115-04-42".to_string(),
955            "dog".to_string(),
956            "31-01-2022".to_string(),
957        ];
958        let expected = vec![true, true, false, false, false];
959        // Act
960        let result = input
961            .iter()
962            .map(|x| x.casts_to_date(&"%Y-%m-%d".to_string()))
963            .collect::<Vec<_>>();
964        // Assert
965        assert_eq!(result, expected);
966    }
967
968    #[test]
969    fn process_date() {
970        // Arrange
971        let input = vec![
972            "2020-01-01".to_string(),
973            " 2200-12-31\t".to_string(),
974            String::new(),
975        ];
976        let expected = vec!["2020-01-01".to_string(), String::new(), String::new()];
977        let legal_vals: Vec<String> = Vec::new();
978        let column = Column {
979            column_type: ColumnType::Date,
980            illegal_val_replacement: String::new(),
981            legal_vals: legal_vals,
982            format: "%Y-%m-%d".to_string(),
983        };
984        let constants = generate_constants();
985        // Act
986        let result = input
987            .iter()
988            .map(|x| x.process(&column, &constants))
989            .collect::<Vec<_>>();
990        // Assert
991        assert_eq!(result, expected);
992    }
993
994    #[test]
995    fn converts_to_float() {
996        // Arrange
997        let input = vec![
998            "1.0".to_string(),
999            "-3".to_string(),
1000            "264634633426".to_string(),
1001            "dog".to_string(),
1002            "0.4".to_string(),
1003            String::new(),
1004        ];
1005        let expected = vec![true, true, true, false, true, false];
1006        // Act
1007        let result = input.iter().map(|x| x.casts_to_float()).collect::<Vec<_>>();
1008        // Assert
1009        assert_eq!(result, expected);
1010    }
1011
1012    #[test]
1013    fn process_float() {
1014        // Arrange
1015        let input = vec![String::new(), " 0.1\t".to_string(), "123.456".to_string()];
1016        let expected = vec![String::new(), String::new(), "123.456".to_string()];
1017        let legal_vals: Vec<String> = Vec::new();
1018        let column = Column {
1019            column_type: ColumnType::Float,
1020            illegal_val_replacement: String::new(),
1021            legal_vals: legal_vals,
1022            format: String::new(),
1023        };
1024        let constants = generate_constants();
1025        // Act
1026        let result = input
1027            .iter()
1028            .map(|x| x.process(&column, &constants))
1029            .collect::<Vec<_>>();
1030        // Assert
1031        assert_eq!(result, expected);
1032    }
1033
1034    #[test]
1035    fn converts_to_enum() {
1036        // Arrange
1037        let input = vec![
1038            "A".to_string(),
1039            "B".to_string(),
1040            "C".to_string(),
1041            "7".to_string(),
1042            "0.4".to_string(),
1043            String::new(),
1044        ];
1045        let legal = vec!["A".to_string(), "B".to_string()];
1046        let expected = vec![true, true, false, false, false, false];
1047        // Act
1048        let result = input
1049            .iter()
1050            .map(|x| x.casts_to_enum(&legal))
1051            .collect::<Vec<_>>();
1052        // Assert
1053        assert_eq!(result, expected);
1054    }
1055
1056    #[test]
1057    fn converts_to_tool() {
1058        // Arrange
1059        let input = vec![
1060            "true".to_string(),
1061            "false".to_string(),
1062            "True".to_string(),
1063            "False".to_string(),
1064            "0".to_string(),
1065            "1".to_string(),
1066            "dog".to_string(),
1067        ];
1068        let expected = vec![true, true, true, true, true, true, false];
1069        let constants = generate_constants();
1070        // Act
1071        let result = input
1072            .iter()
1073            .map(|x| x.casts_to_bool(&constants))
1074            .collect::<Vec<_>>();
1075        // Assert
1076        assert_eq!(result, expected);
1077    }
1078
1079    #[test]
1080    fn process_enum() {
1081        // Arrange
1082        let input = vec![String::new(), " A\t".to_string(), "B".to_string()];
1083        let expected = vec![String::new(), String::new(), "B".to_string()];
1084        let legal_vals = vec!["A".to_string(), "B".to_string()];
1085        let column = Column {
1086            column_type: ColumnType::Enum,
1087            illegal_val_replacement: String::new(),
1088            legal_vals: legal_vals,
1089            format: String::new(),
1090        };
1091        let constants = generate_constants();
1092        // Act
1093        let result = input
1094            .iter()
1095            .map(|x| x.process(&column, &constants))
1096            .collect::<Vec<_>>();
1097        // Assert
1098        assert_eq!(result, expected);
1099    }
1100
1101    #[test]
1102    fn generate_column() {
1103        // Arrange
1104        let raw_schema = r#"
1105            {
1106                "columns": [
1107		    {
1108			"name": "INT_COLUMN",
1109			"column_type": "Int",
1110			"illegal_val_replacement": null,
1111			"legal_vals": null
1112		    },
1113		    {
1114			"name": "DATE_COLUMN",
1115			"column_type": "Date",
1116			"format": "%Y-%m-%d"
1117		    },
1118		    {
1119			"name": "FLOAT_COLUMN",
1120			"column_type": "Float",
1121			"illegal_val_replacement": ""
1122		    },
1123		    {
1124			"name": "STRING_COLUMN",
1125			"column_type": "String"
1126		    },
1127		    {
1128			"name": "BOOL_COLUMN",
1129			"column_type": "Bool"
1130		    },
1131		    {
1132			"name": "ENUM_COLUMN",
1133			"column_type": "Enum",
1134			"illegal_val_replacement": "DEFAULT",
1135			"legal_vals": ["A", "B", "DEFAULT"]
1136		    }
1137                ]
1138            }"#;
1139        let json_schema: JsonSchema = serde_json::from_str(raw_schema).unwrap();
1140
1141        generate_validated_schema(json_schema).unwrap();
1142    }
1143}