Skip to main content

oxideshield_guard/
licensed.rs

1//! Licensed guard infrastructure for OxideShield.
2//!
3//! This module provides the infrastructure for license-gated guards,
4//! including the `LicensedGuard` trait and guard registry system.
5//!
6//! ## Architecture
7//!
8//! Commercial guards implement `LicensedGuard` which adds license validation
9//! on top of the base `Guard` trait. The `GuardRegistry` allows dynamic
10//! registration of guards, enabling plugin-style commercial extensions.
11//!
12//! ## License Tiers
13//!
14//! - **Community**: Basic guards (Pattern, Length, Encoding, Perplexity, PII, Toxicity)
15//! - **Professional**: + SemanticSimilarityGuard, MLClassifierGuard, Scanner, MultiLayerDefense
16//! - **Enterprise**: + Custom guards, private models, dashboard access
17//!
18//! ## Usage
19//!
20//! ```rust,ignore
21//! use oxideshield_guard::licensed::{LicensedGuard, GuardRegistry};
22//! use oxide_license::{LicenseValidator, Feature};
23//!
24//! // Create a registry
25//! let registry = GuardRegistry::new();
26//!
27//! // Register a licensed guard (from commercial module)
28//! registry.register("semantic", SemanticGuardFactory::new());
29//!
30//! // Create guard with license validation
31//! let validator = LicenseValidator::new()?;
32//! let guard = registry.create("semantic", &validator).await?;
33//! ```
34
35use crate::guard::{Guard, GuardCheckResult};
36use std::collections::HashMap;
37use std::sync::Arc;
38
39use oxide_license::{Feature, LicenseError, LicenseValidator};
40
41/// Error types for licensed guard operations.
42#[derive(Debug, thiserror::Error)]
43pub enum LicensedGuardError {
44    /// The required feature is not licensed.
45    #[error("Feature '{feature}' requires {required_tier} license (current: {current_tier})")]
46    FeatureNotLicensed {
47        feature: String,
48        required_tier: String,
49        current_tier: String,
50    },
51
52    /// License validation failed.
53    #[error("License validation failed: {0}")]
54    ValidationFailed(String),
55
56    /// Guard not found in registry.
57    #[error("Guard '{0}' not found in registry")]
58    GuardNotFound(String),
59
60    /// Guard creation failed.
61    #[error("Guard creation failed: {0}")]
62    CreationFailed(String),
63
64    /// Configuration error.
65    #[error("Configuration error: {0}")]
66    ConfigError(String),
67}
68
69impl From<LicenseError> for LicensedGuardError {
70    fn from(err: LicenseError) -> Self {
71        match err {
72            LicenseError::FeatureNotLicensed {
73                feature,
74                required_tier,
75                current_tier,
76            } => LicensedGuardError::FeatureNotLicensed {
77                feature,
78                required_tier,
79                current_tier,
80            },
81            _ => LicensedGuardError::ValidationFailed(err.to_string()),
82        }
83    }
84}
85
86/// Result type for licensed guard operations.
87pub type LicensedGuardResult<T> = std::result::Result<T, LicensedGuardError>;
88
89/// Guard that requires license validation.
90///
91/// This trait extends the base `Guard` trait with license requirements.
92/// Commercial guards implement this trait to enable license-gated access.
93pub trait LicensedGuard: Guard {
94    /// Returns the feature required to use this guard.
95    fn required_feature(&self) -> &str;
96
97    /// Returns the license tier required for this guard.
98    fn required_tier(&self) -> &str;
99
100    /// Returns true if this guard is available in the Community tier.
101    fn is_community(&self) -> bool {
102        false
103    }
104}
105
106/// Wrapper for a guard that performs license checks.
107///
108/// This wrapper validates the license before each guard check,
109/// returning an error if the required feature is not licensed.
110pub struct LicenseCheckedGuard<G: Guard> {
111    inner: G,
112    feature_name: String,
113    tier_name: String,
114    validator: Arc<LicenseValidator>,
115}
116
117impl<G: Guard> LicenseCheckedGuard<G> {
118    /// Creates a new license-checked guard wrapper.
119    pub fn new(
120        inner: G,
121        feature_name: impl Into<String>,
122        tier_name: impl Into<String>,
123        validator: Arc<LicenseValidator>,
124    ) -> Self {
125        Self {
126            inner,
127            feature_name: feature_name.into(),
128            tier_name: tier_name.into(),
129            validator,
130        }
131    }
132
133    /// Returns the inner guard.
134    pub fn inner(&self) -> &G {
135        &self.inner
136    }
137
138    /// Returns the feature name.
139    pub fn feature_name(&self) -> &str {
140        &self.feature_name
141    }
142
143    /// Returns the tier name.
144    pub fn tier_name(&self) -> &str {
145        &self.tier_name
146    }
147
148    /// Validates the license for this guard.
149    pub async fn validate_license(&self) -> LicensedGuardResult<()> {
150        let feature = match self.feature_name.as_str() {
151            "semantic_guard" => Feature::SemanticGuard,
152            "ml_classifier" => Feature::MlClassifier,
153            "scanner" | "advanced_probes" => Feature::Scanner,
154            "multi_layer_defense" => Feature::MultiLayerDefense,
155            "telemetry" => Feature::Telemetry,
156            "bundled_embeddings" => Feature::BundledEmbeddings,
157            "compliance_reports" => Feature::ComplianceReports,
158            "threat_intel" => Feature::ThreatIntel,
159            "dashboard" => Feature::Dashboard,
160            "api_access" => Feature::ApiAccess,
161            "proxy_basic" | "proxy_gateway" => Feature::ProxyBasic,
162            "webhook_alerts" => Feature::WebhookAlerts,
163            "rate_limiting" => Feature::RateLimiting,
164            "streaming_guards" => Feature::StreamingGuards,
165            "custom_guards" => Feature::CustomGuards,
166            "private_models" => Feature::PrivateModels,
167            "sso_saml" => Feature::SsoSaml,
168            "embedding_privacy" => Feature::EmbeddingPrivacy,
169            _ => {
170                // Unknown feature - allow for community features
171                return Ok(());
172            }
173        };
174
175        self.validator.require_feature(feature).await?;
176        Ok(())
177    }
178}
179
180impl<G: Guard> Guard for LicenseCheckedGuard<G> {
181    fn name(&self) -> &str {
182        self.inner.name()
183    }
184
185    fn check(&self, content: &str) -> GuardCheckResult {
186        // Note: License validation should be done async before calling check()
187        // This sync method assumes validation was already performed
188        self.inner.check(content)
189    }
190
191    fn action(&self) -> crate::guard::GuardAction {
192        self.inner.action()
193    }
194
195    fn severity_threshold(&self) -> oxideshield_core::Severity {
196        self.inner.severity_threshold()
197    }
198}
199
200impl<G: Guard> LicensedGuard for LicenseCheckedGuard<G> {
201    fn required_feature(&self) -> &str {
202        &self.feature_name
203    }
204
205    fn required_tier(&self) -> &str {
206        &self.tier_name
207    }
208}
209
210/// Factory trait for creating guards.
211///
212/// This trait enables dynamic guard creation through the registry.
213/// Each guard type provides a factory that can create instances.
214#[async_trait::async_trait]
215pub trait GuardFactory: Send + Sync {
216    /// The type of guard this factory creates.
217    type Guard: Guard + 'static;
218
219    /// Creates a new guard instance.
220    async fn create(&self, config: &GuardFactoryConfig) -> LicensedGuardResult<Self::Guard>;
221
222    /// Returns the name of the guard type.
223    fn guard_type(&self) -> &str;
224
225    /// Returns the required feature for this guard.
226    fn required_feature(&self) -> Option<&str>;
227
228    /// Returns the required tier for this guard.
229    fn required_tier(&self) -> &str;
230}
231
232/// Configuration for guard factory.
233#[derive(Debug, Clone, Default)]
234pub struct GuardFactoryConfig {
235    /// Guard name.
236    pub name: String,
237    /// Additional configuration parameters.
238    pub params: HashMap<String, String>,
239    /// Whether to skip license validation.
240    pub skip_license_check: bool,
241}
242
243impl GuardFactoryConfig {
244    /// Creates a new factory config with a name.
245    pub fn new(name: impl Into<String>) -> Self {
246        Self {
247            name: name.into(),
248            params: HashMap::new(),
249            skip_license_check: false,
250        }
251    }
252
253    /// Adds a parameter.
254    pub fn with_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
255        self.params.insert(key.into(), value.into());
256        self
257    }
258
259    /// Sets whether to skip license validation.
260    pub fn skip_license_check(mut self, skip: bool) -> Self {
261        self.skip_license_check = skip;
262        self
263    }
264}
265
266/// Registry for dynamically registering and creating guards.
267///
268/// The registry enables a plugin architecture where commercial
269/// guards can be registered at runtime.
270pub struct GuardRegistry {
271    factories: HashMap<String, Arc<dyn DynGuardFactory>>,
272}
273
274/// Type-erased guard factory for the registry.
275#[async_trait::async_trait]
276pub trait DynGuardFactory: Send + Sync {
277    /// Creates a boxed guard.
278    async fn create_boxed(
279        &self,
280        config: &GuardFactoryConfig,
281    ) -> LicensedGuardResult<Box<dyn Guard>>;
282
283    /// Returns the guard type name.
284    fn guard_type(&self) -> &str;
285
286    /// Returns the required feature.
287    fn required_feature(&self) -> Option<&str>;
288
289    /// Returns the required tier.
290    fn required_tier(&self) -> &str;
291}
292
293#[async_trait::async_trait]
294impl<F: GuardFactory> DynGuardFactory for F {
295    async fn create_boxed(
296        &self,
297        config: &GuardFactoryConfig,
298    ) -> LicensedGuardResult<Box<dyn Guard>> {
299        let guard = self.create(config).await?;
300        Ok(Box::new(guard))
301    }
302
303    fn guard_type(&self) -> &str {
304        GuardFactory::guard_type(self)
305    }
306
307    fn required_feature(&self) -> Option<&str> {
308        GuardFactory::required_feature(self)
309    }
310
311    fn required_tier(&self) -> &str {
312        GuardFactory::required_tier(self)
313    }
314}
315
316impl GuardRegistry {
317    /// Creates a new empty registry.
318    pub fn new() -> Self {
319        Self {
320            factories: HashMap::new(),
321        }
322    }
323
324    /// Creates a registry with built-in community guards.
325    pub fn with_builtins() -> Self {
326        // Built-in guards are registered here
327        // (Commercial guards would be registered by external modules)
328        Self::new()
329    }
330
331    /// Registers a guard factory.
332    pub fn register<F: DynGuardFactory + 'static>(&mut self, name: &str, factory: F) {
333        self.factories.insert(name.to_string(), Arc::new(factory));
334    }
335
336    /// Returns true if a guard type is registered.
337    pub fn has(&self, name: &str) -> bool {
338        self.factories.contains_key(name)
339    }
340
341    /// Lists all registered guard types.
342    pub fn list(&self) -> Vec<&str> {
343        self.factories.keys().map(|s| s.as_str()).collect()
344    }
345
346    /// Gets information about a registered guard type.
347    pub fn info(&self, name: &str) -> Option<GuardInfo> {
348        self.factories.get(name).map(|f| GuardInfo {
349            name: f.guard_type().to_string(),
350            required_feature: f.required_feature().map(|s| s.to_string()),
351            required_tier: f.required_tier().to_string(),
352        })
353    }
354
355    /// Creates a guard instance.
356    pub async fn create(
357        &self,
358        name: &str,
359        config: &GuardFactoryConfig,
360    ) -> LicensedGuardResult<Box<dyn Guard>> {
361        let factory = self
362            .factories
363            .get(name)
364            .ok_or_else(|| LicensedGuardError::GuardNotFound(name.to_string()))?;
365
366        factory.create_boxed(config).await
367    }
368
369    /// Creates a guard instance with license validation.
370    pub async fn create_licensed(
371        &self,
372        name: &str,
373        config: &GuardFactoryConfig,
374        validator: &LicenseValidator,
375    ) -> LicensedGuardResult<Box<dyn Guard>> {
376        let factory = self
377            .factories
378            .get(name)
379            .ok_or_else(|| LicensedGuardError::GuardNotFound(name.to_string()))?;
380
381        // Validate license if required
382        if !config.skip_license_check {
383            if let Some(feature_name) = factory.required_feature() {
384                let feature = match feature_name {
385                    "semantic_guard" => Feature::SemanticGuard,
386                    "ml_classifier" => Feature::MlClassifier,
387                    "scanner" | "advanced_probes" => Feature::Scanner,
388                    "multi_layer_defense" => Feature::MultiLayerDefense,
389                    "telemetry" => Feature::Telemetry,
390                    "bundled_embeddings" => Feature::BundledEmbeddings,
391                    "compliance_reports" => Feature::ComplianceReports,
392                    "threat_intel" => Feature::ThreatIntel,
393                    "dashboard" => Feature::Dashboard,
394                    "api_access" => Feature::ApiAccess,
395                    "proxy_basic" | "proxy_gateway" => Feature::ProxyBasic,
396                    "embedding_privacy" => Feature::EmbeddingPrivacy,
397                    _ => {
398                        // Unknown feature - skip validation
399                        return factory.create_boxed(config).await;
400                    }
401                };
402                validator.require_feature(feature).await?;
403            }
404        }
405
406        factory.create_boxed(config).await
407    }
408}
409
410impl Default for GuardRegistry {
411    fn default() -> Self {
412        Self::new()
413    }
414}
415
416/// Information about a registered guard type.
417#[derive(Debug, Clone)]
418pub struct GuardInfo {
419    /// Guard type name.
420    pub name: String,
421    /// Required feature (None for community guards).
422    pub required_feature: Option<String>,
423    /// Required license tier.
424    pub required_tier: String,
425}
426
427/// List of guards and their license requirements.
428///
429/// This is used for documentation and license validation.
430pub const GUARD_LICENSE_REQUIREMENTS: &[(&str, &str, &str)] = &[
431    // (Guard name, Required feature, Required tier)
432    // Community guards (no feature required)
433    ("PatternGuard", "", "Community"),
434    ("LengthGuard", "", "Community"),
435    ("EncodingGuard", "", "Community"),
436    ("PerplexityGuard", "", "Community"),
437    ("PIIGuard", "", "Community"),
438    ("ToxicityGuard", "", "Community"),
439    ("StructuredOutputGuard", "", "Community"),
440    // Professional guards
441    ("SemanticSimilarityGuard", "semantic_guard", "Professional"),
442    ("MLClassifierGuard", "ml_classifier", "Professional"),
443    ("AgenticGuard", "agentic_guard", "Professional"),
444    ("SwarmGuard", "swarm_guard", "Professional"),
445    ("ContainmentPolicy", "containment_policy", "Professional"),
446    ("EmbeddingPIIFilter", "embedding_privacy", "Professional"),
447    // Enterprise guards (future)
448    ("CustomGuard", "custom_guards", "Enterprise"),
449];
450
451/// Returns the license requirement for a guard type.
452pub fn guard_license_requirement(guard_name: &str) -> Option<(&'static str, &'static str)> {
453    GUARD_LICENSE_REQUIREMENTS
454        .iter()
455        .find(|(name, _, _)| *name == guard_name)
456        .map(|(_, feature, tier)| (*feature, *tier))
457}
458
459/// Returns true if a guard is available in the Community tier.
460pub fn is_community_guard(guard_name: &str) -> bool {
461    guard_license_requirement(guard_name)
462        .map(|(_, tier)| tier == "Community")
463        .unwrap_or(false)
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn test_guard_license_requirements() {
472        // Community guards
473        assert!(is_community_guard("PatternGuard"));
474        assert!(is_community_guard("PIIGuard"));
475        assert!(is_community_guard("ToxicityGuard"));
476
477        // Professional guards
478        assert!(!is_community_guard("SemanticSimilarityGuard"));
479        assert!(!is_community_guard("MLClassifierGuard"));
480
481        // Unknown guards
482        assert!(!is_community_guard("UnknownGuard"));
483    }
484
485    #[test]
486    fn test_guard_factory_config() {
487        let config = GuardFactoryConfig::new("test")
488            .with_param("threshold", "0.8")
489            .with_param("action", "block")
490            .skip_license_check(true);
491
492        assert_eq!(config.name, "test");
493        assert_eq!(config.params.get("threshold"), Some(&"0.8".to_string()));
494        assert!(config.skip_license_check);
495    }
496
497    #[test]
498    fn test_guard_registry() {
499        let registry = GuardRegistry::new();
500        assert!(!registry.has("semantic"));
501        assert!(registry.list().is_empty());
502    }
503
504    #[test]
505    fn test_licensed_guard_error_display() {
506        let err = LicensedGuardError::FeatureNotLicensed {
507            feature: "semantic_guard".to_string(),
508            required_tier: "Professional".to_string(),
509            current_tier: "Community".to_string(),
510        };
511        assert!(err.to_string().contains("semantic_guard"));
512        assert!(err.to_string().contains("Professional"));
513    }
514}