use ed25519_dalek::Verifier;
use super::certificate::{MeshCertificate, MeshTier};
use super::error::SecurityError;
use super::keypair::DeviceKeypair;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum EnrollmentStatus {
Pending,
Approved,
Denied { reason: String },
Revoked { reason: String },
}
impl EnrollmentStatus {
pub fn to_byte(&self) -> u8 {
match self {
Self::Pending => 0,
Self::Approved => 1,
Self::Denied { .. } => 2,
Self::Revoked { .. } => 3,
}
}
}
#[derive(Clone, Debug)]
pub struct EnrollmentRequest {
pub subject_public_key: [u8; 32],
pub mesh_id: String,
pub node_id: String,
pub requested_tier: MeshTier,
pub bootstrap_token: Vec<u8>,
pub timestamp_ms: u64,
pub signature: [u8; 64],
}
impl EnrollmentRequest {
pub fn new(
keypair: &DeviceKeypair,
mesh_id: String,
node_id: String,
requested_tier: MeshTier,
bootstrap_token: Vec<u8>,
timestamp_ms: u64,
) -> Self {
let mut req = Self {
subject_public_key: keypair.public_key_bytes(),
mesh_id,
node_id,
requested_tier,
bootstrap_token,
timestamp_ms,
signature: [0u8; 64],
};
let signable = req.signable_bytes();
req.signature = keypair.sign(&signable).to_bytes();
req
}
pub fn verify_signature(&self) -> Result<(), SecurityError> {
let vk = ed25519_dalek::VerifyingKey::from_bytes(&self.subject_public_key)
.map_err(|e| SecurityError::InvalidPublicKey(e.to_string()))?;
let sig = ed25519_dalek::Signature::from_bytes(&self.signature);
let signable = self.signable_bytes();
vk.verify(&signable, &sig)
.map_err(|e| SecurityError::InvalidSignature(e.to_string()))
}
fn signable_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(
32 + 1
+ self.mesh_id.len()
+ 1
+ self.node_id.len()
+ 1
+ 2
+ self.bootstrap_token.len()
+ 8,
);
buf.extend_from_slice(&self.subject_public_key);
buf.push(self.mesh_id.len() as u8);
buf.extend_from_slice(self.mesh_id.as_bytes());
buf.push(self.node_id.len() as u8);
buf.extend_from_slice(self.node_id.as_bytes());
buf.push(self.requested_tier.to_byte());
buf.extend_from_slice(&(self.bootstrap_token.len() as u16).to_le_bytes());
buf.extend_from_slice(&self.bootstrap_token);
buf.extend_from_slice(&self.timestamp_ms.to_le_bytes());
buf
}
pub fn encode(&self) -> Vec<u8> {
let mut buf = self.signable_bytes();
buf.extend_from_slice(&self.signature);
buf
}
pub fn decode(data: &[u8]) -> Result<Self, SecurityError> {
if data.len() < 109 {
return Err(SecurityError::SerializationError(format!(
"enrollment request too short: {} bytes (min 109)",
data.len()
)));
}
let mut pos = 0;
let mut subject_public_key = [0u8; 32];
subject_public_key.copy_from_slice(&data[pos..pos + 32]);
pos += 32;
let mesh_id_len = data[pos] as usize;
pos += 1;
if pos + mesh_id_len + 1 > data.len() {
return Err(SecurityError::SerializationError(
"enrollment request truncated at mesh_id".to_string(),
));
}
let mesh_id = String::from_utf8(data[pos..pos + mesh_id_len].to_vec())
.map_err(|e| SecurityError::SerializationError(format!("invalid mesh_id: {e}")))?;
pos += mesh_id_len;
let node_id_len = data[pos] as usize;
pos += 1;
if pos + node_id_len + 1 + 2 > data.len() {
return Err(SecurityError::SerializationError(
"enrollment request truncated at node_id".to_string(),
));
}
let node_id = String::from_utf8(data[pos..pos + node_id_len].to_vec())
.map_err(|e| SecurityError::SerializationError(format!("invalid node_id: {e}")))?;
pos += node_id_len;
let requested_tier = MeshTier::from_byte(data[pos])
.ok_or_else(|| SecurityError::SerializationError("invalid tier byte".to_string()))?;
pos += 1;
let token_len = u16::from_le_bytes(data[pos..pos + 2].try_into().unwrap()) as usize;
pos += 2;
if pos + token_len + 8 + 64 > data.len() {
return Err(SecurityError::SerializationError(
"enrollment request truncated at token".to_string(),
));
}
let bootstrap_token = data[pos..pos + token_len].to_vec();
pos += token_len;
let timestamp_ms = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
pos += 8;
let mut signature = [0u8; 64];
signature.copy_from_slice(&data[pos..pos + 64]);
Ok(Self {
subject_public_key,
mesh_id,
node_id,
requested_tier,
bootstrap_token,
timestamp_ms,
signature,
})
}
}
#[derive(Clone, Debug)]
pub struct EnrollmentResponse {
pub status: EnrollmentStatus,
pub certificate: Option<MeshCertificate>,
pub formation_secret: Option<Vec<u8>>,
pub timestamp_ms: u64,
}
impl EnrollmentResponse {
pub fn approved(
certificate: MeshCertificate,
formation_secret: Option<Vec<u8>>,
timestamp_ms: u64,
) -> Self {
Self {
status: EnrollmentStatus::Approved,
certificate: Some(certificate),
formation_secret,
timestamp_ms,
}
}
pub fn denied(reason: String, timestamp_ms: u64) -> Self {
Self {
status: EnrollmentStatus::Denied { reason },
certificate: None,
formation_secret: None,
timestamp_ms,
}
}
pub fn pending(timestamp_ms: u64) -> Self {
Self {
status: EnrollmentStatus::Pending,
certificate: None,
formation_secret: None,
timestamp_ms,
}
}
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(64);
buf.push(self.status.to_byte());
match &self.status {
EnrollmentStatus::Denied { reason } | EnrollmentStatus::Revoked { reason } => {
let reason_bytes = reason.as_bytes();
buf.extend_from_slice(&(reason_bytes.len() as u16).to_le_bytes());
buf.extend_from_slice(reason_bytes);
}
_ => {
buf.extend_from_slice(&0u16.to_le_bytes());
}
}
if let Some(ref cert) = self.certificate {
let cert_bytes = cert.encode();
buf.push(1);
buf.extend_from_slice(&(cert_bytes.len() as u16).to_le_bytes());
buf.extend_from_slice(&cert_bytes);
} else {
buf.push(0);
}
if let Some(ref secret) = self.formation_secret {
buf.push(1);
buf.extend_from_slice(&(secret.len() as u16).to_le_bytes());
buf.extend_from_slice(secret);
} else {
buf.push(0);
}
buf.extend_from_slice(&self.timestamp_ms.to_le_bytes());
buf
}
pub fn decode(data: &[u8]) -> Result<Self, SecurityError> {
if data.len() < 13 {
return Err(SecurityError::SerializationError(format!(
"enrollment response too short: {} bytes (min 13)",
data.len()
)));
}
let mut pos = 0;
let status_byte = data[pos];
pos += 1;
let reason_len = u16::from_le_bytes(data[pos..pos + 2].try_into().unwrap()) as usize;
pos += 2;
if pos + reason_len >= data.len() {
return Err(SecurityError::SerializationError(
"enrollment response truncated at reason".to_string(),
));
}
let reason = if reason_len > 0 {
String::from_utf8(data[pos..pos + reason_len].to_vec())
.map_err(|e| SecurityError::SerializationError(format!("invalid reason: {e}")))?
} else {
String::new()
};
pos += reason_len;
let status = match status_byte {
0 => EnrollmentStatus::Pending,
1 => EnrollmentStatus::Approved,
2 => EnrollmentStatus::Denied { reason },
3 => EnrollmentStatus::Revoked { reason },
_ => {
return Err(SecurityError::SerializationError(format!(
"invalid status byte: {status_byte}"
)))
}
};
if pos >= data.len() {
return Err(SecurityError::SerializationError(
"enrollment response truncated at certificate flag".to_string(),
));
}
let has_cert = data[pos];
pos += 1;
let certificate = if has_cert == 1 {
if pos + 2 > data.len() {
return Err(SecurityError::SerializationError(
"enrollment response truncated at cert_len".to_string(),
));
}
let cert_len = u16::from_le_bytes(data[pos..pos + 2].try_into().unwrap()) as usize;
pos += 2;
if pos + cert_len > data.len() {
return Err(SecurityError::SerializationError(
"enrollment response truncated at certificate".to_string(),
));
}
let cert = MeshCertificate::decode(&data[pos..pos + cert_len])?;
pos += cert_len;
Some(cert)
} else {
None
};
if pos >= data.len() {
return Err(SecurityError::SerializationError(
"enrollment response truncated at secret flag".to_string(),
));
}
let has_secret = data[pos];
pos += 1;
let formation_secret = if has_secret == 1 {
if pos + 2 > data.len() {
return Err(SecurityError::SerializationError(
"enrollment response truncated at secret_len".to_string(),
));
}
let secret_len = u16::from_le_bytes(data[pos..pos + 2].try_into().unwrap()) as usize;
pos += 2;
if pos + secret_len > data.len() {
return Err(SecurityError::SerializationError(
"enrollment response truncated at secret".to_string(),
));
}
let secret = data[pos..pos + secret_len].to_vec();
pos += secret_len;
Some(secret)
} else {
None
};
if pos + 8 > data.len() {
return Err(SecurityError::SerializationError(
"enrollment response truncated at timestamp".to_string(),
));
}
let timestamp_ms = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
Ok(Self {
status,
certificate,
formation_secret,
timestamp_ms,
})
}
}
#[async_trait::async_trait]
pub trait EnrollmentService: Send + Sync {
async fn process_request(
&self,
request: &EnrollmentRequest,
) -> Result<EnrollmentResponse, SecurityError>;
async fn check_status(&self, subject_key: &[u8; 32])
-> Result<EnrollmentStatus, SecurityError>;
async fn revoke(&self, subject_key: &[u8; 32], reason: String) -> Result<(), SecurityError>;
}
pub struct StaticEnrollmentService {
authority: DeviceKeypair,
mesh_id: String,
allowed_tokens: std::collections::HashMap<Vec<u8>, (MeshTier, u8)>,
validity_ms: u64,
}
impl StaticEnrollmentService {
pub fn new(authority: DeviceKeypair, mesh_id: String, validity_ms: u64) -> Self {
Self {
authority,
mesh_id,
allowed_tokens: std::collections::HashMap::new(),
validity_ms,
}
}
pub fn add_token(&mut self, token: Vec<u8>, tier: MeshTier, permissions: u8) {
self.allowed_tokens.insert(token, (tier, permissions));
}
}
#[async_trait::async_trait]
impl EnrollmentService for StaticEnrollmentService {
async fn process_request(
&self,
request: &EnrollmentRequest,
) -> Result<EnrollmentResponse, SecurityError> {
request.verify_signature()?;
if request.mesh_id != self.mesh_id {
return Ok(EnrollmentResponse::denied(
format!(
"mesh ID mismatch: expected {}, got {}",
self.mesh_id, request.mesh_id
),
request.timestamp_ms,
));
}
let (tier, permissions) = match self.allowed_tokens.get(&request.bootstrap_token) {
Some(entry) => *entry,
None => {
return Ok(EnrollmentResponse::denied(
"invalid bootstrap token".to_string(),
request.timestamp_ms,
));
}
};
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
let cert = MeshCertificate::new(
request.subject_public_key,
self.mesh_id.clone(),
request.node_id.clone(),
tier,
permissions,
now,
now + self.validity_ms,
self.authority.public_key_bytes(),
)
.signed(&self.authority);
Ok(EnrollmentResponse::approved(cert, None, now))
}
async fn check_status(
&self,
_subject_key: &[u8; 32],
) -> Result<EnrollmentStatus, SecurityError> {
Ok(EnrollmentStatus::Pending)
}
async fn revoke(&self, _subject_key: &[u8; 32], _reason: String) -> Result<(), SecurityError> {
Err(SecurityError::Internal(
"static enrollment service does not support revocation".to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::super::certificate::permissions;
use super::*;
fn now_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64
}
#[test]
fn test_enrollment_request_sign_verify() {
let member = DeviceKeypair::generate();
let now = now_ms();
let req = EnrollmentRequest::new(
&member,
"A1B2C3D4".to_string(),
"tac-west-1".to_string(),
MeshTier::Tactical,
b"bootstrap-token-123".to_vec(),
now,
);
assert!(req.verify_signature().is_ok());
assert_eq!(req.subject_public_key, member.public_key_bytes());
assert_eq!(req.mesh_id, "A1B2C3D4");
assert_eq!(req.node_id, "tac-west-1");
}
#[test]
fn test_enrollment_request_encode_decode() {
let member = DeviceKeypair::generate();
let now = now_ms();
let req = EnrollmentRequest::new(
&member,
"A1B2C3D4".to_string(),
"edge-unit-7".to_string(),
MeshTier::Edge,
b"token".to_vec(),
now,
);
let encoded = req.encode();
let decoded = EnrollmentRequest::decode(&encoded).unwrap();
assert_eq!(decoded.subject_public_key, req.subject_public_key);
assert_eq!(decoded.mesh_id, req.mesh_id);
assert_eq!(decoded.node_id, "edge-unit-7");
assert_eq!(decoded.requested_tier, req.requested_tier);
assert_eq!(decoded.bootstrap_token, req.bootstrap_token);
assert_eq!(decoded.timestamp_ms, req.timestamp_ms);
assert!(decoded.verify_signature().is_ok());
}
#[test]
fn test_enrollment_request_decode_too_short() {
assert!(EnrollmentRequest::decode(&[0u8; 10]).is_err());
}
#[test]
fn test_enrollment_response_approved() {
let authority = DeviceKeypair::generate();
let now = now_ms();
let cert = MeshCertificate::new_root(
&authority,
"DEADBEEF".to_string(),
"enterprise-0".to_string(),
MeshTier::Enterprise,
now,
now + 3600000,
);
let resp = EnrollmentResponse::approved(cert, Some(b"secret".to_vec()), now);
assert_eq!(resp.status, EnrollmentStatus::Approved);
assert!(resp.certificate.is_some());
assert!(resp.formation_secret.is_some());
}
#[test]
fn test_enrollment_response_denied() {
let now = now_ms();
let resp = EnrollmentResponse::denied("bad token".to_string(), now);
assert_eq!(
resp.status,
EnrollmentStatus::Denied {
reason: "bad token".to_string()
}
);
assert!(resp.certificate.is_none());
}
#[tokio::test]
async fn test_static_enrollment_service_approve() {
let authority = DeviceKeypair::generate();
let member = DeviceKeypair::generate();
let now = now_ms();
let validity = 24 * 60 * 60 * 1000;
let mut service =
StaticEnrollmentService::new(authority.clone(), "DEADBEEF".to_string(), validity);
service.add_token(
b"valid-token".to_vec(),
MeshTier::Tactical,
permissions::STANDARD,
);
let req = EnrollmentRequest::new(
&member,
"DEADBEEF".to_string(),
"tac-node-1".to_string(),
MeshTier::Tactical,
b"valid-token".to_vec(),
now,
);
let resp = service.process_request(&req).await.unwrap();
assert_eq!(resp.status, EnrollmentStatus::Approved);
let cert = resp.certificate.unwrap();
assert!(cert.verify().is_ok());
assert_eq!(cert.subject_public_key, member.public_key_bytes());
assert_eq!(cert.node_id, "tac-node-1");
assert_eq!(cert.tier, MeshTier::Tactical);
assert_eq!(cert.permissions, permissions::STANDARD);
assert_eq!(cert.issuer_public_key, authority.public_key_bytes());
}
#[tokio::test]
async fn test_static_enrollment_service_deny_bad_token() {
let authority = DeviceKeypair::generate();
let member = DeviceKeypair::generate();
let now = now_ms();
let service = StaticEnrollmentService::new(authority, "DEADBEEF".to_string(), 3600000);
let req = EnrollmentRequest::new(
&member,
"DEADBEEF".to_string(),
"tac-node-2".to_string(),
MeshTier::Tactical,
b"invalid-token".to_vec(),
now,
);
let resp = service.process_request(&req).await.unwrap();
match resp.status {
EnrollmentStatus::Denied { reason } => {
assert!(reason.contains("invalid bootstrap token"));
}
other => panic!("expected Denied, got {:?}", other),
}
}
#[tokio::test]
async fn test_static_enrollment_service_deny_wrong_mesh() {
let authority = DeviceKeypair::generate();
let member = DeviceKeypair::generate();
let now = now_ms();
let mut service = StaticEnrollmentService::new(authority, "DEADBEEF".to_string(), 3600000);
service.add_token(b"token".to_vec(), MeshTier::Tactical, permissions::STANDARD);
let req = EnrollmentRequest::new(
&member,
"WRONG_MESH".to_string(),
"tac-node-3".to_string(),
MeshTier::Tactical,
b"token".to_vec(),
now,
);
let resp = service.process_request(&req).await.unwrap();
match resp.status {
EnrollmentStatus::Denied { reason } => {
assert!(reason.contains("mesh ID mismatch"));
}
other => panic!("expected Denied, got {:?}", other),
}
}
#[test]
fn test_enrollment_status_byte() {
assert_eq!(EnrollmentStatus::Pending.to_byte(), 0);
assert_eq!(EnrollmentStatus::Approved.to_byte(), 1);
assert_eq!(
EnrollmentStatus::Denied {
reason: "x".to_string()
}
.to_byte(),
2
);
assert_eq!(
EnrollmentStatus::Revoked {
reason: "x".to_string()
}
.to_byte(),
3
);
}
#[test]
fn test_enrollment_response_approved_encode_decode() {
let authority = DeviceKeypair::generate();
let now = now_ms();
let cert = MeshCertificate::new_root(
&authority,
"DEADBEEF".to_string(),
"enterprise-0".to_string(),
MeshTier::Enterprise,
now,
now + 3600000,
);
let resp =
EnrollmentResponse::approved(cert.clone(), Some(b"formation-secret".to_vec()), now);
let encoded = resp.encode();
let decoded = EnrollmentResponse::decode(&encoded).unwrap();
assert_eq!(decoded.status, EnrollmentStatus::Approved);
assert_eq!(decoded.timestamp_ms, now);
let decoded_cert = decoded.certificate.unwrap();
assert_eq!(decoded_cert.subject_public_key, cert.subject_public_key);
assert_eq!(decoded_cert.mesh_id, cert.mesh_id);
assert_eq!(decoded_cert.node_id, "enterprise-0");
assert!(decoded_cert.verify().is_ok());
assert_eq!(decoded.formation_secret, Some(b"formation-secret".to_vec()));
}
#[test]
fn test_enrollment_response_denied_encode_decode() {
let now = now_ms();
let resp = EnrollmentResponse::denied("bad token".to_string(), now);
let encoded = resp.encode();
let decoded = EnrollmentResponse::decode(&encoded).unwrap();
assert_eq!(
decoded.status,
EnrollmentStatus::Denied {
reason: "bad token".to_string()
}
);
assert!(decoded.certificate.is_none());
assert!(decoded.formation_secret.is_none());
assert_eq!(decoded.timestamp_ms, now);
}
#[test]
fn test_enrollment_response_pending_encode_decode() {
let now = now_ms();
let resp = EnrollmentResponse::pending(now);
let encoded = resp.encode();
let decoded = EnrollmentResponse::decode(&encoded).unwrap();
assert_eq!(decoded.status, EnrollmentStatus::Pending);
assert!(decoded.certificate.is_none());
assert!(decoded.formation_secret.is_none());
}
#[test]
fn test_enrollment_response_decode_too_short() {
assert!(EnrollmentResponse::decode(&[0u8; 5]).is_err());
}
#[test]
fn test_enrollment_response_no_secret_encode_decode() {
let authority = DeviceKeypair::generate();
let now = now_ms();
let cert = MeshCertificate::new_root(
&authority,
"DEADBEEF".to_string(),
"node-1".to_string(),
MeshTier::Tactical,
now,
now + 3600000,
);
let resp = EnrollmentResponse::approved(cert, None, now);
let encoded = resp.encode();
let decoded = EnrollmentResponse::decode(&encoded).unwrap();
assert_eq!(decoded.status, EnrollmentStatus::Approved);
assert!(decoded.certificate.is_some());
assert!(decoded.formation_secret.is_none());
}
}