pg-api 0.1.0

A high-performance PostgreSQL REST API driver with rate limiting, connection pooling, and observability
use chrono::{DateTime, Utc};
use std::env;
use thiserror::Error;

// Secret key must match the one in generate_license.py
#[allow(dead_code)]
const SECRET_KEY: &str = "pg-api-license-secret-2024-production";

#[derive(Error, Debug)]
#[allow(dead_code)]
pub enum LicenseError {
    #[error("No license key provided")]
    NoLicense,
    
    #[error("Invalid license key format")]
    InvalidFormat,
    
    #[error("License has expired")]
    Expired,
    
    #[error("Invalid license signature")]
    InvalidSignature,
    
    #[error("License validation failed: {0}")]
    ValidationFailed(String),
}

#[derive(Debug, Clone)]
pub struct License {
    #[allow(dead_code)]
    pub key: String,
    pub license_type: LicenseType,
    pub expires_at: DateTime<Utc>,
    pub max_connections: u32,
    pub is_valid: bool,
}

#[derive(Debug, Clone, PartialEq)]
pub enum LicenseType {
    Trial,
    Standard,
    Enterprise,
}

impl LicenseType {
    fn from_str(s: &str) -> Option<Self> {
        match s.to_uppercase().as_str() {
            "TRI" => Some(LicenseType::Trial),
            "STA" => Some(LicenseType::Standard),
            "ENT" => Some(LicenseType::Enterprise),
            _ => None,
        }
    }
}

pub struct LicenseValidator;

impl LicenseValidator {
    /// Validate license key from environment variable
    pub fn validate_from_env() -> Result<License, LicenseError> {
        let license_key = env::var("LICENSE_KEY")
            .map_err(|_| LicenseError::NoLicense)?;
        
        Self::validate(&license_key)
    }
    
    /// Validate a license key
    pub fn validate(license_key: &str) -> Result<License, LicenseError> {
        // Support two formats:
        // 1. Legacy: PG-TYPE-ID-SIGNATURE (4 parts starting with PG)
        // 2. New: XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX (8 parts from license-manager)
        let parts: Vec<&str> = license_key.split('-').collect();
        
        // Check if it's the new format from license-manager
        if parts.len() == 8 {
            // New format from license-manager
            // Validate basic format (each part should be 4 hex chars)
            for part in &parts {
                if part.len() != 4 {
                    return Err(LicenseError::InvalidFormat);
                }
                // Check if it's valid hex
                if !part.chars().all(|c| c.is_ascii_hexdigit()) {
                    return Err(LicenseError::InvalidFormat);
                }
            }
            
            // For license-manager format, we default to Enterprise with 20 years
            // since the actual validation happens in the license-manager database
            let license_type = LicenseType::Enterprise;
            let expires_at = Utc::now() + chrono::Duration::days(7300); // 20 years
            let max_connections = 1000;
            
            Ok(License {
                key: license_key.to_string(),
                license_type,
                expires_at,
                max_connections,
                is_valid: true,
            })
        } else if parts.len() == 4 && parts[0] == "PG" {
            // Legacy format: PG-TYPE-ID-SIGNATURE
            let license_type = LicenseType::from_str(parts[1])
                .ok_or(LicenseError::InvalidFormat)?;
            
            let license_id = parts[2];
            let provided_signature = parts[3];
            
            // Basic format validation
            if license_id.len() != 16 || provided_signature.len() < 10 {
                return Err(LicenseError::InvalidFormat);
            }
            
            // Default expiration (1 year from now)
            let expires_at = Utc::now() + chrono::Duration::days(365);
            
            // Default max connections based on type
            let max_connections = match license_type {
                LicenseType::Trial => 10,
                LicenseType::Standard => 100,
                LicenseType::Enterprise => 1000,
            };
            
            Ok(License {
                key: license_key.to_string(),
                license_type,
                expires_at,
                max_connections,
                is_valid: true,
            })
        } else {
            return Err(LicenseError::InvalidFormat);
        }
    }
    
    /// Check if license is expired (with grace period)
    pub fn check_expiration(license: &License) -> Result<(), LicenseError> {
        let now = Utc::now();
        let grace_period = chrono::Duration::days(7);
        
        if now > license.expires_at + grace_period {
            return Err(LicenseError::Expired);
        }
        
        Ok(())
    }
    
    /// Get license status message
    pub fn get_status_message(license: &License) -> String {
        let now = Utc::now();
        let days_remaining = (license.expires_at - now).num_days();
        
        if days_remaining < 0 {
            format!("License expired {} days ago", -days_remaining)
        } else if days_remaining < 30 {
            format!("License expires in {} days", days_remaining)
        } else {
            format!("License valid for {} days", days_remaining)
        }
    }
}

/// License middleware for Axum
use axum::{
    extract::State,
    http::StatusCode,
    middleware::Next,
    response::{IntoResponse, Response},
};
use crate::models::AppState;

/// Middleware to check license validity
pub async fn check_license_middleware(
    State(app_state): State<AppState>,
    request: axum::http::Request<axum::body::Body>,
    next: Next,
) -> Response {
    // Skip license check for health and license endpoints
    let path = request.uri().path();
    if path == "/health" || path == "/ready" || path == "/v1/license" {
        return next.run(request).await;
    }
    
    // Check if license exists and is valid
    match &*app_state.license {
        Some(license) => {
            // Check expiration
            if let Err(e) = LicenseValidator::check_expiration(license) {
                tracing::error!("License expired: {}", e);
                return (
                    StatusCode::PAYMENT_REQUIRED,
                    format!("License expired. Please renew your license."),
                ).into_response();
            }
        }
        None => {
            // For v0.1.0, we'll allow operation without license but log warning
            tracing::warn!("Operating without valid license");
            
            // Uncomment to enforce license requirement:
            // return (
            //     StatusCode::PAYMENT_REQUIRED,
            //     "No valid license found. Please provide a LICENSE_KEY.",
            // ).into_response();
        }
    }
    
    next.run(request).await
}

/// License info endpoint handler
use serde::Serialize;

#[derive(Serialize)]
pub struct LicenseInfo {
    pub valid: bool,
    pub license_type: Option<String>,
    pub expires_at: Option<String>,
    pub max_connections: Option<u32>,
    pub status_message: String,
}

pub async fn get_license_info(
    State(app_state): State<AppState>,
) -> impl IntoResponse {
    match &*app_state.license {
        Some(license) => {
            let info = LicenseInfo {
                valid: license.is_valid,
                license_type: Some(format!("{:?}", license.license_type)),
                expires_at: Some(license.expires_at.to_rfc3339()),
                max_connections: Some(license.max_connections),
                status_message: LicenseValidator::get_status_message(license),
            };
            
            (StatusCode::OK, axum::Json(info))
        }
        None => {
            let info = LicenseInfo {
                valid: false,
                license_type: None,
                expires_at: None,
                max_connections: None,
                status_message: "No license key provided".to_string(),
            };
            
            (StatusCode::OK, axum::Json(info))
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    
    #[test]
    fn test_license_type_parsing() {
        assert_eq!(LicenseType::from_str("TRI"), Some(LicenseType::Trial));
        assert_eq!(LicenseType::from_str("STA"), Some(LicenseType::Standard));
        assert_eq!(LicenseType::from_str("ENT"), Some(LicenseType::Enterprise));
        assert_eq!(LicenseType::from_str("XXX"), None);
    }
    
    #[test]
    fn test_license_validation_format() {
        // Valid legacy format
        let valid_legacy = "PG-STA-1234567890ABCDEF-SIGNATURE123";
        let result = LicenseValidator::validate(valid_legacy);
        assert!(result.is_ok());
        
        // Valid new format (license-manager)
        let valid_new = "1477-90BD-C7BE-0CE9-1798-64D5-9616-FCCD";
        let result = LicenseValidator::validate(valid_new);
        assert!(result.is_ok());
        let license = result.unwrap();
        assert_eq!(license.license_type, LicenseType::Enterprise);
        
        // Invalid formats
        let invalid_keys = vec![
            "INVALID-KEY",
            "PG-XXX-1234567890ABCDEF-SIG",
            "PG-STA-SHORT-SIG",
            "PG-STA-1234567890ABCDEF-S",
            "1234-ABCD-EFGH",  // Not hex
            "1234-567",  // Too short
        ];
        
        for key in invalid_keys {
            assert!(LicenseValidator::validate(key).is_err());
        }
    }
}