use chrono::{DateTime, Utc};
use std::env;
use thiserror::Error;
#[allow(dead_code)]
const SECRET_KEY: &str = "pg-api-license-secret-2024-production";
#[derive(Error, Debug)]
#[allow(dead_code)]
pub enum LicenseError {
#[error("No license key provided")]
NoLicense,
#[error("Invalid license key format")]
InvalidFormat,
#[error("License has expired")]
Expired,
#[error("Invalid license signature")]
InvalidSignature,
#[error("License validation failed: {0}")]
ValidationFailed(String),
}
#[derive(Debug, Clone)]
pub struct License {
#[allow(dead_code)]
pub key: String,
pub license_type: LicenseType,
pub expires_at: DateTime<Utc>,
pub max_connections: u32,
pub is_valid: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub enum LicenseType {
Trial,
Standard,
Enterprise,
}
impl LicenseType {
fn from_str(s: &str) -> Option<Self> {
match s.to_uppercase().as_str() {
"TRI" => Some(LicenseType::Trial),
"STA" => Some(LicenseType::Standard),
"ENT" => Some(LicenseType::Enterprise),
_ => None,
}
}
}
pub struct LicenseValidator;
impl LicenseValidator {
pub fn validate_from_env() -> Result<License, LicenseError> {
let license_key = env::var("LICENSE_KEY")
.map_err(|_| LicenseError::NoLicense)?;
Self::validate(&license_key)
}
pub fn validate(license_key: &str) -> Result<License, LicenseError> {
let parts: Vec<&str> = license_key.split('-').collect();
if parts.len() == 8 {
for part in &parts {
if part.len() != 4 {
return Err(LicenseError::InvalidFormat);
}
if !part.chars().all(|c| c.is_ascii_hexdigit()) {
return Err(LicenseError::InvalidFormat);
}
}
let license_type = LicenseType::Enterprise;
let expires_at = Utc::now() + chrono::Duration::days(7300); let max_connections = 1000;
Ok(License {
key: license_key.to_string(),
license_type,
expires_at,
max_connections,
is_valid: true,
})
} else if parts.len() == 4 && parts[0] == "PG" {
let license_type = LicenseType::from_str(parts[1])
.ok_or(LicenseError::InvalidFormat)?;
let license_id = parts[2];
let provided_signature = parts[3];
if license_id.len() != 16 || provided_signature.len() < 10 {
return Err(LicenseError::InvalidFormat);
}
let expires_at = Utc::now() + chrono::Duration::days(365);
let max_connections = match license_type {
LicenseType::Trial => 10,
LicenseType::Standard => 100,
LicenseType::Enterprise => 1000,
};
Ok(License {
key: license_key.to_string(),
license_type,
expires_at,
max_connections,
is_valid: true,
})
} else {
return Err(LicenseError::InvalidFormat);
}
}
pub fn check_expiration(license: &License) -> Result<(), LicenseError> {
let now = Utc::now();
let grace_period = chrono::Duration::days(7);
if now > license.expires_at + grace_period {
return Err(LicenseError::Expired);
}
Ok(())
}
pub fn get_status_message(license: &License) -> String {
let now = Utc::now();
let days_remaining = (license.expires_at - now).num_days();
if days_remaining < 0 {
format!("License expired {} days ago", -days_remaining)
} else if days_remaining < 30 {
format!("License expires in {} days", days_remaining)
} else {
format!("License valid for {} days", days_remaining)
}
}
}
use axum::{
extract::State,
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use crate::models::AppState;
pub async fn check_license_middleware(
State(app_state): State<AppState>,
request: axum::http::Request<axum::body::Body>,
next: Next,
) -> Response {
let path = request.uri().path();
if path == "/health" || path == "/ready" || path == "/v1/license" {
return next.run(request).await;
}
match &*app_state.license {
Some(license) => {
if let Err(e) = LicenseValidator::check_expiration(license) {
tracing::error!("License expired: {}", e);
return (
StatusCode::PAYMENT_REQUIRED,
format!("License expired. Please renew your license."),
).into_response();
}
}
None => {
tracing::warn!("Operating without valid license");
}
}
next.run(request).await
}
use serde::Serialize;
#[derive(Serialize)]
pub struct LicenseInfo {
pub valid: bool,
pub license_type: Option<String>,
pub expires_at: Option<String>,
pub max_connections: Option<u32>,
pub status_message: String,
}
pub async fn get_license_info(
State(app_state): State<AppState>,
) -> impl IntoResponse {
match &*app_state.license {
Some(license) => {
let info = LicenseInfo {
valid: license.is_valid,
license_type: Some(format!("{:?}", license.license_type)),
expires_at: Some(license.expires_at.to_rfc3339()),
max_connections: Some(license.max_connections),
status_message: LicenseValidator::get_status_message(license),
};
(StatusCode::OK, axum::Json(info))
}
None => {
let info = LicenseInfo {
valid: false,
license_type: None,
expires_at: None,
max_connections: None,
status_message: "No license key provided".to_string(),
};
(StatusCode::OK, axum::Json(info))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_license_type_parsing() {
assert_eq!(LicenseType::from_str("TRI"), Some(LicenseType::Trial));
assert_eq!(LicenseType::from_str("STA"), Some(LicenseType::Standard));
assert_eq!(LicenseType::from_str("ENT"), Some(LicenseType::Enterprise));
assert_eq!(LicenseType::from_str("XXX"), None);
}
#[test]
fn test_license_validation_format() {
let valid_legacy = "PG-STA-1234567890ABCDEF-SIGNATURE123";
let result = LicenseValidator::validate(valid_legacy);
assert!(result.is_ok());
let valid_new = "1477-90BD-C7BE-0CE9-1798-64D5-9616-FCCD";
let result = LicenseValidator::validate(valid_new);
assert!(result.is_ok());
let license = result.unwrap();
assert_eq!(license.license_type, LicenseType::Enterprise);
let invalid_keys = vec![
"INVALID-KEY",
"PG-XXX-1234567890ABCDEF-SIG",
"PG-STA-SHORT-SIG",
"PG-STA-1234567890ABCDEF-S",
"1234-ABCD-EFGH", "1234-567", ];
for key in invalid_keys {
assert!(LicenseValidator::validate(key).is_err());
}
}
}