Skip to main content

mcp_kit/auth/
custom.rs

1use std::{future::Future, pin::Pin, sync::Arc};
2
3use crate::{
4    auth::{
5        credentials::Credentials,
6        identity::AuthenticatedIdentity,
7        provider::{AuthFuture, AuthProvider},
8    },
9    error::{McpError, McpResult},
10};
11
12/// A custom async validator for arbitrary header values.
13pub type CustomValidatorFn = Arc<
14    dyn Fn(String) -> Pin<Box<dyn Future<Output = McpResult<AuthenticatedIdentity>> + Send>>
15        + Send
16        + Sync,
17>;
18
19/// Validates a custom HTTP header credential (e.g. `X-Internal-Token`).
20///
21/// The transport layer extracts the value of the configured header and wraps it
22/// in [`Credentials::CustomHeader`] before passing it to this provider.
23///
24/// # Examples
25///
26/// ```rust,no_run
27/// use mcp_kit::auth::{CustomHeaderProvider, AuthenticatedIdentity};
28///
29/// let provider = CustomHeaderProvider::new("x-internal-token", |value| async move {
30///     if value == "trusted" {
31///         Ok(AuthenticatedIdentity::new("internal-service"))
32///     } else {
33///         Err(mcp_kit::McpError::Unauthorized("invalid token".into()))
34///     }
35/// });
36/// ```
37pub struct CustomHeaderProvider {
38    header_name: String,
39    validator: CustomValidatorFn,
40}
41
42impl CustomHeaderProvider {
43    /// Create a provider that validates values from the named header.
44    ///
45    /// `header_name` is case-insensitive; it will be normalised to lowercase.
46    pub fn new<F, Fut>(header_name: impl Into<String>, f: F) -> Self
47    where
48        F: Fn(String) -> Fut + Send + Sync + 'static,
49        Fut: Future<Output = McpResult<AuthenticatedIdentity>> + Send + 'static,
50    {
51        Self {
52            header_name: header_name.into().to_lowercase(),
53            validator: Arc::new(move |v| Box::pin(f(v))),
54        }
55    }
56
57    /// The header name this provider expects (normalised to lowercase).
58    pub fn header_name(&self) -> &str {
59        &self.header_name
60    }
61}
62
63impl AuthProvider for CustomHeaderProvider {
64    fn authenticate<'a>(&'a self, credentials: &'a Credentials) -> AuthFuture<'a> {
65        Box::pin(async move {
66            match credentials {
67                Credentials::CustomHeader { header_name, value }
68                    if header_name == &self.header_name =>
69                {
70                    (self.validator)(value.clone()).await
71                }
72                Credentials::CustomHeader { header_name, .. } => Err(McpError::Unauthorized(
73                    format!("unexpected header: {header_name}"),
74                )),
75                _ => Err(McpError::Unauthorized(format!(
76                    "expected custom header: {}",
77                    self.header_name
78                ))),
79            }
80        })
81    }
82
83    fn accepts(&self, credentials: &Credentials) -> bool {
84        match credentials {
85            Credentials::CustomHeader { header_name, .. } => header_name == &self.header_name,
86            _ => false,
87        }
88    }
89}