use crate::domain::Domain;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fmt;
use std::sync::Arc;
use thiserror::Error;
#[derive(Debug, Error, Clone)]
pub enum LicenseError {
#[error("Domain '{0}' is not licensed")]
DomainNotLicensed(Domain),
#[error("Feature '{0}' is not licensed")]
FeatureNotLicensed(String),
#[error("GPU-native kernels require Enterprise license")]
GpuNativeNotLicensed,
#[error("License expired at {0}")]
Expired(DateTime<Utc>),
#[error("Maximum kernel count ({0}) exceeded")]
KernelLimitExceeded(usize),
#[error("License validation failed: {0}")]
ValidationFailed(String),
#[error("Invalid license key")]
InvalidKey,
#[error("No valid license found")]
NotFound,
}
pub type LicenseResult<T> = std::result::Result<T, LicenseError>;
pub trait LicenseValidator: Send + Sync + fmt::Debug {
fn validate_domain(&self, domain: Domain) -> LicenseResult<()>;
fn validate_feature(&self, feature: &str) -> LicenseResult<()>;
fn gpu_native_enabled(&self) -> bool;
fn licensed_domains(&self) -> &[Domain];
fn expires_at(&self) -> Option<DateTime<Utc>>;
fn is_valid(&self) -> bool {
if let Some(expiry) = self.expires_at() {
Utc::now() < expiry
} else {
true }
}
fn tier(&self) -> LicenseTier;
fn max_kernels(&self) -> Option<usize>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum LicenseTier {
Development,
Community,
Professional,
Enterprise,
}
impl LicenseTier {
#[must_use]
pub const fn supports_gpu_native(&self) -> bool {
matches!(self, LicenseTier::Development | LicenseTier::Enterprise)
}
#[must_use]
pub const fn default_max_kernels(&self) -> Option<usize> {
match self {
LicenseTier::Development => None, LicenseTier::Community => Some(5),
LicenseTier::Professional => Some(50),
LicenseTier::Enterprise => None, }
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct LicenseId(pub String);
impl LicenseId {
#[must_use]
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
}
impl fmt::Display for LicenseId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct License {
pub id: LicenseId,
pub tier: LicenseTier,
pub domains: HashSet<Domain>,
pub features: HashSet<String>,
pub gpu_native: bool,
pub expires_at: Option<DateTime<Utc>>,
pub max_kernels: Option<usize>,
pub holder: String,
}
impl License {
#[must_use]
pub fn development() -> Self {
Self {
id: LicenseId::new("dev-license"),
tier: LicenseTier::Development,
domains: Domain::ALL.iter().copied().collect(),
features: HashSet::new(), gpu_native: true,
expires_at: None,
max_kernels: None,
holder: "Development".to_string(),
}
}
#[must_use]
pub fn enterprise(holder: impl Into<String>, expires_at: Option<DateTime<Utc>>) -> Self {
Self {
id: LicenseId::new(format!("enterprise-{}", chrono::Utc::now().timestamp())),
tier: LicenseTier::Enterprise,
domains: Domain::ALL.iter().copied().collect(),
features: HashSet::new(), gpu_native: true,
expires_at,
max_kernels: None,
holder: holder.into(),
}
}
#[must_use]
pub fn professional(
holder: impl Into<String>,
domains: HashSet<Domain>,
expires_at: Option<DateTime<Utc>>,
) -> Self {
Self {
id: LicenseId::new(format!("professional-{}", chrono::Utc::now().timestamp())),
tier: LicenseTier::Professional,
domains,
features: HashSet::new(),
gpu_native: false, expires_at,
max_kernels: Some(50),
holder: holder.into(),
}
}
#[must_use]
pub fn community(holder: impl Into<String>) -> Self {
let mut domains = HashSet::new();
domains.insert(Domain::Core);
domains.insert(Domain::GraphAnalytics);
domains.insert(Domain::StatisticalML);
Self {
id: LicenseId::new(format!("community-{}", chrono::Utc::now().timestamp())),
tier: LicenseTier::Community,
domains,
features: HashSet::new(),
gpu_native: false,
expires_at: None, max_kernels: Some(5),
holder: holder.into(),
}
}
#[must_use]
pub fn with_domain(mut self, domain: Domain) -> Self {
self.domains.insert(domain);
self
}
#[must_use]
pub fn with_feature(mut self, feature: impl Into<String>) -> Self {
self.features.insert(feature.into());
self
}
}
#[derive(Debug)]
pub struct StandardLicenseValidator {
license: License,
}
impl StandardLicenseValidator {
#[must_use]
pub fn new(license: License) -> Self {
Self { license }
}
#[must_use]
pub fn license(&self) -> &License {
&self.license
}
}
impl LicenseValidator for StandardLicenseValidator {
fn validate_domain(&self, domain: Domain) -> LicenseResult<()> {
if !self.is_valid() {
return Err(LicenseError::Expired(
self.license.expires_at.unwrap_or_else(Utc::now),
));
}
if self.license.tier == LicenseTier::Development {
return Ok(());
}
if self.license.tier == LicenseTier::Enterprise {
return Ok(());
}
if self.license.domains.contains(&domain) {
Ok(())
} else {
Err(LicenseError::DomainNotLicensed(domain))
}
}
fn validate_feature(&self, feature: &str) -> LicenseResult<()> {
if !self.is_valid() {
return Err(LicenseError::Expired(
self.license.expires_at.unwrap_or_else(Utc::now),
));
}
if matches!(
self.license.tier,
LicenseTier::Development | LicenseTier::Enterprise
) {
return Ok(());
}
if self.license.features.contains(feature) {
return Ok(());
}
if let Some((domain_str, _)) = feature.split_once('.') {
if let Some(domain) = Domain::parse(domain_str) {
if self.license.domains.contains(&domain) {
return Ok(());
}
}
}
Err(LicenseError::FeatureNotLicensed(feature.to_string()))
}
fn gpu_native_enabled(&self) -> bool {
self.license.gpu_native && self.license.tier.supports_gpu_native()
}
fn licensed_domains(&self) -> &[Domain] {
if matches!(
self.license.tier,
LicenseTier::Development | LicenseTier::Enterprise
) {
Domain::ALL
} else {
Domain::ALL }
}
fn expires_at(&self) -> Option<DateTime<Utc>> {
self.license.expires_at
}
fn tier(&self) -> LicenseTier {
self.license.tier
}
fn max_kernels(&self) -> Option<usize> {
self.license.max_kernels
}
}
#[derive(Debug, Default, Clone)]
pub struct DevelopmentLicense;
impl LicenseValidator for DevelopmentLicense {
fn validate_domain(&self, _domain: Domain) -> LicenseResult<()> {
Ok(()) }
fn validate_feature(&self, _feature: &str) -> LicenseResult<()> {
Ok(()) }
fn gpu_native_enabled(&self) -> bool {
true
}
fn licensed_domains(&self) -> &[Domain] {
Domain::ALL
}
fn expires_at(&self) -> Option<DateTime<Utc>> {
None }
fn tier(&self) -> LicenseTier {
LicenseTier::Development
}
fn max_kernels(&self) -> Option<usize> {
None }
}
#[derive(Debug)]
pub struct LicenseGuard<'a> {
validator: &'a dyn LicenseValidator,
domain: Domain,
}
impl<'a> LicenseGuard<'a> {
#[must_use]
pub fn new(validator: &'a dyn LicenseValidator, domain: Domain) -> Self {
Self { validator, domain }
}
pub fn check(&self) -> LicenseResult<()> {
self.validator.validate_domain(self.domain)
}
pub fn check_feature(&self, feature: &str) -> LicenseResult<()> {
self.validator.validate_feature(feature)
}
pub fn check_gpu_native(&self) -> LicenseResult<()> {
if self.validator.gpu_native_enabled() {
Ok(())
} else {
Err(LicenseError::GpuNativeNotLicensed)
}
}
}
pub type SharedLicenseValidator = Arc<dyn LicenseValidator>;
#[must_use]
pub fn dev_license() -> SharedLicenseValidator {
Arc::new(DevelopmentLicense)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_development_license() {
let license = DevelopmentLicense;
assert!(license.validate_domain(Domain::GraphAnalytics).is_ok());
assert!(license.validate_domain(Domain::RiskAnalytics).is_ok());
assert!(license.validate_feature("GraphAnalytics.PageRank").is_ok());
assert!(license.gpu_native_enabled());
assert!(license.is_valid());
assert_eq!(license.tier(), LicenseTier::Development);
}
#[test]
fn test_community_license() {
let license = License::community("Test User");
let validator = StandardLicenseValidator::new(license);
assert!(validator.validate_domain(Domain::Core).is_ok());
assert!(validator.validate_domain(Domain::GraphAnalytics).is_ok());
assert!(validator.validate_domain(Domain::StatisticalML).is_ok());
assert!(validator.validate_domain(Domain::RiskAnalytics).is_err());
assert!(!validator.gpu_native_enabled());
assert_eq!(validator.max_kernels(), Some(5));
}
#[test]
fn test_enterprise_license() {
let license = License::enterprise("Enterprise User", None);
let validator = StandardLicenseValidator::new(license);
assert!(validator.validate_domain(Domain::GraphAnalytics).is_ok());
assert!(validator.validate_domain(Domain::RiskAnalytics).is_ok());
assert!(validator.validate_domain(Domain::Banking).is_ok());
assert!(validator.gpu_native_enabled());
assert_eq!(validator.max_kernels(), None);
}
#[test]
fn test_expired_license() {
let mut license = License::enterprise("Expired User", None);
license.expires_at = Some(Utc::now() - chrono::Duration::days(1));
let validator = StandardLicenseValidator::new(license);
assert!(!validator.is_valid());
assert!(validator.validate_domain(Domain::Core).is_err());
}
#[test]
fn test_license_guard() {
let validator = DevelopmentLicense;
let guard = LicenseGuard::new(&validator, Domain::GraphAnalytics);
assert!(guard.check().is_ok());
assert!(guard.check_feature("GraphAnalytics.PageRank").is_ok());
assert!(guard.check_gpu_native().is_ok());
}
#[test]
fn test_license_tier_properties() {
assert!(LicenseTier::Development.supports_gpu_native());
assert!(LicenseTier::Enterprise.supports_gpu_native());
assert!(!LicenseTier::Professional.supports_gpu_native());
assert!(!LicenseTier::Community.supports_gpu_native());
}
}