1use crate::guard::{Guard, GuardCheckResult};
36use std::collections::HashMap;
37use std::sync::Arc;
38
39use oxide_license::{Feature, LicenseError, LicenseValidator};
40
41#[derive(Debug, thiserror::Error)]
43pub enum LicensedGuardError {
44 #[error("Feature '{feature}' requires {required_tier} license (current: {current_tier})")]
46 FeatureNotLicensed {
47 feature: String,
48 required_tier: String,
49 current_tier: String,
50 },
51
52 #[error("License validation failed: {0}")]
54 ValidationFailed(String),
55
56 #[error("Guard '{0}' not found in registry")]
58 GuardNotFound(String),
59
60 #[error("Guard creation failed: {0}")]
62 CreationFailed(String),
63
64 #[error("Configuration error: {0}")]
66 ConfigError(String),
67}
68
69impl From<LicenseError> for LicensedGuardError {
70 fn from(err: LicenseError) -> Self {
71 match err {
72 LicenseError::FeatureNotLicensed {
73 feature,
74 required_tier,
75 current_tier,
76 } => LicensedGuardError::FeatureNotLicensed {
77 feature,
78 required_tier,
79 current_tier,
80 },
81 _ => LicensedGuardError::ValidationFailed(err.to_string()),
82 }
83 }
84}
85
86pub type LicensedGuardResult<T> = std::result::Result<T, LicensedGuardError>;
88
89pub trait LicensedGuard: Guard {
94 fn required_feature(&self) -> &str;
96
97 fn required_tier(&self) -> &str;
99
100 fn is_community(&self) -> bool {
102 false
103 }
104}
105
106pub struct LicenseCheckedGuard<G: Guard> {
111 inner: G,
112 feature_name: String,
113 tier_name: String,
114 validator: Arc<LicenseValidator>,
115}
116
117impl<G: Guard> LicenseCheckedGuard<G> {
118 pub fn new(
120 inner: G,
121 feature_name: impl Into<String>,
122 tier_name: impl Into<String>,
123 validator: Arc<LicenseValidator>,
124 ) -> Self {
125 Self {
126 inner,
127 feature_name: feature_name.into(),
128 tier_name: tier_name.into(),
129 validator,
130 }
131 }
132
133 pub fn inner(&self) -> &G {
135 &self.inner
136 }
137
138 pub fn feature_name(&self) -> &str {
140 &self.feature_name
141 }
142
143 pub fn tier_name(&self) -> &str {
145 &self.tier_name
146 }
147
148 pub async fn validate_license(&self) -> LicensedGuardResult<()> {
150 let feature = match self.feature_name.as_str() {
151 "semantic_guard" => Feature::SemanticGuard,
152 "ml_classifier" => Feature::MlClassifier,
153 "scanner" | "advanced_probes" => Feature::Scanner,
154 "multi_layer_defense" => Feature::MultiLayerDefense,
155 "telemetry" => Feature::Telemetry,
156 "bundled_embeddings" => Feature::BundledEmbeddings,
157 "compliance_reports" => Feature::ComplianceReports,
158 "threat_intel" => Feature::ThreatIntel,
159 "dashboard" => Feature::Dashboard,
160 "api_access" => Feature::ApiAccess,
161 "proxy_basic" | "proxy_gateway" => Feature::ProxyBasic,
162 "webhook_alerts" => Feature::WebhookAlerts,
163 "rate_limiting" => Feature::RateLimiting,
164 "streaming_guards" => Feature::StreamingGuards,
165 "custom_guards" => Feature::CustomGuards,
166 "private_models" => Feature::PrivateModels,
167 "sso_saml" => Feature::SsoSaml,
168 "embedding_privacy" => Feature::EmbeddingPrivacy,
169 _ => {
170 return Ok(());
172 }
173 };
174
175 self.validator.require_feature(feature).await?;
176 Ok(())
177 }
178}
179
180impl<G: Guard> Guard for LicenseCheckedGuard<G> {
181 fn name(&self) -> &str {
182 self.inner.name()
183 }
184
185 fn check(&self, content: &str) -> GuardCheckResult {
186 self.inner.check(content)
189 }
190
191 fn action(&self) -> crate::guard::GuardAction {
192 self.inner.action()
193 }
194
195 fn severity_threshold(&self) -> oxideshield_core::Severity {
196 self.inner.severity_threshold()
197 }
198}
199
200impl<G: Guard> LicensedGuard for LicenseCheckedGuard<G> {
201 fn required_feature(&self) -> &str {
202 &self.feature_name
203 }
204
205 fn required_tier(&self) -> &str {
206 &self.tier_name
207 }
208}
209
210#[async_trait::async_trait]
215pub trait GuardFactory: Send + Sync {
216 type Guard: Guard + 'static;
218
219 async fn create(&self, config: &GuardFactoryConfig) -> LicensedGuardResult<Self::Guard>;
221
222 fn guard_type(&self) -> &str;
224
225 fn required_feature(&self) -> Option<&str>;
227
228 fn required_tier(&self) -> &str;
230}
231
232#[derive(Debug, Clone, Default)]
234pub struct GuardFactoryConfig {
235 pub name: String,
237 pub params: HashMap<String, String>,
239 pub skip_license_check: bool,
241}
242
243impl GuardFactoryConfig {
244 pub fn new(name: impl Into<String>) -> Self {
246 Self {
247 name: name.into(),
248 params: HashMap::new(),
249 skip_license_check: false,
250 }
251 }
252
253 pub fn with_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
255 self.params.insert(key.into(), value.into());
256 self
257 }
258
259 pub fn skip_license_check(mut self, skip: bool) -> Self {
261 self.skip_license_check = skip;
262 self
263 }
264}
265
266pub struct GuardRegistry {
271 factories: HashMap<String, Arc<dyn DynGuardFactory>>,
272}
273
274#[async_trait::async_trait]
276pub trait DynGuardFactory: Send + Sync {
277 async fn create_boxed(
279 &self,
280 config: &GuardFactoryConfig,
281 ) -> LicensedGuardResult<Box<dyn Guard>>;
282
283 fn guard_type(&self) -> &str;
285
286 fn required_feature(&self) -> Option<&str>;
288
289 fn required_tier(&self) -> &str;
291}
292
293#[async_trait::async_trait]
294impl<F: GuardFactory> DynGuardFactory for F {
295 async fn create_boxed(
296 &self,
297 config: &GuardFactoryConfig,
298 ) -> LicensedGuardResult<Box<dyn Guard>> {
299 let guard = self.create(config).await?;
300 Ok(Box::new(guard))
301 }
302
303 fn guard_type(&self) -> &str {
304 GuardFactory::guard_type(self)
305 }
306
307 fn required_feature(&self) -> Option<&str> {
308 GuardFactory::required_feature(self)
309 }
310
311 fn required_tier(&self) -> &str {
312 GuardFactory::required_tier(self)
313 }
314}
315
316impl GuardRegistry {
317 pub fn new() -> Self {
319 Self {
320 factories: HashMap::new(),
321 }
322 }
323
324 pub fn with_builtins() -> Self {
326 Self::new()
329 }
330
331 pub fn register<F: DynGuardFactory + 'static>(&mut self, name: &str, factory: F) {
333 self.factories.insert(name.to_string(), Arc::new(factory));
334 }
335
336 pub fn has(&self, name: &str) -> bool {
338 self.factories.contains_key(name)
339 }
340
341 pub fn list(&self) -> Vec<&str> {
343 self.factories.keys().map(|s| s.as_str()).collect()
344 }
345
346 pub fn info(&self, name: &str) -> Option<GuardInfo> {
348 self.factories.get(name).map(|f| GuardInfo {
349 name: f.guard_type().to_string(),
350 required_feature: f.required_feature().map(|s| s.to_string()),
351 required_tier: f.required_tier().to_string(),
352 })
353 }
354
355 pub async fn create(
357 &self,
358 name: &str,
359 config: &GuardFactoryConfig,
360 ) -> LicensedGuardResult<Box<dyn Guard>> {
361 let factory = self
362 .factories
363 .get(name)
364 .ok_or_else(|| LicensedGuardError::GuardNotFound(name.to_string()))?;
365
366 factory.create_boxed(config).await
367 }
368
369 pub async fn create_licensed(
371 &self,
372 name: &str,
373 config: &GuardFactoryConfig,
374 validator: &LicenseValidator,
375 ) -> LicensedGuardResult<Box<dyn Guard>> {
376 let factory = self
377 .factories
378 .get(name)
379 .ok_or_else(|| LicensedGuardError::GuardNotFound(name.to_string()))?;
380
381 if !config.skip_license_check {
383 if let Some(feature_name) = factory.required_feature() {
384 let feature = match feature_name {
385 "semantic_guard" => Feature::SemanticGuard,
386 "ml_classifier" => Feature::MlClassifier,
387 "scanner" | "advanced_probes" => Feature::Scanner,
388 "multi_layer_defense" => Feature::MultiLayerDefense,
389 "telemetry" => Feature::Telemetry,
390 "bundled_embeddings" => Feature::BundledEmbeddings,
391 "compliance_reports" => Feature::ComplianceReports,
392 "threat_intel" => Feature::ThreatIntel,
393 "dashboard" => Feature::Dashboard,
394 "api_access" => Feature::ApiAccess,
395 "proxy_basic" | "proxy_gateway" => Feature::ProxyBasic,
396 "embedding_privacy" => Feature::EmbeddingPrivacy,
397 _ => {
398 return factory.create_boxed(config).await;
400 }
401 };
402 validator.require_feature(feature).await?;
403 }
404 }
405
406 factory.create_boxed(config).await
407 }
408}
409
410impl Default for GuardRegistry {
411 fn default() -> Self {
412 Self::new()
413 }
414}
415
416#[derive(Debug, Clone)]
418pub struct GuardInfo {
419 pub name: String,
421 pub required_feature: Option<String>,
423 pub required_tier: String,
425}
426
427pub const GUARD_LICENSE_REQUIREMENTS: &[(&str, &str, &str)] = &[
431 ("PatternGuard", "", "Community"),
434 ("LengthGuard", "", "Community"),
435 ("EncodingGuard", "", "Community"),
436 ("PerplexityGuard", "", "Community"),
437 ("PIIGuard", "", "Community"),
438 ("ToxicityGuard", "", "Community"),
439 ("StructuredOutputGuard", "", "Community"),
440 ("SemanticSimilarityGuard", "semantic_guard", "Professional"),
442 ("MLClassifierGuard", "ml_classifier", "Professional"),
443 ("AgenticGuard", "agentic_guard", "Professional"),
444 ("SwarmGuard", "swarm_guard", "Professional"),
445 ("ContainmentPolicy", "containment_policy", "Professional"),
446 ("EmbeddingPIIFilter", "embedding_privacy", "Professional"),
447 ("CustomGuard", "custom_guards", "Enterprise"),
449];
450
451pub fn guard_license_requirement(guard_name: &str) -> Option<(&'static str, &'static str)> {
453 GUARD_LICENSE_REQUIREMENTS
454 .iter()
455 .find(|(name, _, _)| *name == guard_name)
456 .map(|(_, feature, tier)| (*feature, *tier))
457}
458
459pub fn is_community_guard(guard_name: &str) -> bool {
461 guard_license_requirement(guard_name)
462 .map(|(_, tier)| tier == "Community")
463 .unwrap_or(false)
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_guard_license_requirements() {
472 assert!(is_community_guard("PatternGuard"));
474 assert!(is_community_guard("PIIGuard"));
475 assert!(is_community_guard("ToxicityGuard"));
476
477 assert!(!is_community_guard("SemanticSimilarityGuard"));
479 assert!(!is_community_guard("MLClassifierGuard"));
480
481 assert!(!is_community_guard("UnknownGuard"));
483 }
484
485 #[test]
486 fn test_guard_factory_config() {
487 let config = GuardFactoryConfig::new("test")
488 .with_param("threshold", "0.8")
489 .with_param("action", "block")
490 .skip_license_check(true);
491
492 assert_eq!(config.name, "test");
493 assert_eq!(config.params.get("threshold"), Some(&"0.8".to_string()));
494 assert!(config.skip_license_check);
495 }
496
497 #[test]
498 fn test_guard_registry() {
499 let registry = GuardRegistry::new();
500 assert!(!registry.has("semantic"));
501 assert!(registry.list().is_empty());
502 }
503
504 #[test]
505 fn test_licensed_guard_error_display() {
506 let err = LicensedGuardError::FeatureNotLicensed {
507 feature: "semantic_guard".to_string(),
508 required_tier: "Professional".to_string(),
509 current_tier: "Community".to_string(),
510 };
511 assert!(err.to_string().contains("semantic_guard"));
512 assert!(err.to_string().contains("Professional"));
513 }
514}