pg2parquet 0.2.2

Command line tool for exporting PostgreSQL tables or queries into Parquet files
#![allow(unused_imports)]
#![allow(dead_code)]
use std::{sync::Arc, path::PathBuf, process};

use clap::{Command, CommandFactory, Parser, ValueEnum};
use parquet::{basic::{ZstdLevel, BrotliLevel, GzipLevel, Compression}, file::properties::DEFAULT_WRITE_BATCH_SIZE};
use postgres_cloner::{SchemaSettingsArrayHandling, SchemaSettingsEnumHandling, SchemaSettingsFloat16Handling, SchemaSettingsIntervalHandling, SchemaSettingsJsonHandling, SchemaSettingsMacaddrHandling, SchemaSettingsNumericHandling};

mod postgresutils;
mod myfrom;
mod level_index;
mod parquetinfo;
mod playground;
mod parquet_writer;
mod postgres_cloner;
mod pg_custom_types;
mod datatypes;
mod appenders;

#[cfg(not(any(target_family = "windows", target_arch = "riscv64")))]
use jemallocator::Jemalloc;

use crate::postgres_cloner::SchemaSettings;

#[cfg(not(any(target_family = "windows", target_arch = "riscv64")))]
#[global_allocator]
static GLOBAL: Jemalloc = Jemalloc;


#[derive(Parser, Debug, Clone)]
#[command(name = "pg2parquet")]
#[command(bin_name = "pg2parquet")]
#[command(version)]
enum CliCommand {
    /// Dumps something from a parquet file
    #[command(arg_required_else_help = true, hide = true)]
    ParquetInfo(ParquetInfoArgs),
    #[command(arg_required_else_help = true, hide = true)]
    PlaygroundCreateSomething(PlaygroundCreateSomethingArgs),
    #[command(arg_required_else_help = true, hide = true)]
    #[cfg(feature = "manpage_gen")]
    ManpageGen(ManpageGenArgs),
    /// Exports a PostgreSQL table or query to a Parquet file
    #[command(arg_required_else_help = true)]
    Export(ExportArgs)
}

#[derive(clap::Args, Debug, Clone)]
struct ExportArgs {
    /// Path to the output file. If the file exists, it will be overwritten.
    #[arg(long, short = 'o')]
    output_file: PathBuf,
    /// SQL query to execute. Exclusive with --table
    #[arg(long, short = 'q')]
    query: Option<String>,
    /// Which table should be exported. Exclusive with --query
    #[arg(long, short = 't')]
    table: Option<String>,
    /// Compression applied on the output file. Default: zstd, change to Snappy or None if it's too slow
    #[arg(long, hide_short_help = true)]
    compression: Option<ParquetCompression>,
    /// Compression level of the output file compressor. Only relevant for zstd, brotli and gzip. Default: 3
    #[arg(long, hide_short_help = true)]
    compression_level: Option<i32>,
    /// Avoid printing unnecessary information (schema and progress). Only errors will be written to stderr
    #[arg(long, hide_short_help = true)]
    quiet: bool,
    #[command(flatten)]
    postgres: PostgresConnArgs,
    #[command(flatten)]
    schema_settings: SchemaSettingsArgs,
}

#[derive(clap::ValueEnum, Debug, Clone)]
enum SslMode {
    /// Do not use TLS.
    Disable,
    /// Attempt to connect with TLS but allow sessions without (default behavior compiled with SSL support).
    Prefer,
    /// Require the use of TLS.
    Require,
}

#[derive(clap::Args, Clone)]
pub struct PostgresConnArgs {
    /// PostgreSQL connection URL (postgres://...) or connection string. Mutually exclusive with individual connection parameters.
    /// The connection URL may be also provided using the DATABASE_URL environment variable (recommended if it contains a password)
    /// See https://docs.rs/postgres/latest/postgres/config/struct.Config.html for a list of supported options
    #[arg(long="connection", short='c', conflicts_with_all = ["host", "user", "dbname", "port", "password", "sslmode"])]
    connection_string: Option<String>,

    /// Database server host
    #[arg(short='H', long)]
    host: Option<String>,
    /// Database user name. If not specified, PGUSER environment variable is used.
    #[arg(short='U', long)]
    user: Option<String>,
    #[arg(short='d', long)]
    dbname: Option<String>,
    #[arg(short='p', long)]
    port: Option<u16>,
    /// Password to use for the connection. It is recommended to use the PGPASSWORD environment variable instead, since process arguments are visible to other users on the system.
    #[arg(long)]
    password: Option<String>,
    /// Controls whether to use SSL/TLS to connect to the server.
    #[arg(long="sslmode", alias="tlsmode", alias="ssl-mode", alias="tls-mode")]
    sslmode: Option<SslMode>,
    /// File with a TLS root certificate in PEM or DER (.crt) format. When specified, the default CA certificates are considered untrusted. The option can be specified multiple times. The option implies --sslmode=require.
    #[arg(long="ssl-root-cert", alias="tls-root-cert")]
    ssl_root_cert: Option<Vec<PathBuf>>,
    /// File with a TLS client certificate in PEM format. Note that you also need to use the ssl-client-key option
    #[arg(long="ssl-client-cert", alias="tls-client-cert")]
    ssl_client_cert: Option<PathBuf>,
    /// File with a TLS client key in PEM format. 
    #[arg(long="ssl-client-key", alias="tls-client-key")]
    ssl_client_key: Option<PathBuf>,
}

impl PostgresConnArgs {
    fn validate(&self) -> Result<(), String> {
        // Either connection_string or host+dbname are specified
        if self.connection_string.is_some() || std::env::var("DATABASE_URL").is_ok() || std::env::var("POSTGRES_URL").is_ok() {
            return Ok(());
        }
        if self.host.is_some() && self.dbname.is_some() {
            return Ok(());
        }

        Err("Either --connection <CONNECTION_STRING> or --host <HOST> and --dbname <DBNAME> must be provided, or set the DATABASE_URL environment variable".to_string())
    }
}

impl std::fmt::Debug for PostgresConnArgs {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let mut s = f.debug_struct("PostgresConnArgs");
        if let Some(connection_string) = &self.connection_string { 
            let parsed = connection_string.parse::<postgres::config::Config>();
            if matches!(parsed, Ok(c) if c.get_password().is_some()) {
                // don't expose passswords in debug output
                s.field("connection_string", &"<contains password>"); 
            } else {
                s.field("connection_string", connection_string);
            }
        }
        if let Some(host) = &self.host { s.field("host", host); }
        if let Some(user) = &self.user { s.field("user", user); }
        if let Some(dbname) = &self.dbname { s.field("dbname", dbname); }
        if let Some(port) = &self.port { s.field("port", port); }
        if let Some(sslmode) = &self.sslmode { s.field("sslmode", sslmode); }
        if let Some(ssl_root_cert) = &self.ssl_root_cert { s.field("ssl_root_cert", ssl_root_cert); }
        s.finish()
    }
}

#[derive(clap::Args, Debug, Clone)]
pub struct SchemaSettingsArgs {
    /// How to handle `macaddr` columns
    #[arg(long, hide_short_help = true, default_value = "text")]
    macaddr_handling: SchemaSettingsMacaddrHandling,
    /// How to handle `json` and `jsonb` columns
    #[arg(long, hide_short_help = true, default_value = "text")]
	json_handling: SchemaSettingsJsonHandling,
    /// How to handle enum (Enumerated Type) columns 
    #[arg(long, hide_short_help = true, default_value = "text")]
    enum_handling: SchemaSettingsEnumHandling,
    /// How to handle `interval` columns
    #[arg(long, hide_short_help = true, default_value = "interval")]
    interval_handling: SchemaSettingsIntervalHandling,
    /// How to handle `numeric` columns
    #[arg(long, hide_short_help = true, default_value = "double")]
    numeric_handling: SchemaSettingsNumericHandling,
    /// How many decimal digits after the decimal point are stored in the Parquet file in DECIMAL data type.
    #[arg(long, hide_short_help = true, default_value_t = 18)]
	decimal_scale: i32,
    /// How many decimal digits are allowed in numeric/DECIMAL column. By default 38, the largest value which fits in 128 bits. If <= 9, the column is stored as INT32; if <= 18, the column is stored as INT64; otherwise BYTE_ARRAY.
    #[arg(long, hide_short_help = true, default_value_t = 38)]
    decimal_precision: u32,
    /// Parquet does not support multi-dimensional arrays and arrays with different starting index. pg2parquet flattens the arrays, and this options allows including the stripped information in additional columns.
    #[arg(long, hide_short_help = true, default_value = "plain")]
    array_handling: SchemaSettingsArrayHandling,
    /// 
    #[arg(long, hide_short_help = true, default_value = "float32")]
    float16_handling: SchemaSettingsFloat16Handling,
}


#[derive(ValueEnum, Debug, Clone)]
enum ParquetCompression { None, Snappy, Gzip, Lzo, Brotli, Lz4, Zstd }

#[derive(clap::Args, Debug, Clone)]
#[cfg(feature = "manpage_gen")]
struct ManpageGenArgs {
    #[arg(long)]
    out_manpage: Option<PathBuf>,
    #[arg(long)]
    out_completion: Option<PathBuf>,
}

#[derive(clap::Args, Debug, Clone)]
// #[command(author, version, about, long_about = None)]
struct ParquetInfoArgs {
    parquet_file: PathBuf,
    // #[arg(long)]
    // manifest_path: Option<std::path::PathBuf>,
}

#[derive(clap::Args, Debug, Clone)]
struct PlaygroundCreateSomethingArgs {
    parquet_file: PathBuf,
}

fn handle_result<T, TErr: ToString>(r: Result<T, TErr>) -> T {
    match r {
        Ok(v) => v,
        Err(e) => {
            let args = CliCommand::try_parse();
            match args.ok() {
                Some(a) => eprintln!("Error occured while executing command {:#?}", a),
                None => eprintln!("Error occured while executing an unparsable command"),
            };
            eprintln!();
            eprintln!("{}", e.to_string());
            process::exit(1);
        }
    }
}

fn get_compression(args: &ExportArgs) -> Result<parquet::basic::Compression, parquet::errors::ParquetError> {
    let lvl = args.compression_level;
    let level_not_supported = ||
        if lvl.is_some() {
            Err(parquet::errors::ParquetError::General(format!(
                "Compression algorithm {:?} does not allow setting --compression-level option",
                args.compression.as_ref().unwrap_or(&ParquetCompression::Zstd)
            )))
        } else {
            Ok(())
        };
    let compression = match args.compression {
        None => parquet::basic::Compression::ZSTD(ZstdLevel::try_new(lvl.unwrap_or(3))?),
        Some(ParquetCompression::Brotli) => parquet::basic::Compression::BROTLI(BrotliLevel::try_new(lvl.unwrap_or(3) as u32)?),
        Some(ParquetCompression::Gzip) => parquet::basic::Compression::GZIP(GzipLevel::try_new(lvl.unwrap_or(3) as u32)?),
        Some(ParquetCompression::Zstd) => parquet::basic::Compression::ZSTD(ZstdLevel::try_new(lvl.unwrap_or(3))?),
        Some(ParquetCompression::Lzo) => { level_not_supported()?; parquet::basic::Compression::LZO }
        Some(ParquetCompression::Lz4) => { level_not_supported()?; parquet::basic::Compression::LZ4 }
        Some(ParquetCompression::Snappy) => { level_not_supported()?; parquet::basic::Compression::SNAPPY }
        Some(ParquetCompression::None) => { level_not_supported()?; parquet::basic::Compression::UNCOMPRESSED }
    };
    Ok(compression)
}

fn perform_export(args: ExportArgs) {
    if args.query.is_some() && args.table.is_some() {
        eprintln!("Either query or table must be specified, but not both");
        process::exit(1);
    }
    if args.query.is_none() && args.table.is_none() {
        eprintln!("Either query or table must be specified");
        process::exit(1);
    }

    let compression = get_compression(&args).unwrap_or_else(|e| {
        eprintln!("Invalid combination of compression and compression_level: {}", e);
        process::exit(1);
    });

    let batch_size = match compression {
        // use smaller page size if shitty compression is chosen
        Compression::UNCOMPRESSED | Compression::SNAPPY | Compression::LZO | Compression::LZ4 =>
            DEFAULT_WRITE_BATCH_SIZE,
        Compression::ZSTD(lvl) if lvl.compression_level() <= 2 =>
            DEFAULT_WRITE_BATCH_SIZE,
        // otherwise prefer larger page size to improve the compression ratio slightly
        // the parquet library doesn't parallelize compression anyway
        _ => 1024 * 128,
    };

    let props =
        parquet::file::properties::WriterProperties::builder()
            .set_compression(compression)
            .set_write_batch_size(batch_size)
            .set_created_by(format!("pg2parquet version {}, using {}", env!("CARGO_PKG_VERSION"), parquet::file::properties::DEFAULT_CREATED_BY))
        .build();
    let props = Arc::new(props);

    let settings = SchemaSettings {
        macaddr_handling: args.schema_settings.macaddr_handling,
        json_handling: args.schema_settings.json_handling,
        enum_handling: args.schema_settings.enum_handling,
        interval_handling: args.schema_settings.interval_handling,
        numeric_handling: args.schema_settings.numeric_handling,
        decimal_scale: args.schema_settings.decimal_scale,
        decimal_precision: args.schema_settings.decimal_precision,
        array_handling: args.schema_settings.array_handling,
        float16_handling: args.schema_settings.float16_handling,
    };
    let query = args.query.unwrap_or_else(|| {
        format!("SELECT * FROM {}", args.table.unwrap())
    });
    let result = postgres_cloner::execute_copy(&args.postgres, &query, &args.output_file, props, args.quiet, &settings);
    let _stats = handle_result(result);

    // eprintln!("Wrote {} rows, {} bytes of raw data in {} groups", stats.rows, stats.bytes, stats.groups);
}

fn parse_args() -> CliCommand {
    let args = CliCommand::parse();
    
    // Validate connection arguments for Export command
    if let CliCommand::Export(ref export_args) = args {
        if let Err(err) = export_args.postgres.validate() {
            eprintln!("Error: {}", err);
            process::exit(2);
        }
    }
    
    args
}

fn main() {
    let default_hook = std::panic::take_hook();
    std::panic::set_hook(Box::new(move |x| {
        default_hook(x);
        eprintln!();
        eprintln!("pg2parquet probably should not crash in this way, could you please report a bug at https://github.com/exyi/pg2parquet/issues/new? (ideally with the backtrace and some info on what you did)");
    }));
    let args = parse_args();

    match args {
        CliCommand::ParquetInfo(args) => {
            eprintln!("parquet file: {:?}", args.parquet_file);
            parquetinfo::print_parquet_info(&args.parquet_file);
        },
        CliCommand::PlaygroundCreateSomething(args) => {
            eprintln!("parquet file: {:?}", args.parquet_file);
            playground::create_something(&args.parquet_file);
        },
        CliCommand::Export(args) => {
            perform_export(args);
        }
        #[cfg(feature = "manpage_gen")]
        CliCommand::ManpageGen(args) => {
            let mut cmd = CliCommand::command();
            // {
            //     let x = cmd.find_subcommand_mut("pg2parquet-help").unwrap();
            //     *x = x.clone().hide(true);
            // }
            cmd.build(); // prepends `pg2parquet-` to subcommands
            if let Some(out_dir) = args.out_manpage {
                let man = clap_mangen::Man::new(cmd.clone());
                let mut buffer = Vec::new();
                man.render(&mut buffer).unwrap();

                let man_export = clap_mangen::Man::new(cmd.get_subcommands().find(|c| c.get_name() == "export").unwrap().clone());
                man_export.render(&mut buffer).unwrap();

                std::fs::write(out_dir.join(&format!("{}.1", cmd.get_name())), buffer).unwrap();

                for cmd_ref in cmd.get_subcommands().filter(|c| !c.is_hide_set()) {
                    let cmd = cmd_ref.clone();
                    let man = clap_mangen::Man::new(cmd);
                    let mut buffer = Vec::new();
                    man.render(&mut buffer).unwrap();
                    std::fs::write(out_dir.join(&format!("pg2parquet-{}.1", cmd_ref.get_name())), buffer).unwrap();
                }
            }

            if let Some(out_dir) = args.out_completion {
                let name = "pg2parquet";
                let file = |name: &str| std::fs::File::create(out_dir.join(name)).unwrap();
                clap_complete::generate(clap_complete::shells::Bash, &mut cmd, name, &mut file("complete.bash"));
                clap_complete::generate(clap_complete::shells::Zsh, &mut cmd, name, &mut file("complete.zsh"));
                clap_complete::generate(clap_complete::shells::Elvish, &mut cmd, name, &mut file("complete.elv"));
                clap_complete::generate(clap_complete::shells::Fish, &mut cmd, name, &mut file("complete.fish"));
                clap_complete::generate(clap_complete_nushell::Nushell, &mut cmd, name, &mut file("complete.nu"));
            }
        }
    }
}