use crate::{Error, Result};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::time::Duration;
use tokio::fs;
use super::{CheckType, PipelineStage};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SafetyConfig {
pub enabled: bool,
pub strict_mode: bool,
pub show_progress: bool,
pub parallel_checks: bool,
pub pre_commit: StageConfig,
pub pre_push: StageConfig,
pub publish: StageConfig,
pub bypass: BypassConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StageConfig {
pub enabled: bool,
pub timeout_seconds: u64,
pub checks: Vec<CheckType>,
pub continue_on_warning: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BypassConfig {
pub enabled: bool,
pub require_reason: bool,
pub require_confirmation: bool,
pub log_bypasses: bool,
pub max_bypasses_per_day: u32,
}
impl Default for BypassConfig {
fn default() -> Self {
Self {
enabled: true,
require_reason: true,
require_confirmation: true,
log_bypasses: true,
max_bypasses_per_day: 3,
}
}
}
impl BypassConfig {
pub async fn load_or_default() -> Result<Self> {
match SafetyConfig::load().await {
Ok(config) => Ok(config.bypass),
Err(_) => Ok(Self::default()),
}
}
}
impl Default for SafetyConfig {
fn default() -> Self {
Self {
enabled: true,
strict_mode: true,
show_progress: true,
parallel_checks: true,
pre_commit: StageConfig {
enabled: true,
timeout_seconds: 300, checks: CheckType::for_stage(PipelineStage::PreCommit),
continue_on_warning: false,
},
pre_push: StageConfig {
enabled: true,
timeout_seconds: 600, checks: CheckType::for_stage(PipelineStage::PrePush),
continue_on_warning: false,
},
publish: StageConfig {
enabled: true,
timeout_seconds: 900, checks: CheckType::for_stage(PipelineStage::Publish),
continue_on_warning: false,
},
bypass: BypassConfig {
enabled: true, require_reason: true,
require_confirmation: true,
log_bypasses: true,
max_bypasses_per_day: 3,
},
}
}
}
impl SafetyConfig {
pub async fn load_or_default() -> Result<Self> {
match Self::load().await {
Ok(config) => Ok(config),
Err(_) => Ok(Self::default()),
}
}
pub async fn load() -> Result<Self> {
let config_path = Self::config_file_path()?;
let contents = fs::read_to_string(&config_path)
.await
.map_err(|e| Error::config(format!("Failed to read safety config: {}", e)))?;
let config: Self = toml::from_str(&contents)
.map_err(|e| Error::config(format!("Failed to parse safety config: {}", e)))?;
Ok(config)
}
pub async fn save(&self) -> Result<()> {
let config_path = Self::config_file_path()?;
if let Some(parent) = config_path.parent() {
fs::create_dir_all(parent).await?;
}
let contents = toml::to_string_pretty(self)
.map_err(|e| Error::config(format!("Failed to serialize safety config: {}", e)))?;
fs::write(&config_path, contents)
.await
.map_err(|e| Error::config(format!("Failed to write safety config: {}", e)))?;
Ok(())
}
pub fn config_file_path() -> Result<PathBuf> {
let config_dir = crate::config::Config::config_dir_path()?;
Ok(config_dir.join("safety.toml"))
}
pub fn get_stage_config(&self, stage: PipelineStage) -> &StageConfig {
match stage {
PipelineStage::PreCommit => &self.pre_commit,
PipelineStage::PrePush => &self.pre_push,
PipelineStage::Publish => &self.publish,
}
}
pub fn get_stage_config_mut(&mut self, stage: PipelineStage) -> &mut StageConfig {
match stage {
PipelineStage::PreCommit => &mut self.pre_commit,
PipelineStage::PrePush => &mut self.pre_push,
PipelineStage::Publish => &mut self.publish,
}
}
pub fn is_check_enabled(&self, stage: PipelineStage, check: CheckType) -> bool {
let stage_config = self.get_stage_config(stage);
stage_config.enabled && stage_config.checks.contains(&check)
}
pub fn get_timeout(&self, stage: PipelineStage) -> Duration {
Duration::from_secs(self.get_stage_config(stage).timeout_seconds)
}
pub fn set(&mut self, key: &str, value: &str) -> Result<()> {
match key {
key if key.starts_with("bypass.") => self.set_bypass_config(key, value),
key if key.contains('.') => self.set_stage_config(key, value),
_ => self.set_main_config(key, value),
}
}
fn set_main_config(&mut self, key: &str, value: &str) -> Result<()> {
match key {
"enabled" => {
self.enabled = self.parse_bool(value, "enabled")?;
}
"strict_mode" => {
self.strict_mode = self.parse_bool(value, "strict_mode")?;
}
"show_progress" => {
self.show_progress = self.parse_bool(value, "show_progress")?;
}
"parallel_checks" => {
self.parallel_checks = self.parse_bool(value, "parallel_checks")?;
}
_ => return Err(Error::config(format!("Unknown safety config key: {}", key))),
}
Ok(())
}
fn set_stage_config(&mut self, key: &str, value: &str) -> Result<()> {
let (stage, field) = key
.split_once('.')
.ok_or_else(|| Error::config(format!("Invalid config key format: {}", key)))?;
match (stage, field) {
("pre_commit", "enabled") => {
self.pre_commit.enabled = self.parse_bool(value, "pre_commit.enabled")?;
}
("pre_commit", "timeout_seconds") => {
self.pre_commit.timeout_seconds =
self.parse_u64(value, "pre_commit.timeout_seconds")?;
}
("pre_push", "enabled") => {
self.pre_push.enabled = self.parse_bool(value, "pre_push.enabled")?;
}
("pre_push", "timeout_seconds") => {
self.pre_push.timeout_seconds =
self.parse_u64(value, "pre_push.timeout_seconds")?;
}
("publish", "enabled") => {
self.publish.enabled = self.parse_bool(value, "publish.enabled")?;
}
("publish", "timeout_seconds") => {
self.publish.timeout_seconds = self.parse_u64(value, "publish.timeout_seconds")?;
}
_ => return Err(Error::config(format!("Unknown safety config key: {}", key))),
}
Ok(())
}
fn set_bypass_config(&mut self, key: &str, value: &str) -> Result<()> {
match key {
"bypass.enabled" => {
self.bypass.enabled = self.parse_bool(value, "bypass.enabled")?;
}
_ => return Err(Error::config(format!("Unknown safety config key: {}", key))),
}
Ok(())
}
fn parse_bool(&self, value: &str, field: &str) -> Result<bool> {
value
.parse()
.map_err(|_| Error::config(format!("Invalid boolean value for {}", field)))
}
fn parse_u64(&self, value: &str, field: &str) -> Result<u64> {
value
.parse()
.map_err(|_| Error::config(format!("Invalid number for {}", field)))
}
pub fn get(&self, key: &str) -> Option<String> {
match key {
"enabled" => Some(self.enabled.to_string()),
"strict_mode" => Some(self.strict_mode.to_string()),
"show_progress" => Some(self.show_progress.to_string()),
"parallel_checks" => Some(self.parallel_checks.to_string()),
"pre_commit.enabled" => Some(self.pre_commit.enabled.to_string()),
"pre_commit.timeout_seconds" => Some(self.pre_commit.timeout_seconds.to_string()),
"pre_push.enabled" => Some(self.pre_push.enabled.to_string()),
"pre_push.timeout_seconds" => Some(self.pre_push.timeout_seconds.to_string()),
"publish.enabled" => Some(self.publish.enabled.to_string()),
"publish.timeout_seconds" => Some(self.publish.timeout_seconds.to_string()),
"bypass.enabled" => Some(self.bypass.enabled.to_string()),
_ => None,
}
}
}