Skip to main content

iroh_auth/
lib.rs

1use n0_watcher::Watchable;
2use std::{
3    collections::BTreeSet,
4    sync::{Arc, Mutex},
5    time::Duration,
6};
7use tracing::{trace, debug, error, info, warn};
8
9use hkdf::Hkdf;
10use iroh::{
11    endpoint::{AfterHandshakeOutcome, Connection, EndpointHooks, VarInt},
12    protocol::ProtocolHandler,
13    Endpoint, PublicKey, Watcher,
14};
15use n0_future::{task::spawn, time::timeout, StreamExt};
16use secrecy::{ExposeSecret, SecretSlice};
17use sha2::Sha512;
18use spake2::{Ed25519Group, Identity, Password, Spake2};
19use subtle::ConstantTimeEq;
20
21// Errors
22#[derive(Debug)]
23pub enum AuthenticatorError {
24    AddFailed,
25    AcceptFailed(String),
26    OpenFailed(String),
27    EndpointNotSet,
28}
29
30impl std::fmt::Display for AuthenticatorError {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            AuthenticatorError::AddFailed => write!(f, "Failed to add authenticated ID"),
34            AuthenticatorError::AcceptFailed(msg) => write!(f, "Accept failed: {}", msg),
35            AuthenticatorError::OpenFailed(msg) => write!(f, "Open failed: {}", msg),
36            AuthenticatorError::EndpointNotSet => write!(
37                f,
38                "Authenticator endpoint not set: missing authenticator.start(endpoint)"
39            ),
40        }
41    }
42}
43
44impl std::error::Error for AuthenticatorError {}
45
46pub trait IntoSecret {
47    fn into_secret(self) -> SecretSlice<u8>;
48}
49
50impl IntoSecret for SecretSlice<u8> {
51    fn into_secret(self) -> SecretSlice<u8> {
52        self
53    }
54}
55
56impl IntoSecret for String {
57    fn into_secret(self) -> SecretSlice<u8> {
58        SecretSlice::new(self.into_bytes().into_boxed_slice())
59    }
60}
61
62impl IntoSecret for &str {
63    fn into_secret(self) -> SecretSlice<u8> {
64        SecretSlice::new(self.as_bytes().to_vec().into_boxed_slice())
65    }
66}
67
68impl IntoSecret for Vec<u8> {
69    fn into_secret(self) -> SecretSlice<u8> {
70        SecretSlice::new(self.into_boxed_slice())
71    }
72}
73
74impl IntoSecret for &[u8] {
75    fn into_secret(self) -> SecretSlice<u8> {
76        SecretSlice::new(self.to_vec().into_boxed_slice())
77    }
78}
79
80impl<const N: usize> IntoSecret for &[u8; N] {
81    fn into_secret(self) -> SecretSlice<u8> {
82        SecretSlice::new(self.as_slice().to_vec().into_boxed_slice())
83    }
84}
85
86impl IntoSecret for Box<[u8]> {
87    fn into_secret(self) -> SecretSlice<u8> {
88        SecretSlice::new(self)
89    }
90}
91
92#[derive(Debug, Clone, Default, PartialEq, Eq)]
93struct WatchableCounter {
94    authenticated: usize,
95    blocked: usize,
96}
97
98#[derive(Debug, Clone)]
99pub struct Authenticator {
100    secret: SecretSlice<u8>,
101    authenticated: Arc<Mutex<BTreeSet<PublicKey>>>,
102    watcher: Watchable<WatchableCounter>,
103    endpoint: Arc<Mutex<Option<iroh::Endpoint>>>,
104}
105
106pub const ALPN: &[u8] = b"/iroh/auth/0.1";
107
108impl Authenticator {
109    pub const ALPN: &'static [u8] = ALPN;
110    const ACCEPT_CONTEXT: &'static [u8] = b"iroh-auth-accept";
111    const OPEN_CONTEXT: &'static [u8] = b"iroh-auth-open";
112
113    pub fn new<S: IntoSecret>(secret: S) -> Self {
114        Self {
115            secret: secret.into_secret(),
116            authenticated: Arc::new(Mutex::new(BTreeSet::new())),
117            watcher: Watchable::new(WatchableCounter::default()),
118            endpoint: Arc::new(Mutex::new(None)),
119        }
120    }
121
122    pub fn set_endpoint(&self, endpoint: &Endpoint) {
123        if let Ok(mut guard) = self.endpoint.lock() {
124            if guard.is_none() {
125                *guard = Some(endpoint.clone());
126                trace!("Authenticator endpoint set to {}", endpoint.id());
127            }
128        }
129    }
130
131    fn id(&self) -> Result<PublicKey, AuthenticatorError> {
132        self.endpoint
133            .lock()
134            .map_err(|_| AuthenticatorError::EndpointNotSet)?
135            .as_ref()
136            .map(|ep| ep.id())
137            .ok_or(AuthenticatorError::EndpointNotSet)
138    }
139
140    fn endpoint(&self) -> Result<iroh::Endpoint, AuthenticatorError> {
141        self.endpoint
142            .lock()
143            .map_err(|_| AuthenticatorError::EndpointNotSet)?
144            .as_ref()
145            .cloned()
146            .ok_or(AuthenticatorError::EndpointNotSet)
147    }
148
149    fn is_authenticated(&self, id: &PublicKey) -> bool {
150        self.authenticated
151            .lock()
152            .map(|set| set.contains(id))
153            .unwrap_or(false)
154    }
155
156    fn add_authenticated(&self, id: PublicKey) -> Result<(), AuthenticatorError> {
157        self.authenticated
158            .lock()
159            .map_err(|_| AuthenticatorError::AddFailed)?
160            .insert(id);
161        let mut counter = self.watcher.get();
162        counter.authenticated += 1;
163        self.watcher
164            .set(counter)
165            .map_err(|_| AuthenticatorError::AddFailed)?;
166        Ok(())
167    }
168
169    fn add_blocked(&self) -> Result<(), AuthenticatorError> {
170        let mut counter = self.watcher.get();
171        counter.blocked += 1;
172        self.watcher
173            .set(counter)
174            .map_err(|_| AuthenticatorError::AddFailed)?;
175        Ok(())
176    }
177
178    #[doc(hidden)]
179    pub fn list_authenticated(&self) -> Vec<PublicKey> {
180        self.authenticated
181            .lock()
182            .map(|set| set.iter().cloned().collect())
183            .unwrap_or_default()
184    }
185
186    /// Accept an incoming connection and perform SPAKE2 authentication.
187    /// On success, adds the remote ID to the authenticated set.
188    /// Returns Ok(()) on success, or an AuthenticatorError on failure.
189    async fn auth_accept(&self, conn: Connection) -> Result<(), AuthenticatorError> {
190        let remote_id = conn.remote_id();
191        debug!("accepting auth connection from {}", remote_id);
192        let (mut send, mut recv) = conn.accept_bi().await.map_err(|err| {
193            error!("accept bidirectional stream failed: {}", err);
194            AuthenticatorError::AcceptFailed(format!("Accept bidirectional stream failed: {}", err))
195        })?;
196
197        let (spake, token_b) = Spake2::<Ed25519Group>::start_b(
198            &Password::new(self.secret.expose_secret()),
199            &Identity::new(conn.remote_id().as_bytes()),
200            &Identity::new(self.id()?.as_bytes()),
201        );
202
203        let mut token_a = [0u8; 33];
204        recv.read_exact(&mut token_a).await.map_err(|err| {
205            error!("failed to read token_a: {}", err);
206            AuthenticatorError::AcceptFailed(format!("Failed to read token_a: {}", err))
207        })?;
208
209        send.write_all(&token_b).await.map_err(|err| {
210            error!("failed to write token_b: {}", err);
211            AuthenticatorError::AcceptFailed(format!("Failed to write token_b: {}", err))
212        })?;
213
214        let shared_secret = spake.finish(&token_a).map_err(|err| {
215            error!("SPAKE2 invalid: {}", err);
216            AuthenticatorError::AcceptFailed(format!("SPAKE2 invalid: {}", err))
217        })?;
218
219        let hk = Hkdf::<Sha512>::new(None, shared_secret.as_slice());
220        let mut accept_key = [0u8; 64];
221        let mut open_key = [0u8; 64];
222        hk.expand(Self::ACCEPT_CONTEXT, &mut accept_key)
223            .map_err(|err| {
224                error!("failed to expand accept_key: {}", err);
225                AuthenticatorError::AcceptFailed(format!("Failed to expand accept_key: {}", err))
226            })?;
227        hk.expand(Self::OPEN_CONTEXT, &mut open_key)
228            .map_err(|err| {
229                error!("failed to expand open_key: {}", err);
230                AuthenticatorError::AcceptFailed(format!("Failed to expand open_key: {}", err))
231            })?;
232
233        send.write_all(&accept_key).await.map_err(|err| {
234            error!("failed to write accept_key: {}", err);
235            AuthenticatorError::AcceptFailed(format!("Failed to write accept_key: {}", err))
236        })?;
237        let mut remote_open_key = [0u8; 64];
238        recv.read_exact(&mut remote_open_key).await.map_err(|err| {
239            error!("failed to read remote_open_key: {}", err);
240            AuthenticatorError::AcceptFailed(format!("Failed to read remote_open_key: {}", err))
241        })?;
242
243        if !bool::from(remote_open_key.ct_eq(&open_key)) {
244            error!("remote open_key mismatch");
245            return Err(AuthenticatorError::AcceptFailed(
246                "Remote open_key mismatch".to_string(),
247            ));
248        }
249
250        self.add_authenticated(conn.remote_id())?;
251        info!("authenticated connection from {}", remote_id);
252
253        Ok(())
254    }
255
256    /// Open an outgoing connection and perform SPAKE2 authentication.
257    /// On success, adds the remote ID to the authenticated set.
258    /// Returns Ok(()) on success, or an AuthenticatorError on failure.
259    async fn auth_open(&self, conn: Connection) -> Result<(), AuthenticatorError> {
260        let remote_id = conn.remote_id();
261        debug!("opening auth connection to {}", remote_id);
262        let (mut send, mut recv) = conn.open_bi().await.map_err(|err| {
263            error!("open bidirectional stream failed: {}", err);
264            AuthenticatorError::AcceptFailed(format!("Open bidirectional stream failed: {}", err))
265        })?;
266
267        let (spake, token_a) = Spake2::<Ed25519Group>::start_a(
268            &Password::new(self.secret.expose_secret()),
269            &Identity::new(self.id()?.as_bytes()),
270            &Identity::new(conn.remote_id().as_bytes()),
271        );
272
273        send.write_all(&token_a).await.map_err(|err| {
274            error!("failed to write token_a: {}", err);
275            AuthenticatorError::AcceptFailed(format!("Failed to write token_a: {}", err))
276        })?;
277
278        let mut token_b = [0u8; 33];
279        recv.read_exact(&mut token_b).await.map_err(|err| {
280            error!("failed to read token_b: {}", err);
281            AuthenticatorError::AcceptFailed(format!("Failed to read token_b: {}", err))
282        })?;
283
284        let shared_secret = spake.finish(&token_b).map_err(|err| {
285            error!("SPAKE2 invalid: {}", err);
286            AuthenticatorError::AcceptFailed(format!("SPAKE2 invalid: {}", err))
287        })?;
288
289        let hk = Hkdf::<Sha512>::new(None, shared_secret.as_slice());
290        let mut accept_key = [0u8; 64];
291        let mut open_key = [0u8; 64];
292        hk.expand(Self::ACCEPT_CONTEXT, &mut accept_key)
293            .map_err(|err| {
294                error!("failed to expand accept_key: {}", err);
295                AuthenticatorError::AcceptFailed(format!("Failed to expand accept_key: {}", err))
296            })?;
297        hk.expand(Self::OPEN_CONTEXT, &mut open_key)
298            .map_err(|err| {
299                error!("failed to expand open_key: {}", err);
300                AuthenticatorError::AcceptFailed(format!("Failed to expand open_key: {}", err))
301            })?;
302
303        let mut remote_accept_key = [0u8; 64];
304        recv.read_exact(&mut remote_accept_key)
305            .await
306            .map_err(|err| {
307                error!("failed to read remote_accept_key: {}", err);
308                AuthenticatorError::AcceptFailed(format!(
309                    "Failed to read remote_accept_key: {}",
310                    err
311                ))
312            })?;
313
314        if !bool::from(remote_accept_key.ct_eq(&accept_key)) {
315            error!("remote accept_key mismatch");
316            return Err(AuthenticatorError::AcceptFailed(
317                "Remote accept_key mismatch".to_string(),
318            ));
319        }
320
321        send.write_all(&open_key).await.map_err(|err| {
322            error!("failed to write open_key: {}", err);
323            AuthenticatorError::AcceptFailed(format!("Failed to write open_key: {}", err))
324        })?;
325        send.finish().map_err(|err| {
326            error!("failed to finish stream: {}", err);
327            AuthenticatorError::AcceptFailed(format!("Failed to finish stream: {}", err))
328        })?;
329
330        conn.closed().await;
331
332        self.add_authenticated(conn.remote_id())?;
333        info!("authenticated connection to {}", remote_id);
334
335        Ok(())
336    }
337}
338
339impl ProtocolHandler for Authenticator {
340    async fn accept(
341        &self,
342        connection: iroh::endpoint::Connection,
343    ) -> Result<(), iroh::protocol::AcceptError> {
344        if let Err(err) = self
345            .auth_accept(connection)
346            .await
347            .map_err(|err| iroh::protocol::AcceptError::from_err(err))
348        {
349            self.add_blocked().ok();
350            Err(err)
351        } else {
352            Ok(())
353        }
354    }
355}
356
357impl EndpointHooks for Authenticator {
358    async fn after_handshake<'a>(
359        &'a self,
360        conn_info: &'a iroh::endpoint::ConnectionInfo,
361    ) -> iroh::endpoint::AfterHandshakeOutcome {
362        if self.is_authenticated(&conn_info.remote_id()) {
363            debug!("already authenticated: {}", conn_info.remote_id());
364            return AfterHandshakeOutcome::accept();
365        }
366
367        if conn_info.alpn() == Self::ALPN {
368            debug!(
369                "skipping auth for connection with alpn {}",
370                String::from_utf8_lossy(conn_info.alpn())
371            );
372            return AfterHandshakeOutcome::accept();
373        }
374
375        let remote_id = conn_info.remote_id();
376        let counter = self.watcher.get();
377
378        let wait_for_auth = async {
379            let mut stream = self.watcher.watch().stream();
380            while let Some(next_counter) = stream.next().await {
381                if next_counter != counter && self.is_authenticated(&remote_id) {
382                    return Ok(()) as Result<(), AuthenticatorError>;
383                }
384            }
385            Err(AuthenticatorError::AcceptFailed(
386                "Watcher stream ended unexpectedly".to_string(),
387            ))
388        };
389
390        match timeout(Duration::from_secs(10), wait_for_auth).await {
391            Ok(_) => AfterHandshakeOutcome::accept(),
392            Err(_) => {
393                warn!("authentication timed out for {}", remote_id);
394                AfterHandshakeOutcome::Reject {
395                    error_code: VarInt::from_u32(401),
396                    reason: b"Authentication timed out".to_vec(),
397                }
398            }
399        }
400    }
401
402    async fn before_connect<'a>(
403        &'a self,
404        remote_addr: &'a iroh::EndpointAddr,
405        alpn: &'a [u8],
406    ) -> iroh::endpoint::BeforeConnectOutcome {
407        if self.is_authenticated(&remote_addr.id) {
408            debug!("already authenticated: {}", remote_addr.id);
409            return iroh::endpoint::BeforeConnectOutcome::Accept;
410        }
411
412        if alpn == Self::ALPN {
413            debug!(
414                "skipping auth for connection to {} with alpn {:?}",
415                remote_addr.id, alpn
416            );
417            return iroh::endpoint::BeforeConnectOutcome::Accept;
418        }
419
420        debug!(
421            "initiating auth for client connection with alpn {} to {}",
422            String::from_utf8_lossy(alpn),
423            remote_addr.id
424        );
425        let endpoint = match self.endpoint() {
426            Ok(ep) => ep,
427            Err(_) => {
428                warn!("authenticator endpoint not set");
429                return iroh::endpoint::BeforeConnectOutcome::Reject;
430            }
431        };
432        spawn({
433            let auth = self.clone();
434            let remote_id = remote_addr.id;
435
436            async move {
437                debug!("background: connecting to {} for auth", remote_id);
438
439                match endpoint.connect(remote_id, Self::ALPN).await {
440                    Ok(conn) => {
441                        debug!("background: connected to {}, performing auth", remote_id);
442                        if let Err(err) = auth.auth_open(conn).await {
443                            auth.add_blocked().ok();
444                            warn!(
445                                "background: authentication failed for {}: {}",
446                                remote_id, err
447                            );
448                        } else {
449                            debug!("background: authentication successful for {}", remote_id);
450                        }
451                    }
452                    Err(e) => {
453                        warn!(
454                            "background: failed to open connection for authentication to {}: {}",
455                            remote_id, e
456                        );
457                    }
458                };
459            }
460        });
461        iroh::endpoint::BeforeConnectOutcome::Accept
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use iroh::Watcher;
468
469    use super::*;
470    #[test]
471    fn test_token_different() {
472        let password = b"testpassword";
473        let id_a = b"identityA";
474        let id_b = b"identityB";
475
476        let (spake_a, token_a) = Spake2::<Ed25519Group>::start_a(
477            &Password::new(password),
478            &Identity::new(id_a),
479            &Identity::new(id_b),
480        );
481
482        let (spake_b, token_b) = Spake2::<Ed25519Group>::start_b(
483            &Password::new(password),
484            &Identity::new(id_a),
485            &Identity::new(id_b),
486        );
487
488        assert_ne!(token_a, token_b);
489
490        let key_a = spake_a.finish(&token_b).unwrap();
491        let key_b = spake_b.finish(&token_a).unwrap();
492
493        assert_eq!(key_a, key_b);
494    }
495
496    #[derive(Debug, Clone)]
497    struct DummyProtocol;
498    impl ProtocolHandler for DummyProtocol {
499        async fn accept(&self, _conn: Connection) -> Result<(), iroh::protocol::AcceptError> {
500            Ok(())
501        }
502    }
503
504    #[tokio::test(flavor = "multi_thread")]
505    async fn test_auth_success() {
506        let secret = b"supersecrettoken1234567890123456";
507        assert!(run_auth_test(secret, secret).await.unwrap());
508    }
509
510    #[tokio::test(flavor = "multi_thread")]
511    async fn test_auth_failure() {
512        let secret_a = b"supersecrettoken1234567890123456";
513        let secret_b = b"differentsecrettoken123456789012";
514        assert!(!run_auth_test(secret_a, secret_b).await.unwrap());
515    }
516
517    async fn run_auth_test(
518        secret_a: &'static [u8],
519        secret_b: &'static [u8],
520    ) -> Result<bool, String> {
521
522        let auth_a = Authenticator::new(secret_a);
523        let endpoint_a = iroh::Endpoint::builder()
524            .hooks(auth_a.clone())
525            .bind()
526            .await
527            .map_err(|e| e.to_string())?;
528        auth_a.set_endpoint(&endpoint_a);
529        let dummy_a = DummyProtocol;
530
531        let auth_b = Authenticator::new(secret_b);
532        let endpoint_b = iroh::Endpoint::builder()
533            .hooks(auth_b.clone())
534            .bind()
535            .await
536            .map_err(|e| e.to_string())?;
537        auth_b.set_endpoint(&endpoint_b);
538        let dummy_b = DummyProtocol;
539
540        let router_a = iroh::protocol::Router::builder(endpoint_a.clone())
541            .accept(Authenticator::ALPN, auth_a.clone())
542            .accept(b"/dummy/1", dummy_a)
543            .spawn();
544
545        let router_b = iroh::protocol::Router::builder(endpoint_b.clone())
546            .accept(Authenticator::ALPN, auth_b.clone())
547            .accept(b"/dummy/1", dummy_b)
548            .spawn();
549
550        spawn({
551            let endpoint_a = endpoint_a.clone();
552            let endpoint_b = endpoint_b.clone();
553            async move {
554                endpoint_a
555                    .connect(endpoint_b.addr(), b"/dummy/1")
556                    .await
557                    .ok();
558            }
559        });
560
561        let wait_loop = async {
562            use n0_future::StreamExt;
563
564            let wait_a = async {
565                let mut stream = auth_a.watcher.watch().stream();
566                while let Some(counter) = stream.next().await {
567                    if counter.authenticated >= 1 || counter.blocked >= 1 {
568                        break;
569                    }
570                }
571            };
572            let wait_b = async {
573                let mut stream = auth_b.watcher.watch().stream();
574                while let Some(counter) = stream.next().await {
575                    if counter.authenticated >= 1 || counter.blocked >= 1 {
576                        break;
577                    }
578                }
579            };
580            tokio::join!(wait_a, wait_b);
581        };
582
583        if timeout(Duration::from_secs(20), wait_loop).await.is_err() {
584            router_a.shutdown().await.ok();
585            router_b.shutdown().await.ok();
586            return Err("Authentication did not complete in time".to_string());
587        }
588
589        router_a.shutdown().await.ok();
590        router_b.shutdown().await.ok();
591
592        Ok(auth_a.is_authenticated(&endpoint_b.id()) && auth_b.is_authenticated(&endpoint_a.id()))
593    }
594
595    #[test]
596    fn test_into_secret_impls() {
597        use secrecy::ExposeSecret;
598
599        let expected_bytes = b"my-secret-key";
600
601        // &str
602        let secret = "my-secret-key".into_secret();
603        assert_eq!(secret.expose_secret(), expected_bytes);
604
605        // String
606        let secret = String::from("my-secret-key").into_secret();
607        assert_eq!(secret.expose_secret(), expected_bytes);
608        // Vec<u8>
609        let secret = b"my-secret-key".to_vec().into_secret();
610        assert_eq!(secret.expose_secret(), expected_bytes);
611
612        // &[u8]
613        let bytes: &[u8] = b"my-secret-key";
614        let secret = bytes.into_secret();
615        assert_eq!(secret.expose_secret(), expected_bytes);
616
617        // &[u8; N]
618        let bytes: &[u8; 13] = b"my-secret-key";
619        let secret = bytes.into_secret();
620        assert_eq!(secret.expose_secret(), expected_bytes);
621
622        // Box<[u8]>
623        let bytes: Box<[u8]> = Box::new(*b"my-secret-key");
624        let secret = bytes.into_secret();
625        assert_eq!(secret.expose_secret(), expected_bytes);
626
627        // SecretSlice<u8>
628        let ps = SecretSlice::new(Box::new(*b"my-secret-key"));
629        let secret = ps.into_secret();
630        assert_eq!(secret.expose_secret(), expected_bytes);
631    }
632}