oxideshield_guard/benchmark/
runner.rs1use std::time::Instant;
6
7use crate::guard::Guard;
8use crate::multilayer::MultiLayerDefense;
9
10use super::datasets::BenchmarkDataset;
11use super::metrics::GuardMetrics;
12
13#[derive(Debug, Clone)]
15pub struct BenchmarkConfig {
16 pub warmup_iterations: usize,
18 pub iterations: usize,
20 pub shuffle: bool,
22 pub seed: Option<u64>,
24}
25
26impl Default for BenchmarkConfig {
27 fn default() -> Self {
28 Self {
29 warmup_iterations: 3,
30 iterations: 1,
31 shuffle: false,
32 seed: None,
33 }
34 }
35}
36
37impl BenchmarkConfig {
38 pub fn quick() -> Self {
40 Self {
41 warmup_iterations: 1,
42 iterations: 1,
43 shuffle: false,
44 seed: None,
45 }
46 }
47
48 pub fn thorough() -> Self {
50 Self {
51 warmup_iterations: 5,
52 iterations: 3,
53 shuffle: true,
54 seed: Some(42),
55 }
56 }
57}
58
59pub struct BenchmarkRunner {
61 config: BenchmarkConfig,
62}
63
64impl BenchmarkRunner {
65 pub fn new(config: BenchmarkConfig) -> Self {
67 Self { config }
68 }
69
70 pub fn default_runner() -> Self {
72 Self::new(BenchmarkConfig::default())
73 }
74
75 pub fn benchmark_guard(&self, guard: &dyn Guard, dataset: &BenchmarkDataset) -> GuardMetrics {
77 let mut metrics = GuardMetrics::new(guard.name());
78
79 for sample in dataset
81 .samples
82 .iter()
83 .take(self.config.warmup_iterations.min(dataset.len()))
84 {
85 let _ = guard.check(&sample.text);
86 }
87
88 for sample in &dataset.samples {
90 for _ in 0..self.config.iterations {
91 let start = Instant::now();
92 let result = guard.check(&sample.text);
93 let elapsed = start.elapsed();
94
95 let detected = !result.passed;
96 metrics.record(
97 detected,
98 sample.is_attack,
99 elapsed,
100 sample.category.as_deref(),
101 );
102 }
103 }
104
105 metrics
106 }
107
108 pub fn benchmark_multilayer(
110 &self,
111 defense: &MultiLayerDefense,
112 dataset: &BenchmarkDataset,
113 ) -> GuardMetrics {
114 let mut metrics = GuardMetrics::new(defense.name());
115
116 for sample in dataset
118 .samples
119 .iter()
120 .take(self.config.warmup_iterations.min(dataset.len()))
121 {
122 let _ = defense.check(&sample.text);
123 }
124
125 for sample in &dataset.samples {
127 for _ in 0..self.config.iterations {
128 let start = Instant::now();
129 let result = defense.check(&sample.text);
130 let elapsed = start.elapsed();
131
132 let detected = !result.passed;
133 metrics.record(
134 detected,
135 sample.is_attack,
136 elapsed,
137 sample.category.as_deref(),
138 );
139 }
140 }
141
142 metrics
143 }
144
145 pub fn benchmark_guards(
147 &self,
148 guards: &[&dyn Guard],
149 dataset: &BenchmarkDataset,
150 ) -> Vec<GuardMetrics> {
151 guards
152 .iter()
153 .map(|guard| self.benchmark_guard(*guard, dataset))
154 .collect()
155 }
156}
157
158impl Default for BenchmarkRunner {
159 fn default() -> Self {
160 Self::default_runner()
161 }
162}
163
164#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
166pub struct BenchmarkSuiteResults {
167 pub dataset_name: String,
169 pub dataset_size: usize,
171 pub attack_count: usize,
173 pub benign_count: usize,
175 pub guard_metrics: Vec<GuardMetrics>,
177 pub timestamp: String,
179}
180
181impl BenchmarkSuiteResults {
182 pub fn new(dataset: &BenchmarkDataset, guard_metrics: Vec<GuardMetrics>) -> Self {
184 Self {
185 dataset_name: dataset.name.clone(),
186 dataset_size: dataset.len(),
187 attack_count: dataset.attack_samples().len(),
188 benign_count: dataset.benign_samples().len(),
189 guard_metrics,
190 timestamp: chrono::Utc::now().to_rfc3339(),
191 }
192 }
193
194 pub fn best_by_f1(&self) -> Option<&GuardMetrics> {
196 self.guard_metrics
197 .iter()
198 .max_by(|a, b| a.f1_score().partial_cmp(&b.f1_score()).unwrap())
199 }
200
201 pub fn fastest_by_p50(&self) -> Option<&GuardMetrics> {
203 self.guard_metrics
204 .iter()
205 .min_by(|a, b| a.p50_latency_ms().partial_cmp(&b.p50_latency_ms()).unwrap())
206 }
207}
208
209pub struct BenchmarkSuiteBuilder {
211 config: BenchmarkConfig,
212 guards: Vec<Box<dyn Guard>>,
213 multilayer: Option<MultiLayerDefense>,
214 datasets: Vec<BenchmarkDataset>,
215}
216
217impl BenchmarkSuiteBuilder {
218 pub fn new() -> Self {
220 Self {
221 config: BenchmarkConfig::default(),
222 guards: Vec::new(),
223 multilayer: None,
224 datasets: Vec::new(),
225 }
226 }
227
228 pub fn with_config(mut self, config: BenchmarkConfig) -> Self {
230 self.config = config;
231 self
232 }
233
234 pub fn add_guard(mut self, guard: Box<dyn Guard>) -> Self {
236 self.guards.push(guard);
237 self
238 }
239
240 pub fn with_multilayer(mut self, defense: MultiLayerDefense) -> Self {
242 self.multilayer = Some(defense);
243 self
244 }
245
246 pub fn add_dataset(mut self, dataset: BenchmarkDataset) -> Self {
248 self.datasets.push(dataset);
249 self
250 }
251
252 pub fn run(&self) -> Vec<BenchmarkSuiteResults> {
254 let runner = BenchmarkRunner::new(self.config.clone());
255 let mut all_results = Vec::new();
256
257 for dataset in &self.datasets {
258 let mut guard_metrics = Vec::new();
259
260 for guard in &self.guards {
262 let metrics = runner.benchmark_guard(guard.as_ref(), dataset);
263 guard_metrics.push(metrics);
264 }
265
266 if let Some(ref defense) = self.multilayer {
268 let metrics = runner.benchmark_multilayer(defense, dataset);
269 guard_metrics.push(metrics);
270 }
271
272 all_results.push(BenchmarkSuiteResults::new(dataset, guard_metrics));
273 }
274
275 all_results
276 }
277}
278
279impl Default for BenchmarkSuiteBuilder {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[cfg(feature = "adversarial-samples")]
290 use crate::guard::LengthGuard;
291 #[cfg(feature = "adversarial-samples")]
292 use crate::guards::PerplexityGuard;
293 #[cfg(feature = "adversarial-samples")]
294 use crate::benchmark::datasets::oxideshield_benchmark_dataset;
295
296 #[cfg(feature = "adversarial-samples")]
297 #[test]
298 fn test_benchmark_runner() {
299 let runner = BenchmarkRunner::new(BenchmarkConfig::quick());
300 let guard = LengthGuard::new("length").with_max_chars(100);
301 let dataset = oxideshield_benchmark_dataset();
302
303 let metrics = runner.benchmark_guard(&guard, &dataset);
304 assert!(metrics.total_samples > 0);
305 assert!(metrics.latencies_ms.len() > 0);
306 }
307
308 #[cfg(feature = "adversarial-samples")]
309 #[test]
310 fn test_benchmark_suite_builder() {
311 let results = BenchmarkSuiteBuilder::new()
312 .with_config(BenchmarkConfig::quick())
313 .add_guard(Box::new(LengthGuard::new("length").with_max_chars(100)))
314 .add_guard(Box::new(PerplexityGuard::new("perplexity")))
315 .add_dataset(oxideshield_benchmark_dataset())
316 .run();
317
318 assert_eq!(results.len(), 1);
319 assert_eq!(results[0].guard_metrics.len(), 2);
320 }
321
322 #[cfg(feature = "adversarial-samples")]
323 #[test]
324 fn test_suite_results() {
325 let results = BenchmarkSuiteBuilder::new()
326 .with_config(BenchmarkConfig::quick())
327 .add_guard(Box::new(LengthGuard::new("length").with_max_chars(100)))
328 .add_guard(Box::new(PerplexityGuard::new("perplexity")))
329 .add_dataset(oxideshield_benchmark_dataset())
330 .run();
331
332 let suite = &results[0];
333 assert!(suite.best_by_f1().is_some());
334 assert!(suite.fastest_by_p50().is_some());
335 }
336
337 #[test]
338 fn test_benchmark_config_quick() {
339 let config = BenchmarkConfig::quick();
340 assert_eq!(config.warmup_iterations, 1);
341 assert_eq!(config.iterations, 1);
342 assert!(!config.shuffle);
343 }
344}