oxideshield_guard/benchmark/
competitor.rs1use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17
18use super::metrics::GuardMetrics;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CompetitorResult {
23 pub flagged: bool,
25 pub confidence: f64,
27 pub categories: Vec<String>,
29 pub raw_response: Option<String>,
31}
32
33pub trait CompetitorAdapter: Send + Sync {
35 fn name(&self) -> &str;
37
38 fn check(&self, content: &str) -> CompetitorResult;
40
41 fn is_reference(&self) -> bool;
43}
44
45#[derive(Debug, Clone)]
48pub struct LlamaGuard3Reference {
49 pub f1: f64,
51 pub precision: f64,
53 pub recall: f64,
55 pub latency_ms: f64,
57}
58
59impl Default for LlamaGuard3Reference {
60 fn default() -> Self {
61 Self {
62 f1: 0.94,
63 precision: 0.96,
64 recall: 0.92,
65 latency_ms: 105.0, }
67 }
68}
69
70impl LlamaGuard3Reference {
71 pub fn to_guard_metrics(&self) -> GuardMetrics {
73 let mut metrics = GuardMetrics::new("Llama Guard 3 (Reference)");
74 metrics.true_positives = 92; metrics.false_negatives = 8; metrics.false_positives = 4; metrics.true_negatives = 96; metrics.latencies_ms = vec![self.latency_ms; 100];
82 metrics
83 }
84}
85
86#[derive(Debug, Clone)]
89pub struct LLMGuardReference {
90 pub f1: f64,
91 pub precision: f64,
92 pub recall: f64,
93 pub latency_ms: f64,
94}
95
96impl Default for LLMGuardReference {
97 fn default() -> Self {
98 Self {
99 f1: 0.90,
100 precision: 0.92,
101 recall: 0.88,
102 latency_ms: 52.0,
103 }
104 }
105}
106
107impl LLMGuardReference {
108 pub fn to_guard_metrics(&self) -> GuardMetrics {
109 let mut metrics = GuardMetrics::new("LLM Guard (Reference)");
110 metrics.true_positives = 88;
111 metrics.false_negatives = 12;
112 metrics.false_positives = 8;
113 metrics.true_negatives = 92;
114 metrics.latencies_ms = vec![self.latency_ms; 100];
115 metrics
116 }
117}
118
119#[derive(Debug, Clone)]
122pub struct LakeraGuardReference {
123 pub f1: f64,
124 pub precision: f64,
125 pub recall: f64,
126 pub latency_ms: f64,
127}
128
129impl Default for LakeraGuardReference {
130 fn default() -> Self {
131 Self {
132 f1: 0.89,
133 precision: 0.91,
134 recall: 0.87,
135 latency_ms: 66.0, }
137 }
138}
139
140impl LakeraGuardReference {
141 pub fn to_guard_metrics(&self) -> GuardMetrics {
142 let mut metrics = GuardMetrics::new("Lakera Guard (Reference)");
143 metrics.true_positives = 87;
144 metrics.false_negatives = 13;
145 metrics.false_positives = 9;
146 metrics.true_negatives = 91;
147 metrics.latencies_ms = vec![self.latency_ms; 100];
148 metrics
149 }
150}
151
152#[derive(Debug, Clone)]
154pub struct NeMoGuardrailsReference {
155 pub f1: f64,
156 pub precision: f64,
157 pub recall: f64,
158 pub latency_ms: f64,
159}
160
161impl Default for NeMoGuardrailsReference {
162 fn default() -> Self {
163 Self {
164 f1: 0.85,
165 precision: 0.88,
166 recall: 0.82,
167 latency_ms: 150.0, }
169 }
170}
171
172impl NeMoGuardrailsReference {
173 pub fn to_guard_metrics(&self) -> GuardMetrics {
174 let mut metrics = GuardMetrics::new("NeMo Guardrails (Reference)");
175 metrics.true_positives = 82;
176 metrics.false_negatives = 18;
177 metrics.false_positives = 12;
178 metrics.true_negatives = 88;
179 metrics.latencies_ms = vec![self.latency_ms; 100];
180 metrics
181 }
182}
183
184pub fn all_competitor_references() -> Vec<GuardMetrics> {
186 vec![
187 LlamaGuard3Reference::default().to_guard_metrics(),
188 LLMGuardReference::default().to_guard_metrics(),
189 LakeraGuardReference::default().to_guard_metrics(),
190 NeMoGuardrailsReference::default().to_guard_metrics(),
191 ]
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct CompetitorComparison {
197 pub oxideshield: GuardMetrics,
199 pub competitors: HashMap<String, GuardMetrics>,
201 pub summary: ComparisonSummary,
203}
204
205#[derive(Debug, Clone, Default, Serialize, Deserialize)]
207pub struct ComparisonSummary {
208 pub wins: Vec<String>,
210 pub losses: Vec<String>,
212 pub ties: Vec<String>,
214}
215
216impl CompetitorComparison {
217 pub fn compare(oxideshield: GuardMetrics) -> Self {
219 let competitors: HashMap<String, GuardMetrics> = all_competitor_references()
220 .into_iter()
221 .map(|m| (m.name.clone(), m))
222 .collect();
223
224 let summary = Self::calculate_summary(&oxideshield, &competitors);
225
226 Self {
227 oxideshield,
228 competitors,
229 summary,
230 }
231 }
232
233 fn calculate_summary(
234 oxide: &GuardMetrics,
235 competitors: &HashMap<String, GuardMetrics>,
236 ) -> ComparisonSummary {
237 let mut summary = ComparisonSummary::default();
238
239 let best_competitor_f1 = competitors
241 .values()
242 .map(|m| m.f1_score())
243 .max_by(|a, b| a.partial_cmp(b).unwrap())
244 .unwrap_or(0.0);
245
246 if oxide.f1_score() > best_competitor_f1 + 0.01 {
247 summary.wins.push("F1 Score".into());
248 } else if oxide.f1_score() < best_competitor_f1 - 0.01 {
249 summary.losses.push("F1 Score".into());
250 } else {
251 summary.ties.push("F1 Score".into());
252 }
253
254 let best_competitor_latency = competitors
256 .values()
257 .map(|m| m.p50_latency_ms())
258 .min_by(|a, b| a.partial_cmp(b).unwrap())
259 .unwrap_or(f64::MAX);
260
261 if oxide.p50_latency_ms() < best_competitor_latency * 0.9 {
262 summary.wins.push("Latency (p50)".into());
263 } else if oxide.p50_latency_ms() > best_competitor_latency * 1.1 {
264 summary.losses.push("Latency (p50)".into());
265 } else {
266 summary.ties.push("Latency (p50)".into());
267 }
268
269 let best_competitor_precision = competitors
271 .values()
272 .map(|m| m.precision())
273 .max_by(|a, b| a.partial_cmp(b).unwrap())
274 .unwrap_or(0.0);
275
276 if oxide.precision() > best_competitor_precision + 0.01 {
277 summary.wins.push("Precision".into());
278 } else if oxide.precision() < best_competitor_precision - 0.01 {
279 summary.losses.push("Precision".into());
280 } else {
281 summary.ties.push("Precision".into());
282 }
283
284 summary
285 }
286
287 pub fn to_markdown(&self) -> String {
289 let mut md = String::new();
290
291 md.push_str("## OxideShield vs Competitors\n\n");
292 md.push_str("| Tool | F1 | Precision | Recall | p50 (ms) |\n");
293 md.push_str("|------|-----|-----------|--------|----------|\n");
294
295 md.push_str(&format!(
297 "| **OxideShield** | **{:.3}** | **{:.3}** | **{:.3}** | **{:.1}** |\n",
298 self.oxideshield.f1_score(),
299 self.oxideshield.precision(),
300 self.oxideshield.recall(),
301 self.oxideshield.p50_latency_ms()
302 ));
303
304 for (name, metrics) in &self.competitors {
306 md.push_str(&format!(
307 "| {} | {:.3} | {:.3} | {:.3} | {:.1} |\n",
308 name,
309 metrics.f1_score(),
310 metrics.precision(),
311 metrics.recall(),
312 metrics.p50_latency_ms()
313 ));
314 }
315
316 md.push_str("\n### Summary\n\n");
317 if !self.summary.wins.is_empty() {
318 md.push_str(&format!("- **Wins**: {}\n", self.summary.wins.join(", ")));
319 }
320 if !self.summary.losses.is_empty() {
321 md.push_str(&format!(
322 "- **Losses**: {}\n",
323 self.summary.losses.join(", ")
324 ));
325 }
326 if !self.summary.ties.is_empty() {
327 md.push_str(&format!("- **Ties**: {}\n", self.summary.ties.join(", ")));
328 }
329
330 md
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337
338 #[test]
339 fn test_llama_guard_reference() {
340 let reference = LlamaGuard3Reference::default();
341 let metrics = reference.to_guard_metrics();
342
343 assert!(metrics.f1_score() > 0.9);
344 assert!(metrics.precision() > 0.9);
345 }
346
347 #[test]
348 fn test_all_references() {
349 let refs = all_competitor_references();
350 assert_eq!(refs.len(), 4);
351
352 for r in refs {
353 assert!(r.f1_score() > 0.8);
354 assert!(r.p50_latency_ms() > 0.0);
355 }
356 }
357
358 #[test]
359 fn test_competitor_comparison() {
360 let mut oxide = GuardMetrics::new("OxideShield");
361 oxide.true_positives = 95;
362 oxide.false_negatives = 5;
363 oxide.false_positives = 3;
364 oxide.true_negatives = 97;
365 oxide.latencies_ms = vec![15.0; 100];
366
367 let comparison = CompetitorComparison::compare(oxide);
368
369 assert!(!comparison.competitors.is_empty());
370 let md = comparison.to_markdown();
371 assert!(md.contains("OxideShield"));
372 assert!(md.contains("Llama Guard"));
373 }
374
375 #[test]
376 fn test_comparison_summary() {
377 let mut oxide = GuardMetrics::new("OxideShield");
378 oxide.true_positives = 98;
380 oxide.false_negatives = 2;
381 oxide.false_positives = 1;
382 oxide.true_negatives = 99;
383 oxide.latencies_ms = vec![5.0; 100]; let comparison = CompetitorComparison::compare(oxide);
386
387 assert!(comparison
389 .summary
390 .wins
391 .contains(&"Latency (p50)".to_string()));
392 }
393}