use crate::lint::LintSettings;
use anyhow::{Result, anyhow};
use clap::ValueEnum;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Settings {
pub input: Option<String>,
#[serde(default)]
pub var_files: Vec<String>,
#[serde(default)]
pub env: HashMap<String, String>,
#[serde(default)]
pub test_backend: Option<String>,
#[serde(default)]
pub test_dsn: Option<String>,
#[serde(default)]
pub lint: LintSettings,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TargetConfig {
pub name: String,
pub backend: String,
pub input: Option<String>,
pub output: Option<String>,
#[serde(default)]
pub include: Vec<String>,
#[serde(default)]
pub exclude: Vec<String>,
#[serde(default)]
pub vars: HashMap<String, toml::Value>,
#[serde(default)]
pub var_files: Vec<String>,
#[serde(flatten)]
pub options: std::collections::HashMap<String, toml::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
#[serde(default)]
pub settings: Settings,
pub targets: Vec<TargetConfig>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, ValueEnum)]
pub enum ResourceKind {
Schemas,
Enums,
Domains,
Types,
Tables,
Views,
Materialized,
Aggregates,
Operators,
Functions,
Procedures,
Triggers,
Rules,
EventTriggers,
Extensions,
Collations,
Sequences,
Indexes,
Statistics,
Policies,
Roles,
Tablespaces,
Grants,
ForeignDataWrappers,
ForeignServers,
ForeignTables,
TextSearchDictionaries,
TextSearchConfigurations,
TextSearchTemplates,
TextSearchParsers,
Publications,
Subscriptions,
Tests,
}
impl ResourceKind {
pub const ALL: [ResourceKind; 33] = [
ResourceKind::Schemas,
ResourceKind::Enums,
ResourceKind::Domains,
ResourceKind::Types,
ResourceKind::Tables,
ResourceKind::Views,
ResourceKind::Materialized,
ResourceKind::Aggregates,
ResourceKind::Operators,
ResourceKind::Functions,
ResourceKind::Procedures,
ResourceKind::Triggers,
ResourceKind::Rules,
ResourceKind::EventTriggers,
ResourceKind::Extensions,
ResourceKind::Collations,
ResourceKind::Sequences,
ResourceKind::Indexes,
ResourceKind::Statistics,
ResourceKind::Policies,
ResourceKind::Roles,
ResourceKind::Tablespaces,
ResourceKind::Grants,
ResourceKind::ForeignDataWrappers,
ResourceKind::ForeignServers,
ResourceKind::ForeignTables,
ResourceKind::TextSearchDictionaries,
ResourceKind::TextSearchConfigurations,
ResourceKind::TextSearchTemplates,
ResourceKind::TextSearchParsers,
ResourceKind::Publications,
ResourceKind::Subscriptions,
ResourceKind::Tests,
];
pub fn default_include_set() -> HashSet<ResourceKind> {
Self::ALL.iter().copied().collect()
}
}
impl fmt::Display for ResourceKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
ResourceKind::Schemas => "schemas",
ResourceKind::Enums => "enums",
ResourceKind::Domains => "domains",
ResourceKind::Types => "types",
ResourceKind::Tables => "tables",
ResourceKind::Views => "views",
ResourceKind::Materialized => "materialized",
ResourceKind::Aggregates => "aggregates",
ResourceKind::Operators => "operators",
ResourceKind::Functions => "functions",
ResourceKind::Procedures => "procedures",
ResourceKind::Triggers => "triggers",
ResourceKind::Rules => "rules",
ResourceKind::EventTriggers => "event_triggers",
ResourceKind::Extensions => "extensions",
ResourceKind::Collations => "collations",
ResourceKind::Sequences => "sequences",
ResourceKind::Indexes => "indexes",
ResourceKind::Statistics => "statistics",
ResourceKind::Policies => "policies",
ResourceKind::Roles => "roles",
ResourceKind::Tablespaces => "tablespaces",
ResourceKind::Grants => "grants",
ResourceKind::ForeignDataWrappers => "foreign_data_wrappers",
ResourceKind::ForeignServers => "foreign_servers",
ResourceKind::ForeignTables => "foreign_tables",
ResourceKind::TextSearchDictionaries => "text_search_dictionaries",
ResourceKind::TextSearchConfigurations => "text_search_configurations",
ResourceKind::TextSearchTemplates => "text_search_templates",
ResourceKind::TextSearchParsers => "text_search_parsers",
ResourceKind::Publications => "publications",
ResourceKind::Subscriptions => "subscriptions",
ResourceKind::Tests => "tests",
};
write!(f, "{}", s)
}
}
impl std::str::FromStr for ResourceKind {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"schemas" => Ok(ResourceKind::Schemas),
"enums" => Ok(ResourceKind::Enums),
"domains" => Ok(ResourceKind::Domains),
"types" => Ok(ResourceKind::Types),
"tables" => Ok(ResourceKind::Tables),
"views" => Ok(ResourceKind::Views),
"materialized" => Ok(ResourceKind::Materialized),
"aggregates" => Ok(ResourceKind::Aggregates),
"operators" => Ok(ResourceKind::Operators),
"functions" => Ok(ResourceKind::Functions),
"procedures" => Ok(ResourceKind::Procedures),
"triggers" => Ok(ResourceKind::Triggers),
"rules" => Ok(ResourceKind::Rules),
"event_triggers" => Ok(ResourceKind::EventTriggers),
"extensions" => Ok(ResourceKind::Extensions),
"collations" => Ok(ResourceKind::Collations),
"sequences" => Ok(ResourceKind::Sequences),
"indexes" => Ok(ResourceKind::Indexes),
"statistics" => Ok(ResourceKind::Statistics),
"policies" => Ok(ResourceKind::Policies),
"roles" => Ok(ResourceKind::Roles),
"tablespaces" => Ok(ResourceKind::Tablespaces),
"grants" => Ok(ResourceKind::Grants),
"foreign_data_wrappers" => Ok(ResourceKind::ForeignDataWrappers),
"foreign_servers" => Ok(ResourceKind::ForeignServers),
"foreign_tables" => Ok(ResourceKind::ForeignTables),
"text_search_dictionaries" => Ok(ResourceKind::TextSearchDictionaries),
"text_search_configurations" => Ok(ResourceKind::TextSearchConfigurations),
"text_search_templates" => Ok(ResourceKind::TextSearchTemplates),
"text_search_parsers" => Ok(ResourceKind::TextSearchParsers),
"publications" => Ok(ResourceKind::Publications),
"subscriptions" => Ok(ResourceKind::Subscriptions),
"tests" => Ok(ResourceKind::Tests),
_ => Err(format!("invalid resource kind: {}", s)),
}
}
}
impl TargetConfig {
pub fn get_include_set(&self) -> Result<HashSet<ResourceKind>> {
if self.include.is_empty() {
Ok(ResourceKind::default_include_set())
} else {
parse_resource_kinds(&self.include)
}
}
pub fn get_exclude_set(&self) -> Result<HashSet<ResourceKind>> {
parse_resource_kinds(&self.exclude)
}
}
fn parse_resource_kinds(values: &[String]) -> Result<HashSet<ResourceKind>> {
values
.iter()
.map(|s| s.parse::<ResourceKind>().map_err(|e| anyhow!(e)))
.collect()
}
pub fn load_config() -> Result<Option<Config>> {
load_config_from_path(Path::new("dbschema.toml"))
}
pub fn load_config_from_path(path: &Path) -> Result<Option<Config>> {
if !path.exists() {
return Ok(None);
}
let content = std::fs::read_to_string(path)?;
let config: Config = toml::from_str(&content)?;
Ok(Some(config))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resource_kind_from_str() {
assert_eq!("tables".parse::<ResourceKind>(), Ok(ResourceKind::Tables));
assert_eq!("TABLES".parse::<ResourceKind>(), Ok(ResourceKind::Tables));
assert!("invalid".parse::<ResourceKind>().is_err());
}
#[test]
fn test_target_config_include_all() {
let target = TargetConfig {
name: "test".to_string(),
backend: "postgres".to_string(),
input: None,
output: None,
include: vec![],
exclude: vec![],
vars: Default::default(),
var_files: vec![],
options: Default::default(),
};
let include_set = target.get_include_set().unwrap();
assert!(include_set.contains(&ResourceKind::Tables));
assert!(include_set.contains(&ResourceKind::Enums));
assert!(include_set.contains(&ResourceKind::EventTriggers));
assert!(include_set.contains(&ResourceKind::Aggregates));
assert!(include_set.contains(&ResourceKind::Collations));
assert!(include_set.contains(&ResourceKind::Indexes));
assert!(include_set.contains(&ResourceKind::ForeignDataWrappers));
assert_eq!(include_set.len(), ResourceKind::ALL.len());
}
#[test]
fn test_target_config_include_specific() {
let target = TargetConfig {
name: "test".to_string(),
backend: "prisma".to_string(),
input: None,
output: None,
include: vec!["tables".to_string(), "enums".to_string()],
exclude: vec![],
vars: Default::default(),
var_files: vec![],
options: Default::default(),
};
let include_set = target.get_include_set().unwrap();
assert!(include_set.contains(&ResourceKind::Tables));
assert!(include_set.contains(&ResourceKind::Enums));
assert!(!include_set.contains(&ResourceKind::Functions));
assert_eq!(include_set.len(), 2);
}
#[test]
fn test_target_config_exclude() {
let target = TargetConfig {
name: "test".to_string(),
backend: "postgres".to_string(),
input: None,
output: None,
include: vec![],
exclude: vec!["functions".to_string(), "triggers".to_string()],
vars: Default::default(),
var_files: vec![],
options: Default::default(),
};
let include_set = target.get_include_set().unwrap();
let exclude_set = target.get_exclude_set().unwrap();
assert!(include_set.contains(&ResourceKind::Tables));
assert!(!exclude_set.contains(&ResourceKind::Tables));
assert!(exclude_set.contains(&ResourceKind::Functions));
assert!(exclude_set.contains(&ResourceKind::Triggers));
assert!(!exclude_set.contains(&ResourceKind::EventTriggers));
}
#[test]
fn test_target_config_invalid_resource() {
let target = TargetConfig {
name: "test".to_string(),
backend: "postgres".to_string(),
input: None,
output: None,
include: vec!["not_a_resource".to_string()],
exclude: vec![],
vars: Default::default(),
var_files: vec![],
options: Default::default(),
};
let err = target.get_include_set().unwrap_err();
assert!(err.to_string().contains("invalid resource kind"));
}
}