Skip to main content

peat_protocol/security/
transport.rs

1//! Secure transport wrapper for authenticated mesh connections.
2//!
3//! This module provides a decorator pattern wrapper that adds Ed25519 authentication
4//! to any `MeshTransport` implementation. It implements the challenge-response protocol
5//! to authenticate peers before allowing sync operations.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ┌─────────────────────────────────────────┐
11//! │          SecureMeshTransport            │
12//! │  ┌──────────────────────────────────┐   │
13//! │  │     DeviceAuthenticator          │   │
14//! │  │  (Ed25519 challenge-response)    │   │
15//! │  └──────────────────────────────────┘   │
16//! │                  │                      │
17//! │  ┌──────────────────────────────────┐   │
18//! │  │   Inner MeshTransport            │   │
19//! │  │   (Iroh or Ditto)                │   │
20//! │  └──────────────────────────────────┘   │
21//! └─────────────────────────────────────────┘
22//! ```
23//!
24//! # Example
25//!
26//! ```ignore
27//! use peat_protocol::security::{DeviceKeypair, SecureMeshTransport};
28//! use peat_protocol::transport::MeshTransport;
29//!
30//! // Create keypair and inner transport
31//! let keypair = DeviceKeypair::generate();
32//! let inner_transport: Arc<dyn MeshTransport> = ...;
33//!
34//! // Wrap with security
35//! let secure_transport = SecureMeshTransport::new(keypair, inner_transport);
36//!
37//! // Connect authenticates first
38//! let conn = secure_transport.connect(&peer_id).await?;
39//! // conn.peer_id() is now cryptographically verified
40//! ```
41
42use super::authenticator::DeviceAuthenticator;
43use super::device_id::DeviceId;
44use super::error::SecurityError;
45use super::keypair::DeviceKeypair;
46use crate::transport::{
47    MeshConnection, MeshTransport, NodeId, Result as TransportResult, TransportError,
48};
49use async_trait::async_trait;
50use peat_schema::security::v1::{Challenge, SignedChallengeResponse};
51use std::collections::HashMap;
52use std::sync::{Arc, RwLock};
53
54/// Authentication callback for custom transport-level auth message exchange.
55///
56/// This trait allows the secure transport to exchange authentication messages
57/// over any underlying transport mechanism.
58#[async_trait]
59pub trait AuthenticationChannel: Send + Sync {
60    /// Send an authentication challenge to a peer.
61    async fn send_challenge(
62        &self,
63        peer_id: &NodeId,
64        challenge: &Challenge,
65    ) -> Result<(), SecurityError>;
66
67    /// Receive a challenge response from a peer.
68    async fn receive_response(
69        &self,
70        peer_id: &NodeId,
71    ) -> Result<SignedChallengeResponse, SecurityError>;
72
73    /// Send a challenge response to a peer.
74    async fn send_response(
75        &self,
76        peer_id: &NodeId,
77        response: &SignedChallengeResponse,
78    ) -> Result<(), SecurityError>;
79
80    /// Receive a challenge from a peer.
81    async fn receive_challenge(&self, peer_id: &NodeId) -> Result<Challenge, SecurityError>;
82}
83
84/// Secure mesh transport that requires authentication before sync.
85///
86/// This wrapper adds Ed25519-based challenge-response authentication to any
87/// `MeshTransport` implementation. Peers must complete mutual authentication
88/// before the connection is considered established.
89pub struct SecureMeshTransport<T: MeshTransport, A: AuthenticationChannel> {
90    /// The device authenticator for crypto operations
91    authenticator: DeviceAuthenticator,
92
93    /// The underlying transport
94    inner: Arc<T>,
95
96    /// Authentication channel for message exchange
97    auth_channel: Arc<A>,
98
99    /// Mapping from NodeId to DeviceId for authenticated peers
100    authenticated_peers: RwLock<HashMap<NodeId, DeviceId>>,
101}
102
103impl<T: MeshTransport, A: AuthenticationChannel> SecureMeshTransport<T, A> {
104    /// Create a new secure transport wrapper.
105    ///
106    /// # Arguments
107    ///
108    /// * `keypair` - This device's keypair for authentication
109    /// * `inner` - The underlying transport to wrap
110    /// * `auth_channel` - Channel for exchanging authentication messages
111    pub fn new(keypair: DeviceKeypair, inner: Arc<T>, auth_channel: Arc<A>) -> Self {
112        Self {
113            authenticator: DeviceAuthenticator::new(keypair),
114            inner,
115            auth_channel,
116            authenticated_peers: RwLock::new(HashMap::new()),
117        }
118    }
119
120    /// Get this device's ID.
121    pub fn device_id(&self) -> DeviceId {
122        self.authenticator.device_id()
123    }
124
125    /// Check if a peer is authenticated.
126    pub fn is_authenticated(&self, peer_id: &NodeId) -> bool {
127        self.authenticated_peers
128            .read()
129            .map(|peers| peers.contains_key(peer_id))
130            .unwrap_or(false)
131    }
132
133    /// Get the DeviceId for an authenticated peer.
134    pub fn get_peer_device_id(&self, peer_id: &NodeId) -> Option<DeviceId> {
135        self.authenticated_peers
136            .read()
137            .ok()
138            .and_then(|peers| peers.get(peer_id).copied())
139    }
140
141    /// Authenticate a peer using challenge-response.
142    ///
143    /// This performs mutual authentication:
144    /// 1. We send a challenge to the peer
145    /// 2. Peer responds with signed challenge
146    /// 3. We verify the response
147    /// 4. Peer sends us a challenge
148    /// 5. We respond with signed challenge
149    /// 6. Both sides are now authenticated
150    pub async fn authenticate_peer(&self, peer_id: &NodeId) -> Result<DeviceId, SecurityError> {
151        // Check if already authenticated
152        if let Some(device_id) = self.get_peer_device_id(peer_id) {
153            return Ok(device_id);
154        }
155
156        // Step 1: Generate and send challenge
157        let challenge = self.authenticator.generate_challenge();
158        self.auth_channel
159            .send_challenge(peer_id, &challenge)
160            .await?;
161
162        // Step 2: Receive and verify response
163        let response = self.auth_channel.receive_response(peer_id).await?;
164        let device_id = self.authenticator.verify_response(&response)?;
165
166        // Step 3: Receive challenge from peer (mutual auth)
167        let peer_challenge = self.auth_channel.receive_challenge(peer_id).await?;
168
169        // Step 4: Respond to peer's challenge
170        let our_response = self.authenticator.respond_to_challenge(&peer_challenge)?;
171        self.auth_channel
172            .send_response(peer_id, &our_response)
173            .await?;
174
175        // Cache the authenticated peer
176        if let Ok(mut peers) = self.authenticated_peers.write() {
177            peers.insert(peer_id.clone(), device_id);
178        }
179
180        Ok(device_id)
181    }
182
183    /// Remove a peer from the authenticated cache.
184    pub fn remove_authenticated_peer(&self, peer_id: &NodeId) {
185        if let Ok(mut peers) = self.authenticated_peers.write() {
186            if let Some(device_id) = peers.remove(peer_id) {
187                self.authenticator.remove_peer(&device_id);
188            }
189        }
190    }
191
192    /// Get the number of authenticated peers.
193    pub fn authenticated_peer_count(&self) -> usize {
194        self.authenticated_peers
195            .read()
196            .map(|peers| peers.len())
197            .unwrap_or(0)
198    }
199
200    /// Get the underlying authenticator (for testing or advanced use).
201    pub fn authenticator(&self) -> &DeviceAuthenticator {
202        &self.authenticator
203    }
204}
205
206#[async_trait]
207impl<T: MeshTransport + 'static, A: AuthenticationChannel + 'static> MeshTransport
208    for SecureMeshTransport<T, A>
209{
210    async fn start(&self) -> TransportResult<()> {
211        self.inner.start().await
212    }
213
214    async fn stop(&self) -> TransportResult<()> {
215        self.inner.stop().await
216    }
217
218    async fn connect(&self, peer_id: &NodeId) -> TransportResult<Box<dyn MeshConnection>> {
219        // First establish the underlying connection
220        let conn = self.inner.connect(peer_id).await?;
221
222        // Then authenticate the peer
223        self.authenticate_peer(peer_id).await.map_err(|e| {
224            TransportError::ConnectionFailed(format!("Authentication failed: {}", e))
225        })?;
226
227        // Return an authenticated connection wrapper
228        Ok(Box::new(AuthenticatedConnection {
229            inner: conn,
230            device_id: self.get_peer_device_id(peer_id).ok_or_else(|| {
231                TransportError::ConnectionFailed(
232                    "peer device ID missing after authentication".to_string(),
233                )
234            })?,
235        }))
236    }
237
238    async fn disconnect(&self, peer_id: &NodeId) -> TransportResult<()> {
239        self.remove_authenticated_peer(peer_id);
240        self.inner.disconnect(peer_id).await
241    }
242
243    fn get_connection(&self, peer_id: &NodeId) -> Option<Box<dyn MeshConnection>> {
244        // Only return connection if peer is authenticated
245        if let Some(device_id) = self.get_peer_device_id(peer_id) {
246            self.inner.get_connection(peer_id).map(|conn| {
247                Box::new(AuthenticatedConnection {
248                    inner: conn,
249                    device_id,
250                }) as Box<dyn MeshConnection>
251            })
252        } else {
253            None
254        }
255    }
256
257    fn peer_count(&self) -> usize {
258        self.authenticated_peer_count()
259    }
260
261    fn connected_peers(&self) -> Vec<NodeId> {
262        self.authenticated_peers
263            .read()
264            .map(|peers| peers.keys().cloned().collect())
265            .unwrap_or_default()
266    }
267
268    fn is_connected(&self, peer_id: &NodeId) -> bool {
269        self.is_authenticated(peer_id) && self.inner.is_connected(peer_id)
270    }
271
272    fn subscribe_peer_events(&self) -> crate::transport::PeerEventReceiver {
273        // Delegate to inner transport - events are emitted at the transport layer
274        self.inner.subscribe_peer_events()
275    }
276}
277
278/// An authenticated connection wrapper.
279///
280/// This wraps an underlying `MeshConnection` and tracks the verified DeviceId
281/// of the remote peer.
282pub struct AuthenticatedConnection {
283    inner: Box<dyn MeshConnection>,
284    device_id: DeviceId,
285}
286
287impl AuthenticatedConnection {
288    /// Get the verified DeviceId of the remote peer.
289    pub fn verified_device_id(&self) -> DeviceId {
290        self.device_id
291    }
292}
293
294impl MeshConnection for AuthenticatedConnection {
295    fn peer_id(&self) -> &NodeId {
296        self.inner.peer_id()
297    }
298
299    fn is_alive(&self) -> bool {
300        self.inner.is_alive()
301    }
302
303    fn connected_at(&self) -> std::time::Instant {
304        self.inner.connected_at()
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use crate::transport::{
312        MeshConnection, MeshTransport, NodeId, Result as TransportResult, TransportError,
313    };
314    use std::sync::atomic::{AtomicBool, Ordering};
315
316    /// Mock transport for testing
317    struct MockTransport {
318        started: AtomicBool,
319        connections: RwLock<HashMap<String, MockConnection>>,
320    }
321
322    impl MockTransport {
323        fn new() -> Self {
324            Self {
325                started: AtomicBool::new(false),
326                connections: RwLock::new(HashMap::new()),
327            }
328        }
329    }
330
331    #[async_trait]
332    impl MeshTransport for MockTransport {
333        async fn start(&self) -> TransportResult<()> {
334            self.started.store(true, Ordering::SeqCst);
335            Ok(())
336        }
337
338        async fn stop(&self) -> TransportResult<()> {
339            self.started.store(false, Ordering::SeqCst);
340            Ok(())
341        }
342
343        async fn connect(&self, peer_id: &NodeId) -> TransportResult<Box<dyn MeshConnection>> {
344            if !self.started.load(Ordering::SeqCst) {
345                return Err(TransportError::NotStarted);
346            }
347            let now = std::time::Instant::now();
348            let conn = MockConnection {
349                peer_id: peer_id.clone(),
350                alive: AtomicBool::new(true),
351                connected_at: now,
352            };
353            self.connections.write().unwrap().insert(
354                peer_id.to_string(),
355                MockConnection {
356                    peer_id: peer_id.clone(),
357                    alive: AtomicBool::new(true),
358                    connected_at: now,
359                },
360            );
361            Ok(Box::new(conn))
362        }
363
364        async fn disconnect(&self, peer_id: &NodeId) -> TransportResult<()> {
365            self.connections
366                .write()
367                .unwrap()
368                .remove(&peer_id.to_string());
369            Ok(())
370        }
371
372        fn get_connection(&self, peer_id: &NodeId) -> Option<Box<dyn MeshConnection>> {
373            self.connections.read().ok().and_then(|conns| {
374                conns.get(&peer_id.to_string()).map(|c| {
375                    Box::new(MockConnection {
376                        peer_id: c.peer_id.clone(),
377                        alive: AtomicBool::new(c.alive.load(Ordering::SeqCst)),
378                        connected_at: c.connected_at,
379                    }) as Box<dyn MeshConnection>
380                })
381            })
382        }
383
384        fn peer_count(&self) -> usize {
385            self.connections.read().map(|c| c.len()).unwrap_or(0)
386        }
387
388        fn connected_peers(&self) -> Vec<NodeId> {
389            self.connections
390                .read()
391                .map(|c| c.values().map(|conn| conn.peer_id.clone()).collect())
392                .unwrap_or_default()
393        }
394
395        fn subscribe_peer_events(&self) -> crate::transport::PeerEventReceiver {
396            let (_tx, rx) = tokio::sync::mpsc::channel(256);
397            rx
398        }
399    }
400
401    struct MockConnection {
402        peer_id: NodeId,
403        alive: AtomicBool,
404        connected_at: std::time::Instant,
405    }
406
407    impl MeshConnection for MockConnection {
408        fn peer_id(&self) -> &NodeId {
409            &self.peer_id
410        }
411
412        fn is_alive(&self) -> bool {
413            self.alive.load(Ordering::SeqCst)
414        }
415
416        fn connected_at(&self) -> std::time::Instant {
417            self.connected_at
418        }
419    }
420
421    /// Mock auth channel that always succeeds (for basic transport tests)
422    struct MockAuthChannel {
423        /// Peer keypairs for simulating responses
424        peer_keypairs: RwLock<HashMap<String, DeviceKeypair>>,
425        /// Last challenge sent (for consistent response)
426        last_challenge: RwLock<Option<Challenge>>,
427    }
428
429    impl MockAuthChannel {
430        fn new() -> Self {
431            Self {
432                peer_keypairs: RwLock::new(HashMap::new()),
433                last_challenge: RwLock::new(None),
434            }
435        }
436
437        fn register_peer_keypair(&self, peer_id: &NodeId, keypair: DeviceKeypair) {
438            if let Ok(mut peers) = self.peer_keypairs.write() {
439                peers.insert(peer_id.to_string(), keypair);
440            }
441        }
442    }
443
444    #[async_trait]
445    impl AuthenticationChannel for MockAuthChannel {
446        async fn send_challenge(
447            &self,
448            _peer_id: &NodeId,
449            challenge: &Challenge,
450        ) -> Result<(), SecurityError> {
451            // Store the challenge for when we need to create a response
452            if let Ok(mut last) = self.last_challenge.write() {
453                *last = Some(challenge.clone());
454            }
455            Ok(())
456        }
457
458        async fn receive_response(
459            &self,
460            peer_id: &NodeId,
461        ) -> Result<SignedChallengeResponse, SecurityError> {
462            // Return a valid response from the peer's keypair
463            let keypair = self
464                .peer_keypairs
465                .read()
466                .map_err(|e| SecurityError::Internal(e.to_string()))?
467                .get(&peer_id.to_string())
468                .cloned()
469                .ok_or_else(|| SecurityError::PeerNotFound(peer_id.to_string()))?;
470
471            // Use the challenge that was sent (with correct challenger_id)
472            let challenge = self
473                .last_challenge
474                .read()
475                .map_err(|e| SecurityError::Internal(e.to_string()))?
476                .clone()
477                .ok_or_else(|| SecurityError::Internal("no challenge sent".to_string()))?;
478
479            let authenticator = DeviceAuthenticator::new(keypair);
480            authenticator.respond_to_challenge(&challenge)
481        }
482
483        async fn send_response(
484            &self,
485            _peer_id: &NodeId,
486            _response: &SignedChallengeResponse,
487        ) -> Result<(), SecurityError> {
488            Ok(())
489        }
490
491        async fn receive_challenge(&self, _peer_id: &NodeId) -> Result<Challenge, SecurityError> {
492            Ok(Challenge {
493                nonce: vec![0u8; 32],
494                timestamp: None,
495                challenger_id: "peer".to_string(),
496                expires_at: Some(peat_schema::common::v1::Timestamp {
497                    seconds: u64::MAX,
498                    nanos: 0,
499                }),
500            })
501        }
502    }
503
504    #[tokio::test]
505    async fn test_secure_transport_creation() {
506        let keypair = DeviceKeypair::generate();
507        let transport = Arc::new(MockTransport::new());
508        let auth_channel = Arc::new(MockAuthChannel::new());
509
510        let secure = SecureMeshTransport::new(keypair, transport, auth_channel);
511
512        assert_eq!(secure.authenticated_peer_count(), 0);
513    }
514
515    #[tokio::test]
516    async fn test_secure_transport_start_stop() {
517        let keypair = DeviceKeypair::generate();
518        let transport = Arc::new(MockTransport::new());
519        let auth_channel = Arc::new(MockAuthChannel::new());
520
521        let secure = SecureMeshTransport::new(keypair, transport.clone(), auth_channel);
522
523        assert!(!transport.started.load(Ordering::SeqCst));
524        secure.start().await.unwrap();
525        assert!(transport.started.load(Ordering::SeqCst));
526        secure.stop().await.unwrap();
527        assert!(!transport.started.load(Ordering::SeqCst));
528    }
529
530    #[tokio::test]
531    async fn test_secure_transport_connect_authenticates() {
532        let our_keypair = DeviceKeypair::generate();
533        let peer_keypair = DeviceKeypair::generate();
534        let peer_id: NodeId = peer_keypair.device_id().into();
535
536        let transport = Arc::new(MockTransport::new());
537        let auth_channel = Arc::new(MockAuthChannel::new());
538        auth_channel.register_peer_keypair(&peer_id, peer_keypair.clone());
539
540        let secure = SecureMeshTransport::new(our_keypair, transport, auth_channel);
541
542        secure.start().await.unwrap();
543        let conn = secure.connect(&peer_id).await.unwrap();
544
545        assert!(secure.is_authenticated(&peer_id));
546        assert_eq!(conn.peer_id(), &peer_id);
547        assert!(conn.is_alive());
548    }
549
550    #[tokio::test]
551    async fn test_secure_transport_disconnect_removes_auth() {
552        let our_keypair = DeviceKeypair::generate();
553        let peer_keypair = DeviceKeypair::generate();
554        let peer_id: NodeId = peer_keypair.device_id().into();
555
556        let transport = Arc::new(MockTransport::new());
557        let auth_channel = Arc::new(MockAuthChannel::new());
558        auth_channel.register_peer_keypair(&peer_id, peer_keypair);
559
560        let secure = SecureMeshTransport::new(our_keypair, transport, auth_channel);
561
562        secure.start().await.unwrap();
563        secure.connect(&peer_id).await.unwrap();
564        assert!(secure.is_authenticated(&peer_id));
565
566        secure.disconnect(&peer_id).await.unwrap();
567        assert!(!secure.is_authenticated(&peer_id));
568    }
569
570    #[tokio::test]
571    async fn test_authenticated_connection_exposes_device_id() {
572        let our_keypair = DeviceKeypair::generate();
573        let peer_keypair = DeviceKeypair::generate();
574        let peer_device_id = peer_keypair.device_id();
575        let peer_id: NodeId = peer_device_id.into();
576
577        let transport = Arc::new(MockTransport::new());
578        let auth_channel = Arc::new(MockAuthChannel::new());
579        auth_channel.register_peer_keypair(&peer_id, peer_keypair);
580
581        let secure = SecureMeshTransport::new(our_keypair, transport, auth_channel);
582
583        secure.start().await.unwrap();
584        let _conn = secure.connect(&peer_id).await.unwrap();
585
586        // Verify we can get the peer's device ID through the transport
587        assert!(secure.is_authenticated(&peer_id));
588        assert_eq!(secure.get_peer_device_id(&peer_id), Some(peer_device_id));
589    }
590}