1mod authorize;
4mod claims;
5pub mod client;
6mod consent;
7mod metadata;
8mod register;
9mod schema;
10mod session;
11mod shared;
12mod token;
13mod userinfo;
14
15use openauth_core::plugin::AuthPlugin;
16use openauth_core::plugin::{PluginAfterHookAction, PluginAfterHookFuture};
17use openauth_core::{db::User, error::OpenAuthError};
18use serde::{Deserialize, Serialize};
19use serde_json::{Map, Value};
20use std::sync::Arc;
21use thiserror::Error;
22
23pub const UPSTREAM_PLUGIN_ID: &str = "mcp";
24
25const DEFAULT_SCOPES: [&str; 4] = ["openid", "profile", "email", "offline_access"];
26
27pub type McpClientIdGenerator = Arc<dyn Fn() -> String + Send + Sync>;
28pub type McpClientSecretGenerator = Arc<dyn Fn() -> String + Send + Sync>;
29pub type McpAdditionalIdTokenClaims =
30 Arc<dyn Fn(&User, &[String]) -> Result<Map<String, Value>, OpenAuthError> + Send + Sync>;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
34#[serde(rename_all = "snake_case")]
35pub enum TokenEndpointAuthMethod {
36 None,
37 ClientSecretBasic,
38 ClientSecretPost,
39}
40
41impl TokenEndpointAuthMethod {
42 pub fn as_str(self) -> &'static str {
43 match self {
44 Self::None => "none",
45 Self::ClientSecretBasic => "client_secret_basic",
46 Self::ClientSecretPost => "client_secret_post",
47 }
48 }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct McpOidcConfig {
54 pub scopes: Vec<String>,
55 pub default_scope: String,
56 pub code_expires_in: u64,
57 pub access_token_expires_in: u64,
58 pub refresh_token_expires_in: u64,
59 pub allow_plain_code_challenge_method: bool,
60 pub require_pkce: bool,
61}
62
63impl Default for McpOidcConfig {
64 fn default() -> Self {
65 Self {
66 scopes: Vec::new(),
67 default_scope: "openid".to_owned(),
68 code_expires_in: 600,
69 access_token_expires_in: 3600,
70 refresh_token_expires_in: 604800,
71 allow_plain_code_challenge_method: true,
72 require_pkce: false,
73 }
74 }
75}
76
77#[derive(Debug, Clone, Default, PartialEq, Serialize)]
79pub struct McpMetadataOverrides {
80 pub authorization_server: Map<String, Value>,
81 pub protected_resource: Map<String, Value>,
82}
83
84#[derive(Clone, Default)]
86pub struct McpOptions {
87 pub login_page: String,
88 pub consent_page: Option<String>,
89 pub resource: Option<String>,
90 pub oidc_config: McpOidcConfig,
91 pub metadata: McpMetadataOverrides,
92 pub client_id_generator: Option<McpClientIdGenerator>,
93 pub client_secret_generator: Option<McpClientSecretGenerator>,
94 pub additional_id_token_claims: Option<McpAdditionalIdTokenClaims>,
95}
96
97#[derive(Debug, Clone, PartialEq, Serialize)]
99pub struct ResolvedMcpOptions {
100 pub login_page: String,
101 pub consent_page: Option<String>,
102 pub resource: Option<String>,
103 pub scopes: Vec<String>,
104 pub default_scope: Vec<String>,
105 pub code_expires_in: u64,
106 pub access_token_expires_in: u64,
107 pub refresh_token_expires_in: u64,
108 pub allow_plain_code_challenge_method: bool,
109 pub require_pkce: bool,
110 pub metadata: McpMetadataOverrides,
111}
112
113#[derive(Debug, Clone)]
115pub struct McpPlugin {
116 pub id: String,
117 pub version: String,
118 pub options: ResolvedMcpOptions,
119 auth_plugin: AuthPlugin,
120}
121
122impl McpPlugin {
123 pub fn into_auth_plugin(self) -> AuthPlugin {
124 self.auth_plugin
125 }
126
127 pub fn as_auth_plugin(&self) -> &AuthPlugin {
128 &self.auth_plugin
129 }
130}
131
132#[derive(Debug, Clone, PartialEq, Eq, Error)]
134pub enum McpConfigError {
135 #[error("login_page is required")]
136 MissingLoginPage,
137}
138
139pub fn mcp(options: McpOptions) -> Result<McpPlugin, McpConfigError> {
141 if options.login_page.is_empty() {
142 return Err(McpConfigError::MissingLoginPage);
143 }
144 let client_id_generator = options.client_id_generator.clone();
145 let client_secret_generator = options.client_secret_generator.clone();
146 let additional_id_token_claims = options.additional_id_token_claims.clone();
147
148 let mut scopes = DEFAULT_SCOPES
149 .into_iter()
150 .map(str::to_owned)
151 .collect::<Vec<_>>();
152 for scope in options.oidc_config.scopes {
153 if !scope.is_empty() && !scopes.contains(&scope) {
154 scopes.push(scope);
155 }
156 }
157
158 let mut default_scope = options
159 .oidc_config
160 .default_scope
161 .split_whitespace()
162 .filter(|scope| !scope.is_empty())
163 .map(str::to_owned)
164 .collect::<Vec<_>>();
165 if default_scope.is_empty() {
166 default_scope.push("openid".to_owned());
167 }
168
169 let resolved = ResolvedMcpOptions {
170 login_page: options.login_page,
171 consent_page: options.consent_page,
172 resource: options.resource,
173 scopes,
174 default_scope,
175 code_expires_in: options.oidc_config.code_expires_in,
176 access_token_expires_in: options.oidc_config.access_token_expires_in,
177 refresh_token_expires_in: options.oidc_config.refresh_token_expires_in,
178 allow_plain_code_challenge_method: options.oidc_config.allow_plain_code_challenge_method,
179 require_pkce: options.oidc_config.require_pkce,
180 metadata: options.metadata,
181 };
182
183 let auth_plugin = AuthPlugin::new(UPSTREAM_PLUGIN_ID)
184 .with_version(env!("CARGO_PKG_VERSION"))
185 .with_options(serde_json::to_value(&resolved).unwrap_or(serde_json::Value::Null))
186 .with_schema(schema::oauth_application_schema())
187 .with_schema(schema::oauth_access_token_schema())
188 .with_schema(schema::oauth_consent_schema())
189 .with_endpoint(metadata::authorization_server_endpoint(resolved.clone()))
190 .with_endpoint(metadata::protected_resource_endpoint(resolved.clone()))
191 .with_endpoint(register::register_endpoint(
192 resolved.clone(),
193 client_id_generator,
194 client_secret_generator,
195 ))
196 .with_endpoint(authorize::authorize_endpoint(resolved.clone()))
197 .with_endpoint(consent::consent_endpoint(resolved.clone()))
198 .with_endpoint(token::token_endpoint(
199 resolved.clone(),
200 additional_id_token_claims.clone(),
201 ))
202 .with_endpoint(userinfo::userinfo_endpoint(additional_id_token_claims))
203 .with_endpoint(userinfo::jwks_endpoint())
204 .with_endpoint(session::get_session_endpoint())
205 .with_async_after_hook("*", {
206 let resolved = resolved.clone();
207 move |context, request, response| -> PluginAfterHookFuture<'_> {
208 let resolved = resolved.clone();
209 Box::pin(async move {
210 let response =
211 authorize::resume_after_login(context, request, response, &resolved)
212 .await?;
213 Ok(PluginAfterHookAction::Continue(response))
214 })
215 }
216 });
217
218 Ok(McpPlugin {
219 id: UPSTREAM_PLUGIN_ID.to_owned(),
220 version: env!("CARGO_PKG_VERSION").to_owned(),
221 options: resolved,
222 auth_plugin,
223 })
224}