Skip to main content

oxideshield_guard/benchmark/
runner.rs

1//! Benchmark Runner
2//!
3//! Executes benchmarks against guards and generates metrics.
4
5use std::time::Instant;
6
7use crate::guard::Guard;
8use crate::multilayer::MultiLayerDefense;
9
10use super::datasets::BenchmarkDataset;
11use super::metrics::GuardMetrics;
12
13/// Configuration for benchmark runs
14#[derive(Debug, Clone)]
15pub struct BenchmarkConfig {
16    /// Number of warmup iterations (not counted)
17    pub warmup_iterations: usize,
18    /// Number of benchmark iterations per sample
19    pub iterations: usize,
20    /// Whether to shuffle samples
21    pub shuffle: bool,
22    /// Random seed for shuffling
23    pub seed: Option<u64>,
24}
25
26impl Default for BenchmarkConfig {
27    fn default() -> Self {
28        Self {
29            warmup_iterations: 3,
30            iterations: 1,
31            shuffle: false,
32            seed: None,
33        }
34    }
35}
36
37impl BenchmarkConfig {
38    /// Create a quick benchmark config (fewer iterations)
39    pub fn quick() -> Self {
40        Self {
41            warmup_iterations: 1,
42            iterations: 1,
43            shuffle: false,
44            seed: None,
45        }
46    }
47
48    /// Create a thorough benchmark config (more iterations for stable results)
49    pub fn thorough() -> Self {
50        Self {
51            warmup_iterations: 5,
52            iterations: 3,
53            shuffle: true,
54            seed: Some(42),
55        }
56    }
57}
58
59/// Benchmark runner for OxideShield guards
60pub struct BenchmarkRunner {
61    config: BenchmarkConfig,
62}
63
64impl BenchmarkRunner {
65    /// Create a new benchmark runner
66    pub fn new(config: BenchmarkConfig) -> Self {
67        Self { config }
68    }
69
70    /// Create with default config
71    pub fn default_runner() -> Self {
72        Self::new(BenchmarkConfig::default())
73    }
74
75    /// Run benchmark on a single guard
76    pub fn benchmark_guard(&self, guard: &dyn Guard, dataset: &BenchmarkDataset) -> GuardMetrics {
77        let mut metrics = GuardMetrics::new(guard.name());
78
79        // Warmup phase
80        for sample in dataset
81            .samples
82            .iter()
83            .take(self.config.warmup_iterations.min(dataset.len()))
84        {
85            let _ = guard.check(&sample.text);
86        }
87
88        // Benchmark phase
89        for sample in &dataset.samples {
90            for _ in 0..self.config.iterations {
91                let start = Instant::now();
92                let result = guard.check(&sample.text);
93                let elapsed = start.elapsed();
94
95                let detected = !result.passed;
96                metrics.record(
97                    detected,
98                    sample.is_attack,
99                    elapsed,
100                    sample.category.as_deref(),
101                );
102            }
103        }
104
105        metrics
106    }
107
108    /// Run benchmark on multi-layer defense
109    pub fn benchmark_multilayer(
110        &self,
111        defense: &MultiLayerDefense,
112        dataset: &BenchmarkDataset,
113    ) -> GuardMetrics {
114        let mut metrics = GuardMetrics::new(defense.name());
115
116        // Warmup phase
117        for sample in dataset
118            .samples
119            .iter()
120            .take(self.config.warmup_iterations.min(dataset.len()))
121        {
122            let _ = defense.check(&sample.text);
123        }
124
125        // Benchmark phase
126        for sample in &dataset.samples {
127            for _ in 0..self.config.iterations {
128                let start = Instant::now();
129                let result = defense.check(&sample.text);
130                let elapsed = start.elapsed();
131
132                let detected = !result.passed;
133                metrics.record(
134                    detected,
135                    sample.is_attack,
136                    elapsed,
137                    sample.category.as_deref(),
138                );
139            }
140        }
141
142        metrics
143    }
144
145    /// Run benchmark on multiple guards
146    pub fn benchmark_guards(
147        &self,
148        guards: &[&dyn Guard],
149        dataset: &BenchmarkDataset,
150    ) -> Vec<GuardMetrics> {
151        guards
152            .iter()
153            .map(|guard| self.benchmark_guard(*guard, dataset))
154            .collect()
155    }
156}
157
158impl Default for BenchmarkRunner {
159    fn default() -> Self {
160        Self::default_runner()
161    }
162}
163
164/// Results from a complete benchmark suite run
165#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
166pub struct BenchmarkSuiteResults {
167    /// Dataset used
168    pub dataset_name: String,
169    /// Dataset size
170    pub dataset_size: usize,
171    /// Attack count
172    pub attack_count: usize,
173    /// Benign count
174    pub benign_count: usize,
175    /// Metrics for each guard
176    pub guard_metrics: Vec<GuardMetrics>,
177    /// Timestamp
178    pub timestamp: String,
179}
180
181impl BenchmarkSuiteResults {
182    /// Create new results
183    pub fn new(dataset: &BenchmarkDataset, guard_metrics: Vec<GuardMetrics>) -> Self {
184        Self {
185            dataset_name: dataset.name.clone(),
186            dataset_size: dataset.len(),
187            attack_count: dataset.attack_samples().len(),
188            benign_count: dataset.benign_samples().len(),
189            guard_metrics,
190            timestamp: chrono::Utc::now().to_rfc3339(),
191        }
192    }
193
194    /// Find the best guard by F1 score
195    pub fn best_by_f1(&self) -> Option<&GuardMetrics> {
196        self.guard_metrics
197            .iter()
198            .max_by(|a, b| a.f1_score().partial_cmp(&b.f1_score()).unwrap())
199    }
200
201    /// Find the fastest guard by p50 latency
202    pub fn fastest_by_p50(&self) -> Option<&GuardMetrics> {
203        self.guard_metrics
204            .iter()
205            .min_by(|a, b| a.p50_latency_ms().partial_cmp(&b.p50_latency_ms()).unwrap())
206    }
207}
208
209/// Builder for running benchmark suites
210pub struct BenchmarkSuiteBuilder {
211    config: BenchmarkConfig,
212    guards: Vec<Box<dyn Guard>>,
213    multilayer: Option<MultiLayerDefense>,
214    datasets: Vec<BenchmarkDataset>,
215}
216
217impl BenchmarkSuiteBuilder {
218    /// Create a new builder
219    pub fn new() -> Self {
220        Self {
221            config: BenchmarkConfig::default(),
222            guards: Vec::new(),
223            multilayer: None,
224            datasets: Vec::new(),
225        }
226    }
227
228    /// Set benchmark config
229    pub fn with_config(mut self, config: BenchmarkConfig) -> Self {
230        self.config = config;
231        self
232    }
233
234    /// Add a guard to benchmark
235    pub fn add_guard(mut self, guard: Box<dyn Guard>) -> Self {
236        self.guards.push(guard);
237        self
238    }
239
240    /// Set multi-layer defense to benchmark
241    pub fn with_multilayer(mut self, defense: MultiLayerDefense) -> Self {
242        self.multilayer = Some(defense);
243        self
244    }
245
246    /// Add a dataset
247    pub fn add_dataset(mut self, dataset: BenchmarkDataset) -> Self {
248        self.datasets.push(dataset);
249        self
250    }
251
252    /// Run all benchmarks
253    pub fn run(&self) -> Vec<BenchmarkSuiteResults> {
254        let runner = BenchmarkRunner::new(self.config.clone());
255        let mut all_results = Vec::new();
256
257        for dataset in &self.datasets {
258            let mut guard_metrics = Vec::new();
259
260            // Benchmark individual guards
261            for guard in &self.guards {
262                let metrics = runner.benchmark_guard(guard.as_ref(), dataset);
263                guard_metrics.push(metrics);
264            }
265
266            // Benchmark multi-layer if configured
267            if let Some(ref defense) = self.multilayer {
268                let metrics = runner.benchmark_multilayer(defense, dataset);
269                guard_metrics.push(metrics);
270            }
271
272            all_results.push(BenchmarkSuiteResults::new(dataset, guard_metrics));
273        }
274
275        all_results
276    }
277}
278
279impl Default for BenchmarkSuiteBuilder {
280    fn default() -> Self {
281        Self::new()
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[cfg(feature = "adversarial-samples")]
290    use crate::guard::LengthGuard;
291    #[cfg(feature = "adversarial-samples")]
292    use crate::guards::PerplexityGuard;
293    #[cfg(feature = "adversarial-samples")]
294    use crate::benchmark::datasets::oxideshield_benchmark_dataset;
295
296    #[cfg(feature = "adversarial-samples")]
297    #[test]
298    fn test_benchmark_runner() {
299        let runner = BenchmarkRunner::new(BenchmarkConfig::quick());
300        let guard = LengthGuard::new("length").with_max_chars(100);
301        let dataset = oxideshield_benchmark_dataset();
302
303        let metrics = runner.benchmark_guard(&guard, &dataset);
304        assert!(metrics.total_samples > 0);
305        assert!(metrics.latencies_ms.len() > 0);
306    }
307
308    #[cfg(feature = "adversarial-samples")]
309    #[test]
310    fn test_benchmark_suite_builder() {
311        let results = BenchmarkSuiteBuilder::new()
312            .with_config(BenchmarkConfig::quick())
313            .add_guard(Box::new(LengthGuard::new("length").with_max_chars(100)))
314            .add_guard(Box::new(PerplexityGuard::new("perplexity")))
315            .add_dataset(oxideshield_benchmark_dataset())
316            .run();
317
318        assert_eq!(results.len(), 1);
319        assert_eq!(results[0].guard_metrics.len(), 2);
320    }
321
322    #[cfg(feature = "adversarial-samples")]
323    #[test]
324    fn test_suite_results() {
325        let results = BenchmarkSuiteBuilder::new()
326            .with_config(BenchmarkConfig::quick())
327            .add_guard(Box::new(LengthGuard::new("length").with_max_chars(100)))
328            .add_guard(Box::new(PerplexityGuard::new("perplexity")))
329            .add_dataset(oxideshield_benchmark_dataset())
330            .run();
331
332        let suite = &results[0];
333        assert!(suite.best_by_f1().is_some());
334        assert!(suite.fastest_by_p50().is_some());
335    }
336
337    #[test]
338    fn test_benchmark_config_quick() {
339        let config = BenchmarkConfig::quick();
340        assert_eq!(config.warmup_iterations, 1);
341        assert_eq!(config.iterations, 1);
342        assert!(!config.shuffle);
343    }
344}