use crate::{apis::coredb_types::CoreDB, controller::patch_cdb_status_merge, defaults, Context, Error};
use kube::api::Api;
use lazy_static::lazy_static;
use regex::Regex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tracing::{debug, error, info, warn};
lazy_static! {
static ref VALID_INPUT: Regex = Regex::new(r"^[a-zA-Z]([a-zA-Z0-9]*[-_]?)*[a-zA-Z0-9]+$").unwrap();
}
#[derive(Clone, Debug, Deserialize, Eq, Hash, JsonSchema, Serialize, PartialEq)]
pub struct Extension {
pub name: String,
#[serde(default = "defaults::default_description")]
pub description: String,
pub locations: Vec<ExtensionInstallLocation>,
}
impl Default for Extension {
fn default() -> Self {
Extension {
name: "pg_stat_statements".to_owned(),
description: " track planning and execution statistics of all SQL statements executed".to_owned(),
locations: vec![ExtensionInstallLocation::default()],
}
}
}
#[derive(Clone, Debug, Deserialize, Eq, Hash, JsonSchema, Serialize, PartialEq)]
pub struct ExtensionInstallLocation {
pub enabled: bool,
#[serde(default = "defaults::default_database")]
pub database: String,
#[serde(default = "defaults::default_schema")]
pub schema: String,
pub version: Option<String>,
}
impl Default for ExtensionInstallLocation {
fn default() -> Self {
ExtensionInstallLocation {
schema: "public".to_owned(),
database: "postgres".to_owned(),
enabled: true,
version: Some("1.9".to_owned()),
}
}
}
#[derive(Debug)]
pub struct ExtRow {
pub name: String,
pub description: String,
pub version: String,
pub enabled: bool,
pub schema: String,
}
const LIST_DATABASES_QUERY: &str = r#"SELECT datname FROM pg_database WHERE datistemplate = false;"#;
const LIST_EXTENSIONS_QUERY: &str = r#"select
distinct on
(name) *
from
(
select
name,
version,
enabled,
schema,
description
from
(
select
t0.extname as name,
t0.extversion as version,
true as enabled,
t1.nspname as schema,
comment as description
from
(
select
extnamespace,
extname,
extversion
from
pg_extension
) t0,
(
select
oid,
nspname
from
pg_namespace
) t1,
(
select
name,
comment
from
pg_catalog.pg_available_extensions
) t2
where
t1.oid = t0.extnamespace
and t2.name = t0.extname
) installed
union
select
name,
default_version as version,
false as enabled,
'public' as schema,
comment as description
from
pg_catalog.pg_available_extensions
order by
enabled asc
) combined
order by
name asc,
enabled desc
"#;
pub async fn install_extension(
cdb: &CoreDB,
extensions: &[Extension],
ctx: Arc<Context>,
) -> Result<(), Error> {
debug!("extensions to install: {:?}", extensions);
let client = ctx.client.clone();
let pod_name = cdb
.primary_pod(client.clone())
.await
.unwrap()
.metadata
.name
.unwrap();
let mut errors: Vec<Error> = Vec::new();
let num_to_install = extensions.len();
for ext in extensions.iter() {
let version = ext.locations[0].version.clone().unwrap();
let cmd = vec![
"trunk".to_owned(),
"install".to_owned(),
"-r https://registry.pgtrunk.io".to_owned(),
ext.name.clone(),
"--version".to_owned(),
version,
];
let result = cdb.exec(pod_name.clone(), client.clone(), &cmd).await;
match result {
Ok(result) => {
debug!("installed extension: {}", result.stdout.clone().unwrap());
}
Err(err) => {
error!("error installing extension, {}", err);
errors.push(err);
}
}
}
let num_success = num_to_install - errors.len();
info!(
"Successfully installed {} / {} extensions",
num_success, num_to_install
);
Ok(())
}
pub async fn toggle_extensions(
cdb: &CoreDB,
extensions: &[Extension],
ctx: Arc<Context>,
) -> Result<(), Error> {
let client = ctx.client.clone();
for ext in extensions {
let ext_name = ext.name.as_str();
if !check_input(ext_name) {
warn!(
"Extension {} is not formatted properly. Skipping operation.",
ext_name
)
} else {
for ext_loc in ext.locations.iter() {
let database_name = ext_loc.database.to_owned();
if !check_input(&database_name) {
warn!(
"Extension.Database {}.{} is not formatted properly. Skipping operation.",
ext_name, database_name
);
continue;
}
let command = match ext_loc.enabled {
true => {
info!("Creating extension: {}, database {}", ext_name, database_name);
let schema_name = ext_loc.schema.to_owned();
if !check_input(&schema_name) {
warn!(
"Extension.Database.Schema {}.{}.{} is not formatted properly. Skipping operation.",
ext_name, database_name, schema_name
);
continue;
}
format!("CREATE EXTENSION IF NOT EXISTS \"{ext_name}\" SCHEMA {schema_name} cascade;")
}
false => {
info!("Dropping extension: {}, database {}", ext_name, database_name);
format!("DROP EXTENSION IF EXISTS \"{ext_name}\" CASCADE;")
}
};
let result = cdb
.psql(command.clone(), database_name.clone(), client.clone())
.await;
match result {
Ok(result) => {
debug!("Result: {}", result.stdout.clone().unwrap());
}
Err(err) => {
error!("error managing extension");
return Err(err.into());
}
}
}
}
}
Ok(())
}
pub fn check_input(input: &str) -> bool {
VALID_INPUT.is_match(input)
}
pub async fn list_databases(cdb: &CoreDB, ctx: Arc<Context>) -> Result<Vec<String>, Error> {
let client = ctx.client.clone();
let psql_out = cdb
.psql(
LIST_DATABASES_QUERY.to_owned(),
"postgres".to_owned(),
client.clone(),
)
.await?;
let result_string = psql_out.stdout.unwrap();
Ok(parse_databases(&result_string))
}
fn parse_databases(psql_str: &str) -> Vec<String> {
let mut databases = vec![];
for line in psql_str.lines().skip(2) {
let fields: Vec<&str> = line.split('|').map(|s| s.trim()).collect();
if fields.is_empty()
|| fields[0].is_empty()
|| fields[0].contains("rows)")
|| fields[0].contains("row)")
{
debug!("Done:{:?}", fields);
continue;
}
databases.push(fields[0].to_string());
}
let num_databases = databases.len();
info!("Found {} databases", num_databases);
databases
}
pub async fn list_extensions(cdb: &CoreDB, ctx: Arc<Context>, database: &str) -> Result<Vec<ExtRow>, Error> {
let client = ctx.client.clone();
let psql_out = cdb
.psql(
LIST_EXTENSIONS_QUERY.to_owned(),
database.to_owned(),
client.clone(),
)
.await
.unwrap();
let result_string = psql_out.stdout.unwrap();
Ok(parse_extensions(&result_string))
}
fn parse_extensions(psql_str: &str) -> Vec<ExtRow> {
let mut extensions = vec![];
for line in psql_str.lines().skip(2) {
let fields: Vec<&str> = line.split('|').map(|s| s.trim()).collect();
if fields.len() < 5 {
debug!("Done:{:?}", fields);
continue;
}
let package = ExtRow {
name: fields[0].to_owned(),
version: fields[1].to_owned(),
enabled: fields[2] == "t",
schema: fields[3].to_owned(),
description: fields[4].to_owned(),
};
extensions.push(package);
}
let num_extensions = extensions.len();
debug!("Found {} extensions", num_extensions);
extensions
}
pub async fn get_all_extensions(cdb: &CoreDB, ctx: Arc<Context>) -> Result<Vec<Extension>, Error> {
let databases = list_databases(cdb, ctx.clone()).await?;
debug!("databases: {:?}", databases);
let mut ext_hashmap: HashMap<(String, String), Vec<ExtensionInstallLocation>> = HashMap::new();
for db in databases {
let extensions = list_extensions(cdb, ctx.clone(), &db).await?;
for ext in extensions {
let extlocation = ExtensionInstallLocation {
database: db.clone(),
version: Some(ext.version),
enabled: ext.enabled,
schema: ext.schema,
};
ext_hashmap
.entry((ext.name, ext.description))
.or_insert_with(Vec::new)
.push(extlocation);
}
}
let mut ext_spec: Vec<Extension> = Vec::new();
for ((extname, extdescr), ext_locations) in &ext_hashmap {
ext_spec.push(Extension {
name: extname.clone(),
description: extdescr.clone(),
locations: ext_locations.clone(),
});
}
ext_spec.sort_by_key(|e| e.name.clone());
Ok(ext_spec)
}
fn diff_extensions(desired: &[Extension], actual: &[Extension]) -> Vec<Extension> {
let set_desired: HashSet<_> = desired.iter().cloned().collect();
let set_actual: HashSet<_> = actual.iter().cloned().collect();
let mut diff: Vec<Extension> = set_desired.difference(&set_actual).cloned().collect();
diff.sort_by_key(|e| e.name.clone());
debug!("Extensions diff: {:?}", diff);
diff
}
fn extension_plan(have_changed: &[Extension], actual: &[Extension]) -> (Vec<Extension>, Vec<Extension>) {
let mut changed = Vec::new();
let mut to_install = Vec::new();
for extension_desired in have_changed {
let mut found = false;
for extension_actual in actual {
if extension_desired.name == extension_actual.name {
found = true;
'loc: for loc_desired in extension_desired.locations.clone() {
for loc_actual in extension_actual.locations.clone() {
if loc_desired.database == loc_actual.database {
if loc_desired.enabled != loc_actual.enabled {
debug!("desired: {:?}, actual: {:?}", extension_desired, extension_actual);
changed.push(extension_desired.clone());
break 'loc;
}
}
}
}
}
}
if !found {
to_install.push(extension_desired.clone());
}
}
debug!(
"extension to create/drop: {:?}, extensions to install: {:?}",
changed, to_install
);
(changed, to_install)
}
pub async fn reconcile_extensions(
coredb: &CoreDB,
ctx: Arc<Context>,
cdb_api: &Api<CoreDB>,
name: &str,
) -> Result<Vec<Extension>, Error> {
let actual_extensions = get_all_extensions(coredb, ctx.clone()).await?;
let mut desired_extensions = coredb.spec.extensions.clone();
desired_extensions.sort_by_key(|e| e.name.clone());
let extensions_changed = diff_extensions(&desired_extensions, &actual_extensions);
if extensions_changed.is_empty() {
return Ok(actual_extensions);
}
let (changed_extensions, extensions_to_install) = extension_plan(&extensions_changed, &actual_extensions);
if !changed_extensions.is_empty() || !extensions_to_install.is_empty() {
let status = serde_json::json!({
"status": {"extensionsUpdating": true}
});
let _ = patch_cdb_status_merge(cdb_api, name, status).await;
if !changed_extensions.is_empty() {
toggle_extensions(coredb, &changed_extensions, ctx.clone()).await?;
}
if !extensions_to_install.is_empty() {
install_extension(coredb, &extensions_to_install, ctx.clone()).await?;
}
let status = serde_json::json!({
"status": {"extensionsUpdating": false}
});
let _ = patch_cdb_status_merge(cdb_api, name, status).await;
}
get_all_extensions(coredb, ctx.clone()).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extension_plan() {
let postgis_disabled = Extension {
name: "postgis".to_owned(),
description: "my description".to_owned(),
locations: vec![ExtensionInstallLocation {
enabled: false,
database: "postgres".to_owned(),
schema: "public".to_owned(),
version: Some("1.1.1".to_owned()),
}],
};
let pgmq_disabled = Extension {
name: "pgmq".to_owned(),
description: "my description".to_owned(),
locations: vec![ExtensionInstallLocation {
enabled: false,
database: "postgres".to_owned(),
schema: "public".to_owned(),
version: Some("1.1.1".to_owned()),
}],
};
let diff = vec![pgmq_disabled.clone()];
let actual = vec![postgis_disabled];
let (changed, to_install) = extension_plan(&diff, &actual);
assert!(changed.is_empty());
assert!(to_install.len() == 1);
let diff = vec![pgmq_disabled.clone()];
let actual = vec![pgmq_disabled];
let (changed, to_install) = extension_plan(&diff, &actual);
assert!(changed.is_empty());
assert!(to_install.is_empty());
}
#[test]
fn test_diff_and_plan() {
let postgis_disabled = Extension {
name: "postgis".to_owned(),
description: "my description".to_owned(),
locations: vec![ExtensionInstallLocation {
enabled: false,
database: "postgres".to_owned(),
schema: "public".to_owned(),
version: Some("1.1.1".to_owned()),
}],
};
let postgis_enabled = Extension {
name: "postgis".to_owned(),
description: "my description".to_owned(),
locations: vec![ExtensionInstallLocation {
enabled: true,
database: "postgres".to_owned(),
schema: "public".to_owned(),
version: Some("1.1.1".to_owned()),
}],
};
let pgmq_disabled = Extension {
name: "pgmq".to_owned(),
description: "my description".to_owned(),
locations: vec![ExtensionInstallLocation {
enabled: false,
database: "postgres".to_owned(),
schema: "public".to_owned(),
version: Some("1.1.1".to_owned()),
}],
};
let pg_stat_enabled = Extension {
name: "pg_stat_statements".to_owned(),
description: "my description".to_owned(),
locations: vec![ExtensionInstallLocation {
enabled: true,
database: "postgres".to_owned(),
schema: "public".to_owned(),
version: Some("1.1.1".to_owned()),
}],
};
let desired = vec![
postgis_disabled.clone(),
pgmq_disabled.clone(),
pg_stat_enabled.clone(),
];
let actual = vec![postgis_enabled.clone(), pgmq_disabled.clone()];
let diff = diff_extensions(&desired, &actual);
assert!(
diff.len() == 2,
"expected two changed extensions, found extensions {:?}",
diff
);
assert_eq!(diff[0], pg_stat_enabled, "expected pg_stat, found {:?}", diff[0]);
assert_eq!(diff[1], postgis_disabled, "expected postgis, found {:?}", diff[1]);
let (changed, to_install) = extension_plan(&diff, &actual);
assert_eq!(changed.len(), 1);
assert!(
changed[0] == postgis_disabled,
"expected postgis changed to disabled, found {:?}",
changed[0]
);
assert_eq!(to_install.len(), 1, "expected 1 install, found {:?}", to_install);
assert!(
to_install[0] == pg_stat_enabled,
"expected pg_stat to install, found {:?}",
to_install[0]
);
}
#[test]
fn test_diff() {
let postgis_disabled = Extension {
name: "postgis".to_owned(),
description: "my description".to_owned(),
locations: vec![ExtensionInstallLocation {
enabled: false,
database: "postgres".to_owned(),
schema: "public".to_owned(),
version: Some("1.1.1".to_owned()),
}],
};
let pgmq_enabled = Extension {
name: "pgmq".to_owned(),
description: "my description".to_owned(),
locations: vec![ExtensionInstallLocation {
enabled: true,
database: "postgres".to_owned(),
schema: "public".to_owned(),
version: Some("1.1.1".to_owned()),
}],
};
let pgmq_disabled = Extension {
name: "pgmq".to_owned(),
description: "my description".to_owned(),
locations: vec![ExtensionInstallLocation {
enabled: false,
database: "postgres".to_owned(),
schema: "public".to_owned(),
version: Some("1.1.1".to_owned()),
}],
};
let desired = vec![];
let actual = vec![postgis_disabled.clone(), pgmq_enabled.clone()];
let diff = diff_extensions(&desired, &actual);
assert!(diff.is_empty());
let desired = vec![postgis_disabled.clone(), pgmq_enabled.clone()];
let actual = vec![postgis_disabled.clone(), pgmq_disabled.clone()];
let diff = diff_extensions(&desired, &actual);
assert_eq!(diff.len(), 1);
assert_eq!(diff[0], pgmq_enabled);
let desired = vec![pgmq_enabled.clone(), postgis_disabled.clone()];
let actual = vec![postgis_disabled.clone(), pgmq_disabled.clone()];
let diff = diff_extensions(&desired, &actual);
assert_eq!(diff.len(), 1);
assert_eq!(diff[0], pgmq_enabled);
let desired = vec![postgis_disabled.clone(), pgmq_enabled.clone()];
let actual = vec![postgis_disabled.clone(), pgmq_disabled.clone()];
let diff = diff_extensions(&desired, &actual);
assert_eq!(diff.len(), 1);
assert_eq!(diff[0], pgmq_enabled);
let desired = vec![postgis_disabled.clone(), pgmq_enabled.clone()];
let actual = vec![postgis_disabled.clone(), pgmq_enabled.clone()];
let diff = diff_extensions(&desired, &actual);
assert_eq!(diff.len(), 0);
let desired = vec![postgis_disabled.clone()];
let actual = vec![postgis_disabled.clone(), pgmq_enabled.clone()];
let diff = diff_extensions(&desired, &actual);
assert_eq!(diff.len(), 0);
}
#[test]
fn test_parse_databases() {
let three_db = " datname
----------
postgres
cat
dog
(3 rows)
";
let rows = parse_databases(three_db);
println!("{:?}", rows);
assert_eq!(rows.len(), 3);
assert_eq!(rows[0], "postgres");
assert_eq!(rows[1], "cat");
assert_eq!(rows[2], "dog");
let one_db = " datname
----------
postgres
(1 row)
";
let rows = parse_databases(one_db);
println!("{:?}", rows);
assert_eq!(rows.len(), 1);
assert_eq!(rows[0], "postgres");
}
#[test]
fn test_parse_extensions() {
let ext_psql = " name | version | enabled | schema | description
--------------------+---------+---------+------------+------------------------------------------------------------------------
adminpack | 2.1 | f | public | administrative functions for PostgreSQL
amcheck | 1.3 | f | public | functions for verifying relation integrity
autoinc | 1.0 | f | public | functions for autoincrementing fields
bloom | 1.0 | f | public | bloom access method - signature file based index
btree_gin | 1.3 | f | public | support for indexing common datatypes in GIN
btree_gist | 1.7 | f | public | support for indexing common datatypes in GiST
citext | 1.6 | f | public | data type for case-insensitive character strings
cube | 1.5 | f | public | data type for multidimensional cubes
dblink | 1.2 | f | public | connect to other PostgreSQL databases from within a database
(9 rows)";
let ext = parse_extensions(ext_psql);
assert_eq!(ext.len(), 9);
assert_eq!(ext[0].name, "adminpack");
assert_eq!(ext[0].enabled, false);
assert_eq!(ext[0].version, "2.1".to_owned());
assert_eq!(ext[0].schema, "public".to_owned());
assert_eq!(
ext[0].description,
"administrative functions for PostgreSQL".to_owned()
);
assert_eq!(ext[8].name, "dblink");
assert_eq!(ext[8].enabled, false);
assert_eq!(ext[8].version, "1.2".to_owned());
assert_eq!(ext[8].schema, "public".to_owned());
assert_eq!(
ext[8].description,
"connect to other PostgreSQL databases from within a database".to_owned()
);
}
#[test]
fn test_check_input() {
let invalids = ["extension--", "data;", "invalid^#$$characters", ";invalid", ""];
for i in invalids.iter() {
assert!(!check_input(i), "input {} should be invalid", i);
}
let valids = [
"extension_a",
"schema_abc",
"extension",
"NewExtension",
"NewExtension123",
"postgis_tiger_geocoder-3",
"address_standardizer-3",
"xml2",
];
for i in valids.iter() {
assert!(check_input(i), "input {} should be valid", i);
}
}
}