1use std::sync::Arc;
2
3use iroh::{
4 endpoint::{AfterHandshakeOutcome, EndpointHooks, VarInt},
5 protocol::ProtocolHandler,
6 EndpointId, PublicKey,
7};
8use lru::LruCache;
9use n0_future::StreamExt;
10use n0_watcher::Watcher;
11use tokio::{sync::Mutex, time::timeout};
12use tracing::{debug, error, info, trace, warn};
13
14use crate::{
15 auth::{AuthState, RegisterResponse, WatchableRemote},
16 error::InFlightError,
17 Authenticator, AuthenticatorError, ALPN, AUTH_TIMEOUT,
18};
19
20impl ProtocolHandler for Authenticator {
21 async fn accept(
22 &self,
23 connection: iroh::endpoint::Connection,
24 ) -> Result<(), iroh::protocol::AcceptError> {
25 let remote_id = connection.remote_id();
26 trace!("[accept] starting auth protocol accept for {}", remote_id);
27 let res = match timeout(AUTH_TIMEOUT, self.auth_accept(connection)).await {
28 Ok(Ok(())) => {
29 trace!(
30 "[accept] auth_accept succeeded for {}, releasing as Authenticated",
31 remote_id
32 );
33 release_in_flight(self.auth_state.clone(), remote_id, AuthState::Authenticated)
34 .await
35 .ok();
36 Ok(())
37 }
38 Ok(Err(err)) => match &err {
39 AuthenticatorError::AcceptFailedAndBlock(msg, public_key) => {
40 warn!(
41 "[accept] authentication failed and blocking {}: {}",
42 public_key, msg
43 );
44 trace!(
45 "[accept] releasing {} as Blocked after accept failure",
46 remote_id
47 );
48 release_in_flight(self.auth_state.clone(), remote_id, AuthState::Blocked)
49 .await
50 .ok();
51 Err(iroh::protocol::AcceptError::from_err(err))
52 }
53 _ => {
54 warn!("[accept] authentication failed: {}", err);
55 trace!(
56 "[accept] releasing {} as Unauthenticated after accept failure",
57 remote_id
58 );
59 release_in_flight(
60 self.auth_state.clone(),
61 remote_id,
62 AuthState::Unauthenticated,
63 )
64 .await
65 .ok();
66 Err(iroh::protocol::AcceptError::from_err(err))
67 }
68 },
69 Err(_) => {
70 warn!("[accept] authentication failed: timed out");
71 trace!(
72 "[accept] releasing {} as Unauthenticated after accept timeout",
73 remote_id
74 );
75 release_in_flight(
76 self.auth_state.clone(),
77 remote_id,
78 AuthState::Unauthenticated,
79 )
80 .await
81 .ok();
82 Err(iroh::protocol::AcceptError::from_err(
83 AuthenticatorError::AcceptFailed("Authentication timed out".into()),
84 ))
85 }
86 };
87
88 res
89 }
90}
91
92impl EndpointHooks for Authenticator {
93 async fn after_handshake<'a>(
94 &'a self,
95 conn: &'a iroh::endpoint::Connection,
96 ) -> iroh::endpoint::AfterHandshakeOutcome {
97 let endpoint_id = conn.remote_id();
98 trace!(
99 "[after_handshake] entered for {} with alpn {}",
100 endpoint_id,
101 String::from_utf8_lossy(conn.alpn())
102 );
103 if self.is_authenticated(&endpoint_id).await {
104 debug!("[after_handshake] already authenticated: {}", endpoint_id);
105 return AfterHandshakeOutcome::accept();
106 }
107
108 if conn.alpn() == ALPN {
109 debug!(
110 "[after_handshake] accepting auth connection: {}",
111 String::from_utf8_lossy(conn.alpn())
112 );
113 return AfterHandshakeOutcome::accept();
114 }
115
116 let in_flight_watcher = if let Some(watchable) =
118 get_auth_state(self.auth_state.clone(), endpoint_id).await
119 {
120 trace!(
121 "[after_handshake] found auth state for {}: {}",
122 endpoint_id,
123 watchable.state()
124 );
125 match watchable.state() {
126 AuthState::Unauthenticated => {
127 debug!("[after_handshake] no in-flight auth for {}, we are asymetric (the other node successfully authed but we didn't), initiating auth ourself",endpoint_id);
128 match register_in_flight(self.auth_state.clone(), endpoint_id).await {
129 Ok(RegisterResponse::AlreadyInFlight) => {
130 debug!(
131 "[after_handshake] already in-flight auth for {}, waiting for it to complete",
132 endpoint_id
133 );
134 watchable.watcher()
135 }
136 Ok(RegisterResponse::InFlightRegistered) => {
137 debug!(
138 "[after_handshake] registered in-flight auth for {}, performing auth",
139 endpoint_id
140 );
141 let endpoint = match self.endpoint().await {
142 Ok(ep) => ep,
143 Err(_) => {
144 error!("[after_handshake] authenticator endpoint not set");
145 return AfterHandshakeOutcome::Reject {
146 error_code: VarInt::from_u32(500),
147 reason: b"Internal server error".to_vec(),
148 };
149 }
150 };
151 if let Err(err) = self.perform_auth(endpoint_id, endpoint).await {
152 error!(
153 "[after_handshake] authentication failed for {}, rejecting connection with error: {}",
154 endpoint_id, err
155 );
156 return AfterHandshakeOutcome::Reject {
157 error_code: VarInt::from_u32(401),
158 reason: b"Authentication failed".to_vec(),
159 };
160 } else {
161 info!(
162 "[after_handshake] authentication succeeded for {}",
163 endpoint_id
164 );
165 debug!(
166 "[after_handshake] authentication succeeded for {}, waiting for state update",
167 endpoint_id
168 );
169 return iroh::endpoint::AfterHandshakeOutcome::accept();
170 }
171 }
172 _ => {
173 debug!(
174 "[after_handshake] failed to register in-flight auth for {}, rejecting connection",
175 endpoint_id
176 );
177 return AfterHandshakeOutcome::Reject {
178 error_code: VarInt::from_u32(401),
179 reason: b"Authentication failed".to_vec(),
180 };
181 }
182 }
183 }
184 AuthState::InFlight => {
185 debug!(
186 "[after_handshake] waiting for in-flight auth for {}",
187 endpoint_id
188 );
189 watchable.watcher()
190 }
191 AuthState::Authenticated => {
192 debug!(
193 "[after_handshake] already authenticated: {}",
194 conn.remote_id()
195 );
196 return AfterHandshakeOutcome::accept();
197 }
198 AuthState::Blocked => {
199 debug!(
200 "[after_handshake] endpoint {} is blocked, rejecting connection",
201 endpoint_id
202 );
203 return AfterHandshakeOutcome::Reject {
204 error_code: VarInt::from_u32(403),
205 reason: b"Endpoint is blocked".to_vec(),
206 };
207 }
208 }
209 } else {
210 debug!(
211 "[after_handshake] no in-flight auth for {}, rejecting connection",
212 endpoint_id
213 );
214 return AfterHandshakeOutcome::Reject {
215 error_code: VarInt::from_u32(401),
216 reason: b"No authentication in progress".to_vec(),
217 };
218 };
219
220 let wait_for_auth = async {
221 trace!(
222 "[after_handshake] subscribing to auth state updates for {}",
223 endpoint_id
224 );
225 let mut stream = in_flight_watcher.watch().stream();
226 while let Some(in_flight) = stream.next().await {
227 trace!(
228 "[after_handshake] observed auth state update for {} -> {}",
229 endpoint_id,
230 in_flight
231 );
232 if matches!(
233 in_flight,
234 AuthState::Unauthenticated | AuthState::Authenticated | AuthState::Blocked
235 ) {
236 trace!(
237 "[after_handshake] terminal auth state {} reached for {}",
238 in_flight,
239 endpoint_id
240 );
241 return;
242 }
243 }
244 warn!(
245 "[after_handshake] auth state watch stream ended unexpectedly for {}",
246 endpoint_id
247 );
248 };
249
250 match timeout(AUTH_TIMEOUT, wait_for_auth).await {
251 Ok(_) => {
252 if self.is_authenticated(&endpoint_id).await {
253 trace!(
254 "[after_handshake] auth completed successfully for {}",
255 endpoint_id
256 );
257 AfterHandshakeOutcome::accept()
258 } else {
259 warn!(
260 "[after_handshake] auth wait finished for {} but endpoint is not authenticated",
261 endpoint_id
262 );
263 AfterHandshakeOutcome::Reject {
264 error_code: VarInt::from_u32(401),
265 reason: b"Authentication failed".to_vec(),
266 }
267 }
268 }
269 Err(_) => {
270 warn!(
271 "[after_handshake] authentication timed out for {}",
272 endpoint_id
273 );
274 AfterHandshakeOutcome::Reject {
275 error_code: VarInt::from_u32(401),
276 reason: b"Authentication timed out".to_vec(),
277 }
278 }
279 }
280 }
281
282 async fn before_connect<'a>(
283 &'a self,
284 remote_addr: &'a iroh::EndpointAddr,
285 alpn: &'a [u8],
286 ) -> iroh::endpoint::BeforeConnectOutcome {
287 let remote_id = remote_addr.id;
288 trace!(
289 "[before_connect] entered for {} with alpn {}",
290 remote_id,
291 String::from_utf8_lossy(alpn)
292 );
293 if self.is_authenticated(&remote_id).await {
294 debug!("[before_connect] already authenticated: {}", remote_id);
295 return iroh::endpoint::BeforeConnectOutcome::Accept;
296 }
297
298 if alpn == ALPN {
299 debug!(
300 "[before_connect] initiating auth for client connection with alpn {} to {}",
301 String::from_utf8_lossy(alpn),
302 remote_id
303 );
304 return iroh::endpoint::BeforeConnectOutcome::Accept;
305 }
306
307 match register_in_flight(self.auth_state.clone(), remote_id).await {
308 Ok(RegisterResponse::InFlightRegistered) | Ok(RegisterResponse::AlreadyInFlight) => {
309 debug!(
310 "[before_connect] registered in-flight auth for {}, performing auth",
311 remote_id
312 );
313
314 let endpoint = match self.endpoint().await {
315 Ok(ep) => ep,
316 Err(_) => {
317 error!("[before_connect] authenticator endpoint not set");
318 return iroh::endpoint::BeforeConnectOutcome::Reject;
319 }
320 };
321 if let Err(err) = self.perform_auth(remote_id, endpoint).await {
322 error!(
323 "[before_connect] authentication failed for {}, rejecting connection with error: {}",
324 remote_id, err
325 );
326 iroh::endpoint::BeforeConnectOutcome::Reject
327 } else {
328 info!(
329 "[before_connect] authentication succeeded for {}",
330 remote_id
331 );
332 iroh::endpoint::BeforeConnectOutcome::Accept
333 }
334 }
335 Ok(RegisterResponse::AlreadyAuthenticated) => {
336 trace!(
337 "[before_connect] auth already in progress or complete for {}, allowing connect to proceed",
338 remote_id
339 );
340 if self.is_authenticated(&remote_id).await {
341 debug!(
342 "[before_connect] already authenticated (in flight), accepting connection to {}",
343 remote_id
344 );
345 }
346 iroh::endpoint::BeforeConnectOutcome::Accept
347 }
348 Ok(RegisterResponse::AlreadyBlocked) => {
349 debug!(
350 "[before_connect] endpoint {} is blocked, rejecting connection",
351 remote_id
352 );
353 iroh::endpoint::BeforeConnectOutcome::Reject
354 }
355 Err(err) => {
356 warn!(
357 "[before_connect] failed to register in-flight auth for {}: {}",
358 remote_id, err
359 );
360 iroh::endpoint::BeforeConnectOutcome::Reject
361 }
362 }
363 }
364}
365
366pub(crate) async fn register_in_flight(
367 in_flight: Arc<Mutex<LruCache<EndpointId, WatchableRemote>>>,
368 endpoint_id: PublicKey,
369) -> Result<RegisterResponse, InFlightError> {
370 trace!(
371 "[register_in_flight] locking auth cache for {}",
372 endpoint_id
373 );
374 let mut guard = in_flight.lock().await;
375 trace!(
376 "[register_in_flight] auth cache locked for {}, current size {}",
377 endpoint_id,
378 guard.len()
379 );
380 if let Some(entry) = guard.get(&endpoint_id) {
381 let current_state = entry.state();
382 trace!(
383 "[register_in_flight] existing state for {} is {}",
384 endpoint_id,
385 current_state
386 );
387 return match current_state {
388 AuthState::Unauthenticated => {
389 entry.set_state(AuthState::InFlight);
390 trace!(
391 "[register_in_flight] endpoint {} promoted from Unauthenticated to InFlight",
392 endpoint_id
393 );
394 Ok(RegisterResponse::InFlightRegistered)
395 }
396 AuthState::Authenticated => {
397 trace!(
398 "[register_in_flight] endpoint {} already authenticated",
399 endpoint_id
400 );
401 Ok(RegisterResponse::AlreadyAuthenticated)
402 }
403 AuthState::InFlight => {
404 trace!(
405 "[register_in_flight] endpoint {} already has auth in flight",
406 endpoint_id
407 );
408 Ok(RegisterResponse::AlreadyInFlight)
409 }
410 AuthState::Blocked => {
411 trace!("[register_in_flight] endpoint {} is blocked", endpoint_id);
412 Ok(RegisterResponse::AlreadyBlocked)
413 }
414 };
415 }
416
417 let watchable = WatchableRemote::new(endpoint_id);
418 watchable.set_state(AuthState::InFlight);
419 trace!(
420 "[register_in_flight] inserting new auth state entry for {} as InFlight",
421 endpoint_id
422 );
423
424 if let Some(evicted) = guard.put(endpoint_id, watchable) {
425 debug!(
426 "evicting endpoint {} from auth cache due to capacity limit",
427 evicted.id()
428 );
429 }
430
431 Ok(RegisterResponse::InFlightRegistered)
432}
433
434pub(crate) async fn release_in_flight(
435 in_flight: Arc<Mutex<LruCache<EndpointId, WatchableRemote>>>,
436 endpoint_id: PublicKey,
437 target_state: AuthState,
438) -> Result<(), InFlightError> {
439 trace!(
440 "[release_in_flight] requested state release for {} -> {}",
441 endpoint_id,
442 target_state
443 );
444 if target_state == AuthState::InFlight {
445 return Err(InFlightError::PromotionNotAllowed(
446 "cannot release by promoting to InFlight".to_string(),
447 ));
448 }
449 trace!("[release_in_flight] locking auth cache for {}", endpoint_id);
450 let mut guard = in_flight.lock().await;
451 trace!(
452 "[release_in_flight] auth cache locked for {}, current size {}",
453 endpoint_id,
454 guard.len()
455 );
456
457 if let Some(entry) = guard.get(&endpoint_id) {
459 let current_state = entry.state();
460 let target_state_for_logs = target_state.clone();
461 trace!(
462 "[release_in_flight] current state for {} is {}, target {}",
463 endpoint_id,
464 current_state,
465 target_state_for_logs
466 );
467 return match current_state {
468 AuthState::InFlight => {
469 entry.set_state(target_state);
470 trace!(
471 "[release_in_flight] endpoint {} released from InFlight to {}",
472 endpoint_id,
473 target_state_for_logs
474 );
475 Ok(())
476 }
477 AuthState::Authenticated => {
478 if target_state == AuthState::Blocked {
479 entry.set_state(AuthState::Blocked);
480 debug!(
481 "endpoint {} was authenticated but is now blocked, updating state to Blocked",
482 endpoint_id
483 );
484 Ok(())
485 } else {
486 trace!("endpoint {} is already authenticated, no-op", endpoint_id);
487 Ok(())
488 }
489 }
490 AuthState::Unauthenticated => match target_state {
491 AuthState::Blocked => {
492 entry.set_state(AuthState::Blocked);
493 debug!(
494 "endpoint {} was unauthenticated but is now blocked, updating state to Blocked",
495 endpoint_id
496 );
497 Ok(())
498 }
499 AuthState::Authenticated => {
500 trace!("promoting endpoint {} from Unauthenticated to Authenticated (this is required because we can have asymetric failures that lead to this state transition)", endpoint_id);
501
502 entry.set_state(AuthState::Authenticated);
503 Ok(())
504 }
505 AuthState::Unauthenticated => {
506 trace!("endpoint {} is already unauthenticated, no-op", endpoint_id);
507 Ok(())
508 }
509 AuthState::InFlight => {
510 trace!(
511 "cannot promote endpoint {} from Unauthenticated back to InFlight",
512 endpoint_id
513 );
514 Err(InFlightError::PromotionNotAllowed(
515 "cannot promote to InFlight".to_string(),
516 ))
517 }
518 },
519 current_state => {
520 if current_state == target_state {
521 debug!(
522 "endpoint {} is already in target state {}, no state change needed",
523 endpoint_id, target_state
524 );
525 Ok(())
526 } else {
527 warn!(
528 "[release_in_flight] refusing state overwrite for {} from {} to {}",
529 endpoint_id, current_state, target_state
530 );
531 Err(InFlightError::PromotionNotAllowed(format!(
532 "only promote to {} from {} not from {}",
533 target_state,
534 AuthState::InFlight,
535 entry.state()
536 )))
537 }
538 }
539 };
540 }
541
542 let watchable = WatchableRemote::new(endpoint_id);
544 let target_state_for_logs = target_state.clone();
545 watchable.set_state(target_state);
546 trace!(
547 "[release_in_flight] no auth state entry existed for {}, inserting {}",
548 endpoint_id,
549 target_state_for_logs
550 );
551
552 if let Some(evicted) = guard.put(endpoint_id, watchable) {
553 debug!(
554 "evicting endpoint {} from auth cache due to capacity limit",
555 evicted.id()
556 );
557 }
558
559 Ok(())
560}
561
562pub(crate) async fn get_auth_state(
563 auth_state: Arc<Mutex<LruCache<EndpointId, WatchableRemote>>>,
564 endpoint_id: PublicKey,
565) -> Option<WatchableRemote> {
566 trace!("[get_auth_state] locking auth cache for {}", endpoint_id);
567 let mut guard = auth_state.lock().await;
568 let result = guard.get(&endpoint_id).cloned();
569 match result.as_ref() {
570 Some(watchable) => {
571 trace!(
572 "[get_auth_state] found auth state for {}: {}",
573 endpoint_id,
574 watchable.state()
575 );
576 }
577 None => {
578 trace!("[get_auth_state] no auth state found for {}", endpoint_id);
579 }
580 }
581 result
582}