mod database;
mod format;
mod lint;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
pub use banshee_templater::PlaceholderStyle;
pub use database::DatabaseSettings;
pub use format::{
CasePolicy, CommaStyle, FormatSettings, IdentifierCase, IndentUnit, KeywordCase,
PgFormatSettings, StylePreset,
};
pub use lint::{LintSettings, RuleSetting, SeverityLevel};
pub const CONFIG_FILE_NAME: &str = "banshee.toml";
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case", default)]
pub struct CompletionSettings {
pub keywords: bool,
pub functions: bool,
pub max_items: usize,
}
impl Default for CompletionSettings {
fn default() -> Self {
Self {
keywords: true,
functions: true,
max_items: 200,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case", default)]
pub struct BansheeConfig {
pub format: FormatSettings,
pub lint: LintSettings,
pub completion: CompletionSettings,
pub database: DatabaseSettings,
pub templater: TemplaterSettings,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case", default)]
pub struct TemplaterSettings {
pub style: Option<banshee_templater::PlaceholderStyle>,
}
#[derive(Debug)]
pub enum ConfigError {
Io(std::io::Error),
Parse(toml::de::Error),
}
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConfigError::Io(e) => write!(f, "failed to read config: {e}"),
ConfigError::Parse(e) => write!(f, "invalid config: {e}"),
}
}
}
impl std::error::Error for ConfigError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
ConfigError::Io(e) => Some(e),
ConfigError::Parse(e) => Some(e),
}
}
}
impl BansheeConfig {
pub fn from_toml(src: &str) -> Result<Self, ConfigError> {
toml::from_str(src).map_err(ConfigError::Parse)
}
pub fn load(path: &Path) -> Result<Self, ConfigError> {
let src = std::fs::read_to_string(path).map_err(ConfigError::Io)?;
Self::from_toml(&src)
}
#[must_use]
pub fn find_config_file(start: &Path) -> Option<PathBuf> {
let mut dir = if start.is_file() {
start.parent()?.to_path_buf()
} else {
start.to_path_buf()
};
loop {
let candidate = dir.join(CONFIG_FILE_NAME);
if candidate.is_file() {
return Some(candidate);
}
if !dir.pop() {
return None;
}
}
}
pub fn discover(start: &Path) -> Result<Self, ConfigError> {
match Self::find_config_file(start) {
Some(path) => Self::load(&path),
None => Ok(Self::default()),
}
}
pub fn format_config(&self) -> banshee_format::FormatConfig {
self.format.to_format_config()
}
#[must_use]
pub fn format(&self, source: &str) -> String {
match self.format.style {
StylePreset::Pgformatter => {
banshee_format::pg_format::format(source, &self.format.to_pg_config())
}
_ => banshee_format::format(source, &self.format.to_format_config()),
}
}
#[must_use]
pub fn format_edits(&self, source: &str) -> Vec<banshee_format::TextEdit> {
banshee_format::diff_edits(source, &self.format(source))
}
#[must_use]
pub fn format_range(
&self,
source: &str,
start: u32,
end: u32,
) -> Vec<banshee_format::TextEdit> {
self.format_edits(source)
.into_iter()
.filter(|e| u32::from(e.range.start()) <= end && start <= u32::from(e.range.end()))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_toml_is_all_defaults() {
let cfg = BansheeConfig::from_toml("").unwrap();
assert!(cfg.lint.enabled);
assert!(!cfg.database.is_configured());
assert_eq!(cfg.completion.max_items, 200);
}
#[test]
fn parses_full_config() {
let src = r#"
[format]
style = "compact"
keyword-case = "lower"
max-width = 100
[lint]
exclude = ["AM01"]
[lint.rules.CP01]
severity = "warning"
[database]
url-env = "DATABASE_URL"
schema = "app"
"#;
let cfg = BansheeConfig::from_toml(src).unwrap();
assert_eq!(cfg.format.style, StylePreset::Compact);
assert_eq!(cfg.format.max_width, Some(100));
assert!(!cfg.lint.is_rule_enabled("AM01"));
assert_eq!(cfg.lint.severity_for("CP01"), Some(SeverityLevel::Warning));
assert_eq!(cfg.database.schema, "app");
assert_eq!(cfg.database.url_env.as_deref(), Some("DATABASE_URL"));
}
#[test]
fn format_config_reflects_overrides() {
let cfg = BansheeConfig::from_toml(
r#"
[format]
style = "sqlstyle"
keyword-case = "lower"
"#,
)
.unwrap();
let fc = cfg.format_config();
assert_eq!(fc.keyword_case, banshee_format::KeywordCase::Lower);
}
#[test]
fn rejects_unknown_severity() {
let err = BansheeConfig::from_toml(
r#"
[lint.rules.CP01]
severity = "bogus"
"#,
);
assert!(err.is_err());
}
}