Skip to main content

oxideshield_guard/benchmark/
competitor.rs

1//! Competitor Adapters
2//!
3//! Adapters for comparing OxideShield against competitor tools.
4//! Provides interfaces for both direct API calls and mock/reference implementations.
5//!
6//! ## Competitors
7//!
8//! | Tool | Integration Method |
9//! |------|-------------------|
10//! | Llama Guard 3 | Reference metrics (F1=0.94) |
11//! | LLM Guard | Python subprocess |
12//! | NeMo Guardrails | Reference metrics |
13//! | Lakera Guard | API adapter |
14
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17
18use super::metrics::GuardMetrics;
19
20/// Result from a competitor check
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CompetitorResult {
23    /// Whether content was flagged
24    pub flagged: bool,
25    /// Confidence score (0-1)
26    pub confidence: f64,
27    /// Categories triggered
28    pub categories: Vec<String>,
29    /// Raw response (if available)
30    pub raw_response: Option<String>,
31}
32
33/// Trait for competitor adapters
34pub trait CompetitorAdapter: Send + Sync {
35    /// Get the competitor name
36    fn name(&self) -> &str;
37
38    /// Check content
39    fn check(&self, content: &str) -> CompetitorResult;
40
41    /// Check if this is a reference implementation (not real API)
42    fn is_reference(&self) -> bool;
43}
44
45/// Reference metrics for Llama Guard 3
46/// Based on published benchmarks from Meta AI
47#[derive(Debug, Clone)]
48pub struct LlamaGuard3Reference {
49    /// Reference F1 score
50    pub f1: f64,
51    /// Reference precision
52    pub precision: f64,
53    /// Reference recall
54    pub recall: f64,
55    /// Average latency (ms)
56    pub latency_ms: f64,
57}
58
59impl Default for LlamaGuard3Reference {
60    fn default() -> Self {
61        Self {
62            f1: 0.94,
63            precision: 0.96,
64            recall: 0.92,
65            latency_ms: 105.0, // ~100ms typical
66        }
67    }
68}
69
70impl LlamaGuard3Reference {
71    /// Create reference metrics object
72    pub fn to_guard_metrics(&self) -> GuardMetrics {
73        let mut metrics = GuardMetrics::new("Llama Guard 3 (Reference)");
74        // Simulate metrics based on reference values
75        // Assuming 100 samples: precision = TP/(TP+FP), recall = TP/(TP+FN)
76        // With precision=0.96, recall=0.92, we can derive:
77        metrics.true_positives = 92; // recall * 100 attacks
78        metrics.false_negatives = 8; // attacks - TP
79        metrics.false_positives = 4; // TP/precision - TP ≈ 4
80        metrics.true_negatives = 96; // benign - FP
81        metrics.latencies_ms = vec![self.latency_ms; 100];
82        metrics
83    }
84}
85
86/// Reference metrics for LLM Guard
87/// Based on published benchmarks from ProtectAI
88#[derive(Debug, Clone)]
89pub struct LLMGuardReference {
90    pub f1: f64,
91    pub precision: f64,
92    pub recall: f64,
93    pub latency_ms: f64,
94}
95
96impl Default for LLMGuardReference {
97    fn default() -> Self {
98        Self {
99            f1: 0.90,
100            precision: 0.92,
101            recall: 0.88,
102            latency_ms: 52.0,
103        }
104    }
105}
106
107impl LLMGuardReference {
108    pub fn to_guard_metrics(&self) -> GuardMetrics {
109        let mut metrics = GuardMetrics::new("LLM Guard (Reference)");
110        metrics.true_positives = 88;
111        metrics.false_negatives = 12;
112        metrics.false_positives = 8;
113        metrics.true_negatives = 92;
114        metrics.latencies_ms = vec![self.latency_ms; 100];
115        metrics
116    }
117}
118
119/// Reference metrics for Lakera Guard
120/// Based on published API benchmarks
121#[derive(Debug, Clone)]
122pub struct LakeraGuardReference {
123    pub f1: f64,
124    pub precision: f64,
125    pub recall: f64,
126    pub latency_ms: f64,
127}
128
129impl Default for LakeraGuardReference {
130    fn default() -> Self {
131        Self {
132            f1: 0.89,
133            precision: 0.91,
134            recall: 0.87,
135            latency_ms: 66.0, // Advertised ~66ms
136        }
137    }
138}
139
140impl LakeraGuardReference {
141    pub fn to_guard_metrics(&self) -> GuardMetrics {
142        let mut metrics = GuardMetrics::new("Lakera Guard (Reference)");
143        metrics.true_positives = 87;
144        metrics.false_negatives = 13;
145        metrics.false_positives = 9;
146        metrics.true_negatives = 91;
147        metrics.latencies_ms = vec![self.latency_ms; 100];
148        metrics
149    }
150}
151
152/// Reference metrics for NeMo Guardrails
153#[derive(Debug, Clone)]
154pub struct NeMoGuardrailsReference {
155    pub f1: f64,
156    pub precision: f64,
157    pub recall: f64,
158    pub latency_ms: f64,
159}
160
161impl Default for NeMoGuardrailsReference {
162    fn default() -> Self {
163        Self {
164            f1: 0.85,
165            precision: 0.88,
166            recall: 0.82,
167            latency_ms: 150.0, // Embedding + KNN lookup
168        }
169    }
170}
171
172impl NeMoGuardrailsReference {
173    pub fn to_guard_metrics(&self) -> GuardMetrics {
174        let mut metrics = GuardMetrics::new("NeMo Guardrails (Reference)");
175        metrics.true_positives = 82;
176        metrics.false_negatives = 18;
177        metrics.false_positives = 12;
178        metrics.true_negatives = 88;
179        metrics.latencies_ms = vec![self.latency_ms; 100];
180        metrics
181    }
182}
183
184/// Get all reference competitor metrics
185pub fn all_competitor_references() -> Vec<GuardMetrics> {
186    vec![
187        LlamaGuard3Reference::default().to_guard_metrics(),
188        LLMGuardReference::default().to_guard_metrics(),
189        LakeraGuardReference::default().to_guard_metrics(),
190        NeMoGuardrailsReference::default().to_guard_metrics(),
191    ]
192}
193
194/// Comparison table for competitor metrics
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct CompetitorComparison {
197    /// OxideShield metrics
198    pub oxideshield: GuardMetrics,
199    /// Competitor reference metrics
200    pub competitors: HashMap<String, GuardMetrics>,
201    /// Win/loss/tie summary
202    pub summary: ComparisonSummary,
203}
204
205/// Summary of wins/losses/ties
206#[derive(Debug, Clone, Default, Serialize, Deserialize)]
207pub struct ComparisonSummary {
208    /// Metrics where OxideShield wins
209    pub wins: Vec<String>,
210    /// Metrics where OxideShield loses
211    pub losses: Vec<String>,
212    /// Metrics that are tied
213    pub ties: Vec<String>,
214}
215
216impl CompetitorComparison {
217    /// Create a comparison between OxideShield and competitors
218    pub fn compare(oxideshield: GuardMetrics) -> Self {
219        let competitors: HashMap<String, GuardMetrics> = all_competitor_references()
220            .into_iter()
221            .map(|m| (m.name.clone(), m))
222            .collect();
223
224        let summary = Self::calculate_summary(&oxideshield, &competitors);
225
226        Self {
227            oxideshield,
228            competitors,
229            summary,
230        }
231    }
232
233    fn calculate_summary(
234        oxide: &GuardMetrics,
235        competitors: &HashMap<String, GuardMetrics>,
236    ) -> ComparisonSummary {
237        let mut summary = ComparisonSummary::default();
238
239        // Compare F1
240        let best_competitor_f1 = competitors
241            .values()
242            .map(|m| m.f1_score())
243            .max_by(|a, b| a.partial_cmp(b).unwrap())
244            .unwrap_or(0.0);
245
246        if oxide.f1_score() > best_competitor_f1 + 0.01 {
247            summary.wins.push("F1 Score".into());
248        } else if oxide.f1_score() < best_competitor_f1 - 0.01 {
249            summary.losses.push("F1 Score".into());
250        } else {
251            summary.ties.push("F1 Score".into());
252        }
253
254        // Compare latency (lower is better)
255        let best_competitor_latency = competitors
256            .values()
257            .map(|m| m.p50_latency_ms())
258            .min_by(|a, b| a.partial_cmp(b).unwrap())
259            .unwrap_or(f64::MAX);
260
261        if oxide.p50_latency_ms() < best_competitor_latency * 0.9 {
262            summary.wins.push("Latency (p50)".into());
263        } else if oxide.p50_latency_ms() > best_competitor_latency * 1.1 {
264            summary.losses.push("Latency (p50)".into());
265        } else {
266            summary.ties.push("Latency (p50)".into());
267        }
268
269        // Compare precision
270        let best_competitor_precision = competitors
271            .values()
272            .map(|m| m.precision())
273            .max_by(|a, b| a.partial_cmp(b).unwrap())
274            .unwrap_or(0.0);
275
276        if oxide.precision() > best_competitor_precision + 0.01 {
277            summary.wins.push("Precision".into());
278        } else if oxide.precision() < best_competitor_precision - 0.01 {
279            summary.losses.push("Precision".into());
280        } else {
281            summary.ties.push("Precision".into());
282        }
283
284        summary
285    }
286
287    /// Generate markdown comparison table
288    pub fn to_markdown(&self) -> String {
289        let mut md = String::new();
290
291        md.push_str("## OxideShield vs Competitors\n\n");
292        md.push_str("| Tool | F1 | Precision | Recall | p50 (ms) |\n");
293        md.push_str("|------|-----|-----------|--------|----------|\n");
294
295        // OxideShield first
296        md.push_str(&format!(
297            "| **OxideShield** | **{:.3}** | **{:.3}** | **{:.3}** | **{:.1}** |\n",
298            self.oxideshield.f1_score(),
299            self.oxideshield.precision(),
300            self.oxideshield.recall(),
301            self.oxideshield.p50_latency_ms()
302        ));
303
304        // Competitors
305        for (name, metrics) in &self.competitors {
306            md.push_str(&format!(
307                "| {} | {:.3} | {:.3} | {:.3} | {:.1} |\n",
308                name,
309                metrics.f1_score(),
310                metrics.precision(),
311                metrics.recall(),
312                metrics.p50_latency_ms()
313            ));
314        }
315
316        md.push_str("\n### Summary\n\n");
317        if !self.summary.wins.is_empty() {
318            md.push_str(&format!("- **Wins**: {}\n", self.summary.wins.join(", ")));
319        }
320        if !self.summary.losses.is_empty() {
321            md.push_str(&format!(
322                "- **Losses**: {}\n",
323                self.summary.losses.join(", ")
324            ));
325        }
326        if !self.summary.ties.is_empty() {
327            md.push_str(&format!("- **Ties**: {}\n", self.summary.ties.join(", ")));
328        }
329
330        md
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    #[test]
339    fn test_llama_guard_reference() {
340        let reference = LlamaGuard3Reference::default();
341        let metrics = reference.to_guard_metrics();
342
343        assert!(metrics.f1_score() > 0.9);
344        assert!(metrics.precision() > 0.9);
345    }
346
347    #[test]
348    fn test_all_references() {
349        let refs = all_competitor_references();
350        assert_eq!(refs.len(), 4);
351
352        for r in refs {
353            assert!(r.f1_score() > 0.8);
354            assert!(r.p50_latency_ms() > 0.0);
355        }
356    }
357
358    #[test]
359    fn test_competitor_comparison() {
360        let mut oxide = GuardMetrics::new("OxideShield");
361        oxide.true_positives = 95;
362        oxide.false_negatives = 5;
363        oxide.false_positives = 3;
364        oxide.true_negatives = 97;
365        oxide.latencies_ms = vec![15.0; 100];
366
367        let comparison = CompetitorComparison::compare(oxide);
368
369        assert!(!comparison.competitors.is_empty());
370        let md = comparison.to_markdown();
371        assert!(md.contains("OxideShield"));
372        assert!(md.contains("Llama Guard"));
373    }
374
375    #[test]
376    fn test_comparison_summary() {
377        let mut oxide = GuardMetrics::new("OxideShield");
378        // Set metrics to be better than all competitors
379        oxide.true_positives = 98;
380        oxide.false_negatives = 2;
381        oxide.false_positives = 1;
382        oxide.true_negatives = 99;
383        oxide.latencies_ms = vec![5.0; 100]; // Much faster
384
385        let comparison = CompetitorComparison::compare(oxide);
386
387        // Should win on latency
388        assert!(comparison
389            .summary
390            .wins
391            .contains(&"Latency (p50)".to_string()));
392    }
393}