1use std::{collections::HashSet, future::Future, pin::Pin, sync::Arc};
2
3use crate::{
4 auth::{
5 credentials::Credentials,
6 identity::AuthenticatedIdentity,
7 provider::{AuthFuture, AuthProvider},
8 },
9 error::{McpError, McpResult},
10};
11
12pub type ApiKeyValidatorFn = Arc<
14 dyn Fn(String) -> Pin<Box<dyn Future<Output = McpResult<AuthenticatedIdentity>> + Send>>
15 + Send
16 + Sync,
17>;
18
19enum Inner {
20 Static(HashSet<String>),
21 Custom(ApiKeyValidatorFn),
22}
23
24pub struct ApiKeyProvider {
46 inner: Inner,
47}
48
49impl ApiKeyProvider {
50 pub fn new(keys: impl IntoIterator<Item = impl Into<String>>) -> Self {
53 Self {
54 inner: Inner::Static(keys.into_iter().map(Into::into).collect()),
55 }
56 }
57
58 pub fn with_validator<F, Fut>(f: F) -> Self
60 where
61 F: Fn(String) -> Fut + Send + Sync + 'static,
62 Fut: Future<Output = McpResult<AuthenticatedIdentity>> + Send + 'static,
63 {
64 Self {
65 inner: Inner::Custom(Arc::new(move |key| Box::pin(f(key)))),
66 }
67 }
68}
69
70impl AuthProvider for ApiKeyProvider {
71 fn authenticate<'a>(&'a self, credentials: &'a Credentials) -> AuthFuture<'a> {
72 Box::pin(async move {
73 match credentials {
74 Credentials::ApiKey { key } => match &self.inner {
75 Inner::Static(set) => {
76 if set.contains(key.as_str()) {
77 Ok(AuthenticatedIdentity::new(key.clone()))
78 } else {
79 Err(McpError::Unauthorized("invalid API key".into()))
80 }
81 }
82 Inner::Custom(f) => f(key.clone()).await,
83 },
84 _ => Err(McpError::Unauthorized("expected API key".into())),
85 }
86 })
87 }
88
89 fn accepts(&self, credentials: &Credentials) -> bool {
90 matches!(credentials, Credentials::ApiKey { .. })
91 }
92}