elefant-tools 0.0.2

A library for doing things like pg_dump and pg_restore, with extra features, and probably more bugs.
Documentation
use crate::object_id::ObjectId;
use crate::postgres_client_wrapper::FromPgChar;
use crate::quoting::AttemptedKeywordUsage::TypeOrFunctionName;
use crate::quoting::{quote_value_string, IdentifierQuoter, Quotable};
use crate::whitespace_ignorant_string::WhitespaceIgnorantString;
use crate::{ElefantToolsError, PostgresSchema};
use ordered_float::NotNan;
use serde::{Deserialize, Serialize};
use std::fmt::Display;

#[derive(Debug, Eq, PartialEq, Copy, Clone, Default, Serialize, Deserialize)]
pub enum FunctionKind {
    #[default]
    Function,
    Procedure,
    Aggregate,
    Window,
}

impl FromPgChar for FunctionKind {
    fn from_pg_char(c: char) -> Result<Self, crate::ElefantToolsError> {
        match c {
            'f' => Ok(FunctionKind::Function),
            'p' => Ok(FunctionKind::Procedure),
            'a' => Ok(FunctionKind::Aggregate),
            'w' => Ok(FunctionKind::Window),
            _ => Err(ElefantToolsError::UnknownFunctionKind(c.to_string())),
        }
    }
}

#[derive(Debug, Eq, PartialEq, Copy, Clone, Default, Serialize, Deserialize)]
pub enum Volatility {
    Immutable,
    Stable,
    #[default]
    Volatile,
}

impl FromPgChar for Volatility {
    fn from_pg_char(c: char) -> Result<Self, ElefantToolsError> {
        match c {
            'i' => Ok(Volatility::Immutable),
            's' => Ok(Volatility::Stable),
            'v' => Ok(Volatility::Volatile),
            _ => Err(ElefantToolsError::UnknownVolatility(c.to_string())),
        }
    }
}

#[derive(Debug, Eq, PartialEq, Copy, Clone, Default, Serialize, Deserialize)]
pub enum Parallel {
    Safe,
    Restricted,
    #[default]
    Unsafe,
}

impl FromPgChar for Parallel {
    fn from_pg_char(c: char) -> Result<Self, ElefantToolsError> {
        match c {
            's' => Ok(Parallel::Safe),
            'r' => Ok(Parallel::Restricted),
            'u' => Ok(Parallel::Unsafe),
            _ => Err(ElefantToolsError::UnknownParallel(c.to_string())),
        }
    }
}

#[derive(Debug, Eq, PartialEq, Default, Clone, Serialize, Deserialize)]
pub struct PostgresFunction {
    pub function_name: String,
    pub language: String,
    pub estimated_cost: NotNan<f32>,
    pub estimated_rows: NotNan<f32>,
    pub support_function: Option<String>,
    pub kind: FunctionKind,
    pub security_definer: bool,
    pub leak_proof: bool,
    pub strict: bool,
    pub returns_set: bool,
    pub volatility: Volatility,
    pub parallel: Parallel,
    pub sql_body: WhitespaceIgnorantString,
    pub configuration: Option<Vec<String>>,
    pub arguments: String,
    pub result: Option<String>,
    pub comment: Option<String>,
    pub object_id: ObjectId,
    pub depends_on: Vec<ObjectId>,
}

impl PostgresFunction {
    pub fn get_create_statement(
        &self,
        schema: &PostgresSchema,
        identifier_quoter: &IdentifierQuoter,
    ) -> String {
        let fn_name = format!(
            "{}.{}",
            schema.name.quote(identifier_quoter, TypeOrFunctionName),
            &self
                .function_name
                .quote(identifier_quoter, TypeOrFunctionName)
        );

        let function_keyword = if self.kind == FunctionKind::Procedure {
            "procedure"
        } else {
            "function"
        };

        let mut sql = format!(
            "create {} {} ({})",
            function_keyword, fn_name, self.arguments
        );

        if let Some(result) = &self.result {
            sql.push_str(" returns ");

            sql.push_str(result);
        }

        sql.push_str(" language ");
        sql.push_str(&self.language);

        if self.kind == FunctionKind::Window {
            sql.push_str("window ");
        }

        if self.kind != FunctionKind::Procedure {
            match self.volatility {
                Volatility::Immutable => sql.push_str(" immutable "),
                Volatility::Stable => sql.push_str(" stable "),
                Volatility::Volatile => sql.push_str(" volatile "),
            }

            match self.parallel {
                Parallel::Safe => sql.push_str(" parallel safe "),
                Parallel::Restricted => sql.push_str(" parallel restricted "),
                Parallel::Unsafe => sql.push_str(" parallel unsafe "),
            }

            if self.leak_proof {
                sql.push_str(" leakproof ");
            }

            if self.strict {
                sql.push_str(" strict ");
            }
        }

        if self.security_definer {
            sql.push_str(" security definer ");
        }

        if let Some(configuration) = &self.configuration {
            sql.push_str(" set ");
            for cfg in configuration {
                sql.push_str(cfg);
            }
        }

        if self.kind != FunctionKind::Procedure {
            sql.push_str("cost ");
            sql.push_str(&self.estimated_cost.to_string());

            if self.estimated_rows.into_inner() > 0. {
                sql.push_str(" rows ");
                sql.push_str(&self.estimated_rows.to_string());
            }

            if let Some(support_function_name) = &self.support_function {
                sql.push_str(" support ");
                sql.push_str(support_function_name);
            }
        }

        sql.push_str(" as $$");
        sql.push_str(&self.sql_body);
        sql.push_str("$$;");

        if let Some(comment) = &self.comment {
            sql.push_str("\ncomment on ");
            sql.push_str(function_keyword);
            String::push_str(&mut sql, " ");
            sql.push_str(&fn_name);
            sql.push_str(" is ");
            sql.push_str(&quote_value_string(comment));
            sql.push(';');
        }

        sql
    }
}

#[derive(Debug, Eq, PartialEq, Copy, Clone, Default, Serialize, Deserialize)]
pub enum FinalModify {
    #[default]
    ReadOnly,
    Shareable,
    ReadWrite,
}

impl FromPgChar for FinalModify {
    fn from_pg_char(c: char) -> Result<Self, ElefantToolsError> {
        match c {
            'r' => Ok(FinalModify::ReadOnly),
            's' => Ok(FinalModify::Shareable),
            'w' => Ok(FinalModify::ReadWrite),
            _ => Err(ElefantToolsError::UnknownAggregateFinalFunctionModify(
                c.to_string(),
            )),
        }
    }
}

impl Display for FinalModify {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let str = match self {
            FinalModify::ReadOnly => "read_only",
            FinalModify::Shareable => "shareable",
            FinalModify::ReadWrite => "read_write",
        };
        write!(f, "{}", str)
    }
}

#[derive(Debug, Eq, PartialEq, Default, Clone, Serialize, Deserialize)]
pub struct PostgresAggregateFunction {
    pub function_name: String,
    pub arguments: String,
    pub state_transition_function: String,
    pub final_function: Option<String>,
    pub combine_function: Option<String>,
    pub serial_function: Option<String>,
    pub deserial_function: Option<String>,
    pub moving_state_transition_function: Option<String>,
    pub inverse_moving_state_transition_function: Option<String>,
    pub moving_final_function: Option<String>,
    pub final_extra_data: bool,
    pub moving_final_extra_data: bool,
    pub final_modify: FinalModify,
    pub moving_final_modify: FinalModify,
    pub sort_operator: Option<String>,
    pub transition_type: String,
    pub transition_space: Option<i32>,
    pub moving_transition_type: Option<String>,
    pub moving_transition_space: Option<i32>,
    pub initial_value: Option<String>,
    pub moving_initial_value: Option<String>,
    pub parallel: Parallel,
    pub object_id: ObjectId,
    pub depends_on: Vec<ObjectId>,
}

impl PostgresAggregateFunction {
    pub fn get_create_statement(
        &self,
        schema: &PostgresSchema,
        identifier_quoter: &IdentifierQuoter,
    ) -> String {
        let fn_name = format!(
            "{}.{}",
            schema.name.quote(identifier_quoter, TypeOrFunctionName),
            &self
                .function_name
                .quote(identifier_quoter, TypeOrFunctionName)
        );

        let mut sql = format!("create aggregate {} ({}) (\n", fn_name, self.arguments);

        sql.push_str("\tsfunc = ");
        sql.push_str(&self.state_transition_function);
        sql.push_str(",\n\tstype=");
        sql.push_str(&self.transition_type);

        if let Some(transition_space) = &self.transition_space {
            sql.push_str(",\n\tsspace=");
            sql.push_str(&transition_space.to_string());
        }

        if let Some(serial_function) = &self.serial_function {
            sql.push_str(",\n\tsfunc=");
            sql.push_str(serial_function);
        }

        if let Some(deserial_function) = &self.deserial_function {
            sql.push_str(",\n\tdfunc=");
            sql.push_str(deserial_function);
        }

        if let Some(initial_value) = &self.initial_value {
            sql.push_str(",\n\tinitcond=");
            sql.push_str(initial_value);
        }

        if let Some(final_function) = &self.final_function {
            sql.push_str(",\n\tfinalfunc=");
            sql.push_str(final_function);

            sql.push_str(",\n\tfinalfunc_modify=");
            sql.push_str(&self.final_modify.to_string());

            if self.final_extra_data {
                sql.push_str(",\n\tfinalfunc_extra");
            }
        }

        if let Some(moving_state_transition_function) = &self.moving_state_transition_function {
            sql.push_str(",\n\tmsfunc=");
            sql.push_str(moving_state_transition_function);
        }

        if let Some(inverse_moving_state_transition_function) =
            &self.inverse_moving_state_transition_function
        {
            sql.push_str(",\n\tminv_sfunc=");
            sql.push_str(inverse_moving_state_transition_function);
        }

        if let Some(moving_final_function) = &self.moving_final_function {
            sql.push_str(",\n\tmfinalfunc=");
            sql.push_str(moving_final_function);

            sql.push_str(",\n\tmfinalfunc_modify=");
            sql.push_str(&self.moving_final_modify.to_string());

            if self.moving_final_extra_data {
                sql.push_str(",\n\tmfinalfunc_extra");
            }
        }

        if let Some(moving_transition_type) = &self.moving_transition_type {
            sql.push_str(",\n\tmstype=");
            sql.push_str(moving_transition_type);

            if let Some(moving_transition_space) = &self.moving_transition_space {
                sql.push_str(",\n\tmsspace=");
                sql.push_str(&moving_transition_space.to_string());
            }
        }

        if let Some(moving_initial_value) = &self.moving_initial_value {
            sql.push_str(",\n\tminitcond=");
            sql.push_str(moving_initial_value);
        }

        if let Some(sort_operator) = &self.sort_operator {
            sql.push_str(",\n\tsortop=");
            sql.push_str(sort_operator);
        }

        match self.parallel {
            Parallel::Safe => sql.push_str(",\n\tparallel=safe"),
            Parallel::Restricted => sql.push_str(",\n\tparallel=restricted"),
            Parallel::Unsafe => sql.push_str(",\n\tparallel=unsafe"),
        }

        sql.push_str("\n);");

        sql
    }
}