csvsource 0.7.0

Converts a CSV file to SQL Insert Statements.
Documentation
use serde::Serialize;
use itertools::intersperse;
use std::fs::File;
use std::io;
use std::io::{BufWriter, Write};
use std::path::Path;
use std::io::prelude::*;
use std::result::Result;
use tinytemplate::TinyTemplate;

pub fn convert_to_sql(args: Arguments) -> Result<(), io::Error> {
    if !Path::new(args.csv.as_str()).exists() {
        return Err(io::Error::new(io::ErrorKind::NotFound, "CSV file not found"));
    }

    let csv_file = File::open(args.csv.clone())?;
    let reader = io::BufReader::new(csv_file);
    let csv_reader = csv::ReaderBuilder::new()
                .has_headers(args.has_headers)
                .from_reader(reader);

    generate_sql_file(args, csv_reader)
}

fn generate_sql_file(args: Arguments, csv_reader: csv::Reader<io::BufReader<File>>) -> Result<(), io::Error> {
    let sql_file = File::create(&args.sql).expect("Unable to create sql file");
    let mut writer = BufWriter::new(sql_file);

    let context = &TemplateContext {
        table: args.table.to_string()
    };
    append_file_content(args.prefix.clone(), context, &mut writer)?;
    generate_sql(&args, csv_reader, &mut writer)?;
    append_file_content(args.suffix, context, &mut writer)?;

    Ok(())
}

fn generate_sql(args: &Arguments, mut csv_reader: csv::Reader<io::BufReader<File>>, writer: &mut BufWriter<File>) -> Result<(), io::Error> {
    let insert_fields = args.format_fields(args.get_fields(csv_reader.headers()?));

    let mut chunk_count = 0;
    let mut chunk_insert_count = 0;
    let mut insert_separator = ";\n\n";

    if args.with_transaction {
        write!(writer, "begin transaction")?;
    } else {
        insert_separator = "";
    }

    for record in csv_reader.records() {
        if chunk_insert_count == 0 {
            if args.chunk > 0 && chunk_count == args.chunk {
                write!(writer, ";\n\ncommit;\n\nbegin transaction")?;
                chunk_count = 0;
            }

            write!(writer, "{}insert into {} {} values", insert_separator, args.table.as_str(), insert_fields)?;
            insert_separator = "";
            chunk_count += 1;
        }

        match record {
            Ok(row) => write!(writer, "{}\n{}", insert_separator, get_values(args, &row))?,
            Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e))
        }

        if args.chunk_insert > 0 {
            chunk_insert_count += 1;
            insert_separator = ",";
            if args.chunk_insert == chunk_insert_count {
                chunk_insert_count = 0;
                insert_separator = ";\n\n";
            }
        } else {
            insert_separator = ";\n\n";
        }
    }

    if args.with_transaction {
        write!(writer, ";\n\ncommit;")?
    } else {
        write!(writer, ";")?
    }

    Ok(())
}

#[derive(Serialize)]
struct TemplateContext {
    table: String,
}

fn append_file_content(path: String, context: &TemplateContext, writer: &mut BufWriter<File>) -> Result<(), io::Error> {
    if !Path::new(path.as_str()).exists() {
        return Ok(());
    }
    
    let file = File::open(path)?;
    let reader = io::BufReader::new(file);
    let mut template = String::new();

    for line in reader.lines() {
        template.push_str(line.unwrap().as_str());
        template.push_str("\n");
    }

    let mut tt = TinyTemplate::new();
    let rendered = match tt.add_template("append", template.as_str()) {
        Ok(..) => match tt.render("append", context) {
            Ok(r) => r,
            Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e))
        },
        Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e))
    };

    writeln!(writer, "{}", rendered)?;

    Ok(())
}

fn get_values(args: &Arguments, record: &csv::StringRecord) -> String {
    let mut values = String::new();
    let mut separator = "";

    for result in record {
        values.push_str(separator);
        if args.typed {
            values.push_str(&get_value(result));
        } else {
            values.push_str("'");
            values.push_str(&result.replace("'", "''"));
            values.push_str("'");
        }
        separator = ", "
    }

    format!("({})", values)
}

fn get_value(result: &str) -> String {
    let mut value = String::new();

    if is_number(result) {
        value.push_str(result);
    } else if is_boolean(String::from(result)) {
        value.push_str(result);
    } else {
        if result.is_empty() {
            value.push_str("NULL");
        } else {
            value.push_str("'");
            value.push_str(&result.replace("'", "''"));
            value.push_str("'");
        }
    }

    value
}

fn is_number(str: &str) -> bool {
    if str.is_empty() {
        return false;
    }

    let test = str.parse::<f64>();

    return match test {
        Ok(_) => true,
        Err(_) => false,
    }
}

fn is_boolean(str: String) -> bool {
    let tr = "true";
    let fs = "false";

    return tr.eq(&str.to_lowercase()) || fs.eq(&str.to_lowercase());
}

pub struct Arguments {
    pub csv              : String,
    pub sql              : String,
    pub delimiter        : u8,
    pub has_headers      : bool,
    pub table            : String,
    pub columns          : Vec<String>,
    pub chunk            : usize,
    pub chunk_insert     : usize,
    pub prefix           : String,
    pub suffix           : String,
    pub with_transaction : bool,
    pub typed            : bool,
}

impl Arguments {
    fn get_fields(&self, headers: &csv::StringRecord) -> Vec<String> {
        let mut fields: Vec<String> = Vec::new();
        if self.columns.is_empty() && self.has_headers {
            for header in headers {
                fields.push(header.to_string());
            }
        } else {
            for column in &self.columns {
                fields.push(column.to_string());
            }
        }
        fields
    }

    fn format_fields(&self, fields: Vec<String>) -> String {
        let insert_fields: String = intersperse(fields, ", ".to_string()).collect();
        format!("({})", insert_fields)
    }
}