Skip to main content

oxideshield_guard/benchmark/
datasets.rs

1//! Benchmark Datasets
2//!
3//! Standardized datasets for evaluating LLM security tools.
4//! Includes both attack datasets and benign datasets for measuring
5//! detection rate and false positive rate.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// A benchmark sample with ground truth label
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct BenchmarkSample {
13    /// Unique identifier
14    pub id: String,
15    /// The input text/prompt
16    pub text: String,
17    /// Ground truth: true if this is an attack, false if benign
18    pub is_attack: bool,
19    /// Attack category (if applicable)
20    pub category: Option<String>,
21    /// Subcategory or attack type
22    pub subcategory: Option<String>,
23    /// Source dataset
24    pub source: String,
25    /// Additional metadata
26    pub metadata: HashMap<String, String>,
27}
28
29/// A benchmark dataset
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct BenchmarkDataset {
32    /// Dataset name
33    pub name: String,
34    /// Dataset version
35    pub version: String,
36    /// Description
37    pub description: String,
38    /// Samples
39    pub samples: Vec<BenchmarkSample>,
40    /// Source URL or reference
41    pub source: String,
42}
43
44impl BenchmarkDataset {
45    /// Create a new empty dataset
46    pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
47        Self {
48            name: name.into(),
49            version: version.into(),
50            description: String::new(),
51            samples: Vec::new(),
52            source: String::new(),
53        }
54    }
55
56    /// Add description
57    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
58        self.description = desc.into();
59        self
60    }
61
62    /// Add source
63    pub fn with_source(mut self, source: impl Into<String>) -> Self {
64        self.source = source.into();
65        self
66    }
67
68    /// Add a sample
69    pub fn add_sample(mut self, sample: BenchmarkSample) -> Self {
70        self.samples.push(sample);
71        self
72    }
73
74    /// Add multiple samples
75    pub fn add_samples(mut self, samples: Vec<BenchmarkSample>) -> Self {
76        self.samples.extend(samples);
77        self
78    }
79
80    /// Get attack samples only
81    pub fn attack_samples(&self) -> Vec<&BenchmarkSample> {
82        self.samples.iter().filter(|s| s.is_attack).collect()
83    }
84
85    /// Get benign samples only
86    pub fn benign_samples(&self) -> Vec<&BenchmarkSample> {
87        self.samples.iter().filter(|s| !s.is_attack).collect()
88    }
89
90    /// Total sample count
91    pub fn len(&self) -> usize {
92        self.samples.len()
93    }
94
95    /// Check if empty
96    pub fn is_empty(&self) -> bool {
97        self.samples.is_empty()
98    }
99
100    /// Get samples by category
101    pub fn samples_by_category(&self, category: &str) -> Vec<&BenchmarkSample> {
102        self.samples
103            .iter()
104            .filter(|s| s.category.as_deref() == Some(category))
105            .collect()
106    }
107}
108
109/// Create the OxideShield standard benchmark dataset
110/// Combines attack samples and benign samples from the adversarial module.
111/// Requires the `adversarial-samples` feature.
112#[cfg(feature = "adversarial-samples")]
113pub fn oxideshield_benchmark_dataset() -> BenchmarkDataset {
114    use crate::adversarial::{all_attack_samples, all_benign_samples};
115
116    let mut dataset = BenchmarkDataset::new("OxideShield Standard", "1.0.0")
117        .with_description("Standard benchmark dataset for OxideShield evaluation")
118        .with_source("OxideShield internal");
119
120    // Add attack samples
121    for sample in all_attack_samples() {
122        dataset.samples.push(BenchmarkSample {
123            id: sample.id,
124            text: sample.prompt,
125            is_attack: true,
126            category: Some(format!("{:?}", sample.attack_type)),
127            subcategory: Some(sample.description),
128            source: sample.source,
129            metadata: HashMap::new(),
130        });
131    }
132
133    // Add benign samples
134    for sample in all_benign_samples() {
135        dataset.samples.push(BenchmarkSample {
136            id: sample.id,
137            text: sample.prompt,
138            is_attack: false,
139            category: Some(format!("{:?}", sample.category)),
140            subcategory: Some(sample.description),
141            source: "OxideShield benign".into(),
142            metadata: HashMap::new(),
143        });
144    }
145
146    dataset
147}
148
149/// Create a JailbreakBench-based dataset.
150/// Requires the `adversarial-samples` feature.
151#[cfg(feature = "adversarial-samples")]
152pub fn jailbreakbench_dataset() -> BenchmarkDataset {
153    use crate::adversarial::jailbreakbench_behaviors;
154
155    let mut dataset = BenchmarkDataset::new("JailbreakBench", "1.0.0")
156        .with_description("Behaviors from JailbreakBench (jailbreakbench.github.io)")
157        .with_source("https://jailbreakbench.github.io/");
158
159    for behavior in jailbreakbench_behaviors() {
160        dataset.samples.push(BenchmarkSample {
161            id: behavior.id,
162            text: behavior.prompt,
163            is_attack: true,
164            category: Some(behavior.category.to_string()),
165            subcategory: Some(behavior.behavior),
166            source: "JailbreakBench".into(),
167            metadata: HashMap::new(),
168        });
169    }
170
171    dataset
172}
173
174/// Create a prompt injection focused dataset.
175/// Requires the `adversarial-samples` feature.
176#[cfg(feature = "adversarial-samples")]
177pub fn prompt_injection_dataset() -> BenchmarkDataset {
178    use crate::adversarial::{prompt_injection_samples, system_prompt_leak_samples};
179
180    let mut dataset = BenchmarkDataset::new("Prompt Injection", "1.0.0")
181        .with_description("Focused dataset for prompt injection and system prompt leak attacks")
182        .with_source("OxideShield curated");
183
184    for sample in prompt_injection_samples() {
185        dataset.samples.push(BenchmarkSample {
186            id: sample.id,
187            text: sample.prompt,
188            is_attack: true,
189            category: Some("PromptInjection".into()),
190            subcategory: Some(sample.description),
191            source: sample.source,
192            metadata: HashMap::new(),
193        });
194    }
195
196    for sample in system_prompt_leak_samples() {
197        dataset.samples.push(BenchmarkSample {
198            id: sample.id,
199            text: sample.prompt,
200            is_attack: true,
201            category: Some("SystemPromptLeak".into()),
202            subcategory: Some(sample.description),
203            source: sample.source,
204            metadata: HashMap::new(),
205        });
206    }
207
208    dataset
209}
210
211/// Create an adversarial suffix focused dataset (AutoDAN, GCG).
212/// Requires the `adversarial-samples` feature.
213#[cfg(feature = "adversarial-samples")]
214pub fn adversarial_suffix_dataset() -> BenchmarkDataset {
215    use crate::adversarial::{autodan_samples, gcg_samples};
216
217    let mut dataset = BenchmarkDataset::new("Adversarial Suffixes", "1.0.0")
218        .with_description("AutoDAN and GCG style adversarial suffix attacks")
219        .with_source("Research paper simulations");
220
221    for sample in autodan_samples() {
222        dataset.samples.push(BenchmarkSample {
223            id: sample.id,
224            text: sample.prompt,
225            is_attack: true,
226            category: Some("AutoDAN".into()),
227            subcategory: Some(sample.description),
228            source: sample.source,
229            metadata: HashMap::new(),
230        });
231    }
232
233    for sample in gcg_samples() {
234        dataset.samples.push(BenchmarkSample {
235            id: sample.id,
236            text: sample.prompt,
237            is_attack: true,
238            category: Some("GCG".into()),
239            subcategory: Some(sample.description),
240            source: sample.source,
241            metadata: HashMap::new(),
242        });
243    }
244
245    dataset
246}
247
248/// Combine multiple datasets
249pub fn combined_dataset(datasets: Vec<BenchmarkDataset>) -> BenchmarkDataset {
250    let mut combined = BenchmarkDataset::new("Combined", "1.0.0")
251        .with_description("Combined dataset from multiple sources");
252
253    let sources: Vec<String> = datasets.iter().map(|d| d.name.clone()).collect();
254    combined.source = sources.join(", ");
255
256    for dataset in datasets {
257        for mut sample in dataset.samples {
258            // Prefix ID with dataset name to avoid collisions
259            sample.id = format!("{}_{}", dataset.name.replace(' ', "_"), sample.id);
260            combined.samples.push(sample);
261        }
262    }
263
264    combined
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[cfg(feature = "adversarial-samples")]
272    #[test]
273    fn test_oxideshield_dataset() {
274        let dataset = oxideshield_benchmark_dataset();
275        assert!(!dataset.is_empty());
276        assert!(dataset.attack_samples().len() > 0);
277        assert!(dataset.benign_samples().len() > 0);
278    }
279
280    #[cfg(feature = "adversarial-samples")]
281    #[test]
282    fn test_jailbreakbench_dataset() {
283        let dataset = jailbreakbench_dataset();
284        assert!(!dataset.is_empty());
285        // All JailbreakBench samples are attacks
286        assert_eq!(dataset.attack_samples().len(), dataset.len());
287    }
288
289    #[cfg(feature = "adversarial-samples")]
290    #[test]
291    fn test_prompt_injection_dataset() {
292        let dataset = prompt_injection_dataset();
293        assert!(dataset.len() >= 13); // 8 PI + 5 SL samples
294    }
295
296    #[cfg(feature = "adversarial-samples")]
297    #[test]
298    fn test_adversarial_suffix_dataset() {
299        let dataset = adversarial_suffix_dataset();
300        assert!(dataset.len() >= 6); // 3 AutoDAN + 3 GCG
301    }
302
303    #[cfg(feature = "adversarial-samples")]
304    #[test]
305    fn test_combined_dataset() {
306        let d1 = prompt_injection_dataset();
307        let d2 = adversarial_suffix_dataset();
308        let combined = combined_dataset(vec![d1, d2]);
309        assert!(combined.len() >= 19);
310    }
311
312    #[cfg(feature = "adversarial-samples")]
313    #[test]
314    fn test_samples_by_category() {
315        let dataset = oxideshield_benchmark_dataset();
316        let pi_samples = dataset.samples_by_category("PromptInjection");
317        assert!(pi_samples.len() > 0);
318    }
319
320    #[test]
321    fn test_benchmark_dataset_new() {
322        let dataset = BenchmarkDataset::new("Test", "1.0.0");
323        assert!(dataset.is_empty());
324        assert_eq!(dataset.len(), 0);
325    }
326
327    #[test]
328    fn test_combined_empty_datasets() {
329        let combined = combined_dataset(vec![]);
330        assert!(combined.is_empty());
331    }
332}