Skip to main content

iroh_auth/
protocol.rs

1use std::sync::Arc;
2
3use iroh::{
4    endpoint::{AfterHandshakeOutcome, EndpointHooks, VarInt},
5    protocol::ProtocolHandler,
6    EndpointId, PublicKey,
7};
8use lru::LruCache;
9use n0_future::StreamExt;
10use n0_watcher::Watcher;
11use tokio::{sync::Mutex, time::timeout};
12use tracing::{debug, error, info, trace, warn};
13
14use crate::{
15    auth::{AuthState, RegisterResponse, WatchableRemote},
16    error::InFlightError,
17    Authenticator, AuthenticatorError, ALPN, AUTH_TIMEOUT,
18};
19
20impl ProtocolHandler for Authenticator {
21    async fn accept(
22        &self,
23        connection: iroh::endpoint::Connection,
24    ) -> Result<(), iroh::protocol::AcceptError> {
25        let remote_id = connection.remote_id();
26        trace!("[accept] starting auth protocol accept for {}", remote_id);
27        let res = match timeout(AUTH_TIMEOUT, self.auth_accept(connection)).await {
28            Ok(Ok(())) => {
29                trace!(
30                    "[accept] auth_accept succeeded for {}, releasing as Authenticated",
31                    remote_id
32                );
33                release_in_flight(self.auth_state.clone(), remote_id, AuthState::Authenticated)
34                    .await
35                    .ok();
36                Ok(())
37            }
38            Ok(Err(err)) => match &err {
39                AuthenticatorError::AcceptFailedAndBlock(msg, public_key) => {
40                    warn!(
41                        "[accept] authentication failed and blocking {}: {}",
42                        public_key, msg
43                    );
44                    trace!(
45                        "[accept] releasing {} as Blocked after accept failure",
46                        remote_id
47                    );
48                    release_in_flight(self.auth_state.clone(), remote_id, AuthState::Blocked)
49                        .await
50                        .ok();
51                    Err(iroh::protocol::AcceptError::from_err(err))
52                }
53                _ => {
54                    warn!("[accept] authentication failed: {}", err);
55                    trace!(
56                        "[accept] releasing {} as Unauthenticated after accept failure",
57                        remote_id
58                    );
59                    release_in_flight(
60                        self.auth_state.clone(),
61                        remote_id,
62                        AuthState::Unauthenticated,
63                    )
64                    .await
65                    .ok();
66                    Err(iroh::protocol::AcceptError::from_err(err))
67                }
68            },
69            Err(_) => {
70                warn!("[accept] authentication failed: timed out");
71                trace!(
72                    "[accept] releasing {} as Unauthenticated after accept timeout",
73                    remote_id
74                );
75                release_in_flight(
76                    self.auth_state.clone(),
77                    remote_id,
78                    AuthState::Unauthenticated,
79                )
80                .await
81                .ok();
82                Err(iroh::protocol::AcceptError::from_err(
83                    AuthenticatorError::AcceptFailed("Authentication timed out".into()),
84                ))
85            }
86        };
87
88        res
89    }
90}
91
92impl EndpointHooks for Authenticator {
93    async fn after_handshake<'a>(
94        &'a self,
95        conn: &'a iroh::endpoint::Connection,
96    ) -> iroh::endpoint::AfterHandshakeOutcome {
97        let endpoint_id = conn.remote_id();
98        trace!(
99            "[after_handshake] entered for {} with alpn {}",
100            endpoint_id,
101            String::from_utf8_lossy(conn.alpn())
102        );
103        if self.is_authenticated(&endpoint_id).await {
104            debug!("[after_handshake] already authenticated: {}", endpoint_id);
105            return AfterHandshakeOutcome::accept();
106        }
107
108        if conn.alpn() == ALPN {
109            debug!(
110                "[after_handshake] accepting auth connection: {}",
111                String::from_utf8_lossy(conn.alpn())
112            );
113            return AfterHandshakeOutcome::accept();
114        }
115
116        // wait for authentication to complete
117        let in_flight_watcher = if let Some(watchable) =
118            get_auth_state(self.auth_state.clone(), endpoint_id).await
119        {
120            trace!(
121                "[after_handshake] found auth state for {}: {}",
122                endpoint_id,
123                watchable.state()
124            );
125            match watchable.state() {
126                AuthState::Unauthenticated => {
127                    debug!("[after_handshake] no in-flight auth for {}, we are asymetric (the other node successfully authed but we didn't), initiating auth ourself",endpoint_id);
128                    match register_in_flight(self.auth_state.clone(), endpoint_id).await {
129                        Ok(RegisterResponse::AlreadyInFlight) => {
130                            debug!(
131                                    "[after_handshake] already in-flight auth for {}, waiting for it to complete",
132                                    endpoint_id
133                                );
134                            watchable.watcher()
135                        }
136                        Ok(RegisterResponse::InFlightRegistered) => {
137                            debug!(
138                                    "[after_handshake] registered in-flight auth for {}, performing auth",
139                                    endpoint_id
140                                );
141                            let endpoint = match self.endpoint().await {
142                                Ok(ep) => ep,
143                                Err(_) => {
144                                    error!("[after_handshake] authenticator endpoint not set");
145                                    return AfterHandshakeOutcome::Reject {
146                                        error_code: VarInt::from_u32(500),
147                                        reason: b"Internal server error".to_vec(),
148                                    };
149                                }
150                            };
151                            if let Err(err) = self.perform_auth(endpoint_id, endpoint).await {
152                                error!(
153                                        "[after_handshake] authentication failed for {}, rejecting connection with error: {}",
154                                        endpoint_id, err
155                                    );
156                                return AfterHandshakeOutcome::Reject {
157                                    error_code: VarInt::from_u32(401),
158                                    reason: b"Authentication failed".to_vec(),
159                                };
160                            } else {
161                                info!(
162                                    "[after_handshake] authentication succeeded for {}",
163                                    endpoint_id
164                                );
165                                debug!(
166                                    "[after_handshake] authentication succeeded for {}, waiting for state update",
167                                    endpoint_id
168                                );
169                                return iroh::endpoint::AfterHandshakeOutcome::accept();
170                            }
171                        }
172                        _ => {
173                            debug!(
174                                    "[after_handshake] failed to register in-flight auth for {}, rejecting connection",
175                                    endpoint_id
176                                );
177                            return AfterHandshakeOutcome::Reject {
178                                error_code: VarInt::from_u32(401),
179                                reason: b"Authentication failed".to_vec(),
180                            };
181                        }
182                    }
183                }
184                AuthState::InFlight => {
185                    debug!(
186                        "[after_handshake] waiting for in-flight auth for {}",
187                        endpoint_id
188                    );
189                    watchable.watcher()
190                }
191                AuthState::Authenticated => {
192                    debug!(
193                        "[after_handshake] already authenticated: {}",
194                        conn.remote_id()
195                    );
196                    return AfterHandshakeOutcome::accept();
197                }
198                AuthState::Blocked => {
199                    debug!(
200                        "[after_handshake] endpoint {} is blocked, rejecting connection",
201                        endpoint_id
202                    );
203                    return AfterHandshakeOutcome::Reject {
204                        error_code: VarInt::from_u32(403),
205                        reason: b"Endpoint is blocked".to_vec(),
206                    };
207                }
208            }
209        } else {
210            debug!(
211                "[after_handshake] no in-flight auth for {}, rejecting connection",
212                endpoint_id
213            );
214            return AfterHandshakeOutcome::Reject {
215                error_code: VarInt::from_u32(401),
216                reason: b"No authentication in progress".to_vec(),
217            };
218        };
219
220        let wait_for_auth = async {
221            trace!(
222                "[after_handshake] subscribing to auth state updates for {}",
223                endpoint_id
224            );
225            let mut stream = in_flight_watcher.watch().stream();
226            while let Some(in_flight) = stream.next().await {
227                trace!(
228                    "[after_handshake] observed auth state update for {} -> {}",
229                    endpoint_id,
230                    in_flight
231                );
232                if matches!(
233                    in_flight,
234                    AuthState::Unauthenticated | AuthState::Authenticated | AuthState::Blocked
235                ) {
236                    trace!(
237                        "[after_handshake] terminal auth state {} reached for {}",
238                        in_flight,
239                        endpoint_id
240                    );
241                    return;
242                }
243            }
244            warn!(
245                "[after_handshake] auth state watch stream ended unexpectedly for {}",
246                endpoint_id
247            );
248        };
249
250        match timeout(AUTH_TIMEOUT, wait_for_auth).await {
251            Ok(_) => {
252                if self.is_authenticated(&endpoint_id).await {
253                    trace!(
254                        "[after_handshake] auth completed successfully for {}",
255                        endpoint_id
256                    );
257                    AfterHandshakeOutcome::accept()
258                } else {
259                    warn!(
260                        "[after_handshake] auth wait finished for {} but endpoint is not authenticated",
261                        endpoint_id
262                    );
263                    AfterHandshakeOutcome::Reject {
264                        error_code: VarInt::from_u32(401),
265                        reason: b"Authentication failed".to_vec(),
266                    }
267                }
268            }
269            Err(_) => {
270                warn!(
271                    "[after_handshake] authentication timed out for {}",
272                    endpoint_id
273                );
274                AfterHandshakeOutcome::Reject {
275                    error_code: VarInt::from_u32(401),
276                    reason: b"Authentication timed out".to_vec(),
277                }
278            }
279        }
280    }
281
282    async fn before_connect<'a>(
283        &'a self,
284        remote_addr: &'a iroh::EndpointAddr,
285        alpn: &'a [u8],
286    ) -> iroh::endpoint::BeforeConnectOutcome {
287        let remote_id = remote_addr.id;
288        trace!(
289            "[before_connect] entered for {} with alpn {}",
290            remote_id,
291            String::from_utf8_lossy(alpn)
292        );
293        if self.is_authenticated(&remote_id).await {
294            debug!("[before_connect] already authenticated: {}", remote_id);
295            return iroh::endpoint::BeforeConnectOutcome::Accept;
296        }
297
298        if alpn == ALPN {
299            debug!(
300                "[before_connect] initiating auth for client connection with alpn {} to {}",
301                String::from_utf8_lossy(alpn),
302                remote_id
303            );
304            return iroh::endpoint::BeforeConnectOutcome::Accept;
305        }
306
307        match register_in_flight(self.auth_state.clone(), remote_id).await {
308            Ok(RegisterResponse::InFlightRegistered) | Ok(RegisterResponse::AlreadyInFlight) => {
309                debug!(
310                    "[before_connect] registered in-flight auth for {}, performing auth",
311                    remote_id
312                );
313
314                let endpoint = match self.endpoint().await {
315                    Ok(ep) => ep,
316                    Err(_) => {
317                        error!("[before_connect] authenticator endpoint not set");
318                        return iroh::endpoint::BeforeConnectOutcome::Reject;
319                    }
320                };
321                if let Err(err) = self.perform_auth(remote_id, endpoint).await {
322                    error!(
323                        "[before_connect] authentication failed for {}, rejecting connection with error: {}",
324                        remote_id, err
325                    );
326                    iroh::endpoint::BeforeConnectOutcome::Reject
327                } else {
328                    info!(
329                        "[before_connect] authentication succeeded for {}",
330                        remote_id
331                    );
332                    iroh::endpoint::BeforeConnectOutcome::Accept
333                }
334            }
335            Ok(RegisterResponse::AlreadyAuthenticated) => {
336                trace!(
337                    "[before_connect] auth already in progress or complete for {}, allowing connect to proceed",
338                    remote_id
339                );
340                if self.is_authenticated(&remote_id).await {
341                    debug!(
342                    "[before_connect] already authenticated (in flight), accepting connection to {}",
343                    remote_id
344                );
345                }
346                iroh::endpoint::BeforeConnectOutcome::Accept
347            }
348            Ok(RegisterResponse::AlreadyBlocked) => {
349                debug!(
350                    "[before_connect] endpoint {} is blocked, rejecting connection",
351                    remote_id
352                );
353                iroh::endpoint::BeforeConnectOutcome::Reject
354            }
355            Err(err) => {
356                warn!(
357                    "[before_connect] failed to register in-flight auth for {}: {}",
358                    remote_id, err
359                );
360                iroh::endpoint::BeforeConnectOutcome::Reject
361            }
362        }
363    }
364}
365
366pub(crate) async fn register_in_flight(
367    in_flight: Arc<Mutex<LruCache<EndpointId, WatchableRemote>>>,
368    endpoint_id: PublicKey,
369) -> Result<RegisterResponse, InFlightError> {
370    trace!(
371        "[register_in_flight] locking auth cache for {}",
372        endpoint_id
373    );
374    let mut guard = in_flight.lock().await;
375    trace!(
376        "[register_in_flight] auth cache locked for {}, current size {}",
377        endpoint_id,
378        guard.len()
379    );
380    if let Some(entry) = guard.get(&endpoint_id) {
381        let current_state = entry.state();
382        trace!(
383            "[register_in_flight] existing state for {} is {}",
384            endpoint_id,
385            current_state
386        );
387        return match current_state {
388            AuthState::Unauthenticated => {
389                entry.set_state(AuthState::InFlight);
390                trace!(
391                    "[register_in_flight] endpoint {} promoted from Unauthenticated to InFlight",
392                    endpoint_id
393                );
394                Ok(RegisterResponse::InFlightRegistered)
395            }
396            AuthState::Authenticated => {
397                trace!(
398                    "[register_in_flight] endpoint {} already authenticated",
399                    endpoint_id
400                );
401                Ok(RegisterResponse::AlreadyAuthenticated)
402            }
403            AuthState::InFlight => {
404                trace!(
405                    "[register_in_flight] endpoint {} already has auth in flight",
406                    endpoint_id
407                );
408                Ok(RegisterResponse::AlreadyInFlight)
409            }
410            AuthState::Blocked => {
411                trace!("[register_in_flight] endpoint {} is blocked", endpoint_id);
412                Ok(RegisterResponse::AlreadyBlocked)
413            }
414        };
415    }
416
417    let watchable = WatchableRemote::new(endpoint_id);
418    watchable.set_state(AuthState::InFlight);
419    trace!(
420        "[register_in_flight] inserting new auth state entry for {} as InFlight",
421        endpoint_id
422    );
423
424    if let Some(evicted) = guard.put(endpoint_id, watchable) {
425        debug!(
426            "evicting endpoint {} from auth cache due to capacity limit",
427            evicted.id()
428        );
429    }
430
431    Ok(RegisterResponse::InFlightRegistered)
432}
433
434pub(crate) async fn release_in_flight(
435    in_flight: Arc<Mutex<LruCache<EndpointId, WatchableRemote>>>,
436    endpoint_id: PublicKey,
437    target_state: AuthState,
438) -> Result<(), InFlightError> {
439    trace!(
440        "[release_in_flight] requested state release for {} -> {}",
441        endpoint_id,
442        target_state
443    );
444    if target_state == AuthState::InFlight {
445        return Err(InFlightError::PromotionNotAllowed(
446            "cannot release by promoting to InFlight".to_string(),
447        ));
448    }
449    trace!("[release_in_flight] locking auth cache for {}", endpoint_id);
450    let mut guard = in_flight.lock().await;
451    trace!(
452        "[release_in_flight] auth cache locked for {}, current size {}",
453        endpoint_id,
454        guard.len()
455    );
456
457    // occupied
458    if let Some(entry) = guard.get(&endpoint_id) {
459        let current_state = entry.state();
460        let target_state_for_logs = target_state.clone();
461        trace!(
462            "[release_in_flight] current state for {} is {}, target {}",
463            endpoint_id,
464            current_state,
465            target_state_for_logs
466        );
467        return match current_state {
468            AuthState::InFlight => {
469                entry.set_state(target_state);
470                trace!(
471                    "[release_in_flight] endpoint {} released from InFlight to {}",
472                    endpoint_id,
473                    target_state_for_logs
474                );
475                Ok(())
476            }
477            AuthState::Authenticated => {
478                if target_state == AuthState::Blocked {
479                    entry.set_state(AuthState::Blocked);
480                    debug!(
481                        "endpoint {} was authenticated but is now blocked, updating state to Blocked",
482                        endpoint_id
483                    );
484                    Ok(())
485                } else {
486                    trace!("endpoint {} is already authenticated, no-op", endpoint_id);
487                    Ok(())
488                }
489            }
490            AuthState::Unauthenticated => match target_state {
491                AuthState::Blocked => {
492                    entry.set_state(AuthState::Blocked);
493                    debug!(
494                            "endpoint {} was unauthenticated but is now blocked, updating state to Blocked",
495                            endpoint_id
496                        );
497                    Ok(())
498                }
499                AuthState::Authenticated => {
500                    trace!("promoting endpoint {} from Unauthenticated to Authenticated (this is required because we can have asymetric failures that lead to this state transition)", endpoint_id);
501
502                    entry.set_state(AuthState::Authenticated);
503                    Ok(())
504                }
505                AuthState::Unauthenticated => {
506                    trace!("endpoint {} is already unauthenticated, no-op", endpoint_id);
507                    Ok(())
508                }
509                AuthState::InFlight => {
510                    trace!(
511                        "cannot promote endpoint {} from Unauthenticated back to InFlight",
512                        endpoint_id
513                    );
514                    Err(InFlightError::PromotionNotAllowed(
515                        "cannot promote to InFlight".to_string(),
516                    ))
517                }
518            },
519            current_state => {
520                if current_state == target_state {
521                    debug!(
522                        "endpoint {} is already in target state {}, no state change needed",
523                        endpoint_id, target_state
524                    );
525                    Ok(())
526                } else {
527                    warn!(
528                        "[release_in_flight] refusing state overwrite for {} from {} to {}",
529                        endpoint_id, current_state, target_state
530                    );
531                    Err(InFlightError::PromotionNotAllowed(format!(
532                        "only promote to {} from {} not from {}",
533                        target_state,
534                        AuthState::InFlight,
535                        entry.state()
536                    )))
537                }
538            }
539        };
540    }
541
542    // vacant
543    let watchable = WatchableRemote::new(endpoint_id);
544    let target_state_for_logs = target_state.clone();
545    watchable.set_state(target_state);
546    trace!(
547        "[release_in_flight] no auth state entry existed for {}, inserting {}",
548        endpoint_id,
549        target_state_for_logs
550    );
551
552    if let Some(evicted) = guard.put(endpoint_id, watchable) {
553        debug!(
554            "evicting endpoint {} from auth cache due to capacity limit",
555            evicted.id()
556        );
557    }
558
559    Ok(())
560}
561
562pub(crate) async fn get_auth_state(
563    auth_state: Arc<Mutex<LruCache<EndpointId, WatchableRemote>>>,
564    endpoint_id: PublicKey,
565) -> Option<WatchableRemote> {
566    trace!("[get_auth_state] locking auth cache for {}", endpoint_id);
567    let mut guard = auth_state.lock().await;
568    let result = guard.get(&endpoint_id).cloned();
569    match result.as_ref() {
570        Some(watchable) => {
571            trace!(
572                "[get_auth_state] found auth state for {}: {}",
573                endpoint_id,
574                watchable.state()
575            );
576        }
577        None => {
578            trace!("[get_auth_state] no auth state found for {}", endpoint_id);
579        }
580    }
581    result
582}