1use std::{future::Future, pin::Pin, sync::Arc};
2
3use crate::{
4 auth::{
5 credentials::Credentials,
6 identity::AuthenticatedIdentity,
7 provider::{AuthFuture, AuthProvider},
8 },
9 error::{McpError, McpResult},
10};
11
12pub type CustomValidatorFn = Arc<
14 dyn Fn(String) -> Pin<Box<dyn Future<Output = McpResult<AuthenticatedIdentity>> + Send>>
15 + Send
16 + Sync,
17>;
18
19pub struct CustomHeaderProvider {
38 header_name: String,
39 validator: CustomValidatorFn,
40}
41
42impl CustomHeaderProvider {
43 pub fn new<F, Fut>(header_name: impl Into<String>, f: F) -> Self
47 where
48 F: Fn(String) -> Fut + Send + Sync + 'static,
49 Fut: Future<Output = McpResult<AuthenticatedIdentity>> + Send + 'static,
50 {
51 Self {
52 header_name: header_name.into().to_lowercase(),
53 validator: Arc::new(move |v| Box::pin(f(v))),
54 }
55 }
56
57 pub fn header_name(&self) -> &str {
59 &self.header_name
60 }
61}
62
63impl AuthProvider for CustomHeaderProvider {
64 fn authenticate<'a>(&'a self, credentials: &'a Credentials) -> AuthFuture<'a> {
65 Box::pin(async move {
66 match credentials {
67 Credentials::CustomHeader { header_name, value }
68 if header_name == &self.header_name =>
69 {
70 (self.validator)(value.clone()).await
71 }
72 Credentials::CustomHeader { header_name, .. } => Err(McpError::Unauthorized(
73 format!("unexpected header: {header_name}"),
74 )),
75 _ => Err(McpError::Unauthorized(format!(
76 "expected custom header: {}",
77 self.header_name
78 ))),
79 }
80 })
81 }
82
83 fn accepts(&self, credentials: &Credentials) -> bool {
84 match credentials {
85 Credentials::CustomHeader { header_name, .. } => header_name == &self.header_name,
86 _ => false,
87 }
88 }
89}