Skip to main content

openauth_plugins/mcp/
mod.rs

1//! Model Context Protocol OAuth plugin.
2
3mod authorize;
4mod claims;
5pub mod client;
6mod consent;
7mod metadata;
8mod register;
9mod schema;
10mod session;
11mod shared;
12mod token;
13mod userinfo;
14
15use openauth_core::plugin::AuthPlugin;
16use openauth_core::plugin::{PluginAfterHookAction, PluginAfterHookFuture};
17use openauth_core::{db::User, error::OpenAuthError};
18use serde::{Deserialize, Serialize};
19use serde_json::{Map, Value};
20use std::sync::Arc;
21use thiserror::Error;
22
23pub const UPSTREAM_PLUGIN_ID: &str = "mcp";
24
25const DEFAULT_SCOPES: [&str; 4] = ["openid", "profile", "email", "offline_access"];
26
27pub type McpClientIdGenerator = Arc<dyn Fn() -> String + Send + Sync>;
28pub type McpClientSecretGenerator = Arc<dyn Fn() -> String + Send + Sync>;
29pub type McpAdditionalIdTokenClaims =
30    Arc<dyn Fn(&User, &[String]) -> Result<Map<String, Value>, OpenAuthError> + Send + Sync>;
31
32/// Token endpoint authentication methods accepted by dynamic registration.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
34#[serde(rename_all = "snake_case")]
35pub enum TokenEndpointAuthMethod {
36    None,
37    ClientSecretBasic,
38    ClientSecretPost,
39}
40
41impl TokenEndpointAuthMethod {
42    pub fn as_str(self) -> &'static str {
43        match self {
44            Self::None => "none",
45            Self::ClientSecretBasic => "client_secret_basic",
46            Self::ClientSecretPost => "client_secret_post",
47        }
48    }
49}
50
51/// Optional OIDC-style settings used by the MCP plugin.
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct McpOidcConfig {
54    pub scopes: Vec<String>,
55    pub default_scope: String,
56    pub code_expires_in: u64,
57    pub access_token_expires_in: u64,
58    pub refresh_token_expires_in: u64,
59    pub allow_plain_code_challenge_method: bool,
60    pub require_pkce: bool,
61}
62
63impl Default for McpOidcConfig {
64    fn default() -> Self {
65        Self {
66            scopes: Vec::new(),
67            default_scope: "openid".to_owned(),
68            code_expires_in: 600,
69            access_token_expires_in: 3600,
70            refresh_token_expires_in: 604800,
71            allow_plain_code_challenge_method: true,
72            require_pkce: false,
73        }
74    }
75}
76
77/// Metadata extension points for OAuth discovery responses.
78#[derive(Debug, Clone, Default, PartialEq, Serialize)]
79pub struct McpMetadataOverrides {
80    pub authorization_server: Map<String, Value>,
81    pub protected_resource: Map<String, Value>,
82}
83
84/// User-facing MCP plugin options.
85#[derive(Clone, Default)]
86pub struct McpOptions {
87    pub login_page: String,
88    pub consent_page: Option<String>,
89    pub resource: Option<String>,
90    pub oidc_config: McpOidcConfig,
91    pub metadata: McpMetadataOverrides,
92    pub client_id_generator: Option<McpClientIdGenerator>,
93    pub client_secret_generator: Option<McpClientSecretGenerator>,
94    pub additional_id_token_claims: Option<McpAdditionalIdTokenClaims>,
95}
96
97/// Resolved MCP options after upstream-compatible defaults are applied.
98#[derive(Debug, Clone, PartialEq, Serialize)]
99pub struct ResolvedMcpOptions {
100    pub login_page: String,
101    pub consent_page: Option<String>,
102    pub resource: Option<String>,
103    pub scopes: Vec<String>,
104    pub default_scope: Vec<String>,
105    pub code_expires_in: u64,
106    pub access_token_expires_in: u64,
107    pub refresh_token_expires_in: u64,
108    pub allow_plain_code_challenge_method: bool,
109    pub require_pkce: bool,
110    pub metadata: McpMetadataOverrides,
111}
112
113/// Typed MCP plugin returned by [`mcp`].
114#[derive(Debug, Clone)]
115pub struct McpPlugin {
116    pub id: String,
117    pub version: String,
118    pub options: ResolvedMcpOptions,
119    auth_plugin: AuthPlugin,
120}
121
122impl McpPlugin {
123    pub fn into_auth_plugin(self) -> AuthPlugin {
124        self.auth_plugin
125    }
126
127    pub fn as_auth_plugin(&self) -> &AuthPlugin {
128        &self.auth_plugin
129    }
130}
131
132/// MCP configuration errors.
133#[derive(Debug, Clone, PartialEq, Eq, Error)]
134pub enum McpConfigError {
135    #[error("login_page is required")]
136    MissingLoginPage,
137}
138
139/// Build the MCP OAuth plugin.
140pub fn mcp(options: McpOptions) -> Result<McpPlugin, McpConfigError> {
141    if options.login_page.is_empty() {
142        return Err(McpConfigError::MissingLoginPage);
143    }
144    let client_id_generator = options.client_id_generator.clone();
145    let client_secret_generator = options.client_secret_generator.clone();
146    let additional_id_token_claims = options.additional_id_token_claims.clone();
147
148    let mut scopes = DEFAULT_SCOPES
149        .into_iter()
150        .map(str::to_owned)
151        .collect::<Vec<_>>();
152    for scope in options.oidc_config.scopes {
153        if !scope.is_empty() && !scopes.contains(&scope) {
154            scopes.push(scope);
155        }
156    }
157
158    let mut default_scope = options
159        .oidc_config
160        .default_scope
161        .split_whitespace()
162        .filter(|scope| !scope.is_empty())
163        .map(str::to_owned)
164        .collect::<Vec<_>>();
165    if default_scope.is_empty() {
166        default_scope.push("openid".to_owned());
167    }
168
169    let resolved = ResolvedMcpOptions {
170        login_page: options.login_page,
171        consent_page: options.consent_page,
172        resource: options.resource,
173        scopes,
174        default_scope,
175        code_expires_in: options.oidc_config.code_expires_in,
176        access_token_expires_in: options.oidc_config.access_token_expires_in,
177        refresh_token_expires_in: options.oidc_config.refresh_token_expires_in,
178        allow_plain_code_challenge_method: options.oidc_config.allow_plain_code_challenge_method,
179        require_pkce: options.oidc_config.require_pkce,
180        metadata: options.metadata,
181    };
182
183    let auth_plugin = AuthPlugin::new(UPSTREAM_PLUGIN_ID)
184        .with_version(env!("CARGO_PKG_VERSION"))
185        .with_options(serde_json::to_value(&resolved).unwrap_or(serde_json::Value::Null))
186        .with_schema(schema::oauth_application_schema())
187        .with_schema(schema::oauth_access_token_schema())
188        .with_schema(schema::oauth_consent_schema())
189        .with_endpoint(metadata::authorization_server_endpoint(resolved.clone()))
190        .with_endpoint(metadata::protected_resource_endpoint(resolved.clone()))
191        .with_endpoint(register::register_endpoint(
192            resolved.clone(),
193            client_id_generator,
194            client_secret_generator,
195        ))
196        .with_endpoint(authorize::authorize_endpoint(resolved.clone()))
197        .with_endpoint(consent::consent_endpoint(resolved.clone()))
198        .with_endpoint(token::token_endpoint(
199            resolved.clone(),
200            additional_id_token_claims.clone(),
201        ))
202        .with_endpoint(userinfo::userinfo_endpoint(additional_id_token_claims))
203        .with_endpoint(userinfo::jwks_endpoint())
204        .with_endpoint(session::get_session_endpoint())
205        .with_async_after_hook("*", {
206            let resolved = resolved.clone();
207            move |context, request, response| -> PluginAfterHookFuture<'_> {
208                let resolved = resolved.clone();
209                Box::pin(async move {
210                    let response =
211                        authorize::resume_after_login(context, request, response, &resolved)
212                            .await?;
213                    Ok(PluginAfterHookAction::Continue(response))
214                })
215            }
216        });
217
218    Ok(McpPlugin {
219        id: UPSTREAM_PLUGIN_ID.to_owned(),
220        version: env!("CARGO_PKG_VERSION").to_owned(),
221        options: resolved,
222        auth_plugin,
223    })
224}