use crate::anti_tamper::HardwareFingerprint;
use crate::error::{LicenseError, Result};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::BTreeMap;
use std::io::Write;
use std::path::Path;
use uuid::Uuid;
use super::{
detect_format, SneakernetFormat, MAX_SNEAKERNET_JSON_PAYLOAD, REQUEST_MAGIC,
REQUEST_TEXT_PREFIX, REQUEST_TEXT_SUFFIX, REQUEST_VERSION,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActivationRequest {
pub request_id: Uuid,
pub fingerprint: HardwareFingerprint,
pub product_id: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub requested_features: Vec<String>,
pub timestamp: DateTime<Utc>,
pub version: u8,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub metadata: BTreeMap<String, String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub checksum: Option<String>,
}
impl ActivationRequest {
pub fn builder() -> ActivationRequestBuilder {
ActivationRequestBuilder::new()
}
pub fn load(path: &Path) -> Result<Self> {
let data = std::fs::read(path)?;
Self::from_bytes(&data)
}
pub fn from_bytes(data: &[u8]) -> Result<Self> {
match detect_format(data) {
Some(SneakernetFormat::Binary) => Self::from_binary(data),
Some(SneakernetFormat::Text) => {
let text = std::str::from_utf8(data)
.map_err(|e| LicenseError::InvalidLicenseFormat(e.to_string()))?;
Self::from_base64(text)
}
None => Err(LicenseError::InvalidLicenseFormat(
"Unknown activation request format".to_string(),
)),
}
}
pub fn from_binary(data: &[u8]) -> Result<Self> {
if data.len() < 9 {
return Err(LicenseError::InvalidLicenseFormat(
"Activation request too short".to_string(),
));
}
if &data[0..4] != REQUEST_MAGIC {
return Err(LicenseError::InvalidLicenseFormat(
"Invalid activation request magic header".to_string(),
));
}
let version = data[4];
if version > REQUEST_VERSION {
return Err(LicenseError::InvalidLicenseFormat(format!(
"Unsupported activation request version: {} (max supported: {})",
version, REQUEST_VERSION
)));
}
let len = u32::from_le_bytes([data[5], data[6], data[7], data[8]]) as usize;
if len > MAX_SNEAKERNET_JSON_PAYLOAD {
return Err(LicenseError::InvalidLicenseFormat(format!(
"Activation request payload exceeds maximum of {} bytes",
MAX_SNEAKERNET_JSON_PAYLOAD
)));
}
if data.len() < 9 + len {
return Err(LicenseError::InvalidLicenseFormat(
"Activation request data truncated".to_string(),
));
}
let request: Self = serde_json::from_slice(&data[9..9 + len])
.map_err(|e| LicenseError::InvalidLicenseFormat(e.to_string()))?;
if let Some(ref stored_checksum) = request.checksum {
let computed = request.compute_checksum();
if &computed != stored_checksum {
return Err(LicenseError::InvalidLicenseFormat(
"Activation request checksum mismatch - data may be corrupted".to_string(),
));
}
}
Ok(request)
}
pub fn from_base64(text: &str) -> Result<Self> {
let trimmed = text.trim();
let base64_content = if trimmed.starts_with(REQUEST_TEXT_PREFIX) {
trimmed
.strip_prefix(REQUEST_TEXT_PREFIX)
.and_then(|s| s.strip_suffix(REQUEST_TEXT_SUFFIX))
.map(|s| s.trim())
.ok_or_else(|| {
LicenseError::InvalidLicenseFormat(
"Malformed activation request text format".to_string(),
)
})?
} else {
trimmed
};
let clean_base64: String = base64_content
.chars()
.filter(|c| !c.is_whitespace())
.collect();
let binary = BASE64
.decode(&clean_base64)
.map_err(|e| LicenseError::InvalidLicenseFormat(format!("Invalid base64: {}", e)))?;
Self::from_binary(&binary)
}
pub fn to_binary(&self) -> Result<Vec<u8>> {
let mut output = Vec::new();
output.write_all(REQUEST_MAGIC)?;
output.write_all(&[REQUEST_VERSION])?;
let encoded = serde_json::to_vec(self)
.map_err(|e| LicenseError::SerializationError(e.to_string()))?;
let len = encoded.len() as u32;
output.write_all(&len.to_le_bytes())?;
output.write_all(&encoded)?;
Ok(output)
}
pub fn to_base64(&self) -> Result<String> {
let binary = self.to_binary()?;
let base64_content = BASE64.encode(&binary);
let wrapped: Vec<&str> = base64_content
.as_bytes()
.chunks(64)
.map(|chunk| std::str::from_utf8(chunk).unwrap())
.collect();
Ok(format!(
"{}\n{}\n{}",
REQUEST_TEXT_PREFIX,
wrapped.join("\n"),
REQUEST_TEXT_SUFFIX
))
}
pub fn save_binary(&self, path: &Path) -> Result<()> {
let binary = self.to_binary()?;
std::fs::write(path, binary)?;
Ok(())
}
pub fn save_text(&self, path: &Path) -> Result<()> {
let text = self.to_base64()?;
std::fs::write(path, text)?;
Ok(())
}
fn compute_checksum(&self) -> String {
let mut hasher = Sha256::new();
hasher.update(self.request_id.as_bytes());
hasher.update(self.product_id.as_bytes());
hasher.update(self.fingerprint.combined_hash.as_bytes());
hasher.update(self.timestamp.to_rfc3339().as_bytes());
hasher.update([self.version]);
for feature in &self.requested_features {
hasher.update(feature.as_bytes());
}
for (key, value) in &self.metadata {
hasher.update(key.as_bytes());
hasher.update(value.as_bytes());
}
hex::encode(hasher.finalize())
}
pub fn verify_integrity(&self) -> bool {
match &self.checksum {
Some(stored) => &self.compute_checksum() == stored,
None => true, }
}
}
#[derive(Default)]
pub struct ActivationRequestBuilder {
fingerprint: Option<HardwareFingerprint>,
product_id: Option<String>,
requested_features: Vec<String>,
metadata: BTreeMap<String, String>,
}
impl ActivationRequestBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn fingerprint(mut self, fingerprint: HardwareFingerprint) -> Self {
self.fingerprint = Some(fingerprint);
self
}
pub fn fingerprint_current(mut self) -> Self {
self.fingerprint = Some(HardwareFingerprint::generate());
self
}
pub fn product_id(mut self, product_id: impl Into<String>) -> Self {
self.product_id = Some(product_id.into());
self
}
pub fn feature(mut self, feature: impl Into<String>) -> Self {
self.requested_features.push(feature.into());
self
}
pub fn features(mut self, features: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.requested_features
.extend(features.into_iter().map(|f| f.into()));
self
}
pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn customer_name(self, name: impl Into<String>) -> Self {
self.metadata("customer_name", name)
}
pub fn customer_email(self, email: impl Into<String>) -> Self {
self.metadata("customer_email", email)
}
pub fn order_reference(self, reference: impl Into<String>) -> Self {
self.metadata("order_reference", reference)
}
pub fn build(self) -> Result<ActivationRequest> {
let fingerprint = self.fingerprint.unwrap_or_default();
let product_id = self
.product_id
.ok_or_else(|| LicenseError::MissingField("product_id".into()))?;
let mut request = ActivationRequest {
request_id: Uuid::new_v4(),
fingerprint,
product_id,
requested_features: self.requested_features,
timestamp: Utc::now(),
version: REQUEST_VERSION,
metadata: self.metadata,
checksum: None,
};
request.checksum = Some(request.compute_checksum());
Ok(request)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_fingerprint() -> HardwareFingerprint {
HardwareFingerprint {
mac_hashes: vec!["abc123".to_string()],
disk_hashes: vec!["def456".to_string()],
hostname_hash: Some("host789".to_string()),
machine_guid_hash: Some("guid012".to_string()),
combined_hash: "combined345".to_string(),
}
}
#[test]
fn test_request_builder() {
let request = ActivationRequest::builder()
.product_id("TEST-PRODUCT")
.fingerprint(create_test_fingerprint())
.feature("basic")
.feature("premium")
.customer_name("John Doe")
.customer_email("john@example.com")
.build()
.unwrap();
assert_eq!(request.product_id, "TEST-PRODUCT");
assert_eq!(request.requested_features.len(), 2);
assert!(request.requested_features.contains(&"basic".to_string()));
assert!(request.requested_features.contains(&"premium".to_string()));
assert_eq!(
request.metadata.get("customer_name"),
Some(&"John Doe".to_string())
);
assert!(request.checksum.is_some());
}
#[test]
fn test_request_builder_missing_product_id() {
let result = ActivationRequest::builder()
.fingerprint(create_test_fingerprint())
.build();
assert!(result.is_err());
}
#[test]
fn test_binary_serialization_roundtrip() {
let request = ActivationRequest::builder()
.product_id("MY-APP")
.fingerprint(create_test_fingerprint())
.feature("pro")
.build()
.unwrap();
let binary = request.to_binary().unwrap();
assert_eq!(&binary[0..4], REQUEST_MAGIC);
assert_eq!(binary[4], REQUEST_VERSION);
let parsed = ActivationRequest::from_binary(&binary).unwrap();
assert_eq!(parsed.request_id, request.request_id);
assert_eq!(parsed.product_id, request.product_id);
assert_eq!(parsed.requested_features, request.requested_features);
assert_eq!(parsed.checksum, request.checksum);
}
#[test]
fn test_base64_serialization_roundtrip() {
let request = ActivationRequest::builder()
.product_id("MY-APP")
.fingerprint(create_test_fingerprint())
.feature("enterprise")
.customer_name("Acme Corp")
.build()
.unwrap();
let text = request.to_base64().unwrap();
assert!(text.starts_with(REQUEST_TEXT_PREFIX));
assert!(text.ends_with(REQUEST_TEXT_SUFFIX));
let parsed = ActivationRequest::from_base64(&text).unwrap();
assert_eq!(parsed.request_id, request.request_id);
assert_eq!(parsed.product_id, request.product_id);
assert_eq!(parsed.requested_features, request.requested_features);
assert_eq!(
parsed.metadata.get("customer_name"),
Some(&"Acme Corp".to_string())
);
}
#[test]
fn test_auto_detect_format() {
let request = ActivationRequest::builder()
.product_id("AUTO-DETECT")
.fingerprint(create_test_fingerprint())
.build()
.unwrap();
let binary = request.to_binary().unwrap();
let parsed_binary = ActivationRequest::from_bytes(&binary).unwrap();
assert_eq!(parsed_binary.product_id, "AUTO-DETECT");
let text = request.to_base64().unwrap();
let parsed_text = ActivationRequest::from_bytes(text.as_bytes()).unwrap();
assert_eq!(parsed_text.product_id, "AUTO-DETECT");
}
#[test]
fn test_integrity_verification() {
let request = ActivationRequest::builder()
.product_id("INTEGRITY-TEST")
.fingerprint(create_test_fingerprint())
.build()
.unwrap();
assert!(request.verify_integrity());
}
#[test]
fn test_invalid_magic_header() {
let mut bad_data = vec![b'B', b'A', b'D', b'!'];
bad_data.extend_from_slice(&[1, 0, 0, 0, 0]);
let result = ActivationRequest::from_binary(&bad_data);
assert!(result.is_err());
}
#[test]
fn test_truncated_data() {
let request = ActivationRequest::builder()
.product_id("TRUNCATE-TEST")
.fingerprint(create_test_fingerprint())
.build()
.unwrap();
let binary = request.to_binary().unwrap();
let truncated = &binary[0..20];
let result = ActivationRequest::from_binary(truncated);
assert!(result.is_err());
}
}