1use super::authenticator::DeviceAuthenticator;
43use super::device_id::DeviceId;
44use super::error::SecurityError;
45use super::keypair::DeviceKeypair;
46use crate::transport::{
47 MeshConnection, MeshTransport, NodeId, Result as TransportResult, TransportError,
48};
49use async_trait::async_trait;
50use peat_schema::security::v1::{Challenge, SignedChallengeResponse};
51use std::collections::HashMap;
52use std::sync::{Arc, RwLock};
53
54#[async_trait]
59pub trait AuthenticationChannel: Send + Sync {
60 async fn send_challenge(
62 &self,
63 peer_id: &NodeId,
64 challenge: &Challenge,
65 ) -> Result<(), SecurityError>;
66
67 async fn receive_response(
69 &self,
70 peer_id: &NodeId,
71 ) -> Result<SignedChallengeResponse, SecurityError>;
72
73 async fn send_response(
75 &self,
76 peer_id: &NodeId,
77 response: &SignedChallengeResponse,
78 ) -> Result<(), SecurityError>;
79
80 async fn receive_challenge(&self, peer_id: &NodeId) -> Result<Challenge, SecurityError>;
82}
83
84pub struct SecureMeshTransport<T: MeshTransport, A: AuthenticationChannel> {
90 authenticator: DeviceAuthenticator,
92
93 inner: Arc<T>,
95
96 auth_channel: Arc<A>,
98
99 authenticated_peers: RwLock<HashMap<NodeId, DeviceId>>,
101}
102
103impl<T: MeshTransport, A: AuthenticationChannel> SecureMeshTransport<T, A> {
104 pub fn new(keypair: DeviceKeypair, inner: Arc<T>, auth_channel: Arc<A>) -> Self {
112 Self {
113 authenticator: DeviceAuthenticator::new(keypair),
114 inner,
115 auth_channel,
116 authenticated_peers: RwLock::new(HashMap::new()),
117 }
118 }
119
120 pub fn device_id(&self) -> DeviceId {
122 self.authenticator.device_id()
123 }
124
125 pub fn is_authenticated(&self, peer_id: &NodeId) -> bool {
127 self.authenticated_peers
128 .read()
129 .map(|peers| peers.contains_key(peer_id))
130 .unwrap_or(false)
131 }
132
133 pub fn get_peer_device_id(&self, peer_id: &NodeId) -> Option<DeviceId> {
135 self.authenticated_peers
136 .read()
137 .ok()
138 .and_then(|peers| peers.get(peer_id).copied())
139 }
140
141 pub async fn authenticate_peer(&self, peer_id: &NodeId) -> Result<DeviceId, SecurityError> {
151 if let Some(device_id) = self.get_peer_device_id(peer_id) {
153 return Ok(device_id);
154 }
155
156 let challenge = self.authenticator.generate_challenge();
158 self.auth_channel
159 .send_challenge(peer_id, &challenge)
160 .await?;
161
162 let response = self.auth_channel.receive_response(peer_id).await?;
164 let device_id = self.authenticator.verify_response(&response)?;
165
166 let peer_challenge = self.auth_channel.receive_challenge(peer_id).await?;
168
169 let our_response = self.authenticator.respond_to_challenge(&peer_challenge)?;
171 self.auth_channel
172 .send_response(peer_id, &our_response)
173 .await?;
174
175 if let Ok(mut peers) = self.authenticated_peers.write() {
177 peers.insert(peer_id.clone(), device_id);
178 }
179
180 Ok(device_id)
181 }
182
183 pub fn remove_authenticated_peer(&self, peer_id: &NodeId) {
185 if let Ok(mut peers) = self.authenticated_peers.write() {
186 if let Some(device_id) = peers.remove(peer_id) {
187 self.authenticator.remove_peer(&device_id);
188 }
189 }
190 }
191
192 pub fn authenticated_peer_count(&self) -> usize {
194 self.authenticated_peers
195 .read()
196 .map(|peers| peers.len())
197 .unwrap_or(0)
198 }
199
200 pub fn authenticator(&self) -> &DeviceAuthenticator {
202 &self.authenticator
203 }
204}
205
206#[async_trait]
207impl<T: MeshTransport + 'static, A: AuthenticationChannel + 'static> MeshTransport
208 for SecureMeshTransport<T, A>
209{
210 async fn start(&self) -> TransportResult<()> {
211 self.inner.start().await
212 }
213
214 async fn stop(&self) -> TransportResult<()> {
215 self.inner.stop().await
216 }
217
218 async fn connect(&self, peer_id: &NodeId) -> TransportResult<Box<dyn MeshConnection>> {
219 let conn = self.inner.connect(peer_id).await?;
221
222 self.authenticate_peer(peer_id).await.map_err(|e| {
224 TransportError::ConnectionFailed(format!("Authentication failed: {}", e))
225 })?;
226
227 Ok(Box::new(AuthenticatedConnection {
229 inner: conn,
230 device_id: self.get_peer_device_id(peer_id).ok_or_else(|| {
231 TransportError::ConnectionFailed(
232 "peer device ID missing after authentication".to_string(),
233 )
234 })?,
235 }))
236 }
237
238 async fn disconnect(&self, peer_id: &NodeId) -> TransportResult<()> {
239 self.remove_authenticated_peer(peer_id);
240 self.inner.disconnect(peer_id).await
241 }
242
243 fn get_connection(&self, peer_id: &NodeId) -> Option<Box<dyn MeshConnection>> {
244 if let Some(device_id) = self.get_peer_device_id(peer_id) {
246 self.inner.get_connection(peer_id).map(|conn| {
247 Box::new(AuthenticatedConnection {
248 inner: conn,
249 device_id,
250 }) as Box<dyn MeshConnection>
251 })
252 } else {
253 None
254 }
255 }
256
257 fn peer_count(&self) -> usize {
258 self.authenticated_peer_count()
259 }
260
261 fn connected_peers(&self) -> Vec<NodeId> {
262 self.authenticated_peers
263 .read()
264 .map(|peers| peers.keys().cloned().collect())
265 .unwrap_or_default()
266 }
267
268 fn is_connected(&self, peer_id: &NodeId) -> bool {
269 self.is_authenticated(peer_id) && self.inner.is_connected(peer_id)
270 }
271
272 fn subscribe_peer_events(&self) -> crate::transport::PeerEventReceiver {
273 self.inner.subscribe_peer_events()
275 }
276}
277
278pub struct AuthenticatedConnection {
283 inner: Box<dyn MeshConnection>,
284 device_id: DeviceId,
285}
286
287impl AuthenticatedConnection {
288 pub fn verified_device_id(&self) -> DeviceId {
290 self.device_id
291 }
292}
293
294impl MeshConnection for AuthenticatedConnection {
295 fn peer_id(&self) -> &NodeId {
296 self.inner.peer_id()
297 }
298
299 fn is_alive(&self) -> bool {
300 self.inner.is_alive()
301 }
302
303 fn connected_at(&self) -> std::time::Instant {
304 self.inner.connected_at()
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use crate::transport::{
312 MeshConnection, MeshTransport, NodeId, Result as TransportResult, TransportError,
313 };
314 use std::sync::atomic::{AtomicBool, Ordering};
315
316 struct MockTransport {
318 started: AtomicBool,
319 connections: RwLock<HashMap<String, MockConnection>>,
320 }
321
322 impl MockTransport {
323 fn new() -> Self {
324 Self {
325 started: AtomicBool::new(false),
326 connections: RwLock::new(HashMap::new()),
327 }
328 }
329 }
330
331 #[async_trait]
332 impl MeshTransport for MockTransport {
333 async fn start(&self) -> TransportResult<()> {
334 self.started.store(true, Ordering::SeqCst);
335 Ok(())
336 }
337
338 async fn stop(&self) -> TransportResult<()> {
339 self.started.store(false, Ordering::SeqCst);
340 Ok(())
341 }
342
343 async fn connect(&self, peer_id: &NodeId) -> TransportResult<Box<dyn MeshConnection>> {
344 if !self.started.load(Ordering::SeqCst) {
345 return Err(TransportError::NotStarted);
346 }
347 let now = std::time::Instant::now();
348 let conn = MockConnection {
349 peer_id: peer_id.clone(),
350 alive: AtomicBool::new(true),
351 connected_at: now,
352 };
353 self.connections.write().unwrap().insert(
354 peer_id.to_string(),
355 MockConnection {
356 peer_id: peer_id.clone(),
357 alive: AtomicBool::new(true),
358 connected_at: now,
359 },
360 );
361 Ok(Box::new(conn))
362 }
363
364 async fn disconnect(&self, peer_id: &NodeId) -> TransportResult<()> {
365 self.connections
366 .write()
367 .unwrap()
368 .remove(&peer_id.to_string());
369 Ok(())
370 }
371
372 fn get_connection(&self, peer_id: &NodeId) -> Option<Box<dyn MeshConnection>> {
373 self.connections.read().ok().and_then(|conns| {
374 conns.get(&peer_id.to_string()).map(|c| {
375 Box::new(MockConnection {
376 peer_id: c.peer_id.clone(),
377 alive: AtomicBool::new(c.alive.load(Ordering::SeqCst)),
378 connected_at: c.connected_at,
379 }) as Box<dyn MeshConnection>
380 })
381 })
382 }
383
384 fn peer_count(&self) -> usize {
385 self.connections.read().map(|c| c.len()).unwrap_or(0)
386 }
387
388 fn connected_peers(&self) -> Vec<NodeId> {
389 self.connections
390 .read()
391 .map(|c| c.values().map(|conn| conn.peer_id.clone()).collect())
392 .unwrap_or_default()
393 }
394
395 fn subscribe_peer_events(&self) -> crate::transport::PeerEventReceiver {
396 let (_tx, rx) = tokio::sync::mpsc::channel(256);
397 rx
398 }
399 }
400
401 struct MockConnection {
402 peer_id: NodeId,
403 alive: AtomicBool,
404 connected_at: std::time::Instant,
405 }
406
407 impl MeshConnection for MockConnection {
408 fn peer_id(&self) -> &NodeId {
409 &self.peer_id
410 }
411
412 fn is_alive(&self) -> bool {
413 self.alive.load(Ordering::SeqCst)
414 }
415
416 fn connected_at(&self) -> std::time::Instant {
417 self.connected_at
418 }
419 }
420
421 struct MockAuthChannel {
423 peer_keypairs: RwLock<HashMap<String, DeviceKeypair>>,
425 last_challenge: RwLock<Option<Challenge>>,
427 }
428
429 impl MockAuthChannel {
430 fn new() -> Self {
431 Self {
432 peer_keypairs: RwLock::new(HashMap::new()),
433 last_challenge: RwLock::new(None),
434 }
435 }
436
437 fn register_peer_keypair(&self, peer_id: &NodeId, keypair: DeviceKeypair) {
438 if let Ok(mut peers) = self.peer_keypairs.write() {
439 peers.insert(peer_id.to_string(), keypair);
440 }
441 }
442 }
443
444 #[async_trait]
445 impl AuthenticationChannel for MockAuthChannel {
446 async fn send_challenge(
447 &self,
448 _peer_id: &NodeId,
449 challenge: &Challenge,
450 ) -> Result<(), SecurityError> {
451 if let Ok(mut last) = self.last_challenge.write() {
453 *last = Some(challenge.clone());
454 }
455 Ok(())
456 }
457
458 async fn receive_response(
459 &self,
460 peer_id: &NodeId,
461 ) -> Result<SignedChallengeResponse, SecurityError> {
462 let keypair = self
464 .peer_keypairs
465 .read()
466 .map_err(|e| SecurityError::Internal(e.to_string()))?
467 .get(&peer_id.to_string())
468 .cloned()
469 .ok_or_else(|| SecurityError::PeerNotFound(peer_id.to_string()))?;
470
471 let challenge = self
473 .last_challenge
474 .read()
475 .map_err(|e| SecurityError::Internal(e.to_string()))?
476 .clone()
477 .ok_or_else(|| SecurityError::Internal("no challenge sent".to_string()))?;
478
479 let authenticator = DeviceAuthenticator::new(keypair);
480 authenticator.respond_to_challenge(&challenge)
481 }
482
483 async fn send_response(
484 &self,
485 _peer_id: &NodeId,
486 _response: &SignedChallengeResponse,
487 ) -> Result<(), SecurityError> {
488 Ok(())
489 }
490
491 async fn receive_challenge(&self, _peer_id: &NodeId) -> Result<Challenge, SecurityError> {
492 Ok(Challenge {
493 nonce: vec![0u8; 32],
494 timestamp: None,
495 challenger_id: "peer".to_string(),
496 expires_at: Some(peat_schema::common::v1::Timestamp {
497 seconds: u64::MAX,
498 nanos: 0,
499 }),
500 })
501 }
502 }
503
504 #[tokio::test]
505 async fn test_secure_transport_creation() {
506 let keypair = DeviceKeypair::generate();
507 let transport = Arc::new(MockTransport::new());
508 let auth_channel = Arc::new(MockAuthChannel::new());
509
510 let secure = SecureMeshTransport::new(keypair, transport, auth_channel);
511
512 assert_eq!(secure.authenticated_peer_count(), 0);
513 }
514
515 #[tokio::test]
516 async fn test_secure_transport_start_stop() {
517 let keypair = DeviceKeypair::generate();
518 let transport = Arc::new(MockTransport::new());
519 let auth_channel = Arc::new(MockAuthChannel::new());
520
521 let secure = SecureMeshTransport::new(keypair, transport.clone(), auth_channel);
522
523 assert!(!transport.started.load(Ordering::SeqCst));
524 secure.start().await.unwrap();
525 assert!(transport.started.load(Ordering::SeqCst));
526 secure.stop().await.unwrap();
527 assert!(!transport.started.load(Ordering::SeqCst));
528 }
529
530 #[tokio::test]
531 async fn test_secure_transport_connect_authenticates() {
532 let our_keypair = DeviceKeypair::generate();
533 let peer_keypair = DeviceKeypair::generate();
534 let peer_id: NodeId = peer_keypair.device_id().into();
535
536 let transport = Arc::new(MockTransport::new());
537 let auth_channel = Arc::new(MockAuthChannel::new());
538 auth_channel.register_peer_keypair(&peer_id, peer_keypair.clone());
539
540 let secure = SecureMeshTransport::new(our_keypair, transport, auth_channel);
541
542 secure.start().await.unwrap();
543 let conn = secure.connect(&peer_id).await.unwrap();
544
545 assert!(secure.is_authenticated(&peer_id));
546 assert_eq!(conn.peer_id(), &peer_id);
547 assert!(conn.is_alive());
548 }
549
550 #[tokio::test]
551 async fn test_secure_transport_disconnect_removes_auth() {
552 let our_keypair = DeviceKeypair::generate();
553 let peer_keypair = DeviceKeypair::generate();
554 let peer_id: NodeId = peer_keypair.device_id().into();
555
556 let transport = Arc::new(MockTransport::new());
557 let auth_channel = Arc::new(MockAuthChannel::new());
558 auth_channel.register_peer_keypair(&peer_id, peer_keypair);
559
560 let secure = SecureMeshTransport::new(our_keypair, transport, auth_channel);
561
562 secure.start().await.unwrap();
563 secure.connect(&peer_id).await.unwrap();
564 assert!(secure.is_authenticated(&peer_id));
565
566 secure.disconnect(&peer_id).await.unwrap();
567 assert!(!secure.is_authenticated(&peer_id));
568 }
569
570 #[tokio::test]
571 async fn test_authenticated_connection_exposes_device_id() {
572 let our_keypair = DeviceKeypair::generate();
573 let peer_keypair = DeviceKeypair::generate();
574 let peer_device_id = peer_keypair.device_id();
575 let peer_id: NodeId = peer_device_id.into();
576
577 let transport = Arc::new(MockTransport::new());
578 let auth_channel = Arc::new(MockAuthChannel::new());
579 auth_channel.register_peer_keypair(&peer_id, peer_keypair);
580
581 let secure = SecureMeshTransport::new(our_keypair, transport, auth_channel);
582
583 secure.start().await.unwrap();
584 let _conn = secure.connect(&peer_id).await.unwrap();
585
586 assert!(secure.is_authenticated(&peer_id));
588 assert_eq!(secure.get_peer_device_id(&peer_id), Some(peer_device_id));
589 }
590}