1use std::{collections::HashSet, future::Future, pin::Pin, sync::Arc};
2
3use crate::{
4 auth::{
5 credentials::Credentials,
6 identity::AuthenticatedIdentity,
7 provider::{AuthFuture, AuthProvider},
8 },
9 error::{McpError, McpResult},
10};
11
12pub type BearerValidatorFn = Arc<
14 dyn Fn(String) -> Pin<Box<dyn Future<Output = McpResult<AuthenticatedIdentity>> + Send>>
15 + Send
16 + Sync,
17>;
18
19enum Inner {
20 Static(HashSet<String>),
22 Custom(BearerValidatorFn),
24}
25
26pub struct BearerTokenProvider {
50 inner: Inner,
51}
52
53impl BearerTokenProvider {
54 pub fn new(tokens: impl IntoIterator<Item = impl Into<String>>) -> Self {
57 Self {
58 inner: Inner::Static(tokens.into_iter().map(Into::into).collect()),
59 }
60 }
61
62 pub fn with_validator<F, Fut>(f: F) -> Self
64 where
65 F: Fn(String) -> Fut + Send + Sync + 'static,
66 Fut: Future<Output = McpResult<AuthenticatedIdentity>> + Send + 'static,
67 {
68 Self {
69 inner: Inner::Custom(Arc::new(move |token| Box::pin(f(token)))),
70 }
71 }
72}
73
74impl AuthProvider for BearerTokenProvider {
75 fn authenticate<'a>(&'a self, credentials: &'a Credentials) -> AuthFuture<'a> {
76 Box::pin(async move {
77 match credentials {
78 Credentials::Bearer { token } => match &self.inner {
79 Inner::Static(set) => {
80 if set.contains(token.as_str()) {
81 Ok(AuthenticatedIdentity::new(token.clone()))
82 } else {
83 Err(McpError::Unauthorized("invalid bearer token".into()))
84 }
85 }
86 Inner::Custom(f) => f(token.clone()).await,
87 },
88 _ => Err(McpError::Unauthorized("expected bearer token".into())),
89 }
90 })
91 }
92
93 fn accepts(&self, credentials: &Credentials) -> bool {
94 matches!(credentials, Credentials::Bearer { .. })
95 }
96}