oxideshield_guard/benchmark/
datasets.rs1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct BenchmarkSample {
13 pub id: String,
15 pub text: String,
17 pub is_attack: bool,
19 pub category: Option<String>,
21 pub subcategory: Option<String>,
23 pub source: String,
25 pub metadata: HashMap<String, String>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct BenchmarkDataset {
32 pub name: String,
34 pub version: String,
36 pub description: String,
38 pub samples: Vec<BenchmarkSample>,
40 pub source: String,
42}
43
44impl BenchmarkDataset {
45 pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
47 Self {
48 name: name.into(),
49 version: version.into(),
50 description: String::new(),
51 samples: Vec::new(),
52 source: String::new(),
53 }
54 }
55
56 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
58 self.description = desc.into();
59 self
60 }
61
62 pub fn with_source(mut self, source: impl Into<String>) -> Self {
64 self.source = source.into();
65 self
66 }
67
68 pub fn add_sample(mut self, sample: BenchmarkSample) -> Self {
70 self.samples.push(sample);
71 self
72 }
73
74 pub fn add_samples(mut self, samples: Vec<BenchmarkSample>) -> Self {
76 self.samples.extend(samples);
77 self
78 }
79
80 pub fn attack_samples(&self) -> Vec<&BenchmarkSample> {
82 self.samples.iter().filter(|s| s.is_attack).collect()
83 }
84
85 pub fn benign_samples(&self) -> Vec<&BenchmarkSample> {
87 self.samples.iter().filter(|s| !s.is_attack).collect()
88 }
89
90 pub fn len(&self) -> usize {
92 self.samples.len()
93 }
94
95 pub fn is_empty(&self) -> bool {
97 self.samples.is_empty()
98 }
99
100 pub fn samples_by_category(&self, category: &str) -> Vec<&BenchmarkSample> {
102 self.samples
103 .iter()
104 .filter(|s| s.category.as_deref() == Some(category))
105 .collect()
106 }
107}
108
109#[cfg(feature = "adversarial-samples")]
113pub fn oxideshield_benchmark_dataset() -> BenchmarkDataset {
114 use crate::adversarial::{all_attack_samples, all_benign_samples};
115
116 let mut dataset = BenchmarkDataset::new("OxideShield Standard", "1.0.0")
117 .with_description("Standard benchmark dataset for OxideShield evaluation")
118 .with_source("OxideShield internal");
119
120 for sample in all_attack_samples() {
122 dataset.samples.push(BenchmarkSample {
123 id: sample.id,
124 text: sample.prompt,
125 is_attack: true,
126 category: Some(format!("{:?}", sample.attack_type)),
127 subcategory: Some(sample.description),
128 source: sample.source,
129 metadata: HashMap::new(),
130 });
131 }
132
133 for sample in all_benign_samples() {
135 dataset.samples.push(BenchmarkSample {
136 id: sample.id,
137 text: sample.prompt,
138 is_attack: false,
139 category: Some(format!("{:?}", sample.category)),
140 subcategory: Some(sample.description),
141 source: "OxideShield benign".into(),
142 metadata: HashMap::new(),
143 });
144 }
145
146 dataset
147}
148
149#[cfg(feature = "adversarial-samples")]
152pub fn jailbreakbench_dataset() -> BenchmarkDataset {
153 use crate::adversarial::jailbreakbench_behaviors;
154
155 let mut dataset = BenchmarkDataset::new("JailbreakBench", "1.0.0")
156 .with_description("Behaviors from JailbreakBench (jailbreakbench.github.io)")
157 .with_source("https://jailbreakbench.github.io/");
158
159 for behavior in jailbreakbench_behaviors() {
160 dataset.samples.push(BenchmarkSample {
161 id: behavior.id,
162 text: behavior.prompt,
163 is_attack: true,
164 category: Some(behavior.category.to_string()),
165 subcategory: Some(behavior.behavior),
166 source: "JailbreakBench".into(),
167 metadata: HashMap::new(),
168 });
169 }
170
171 dataset
172}
173
174#[cfg(feature = "adversarial-samples")]
177pub fn prompt_injection_dataset() -> BenchmarkDataset {
178 use crate::adversarial::{prompt_injection_samples, system_prompt_leak_samples};
179
180 let mut dataset = BenchmarkDataset::new("Prompt Injection", "1.0.0")
181 .with_description("Focused dataset for prompt injection and system prompt leak attacks")
182 .with_source("OxideShield curated");
183
184 for sample in prompt_injection_samples() {
185 dataset.samples.push(BenchmarkSample {
186 id: sample.id,
187 text: sample.prompt,
188 is_attack: true,
189 category: Some("PromptInjection".into()),
190 subcategory: Some(sample.description),
191 source: sample.source,
192 metadata: HashMap::new(),
193 });
194 }
195
196 for sample in system_prompt_leak_samples() {
197 dataset.samples.push(BenchmarkSample {
198 id: sample.id,
199 text: sample.prompt,
200 is_attack: true,
201 category: Some("SystemPromptLeak".into()),
202 subcategory: Some(sample.description),
203 source: sample.source,
204 metadata: HashMap::new(),
205 });
206 }
207
208 dataset
209}
210
211#[cfg(feature = "adversarial-samples")]
214pub fn adversarial_suffix_dataset() -> BenchmarkDataset {
215 use crate::adversarial::{autodan_samples, gcg_samples};
216
217 let mut dataset = BenchmarkDataset::new("Adversarial Suffixes", "1.0.0")
218 .with_description("AutoDAN and GCG style adversarial suffix attacks")
219 .with_source("Research paper simulations");
220
221 for sample in autodan_samples() {
222 dataset.samples.push(BenchmarkSample {
223 id: sample.id,
224 text: sample.prompt,
225 is_attack: true,
226 category: Some("AutoDAN".into()),
227 subcategory: Some(sample.description),
228 source: sample.source,
229 metadata: HashMap::new(),
230 });
231 }
232
233 for sample in gcg_samples() {
234 dataset.samples.push(BenchmarkSample {
235 id: sample.id,
236 text: sample.prompt,
237 is_attack: true,
238 category: Some("GCG".into()),
239 subcategory: Some(sample.description),
240 source: sample.source,
241 metadata: HashMap::new(),
242 });
243 }
244
245 dataset
246}
247
248pub fn combined_dataset(datasets: Vec<BenchmarkDataset>) -> BenchmarkDataset {
250 let mut combined = BenchmarkDataset::new("Combined", "1.0.0")
251 .with_description("Combined dataset from multiple sources");
252
253 let sources: Vec<String> = datasets.iter().map(|d| d.name.clone()).collect();
254 combined.source = sources.join(", ");
255
256 for dataset in datasets {
257 for mut sample in dataset.samples {
258 sample.id = format!("{}_{}", dataset.name.replace(' ', "_"), sample.id);
260 combined.samples.push(sample);
261 }
262 }
263
264 combined
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[cfg(feature = "adversarial-samples")]
272 #[test]
273 fn test_oxideshield_dataset() {
274 let dataset = oxideshield_benchmark_dataset();
275 assert!(!dataset.is_empty());
276 assert!(dataset.attack_samples().len() > 0);
277 assert!(dataset.benign_samples().len() > 0);
278 }
279
280 #[cfg(feature = "adversarial-samples")]
281 #[test]
282 fn test_jailbreakbench_dataset() {
283 let dataset = jailbreakbench_dataset();
284 assert!(!dataset.is_empty());
285 assert_eq!(dataset.attack_samples().len(), dataset.len());
287 }
288
289 #[cfg(feature = "adversarial-samples")]
290 #[test]
291 fn test_prompt_injection_dataset() {
292 let dataset = prompt_injection_dataset();
293 assert!(dataset.len() >= 13); }
295
296 #[cfg(feature = "adversarial-samples")]
297 #[test]
298 fn test_adversarial_suffix_dataset() {
299 let dataset = adversarial_suffix_dataset();
300 assert!(dataset.len() >= 6); }
302
303 #[cfg(feature = "adversarial-samples")]
304 #[test]
305 fn test_combined_dataset() {
306 let d1 = prompt_injection_dataset();
307 let d2 = adversarial_suffix_dataset();
308 let combined = combined_dataset(vec![d1, d2]);
309 assert!(combined.len() >= 19);
310 }
311
312 #[cfg(feature = "adversarial-samples")]
313 #[test]
314 fn test_samples_by_category() {
315 let dataset = oxideshield_benchmark_dataset();
316 let pi_samples = dataset.samples_by_category("PromptInjection");
317 assert!(pi_samples.len() > 0);
318 }
319
320 #[test]
321 fn test_benchmark_dataset_new() {
322 let dataset = BenchmarkDataset::new("Test", "1.0.0");
323 assert!(dataset.is_empty());
324 assert_eq!(dataset.len(), 0);
325 }
326
327 #[test]
328 fn test_combined_empty_datasets() {
329 let combined = combined_dataset(vec![]);
330 assert!(combined.is_empty());
331 }
332}