1use std::path::Path;
2use std::sync::{Arc, Mutex};
3
4use base64::engine::general_purpose::URL_SAFE_NO_PAD;
5use base64::Engine;
6use chrono::{DateTime, Duration, Utc};
7use rusqlite::{params, Connection, OptionalExtension};
8use serde::{Deserialize, Serialize};
9use serde_json::{json, Value};
10use sha2::{Digest, Sha256};
11use uuid::Uuid;
12
13use crate::core::{Result, ShuttleError};
14
15const MCP_SCOPE: &str = "mcp";
16
17#[derive(Clone)]
18pub struct OAuthConfig {
19 pub public_url: String,
20 pub admin_token: Option<String>,
26}
27
28impl OAuthConfig {
29 pub fn normalize_public_url(public_url: String) -> String {
30 public_url.trim().trim_end_matches('/').to_owned()
31 }
32
33 pub fn resource_url(&self) -> String {
34 format!("{}/mcp", self.public_url)
35 }
36}
37
38#[derive(Clone)]
39pub struct OAuthStore {
40 conn: Arc<Mutex<Connection>>,
41}
42
43impl OAuthStore {
44 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
45 let conn = Connection::open(path).map_err(to_store_error)?;
46 let store = Self {
47 conn: Arc::new(Mutex::new(conn)),
48 };
49 store.init()?;
50 Ok(store)
51 }
52
53 fn init(&self) -> Result<()> {
54 let conn = self
55 .conn
56 .lock()
57 .map_err(|err| ShuttleError::Store(err.to_string()))?;
58 conn.execute_batch(
59 r#"
60 CREATE TABLE IF NOT EXISTS oauth_clients (
61 client_id TEXT PRIMARY KEY NOT NULL,
62 client_secret TEXT,
63 redirect_uris TEXT NOT NULL,
64 client_name TEXT,
65 created_at TEXT NOT NULL
66 );
67
68 CREATE TABLE IF NOT EXISTS oauth_codes (
69 code TEXT PRIMARY KEY NOT NULL,
70 client_id TEXT NOT NULL,
71 redirect_uri TEXT NOT NULL,
72 code_challenge TEXT NOT NULL,
73 code_challenge_method TEXT NOT NULL,
74 scope TEXT NOT NULL,
75 expires_at TEXT NOT NULL,
76 used_at TEXT,
77 created_at TEXT NOT NULL
78 );
79
80 CREATE TABLE IF NOT EXISTS oauth_tokens (
81 token TEXT PRIMARY KEY NOT NULL,
82 client_id TEXT NOT NULL,
83 scope TEXT NOT NULL,
84 expires_at TEXT NOT NULL,
85 created_at TEXT NOT NULL
86 );
87 "#,
88 )
89 .map_err(to_store_error)?;
90 purge_expired(&conn)?;
91 Ok(())
92 }
93
94 pub fn register_client(&self, request: RegisterRequest) -> Result<RegisteredClient> {
95 if request.redirect_uris.is_empty() {
96 return Err(ShuttleError::Store(
97 "redirect_uris must contain at least one URI".to_owned(),
98 ));
99 }
100 let client = RegisteredClient {
101 client_id: token(),
102 client_secret: None,
103 redirect_uris: request.redirect_uris,
104 client_name: request.client_name,
105 };
106 let conn = self
107 .conn
108 .lock()
109 .map_err(|err| ShuttleError::Store(err.to_string()))?;
110 conn.execute(
111 "INSERT INTO oauth_clients (client_id, client_secret, redirect_uris, client_name, created_at)
112 VALUES (?1, ?2, ?3, ?4, ?5)",
113 params![
114 client.client_id,
115 client.client_secret,
116 serde_json::to_string(&client.redirect_uris)
117 .map_err(|err| ShuttleError::Serialization(err.to_string()))?,
118 client.client_name,
119 Utc::now().to_rfc3339()
120 ],
121 )
122 .map_err(to_store_error)?;
123 Ok(client)
124 }
125
126 pub fn client_allows_redirect(&self, client_id: &str, redirect_uri: &str) -> Result<bool> {
127 let conn = self
128 .conn
129 .lock()
130 .map_err(|err| ShuttleError::Store(err.to_string()))?;
131 let redirect_uris = conn
132 .query_row(
133 "SELECT redirect_uris FROM oauth_clients WHERE client_id = ?1",
134 params![client_id],
135 |row| row.get::<_, String>(0),
136 )
137 .optional()
138 .map_err(to_store_error)?;
139 let Some(redirect_uris) = redirect_uris else {
140 return Ok(false);
141 };
142 let redirect_uris: Vec<String> = serde_json::from_str(&redirect_uris)
143 .map_err(|err| ShuttleError::Serialization(err.to_string()))?;
144 Ok(redirect_uris.iter().any(|uri| uri == redirect_uri))
145 }
146
147 pub fn create_code(&self, request: AuthorizeRequest) -> Result<String> {
148 if request.response_type != "code" {
149 return Err(ShuttleError::Store("response_type must be code".to_owned()));
150 }
151 if !self.client_allows_redirect(&request.client_id, &request.redirect_uri)? {
152 return Err(ShuttleError::Store(
153 "unknown client_id or redirect_uri".to_owned(),
154 ));
155 }
156 if request.code_challenge_method.as_deref() != Some("S256") {
157 return Err(ShuttleError::Store(
158 "code_challenge_method must be S256".to_owned(),
159 ));
160 }
161 let Some(code_challenge) = request.code_challenge else {
162 return Err(ShuttleError::Store("missing code_challenge".to_owned()));
163 };
164 let scope = normalize_scope(request.scope);
165 let code = token();
166 let now = Utc::now();
167 let conn = self
168 .conn
169 .lock()
170 .map_err(|err| ShuttleError::Store(err.to_string()))?;
171 conn.execute(
172 "INSERT INTO oauth_codes (
173 code, client_id, redirect_uri, code_challenge, code_challenge_method,
174 scope, expires_at, created_at
175 ) VALUES (?1, ?2, ?3, ?4, 'S256', ?5, ?6, ?7)",
176 params![
177 code,
178 request.client_id,
179 request.redirect_uri,
180 code_challenge,
181 scope,
182 (now + Duration::minutes(10)).to_rfc3339(),
183 now.to_rfc3339()
184 ],
185 )
186 .map_err(to_store_error)?;
187 Ok(code)
188 }
189
190 pub fn exchange_code(&self, request: TokenRequest) -> Result<TokenResponse> {
191 if request.grant_type != "authorization_code" {
192 return Err(ShuttleError::Store(
193 "grant_type must be authorization_code".to_owned(),
194 ));
195 }
196 let code = request
197 .code
198 .ok_or_else(|| ShuttleError::Store("missing code".to_owned()))?;
199 let verifier = request
200 .code_verifier
201 .ok_or_else(|| ShuttleError::Store("missing code_verifier".to_owned()))?;
202 let mut conn = self
203 .conn
204 .lock()
205 .map_err(|err| ShuttleError::Store(err.to_string()))?;
206 let tx = conn.transaction().map_err(to_store_error)?;
207 let stored = tx
208 .query_row(
209 "SELECT client_id, redirect_uri, code_challenge, scope, expires_at
210 FROM oauth_codes WHERE code = ?1 AND used_at IS NULL",
211 params![code],
212 |row| {
213 Ok(StoredCode {
214 client_id: row.get(0)?,
215 redirect_uri: row.get(1)?,
216 code_challenge: row.get(2)?,
217 scope: row.get(3)?,
218 expires_at: row.get(4)?,
219 })
220 },
221 )
222 .optional()
223 .map_err(to_store_error)?;
224 let Some(stored) = stored else {
225 let exists = tx
226 .query_row(
227 "SELECT 1 FROM oauth_codes WHERE code = ?1",
228 params![code],
229 |_| Ok(()),
230 )
231 .optional()
232 .map_err(to_store_error)?
233 .is_some();
234 return Err(ShuttleError::Store(if exists {
235 "code already used".to_owned()
236 } else {
237 "invalid code".to_owned()
238 }));
239 };
240
241 if stored.client_id != request.client_id {
242 return Err(ShuttleError::Store("invalid client_id".to_owned()));
243 }
244 if stored.redirect_uri != request.redirect_uri {
245 return Err(ShuttleError::Store("invalid redirect_uri".to_owned()));
246 }
247 if parse_time(&stored.expires_at)? < Utc::now() {
248 return Err(ShuttleError::Store("code expired".to_owned()));
249 }
250 if pkce_s256(&verifier) != stored.code_challenge {
251 return Err(ShuttleError::Store("invalid code_verifier".to_owned()));
252 }
253
254 tx.execute(
255 "UPDATE oauth_codes SET used_at = ?1 WHERE code = ?2",
256 params![Utc::now().to_rfc3339(), code],
257 )
258 .map_err(to_store_error)?;
259 let token = create_token(&tx, &stored.client_id, &stored.scope)?;
260 tx.commit().map_err(to_store_error)?;
261 Ok(token)
262 }
263
264 pub fn validate_access_token(&self, bearer_token: &str) -> Result<bool> {
265 let conn = self
266 .conn
267 .lock()
268 .map_err(|err| ShuttleError::Store(err.to_string()))?;
269 let row = conn
270 .query_row(
271 "SELECT scope, expires_at FROM oauth_tokens WHERE token = ?1",
272 params![bearer_token],
273 |row| Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?)),
274 )
275 .optional()
276 .map_err(to_store_error)?;
277 let Some((scope, expires_at)) = row else {
278 return Ok(false);
279 };
280 Ok(scope.split_whitespace().any(|scope| scope == MCP_SCOPE)
281 && parse_time(&expires_at)? > Utc::now())
282 }
283}
284
285#[derive(Debug, Deserialize)]
286pub struct RegisterRequest {
287 #[serde(default)]
288 pub redirect_uris: Vec<String>,
289 pub client_name: Option<String>,
290}
291
292#[derive(Debug, Serialize)]
293pub struct RegisteredClient {
294 pub client_id: String,
295 pub client_secret: Option<String>,
296 pub redirect_uris: Vec<String>,
297 pub client_name: Option<String>,
298}
299
300#[derive(Debug, Deserialize, Clone)]
301pub struct AuthorizeRequest {
302 pub response_type: String,
303 pub client_id: String,
304 pub redirect_uri: String,
305 pub state: Option<String>,
306 pub scope: Option<String>,
307 pub code_challenge: Option<String>,
308 pub code_challenge_method: Option<String>,
309}
310
311#[derive(Debug, Deserialize)]
312pub struct AuthorizeForm {
313 pub admin_token: String,
314 pub response_type: String,
315 pub client_id: String,
316 pub redirect_uri: String,
317 pub state: Option<String>,
318 pub scope: Option<String>,
319 pub code_challenge: Option<String>,
320 pub code_challenge_method: Option<String>,
321}
322
323impl From<AuthorizeForm> for AuthorizeRequest {
324 fn from(form: AuthorizeForm) -> Self {
325 Self {
326 response_type: form.response_type,
327 client_id: form.client_id,
328 redirect_uri: form.redirect_uri,
329 state: form.state,
330 scope: form.scope,
331 code_challenge: form.code_challenge,
332 code_challenge_method: form.code_challenge_method,
333 }
334 }
335}
336
337#[derive(Debug, Clone, Deserialize)]
338pub struct TokenRequest {
339 pub grant_type: String,
340 pub client_id: String,
341 pub redirect_uri: String,
342 pub code: Option<String>,
343 pub code_verifier: Option<String>,
344}
345
346#[derive(Debug, Serialize)]
347pub struct TokenResponse {
348 pub access_token: String,
349 pub token_type: &'static str,
350 pub expires_in: i64,
351 pub scope: String,
352}
353
354pub fn authorization_server_metadata(config: &OAuthConfig) -> Value {
355 json!({
356 "issuer": config.public_url,
357 "authorization_endpoint": format!("{}/oauth/authorize", config.public_url),
358 "token_endpoint": format!("{}/oauth/token", config.public_url),
359 "registration_endpoint": format!("{}/oauth/register", config.public_url),
360 "response_types_supported": ["code"],
361 "grant_types_supported": ["authorization_code"],
362 "code_challenge_methods_supported": ["S256"],
363 "token_endpoint_auth_methods_supported": ["none"],
364 "scopes_supported": [MCP_SCOPE],
365 })
366}
367
368pub fn protected_resource_metadata(config: &OAuthConfig) -> Value {
369 json!({
370 "resource": config.resource_url(),
371 "authorization_servers": [config.public_url],
372 "scopes_supported": [MCP_SCOPE],
373 "bearer_methods_supported": ["header"],
374 })
375}
376
377pub fn authorize_redirect(redirect_uri: &str, code: &str, state: Option<&str>) -> String {
383 let mut target = format!(
384 "{}{}code={}",
385 redirect_uri,
386 if redirect_uri.contains('?') { "&" } else { "?" },
387 query_component(code)
388 );
389 if let Some(state) = state {
390 target.push_str("&state=");
391 target.push_str(&query_component(state));
392 }
393 target
394}
395
396fn query_component(value: &str) -> String {
397 let mut encoded = String::new();
398 for byte in value.bytes() {
399 match byte {
400 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
401 encoded.push(byte as char);
402 }
403 _ => encoded.push_str(&format!("%{byte:02X}")),
404 }
405 }
406 encoded
407}
408
409fn create_token(conn: &Connection, client_id: &str, scope: &str) -> Result<TokenResponse> {
410 let access_token = token();
411 let now = Utc::now();
412 let expires_in = 3600;
413 conn.execute(
414 "INSERT INTO oauth_tokens (token, client_id, scope, expires_at, created_at)
415 VALUES (?1, ?2, ?3, ?4, ?5)",
416 params![
417 access_token,
418 client_id,
419 scope,
420 (now + Duration::seconds(expires_in)).to_rfc3339(),
421 now.to_rfc3339()
422 ],
423 )
424 .map_err(to_store_error)?;
425 Ok(TokenResponse {
426 access_token,
427 token_type: "Bearer",
428 expires_in,
429 scope: scope.to_owned(),
430 })
431}
432
433fn normalize_scope(scope: Option<String>) -> String {
434 let scope = scope.unwrap_or_else(|| MCP_SCOPE.to_owned());
435 if scope.split_whitespace().any(|scope| scope == MCP_SCOPE) {
436 scope
437 } else {
438 MCP_SCOPE.to_owned()
439 }
440}
441
442fn token() -> String {
443 format!("stl_{}", Uuid::new_v4().simple())
444}
445
446fn pkce_s256(verifier: &str) -> String {
447 let digest = Sha256::digest(verifier.as_bytes());
448 URL_SAFE_NO_PAD.encode(digest)
449}
450
451fn parse_time(value: &str) -> Result<DateTime<Utc>> {
452 DateTime::parse_from_rfc3339(value)
453 .map(|time| time.with_timezone(&Utc))
454 .map_err(|err| ShuttleError::Store(err.to_string()))
455}
456
457fn to_store_error(err: rusqlite::Error) -> ShuttleError {
458 ShuttleError::Store(err.to_string())
459}
460
461fn purge_expired(conn: &Connection) -> Result<()> {
462 let now = Utc::now().to_rfc3339();
463 conn.execute(
464 "DELETE FROM oauth_codes WHERE expires_at < ?1 OR used_at IS NOT NULL",
465 params![now],
466 )
467 .map_err(to_store_error)?;
468 conn.execute(
469 "DELETE FROM oauth_tokens WHERE expires_at < ?1",
470 params![now],
471 )
472 .map_err(to_store_error)?;
473 Ok(())
474}
475
476struct StoredCode {
477 client_id: String,
478 redirect_uri: String,
479 code_challenge: String,
480 scope: String,
481 expires_at: String,
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn metadata_uses_public_url() {
490 let config = OAuthConfig {
491 public_url: "https://shuttle.example.test".to_owned(),
492 admin_token: None,
493 };
494
495 assert_eq!(
496 protected_resource_metadata(&config)["resource"],
497 "https://shuttle.example.test/mcp"
498 );
499 assert_eq!(
500 authorization_server_metadata(&config)["token_endpoint"],
501 "https://shuttle.example.test/oauth/token"
502 );
503 }
504
505 #[test]
506 fn authorize_redirect_encodes_state_as_query_component() {
507 let url = authorize_redirect(
508 "https://claude.ai/api/mcp/auth_callback",
509 "stl_abc123",
510 Some("opaque=value+with/special&fragment#part"),
511 );
512 assert_eq!(
513 url,
514 "https://claude.ai/api/mcp/auth_callback?code=stl_abc123&state=opaque%3Dvalue%2Bwith%2Fspecial%26fragment%23part"
515 );
516 }
517
518 #[test]
519 fn authorize_redirect_omits_state_when_absent() {
520 let url = authorize_redirect(
521 "https://claude.ai/api/mcp/auth_callback",
522 "stl_abc123",
523 None,
524 );
525 assert_eq!(
526 url,
527 "https://claude.ai/api/mcp/auth_callback?code=stl_abc123"
528 );
529 assert!(!url.contains("state="));
530 }
531
532 #[test]
533 fn code_exchange_validates_pkce() {
534 let dir = tempfile::tempdir().unwrap();
535 let store = OAuthStore::open(dir.path().join("shuttle.db")).unwrap();
536 let client = store
537 .register_client(RegisterRequest {
538 redirect_uris: vec!["https://client.example.test/callback".to_owned()],
539 client_name: Some("client".to_owned()),
540 })
541 .unwrap();
542 let verifier = "abc123abc123abc123abc123abc123abc123abc123abc123";
543 let code = store
544 .create_code(AuthorizeRequest {
545 response_type: "code".to_owned(),
546 client_id: client.client_id.clone(),
547 redirect_uri: "https://client.example.test/callback".to_owned(),
548 state: None,
549 scope: Some("mcp".to_owned()),
550 code_challenge: Some(pkce_s256(verifier)),
551 code_challenge_method: Some("S256".to_owned()),
552 })
553 .unwrap();
554
555 let token = store
556 .exchange_code(TokenRequest {
557 grant_type: "authorization_code".to_owned(),
558 client_id: client.client_id,
559 redirect_uri: "https://client.example.test/callback".to_owned(),
560 code: Some(code),
561 code_verifier: Some(verifier.to_owned()),
562 })
563 .unwrap();
564
565 assert!(store.validate_access_token(&token.access_token).unwrap());
566 }
567
568 #[test]
569 fn code_exchange_rejects_reused_code() {
570 let dir = tempfile::tempdir().unwrap();
571 let store = OAuthStore::open(dir.path().join("shuttle.db")).unwrap();
572 let client = store
573 .register_client(RegisterRequest {
574 redirect_uris: vec!["https://client.example.test/callback".to_owned()],
575 client_name: Some("client".to_owned()),
576 })
577 .unwrap();
578 let verifier = "abc123abc123abc123abc123abc123abc123abc123abc123";
579 let code = store
580 .create_code(AuthorizeRequest {
581 response_type: "code".to_owned(),
582 client_id: client.client_id.clone(),
583 redirect_uri: "https://client.example.test/callback".to_owned(),
584 state: None,
585 scope: Some("mcp".to_owned()),
586 code_challenge: Some(pkce_s256(verifier)),
587 code_challenge_method: Some("S256".to_owned()),
588 })
589 .unwrap();
590 let request = TokenRequest {
591 grant_type: "authorization_code".to_owned(),
592 client_id: client.client_id,
593 redirect_uri: "https://client.example.test/callback".to_owned(),
594 code: Some(code),
595 code_verifier: Some(verifier.to_owned()),
596 };
597
598 store
599 .exchange_code(TokenRequest { ..request.clone() })
600 .unwrap();
601 let err = store.exchange_code(request).unwrap_err();
602
603 assert!(err.to_string().contains("code already used"));
604 }
605
606 #[test]
607 fn store_validates_oauth_grant_shape() {
608 let dir = tempfile::tempdir().unwrap();
609 let store = OAuthStore::open(dir.path().join("shuttle.db")).unwrap();
610 let client = store
611 .register_client(RegisterRequest {
612 redirect_uris: vec!["https://client.example.test/callback".to_owned()],
613 client_name: Some("client".to_owned()),
614 })
615 .unwrap();
616 let verifier = "abc123abc123abc123abc123abc123abc123abc123abc123";
617
618 assert!(store
619 .create_code(AuthorizeRequest {
620 response_type: "token".to_owned(),
621 client_id: client.client_id.clone(),
622 redirect_uri: "https://client.example.test/callback".to_owned(),
623 state: None,
624 scope: Some("mcp".to_owned()),
625 code_challenge: Some(pkce_s256(verifier)),
626 code_challenge_method: Some("S256".to_owned()),
627 })
628 .unwrap_err()
629 .to_string()
630 .contains("response_type must be code"));
631
632 assert!(store
633 .exchange_code(TokenRequest {
634 grant_type: "refresh_token".to_owned(),
635 client_id: client.client_id,
636 redirect_uri: "https://client.example.test/callback".to_owned(),
637 code: Some("stl_missing".to_owned()),
638 code_verifier: Some(verifier.to_owned()),
639 })
640 .unwrap_err()
641 .to_string()
642 .contains("grant_type must be authorization_code"));
643 }
644}