use super::authenticator::DeviceAuthenticator;
use super::device_id::DeviceId;
use super::error::SecurityError;
use super::keypair::DeviceKeypair;
use crate::transport::{
MeshConnection, MeshTransport, NodeId, Result as TransportResult, TransportError,
};
use async_trait::async_trait;
use peat_schema::security::v1::{Challenge, SignedChallengeResponse};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[async_trait]
pub trait AuthenticationChannel: Send + Sync {
async fn send_challenge(
&self,
peer_id: &NodeId,
challenge: &Challenge,
) -> Result<(), SecurityError>;
async fn receive_response(
&self,
peer_id: &NodeId,
) -> Result<SignedChallengeResponse, SecurityError>;
async fn send_response(
&self,
peer_id: &NodeId,
response: &SignedChallengeResponse,
) -> Result<(), SecurityError>;
async fn receive_challenge(&self, peer_id: &NodeId) -> Result<Challenge, SecurityError>;
}
pub struct SecureMeshTransport<T: MeshTransport, A: AuthenticationChannel> {
authenticator: DeviceAuthenticator,
inner: Arc<T>,
auth_channel: Arc<A>,
authenticated_peers: RwLock<HashMap<NodeId, DeviceId>>,
}
impl<T: MeshTransport, A: AuthenticationChannel> SecureMeshTransport<T, A> {
pub fn new(keypair: DeviceKeypair, inner: Arc<T>, auth_channel: Arc<A>) -> Self {
Self {
authenticator: DeviceAuthenticator::new(keypair),
inner,
auth_channel,
authenticated_peers: RwLock::new(HashMap::new()),
}
}
pub fn device_id(&self) -> DeviceId {
self.authenticator.device_id()
}
pub fn is_authenticated(&self, peer_id: &NodeId) -> bool {
self.authenticated_peers
.read()
.map(|peers| peers.contains_key(peer_id))
.unwrap_or(false)
}
pub fn get_peer_device_id(&self, peer_id: &NodeId) -> Option<DeviceId> {
self.authenticated_peers
.read()
.ok()
.and_then(|peers| peers.get(peer_id).copied())
}
pub async fn authenticate_peer(&self, peer_id: &NodeId) -> Result<DeviceId, SecurityError> {
if let Some(device_id) = self.get_peer_device_id(peer_id) {
return Ok(device_id);
}
let challenge = self.authenticator.generate_challenge();
self.auth_channel
.send_challenge(peer_id, &challenge)
.await?;
let response = self.auth_channel.receive_response(peer_id).await?;
let device_id = self.authenticator.verify_response(&response)?;
let peer_challenge = self.auth_channel.receive_challenge(peer_id).await?;
let our_response = self.authenticator.respond_to_challenge(&peer_challenge)?;
self.auth_channel
.send_response(peer_id, &our_response)
.await?;
if let Ok(mut peers) = self.authenticated_peers.write() {
peers.insert(peer_id.clone(), device_id);
}
Ok(device_id)
}
pub fn remove_authenticated_peer(&self, peer_id: &NodeId) {
if let Ok(mut peers) = self.authenticated_peers.write() {
if let Some(device_id) = peers.remove(peer_id) {
self.authenticator.remove_peer(&device_id);
}
}
}
pub fn authenticated_peer_count(&self) -> usize {
self.authenticated_peers
.read()
.map(|peers| peers.len())
.unwrap_or(0)
}
pub fn authenticator(&self) -> &DeviceAuthenticator {
&self.authenticator
}
}
#[async_trait]
impl<T: MeshTransport + 'static, A: AuthenticationChannel + 'static> MeshTransport
for SecureMeshTransport<T, A>
{
async fn start(&self) -> TransportResult<()> {
self.inner.start().await
}
async fn stop(&self) -> TransportResult<()> {
self.inner.stop().await
}
async fn connect(&self, peer_id: &NodeId) -> TransportResult<Box<dyn MeshConnection>> {
let conn = self.inner.connect(peer_id).await?;
self.authenticate_peer(peer_id).await.map_err(|e| {
TransportError::ConnectionFailed(format!("Authentication failed: {}", e))
})?;
Ok(Box::new(AuthenticatedConnection {
inner: conn,
device_id: self.get_peer_device_id(peer_id).ok_or_else(|| {
TransportError::ConnectionFailed(
"peer device ID missing after authentication".to_string(),
)
})?,
}))
}
async fn disconnect(&self, peer_id: &NodeId) -> TransportResult<()> {
self.remove_authenticated_peer(peer_id);
self.inner.disconnect(peer_id).await
}
fn get_connection(&self, peer_id: &NodeId) -> Option<Box<dyn MeshConnection>> {
if let Some(device_id) = self.get_peer_device_id(peer_id) {
self.inner.get_connection(peer_id).map(|conn| {
Box::new(AuthenticatedConnection {
inner: conn,
device_id,
}) as Box<dyn MeshConnection>
})
} else {
None
}
}
fn peer_count(&self) -> usize {
self.authenticated_peer_count()
}
fn connected_peers(&self) -> Vec<NodeId> {
self.authenticated_peers
.read()
.map(|peers| peers.keys().cloned().collect())
.unwrap_or_default()
}
fn is_connected(&self, peer_id: &NodeId) -> bool {
self.is_authenticated(peer_id) && self.inner.is_connected(peer_id)
}
fn subscribe_peer_events(&self) -> crate::transport::PeerEventReceiver {
self.inner.subscribe_peer_events()
}
}
pub struct AuthenticatedConnection {
inner: Box<dyn MeshConnection>,
device_id: DeviceId,
}
impl AuthenticatedConnection {
pub fn verified_device_id(&self) -> DeviceId {
self.device_id
}
}
impl MeshConnection for AuthenticatedConnection {
fn peer_id(&self) -> &NodeId {
self.inner.peer_id()
}
fn is_alive(&self) -> bool {
self.inner.is_alive()
}
fn connected_at(&self) -> std::time::Instant {
self.inner.connected_at()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::{
MeshConnection, MeshTransport, NodeId, Result as TransportResult, TransportError,
};
use std::sync::atomic::{AtomicBool, Ordering};
struct MockTransport {
started: AtomicBool,
connections: RwLock<HashMap<String, MockConnection>>,
}
impl MockTransport {
fn new() -> Self {
Self {
started: AtomicBool::new(false),
connections: RwLock::new(HashMap::new()),
}
}
}
#[async_trait]
impl MeshTransport for MockTransport {
async fn start(&self) -> TransportResult<()> {
self.started.store(true, Ordering::SeqCst);
Ok(())
}
async fn stop(&self) -> TransportResult<()> {
self.started.store(false, Ordering::SeqCst);
Ok(())
}
async fn connect(&self, peer_id: &NodeId) -> TransportResult<Box<dyn MeshConnection>> {
if !self.started.load(Ordering::SeqCst) {
return Err(TransportError::NotStarted);
}
let now = std::time::Instant::now();
let conn = MockConnection {
peer_id: peer_id.clone(),
alive: AtomicBool::new(true),
connected_at: now,
};
self.connections.write().unwrap().insert(
peer_id.to_string(),
MockConnection {
peer_id: peer_id.clone(),
alive: AtomicBool::new(true),
connected_at: now,
},
);
Ok(Box::new(conn))
}
async fn disconnect(&self, peer_id: &NodeId) -> TransportResult<()> {
self.connections
.write()
.unwrap()
.remove(&peer_id.to_string());
Ok(())
}
fn get_connection(&self, peer_id: &NodeId) -> Option<Box<dyn MeshConnection>> {
self.connections.read().ok().and_then(|conns| {
conns.get(&peer_id.to_string()).map(|c| {
Box::new(MockConnection {
peer_id: c.peer_id.clone(),
alive: AtomicBool::new(c.alive.load(Ordering::SeqCst)),
connected_at: c.connected_at,
}) as Box<dyn MeshConnection>
})
})
}
fn peer_count(&self) -> usize {
self.connections.read().map(|c| c.len()).unwrap_or(0)
}
fn connected_peers(&self) -> Vec<NodeId> {
self.connections
.read()
.map(|c| c.values().map(|conn| conn.peer_id.clone()).collect())
.unwrap_or_default()
}
fn subscribe_peer_events(&self) -> crate::transport::PeerEventReceiver {
let (_tx, rx) = tokio::sync::mpsc::channel(256);
rx
}
}
struct MockConnection {
peer_id: NodeId,
alive: AtomicBool,
connected_at: std::time::Instant,
}
impl MeshConnection for MockConnection {
fn peer_id(&self) -> &NodeId {
&self.peer_id
}
fn is_alive(&self) -> bool {
self.alive.load(Ordering::SeqCst)
}
fn connected_at(&self) -> std::time::Instant {
self.connected_at
}
}
struct MockAuthChannel {
peer_keypairs: RwLock<HashMap<String, DeviceKeypair>>,
last_challenge: RwLock<Option<Challenge>>,
}
impl MockAuthChannel {
fn new() -> Self {
Self {
peer_keypairs: RwLock::new(HashMap::new()),
last_challenge: RwLock::new(None),
}
}
fn register_peer_keypair(&self, peer_id: &NodeId, keypair: DeviceKeypair) {
if let Ok(mut peers) = self.peer_keypairs.write() {
peers.insert(peer_id.to_string(), keypair);
}
}
}
#[async_trait]
impl AuthenticationChannel for MockAuthChannel {
async fn send_challenge(
&self,
_peer_id: &NodeId,
challenge: &Challenge,
) -> Result<(), SecurityError> {
if let Ok(mut last) = self.last_challenge.write() {
*last = Some(challenge.clone());
}
Ok(())
}
async fn receive_response(
&self,
peer_id: &NodeId,
) -> Result<SignedChallengeResponse, SecurityError> {
let keypair = self
.peer_keypairs
.read()
.map_err(|e| SecurityError::Internal(e.to_string()))?
.get(&peer_id.to_string())
.cloned()
.ok_or_else(|| SecurityError::PeerNotFound(peer_id.to_string()))?;
let challenge = self
.last_challenge
.read()
.map_err(|e| SecurityError::Internal(e.to_string()))?
.clone()
.ok_or_else(|| SecurityError::Internal("no challenge sent".to_string()))?;
let authenticator = DeviceAuthenticator::new(keypair);
authenticator.respond_to_challenge(&challenge)
}
async fn send_response(
&self,
_peer_id: &NodeId,
_response: &SignedChallengeResponse,
) -> Result<(), SecurityError> {
Ok(())
}
async fn receive_challenge(&self, _peer_id: &NodeId) -> Result<Challenge, SecurityError> {
Ok(Challenge {
nonce: vec![0u8; 32],
timestamp: None,
challenger_id: "peer".to_string(),
expires_at: Some(peat_schema::common::v1::Timestamp {
seconds: u64::MAX,
nanos: 0,
}),
})
}
}
#[tokio::test]
async fn test_secure_transport_creation() {
let keypair = DeviceKeypair::generate();
let transport = Arc::new(MockTransport::new());
let auth_channel = Arc::new(MockAuthChannel::new());
let secure = SecureMeshTransport::new(keypair, transport, auth_channel);
assert_eq!(secure.authenticated_peer_count(), 0);
}
#[tokio::test]
async fn test_secure_transport_start_stop() {
let keypair = DeviceKeypair::generate();
let transport = Arc::new(MockTransport::new());
let auth_channel = Arc::new(MockAuthChannel::new());
let secure = SecureMeshTransport::new(keypair, transport.clone(), auth_channel);
assert!(!transport.started.load(Ordering::SeqCst));
secure.start().await.unwrap();
assert!(transport.started.load(Ordering::SeqCst));
secure.stop().await.unwrap();
assert!(!transport.started.load(Ordering::SeqCst));
}
#[tokio::test]
async fn test_secure_transport_connect_authenticates() {
let our_keypair = DeviceKeypair::generate();
let peer_keypair = DeviceKeypair::generate();
let peer_id: NodeId = peer_keypair.device_id().into();
let transport = Arc::new(MockTransport::new());
let auth_channel = Arc::new(MockAuthChannel::new());
auth_channel.register_peer_keypair(&peer_id, peer_keypair.clone());
let secure = SecureMeshTransport::new(our_keypair, transport, auth_channel);
secure.start().await.unwrap();
let conn = secure.connect(&peer_id).await.unwrap();
assert!(secure.is_authenticated(&peer_id));
assert_eq!(conn.peer_id(), &peer_id);
assert!(conn.is_alive());
}
#[tokio::test]
async fn test_secure_transport_disconnect_removes_auth() {
let our_keypair = DeviceKeypair::generate();
let peer_keypair = DeviceKeypair::generate();
let peer_id: NodeId = peer_keypair.device_id().into();
let transport = Arc::new(MockTransport::new());
let auth_channel = Arc::new(MockAuthChannel::new());
auth_channel.register_peer_keypair(&peer_id, peer_keypair);
let secure = SecureMeshTransport::new(our_keypair, transport, auth_channel);
secure.start().await.unwrap();
secure.connect(&peer_id).await.unwrap();
assert!(secure.is_authenticated(&peer_id));
secure.disconnect(&peer_id).await.unwrap();
assert!(!secure.is_authenticated(&peer_id));
}
#[tokio::test]
async fn test_authenticated_connection_exposes_device_id() {
let our_keypair = DeviceKeypair::generate();
let peer_keypair = DeviceKeypair::generate();
let peer_device_id = peer_keypair.device_id();
let peer_id: NodeId = peer_device_id.into();
let transport = Arc::new(MockTransport::new());
let auth_channel = Arc::new(MockAuthChannel::new());
auth_channel.register_peer_keypair(&peer_id, peer_keypair);
let secure = SecureMeshTransport::new(our_keypair, transport, auth_channel);
secure.start().await.unwrap();
let _conn = secure.connect(&peer_id).await.unwrap();
assert!(secure.is_authenticated(&peer_id));
assert_eq!(secure.get_peer_device_id(&peer_id), Some(peer_device_id));
}
}