use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use base64::Engine;
use chrono::{DateTime, Duration, Utc};
use semver::Version;
use std::collections::{HashMap, HashSet};
use crate::crypto::KeyPair;
use crate::error::{LicenseError, Result};
use crate::models::{LicenseConstraints, LicensePayload, SignedLicense, LICENSE_FORMAT_VERSION};
#[derive(Debug, Clone, Default)]
pub struct LicenseBuilder {
license_id: Option<String>,
customer_id: Option<String>,
customer_name: Option<String>,
issued_at: Option<DateTime<Utc>>,
constraints: LicenseConstraints,
metadata: Option<HashMap<String, serde_json::Value>>,
}
impl LicenseBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn license_id(mut self, id: impl Into<String>) -> Self {
self.license_id = Some(id.into());
self
}
pub fn customer_id(mut self, id: impl Into<String>) -> Self {
self.customer_id = Some(id.into());
self
}
pub fn customer_name(mut self, name: impl Into<String>) -> Self {
self.customer_name = Some(name.into());
self
}
pub fn issued_at(mut self, time: DateTime<Utc>) -> Self {
self.issued_at = Some(time);
self
}
pub fn expires_at(mut self, expiration: DateTime<Utc>) -> Self {
self.constraints.expiration_date = Some(expiration);
self
}
pub fn expires_in(mut self, duration: Duration) -> Self {
self.constraints.expiration_date = Some(Utc::now() + duration);
self
}
pub fn valid_from(mut self, valid_from: DateTime<Utc>) -> Self {
self.constraints.valid_from = Some(valid_from);
self
}
pub fn valid_after(mut self, duration: Duration) -> Self {
self.constraints.valid_from = Some(Utc::now() + duration);
self
}
pub fn allowed_feature(mut self, feature: impl Into<String>) -> Self {
self.constraints
.allowed_features
.get_or_insert_with(HashSet::new)
.insert(feature.into());
self
}
pub fn allowed_features(
mut self,
features: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
let allowed = self
.constraints
.allowed_features
.get_or_insert_with(HashSet::new);
for feature in features {
allowed.insert(feature.into());
}
self
}
pub fn denied_feature(mut self, feature: impl Into<String>) -> Self {
self.constraints
.denied_features
.get_or_insert_with(HashSet::new)
.insert(feature.into());
self
}
pub fn denied_features(
mut self,
features: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
let denied = self
.constraints
.denied_features
.get_or_insert_with(HashSet::new);
for feature in features {
denied.insert(feature.into());
}
self
}
pub fn max_connections(mut self, max: u32) -> Self {
self.constraints.max_connections = Some(max);
self
}
pub fn allowed_hostname(mut self, hostname: impl Into<String>) -> Self {
self.constraints
.allowed_hostnames
.get_or_insert_with(HashSet::new)
.insert(hostname.into());
self
}
pub fn allowed_hostnames(
mut self,
hostnames: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
let allowed = self
.constraints
.allowed_hostnames
.get_or_insert_with(HashSet::new);
for hostname in hostnames {
allowed.insert(hostname.into());
}
self
}
pub fn allowed_machine_id(mut self, machine_id: impl Into<String>) -> Self {
self.constraints
.allowed_machine_ids
.get_or_insert_with(HashSet::new)
.insert(machine_id.into());
self
}
pub fn allowed_machine_ids(
mut self,
machine_ids: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
let allowed = self
.constraints
.allowed_machine_ids
.get_or_insert_with(HashSet::new);
for machine_id in machine_ids {
allowed.insert(machine_id.into());
}
self
}
pub fn minimum_version(mut self, version: Version) -> Self {
self.constraints.minimum_software_version = Some(version);
self
}
pub fn minimum_version_str(self, version: &str) -> Result<Self> {
let parsed = Version::parse(version).map_err(|e| LicenseError::InvalidBuilderValue {
field: "minimum_version".to_string(),
reason: format!("invalid semver: {}", e),
})?;
Ok(self.minimum_version(parsed))
}
pub fn maximum_version(mut self, version: Version) -> Self {
self.constraints.maximum_software_version = Some(version);
self
}
pub fn maximum_version_str(self, version: &str) -> Result<Self> {
let parsed = Version::parse(version).map_err(|e| LicenseError::InvalidBuilderValue {
field: "maximum_version".to_string(),
reason: format!("invalid semver: {}", e),
})?;
Ok(self.maximum_version(parsed))
}
pub fn custom_constraint(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.constraints
.custom_constraints
.get_or_insert_with(HashMap::new)
.insert(key.into(), value);
self
}
pub fn metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata
.get_or_insert_with(HashMap::new)
.insert(key.into(), value);
self
}
pub fn add_key_value<V>(mut self, key: impl Into<String>, value: V) -> Self
where
V: Into<serde_json::Value>,
{
self.metadata
.get_or_insert_with(HashMap::new)
.insert(key.into(), value.into());
self
}
pub fn add_string(self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.add_key_value(key, serde_json::Value::String(value.into()))
}
pub fn add_i64(self, key: impl Into<String>, value: i64) -> Self {
self.add_key_value(key, value)
}
pub fn add_bool(self, key: impl Into<String>, value: bool) -> Self {
self.add_key_value(key, value)
}
pub fn add_string_array(
self,
key: impl Into<String>,
values: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
let array: Vec<serde_json::Value> = values
.into_iter()
.map(|s| serde_json::Value::String(s.into()))
.collect();
self.add_key_value(key, array)
}
pub fn with_constraints(mut self, constraints: LicenseConstraints) -> Self {
self.constraints = constraints;
self
}
fn validate(&self) -> Vec<String> {
let mut missing = Vec::new();
if self.license_id.is_none() {
missing.push("license_id".to_string());
}
if self.customer_id.is_none() {
missing.push("customer_id".to_string());
}
missing
}
pub fn build_payload(&self) -> Result<LicensePayload> {
let missing = self.validate();
if !missing.is_empty() {
return Err(LicenseError::BuilderIncomplete {
missing_fields: missing.join(", "),
});
}
let license_id = self.license_id.clone().unwrap();
let customer_id = self.customer_id.clone().unwrap();
Ok(LicensePayload {
format_version: LICENSE_FORMAT_VERSION,
license_id,
customer_id,
customer_name: self.customer_name.clone(),
issued_at: self.issued_at.unwrap_or_else(Utc::now),
constraints: self.constraints.clone(),
metadata: self.metadata.clone(),
})
}
pub fn build_and_sign(&self, key_pair: &KeyPair) -> Result<SignedLicense> {
let payload = self.build_payload()?;
let payload_json =
serde_json::to_string(&payload).map_err(|e| LicenseError::JsonSerializationFailed {
reason: e.to_string(),
})?;
let encoded_payload = BASE64_STANDARD.encode(payload_json.as_bytes());
let signature_base64 = key_pair.sign_base64(encoded_payload.as_bytes());
Ok(SignedLicense::new(encoded_payload, signature_base64))
}
pub fn build_and_sign_to_json(&self, key_pair: &KeyPair) -> Result<String> {
let signed_license = self.build_and_sign(key_pair)?;
signed_license
.to_json()
.map_err(|e| LicenseError::JsonSerializationFailed {
reason: e.to_string(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_key_pair() -> KeyPair {
KeyPair::generate().expect("Key generation should succeed")
}
#[test]
fn test_builder_required_fields() {
let key_pair = create_test_key_pair();
let result = LicenseBuilder::new().build_and_sign(&key_pair);
assert!(result.is_err());
let result = LicenseBuilder::new()
.license_id("LIC-001")
.build_and_sign(&key_pair);
assert!(result.is_err());
let result = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.build_and_sign(&key_pair);
assert!(result.is_ok());
}
#[test]
fn test_builder_with_expiration() {
let key_pair = create_test_key_pair();
let signed = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.expires_in(Duration::days(30))
.build_and_sign(&key_pair)
.expect("Should build license");
assert!(!signed.encoded_payload.is_empty());
assert!(!signed.encoded_signature.is_empty());
}
#[test]
fn test_builder_with_features() {
let _key_pair = create_test_key_pair();
let license = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.allowed_feature("premium")
.allowed_features(vec!["analytics", "reports"])
.denied_feature("admin")
.build_payload()
.expect("Should build payload");
let allowed = license.constraints.allowed_features.as_ref().unwrap();
assert!(allowed.contains("premium"));
assert!(allowed.contains("analytics"));
assert!(allowed.contains("reports"));
let denied = license.constraints.denied_features.as_ref().unwrap();
assert!(denied.contains("admin"));
}
#[test]
fn test_builder_with_version_constraints() {
let license = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.minimum_version_str("1.0.0")
.expect("Valid version")
.maximum_version_str("2.0.0")
.expect("Valid version")
.build_payload()
.expect("Should build payload");
assert_eq!(
license.constraints.minimum_software_version,
Some(Version::new(1, 0, 0))
);
assert_eq!(
license.constraints.maximum_software_version,
Some(Version::new(2, 0, 0))
);
}
#[test]
fn test_builder_with_hostnames() {
let license = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.allowed_hostname("server1.example.com")
.allowed_hostnames(vec!["server2.example.com", "server3.example.com"])
.build_payload()
.expect("Should build payload");
let allowed = license.constraints.allowed_hostnames.as_ref().unwrap();
assert_eq!(allowed.len(), 3);
assert!(allowed.contains("server1.example.com"));
}
#[test]
fn test_builder_with_metadata() {
let license = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.customer_name("Acme Corp")
.metadata("contract_id", serde_json::json!("CONTRACT-2024"))
.metadata("sales_rep", serde_json::json!("John Doe"))
.build_payload()
.expect("Should build payload");
assert_eq!(license.customer_name.as_deref(), Some("Acme Corp"));
let metadata = license.metadata.as_ref().unwrap();
assert_eq!(metadata["contract_id"], serde_json::json!("CONTRACT-2024"));
}
#[test]
fn test_build_and_sign_to_json() {
let key_pair = create_test_key_pair();
let json = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.build_and_sign_to_json(&key_pair)
.expect("Should build license JSON");
let parsed: SignedLicense =
serde_json::from_str(&json).expect("Should parse as SignedLicense");
assert!(!parsed.encoded_payload.is_empty());
}
#[test]
fn test_invalid_version_string() {
let result = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.minimum_version_str("not-a-version");
assert!(result.is_err());
}
#[test]
fn test_valid_from_constraint() {
let future_time = Utc::now() + Duration::days(7);
let license = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.valid_from(future_time)
.build_payload()
.expect("Should build payload");
assert!(license.constraints.valid_from.is_some());
}
#[test]
fn test_max_connections() {
let license = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.max_connections(50)
.build_payload()
.expect("Should build payload");
assert_eq!(license.constraints.max_connections, Some(50));
}
#[test]
fn test_custom_constraints() {
let license = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.custom_constraint("max_storage_gb", serde_json::json!(100))
.custom_constraint("tier", serde_json::json!("enterprise"))
.build_payload()
.expect("Should build payload");
let custom = license.constraints.custom_constraints.as_ref().unwrap();
assert_eq!(custom["max_storage_gb"], serde_json::json!(100));
assert_eq!(custom["tier"], serde_json::json!("enterprise"));
}
#[test]
fn test_add_key_value_string() {
let license = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.add_key_value("tier", "enterprise")
.add_string("region", "EU")
.build_payload()
.expect("Should build payload");
assert_eq!(license.get_string("tier"), Some("enterprise"));
assert_eq!(license.get_string("region"), Some("EU"));
}
#[test]
fn test_add_key_value_integer() {
let license = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.add_key_value("max_users", 100i64)
.add_i64("max_projects", 50)
.build_payload()
.expect("Should build payload");
assert_eq!(license.get_i64("max_users"), Some(100));
assert_eq!(license.get_i64("max_projects"), Some(50));
}
#[test]
fn test_add_key_value_boolean() {
let license = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.add_key_value("is_trial", false)
.add_bool("allow_api", true)
.build_payload()
.expect("Should build payload");
assert_eq!(license.get_bool("is_trial"), Some(false));
assert_eq!(license.get_bool("allow_api"), Some(true));
}
#[test]
fn test_add_key_value_array() {
let license = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.add_key_value(
"modules",
serde_json::json!(["core", "analytics", "export"]),
)
.add_string_array("plugins", vec!["plugin1", "plugin2"])
.build_payload()
.expect("Should build payload");
let modules = license.get_string_array("modules").unwrap();
assert_eq!(modules, vec!["core", "analytics", "export"]);
let plugins = license.get_string_array("plugins").unwrap();
assert_eq!(plugins, vec!["plugin1", "plugin2"]);
}
#[test]
fn test_add_key_value_object() {
let license = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.add_key_value(
"limits",
serde_json::json!({
"storage_gb": 500,
"bandwidth_tb": 10
}),
)
.build_payload()
.expect("Should build payload");
let limits = license.get_object("limits").unwrap();
assert_eq!(limits["storage_gb"], serde_json::json!(500));
assert_eq!(limits["bandwidth_tb"], serde_json::json!(10));
}
#[test]
fn test_add_key_value_mixed_types() {
let license = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.add_string("company", "Acme Corp")
.add_i64("employees", 500)
.add_bool("enterprise", true)
.add_string_array("regions", vec!["US", "EU", "APAC"])
.add_key_value(
"config",
serde_json::json!({
"theme": "dark",
"notifications": true
}),
)
.build_payload()
.expect("Should build payload");
assert_eq!(license.get_string("company"), Some("Acme Corp"));
assert_eq!(license.get_i64("employees"), Some(500));
assert_eq!(license.get_bool("enterprise"), Some(true));
let regions = license.get_string_array("regions").unwrap();
assert_eq!(regions.len(), 3);
let config = license.get_object("config").unwrap();
assert_eq!(config["theme"], serde_json::json!("dark"));
}
#[test]
fn test_add_key_value_with_defaults() {
let license = LicenseBuilder::new()
.license_id("LIC-001")
.customer_id("CUST-001")
.add_i64("limit", 100)
.build_payload()
.expect("Should build payload");
assert_eq!(license.get_i64_or("limit", 50), 100);
assert_eq!(license.get_i64_or("nonexistent", 50), 50);
assert_eq!(license.get_string_or("missing", "default"), "default");
assert_eq!(license.get_bool_or("missing", true), true);
}
}