1use n0_watcher::Watchable;
2use std::{
3 collections::BTreeSet,
4 sync::{Arc, Mutex},
5 time::Duration,
6};
7use tracing::{trace, debug, error, info, warn};
8
9use hkdf::Hkdf;
10use iroh::{
11 endpoint::{AfterHandshakeOutcome, Connection, EndpointHooks, VarInt},
12 protocol::ProtocolHandler,
13 Endpoint, PublicKey, Watcher,
14};
15use n0_future::{task::spawn, time::timeout, StreamExt};
16use secrecy::{ExposeSecret, SecretSlice};
17use sha2::Sha512;
18use spake2::{Ed25519Group, Identity, Password, Spake2};
19use subtle::ConstantTimeEq;
20
21#[derive(Debug)]
23pub enum AuthenticatorError {
24 AddFailed,
25 AcceptFailed(String),
26 OpenFailed(String),
27 EndpointNotSet,
28}
29
30impl std::fmt::Display for AuthenticatorError {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 AuthenticatorError::AddFailed => write!(f, "Failed to add authenticated ID"),
34 AuthenticatorError::AcceptFailed(msg) => write!(f, "Accept failed: {}", msg),
35 AuthenticatorError::OpenFailed(msg) => write!(f, "Open failed: {}", msg),
36 AuthenticatorError::EndpointNotSet => write!(
37 f,
38 "Authenticator endpoint not set: missing authenticator.start(endpoint)"
39 ),
40 }
41 }
42}
43
44impl std::error::Error for AuthenticatorError {}
45
46pub trait IntoSecret {
47 fn into_secret(self) -> SecretSlice<u8>;
48}
49
50impl IntoSecret for SecretSlice<u8> {
51 fn into_secret(self) -> SecretSlice<u8> {
52 self
53 }
54}
55
56impl IntoSecret for String {
57 fn into_secret(self) -> SecretSlice<u8> {
58 SecretSlice::new(self.into_bytes().into_boxed_slice())
59 }
60}
61
62impl IntoSecret for &str {
63 fn into_secret(self) -> SecretSlice<u8> {
64 SecretSlice::new(self.as_bytes().to_vec().into_boxed_slice())
65 }
66}
67
68impl IntoSecret for Vec<u8> {
69 fn into_secret(self) -> SecretSlice<u8> {
70 SecretSlice::new(self.into_boxed_slice())
71 }
72}
73
74impl IntoSecret for &[u8] {
75 fn into_secret(self) -> SecretSlice<u8> {
76 SecretSlice::new(self.to_vec().into_boxed_slice())
77 }
78}
79
80impl<const N: usize> IntoSecret for &[u8; N] {
81 fn into_secret(self) -> SecretSlice<u8> {
82 SecretSlice::new(self.as_slice().to_vec().into_boxed_slice())
83 }
84}
85
86impl IntoSecret for Box<[u8]> {
87 fn into_secret(self) -> SecretSlice<u8> {
88 SecretSlice::new(self)
89 }
90}
91
92#[derive(Debug, Clone, Default, PartialEq, Eq)]
93struct WatchableCounter {
94 authenticated: usize,
95 blocked: usize,
96}
97
98#[derive(Debug, Clone)]
99pub struct Authenticator {
100 secret: SecretSlice<u8>,
101 authenticated: Arc<Mutex<BTreeSet<PublicKey>>>,
102 watcher: Watchable<WatchableCounter>,
103 endpoint: Arc<Mutex<Option<iroh::Endpoint>>>,
104}
105
106pub const ALPN: &[u8] = b"/iroh/auth/0.1";
107
108impl Authenticator {
109 pub const ALPN: &'static [u8] = ALPN;
110 const ACCEPT_CONTEXT: &'static [u8] = b"iroh-auth-accept";
111 const OPEN_CONTEXT: &'static [u8] = b"iroh-auth-open";
112
113 pub fn new<S: IntoSecret>(secret: S) -> Self {
114 Self {
115 secret: secret.into_secret(),
116 authenticated: Arc::new(Mutex::new(BTreeSet::new())),
117 watcher: Watchable::new(WatchableCounter::default()),
118 endpoint: Arc::new(Mutex::new(None)),
119 }
120 }
121
122 pub fn set_endpoint(&self, endpoint: &Endpoint) {
123 if let Ok(mut guard) = self.endpoint.lock() {
124 if guard.is_none() {
125 *guard = Some(endpoint.clone());
126 trace!("Authenticator endpoint set to {}", endpoint.id());
127 }
128 }
129 }
130
131 fn id(&self) -> Result<PublicKey, AuthenticatorError> {
132 self.endpoint
133 .lock()
134 .map_err(|_| AuthenticatorError::EndpointNotSet)?
135 .as_ref()
136 .map(|ep| ep.id())
137 .ok_or(AuthenticatorError::EndpointNotSet)
138 }
139
140 fn endpoint(&self) -> Result<iroh::Endpoint, AuthenticatorError> {
141 self.endpoint
142 .lock()
143 .map_err(|_| AuthenticatorError::EndpointNotSet)?
144 .as_ref()
145 .cloned()
146 .ok_or(AuthenticatorError::EndpointNotSet)
147 }
148
149 fn is_authenticated(&self, id: &PublicKey) -> bool {
150 self.authenticated
151 .lock()
152 .map(|set| set.contains(id))
153 .unwrap_or(false)
154 }
155
156 fn add_authenticated(&self, id: PublicKey) -> Result<(), AuthenticatorError> {
157 self.authenticated
158 .lock()
159 .map_err(|_| AuthenticatorError::AddFailed)?
160 .insert(id);
161 let mut counter = self.watcher.get();
162 counter.authenticated += 1;
163 self.watcher
164 .set(counter)
165 .map_err(|_| AuthenticatorError::AddFailed)?;
166 Ok(())
167 }
168
169 fn add_blocked(&self) -> Result<(), AuthenticatorError> {
170 let mut counter = self.watcher.get();
171 counter.blocked += 1;
172 self.watcher
173 .set(counter)
174 .map_err(|_| AuthenticatorError::AddFailed)?;
175 Ok(())
176 }
177
178 #[doc(hidden)]
179 pub fn list_authenticated(&self) -> Vec<PublicKey> {
180 self.authenticated
181 .lock()
182 .map(|set| set.iter().cloned().collect())
183 .unwrap_or_default()
184 }
185
186 async fn auth_accept(&self, conn: Connection) -> Result<(), AuthenticatorError> {
190 let remote_id = conn.remote_id();
191 debug!("accepting auth connection from {}", remote_id);
192 let (mut send, mut recv) = conn.accept_bi().await.map_err(|err| {
193 error!("accept bidirectional stream failed: {}", err);
194 AuthenticatorError::AcceptFailed(format!("Accept bidirectional stream failed: {}", err))
195 })?;
196
197 let (spake, token_b) = Spake2::<Ed25519Group>::start_b(
198 &Password::new(self.secret.expose_secret()),
199 &Identity::new(conn.remote_id().as_bytes()),
200 &Identity::new(self.id()?.as_bytes()),
201 );
202
203 let mut token_a = [0u8; 33];
204 recv.read_exact(&mut token_a).await.map_err(|err| {
205 error!("failed to read token_a: {}", err);
206 AuthenticatorError::AcceptFailed(format!("Failed to read token_a: {}", err))
207 })?;
208
209 send.write_all(&token_b).await.map_err(|err| {
210 error!("failed to write token_b: {}", err);
211 AuthenticatorError::AcceptFailed(format!("Failed to write token_b: {}", err))
212 })?;
213
214 let shared_secret = spake.finish(&token_a).map_err(|err| {
215 error!("SPAKE2 invalid: {}", err);
216 AuthenticatorError::AcceptFailed(format!("SPAKE2 invalid: {}", err))
217 })?;
218
219 let hk = Hkdf::<Sha512>::new(None, shared_secret.as_slice());
220 let mut accept_key = [0u8; 64];
221 let mut open_key = [0u8; 64];
222 hk.expand(Self::ACCEPT_CONTEXT, &mut accept_key)
223 .map_err(|err| {
224 error!("failed to expand accept_key: {}", err);
225 AuthenticatorError::AcceptFailed(format!("Failed to expand accept_key: {}", err))
226 })?;
227 hk.expand(Self::OPEN_CONTEXT, &mut open_key)
228 .map_err(|err| {
229 error!("failed to expand open_key: {}", err);
230 AuthenticatorError::AcceptFailed(format!("Failed to expand open_key: {}", err))
231 })?;
232
233 send.write_all(&accept_key).await.map_err(|err| {
234 error!("failed to write accept_key: {}", err);
235 AuthenticatorError::AcceptFailed(format!("Failed to write accept_key: {}", err))
236 })?;
237 let mut remote_open_key = [0u8; 64];
238 recv.read_exact(&mut remote_open_key).await.map_err(|err| {
239 error!("failed to read remote_open_key: {}", err);
240 AuthenticatorError::AcceptFailed(format!("Failed to read remote_open_key: {}", err))
241 })?;
242
243 if !bool::from(remote_open_key.ct_eq(&open_key)) {
244 error!("remote open_key mismatch");
245 return Err(AuthenticatorError::AcceptFailed(
246 "Remote open_key mismatch".to_string(),
247 ));
248 }
249
250 self.add_authenticated(conn.remote_id())?;
251 info!("authenticated connection from {}", remote_id);
252
253 Ok(())
254 }
255
256 async fn auth_open(&self, conn: Connection) -> Result<(), AuthenticatorError> {
260 let remote_id = conn.remote_id();
261 debug!("opening auth connection to {}", remote_id);
262 let (mut send, mut recv) = conn.open_bi().await.map_err(|err| {
263 error!("open bidirectional stream failed: {}", err);
264 AuthenticatorError::AcceptFailed(format!("Open bidirectional stream failed: {}", err))
265 })?;
266
267 let (spake, token_a) = Spake2::<Ed25519Group>::start_a(
268 &Password::new(self.secret.expose_secret()),
269 &Identity::new(self.id()?.as_bytes()),
270 &Identity::new(conn.remote_id().as_bytes()),
271 );
272
273 send.write_all(&token_a).await.map_err(|err| {
274 error!("failed to write token_a: {}", err);
275 AuthenticatorError::AcceptFailed(format!("Failed to write token_a: {}", err))
276 })?;
277
278 let mut token_b = [0u8; 33];
279 recv.read_exact(&mut token_b).await.map_err(|err| {
280 error!("failed to read token_b: {}", err);
281 AuthenticatorError::AcceptFailed(format!("Failed to read token_b: {}", err))
282 })?;
283
284 let shared_secret = spake.finish(&token_b).map_err(|err| {
285 error!("SPAKE2 invalid: {}", err);
286 AuthenticatorError::AcceptFailed(format!("SPAKE2 invalid: {}", err))
287 })?;
288
289 let hk = Hkdf::<Sha512>::new(None, shared_secret.as_slice());
290 let mut accept_key = [0u8; 64];
291 let mut open_key = [0u8; 64];
292 hk.expand(Self::ACCEPT_CONTEXT, &mut accept_key)
293 .map_err(|err| {
294 error!("failed to expand accept_key: {}", err);
295 AuthenticatorError::AcceptFailed(format!("Failed to expand accept_key: {}", err))
296 })?;
297 hk.expand(Self::OPEN_CONTEXT, &mut open_key)
298 .map_err(|err| {
299 error!("failed to expand open_key: {}", err);
300 AuthenticatorError::AcceptFailed(format!("Failed to expand open_key: {}", err))
301 })?;
302
303 let mut remote_accept_key = [0u8; 64];
304 recv.read_exact(&mut remote_accept_key)
305 .await
306 .map_err(|err| {
307 error!("failed to read remote_accept_key: {}", err);
308 AuthenticatorError::AcceptFailed(format!(
309 "Failed to read remote_accept_key: {}",
310 err
311 ))
312 })?;
313
314 if !bool::from(remote_accept_key.ct_eq(&accept_key)) {
315 error!("remote accept_key mismatch");
316 return Err(AuthenticatorError::AcceptFailed(
317 "Remote accept_key mismatch".to_string(),
318 ));
319 }
320
321 send.write_all(&open_key).await.map_err(|err| {
322 error!("failed to write open_key: {}", err);
323 AuthenticatorError::AcceptFailed(format!("Failed to write open_key: {}", err))
324 })?;
325 send.finish().map_err(|err| {
326 error!("failed to finish stream: {}", err);
327 AuthenticatorError::AcceptFailed(format!("Failed to finish stream: {}", err))
328 })?;
329
330 conn.closed().await;
331
332 self.add_authenticated(conn.remote_id())?;
333 info!("authenticated connection to {}", remote_id);
334
335 Ok(())
336 }
337}
338
339impl ProtocolHandler for Authenticator {
340 async fn accept(
341 &self,
342 connection: iroh::endpoint::Connection,
343 ) -> Result<(), iroh::protocol::AcceptError> {
344 if let Err(err) = self
345 .auth_accept(connection)
346 .await
347 .map_err(|err| iroh::protocol::AcceptError::from_err(err))
348 {
349 self.add_blocked().ok();
350 Err(err)
351 } else {
352 Ok(())
353 }
354 }
355}
356
357impl EndpointHooks for Authenticator {
358 async fn after_handshake<'a>(
359 &'a self,
360 conn_info: &'a iroh::endpoint::ConnectionInfo,
361 ) -> iroh::endpoint::AfterHandshakeOutcome {
362 if self.is_authenticated(&conn_info.remote_id()) {
363 debug!("already authenticated: {}", conn_info.remote_id());
364 return AfterHandshakeOutcome::accept();
365 }
366
367 if conn_info.alpn() == Self::ALPN {
368 debug!(
369 "skipping auth for connection with alpn {}",
370 String::from_utf8_lossy(conn_info.alpn())
371 );
372 return AfterHandshakeOutcome::accept();
373 }
374
375 let remote_id = conn_info.remote_id();
376 let counter = self.watcher.get();
377
378 let wait_for_auth = async {
379 let mut stream = self.watcher.watch().stream();
380 while let Some(next_counter) = stream.next().await {
381 if next_counter != counter && self.is_authenticated(&remote_id) {
382 return Ok(()) as Result<(), AuthenticatorError>;
383 }
384 }
385 Err(AuthenticatorError::AcceptFailed(
386 "Watcher stream ended unexpectedly".to_string(),
387 ))
388 };
389
390 match timeout(Duration::from_secs(10), wait_for_auth).await {
391 Ok(_) => AfterHandshakeOutcome::accept(),
392 Err(_) => {
393 warn!("authentication timed out for {}", remote_id);
394 AfterHandshakeOutcome::Reject {
395 error_code: VarInt::from_u32(401),
396 reason: b"Authentication timed out".to_vec(),
397 }
398 }
399 }
400 }
401
402 async fn before_connect<'a>(
403 &'a self,
404 remote_addr: &'a iroh::EndpointAddr,
405 alpn: &'a [u8],
406 ) -> iroh::endpoint::BeforeConnectOutcome {
407 if self.is_authenticated(&remote_addr.id) {
408 debug!("already authenticated: {}", remote_addr.id);
409 return iroh::endpoint::BeforeConnectOutcome::Accept;
410 }
411
412 if alpn == Self::ALPN {
413 debug!(
414 "skipping auth for connection to {} with alpn {:?}",
415 remote_addr.id, alpn
416 );
417 return iroh::endpoint::BeforeConnectOutcome::Accept;
418 }
419
420 debug!(
421 "initiating auth for client connection with alpn {} to {}",
422 String::from_utf8_lossy(alpn),
423 remote_addr.id
424 );
425 let endpoint = match self.endpoint() {
426 Ok(ep) => ep,
427 Err(_) => {
428 warn!("authenticator endpoint not set");
429 return iroh::endpoint::BeforeConnectOutcome::Reject;
430 }
431 };
432 spawn({
433 let auth = self.clone();
434 let remote_id = remote_addr.id;
435
436 async move {
437 debug!("background: connecting to {} for auth", remote_id);
438
439 match endpoint.connect(remote_id, Self::ALPN).await {
440 Ok(conn) => {
441 debug!("background: connected to {}, performing auth", remote_id);
442 if let Err(err) = auth.auth_open(conn).await {
443 auth.add_blocked().ok();
444 warn!(
445 "background: authentication failed for {}: {}",
446 remote_id, err
447 );
448 } else {
449 debug!("background: authentication successful for {}", remote_id);
450 }
451 }
452 Err(e) => {
453 warn!(
454 "background: failed to open connection for authentication to {}: {}",
455 remote_id, e
456 );
457 }
458 };
459 }
460 });
461 iroh::endpoint::BeforeConnectOutcome::Accept
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use iroh::Watcher;
468
469 use super::*;
470 #[test]
471 fn test_token_different() {
472 let password = b"testpassword";
473 let id_a = b"identityA";
474 let id_b = b"identityB";
475
476 let (spake_a, token_a) = Spake2::<Ed25519Group>::start_a(
477 &Password::new(password),
478 &Identity::new(id_a),
479 &Identity::new(id_b),
480 );
481
482 let (spake_b, token_b) = Spake2::<Ed25519Group>::start_b(
483 &Password::new(password),
484 &Identity::new(id_a),
485 &Identity::new(id_b),
486 );
487
488 assert_ne!(token_a, token_b);
489
490 let key_a = spake_a.finish(&token_b).unwrap();
491 let key_b = spake_b.finish(&token_a).unwrap();
492
493 assert_eq!(key_a, key_b);
494 }
495
496 #[derive(Debug, Clone)]
497 struct DummyProtocol;
498 impl ProtocolHandler for DummyProtocol {
499 async fn accept(&self, _conn: Connection) -> Result<(), iroh::protocol::AcceptError> {
500 Ok(())
501 }
502 }
503
504 #[tokio::test(flavor = "multi_thread")]
505 async fn test_auth_success() {
506 let secret = b"supersecrettoken1234567890123456";
507 assert!(run_auth_test(secret, secret).await.unwrap());
508 }
509
510 #[tokio::test(flavor = "multi_thread")]
511 async fn test_auth_failure() {
512 let secret_a = b"supersecrettoken1234567890123456";
513 let secret_b = b"differentsecrettoken123456789012";
514 assert!(!run_auth_test(secret_a, secret_b).await.unwrap());
515 }
516
517 async fn run_auth_test(
518 secret_a: &'static [u8],
519 secret_b: &'static [u8],
520 ) -> Result<bool, String> {
521
522 let auth_a = Authenticator::new(secret_a);
523 let endpoint_a = iroh::Endpoint::builder()
524 .hooks(auth_a.clone())
525 .bind()
526 .await
527 .map_err(|e| e.to_string())?;
528 auth_a.set_endpoint(&endpoint_a);
529 let dummy_a = DummyProtocol;
530
531 let auth_b = Authenticator::new(secret_b);
532 let endpoint_b = iroh::Endpoint::builder()
533 .hooks(auth_b.clone())
534 .bind()
535 .await
536 .map_err(|e| e.to_string())?;
537 auth_b.set_endpoint(&endpoint_b);
538 let dummy_b = DummyProtocol;
539
540 let router_a = iroh::protocol::Router::builder(endpoint_a.clone())
541 .accept(Authenticator::ALPN, auth_a.clone())
542 .accept(b"/dummy/1", dummy_a)
543 .spawn();
544
545 let router_b = iroh::protocol::Router::builder(endpoint_b.clone())
546 .accept(Authenticator::ALPN, auth_b.clone())
547 .accept(b"/dummy/1", dummy_b)
548 .spawn();
549
550 spawn({
551 let endpoint_a = endpoint_a.clone();
552 let endpoint_b = endpoint_b.clone();
553 async move {
554 endpoint_a
555 .connect(endpoint_b.addr(), b"/dummy/1")
556 .await
557 .ok();
558 }
559 });
560
561 let wait_loop = async {
562 use n0_future::StreamExt;
563
564 let wait_a = async {
565 let mut stream = auth_a.watcher.watch().stream();
566 while let Some(counter) = stream.next().await {
567 if counter.authenticated >= 1 || counter.blocked >= 1 {
568 break;
569 }
570 }
571 };
572 let wait_b = async {
573 let mut stream = auth_b.watcher.watch().stream();
574 while let Some(counter) = stream.next().await {
575 if counter.authenticated >= 1 || counter.blocked >= 1 {
576 break;
577 }
578 }
579 };
580 tokio::join!(wait_a, wait_b);
581 };
582
583 if timeout(Duration::from_secs(20), wait_loop).await.is_err() {
584 router_a.shutdown().await.ok();
585 router_b.shutdown().await.ok();
586 return Err("Authentication did not complete in time".to_string());
587 }
588
589 router_a.shutdown().await.ok();
590 router_b.shutdown().await.ok();
591
592 Ok(auth_a.is_authenticated(&endpoint_b.id()) && auth_b.is_authenticated(&endpoint_a.id()))
593 }
594
595 #[test]
596 fn test_into_secret_impls() {
597 use secrecy::ExposeSecret;
598
599 let expected_bytes = b"my-secret-key";
600
601 let secret = "my-secret-key".into_secret();
603 assert_eq!(secret.expose_secret(), expected_bytes);
604
605 let secret = String::from("my-secret-key").into_secret();
607 assert_eq!(secret.expose_secret(), expected_bytes);
608 let secret = b"my-secret-key".to_vec().into_secret();
610 assert_eq!(secret.expose_secret(), expected_bytes);
611
612 let bytes: &[u8] = b"my-secret-key";
614 let secret = bytes.into_secret();
615 assert_eq!(secret.expose_secret(), expected_bytes);
616
617 let bytes: &[u8; 13] = b"my-secret-key";
619 let secret = bytes.into_secret();
620 assert_eq!(secret.expose_secret(), expected_bytes);
621
622 let bytes: Box<[u8]> = Box::new(*b"my-secret-key");
624 let secret = bytes.into_secret();
625 assert_eq!(secret.expose_secret(), expected_bytes);
626
627 let ps = SecretSlice::new(Box::new(*b"my-secret-key"));
629 let secret = ps.into_secret();
630 assert_eq!(secret.expose_secret(), expected_bytes);
631 }
632}