use std::{
collections::{HashMap, HashSet},
env, fmt, fs,
io::{Read, Write},
path::PathBuf,
};
use crate::error::{Error, Result};
use crate::{checks, checks::Severity, context::ContextConfig};
use serde_derive::{Deserialize, Serialize};
use tracing::debug;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct LlmConfig {
#[serde(default = "default_llm_provider")]
pub provider: String,
#[serde(default = "default_llm_model")]
pub model: String,
#[serde(default)]
pub base_url: Option<String>,
#[serde(default = "default_llm_timeout_ms")]
pub timeout_ms: u64,
#[serde(default = "default_llm_max_tokens")]
pub max_tokens: u32,
}
fn default_llm_provider() -> String {
"anthropic".into()
}
fn default_llm_model() -> String {
"claude-sonnet-4-20250514".into()
}
const fn default_llm_timeout_ms() -> u64 {
5000
}
const fn default_llm_max_tokens() -> u32 {
512
}
impl Default for LlmConfig {
fn default() -> Self {
Self {
provider: default_llm_provider(),
model: default_llm_model(),
base_url: None,
timeout_ms: default_llm_timeout_ms(),
max_tokens: default_llm_max_tokens(),
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct AgentConfig {
#[serde(default = "default_auto_deny_severity")]
pub auto_deny_severity: Severity,
#[serde(default)]
pub require_human_approval: bool,
}
const fn default_auto_deny_severity() -> Severity {
Severity::High
}
impl Default for AgentConfig {
fn default() -> Self {
Self {
auto_deny_severity: default_auto_deny_severity(),
require_human_approval: false,
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
pub struct WrappersConfig {
#[serde(default)]
pub tools: HashMap<String, WrapperToolConfig>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct WrapperToolConfig {
#[serde(default = "default_wrap_delimiter")]
pub delimiter: String,
#[serde(default)]
pub check_groups: Vec<String>,
}
fn default_wrap_delimiter() -> String {
";".into()
}
impl Default for WrapperToolConfig {
fn default() -> Self {
Self {
delimiter: default_wrap_delimiter(),
check_groups: vec![],
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct SeverityEscalationConfig {
#[serde(default = "default_severity_escalation_enabled")]
pub enabled: bool,
#[serde(default = "default_severity_critical")]
pub critical: Challenge,
#[serde(default = "default_severity_high")]
pub high: Challenge,
#[serde(default = "default_severity_medium")]
pub medium: Challenge,
#[serde(default = "default_severity_low")]
pub low: Challenge,
#[serde(default = "default_severity_info")]
pub info: Challenge,
}
const fn default_severity_escalation_enabled() -> bool {
true
}
const fn default_severity_critical() -> Challenge {
Challenge::Yes
}
const fn default_severity_high() -> Challenge {
Challenge::Enter
}
const fn default_severity_medium() -> Challenge {
Challenge::Math
}
const fn default_severity_low() -> Challenge {
Challenge::Math
}
const fn default_severity_info() -> Challenge {
Challenge::Math
}
impl Default for SeverityEscalationConfig {
fn default() -> Self {
Self {
enabled: default_severity_escalation_enabled(),
critical: default_severity_critical(),
high: default_severity_high(),
medium: default_severity_medium(),
low: default_severity_low(),
info: default_severity_info(),
}
}
}
impl SeverityEscalationConfig {
#[must_use]
pub const fn challenge_for_severity(&self, severity: Severity) -> Option<Challenge> {
if !self.enabled {
return None;
}
Some(match severity {
Severity::Critical => self.critical,
Severity::High => self.high,
Severity::Medium => self.medium,
Severity::Low => self.low,
Severity::Info => self.info,
})
}
}
const DEFAULT_SETTING_FILE_NAME: &str = "settings.yaml";
pub const DEFAULT_CHALLENGE: Challenge = Challenge::Math;
fn default_enabled_groups() -> Vec<String> {
DEFAULT_ENABLED_GROUPS
.iter()
.map(|s| (*s).to_string())
.collect()
}
const fn default_audit_enabled() -> bool {
true
}
const fn default_blast_radius() -> bool {
true
}
pub const DEFAULT_ENABLED_GROUPS: [&str; 16] = [
"aws",
"azure",
"base",
"database",
"docker",
"fs",
"gcp",
"git",
"heroku",
"kubernetes",
"mongodb",
"mysql",
"network",
"psql",
"redis",
"terraform",
];
#[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq)]
pub enum Challenge {
Math,
Enter,
Yes,
}
#[derive(Debug)]
pub struct Config {
pub root_folder: PathBuf,
pub setting_file_path: PathBuf,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Settings {
#[serde(default)]
pub challenge: Challenge,
#[serde(default = "default_enabled_groups")]
pub enabled_groups: Vec<String>,
#[serde(default)]
pub disabled_groups: Vec<String>,
#[serde(default)]
pub ignores_patterns_ids: Vec<String>,
#[serde(default)]
pub deny_patterns_ids: Vec<String>,
#[serde(default)]
pub context: ContextConfig,
#[serde(default = "default_audit_enabled")]
pub audit_enabled: bool,
#[serde(default = "default_blast_radius")]
pub blast_radius: bool,
#[serde(default)]
pub min_severity: Option<Severity>,
#[serde(default)]
pub agent: AgentConfig,
#[serde(default)]
pub llm: Option<LlmConfig>,
#[serde(default)]
pub wrappers: WrappersConfig,
#[serde(default)]
pub severity_escalation: SeverityEscalationConfig,
#[serde(default)]
pub group_escalation: HashMap<String, Challenge>,
#[serde(default)]
pub check_escalation: HashMap<String, Challenge>,
}
impl fmt::Display for Challenge {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Math => write!(f, "Math"),
Self::Enter => write!(f, "Enter"),
Self::Yes => write!(f, "Yes"),
}
}
}
impl Default for Challenge {
fn default() -> Self {
DEFAULT_CHALLENGE
}
}
impl Challenge {
pub fn from_string(str: &str) -> Result<Self> {
match str.to_lowercase().as_str() {
"math" => Ok(Self::Math),
"enter" => Ok(Self::Enter),
"yes" => Ok(Self::Yes),
_ => Err(Error::Config("given challenge name not found".into())),
}
}
#[must_use]
pub fn stricter(self, other: Self) -> Self {
let rank = |c: Self| match c {
Self::Math => 0,
Self::Enter => 1,
Self::Yes => 2,
};
if rank(self) >= rank(other) {
self
} else {
other
}
}
}
impl Default for Settings {
fn default() -> Self {
Self {
challenge: DEFAULT_CHALLENGE,
enabled_groups: default_enabled_groups(),
disabled_groups: vec![],
ignores_patterns_ids: vec![],
deny_patterns_ids: vec![],
context: ContextConfig::default(),
audit_enabled: default_audit_enabled(),
blast_radius: default_blast_radius(),
min_severity: None,
agent: AgentConfig::default(),
llm: None,
wrappers: WrappersConfig::default(),
severity_escalation: SeverityEscalationConfig::default(),
group_escalation: HashMap::new(),
check_escalation: HashMap::new(),
}
}
}
impl Config {
pub fn new(path: Option<&str>) -> Result<Self> {
let package_name = env!("CARGO_PKG_NAME");
let config_folder = match path {
Some(p) => PathBuf::from(p),
None => match dirs::config_dir() {
Some(conf_dir) => conf_dir.join(package_name),
None => return Err(Error::Config("could not get directory path".into())),
},
};
let setting_file_path = config_folder.join(DEFAULT_SETTING_FILE_NAME);
let setting_config = Self {
root_folder: config_folder,
setting_file_path,
};
debug!("configuration settings: {setting_config:?}");
Ok(setting_config)
}
#[must_use]
pub fn audit_log_path(&self) -> PathBuf {
self.root_folder.join("audit.log")
}
#[must_use]
pub fn custom_checks_dir(&self) -> PathBuf {
self.root_folder.join("checks")
}
pub fn get_settings_from_file(&self) -> Result<Settings> {
match self.read_config_file() {
Ok(content) => match serde_yaml::from_str(&content) {
Ok(settings) => Ok(settings),
Err(e) => {
tracing::warn!(
"Settings file could not be parsed, using defaults: {e}. \
Run `shellfirm config reset` to fix."
);
Ok(Settings::default())
}
},
Err(_) if !self.setting_file_path.exists() => Ok(Settings::default()),
Err(e) => Err(e),
}
}
pub fn reset_config(&self) -> Result<()> {
self.ensure_config_dir()?;
fs::File::create(&self.setting_file_path)?;
Ok(())
}
fn ensure_config_dir(&self) -> Result<()> {
if let Err(err) = fs::create_dir_all(&self.root_folder) {
if err.kind() != std::io::ErrorKind::AlreadyExists {
return Err(Error::Config(format!("could not create folder: {err}")));
}
debug!("configuration folder found: {}", self.root_folder.display());
} else {
debug!(
"configuration created in path: {}",
self.root_folder.display()
);
}
Ok(())
}
pub fn save_settings_file_from_struct(&self, settings: &Settings) -> Result<()> {
self.ensure_config_dir()?;
let content = serde_yaml::to_string(settings)?;
let mut file = fs::File::create(&self.setting_file_path)?;
file.write_all(content.as_bytes())?;
debug!(
"settings file crated in path: {}. config data: {:?}",
self.setting_file_path.display(),
settings
);
Ok(())
}
pub fn read_config_file(&self) -> Result<String> {
let mut file = std::fs::File::open(&self.setting_file_path)?;
let mut content = String::new();
file.read_to_string(&mut content)?;
Ok(content)
}
pub fn read_config_as_value(&self) -> Result<serde_yaml::Value> {
let empty_mapping = || serde_yaml::Value::Mapping(serde_yaml::Mapping::default());
match self.read_config_file() {
Ok(content) => {
let value: serde_yaml::Value = serde_yaml::from_str(&content)?;
if value.is_null() {
Ok(empty_mapping())
} else {
Ok(value)
}
}
Err(_) if !self.setting_file_path.exists() => Ok(empty_mapping()),
Err(e) => Err(e),
}
}
pub fn save_config_from_value(&self, value: &serde_yaml::Value) -> Result<()> {
self.ensure_config_dir()?;
let yaml_str = serde_yaml::to_string(value)?;
let _settings: Settings = serde_yaml::from_str(&yaml_str)?;
let mut file = fs::File::create(&self.setting_file_path)?;
file.write_all(yaml_str.as_bytes())?;
Ok(())
}
}
impl Settings {
pub fn get_active_checks(&self) -> Result<Vec<checks::Check>> {
let enabled: HashSet<&str> = self.enabled_groups.iter().map(String::as_str).collect();
let disabled: HashSet<&str> = self.disabled_groups.iter().map(String::as_str).collect();
let ignores: HashSet<&str> = self
.ignores_patterns_ids
.iter()
.map(String::as_str)
.collect();
Ok(checks::all_checks_cached()
.iter()
.filter(|c| enabled.contains(c.from.as_str()))
.filter(|c| !disabled.contains(c.from.as_str()))
.filter(|c| !ignores.contains(c.id.as_str()))
.cloned()
.collect())
}
#[must_use]
pub const fn get_active_groups(&self) -> &Vec<String> {
&self.enabled_groups
}
}
pub fn value_set(
root: &mut serde_yaml::Value,
path: &str,
new_value: serde_yaml::Value,
) -> Result<()> {
let segments: Vec<&str> = path.split('.').collect();
let mut current = root;
for (i, segment) in segments.iter().enumerate() {
if i == segments.len() - 1 {
let map = current.as_mapping_mut().ok_or_else(|| {
Error::Config(format!("expected a mapping at parent of '{path}'"))
})?;
map.insert(serde_yaml::Value::String((*segment).to_string()), new_value);
return Ok(());
}
let key = serde_yaml::Value::String((*segment).to_string());
if !current.as_mapping().is_some_and(|m| m.contains_key(&key)) {
let map = current
.as_mapping_mut()
.ok_or_else(|| Error::Config(format!("expected a mapping at '{segment}'")))?;
map.insert(
key.clone(),
serde_yaml::Value::Mapping(serde_yaml::Mapping::default()),
);
}
current = current
.get_mut(segment)
.ok_or_else(|| Error::Config(format!("failed to descend into '{segment}'")))?;
}
Ok(())
}
#[cfg(test)]
mod test_config {
use std::fs::read_dir;
use tree_fs::Tree;
use super::*;
fn initialize_config_folder(temp_dir: &Tree) -> Config {
let temp_dir = temp_dir.root.join("app");
Config::new(Some(&temp_dir.display().to_string())).unwrap()
}
fn initialize_config_folder_with_file(temp_dir: &Tree) -> Config {
let config = initialize_config_folder(temp_dir);
config.reset_config().unwrap();
config
}
#[test]
fn new_config_does_not_create_files() {
let temp_dir = tree_fs::TreeBuilder::default()
.create()
.expect("create tree");
let config = initialize_config_folder(&temp_dir);
assert!(!config.root_folder.is_dir());
assert!(!config.setting_file_path.is_file());
}
#[test]
fn get_settings_returns_defaults_without_file() {
let temp_dir = tree_fs::TreeBuilder::default()
.create()
.expect("create tree");
let config = initialize_config_folder(&temp_dir);
let settings = config.get_settings_from_file().unwrap();
assert_eq!(settings.challenge, DEFAULT_CHALLENGE);
assert_eq!(settings.enabled_groups, default_enabled_groups());
assert!(settings.audit_enabled);
}
#[test]
fn can_reset_config() {
let temp_dir = tree_fs::TreeBuilder::default()
.create()
.expect("create tree");
let config = initialize_config_folder_with_file(&temp_dir);
let mut settings = config.get_settings_from_file().unwrap();
settings.challenge = Challenge::Yes;
config.save_settings_file_from_struct(&settings).unwrap();
assert_eq!(
config.get_settings_from_file().unwrap().challenge,
Challenge::Yes
);
config.reset_config().unwrap();
assert_eq!(
config.get_settings_from_file().unwrap().challenge,
Challenge::Math
);
assert_eq!(read_dir(&config.root_folder).unwrap().count(), 1);
}
#[test]
fn read_config_as_value_empty_file_returns_empty_mapping() {
let temp_dir = tree_fs::TreeBuilder::default()
.create()
.expect("create tree");
let config = initialize_config_folder_with_file(&temp_dir);
let root = config.read_config_as_value().unwrap();
let mapping = root
.as_mapping()
.expect("should be a Mapping, not Null");
assert!(mapping.is_empty());
let mut root = root;
value_set(
&mut root,
"challenge",
serde_yaml::Value::String("Enter".into()),
)
.unwrap();
assert_eq!(
root.get("challenge").unwrap().as_str().unwrap(),
"Enter"
);
}
#[test]
fn sparse_config_on_fresh_install() {
let temp_dir = tree_fs::TreeBuilder::default()
.create()
.expect("create tree");
let config = initialize_config_folder(&temp_dir);
assert!(!config.setting_file_path.exists());
let root = config.read_config_as_value().unwrap();
assert!(root.as_mapping().unwrap().is_empty());
let mut root = root;
value_set(
&mut root,
"challenge",
serde_yaml::Value::String("Yes".into()),
)
.unwrap();
config.save_config_from_value(&root).unwrap();
let content = config.read_config_file().unwrap();
assert!(content.contains("challenge"));
assert!(!content.contains("enabled_groups"));
let settings = config.get_settings_from_file().unwrap();
assert_eq!(settings.challenge, Challenge::Yes);
assert_eq!(settings.enabled_groups, default_enabled_groups());
}
}
#[cfg(test)]
mod test_settings {
use super::*;
#[test]
fn can_get_active_checks() {
assert!(Settings::default().get_active_checks().is_ok());
}
#[test]
fn can_get_settings_from_file() {
let groups = Settings::default().get_active_groups().clone();
assert_eq!(
groups,
vec![
"aws",
"azure",
"base",
"database",
"docker",
"fs",
"gcp",
"git",
"heroku",
"kubernetes",
"mongodb",
"mysql",
"network",
"psql",
"redis",
"terraform",
]
);
}
#[test]
fn settings_yaml_roundtrip_preserves_enabled_groups() {
let original = Settings::default();
let yaml = serde_yaml::to_string(&original).unwrap();
let restored: Settings = serde_yaml::from_str(&yaml).unwrap();
assert_eq!(restored.enabled_groups, original.enabled_groups);
assert!(
!restored.enabled_groups.is_empty(),
"enabled_groups must not be empty after roundtrip"
);
}
#[test]
fn default_settings_produce_nonempty_active_checks() {
let checks = Settings::default().get_active_checks().unwrap();
assert!(
!checks.is_empty(),
"Settings::default() must produce active checks"
);
let groups: std::collections::HashSet<&str> =
checks.iter().map(|c| c.from.as_str()).collect();
assert!(groups.contains("fs"), "fs group must be active");
assert!(groups.contains("git"), "git group must be active");
}
#[test]
fn settings_file_roundtrip_produces_matches() {
let temp = tree_fs::TreeBuilder::default()
.create()
.expect("create tree");
let config = Config::new(Some(&temp.root.join("app").display().to_string())).unwrap();
config.reset_config().unwrap();
let settings = config.get_settings_from_file().unwrap();
let checks = settings.get_active_checks().unwrap();
assert!(
!checks.is_empty(),
"Active checks must not be empty after file roundtrip"
);
let matches = crate::checks::run_check_on_command(&checks, "git push --force origin main");
assert!(
!matches.is_empty(),
"git push --force must match after file roundtrip"
);
}
#[test]
fn old_includes_field_falls_back_to_default_enabled_groups() {
let old_yaml = "challenge: Math\nincludes:\n - base\n - fs\n - git\n";
let settings: Settings = serde_yaml::from_str(old_yaml).unwrap();
assert_eq!(settings.enabled_groups, default_enabled_groups());
let checks = settings.get_active_checks().unwrap();
assert!(!checks.is_empty());
}
}