1use n0_watcher::Watchable;
2use std::{
3 collections::BTreeSet,
4 sync::{Arc, Mutex},
5 time::Duration,
6};
7use tracing::{trace, debug, error, info, warn};
8
9use hkdf::Hkdf;
10use iroh::{
11 endpoint::{AfterHandshakeOutcome, Connection, EndpointHooks, VarInt},
12 protocol::ProtocolHandler,
13 Endpoint, PublicKey, Watcher,
14};
15use n0_future::{task::spawn, time::timeout, StreamExt};
16use secrecy::{ExposeSecret, SecretSlice};
17use sha2::Sha512;
18use spake2::{Ed25519Group, Identity, Password, Spake2};
19use subtle::ConstantTimeEq;
20
21#[derive(Debug)]
23pub enum AuthenticatorError {
24 AddFailed,
25 AcceptFailed(String),
26 OpenFailed(String),
27 EndpointNotSet,
28}
29
30impl std::fmt::Display for AuthenticatorError {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 AuthenticatorError::AddFailed => write!(f, "Failed to add authenticated ID"),
34 AuthenticatorError::AcceptFailed(msg) => write!(f, "Accept failed: {}", msg),
35 AuthenticatorError::OpenFailed(msg) => write!(f, "Open failed: {}", msg),
36 AuthenticatorError::EndpointNotSet => write!(
37 f,
38 "Authenticator endpoint not set: missing authenticator.start(endpoint)"
39 ),
40 }
41 }
42}
43
44impl std::error::Error for AuthenticatorError {}
45
46pub trait IntoSecret {
47 fn into_secret(self) -> SecretSlice<u8>;
48}
49
50impl IntoSecret for SecretSlice<u8> {
51 fn into_secret(self) -> SecretSlice<u8> {
52 self
53 }
54}
55
56impl IntoSecret for String {
57 fn into_secret(self) -> SecretSlice<u8> {
58 SecretSlice::new(self.into_bytes().into_boxed_slice())
59 }
60}
61
62impl IntoSecret for &str {
63 fn into_secret(self) -> SecretSlice<u8> {
64 SecretSlice::new(self.as_bytes().to_vec().into_boxed_slice())
65 }
66}
67
68impl IntoSecret for Vec<u8> {
69 fn into_secret(self) -> SecretSlice<u8> {
70 SecretSlice::new(self.into_boxed_slice())
71 }
72}
73
74impl IntoSecret for &[u8] {
75 fn into_secret(self) -> SecretSlice<u8> {
76 SecretSlice::new(self.to_vec().into_boxed_slice())
77 }
78}
79
80impl<const N: usize> IntoSecret for &[u8; N] {
81 fn into_secret(self) -> SecretSlice<u8> {
82 SecretSlice::new(self.as_slice().to_vec().into_boxed_slice())
83 }
84}
85
86impl IntoSecret for Box<[u8]> {
87 fn into_secret(self) -> SecretSlice<u8> {
88 SecretSlice::new(self)
89 }
90}
91
92#[derive(Debug, Clone, Default, PartialEq, Eq)]
93struct WatchableCounter {
94 authenticated: usize,
95 blocked: usize,
96}
97
98#[derive(Debug, Clone)]
99pub struct Authenticator {
100 secret: SecretSlice<u8>,
101 authenticated: Arc<Mutex<BTreeSet<PublicKey>>>,
102 watcher: Watchable<WatchableCounter>,
103 endpoint: Arc<Mutex<Option<iroh::Endpoint>>>,
104}
105
106impl Authenticator {
107 pub const ALPN: &'static [u8] = b"/iroh/auth/0.1";
108 const ACCEPT_CONTEXT: &'static [u8] = b"iroh-auth-accept";
109 const OPEN_CONTEXT: &'static [u8] = b"iroh-auth-open";
110
111 pub fn new<S: IntoSecret>(secret: S) -> Self {
112 Self {
113 secret: secret.into_secret(),
114 authenticated: Arc::new(Mutex::new(BTreeSet::new())),
115 watcher: Watchable::new(WatchableCounter::default()),
116 endpoint: Arc::new(Mutex::new(None)),
117 }
118 }
119
120 pub fn set_endpoint(&self, endpoint: &Endpoint) {
121 if let Ok(mut guard) = self.endpoint.lock() {
122 if guard.is_none() {
123 *guard = Some(endpoint.clone());
124 trace!("Authenticator endpoint set to {}", endpoint.id());
125 }
126 }
127 }
128
129 fn id(&self) -> Result<PublicKey, AuthenticatorError> {
130 self.endpoint
131 .lock()
132 .map_err(|_| AuthenticatorError::EndpointNotSet)?
133 .as_ref()
134 .map(|ep| ep.id())
135 .ok_or(AuthenticatorError::EndpointNotSet)
136 }
137
138 fn endpoint(&self) -> Result<iroh::Endpoint, AuthenticatorError> {
139 self.endpoint
140 .lock()
141 .map_err(|_| AuthenticatorError::EndpointNotSet)?
142 .as_ref()
143 .cloned()
144 .ok_or(AuthenticatorError::EndpointNotSet)
145 }
146
147 pub fn is_authenticated(&self, id: &PublicKey) -> bool {
148 self.authenticated
149 .lock()
150 .map(|set| set.contains(id))
151 .unwrap_or(false)
152 }
153
154 fn add_authenticated(&self, id: PublicKey) -> Result<(), AuthenticatorError> {
155 self.authenticated
156 .lock()
157 .map_err(|_| AuthenticatorError::AddFailed)?
158 .insert(id);
159 let mut counter = self.watcher.get();
160 counter.authenticated += 1;
161 self.watcher
162 .set(counter)
163 .map_err(|_| AuthenticatorError::AddFailed)?;
164 Ok(())
165 }
166
167 fn add_blocked(&self) -> Result<(), AuthenticatorError> {
168 let mut counter = self.watcher.get();
169 counter.blocked += 1;
170 self.watcher
171 .set(counter)
172 .map_err(|_| AuthenticatorError::AddFailed)?;
173 Ok(())
174 }
175
176 #[doc(hidden)]
177 pub fn list_authenticated(&self) -> Vec<PublicKey> {
178 self.authenticated
179 .lock()
180 .map(|set| set.iter().cloned().collect())
181 .unwrap_or_default()
182 }
183
184 pub async fn auth_accept(&self, conn: Connection) -> Result<(), AuthenticatorError> {
188 let remote_id = conn.remote_id();
189 debug!("accepting auth connection from {}", remote_id);
190 let (mut send, mut recv) = conn.accept_bi().await.map_err(|err| {
191 error!("accept bidirectional stream failed: {}", err);
192 AuthenticatorError::AcceptFailed(format!("Accept bidirectional stream failed: {}", err))
193 })?;
194
195 let (spake, token_b) = Spake2::<Ed25519Group>::start_b(
196 &Password::new(self.secret.expose_secret()),
197 &Identity::new(conn.remote_id().as_bytes()),
198 &Identity::new(self.id()?.as_bytes()),
199 );
200
201 let mut token_a = [0u8; 33];
202 recv.read_exact(&mut token_a).await.map_err(|err| {
203 error!("failed to read token_a: {}", err);
204 AuthenticatorError::AcceptFailed(format!("Failed to read token_a: {}", err))
205 })?;
206
207 send.write_all(&token_b).await.map_err(|err| {
208 error!("failed to write token_b: {}", err);
209 AuthenticatorError::AcceptFailed(format!("Failed to write token_b: {}", err))
210 })?;
211
212 let shared_secret = spake.finish(&token_a).map_err(|err| {
213 error!("SPAKE2 invalid: {}", err);
214 AuthenticatorError::AcceptFailed(format!("SPAKE2 invalid: {}", err))
215 })?;
216
217 let hk = Hkdf::<Sha512>::new(None, shared_secret.as_slice());
218 let mut accept_key = [0u8; 64];
219 let mut open_key = [0u8; 64];
220 hk.expand(Self::ACCEPT_CONTEXT, &mut accept_key)
221 .map_err(|err| {
222 error!("failed to expand accept_key: {}", err);
223 AuthenticatorError::AcceptFailed(format!("Failed to expand accept_key: {}", err))
224 })?;
225 hk.expand(Self::OPEN_CONTEXT, &mut open_key)
226 .map_err(|err| {
227 error!("failed to expand open_key: {}", err);
228 AuthenticatorError::AcceptFailed(format!("Failed to expand open_key: {}", err))
229 })?;
230
231 send.write_all(&accept_key).await.map_err(|err| {
232 error!("failed to write accept_key: {}", err);
233 AuthenticatorError::AcceptFailed(format!("Failed to write accept_key: {}", err))
234 })?;
235 let mut remote_open_key = [0u8; 64];
236 recv.read_exact(&mut remote_open_key).await.map_err(|err| {
237 error!("failed to read remote_open_key: {}", err);
238 AuthenticatorError::AcceptFailed(format!("Failed to read remote_open_key: {}", err))
239 })?;
240
241 if !bool::from(remote_open_key.ct_eq(&open_key)) {
242 error!("remote open_key mismatch");
243 return Err(AuthenticatorError::AcceptFailed(
244 "Remote open_key mismatch".to_string(),
245 ));
246 }
247
248 self.add_authenticated(conn.remote_id())?;
249 info!("authenticated connection from {}", remote_id);
250
251 Ok(())
252 }
253
254 pub async fn auth_open(&self, conn: Connection) -> Result<(), AuthenticatorError> {
258 let remote_id = conn.remote_id();
259 debug!("opening auth connection to {}", remote_id);
260 let (mut send, mut recv) = conn.open_bi().await.map_err(|err| {
261 error!("open bidirectional stream failed: {}", err);
262 AuthenticatorError::AcceptFailed(format!("Open bidirectional stream failed: {}", err))
263 })?;
264
265 let (spake, token_a) = Spake2::<Ed25519Group>::start_a(
266 &Password::new(self.secret.expose_secret()),
267 &Identity::new(self.id()?.as_bytes()),
268 &Identity::new(conn.remote_id().as_bytes()),
269 );
270
271 send.write_all(&token_a).await.map_err(|err| {
272 error!("failed to write token_a: {}", err);
273 AuthenticatorError::AcceptFailed(format!("Failed to write token_a: {}", err))
274 })?;
275
276 let mut token_b = [0u8; 33];
277 recv.read_exact(&mut token_b).await.map_err(|err| {
278 error!("failed to read token_b: {}", err);
279 AuthenticatorError::AcceptFailed(format!("Failed to read token_b: {}", err))
280 })?;
281
282 let shared_secret = spake.finish(&token_b).map_err(|err| {
283 error!("SPAKE2 invalid: {}", err);
284 AuthenticatorError::AcceptFailed(format!("SPAKE2 invalid: {}", err))
285 })?;
286
287 let hk = Hkdf::<Sha512>::new(None, shared_secret.as_slice());
288 let mut accept_key = [0u8; 64];
289 let mut open_key = [0u8; 64];
290 hk.expand(Self::ACCEPT_CONTEXT, &mut accept_key)
291 .map_err(|err| {
292 error!("failed to expand accept_key: {}", err);
293 AuthenticatorError::AcceptFailed(format!("Failed to expand accept_key: {}", err))
294 })?;
295 hk.expand(Self::OPEN_CONTEXT, &mut open_key)
296 .map_err(|err| {
297 error!("failed to expand open_key: {}", err);
298 AuthenticatorError::AcceptFailed(format!("Failed to expand open_key: {}", err))
299 })?;
300
301 let mut remote_accept_key = [0u8; 64];
302 recv.read_exact(&mut remote_accept_key)
303 .await
304 .map_err(|err| {
305 error!("failed to read remote_accept_key: {}", err);
306 AuthenticatorError::AcceptFailed(format!(
307 "Failed to read remote_accept_key: {}",
308 err
309 ))
310 })?;
311
312 if !bool::from(remote_accept_key.ct_eq(&accept_key)) {
313 error!("remote accept_key mismatch");
314 return Err(AuthenticatorError::AcceptFailed(
315 "Remote accept_key mismatch".to_string(),
316 ));
317 }
318
319 send.write_all(&open_key).await.map_err(|err| {
320 error!("failed to write open_key: {}", err);
321 AuthenticatorError::AcceptFailed(format!("Failed to write open_key: {}", err))
322 })?;
323 send.finish().map_err(|err| {
324 error!("failed to finish stream: {}", err);
325 AuthenticatorError::AcceptFailed(format!("Failed to finish stream: {}", err))
326 })?;
327
328 conn.closed().await;
329
330 self.add_authenticated(conn.remote_id())?;
331 info!("authenticated connection to {}", remote_id);
332
333 Ok(())
334 }
335}
336
337impl ProtocolHandler for Authenticator {
338 async fn accept(
339 &self,
340 connection: iroh::endpoint::Connection,
341 ) -> Result<(), iroh::protocol::AcceptError> {
342 if let Err(err) = self
343 .auth_accept(connection)
344 .await
345 .map_err(|err| iroh::protocol::AcceptError::from_err(err))
346 {
347 self.add_blocked().ok();
348 Err(err)
349 } else {
350 Ok(())
351 }
352 }
353}
354
355impl EndpointHooks for Authenticator {
356 async fn after_handshake<'a>(
357 &'a self,
358 conn_info: &'a iroh::endpoint::ConnectionInfo,
359 ) -> iroh::endpoint::AfterHandshakeOutcome {
360 if self.is_authenticated(&conn_info.remote_id()) {
361 debug!("already authenticated: {}", conn_info.remote_id());
362 return AfterHandshakeOutcome::accept();
363 }
364
365 if conn_info.alpn() == Self::ALPN {
366 debug!(
367 "skipping auth for connection with alpn {}",
368 String::from_utf8_lossy(conn_info.alpn())
369 );
370 return AfterHandshakeOutcome::accept();
371 }
372
373 let remote_id = conn_info.remote_id();
374 let counter = self.watcher.get();
375
376 let wait_for_auth = async {
377 let mut stream = self.watcher.watch().stream();
378 while let Some(next_counter) = stream.next().await {
379 if next_counter != counter && self.is_authenticated(&remote_id) {
380 return Ok(()) as Result<(), AuthenticatorError>;
381 }
382 }
383 Err(AuthenticatorError::AcceptFailed(
384 "Watcher stream ended unexpectedly".to_string(),
385 ))
386 };
387
388 match timeout(Duration::from_secs(10), wait_for_auth).await {
389 Ok(_) => AfterHandshakeOutcome::accept(),
390 Err(_) => {
391 warn!("authentication timed out for {}", remote_id);
392 AfterHandshakeOutcome::Reject {
393 error_code: VarInt::from_u32(401),
394 reason: b"Authentication timed out".to_vec(),
395 }
396 }
397 }
398 }
399
400 async fn before_connect<'a>(
401 &'a self,
402 remote_addr: &'a iroh::EndpointAddr,
403 alpn: &'a [u8],
404 ) -> iroh::endpoint::BeforeConnectOutcome {
405 if self.is_authenticated(&remote_addr.id) {
406 debug!("already authenticated: {}", remote_addr.id);
407 return iroh::endpoint::BeforeConnectOutcome::Accept;
408 }
409
410 if alpn == Self::ALPN {
411 debug!(
412 "skipping auth for connection to {} with alpn {:?}",
413 remote_addr.id, alpn
414 );
415 return iroh::endpoint::BeforeConnectOutcome::Accept;
416 }
417
418 debug!(
419 "initiating auth for client connection with alpn {} to {}",
420 String::from_utf8_lossy(alpn),
421 remote_addr.id
422 );
423 let endpoint = match self.endpoint() {
424 Ok(ep) => ep,
425 Err(_) => {
426 warn!("authenticator endpoint not set");
427 return iroh::endpoint::BeforeConnectOutcome::Reject;
428 }
429 };
430 spawn({
431 let auth = self.clone();
432 let remote_id = remote_addr.id;
433
434 async move {
435 debug!("background: connecting to {} for auth", remote_id);
436
437 match endpoint.connect(remote_id, Self::ALPN).await {
438 Ok(conn) => {
439 debug!("background: connected to {}, performing auth", remote_id);
440 if let Err(err) = auth.auth_open(conn).await {
441 auth.add_blocked().ok();
442 warn!(
443 "background: authentication failed for {}: {}",
444 remote_id, err
445 );
446 } else {
447 debug!("background: authentication successful for {}", remote_id);
448 }
449 }
450 Err(e) => {
451 warn!(
452 "background: failed to open connection for authentication to {}: {}",
453 remote_id, e
454 );
455 }
456 };
457 }
458 });
459 iroh::endpoint::BeforeConnectOutcome::Accept
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use iroh::Watcher;
466
467 use super::*;
468 #[test]
469 fn test_token_different() {
470 let password = b"testpassword";
471 let id_a = b"identityA";
472 let id_b = b"identityB";
473
474 let (spake_a, token_a) = Spake2::<Ed25519Group>::start_a(
475 &Password::new(password),
476 &Identity::new(id_a),
477 &Identity::new(id_b),
478 );
479
480 let (spake_b, token_b) = Spake2::<Ed25519Group>::start_b(
481 &Password::new(password),
482 &Identity::new(id_a),
483 &Identity::new(id_b),
484 );
485
486 assert_ne!(token_a, token_b);
487
488 let key_a = spake_a.finish(&token_b).unwrap();
489 let key_b = spake_b.finish(&token_a).unwrap();
490
491 assert_eq!(key_a, key_b);
492 }
493
494 #[derive(Debug, Clone)]
495 struct DummyProtocol;
496 impl ProtocolHandler for DummyProtocol {
497 async fn accept(&self, _conn: Connection) -> Result<(), iroh::protocol::AcceptError> {
498 Ok(())
499 }
500 }
501
502 #[tokio::test(flavor = "multi_thread")]
503 async fn test_auth_success() {
504 let secret = b"supersecrettoken1234567890123456";
505 assert!(run_auth_test(secret, secret).await.unwrap());
506 }
507
508 #[tokio::test(flavor = "multi_thread")]
509 async fn test_auth_failure() {
510 let secret_a = b"supersecrettoken1234567890123456";
511 let secret_b = b"differentsecrettoken123456789012";
512 assert!(!run_auth_test(secret_a, secret_b).await.unwrap());
513 }
514
515 async fn run_auth_test(
516 secret_a: &'static [u8],
517 secret_b: &'static [u8],
518 ) -> Result<bool, String> {
519
520 let auth_a = Authenticator::new(secret_a);
521 let endpoint_a = iroh::Endpoint::builder()
522 .hooks(auth_a.clone())
523 .bind()
524 .await
525 .map_err(|e| e.to_string())?;
526 auth_a.set_endpoint(&endpoint_a);
527 let dummy_a = DummyProtocol;
528
529 let auth_b = Authenticator::new(secret_b);
530 let endpoint_b = iroh::Endpoint::builder()
531 .hooks(auth_b.clone())
532 .bind()
533 .await
534 .map_err(|e| e.to_string())?;
535 auth_b.set_endpoint(&endpoint_b);
536 let dummy_b = DummyProtocol;
537
538 let router_a = iroh::protocol::Router::builder(endpoint_a.clone())
539 .accept(Authenticator::ALPN, auth_a.clone())
540 .accept(b"/dummy/1", dummy_a)
541 .spawn();
542
543 let router_b = iroh::protocol::Router::builder(endpoint_b.clone())
544 .accept(Authenticator::ALPN, auth_b.clone())
545 .accept(b"/dummy/1", dummy_b)
546 .spawn();
547
548 spawn({
549 let endpoint_a = endpoint_a.clone();
550 let endpoint_b = endpoint_b.clone();
551 async move {
552 endpoint_a
553 .connect(endpoint_b.addr(), b"/dummy/1")
554 .await
555 .ok();
556 }
557 });
558
559 let wait_loop = async {
560 use n0_future::StreamExt;
561
562 let wait_a = async {
563 let mut stream = auth_a.watcher.watch().stream();
564 while let Some(counter) = stream.next().await {
565 if counter.authenticated >= 1 || counter.blocked >= 1 {
566 break;
567 }
568 }
569 };
570 let wait_b = async {
571 let mut stream = auth_b.watcher.watch().stream();
572 while let Some(counter) = stream.next().await {
573 if counter.authenticated >= 1 || counter.blocked >= 1 {
574 break;
575 }
576 }
577 };
578 tokio::join!(wait_a, wait_b);
579 };
580
581 if timeout(Duration::from_secs(20), wait_loop).await.is_err() {
582 router_a.shutdown().await.ok();
583 router_b.shutdown().await.ok();
584 return Err("Authentication did not complete in time".to_string());
585 }
586
587 router_a.shutdown().await.ok();
588 router_b.shutdown().await.ok();
589
590 Ok(auth_a.is_authenticated(&endpoint_b.id()) && auth_b.is_authenticated(&endpoint_a.id()))
591 }
592
593 #[test]
594 fn test_into_secret_impls() {
595 use secrecy::ExposeSecret;
596
597 let expected_bytes = b"my-secret-key";
598
599 let secret = "my-secret-key".into_secret();
601 assert_eq!(secret.expose_secret(), expected_bytes);
602
603 let secret = String::from("my-secret-key").into_secret();
605 assert_eq!(secret.expose_secret(), expected_bytes);
606 let secret = b"my-secret-key".to_vec().into_secret();
608 assert_eq!(secret.expose_secret(), expected_bytes);
609
610 let bytes: &[u8] = b"my-secret-key";
612 let secret = bytes.into_secret();
613 assert_eq!(secret.expose_secret(), expected_bytes);
614
615 let bytes: &[u8; 13] = b"my-secret-key";
617 let secret = bytes.into_secret();
618 assert_eq!(secret.expose_secret(), expected_bytes);
619
620 let bytes: Box<[u8]> = Box::new(*b"my-secret-key");
622 let secret = bytes.into_secret();
623 assert_eq!(secret.expose_secret(), expected_bytes);
624
625 let ps = SecretSlice::new(Box::new(*b"my-secret-key"));
627 let secret = ps.into_secret();
628 assert_eq!(secret.expose_secret(), expected_bytes);
629 }
630}