1use crate::core::error::{AdversariaError, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Config {
8 pub version: String,
9 pub default_provider: String,
10 pub providers: HashMap<String, ProviderConfig>,
11 pub suites: SuitesConfig,
12 pub reporting: ReportingConfig,
13 pub plugins: PluginsConfig,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ProviderConfig {
18 pub api_key: Option<String>,
19 pub api_base: Option<String>,
20 pub model: String,
21 pub timeout_seconds: Option<u64>,
22 pub max_retries: Option<u32>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct SuitesConfig {
27 pub directory: PathBuf,
28 pub enabled_suites: Vec<String>,
29 #[serde(default)]
30 pub custom_suites: Vec<PathBuf>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ReportingConfig {
35 pub output_directory: PathBuf,
36 pub format: ReportFormat,
37 #[serde(default = "default_keep_reports")]
38 pub keep_reports: usize,
39}
40
41fn default_keep_reports() -> usize {
42 100
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46#[serde(rename_all = "lowercase")]
47pub enum ReportFormat {
48 Json,
49 Yaml,
50 Both,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct PluginsConfig {
55 pub directory: PathBuf,
56 #[serde(default)]
57 pub enabled: bool,
58}
59
60impl Config {
61 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
62 let content = std::fs::read_to_string(path.as_ref())
63 .map_err(|e| AdversariaError::Config(format!("Failed to read config file: {}", e)))?;
64
65 serde_yaml::from_str(&content)
66 .map_err(|e| AdversariaError::Config(format!("Failed to parse config file: {}", e)))
67 }
68
69 pub fn default_config_path() -> PathBuf {
70 PathBuf::from("adversaria.config.yaml")
71 }
72
73 pub fn create_default<P: AsRef<Path>>(path: P) -> Result<()> {
74 let default_config = Self::default();
75 let yaml = serde_yaml::to_string(&default_config)?;
76 std::fs::write(path, yaml)?;
77 Ok(())
78 }
79}
80
81impl Default for Config {
82 fn default() -> Self {
83 let mut providers = HashMap::new();
84
85 providers.insert(
86 "openai".to_string(),
87 ProviderConfig {
88 api_key: None,
89 api_base: Some("https://api.openai.com/v1".to_string()),
90 model: "gpt-4".to_string(),
91 timeout_seconds: Some(30),
92 max_retries: Some(3),
93 },
94 );
95
96 providers.insert(
97 "anthropic".to_string(),
98 ProviderConfig {
99 api_key: None,
100 api_base: Some("https://api.anthropic.com/v1".to_string()),
101 model: "claude-3-opus-20240229".to_string(),
102 timeout_seconds: Some(30),
103 max_retries: Some(3),
104 },
105 );
106
107 providers.insert(
108 "ollama".to_string(),
109 ProviderConfig {
110 api_key: None,
111 api_base: Some("http://localhost:11434".to_string()),
112 model: "llama2".to_string(),
113 timeout_seconds: Some(60),
114 max_retries: Some(2),
115 },
116 );
117
118 Self {
119 version: "1.0".to_string(),
120 default_provider: "openai".to_string(),
121 providers,
122 suites: SuitesConfig {
123 directory: PathBuf::from("./suites"),
124 enabled_suites: vec![
125 "prompt_injection".to_string(),
126 "jailbreak".to_string(),
127 "role_confusion".to_string(),
128 "data_exfiltration".to_string(),
129 ],
130 custom_suites: vec![],
131 },
132 reporting: ReportingConfig {
133 output_directory: PathBuf::from("./reports"),
134 format: ReportFormat::Json,
135 keep_reports: 100,
136 },
137 plugins: PluginsConfig {
138 directory: PathBuf::from("./plugins"),
139 enabled: true,
140 },
141 }
142 }
143}