Skip to main content

iroh_auth/
lib.rs

1use n0_watcher::Watchable;
2use std::{
3    collections::BTreeSet,
4    sync::{Arc, Mutex},
5    time::Duration,
6};
7use tracing::{trace, debug, error, info, warn};
8
9use hkdf::Hkdf;
10use iroh::{
11    endpoint::{AfterHandshakeOutcome, Connection, EndpointHooks, VarInt},
12    protocol::ProtocolHandler,
13    Endpoint, PublicKey, Watcher,
14};
15use n0_future::{task::spawn, time::timeout, StreamExt};
16use secrecy::{ExposeSecret, SecretSlice};
17use sha2::Sha512;
18use spake2::{Ed25519Group, Identity, Password, Spake2};
19use subtle::ConstantTimeEq;
20
21// Errors
22#[derive(Debug)]
23pub enum AuthenticatorError {
24    AddFailed,
25    AcceptFailed(String),
26    OpenFailed(String),
27    EndpointNotSet,
28}
29
30impl std::fmt::Display for AuthenticatorError {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            AuthenticatorError::AddFailed => write!(f, "Failed to add authenticated ID"),
34            AuthenticatorError::AcceptFailed(msg) => write!(f, "Accept failed: {}", msg),
35            AuthenticatorError::OpenFailed(msg) => write!(f, "Open failed: {}", msg),
36            AuthenticatorError::EndpointNotSet => write!(
37                f,
38                "Authenticator endpoint not set: missing authenticator.start(endpoint)"
39            ),
40        }
41    }
42}
43
44impl std::error::Error for AuthenticatorError {}
45
46pub trait IntoSecret {
47    fn into_secret(self) -> SecretSlice<u8>;
48}
49
50impl IntoSecret for SecretSlice<u8> {
51    fn into_secret(self) -> SecretSlice<u8> {
52        self
53    }
54}
55
56impl IntoSecret for String {
57    fn into_secret(self) -> SecretSlice<u8> {
58        SecretSlice::new(self.into_bytes().into_boxed_slice())
59    }
60}
61
62impl IntoSecret for &str {
63    fn into_secret(self) -> SecretSlice<u8> {
64        SecretSlice::new(self.as_bytes().to_vec().into_boxed_slice())
65    }
66}
67
68impl IntoSecret for Vec<u8> {
69    fn into_secret(self) -> SecretSlice<u8> {
70        SecretSlice::new(self.into_boxed_slice())
71    }
72}
73
74impl IntoSecret for &[u8] {
75    fn into_secret(self) -> SecretSlice<u8> {
76        SecretSlice::new(self.to_vec().into_boxed_slice())
77    }
78}
79
80impl<const N: usize> IntoSecret for &[u8; N] {
81    fn into_secret(self) -> SecretSlice<u8> {
82        SecretSlice::new(self.as_slice().to_vec().into_boxed_slice())
83    }
84}
85
86impl IntoSecret for Box<[u8]> {
87    fn into_secret(self) -> SecretSlice<u8> {
88        SecretSlice::new(self)
89    }
90}
91
92#[derive(Debug, Clone, Default, PartialEq, Eq)]
93struct WatchableCounter {
94    authenticated: usize,
95    blocked: usize,
96}
97
98#[derive(Debug, Clone)]
99pub struct Authenticator {
100    secret: SecretSlice<u8>,
101    authenticated: Arc<Mutex<BTreeSet<PublicKey>>>,
102    watcher: Watchable<WatchableCounter>,
103    endpoint: Arc<Mutex<Option<iroh::Endpoint>>>,
104}
105
106impl Authenticator {
107    pub const ALPN: &'static [u8] = b"/iroh/auth/0.1";
108    const ACCEPT_CONTEXT: &'static [u8] = b"iroh-auth-accept";
109    const OPEN_CONTEXT: &'static [u8] = b"iroh-auth-open";
110
111    pub fn new<S: IntoSecret>(secret: S) -> Self {
112        Self {
113            secret: secret.into_secret(),
114            authenticated: Arc::new(Mutex::new(BTreeSet::new())),
115            watcher: Watchable::new(WatchableCounter::default()),
116            endpoint: Arc::new(Mutex::new(None)),
117        }
118    }
119
120    pub fn set_endpoint(&self, endpoint: &Endpoint) {
121        if let Ok(mut guard) = self.endpoint.lock() {
122            if guard.is_none() {
123                *guard = Some(endpoint.clone());
124                trace!("Authenticator endpoint set to {}", endpoint.id());
125            }
126        }
127    }
128
129    fn id(&self) -> Result<PublicKey, AuthenticatorError> {
130        self.endpoint
131            .lock()
132            .map_err(|_| AuthenticatorError::EndpointNotSet)?
133            .as_ref()
134            .map(|ep| ep.id())
135            .ok_or(AuthenticatorError::EndpointNotSet)
136    }
137
138    fn endpoint(&self) -> Result<iroh::Endpoint, AuthenticatorError> {
139        self.endpoint
140            .lock()
141            .map_err(|_| AuthenticatorError::EndpointNotSet)?
142            .as_ref()
143            .cloned()
144            .ok_or(AuthenticatorError::EndpointNotSet)
145    }
146
147    pub fn is_authenticated(&self, id: &PublicKey) -> bool {
148        self.authenticated
149            .lock()
150            .map(|set| set.contains(id))
151            .unwrap_or(false)
152    }
153
154    fn add_authenticated(&self, id: PublicKey) -> Result<(), AuthenticatorError> {
155        self.authenticated
156            .lock()
157            .map_err(|_| AuthenticatorError::AddFailed)?
158            .insert(id);
159        let mut counter = self.watcher.get();
160        counter.authenticated += 1;
161        self.watcher
162            .set(counter)
163            .map_err(|_| AuthenticatorError::AddFailed)?;
164        Ok(())
165    }
166
167    fn add_blocked(&self) -> Result<(), AuthenticatorError> {
168        let mut counter = self.watcher.get();
169        counter.blocked += 1;
170        self.watcher
171            .set(counter)
172            .map_err(|_| AuthenticatorError::AddFailed)?;
173        Ok(())
174    }
175
176    #[doc(hidden)]
177    pub fn list_authenticated(&self) -> Vec<PublicKey> {
178        self.authenticated
179            .lock()
180            .map(|set| set.iter().cloned().collect())
181            .unwrap_or_default()
182    }
183
184    /// Accept an incoming connection and perform SPAKE2 authentication.
185    /// On success, adds the remote ID to the authenticated set.
186    /// Returns Ok(()) on success, or an AuthenticatorError on failure.
187    pub async fn auth_accept(&self, conn: Connection) -> Result<(), AuthenticatorError> {
188        let remote_id = conn.remote_id();
189        debug!("accepting auth connection from {}", remote_id);
190        let (mut send, mut recv) = conn.accept_bi().await.map_err(|err| {
191            error!("accept bidirectional stream failed: {}", err);
192            AuthenticatorError::AcceptFailed(format!("Accept bidirectional stream failed: {}", err))
193        })?;
194
195        let (spake, token_b) = Spake2::<Ed25519Group>::start_b(
196            &Password::new(self.secret.expose_secret()),
197            &Identity::new(conn.remote_id().as_bytes()),
198            &Identity::new(self.id()?.as_bytes()),
199        );
200
201        let mut token_a = [0u8; 33];
202        recv.read_exact(&mut token_a).await.map_err(|err| {
203            error!("failed to read token_a: {}", err);
204            AuthenticatorError::AcceptFailed(format!("Failed to read token_a: {}", err))
205        })?;
206
207        send.write_all(&token_b).await.map_err(|err| {
208            error!("failed to write token_b: {}", err);
209            AuthenticatorError::AcceptFailed(format!("Failed to write token_b: {}", err))
210        })?;
211
212        let shared_secret = spake.finish(&token_a).map_err(|err| {
213            error!("SPAKE2 invalid: {}", err);
214            AuthenticatorError::AcceptFailed(format!("SPAKE2 invalid: {}", err))
215        })?;
216
217        let hk = Hkdf::<Sha512>::new(None, shared_secret.as_slice());
218        let mut accept_key = [0u8; 64];
219        let mut open_key = [0u8; 64];
220        hk.expand(Self::ACCEPT_CONTEXT, &mut accept_key)
221            .map_err(|err| {
222                error!("failed to expand accept_key: {}", err);
223                AuthenticatorError::AcceptFailed(format!("Failed to expand accept_key: {}", err))
224            })?;
225        hk.expand(Self::OPEN_CONTEXT, &mut open_key)
226            .map_err(|err| {
227                error!("failed to expand open_key: {}", err);
228                AuthenticatorError::AcceptFailed(format!("Failed to expand open_key: {}", err))
229            })?;
230
231        send.write_all(&accept_key).await.map_err(|err| {
232            error!("failed to write accept_key: {}", err);
233            AuthenticatorError::AcceptFailed(format!("Failed to write accept_key: {}", err))
234        })?;
235        let mut remote_open_key = [0u8; 64];
236        recv.read_exact(&mut remote_open_key).await.map_err(|err| {
237            error!("failed to read remote_open_key: {}", err);
238            AuthenticatorError::AcceptFailed(format!("Failed to read remote_open_key: {}", err))
239        })?;
240
241        if !bool::from(remote_open_key.ct_eq(&open_key)) {
242            error!("remote open_key mismatch");
243            return Err(AuthenticatorError::AcceptFailed(
244                "Remote open_key mismatch".to_string(),
245            ));
246        }
247
248        self.add_authenticated(conn.remote_id())?;
249        info!("authenticated connection from {}", remote_id);
250
251        Ok(())
252    }
253
254    /// Open an outgoing connection and perform SPAKE2 authentication.
255    /// On success, adds the remote ID to the authenticated set.
256    /// Returns Ok(()) on success, or an AuthenticatorError on failure.
257    pub async fn auth_open(&self, conn: Connection) -> Result<(), AuthenticatorError> {
258        let remote_id = conn.remote_id();
259        debug!("opening auth connection to {}", remote_id);
260        let (mut send, mut recv) = conn.open_bi().await.map_err(|err| {
261            error!("open bidirectional stream failed: {}", err);
262            AuthenticatorError::AcceptFailed(format!("Open bidirectional stream failed: {}", err))
263        })?;
264
265        let (spake, token_a) = Spake2::<Ed25519Group>::start_a(
266            &Password::new(self.secret.expose_secret()),
267            &Identity::new(self.id()?.as_bytes()),
268            &Identity::new(conn.remote_id().as_bytes()),
269        );
270
271        send.write_all(&token_a).await.map_err(|err| {
272            error!("failed to write token_a: {}", err);
273            AuthenticatorError::AcceptFailed(format!("Failed to write token_a: {}", err))
274        })?;
275
276        let mut token_b = [0u8; 33];
277        recv.read_exact(&mut token_b).await.map_err(|err| {
278            error!("failed to read token_b: {}", err);
279            AuthenticatorError::AcceptFailed(format!("Failed to read token_b: {}", err))
280        })?;
281
282        let shared_secret = spake.finish(&token_b).map_err(|err| {
283            error!("SPAKE2 invalid: {}", err);
284            AuthenticatorError::AcceptFailed(format!("SPAKE2 invalid: {}", err))
285        })?;
286
287        let hk = Hkdf::<Sha512>::new(None, shared_secret.as_slice());
288        let mut accept_key = [0u8; 64];
289        let mut open_key = [0u8; 64];
290        hk.expand(Self::ACCEPT_CONTEXT, &mut accept_key)
291            .map_err(|err| {
292                error!("failed to expand accept_key: {}", err);
293                AuthenticatorError::AcceptFailed(format!("Failed to expand accept_key: {}", err))
294            })?;
295        hk.expand(Self::OPEN_CONTEXT, &mut open_key)
296            .map_err(|err| {
297                error!("failed to expand open_key: {}", err);
298                AuthenticatorError::AcceptFailed(format!("Failed to expand open_key: {}", err))
299            })?;
300
301        let mut remote_accept_key = [0u8; 64];
302        recv.read_exact(&mut remote_accept_key)
303            .await
304            .map_err(|err| {
305                error!("failed to read remote_accept_key: {}", err);
306                AuthenticatorError::AcceptFailed(format!(
307                    "Failed to read remote_accept_key: {}",
308                    err
309                ))
310            })?;
311
312        if !bool::from(remote_accept_key.ct_eq(&accept_key)) {
313            error!("remote accept_key mismatch");
314            return Err(AuthenticatorError::AcceptFailed(
315                "Remote accept_key mismatch".to_string(),
316            ));
317        }
318
319        send.write_all(&open_key).await.map_err(|err| {
320            error!("failed to write open_key: {}", err);
321            AuthenticatorError::AcceptFailed(format!("Failed to write open_key: {}", err))
322        })?;
323        send.finish().map_err(|err| {
324            error!("failed to finish stream: {}", err);
325            AuthenticatorError::AcceptFailed(format!("Failed to finish stream: {}", err))
326        })?;
327
328        conn.closed().await;
329
330        self.add_authenticated(conn.remote_id())?;
331        info!("authenticated connection to {}", remote_id);
332
333        Ok(())
334    }
335}
336
337impl ProtocolHandler for Authenticator {
338    async fn accept(
339        &self,
340        connection: iroh::endpoint::Connection,
341    ) -> Result<(), iroh::protocol::AcceptError> {
342        if let Err(err) = self
343            .auth_accept(connection)
344            .await
345            .map_err(|err| iroh::protocol::AcceptError::from_err(err))
346        {
347            self.add_blocked().ok();
348            Err(err)
349        } else {
350            Ok(())
351        }
352    }
353}
354
355impl EndpointHooks for Authenticator {
356    async fn after_handshake<'a>(
357        &'a self,
358        conn_info: &'a iroh::endpoint::ConnectionInfo,
359    ) -> iroh::endpoint::AfterHandshakeOutcome {
360        if self.is_authenticated(&conn_info.remote_id()) {
361            debug!("already authenticated: {}", conn_info.remote_id());
362            return AfterHandshakeOutcome::accept();
363        }
364
365        if conn_info.alpn() == Self::ALPN {
366            debug!(
367                "skipping auth for connection with alpn {}",
368                String::from_utf8_lossy(conn_info.alpn())
369            );
370            return AfterHandshakeOutcome::accept();
371        }
372
373        let remote_id = conn_info.remote_id();
374        let counter = self.watcher.get();
375
376        let wait_for_auth = async {
377            let mut stream = self.watcher.watch().stream();
378            while let Some(next_counter) = stream.next().await {
379                if next_counter != counter && self.is_authenticated(&remote_id) {
380                    return Ok(()) as Result<(), AuthenticatorError>;
381                }
382            }
383            Err(AuthenticatorError::AcceptFailed(
384                "Watcher stream ended unexpectedly".to_string(),
385            ))
386        };
387
388        match timeout(Duration::from_secs(10), wait_for_auth).await {
389            Ok(_) => AfterHandshakeOutcome::accept(),
390            Err(_) => {
391                warn!("authentication timed out for {}", remote_id);
392                AfterHandshakeOutcome::Reject {
393                    error_code: VarInt::from_u32(401),
394                    reason: b"Authentication timed out".to_vec(),
395                }
396            }
397        }
398    }
399
400    async fn before_connect<'a>(
401        &'a self,
402        remote_addr: &'a iroh::EndpointAddr,
403        alpn: &'a [u8],
404    ) -> iroh::endpoint::BeforeConnectOutcome {
405        if self.is_authenticated(&remote_addr.id) {
406            debug!("already authenticated: {}", remote_addr.id);
407            return iroh::endpoint::BeforeConnectOutcome::Accept;
408        }
409
410        if alpn == Self::ALPN {
411            debug!(
412                "skipping auth for connection to {} with alpn {:?}",
413                remote_addr.id, alpn
414            );
415            return iroh::endpoint::BeforeConnectOutcome::Accept;
416        }
417
418        debug!(
419            "initiating auth for client connection with alpn {} to {}",
420            String::from_utf8_lossy(alpn),
421            remote_addr.id
422        );
423        let endpoint = match self.endpoint() {
424            Ok(ep) => ep,
425            Err(_) => {
426                warn!("authenticator endpoint not set");
427                return iroh::endpoint::BeforeConnectOutcome::Reject;
428            }
429        };
430        spawn({
431            let auth = self.clone();
432            let remote_id = remote_addr.id;
433
434            async move {
435                debug!("background: connecting to {} for auth", remote_id);
436
437                match endpoint.connect(remote_id, Self::ALPN).await {
438                    Ok(conn) => {
439                        debug!("background: connected to {}, performing auth", remote_id);
440                        if let Err(err) = auth.auth_open(conn).await {
441                            auth.add_blocked().ok();
442                            warn!(
443                                "background: authentication failed for {}: {}",
444                                remote_id, err
445                            );
446                        } else {
447                            debug!("background: authentication successful for {}", remote_id);
448                        }
449                    }
450                    Err(e) => {
451                        warn!(
452                            "background: failed to open connection for authentication to {}: {}",
453                            remote_id, e
454                        );
455                    }
456                };
457            }
458        });
459        iroh::endpoint::BeforeConnectOutcome::Accept
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use iroh::Watcher;
466
467    use super::*;
468    #[test]
469    fn test_token_different() {
470        let password = b"testpassword";
471        let id_a = b"identityA";
472        let id_b = b"identityB";
473
474        let (spake_a, token_a) = Spake2::<Ed25519Group>::start_a(
475            &Password::new(password),
476            &Identity::new(id_a),
477            &Identity::new(id_b),
478        );
479
480        let (spake_b, token_b) = Spake2::<Ed25519Group>::start_b(
481            &Password::new(password),
482            &Identity::new(id_a),
483            &Identity::new(id_b),
484        );
485
486        assert_ne!(token_a, token_b);
487
488        let key_a = spake_a.finish(&token_b).unwrap();
489        let key_b = spake_b.finish(&token_a).unwrap();
490
491        assert_eq!(key_a, key_b);
492    }
493
494    #[derive(Debug, Clone)]
495    struct DummyProtocol;
496    impl ProtocolHandler for DummyProtocol {
497        async fn accept(&self, _conn: Connection) -> Result<(), iroh::protocol::AcceptError> {
498            Ok(())
499        }
500    }
501
502    #[tokio::test(flavor = "multi_thread")]
503    async fn test_auth_success() {
504        let secret = b"supersecrettoken1234567890123456";
505        assert!(run_auth_test(secret, secret).await.unwrap());
506    }
507
508    #[tokio::test(flavor = "multi_thread")]
509    async fn test_auth_failure() {
510        let secret_a = b"supersecrettoken1234567890123456";
511        let secret_b = b"differentsecrettoken123456789012";
512        assert!(!run_auth_test(secret_a, secret_b).await.unwrap());
513    }
514
515    async fn run_auth_test(
516        secret_a: &'static [u8],
517        secret_b: &'static [u8],
518    ) -> Result<bool, String> {
519
520        let auth_a = Authenticator::new(secret_a);
521        let endpoint_a = iroh::Endpoint::builder()
522            .hooks(auth_a.clone())
523            .bind()
524            .await
525            .map_err(|e| e.to_string())?;
526        auth_a.set_endpoint(&endpoint_a);
527        let dummy_a = DummyProtocol;
528
529        let auth_b = Authenticator::new(secret_b);
530        let endpoint_b = iroh::Endpoint::builder()
531            .hooks(auth_b.clone())
532            .bind()
533            .await
534            .map_err(|e| e.to_string())?;
535        auth_b.set_endpoint(&endpoint_b);
536        let dummy_b = DummyProtocol;
537
538        let router_a = iroh::protocol::Router::builder(endpoint_a.clone())
539            .accept(Authenticator::ALPN, auth_a.clone())
540            .accept(b"/dummy/1", dummy_a)
541            .spawn();
542
543        let router_b = iroh::protocol::Router::builder(endpoint_b.clone())
544            .accept(Authenticator::ALPN, auth_b.clone())
545            .accept(b"/dummy/1", dummy_b)
546            .spawn();
547
548        spawn({
549            let endpoint_a = endpoint_a.clone();
550            let endpoint_b = endpoint_b.clone();
551            async move {
552                endpoint_a
553                    .connect(endpoint_b.addr(), b"/dummy/1")
554                    .await
555                    .ok();
556            }
557        });
558
559        let wait_loop = async {
560            use n0_future::StreamExt;
561
562            let wait_a = async {
563                let mut stream = auth_a.watcher.watch().stream();
564                while let Some(counter) = stream.next().await {
565                    if counter.authenticated >= 1 || counter.blocked >= 1 {
566                        break;
567                    }
568                }
569            };
570            let wait_b = async {
571                let mut stream = auth_b.watcher.watch().stream();
572                while let Some(counter) = stream.next().await {
573                    if counter.authenticated >= 1 || counter.blocked >= 1 {
574                        break;
575                    }
576                }
577            };
578            tokio::join!(wait_a, wait_b);
579        };
580
581        if timeout(Duration::from_secs(20), wait_loop).await.is_err() {
582            router_a.shutdown().await.ok();
583            router_b.shutdown().await.ok();
584            return Err("Authentication did not complete in time".to_string());
585        }
586
587        router_a.shutdown().await.ok();
588        router_b.shutdown().await.ok();
589
590        Ok(auth_a.is_authenticated(&endpoint_b.id()) && auth_b.is_authenticated(&endpoint_a.id()))
591    }
592
593    #[test]
594    fn test_into_secret_impls() {
595        use secrecy::ExposeSecret;
596
597        let expected_bytes = b"my-secret-key";
598
599        // &str
600        let secret = "my-secret-key".into_secret();
601        assert_eq!(secret.expose_secret(), expected_bytes);
602
603        // String
604        let secret = String::from("my-secret-key").into_secret();
605        assert_eq!(secret.expose_secret(), expected_bytes);
606        // Vec<u8>
607        let secret = b"my-secret-key".to_vec().into_secret();
608        assert_eq!(secret.expose_secret(), expected_bytes);
609
610        // &[u8]
611        let bytes: &[u8] = b"my-secret-key";
612        let secret = bytes.into_secret();
613        assert_eq!(secret.expose_secret(), expected_bytes);
614
615        // &[u8; N]
616        let bytes: &[u8; 13] = b"my-secret-key";
617        let secret = bytes.into_secret();
618        assert_eq!(secret.expose_secret(), expected_bytes);
619
620        // Box<[u8]>
621        let bytes: Box<[u8]> = Box::new(*b"my-secret-key");
622        let secret = bytes.into_secret();
623        assert_eq!(secret.expose_secret(), expected_bytes);
624
625        // SecretSlice<u8>
626        let ps = SecretSlice::new(Box::new(*b"my-secret-key"));
627        let secret = ps.into_secret();
628        assert_eq!(secret.expose_secret(), expected_bytes);
629    }
630}