1use ed25519_dalek::Verifier;
22
23use super::certificate::{MeshCertificate, MeshTier};
24use super::error::SecurityError;
25use super::keypair::DeviceKeypair;
26
27#[derive(Clone, Debug, PartialEq, Eq)]
29pub enum EnrollmentStatus {
30 Pending,
32 Approved,
34 Denied { reason: String },
36 Revoked { reason: String },
38}
39
40impl EnrollmentStatus {
41 pub fn to_byte(&self) -> u8 {
43 match self {
44 Self::Pending => 0,
45 Self::Approved => 1,
46 Self::Denied { .. } => 2,
47 Self::Revoked { .. } => 3,
48 }
49 }
50}
51
52#[derive(Clone, Debug)]
60pub struct EnrollmentRequest {
61 pub subject_public_key: [u8; 32],
63 pub mesh_id: String,
65 pub node_id: String,
67 pub requested_tier: MeshTier,
69 pub bootstrap_token: Vec<u8>,
71 pub timestamp_ms: u64,
73 pub signature: [u8; 64],
75}
76
77impl EnrollmentRequest {
78 pub fn new(
80 keypair: &DeviceKeypair,
81 mesh_id: String,
82 node_id: String,
83 requested_tier: MeshTier,
84 bootstrap_token: Vec<u8>,
85 timestamp_ms: u64,
86 ) -> Self {
87 let mut req = Self {
88 subject_public_key: keypair.public_key_bytes(),
89 mesh_id,
90 node_id,
91 requested_tier,
92 bootstrap_token,
93 timestamp_ms,
94 signature: [0u8; 64],
95 };
96 let signable = req.signable_bytes();
97 req.signature = keypair.sign(&signable).to_bytes();
98 req
99 }
100
101 pub fn verify_signature(&self) -> Result<(), SecurityError> {
103 let vk = ed25519_dalek::VerifyingKey::from_bytes(&self.subject_public_key)
104 .map_err(|e| SecurityError::InvalidPublicKey(e.to_string()))?;
105 let sig = ed25519_dalek::Signature::from_bytes(&self.signature);
106 let signable = self.signable_bytes();
107 vk.verify(&signable, &sig)
108 .map_err(|e| SecurityError::InvalidSignature(e.to_string()))
109 }
110
111 fn signable_bytes(&self) -> Vec<u8> {
112 let mut buf = Vec::with_capacity(
113 32 + 1
114 + self.mesh_id.len()
115 + 1
116 + self.node_id.len()
117 + 1
118 + 2
119 + self.bootstrap_token.len()
120 + 8,
121 );
122 buf.extend_from_slice(&self.subject_public_key);
123 buf.push(self.mesh_id.len() as u8);
124 buf.extend_from_slice(self.mesh_id.as_bytes());
125 buf.push(self.node_id.len() as u8);
126 buf.extend_from_slice(self.node_id.as_bytes());
127 buf.push(self.requested_tier.to_byte());
128 buf.extend_from_slice(&(self.bootstrap_token.len() as u16).to_le_bytes());
129 buf.extend_from_slice(&self.bootstrap_token);
130 buf.extend_from_slice(&self.timestamp_ms.to_le_bytes());
131 buf
132 }
133
134 pub fn encode(&self) -> Vec<u8> {
136 let mut buf = self.signable_bytes();
137 buf.extend_from_slice(&self.signature);
138 buf
139 }
140
141 pub fn decode(data: &[u8]) -> Result<Self, SecurityError> {
143 if data.len() < 109 {
145 return Err(SecurityError::SerializationError(format!(
146 "enrollment request too short: {} bytes (min 109)",
147 data.len()
148 )));
149 }
150
151 let mut pos = 0;
152
153 let mut subject_public_key = [0u8; 32];
154 subject_public_key.copy_from_slice(&data[pos..pos + 32]);
155 pos += 32;
156
157 let mesh_id_len = data[pos] as usize;
158 pos += 1;
159
160 if pos + mesh_id_len + 1 > data.len() {
161 return Err(SecurityError::SerializationError(
162 "enrollment request truncated at mesh_id".to_string(),
163 ));
164 }
165
166 let mesh_id = String::from_utf8(data[pos..pos + mesh_id_len].to_vec())
167 .map_err(|e| SecurityError::SerializationError(format!("invalid mesh_id: {e}")))?;
168 pos += mesh_id_len;
169
170 let node_id_len = data[pos] as usize;
171 pos += 1;
172
173 if pos + node_id_len + 1 + 2 > data.len() {
174 return Err(SecurityError::SerializationError(
175 "enrollment request truncated at node_id".to_string(),
176 ));
177 }
178
179 let node_id = String::from_utf8(data[pos..pos + node_id_len].to_vec())
180 .map_err(|e| SecurityError::SerializationError(format!("invalid node_id: {e}")))?;
181 pos += node_id_len;
182
183 let requested_tier = MeshTier::from_byte(data[pos])
184 .ok_or_else(|| SecurityError::SerializationError("invalid tier byte".to_string()))?;
185 pos += 1;
186
187 let token_len = u16::from_le_bytes(data[pos..pos + 2].try_into().unwrap()) as usize;
188 pos += 2;
189
190 if pos + token_len + 8 + 64 > data.len() {
191 return Err(SecurityError::SerializationError(
192 "enrollment request truncated at token".to_string(),
193 ));
194 }
195
196 let bootstrap_token = data[pos..pos + token_len].to_vec();
197 pos += token_len;
198
199 let timestamp_ms = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
200 pos += 8;
201
202 let mut signature = [0u8; 64];
203 signature.copy_from_slice(&data[pos..pos + 64]);
204
205 Ok(Self {
206 subject_public_key,
207 mesh_id,
208 node_id,
209 requested_tier,
210 bootstrap_token,
211 timestamp_ms,
212 signature,
213 })
214 }
215}
216
217#[derive(Clone, Debug)]
222pub struct EnrollmentResponse {
223 pub status: EnrollmentStatus,
225 pub certificate: Option<MeshCertificate>,
227 pub formation_secret: Option<Vec<u8>>,
230 pub timestamp_ms: u64,
232}
233
234impl EnrollmentResponse {
235 pub fn approved(
237 certificate: MeshCertificate,
238 formation_secret: Option<Vec<u8>>,
239 timestamp_ms: u64,
240 ) -> Self {
241 Self {
242 status: EnrollmentStatus::Approved,
243 certificate: Some(certificate),
244 formation_secret,
245 timestamp_ms,
246 }
247 }
248
249 pub fn denied(reason: String, timestamp_ms: u64) -> Self {
251 Self {
252 status: EnrollmentStatus::Denied { reason },
253 certificate: None,
254 formation_secret: None,
255 timestamp_ms,
256 }
257 }
258
259 pub fn pending(timestamp_ms: u64) -> Self {
261 Self {
262 status: EnrollmentStatus::Pending,
263 certificate: None,
264 formation_secret: None,
265 timestamp_ms,
266 }
267 }
268
269 pub fn encode(&self) -> Vec<u8> {
278 let mut buf = Vec::with_capacity(64);
279
280 buf.push(self.status.to_byte());
282 match &self.status {
283 EnrollmentStatus::Denied { reason } | EnrollmentStatus::Revoked { reason } => {
284 let reason_bytes = reason.as_bytes();
285 buf.extend_from_slice(&(reason_bytes.len() as u16).to_le_bytes());
286 buf.extend_from_slice(reason_bytes);
287 }
288 _ => {
289 buf.extend_from_slice(&0u16.to_le_bytes());
290 }
291 }
292
293 if let Some(ref cert) = self.certificate {
295 let cert_bytes = cert.encode();
296 buf.push(1);
297 buf.extend_from_slice(&(cert_bytes.len() as u16).to_le_bytes());
298 buf.extend_from_slice(&cert_bytes);
299 } else {
300 buf.push(0);
301 }
302
303 if let Some(ref secret) = self.formation_secret {
305 buf.push(1);
306 buf.extend_from_slice(&(secret.len() as u16).to_le_bytes());
307 buf.extend_from_slice(secret);
308 } else {
309 buf.push(0);
310 }
311
312 buf.extend_from_slice(&self.timestamp_ms.to_le_bytes());
314
315 buf
316 }
317
318 pub fn decode(data: &[u8]) -> Result<Self, SecurityError> {
320 if data.len() < 13 {
322 return Err(SecurityError::SerializationError(format!(
323 "enrollment response too short: {} bytes (min 13)",
324 data.len()
325 )));
326 }
327
328 let mut pos = 0;
329
330 let status_byte = data[pos];
331 pos += 1;
332
333 let reason_len = u16::from_le_bytes(data[pos..pos + 2].try_into().unwrap()) as usize;
334 pos += 2;
335
336 if pos + reason_len >= data.len() {
337 return Err(SecurityError::SerializationError(
338 "enrollment response truncated at reason".to_string(),
339 ));
340 }
341
342 let reason = if reason_len > 0 {
343 String::from_utf8(data[pos..pos + reason_len].to_vec())
344 .map_err(|e| SecurityError::SerializationError(format!("invalid reason: {e}")))?
345 } else {
346 String::new()
347 };
348 pos += reason_len;
349
350 let status = match status_byte {
351 0 => EnrollmentStatus::Pending,
352 1 => EnrollmentStatus::Approved,
353 2 => EnrollmentStatus::Denied { reason },
354 3 => EnrollmentStatus::Revoked { reason },
355 _ => {
356 return Err(SecurityError::SerializationError(format!(
357 "invalid status byte: {status_byte}"
358 )))
359 }
360 };
361
362 if pos >= data.len() {
364 return Err(SecurityError::SerializationError(
365 "enrollment response truncated at certificate flag".to_string(),
366 ));
367 }
368 let has_cert = data[pos];
369 pos += 1;
370
371 let certificate = if has_cert == 1 {
372 if pos + 2 > data.len() {
373 return Err(SecurityError::SerializationError(
374 "enrollment response truncated at cert_len".to_string(),
375 ));
376 }
377 let cert_len = u16::from_le_bytes(data[pos..pos + 2].try_into().unwrap()) as usize;
378 pos += 2;
379 if pos + cert_len > data.len() {
380 return Err(SecurityError::SerializationError(
381 "enrollment response truncated at certificate".to_string(),
382 ));
383 }
384 let cert = MeshCertificate::decode(&data[pos..pos + cert_len])?;
385 pos += cert_len;
386 Some(cert)
387 } else {
388 None
389 };
390
391 if pos >= data.len() {
393 return Err(SecurityError::SerializationError(
394 "enrollment response truncated at secret flag".to_string(),
395 ));
396 }
397 let has_secret = data[pos];
398 pos += 1;
399
400 let formation_secret = if has_secret == 1 {
401 if pos + 2 > data.len() {
402 return Err(SecurityError::SerializationError(
403 "enrollment response truncated at secret_len".to_string(),
404 ));
405 }
406 let secret_len = u16::from_le_bytes(data[pos..pos + 2].try_into().unwrap()) as usize;
407 pos += 2;
408 if pos + secret_len > data.len() {
409 return Err(SecurityError::SerializationError(
410 "enrollment response truncated at secret".to_string(),
411 ));
412 }
413 let secret = data[pos..pos + secret_len].to_vec();
414 pos += secret_len;
415 Some(secret)
416 } else {
417 None
418 };
419
420 if pos + 8 > data.len() {
422 return Err(SecurityError::SerializationError(
423 "enrollment response truncated at timestamp".to_string(),
424 ));
425 }
426 let timestamp_ms = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
427
428 Ok(Self {
429 status,
430 certificate,
431 formation_secret,
432 timestamp_ms,
433 })
434 }
435}
436
437#[async_trait::async_trait]
448pub trait EnrollmentService: Send + Sync {
449 async fn process_request(
456 &self,
457 request: &EnrollmentRequest,
458 ) -> Result<EnrollmentResponse, SecurityError>;
459
460 async fn check_status(&self, subject_key: &[u8; 32])
462 -> Result<EnrollmentStatus, SecurityError>;
463
464 async fn revoke(&self, subject_key: &[u8; 32], reason: String) -> Result<(), SecurityError>;
466}
467
468pub struct StaticEnrollmentService {
473 authority: DeviceKeypair,
475 mesh_id: String,
477 allowed_tokens: std::collections::HashMap<Vec<u8>, (MeshTier, u8)>,
479 validity_ms: u64,
481}
482
483impl StaticEnrollmentService {
484 pub fn new(authority: DeviceKeypair, mesh_id: String, validity_ms: u64) -> Self {
486 Self {
487 authority,
488 mesh_id,
489 allowed_tokens: std::collections::HashMap::new(),
490 validity_ms,
491 }
492 }
493
494 pub fn add_token(&mut self, token: Vec<u8>, tier: MeshTier, permissions: u8) {
496 self.allowed_tokens.insert(token, (tier, permissions));
497 }
498}
499
500#[async_trait::async_trait]
501impl EnrollmentService for StaticEnrollmentService {
502 async fn process_request(
503 &self,
504 request: &EnrollmentRequest,
505 ) -> Result<EnrollmentResponse, SecurityError> {
506 request.verify_signature()?;
508
509 if request.mesh_id != self.mesh_id {
511 return Ok(EnrollmentResponse::denied(
512 format!(
513 "mesh ID mismatch: expected {}, got {}",
514 self.mesh_id, request.mesh_id
515 ),
516 request.timestamp_ms,
517 ));
518 }
519
520 let (tier, permissions) = match self.allowed_tokens.get(&request.bootstrap_token) {
522 Some(entry) => *entry,
523 None => {
524 return Ok(EnrollmentResponse::denied(
525 "invalid bootstrap token".to_string(),
526 request.timestamp_ms,
527 ));
528 }
529 };
530
531 let now = std::time::SystemTime::now()
533 .duration_since(std::time::UNIX_EPOCH)
534 .unwrap()
535 .as_millis() as u64;
536
537 let cert = MeshCertificate::new(
538 request.subject_public_key,
539 self.mesh_id.clone(),
540 request.node_id.clone(),
541 tier,
542 permissions,
543 now,
544 now + self.validity_ms,
545 self.authority.public_key_bytes(),
546 )
547 .signed(&self.authority);
548
549 Ok(EnrollmentResponse::approved(cert, None, now))
550 }
551
552 async fn check_status(
553 &self,
554 _subject_key: &[u8; 32],
555 ) -> Result<EnrollmentStatus, SecurityError> {
556 Ok(EnrollmentStatus::Pending)
558 }
559
560 async fn revoke(&self, _subject_key: &[u8; 32], _reason: String) -> Result<(), SecurityError> {
561 Err(SecurityError::Internal(
563 "static enrollment service does not support revocation".to_string(),
564 ))
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::super::certificate::permissions;
571 use super::*;
572
573 fn now_ms() -> u64 {
574 std::time::SystemTime::now()
575 .duration_since(std::time::UNIX_EPOCH)
576 .unwrap()
577 .as_millis() as u64
578 }
579
580 #[test]
581 fn test_enrollment_request_sign_verify() {
582 let member = DeviceKeypair::generate();
583 let now = now_ms();
584
585 let req = EnrollmentRequest::new(
586 &member,
587 "A1B2C3D4".to_string(),
588 "tac-west-1".to_string(),
589 MeshTier::Tactical,
590 b"bootstrap-token-123".to_vec(),
591 now,
592 );
593
594 assert!(req.verify_signature().is_ok());
595 assert_eq!(req.subject_public_key, member.public_key_bytes());
596 assert_eq!(req.mesh_id, "A1B2C3D4");
597 assert_eq!(req.node_id, "tac-west-1");
598 }
599
600 #[test]
601 fn test_enrollment_request_encode_decode() {
602 let member = DeviceKeypair::generate();
603 let now = now_ms();
604
605 let req = EnrollmentRequest::new(
606 &member,
607 "A1B2C3D4".to_string(),
608 "edge-unit-7".to_string(),
609 MeshTier::Edge,
610 b"token".to_vec(),
611 now,
612 );
613
614 let encoded = req.encode();
615 let decoded = EnrollmentRequest::decode(&encoded).unwrap();
616
617 assert_eq!(decoded.subject_public_key, req.subject_public_key);
618 assert_eq!(decoded.mesh_id, req.mesh_id);
619 assert_eq!(decoded.node_id, "edge-unit-7");
620 assert_eq!(decoded.requested_tier, req.requested_tier);
621 assert_eq!(decoded.bootstrap_token, req.bootstrap_token);
622 assert_eq!(decoded.timestamp_ms, req.timestamp_ms);
623 assert!(decoded.verify_signature().is_ok());
624 }
625
626 #[test]
627 fn test_enrollment_request_decode_too_short() {
628 assert!(EnrollmentRequest::decode(&[0u8; 10]).is_err());
629 }
630
631 #[test]
632 fn test_enrollment_response_approved() {
633 let authority = DeviceKeypair::generate();
634 let now = now_ms();
635
636 let cert = MeshCertificate::new_root(
637 &authority,
638 "DEADBEEF".to_string(),
639 "enterprise-0".to_string(),
640 MeshTier::Enterprise,
641 now,
642 now + 3600000,
643 );
644
645 let resp = EnrollmentResponse::approved(cert, Some(b"secret".to_vec()), now);
646 assert_eq!(resp.status, EnrollmentStatus::Approved);
647 assert!(resp.certificate.is_some());
648 assert!(resp.formation_secret.is_some());
649 }
650
651 #[test]
652 fn test_enrollment_response_denied() {
653 let now = now_ms();
654 let resp = EnrollmentResponse::denied("bad token".to_string(), now);
655 assert_eq!(
656 resp.status,
657 EnrollmentStatus::Denied {
658 reason: "bad token".to_string()
659 }
660 );
661 assert!(resp.certificate.is_none());
662 }
663
664 #[tokio::test]
665 async fn test_static_enrollment_service_approve() {
666 let authority = DeviceKeypair::generate();
667 let member = DeviceKeypair::generate();
668 let now = now_ms();
669 let validity = 24 * 60 * 60 * 1000; let mut service =
672 StaticEnrollmentService::new(authority.clone(), "DEADBEEF".to_string(), validity);
673 service.add_token(
674 b"valid-token".to_vec(),
675 MeshTier::Tactical,
676 permissions::STANDARD,
677 );
678
679 let req = EnrollmentRequest::new(
680 &member,
681 "DEADBEEF".to_string(),
682 "tac-node-1".to_string(),
683 MeshTier::Tactical,
684 b"valid-token".to_vec(),
685 now,
686 );
687
688 let resp = service.process_request(&req).await.unwrap();
689 assert_eq!(resp.status, EnrollmentStatus::Approved);
690
691 let cert = resp.certificate.unwrap();
692 assert!(cert.verify().is_ok());
693 assert_eq!(cert.subject_public_key, member.public_key_bytes());
694 assert_eq!(cert.node_id, "tac-node-1");
695 assert_eq!(cert.tier, MeshTier::Tactical);
696 assert_eq!(cert.permissions, permissions::STANDARD);
697 assert_eq!(cert.issuer_public_key, authority.public_key_bytes());
698 }
699
700 #[tokio::test]
701 async fn test_static_enrollment_service_deny_bad_token() {
702 let authority = DeviceKeypair::generate();
703 let member = DeviceKeypair::generate();
704 let now = now_ms();
705
706 let service = StaticEnrollmentService::new(authority, "DEADBEEF".to_string(), 3600000);
707
708 let req = EnrollmentRequest::new(
709 &member,
710 "DEADBEEF".to_string(),
711 "tac-node-2".to_string(),
712 MeshTier::Tactical,
713 b"invalid-token".to_vec(),
714 now,
715 );
716
717 let resp = service.process_request(&req).await.unwrap();
718 match resp.status {
719 EnrollmentStatus::Denied { reason } => {
720 assert!(reason.contains("invalid bootstrap token"));
721 }
722 other => panic!("expected Denied, got {:?}", other),
723 }
724 }
725
726 #[tokio::test]
727 async fn test_static_enrollment_service_deny_wrong_mesh() {
728 let authority = DeviceKeypair::generate();
729 let member = DeviceKeypair::generate();
730 let now = now_ms();
731
732 let mut service = StaticEnrollmentService::new(authority, "DEADBEEF".to_string(), 3600000);
733 service.add_token(b"token".to_vec(), MeshTier::Tactical, permissions::STANDARD);
734
735 let req = EnrollmentRequest::new(
736 &member,
737 "WRONG_MESH".to_string(),
738 "tac-node-3".to_string(),
739 MeshTier::Tactical,
740 b"token".to_vec(),
741 now,
742 );
743
744 let resp = service.process_request(&req).await.unwrap();
745 match resp.status {
746 EnrollmentStatus::Denied { reason } => {
747 assert!(reason.contains("mesh ID mismatch"));
748 }
749 other => panic!("expected Denied, got {:?}", other),
750 }
751 }
752
753 #[test]
754 fn test_enrollment_status_byte() {
755 assert_eq!(EnrollmentStatus::Pending.to_byte(), 0);
756 assert_eq!(EnrollmentStatus::Approved.to_byte(), 1);
757 assert_eq!(
758 EnrollmentStatus::Denied {
759 reason: "x".to_string()
760 }
761 .to_byte(),
762 2
763 );
764 assert_eq!(
765 EnrollmentStatus::Revoked {
766 reason: "x".to_string()
767 }
768 .to_byte(),
769 3
770 );
771 }
772
773 #[test]
774 fn test_enrollment_response_approved_encode_decode() {
775 let authority = DeviceKeypair::generate();
776 let now = now_ms();
777
778 let cert = MeshCertificate::new_root(
779 &authority,
780 "DEADBEEF".to_string(),
781 "enterprise-0".to_string(),
782 MeshTier::Enterprise,
783 now,
784 now + 3600000,
785 );
786
787 let resp =
788 EnrollmentResponse::approved(cert.clone(), Some(b"formation-secret".to_vec()), now);
789 let encoded = resp.encode();
790 let decoded = EnrollmentResponse::decode(&encoded).unwrap();
791
792 assert_eq!(decoded.status, EnrollmentStatus::Approved);
793 assert_eq!(decoded.timestamp_ms, now);
794
795 let decoded_cert = decoded.certificate.unwrap();
796 assert_eq!(decoded_cert.subject_public_key, cert.subject_public_key);
797 assert_eq!(decoded_cert.mesh_id, cert.mesh_id);
798 assert_eq!(decoded_cert.node_id, "enterprise-0");
799 assert!(decoded_cert.verify().is_ok());
800
801 assert_eq!(decoded.formation_secret, Some(b"formation-secret".to_vec()));
802 }
803
804 #[test]
805 fn test_enrollment_response_denied_encode_decode() {
806 let now = now_ms();
807 let resp = EnrollmentResponse::denied("bad token".to_string(), now);
808 let encoded = resp.encode();
809 let decoded = EnrollmentResponse::decode(&encoded).unwrap();
810
811 assert_eq!(
812 decoded.status,
813 EnrollmentStatus::Denied {
814 reason: "bad token".to_string()
815 }
816 );
817 assert!(decoded.certificate.is_none());
818 assert!(decoded.formation_secret.is_none());
819 assert_eq!(decoded.timestamp_ms, now);
820 }
821
822 #[test]
823 fn test_enrollment_response_pending_encode_decode() {
824 let now = now_ms();
825 let resp = EnrollmentResponse::pending(now);
826 let encoded = resp.encode();
827 let decoded = EnrollmentResponse::decode(&encoded).unwrap();
828
829 assert_eq!(decoded.status, EnrollmentStatus::Pending);
830 assert!(decoded.certificate.is_none());
831 assert!(decoded.formation_secret.is_none());
832 }
833
834 #[test]
835 fn test_enrollment_response_decode_too_short() {
836 assert!(EnrollmentResponse::decode(&[0u8; 5]).is_err());
837 }
838
839 #[test]
840 fn test_enrollment_response_no_secret_encode_decode() {
841 let authority = DeviceKeypair::generate();
842 let now = now_ms();
843
844 let cert = MeshCertificate::new_root(
845 &authority,
846 "DEADBEEF".to_string(),
847 "node-1".to_string(),
848 MeshTier::Tactical,
849 now,
850 now + 3600000,
851 );
852
853 let resp = EnrollmentResponse::approved(cert, None, now);
854 let encoded = resp.encode();
855 let decoded = EnrollmentResponse::decode(&encoded).unwrap();
856
857 assert_eq!(decoded.status, EnrollmentStatus::Approved);
858 assert!(decoded.certificate.is_some());
859 assert!(decoded.formation_secret.is_none());
860 }
861}