psqlgen 0.1.2

PostgreSQL Query generation from migrations.
use std::{fs, path::PathBuf};

use sqlparser::{dialect::AnsiDialect, parser::Parser};

use crate::error::QueryGenError;


#[derive(Debug, Clone)]
pub struct QueryNeeds {
    pub table_name: String,
    pub field_names: Vec<String>,
}

/// This will only check one level deep. Not recursive.
pub fn collect_sql_migrations(input_dir: PathBuf) -> Result<Vec<QueryNeeds>, QueryGenError> {
    if !input_dir.is_dir() {
        return Err(QueryGenError::PathNotDir);
    }
    let items_in_dir = match input_dir.read_dir() {
        Ok(dir) => dir,
        Err(error) => {
            println!("{error}");
            return Err(QueryGenError::InvalidPath);
        }
    };
    let mut sql_files = 0;
    let mut failed_files = 0;
    let mut query_needs_list = Vec::new();
    for item_result in items_in_dir {
        if let Ok(item) = item_result {
            if let Ok(file_type) = item.file_type() {
                if file_type.is_file() {
                    if let Ok(file_name) = item.file_name().into_string() {
                        if file_name.ends_with(".sql") {
                            sql_files += 1;
                            match fs::read_to_string(item.path()) {
                                Ok(file_contents_string) => {
                                    match collect_fields_from_migration(file_contents_string) {
                                        Ok(query_needs) => {
                                            query_needs_list.push(query_needs);
                                        },
                                        Err(error) => {
                                            failed_files += 1;
                                            println!("{:?}", error);
                                        }
                                    };
                                }
                                Err(error) => {
                                    failed_files += 1;
                                    println!("{error}");
                                }
                            }
                        }
                    }
                }
            }
        }
    }
    println!("Number of .sql Files: {sql_files}.");
    println!("Failed reading {failed_files} files.");
    Ok(query_needs_list)
}

fn collect_fields_from_migration(file_contents_string: String) -> Result<QueryNeeds, QueryGenError> {
    //let mut fields = Vec::new();

    let ast = match Parser::parse_sql(&AnsiDialect {}, &file_contents_string) {
        Ok(ast) => ast,
        Err(_) => {
            return Err(QueryGenError::InvalidSql);
        },
    };
    for stmnt in ast {
        match stmnt {
            #[allow(unused)]
            sqlparser::ast::Statement::CreateTable {
                or_replace,
                temporary,
                external,
                global,
                if_not_exists,
                transient,
                mut name,
                columns,
                constraints,
                hive_distribution,
                hive_formats,
                table_properties,
                with_options,
                file_format,
                location,
                query,
                without_rowid,
                like,
                clone,
                engine,
                default_charset,
                collation,
                on_commit,
                on_cluster,
                order_by,
                strict,
            } => {
                if name.0.len() > 0 {
                    let table_name = name.0.remove(0).value;

                    let field_names: Vec<String> = columns.into_iter().map(|col| col.name.value).collect();
                    return Ok(QueryNeeds { table_name, field_names })
                    
                } else {
                    return Err(QueryGenError::InvalidSql);
                }
                
                
                
            },
            _ => return Err(QueryGenError::NotACreateTableMigration)
            
        }
    }
    Err(QueryGenError::InvalidSql)
}

/*

[CreateTable { or_replace: false, temporary: false, external: false, global: None, if_not_exists: true, transient: false, name: ObjectName([Ident { value: "product_variant", quote_style: None }]), columns: [ColumnDef { name: Ident { value: "id", quote_style: None }, data_type: Uuid, collation: None, options: [ColumnOptionDef { name: None, option: Unique { is_primary: true } }] }, ColumnDef { name: Ident { value: "product_id", quote_style: None }, data_type: Uuid, collation: None, options: [ColumnOptionDef { name: None, option: NotNull }] }, ColumnDef { name: Ident { value: "variant_type", quote_style: None }, data_type: Varchar(None), collation: None, options: [ColumnOptionDef { name: None, option: NotNull }] }, ColumnDef { name: Ident { value: "variant_label", quote_style: None }, data_type: Varchar(None), collation: None, options: [ColumnOptionDef { name: None, option: NotNull }] }, ColumnDef { name: Ident { value: "time_created", quote_style: None }, data_type: Timestamp(None, Tz), collation: None, options: [ColumnOptionDef { name: None, option: NotNull }] }], constraints: [], hive_distribution: NONE, hive_formats: Some(HiveFormat { row_format: None, storage: None, location: None }), table_properties: [], with_options: [], file_format: None, location: None, query: None, without_rowid: false, like: None, clone: None, engine: None, default_charset: None, collation: None, on_commit: None, on_cluster: None, order_by: None, strict: false }]
*/