Skip to main content

adversaria/core/
config.rs

1use crate::core::error::{AdversariaError, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Config {
8    pub version: String,
9    pub default_provider: String,
10    pub providers: HashMap<String, ProviderConfig>,
11    pub suites: SuitesConfig,
12    pub reporting: ReportingConfig,
13    pub plugins: PluginsConfig,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ProviderConfig {
18    pub api_key: Option<String>,
19    pub api_base: Option<String>,
20    pub model: String,
21    pub timeout_seconds: Option<u64>,
22    pub max_retries: Option<u32>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct SuitesConfig {
27    pub directory: PathBuf,
28    pub enabled_suites: Vec<String>,
29    #[serde(default)]
30    pub custom_suites: Vec<PathBuf>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ReportingConfig {
35    pub output_directory: PathBuf,
36    pub format: ReportFormat,
37    #[serde(default = "default_keep_reports")]
38    pub keep_reports: usize,
39}
40
41fn default_keep_reports() -> usize {
42    100
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46#[serde(rename_all = "lowercase")]
47pub enum ReportFormat {
48    Json,
49    Yaml,
50    Both,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct PluginsConfig {
55    pub directory: PathBuf,
56    #[serde(default)]
57    pub enabled: bool,
58}
59
60impl Config {
61    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
62        let content = std::fs::read_to_string(path.as_ref())
63            .map_err(|e| AdversariaError::Config(format!("Failed to read config file: {}", e)))?;
64
65        serde_yaml::from_str(&content)
66            .map_err(|e| AdversariaError::Config(format!("Failed to parse config file: {}", e)))
67    }
68
69    pub fn default_config_path() -> PathBuf {
70        PathBuf::from("adversaria.config.yaml")
71    }
72
73    pub fn create_default<P: AsRef<Path>>(path: P) -> Result<()> {
74        let default_config = Self::default();
75        let yaml = serde_yaml::to_string(&default_config)?;
76        std::fs::write(path, yaml)?;
77        Ok(())
78    }
79}
80
81impl Default for Config {
82    fn default() -> Self {
83        let mut providers = HashMap::new();
84
85        providers.insert(
86            "openai".to_string(),
87            ProviderConfig {
88                api_key: None,
89                api_base: Some("https://api.openai.com/v1".to_string()),
90                model: "gpt-4".to_string(),
91                timeout_seconds: Some(30),
92                max_retries: Some(3),
93            },
94        );
95
96        providers.insert(
97            "anthropic".to_string(),
98            ProviderConfig {
99                api_key: None,
100                api_base: Some("https://api.anthropic.com/v1".to_string()),
101                model: "claude-3-opus-20240229".to_string(),
102                timeout_seconds: Some(30),
103                max_retries: Some(3),
104            },
105        );
106
107        providers.insert(
108            "ollama".to_string(),
109            ProviderConfig {
110                api_key: None,
111                api_base: Some("http://localhost:11434".to_string()),
112                model: "llama2".to_string(),
113                timeout_seconds: Some(60),
114                max_retries: Some(2),
115            },
116        );
117
118        Self {
119            version: "1.0".to_string(),
120            default_provider: "openai".to_string(),
121            providers,
122            suites: SuitesConfig {
123                directory: PathBuf::from("./suites"),
124                enabled_suites: vec![
125                    "prompt_injection".to_string(),
126                    "jailbreak".to_string(),
127                    "role_confusion".to_string(),
128                    "data_exfiltration".to_string(),
129                ],
130                custom_suites: vec![],
131            },
132            reporting: ReportingConfig {
133                output_directory: PathBuf::from("./reports"),
134                format: ReportFormat::Json,
135                keep_reports: 100,
136            },
137            plugins: PluginsConfig {
138                directory: PathBuf::from("./plugins"),
139                enabled: true,
140            },
141        }
142    }
143}