1use std::collections::HashMap;
4
5use super::{TamperError, TamperRegistry};
6
7#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
9pub struct StrategyConfig {
10 pub name: String,
12 pub enabled: bool,
14 pub contexts: Option<Vec<String>>,
16 pub params: Option<HashMap<String, toml::Value>>,
18}
19
20#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
22pub struct TamperConfig {
23 pub strategies: Vec<StrategyConfig>,
25}
26
27const STRATEGY_FILE_MAX_BYTES: u64 = 256 * 1024; fn read_capped_tamper_text(path: &std::path::Path, max_bytes: u64) -> std::io::Result<String> {
41 use std::io::Read;
42 let f = std::fs::File::open(path)?;
43 let mut limited = f.take(max_bytes + 1);
44 let mut buf = Vec::with_capacity(8 * 1024);
45 limited.read_to_end(&mut buf)?;
46 if (buf.len() as u64) > max_bytes {
47 return Err(std::io::Error::new(
48 std::io::ErrorKind::InvalidData,
49 format!(
50 "{}: tamper config exceeds {}-byte cap",
51 path.display(),
52 max_bytes,
53 ),
54 ));
55 }
56 String::from_utf8(buf).map_err(|e| {
57 std::io::Error::new(
58 std::io::ErrorKind::InvalidData,
59 format!("{}: tamper config is not valid UTF-8: {e}", path.display()),
60 )
61 })
62}
63
64impl TamperRegistry {
65 pub fn load_toml<P: AsRef<std::path::Path>>(
71 &mut self,
72 path: P,
73 ) -> Result<TamperConfig, TamperError> {
74 let path_ref = path.as_ref();
75
76 let meta = std::fs::metadata(path_ref).map_err(|e| {
82 TamperError::LoadError(format!("Failed to stat {}: {e}", path_ref.display()))
83 })?;
84 if meta.len() > STRATEGY_FILE_MAX_BYTES {
85 return Err(TamperError::InvalidConfig(format!(
86 "strategy file {} is {} bytes, exceeds {}-byte cap",
87 path_ref.display(),
88 meta.len(),
89 STRATEGY_FILE_MAX_BYTES,
90 )));
91 }
92
93 let content = read_capped_tamper_text(path_ref, STRATEGY_FILE_MAX_BYTES)
94 .map_err(|e| TamperError::LoadError(format!("Failed to read file: {e}")))?;
95
96 let config: TamperConfig = toml::from_str(&content)
97 .map_err(|e| TamperError::InvalidConfig(format!("Failed to parse TOML: {e}")))?;
98
99 Ok(config)
100 }
101
102 pub fn apply_config(&self, payload: &str, config: &TamperConfig) -> Vec<(String, String)> {
106 let mut results = Vec::new();
107
108 for strategy_config in &config.strategies {
109 if !strategy_config.enabled {
110 continue;
111 }
112
113 if let Some(strategy) = self.get(&strategy_config.name) {
114 let context = strategy_config
115 .contexts
116 .as_ref()
117 .and_then(|v| v.first().map(std::string::String::as_str));
118 let result = if let Some(ref params) = strategy_config.params {
119 strategy.tamper_with_params(payload, context, params)
120 } else {
121 strategy.tamper(payload, context)
122 };
123 results.push((strategy_config.name.clone(), result));
124 }
125 }
126
127 results
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn tamper_config_serialization() {
137 let config = TamperConfig {
138 strategies: vec![
139 StrategyConfig {
140 name: "url_encode".to_string(),
141 enabled: true,
142 contexts: Some(vec!["sql".to_string(), "xss".to_string()]),
143 params: None,
144 },
145 StrategyConfig {
146 name: "base64".to_string(),
147 enabled: false,
148 contexts: None,
149 params: None,
150 },
151 ],
152 };
153
154 let toml_str = toml::to_string(&config).expect("Failed to serialize config");
155 assert!(toml_str.contains("url_encode"));
156 assert!(toml_str.contains("enabled = true"));
157 assert!(toml_str.contains("enabled = false"));
158
159 let deserialized: TamperConfig =
160 toml::from_str(&toml_str).expect("Failed to deserialize config");
161 assert_eq!(deserialized.strategies.len(), 2);
162 assert!(deserialized.strategies[0].enabled);
163 assert!(!deserialized.strategies[1].enabled);
164 }
165
166 #[test]
167 fn apply_config_filters_disabled() {
168 let registry = TamperRegistry::with_defaults();
169 let config = TamperConfig {
170 strategies: vec![
171 StrategyConfig {
172 name: "url_encode".to_string(),
173 enabled: true,
174 contexts: None,
175 params: None,
176 },
177 StrategyConfig {
178 name: "base64".to_string(),
179 enabled: false,
180 contexts: None,
181 params: None,
182 },
183 ],
184 };
185
186 let results = registry.apply_config("test", &config);
187 assert_eq!(results.len(), 1);
188 assert_eq!(results[0].0, "url_encode");
189 }
190
191 #[test]
192 fn apply_config_with_context() {
193 let registry = TamperRegistry::with_defaults();
194 let config = TamperConfig {
195 strategies: vec![StrategyConfig {
196 name: "sql_comment".to_string(),
197 enabled: true,
198 contexts: Some(vec!["sql".to_string()]),
199 params: None,
200 }],
201 };
202
203 let results = registry.apply_config("SELECT * FROM", &config);
204 assert_eq!(results.len(), 1);
205 assert!(results[0].1.contains("/**/"));
206 }
207
208 #[test]
209 fn strategy_config_roundtrip() {
210 let config_str = r#"
211[[strategies]]
212name = "url_encode"
213enabled = true
214contexts = ["sql", "xss"]
215"#;
216
217 let config: TamperConfig = toml::from_str(config_str).expect("Failed to parse TOML");
218 assert_eq!(config.strategies.len(), 1);
219 assert_eq!(config.strategies[0].name, "url_encode");
220 assert!(config.strategies[0].enabled);
221 assert_eq!(
222 config.strategies[0].contexts,
223 Some(vec!["sql".to_string(), "xss".to_string()])
224 );
225 }
226
227 #[test]
228 fn load_toml_from_strategies_d() {
229 let mut registry = TamperRegistry::with_defaults();
230 let path = std::path::Path::new(concat!(
231 env!("CARGO_MANIFEST_DIR"),
232 "/../../strategies.d/core.toml"
233 ));
234
235 if path.exists() {
236 let config = registry.load_toml(path).expect("Failed to load core.toml");
237 let has_url_encode = config
238 .strategies
239 .iter()
240 .any(|s| s.name == "url_encode" && s.enabled);
241 assert!(has_url_encode, "core.toml should have url_encode enabled");
242 }
243 }
244
245 #[test]
246 fn tamper_error_invalid_toml() {
247 let mut registry = TamperRegistry::with_defaults();
248 let invalid_toml = "not valid toml [[";
249
250 let temp_file = std::env::temp_dir().join(format!(
253 "wafrift-invalid-toml-{}-{}.toml",
254 std::process::id(),
255 std::time::SystemTime::now()
256 .duration_since(std::time::UNIX_EPOCH)
257 .map(|d| d.as_nanos())
258 .unwrap_or(0),
259 ));
260 std::fs::write(&temp_file, invalid_toml).unwrap();
261
262 let result = registry.load_toml(&temp_file);
263 assert!(matches!(result, Err(TamperError::InvalidConfig(_))));
264
265 std::fs::remove_file(&temp_file).ok();
266 }
267
268 #[test]
269 fn tamper_error_missing_file() {
270 let mut registry = TamperRegistry::with_defaults();
271 let result = registry.load_toml("/nonexistent/path/file.toml");
272 assert!(matches!(result, Err(TamperError::LoadError(_))));
273 }
274
275 #[test]
276 fn layered_tamper_chain() {
277 let registry = TamperRegistry::with_defaults();
278 let config = TamperConfig {
279 strategies: vec![
280 StrategyConfig {
281 name: "case_alternation".to_string(),
282 enabled: true,
283 contexts: None,
284 params: None,
285 },
286 StrategyConfig {
287 name: "url_encode".to_string(),
288 enabled: true,
289 contexts: None,
290 params: None,
291 },
292 ],
293 };
294
295 let results = registry.apply_config("select <", &config);
296 assert_eq!(results.len(), 2);
297
298 assert!(results.iter().any(|(n, _)| n == "case_alternation"));
299 assert!(results.iter().any(|(n, _)| n == "url_encode"));
300
301 let url_result = results.iter().find(|(n, _)| n == "url_encode").unwrap();
302 assert!(url_result.1.contains('%'));
303 }
304
305 #[test]
306 fn tamper_strategy_trait_object_safety() {
307 let strategies: Vec<Box<dyn super::super::TamperStrategy>> = vec![
308 Box::new(super::super::UrlEncodeTamper),
309 Box::new(super::super::Base64Tamper),
310 Box::new(super::super::CaseAlternationTamper),
311 ];
312
313 for strategy in &strategies {
314 let result = strategy.tamper("test", None);
315 assert!(!result.is_empty());
316 assert!(strategy.aggressiveness() >= 0.0 && strategy.aggressiveness() <= 1.0);
317 }
318 }
319
320 #[test]
321 fn custom_strategy_params() {
322 let config = StrategyConfig {
323 name: "custom".to_string(),
324 enabled: true,
325 contexts: None,
326 params: {
327 let mut map = std::collections::HashMap::new();
328 map.insert("level".to_string(), toml::Value::Integer(5));
329 map.insert(
330 "prefix".to_string(),
331 toml::Value::String("test_".to_string()),
332 );
333 Some(map)
334 },
335 };
336
337 assert!(config.params.is_some());
338 let params = config.params.as_ref().unwrap();
339 assert_eq!(params.get("level").unwrap().as_integer(), Some(5));
340 assert_eq!(params.get("prefix").unwrap().as_str(), Some("test_"));
341 }
342}