Skip to main content

shuttle_rs/
oauth.rs

1use std::path::Path;
2use std::sync::{Arc, Mutex};
3
4use base64::engine::general_purpose::URL_SAFE_NO_PAD;
5use base64::Engine;
6use chrono::{DateTime, Duration, Utc};
7use rusqlite::{params, Connection, OptionalExtension};
8use serde::{Deserialize, Serialize};
9use serde_json::{json, Value};
10use sha2::{Digest, Sha256};
11use uuid::Uuid;
12
13use crate::core::{Result, ShuttleError};
14
15const MCP_SCOPE: &str = "mcp";
16
17#[derive(Clone)]
18pub struct OAuthConfig {
19    pub public_url: String,
20    /// Owner-approval token for authorization-code issuance.
21    ///
22    /// CLI public URL mode requires this to be `Some`; `None` is reserved for
23    /// programmatic or local-only runtimes that intentionally skip owner
24    /// approval.
25    pub admin_token: Option<String>,
26}
27
28impl OAuthConfig {
29    pub fn normalize_public_url(public_url: String) -> String {
30        public_url.trim().trim_end_matches('/').to_owned()
31    }
32
33    pub fn resource_url(&self) -> String {
34        format!("{}/mcp", self.public_url)
35    }
36}
37
38#[derive(Clone)]
39pub struct OAuthStore {
40    conn: Arc<Mutex<Connection>>,
41}
42
43impl OAuthStore {
44    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
45        let conn = Connection::open(path).map_err(to_store_error)?;
46        let store = Self {
47            conn: Arc::new(Mutex::new(conn)),
48        };
49        store.init()?;
50        Ok(store)
51    }
52
53    fn init(&self) -> Result<()> {
54        let conn = self
55            .conn
56            .lock()
57            .map_err(|err| ShuttleError::Store(err.to_string()))?;
58        conn.execute_batch(
59            r#"
60            CREATE TABLE IF NOT EXISTS oauth_clients (
61                client_id TEXT PRIMARY KEY NOT NULL,
62                client_secret TEXT,
63                redirect_uris TEXT NOT NULL,
64                client_name TEXT,
65                created_at TEXT NOT NULL
66            );
67
68            CREATE TABLE IF NOT EXISTS oauth_codes (
69                code TEXT PRIMARY KEY NOT NULL,
70                client_id TEXT NOT NULL,
71                redirect_uri TEXT NOT NULL,
72                code_challenge TEXT NOT NULL,
73                code_challenge_method TEXT NOT NULL,
74                scope TEXT NOT NULL,
75                expires_at TEXT NOT NULL,
76                used_at TEXT,
77                created_at TEXT NOT NULL
78            );
79
80            CREATE TABLE IF NOT EXISTS oauth_tokens (
81                token TEXT PRIMARY KEY NOT NULL,
82                client_id TEXT NOT NULL,
83                scope TEXT NOT NULL,
84                expires_at TEXT NOT NULL,
85                created_at TEXT NOT NULL
86            );
87            "#,
88        )
89        .map_err(to_store_error)?;
90        purge_expired(&conn)?;
91        Ok(())
92    }
93
94    pub fn register_client(&self, request: RegisterRequest) -> Result<RegisteredClient> {
95        if request.redirect_uris.is_empty() {
96            return Err(ShuttleError::Store(
97                "redirect_uris must contain at least one URI".to_owned(),
98            ));
99        }
100        let client = RegisteredClient {
101            client_id: token(),
102            client_secret: None,
103            redirect_uris: request.redirect_uris,
104            client_name: request.client_name,
105        };
106        let conn = self
107            .conn
108            .lock()
109            .map_err(|err| ShuttleError::Store(err.to_string()))?;
110        conn.execute(
111            "INSERT INTO oauth_clients (client_id, client_secret, redirect_uris, client_name, created_at)
112             VALUES (?1, ?2, ?3, ?4, ?5)",
113            params![
114                client.client_id,
115                client.client_secret,
116                serde_json::to_string(&client.redirect_uris)
117                    .map_err(|err| ShuttleError::Serialization(err.to_string()))?,
118                client.client_name,
119                Utc::now().to_rfc3339()
120            ],
121        )
122        .map_err(to_store_error)?;
123        Ok(client)
124    }
125
126    pub fn client_allows_redirect(&self, client_id: &str, redirect_uri: &str) -> Result<bool> {
127        let conn = self
128            .conn
129            .lock()
130            .map_err(|err| ShuttleError::Store(err.to_string()))?;
131        let redirect_uris = conn
132            .query_row(
133                "SELECT redirect_uris FROM oauth_clients WHERE client_id = ?1",
134                params![client_id],
135                |row| row.get::<_, String>(0),
136            )
137            .optional()
138            .map_err(to_store_error)?;
139        let Some(redirect_uris) = redirect_uris else {
140            return Ok(false);
141        };
142        let redirect_uris: Vec<String> = serde_json::from_str(&redirect_uris)
143            .map_err(|err| ShuttleError::Serialization(err.to_string()))?;
144        Ok(redirect_uris.iter().any(|uri| uri == redirect_uri))
145    }
146
147    pub fn create_code(&self, request: AuthorizeRequest) -> Result<String> {
148        if request.response_type != "code" {
149            return Err(ShuttleError::Store("response_type must be code".to_owned()));
150        }
151        if !self.client_allows_redirect(&request.client_id, &request.redirect_uri)? {
152            return Err(ShuttleError::Store(
153                "unknown client_id or redirect_uri".to_owned(),
154            ));
155        }
156        if request.code_challenge_method.as_deref() != Some("S256") {
157            return Err(ShuttleError::Store(
158                "code_challenge_method must be S256".to_owned(),
159            ));
160        }
161        let Some(code_challenge) = request.code_challenge else {
162            return Err(ShuttleError::Store("missing code_challenge".to_owned()));
163        };
164        let scope = normalize_scope(request.scope);
165        let code = token();
166        let now = Utc::now();
167        let conn = self
168            .conn
169            .lock()
170            .map_err(|err| ShuttleError::Store(err.to_string()))?;
171        conn.execute(
172            "INSERT INTO oauth_codes (
173                code, client_id, redirect_uri, code_challenge, code_challenge_method,
174                scope, expires_at, created_at
175             ) VALUES (?1, ?2, ?3, ?4, 'S256', ?5, ?6, ?7)",
176            params![
177                code,
178                request.client_id,
179                request.redirect_uri,
180                code_challenge,
181                scope,
182                (now + Duration::minutes(10)).to_rfc3339(),
183                now.to_rfc3339()
184            ],
185        )
186        .map_err(to_store_error)?;
187        Ok(code)
188    }
189
190    pub fn exchange_code(&self, request: TokenRequest) -> Result<TokenResponse> {
191        if request.grant_type != "authorization_code" {
192            return Err(ShuttleError::Store(
193                "grant_type must be authorization_code".to_owned(),
194            ));
195        }
196        let code = request
197            .code
198            .ok_or_else(|| ShuttleError::Store("missing code".to_owned()))?;
199        let verifier = request
200            .code_verifier
201            .ok_or_else(|| ShuttleError::Store("missing code_verifier".to_owned()))?;
202        let mut conn = self
203            .conn
204            .lock()
205            .map_err(|err| ShuttleError::Store(err.to_string()))?;
206        let tx = conn.transaction().map_err(to_store_error)?;
207        let stored = tx
208            .query_row(
209                "SELECT client_id, redirect_uri, code_challenge, scope, expires_at
210                 FROM oauth_codes WHERE code = ?1 AND used_at IS NULL",
211                params![code],
212                |row| {
213                    Ok(StoredCode {
214                        client_id: row.get(0)?,
215                        redirect_uri: row.get(1)?,
216                        code_challenge: row.get(2)?,
217                        scope: row.get(3)?,
218                        expires_at: row.get(4)?,
219                    })
220                },
221            )
222            .optional()
223            .map_err(to_store_error)?;
224        let Some(stored) = stored else {
225            let exists = tx
226                .query_row(
227                    "SELECT 1 FROM oauth_codes WHERE code = ?1",
228                    params![code],
229                    |_| Ok(()),
230                )
231                .optional()
232                .map_err(to_store_error)?
233                .is_some();
234            return Err(ShuttleError::Store(if exists {
235                "code already used".to_owned()
236            } else {
237                "invalid code".to_owned()
238            }));
239        };
240
241        if stored.client_id != request.client_id {
242            return Err(ShuttleError::Store("invalid client_id".to_owned()));
243        }
244        if stored.redirect_uri != request.redirect_uri {
245            return Err(ShuttleError::Store("invalid redirect_uri".to_owned()));
246        }
247        if parse_time(&stored.expires_at)? < Utc::now() {
248            return Err(ShuttleError::Store("code expired".to_owned()));
249        }
250        if pkce_s256(&verifier) != stored.code_challenge {
251            return Err(ShuttleError::Store("invalid code_verifier".to_owned()));
252        }
253
254        tx.execute(
255            "UPDATE oauth_codes SET used_at = ?1 WHERE code = ?2",
256            params![Utc::now().to_rfc3339(), code],
257        )
258        .map_err(to_store_error)?;
259        let token = create_token(&tx, &stored.client_id, &stored.scope)?;
260        tx.commit().map_err(to_store_error)?;
261        Ok(token)
262    }
263
264    pub fn validate_access_token(&self, bearer_token: &str) -> Result<bool> {
265        let conn = self
266            .conn
267            .lock()
268            .map_err(|err| ShuttleError::Store(err.to_string()))?;
269        let row = conn
270            .query_row(
271                "SELECT scope, expires_at FROM oauth_tokens WHERE token = ?1",
272                params![bearer_token],
273                |row| Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?)),
274            )
275            .optional()
276            .map_err(to_store_error)?;
277        let Some((scope, expires_at)) = row else {
278            return Ok(false);
279        };
280        Ok(scope.split_whitespace().any(|scope| scope == MCP_SCOPE)
281            && parse_time(&expires_at)? > Utc::now())
282    }
283}
284
285#[derive(Debug, Deserialize)]
286pub struct RegisterRequest {
287    #[serde(default)]
288    pub redirect_uris: Vec<String>,
289    pub client_name: Option<String>,
290}
291
292#[derive(Debug, Serialize)]
293pub struct RegisteredClient {
294    pub client_id: String,
295    pub client_secret: Option<String>,
296    pub redirect_uris: Vec<String>,
297    pub client_name: Option<String>,
298}
299
300#[derive(Debug, Deserialize, Clone)]
301pub struct AuthorizeRequest {
302    pub response_type: String,
303    pub client_id: String,
304    pub redirect_uri: String,
305    pub state: Option<String>,
306    pub scope: Option<String>,
307    pub code_challenge: Option<String>,
308    pub code_challenge_method: Option<String>,
309}
310
311#[derive(Debug, Deserialize)]
312pub struct AuthorizeForm {
313    pub admin_token: String,
314    pub response_type: String,
315    pub client_id: String,
316    pub redirect_uri: String,
317    pub state: Option<String>,
318    pub scope: Option<String>,
319    pub code_challenge: Option<String>,
320    pub code_challenge_method: Option<String>,
321}
322
323impl From<AuthorizeForm> for AuthorizeRequest {
324    fn from(form: AuthorizeForm) -> Self {
325        Self {
326            response_type: form.response_type,
327            client_id: form.client_id,
328            redirect_uri: form.redirect_uri,
329            state: form.state,
330            scope: form.scope,
331            code_challenge: form.code_challenge,
332            code_challenge_method: form.code_challenge_method,
333        }
334    }
335}
336
337#[derive(Debug, Clone, Deserialize)]
338pub struct TokenRequest {
339    pub grant_type: String,
340    pub client_id: String,
341    pub redirect_uri: String,
342    pub code: Option<String>,
343    pub code_verifier: Option<String>,
344}
345
346#[derive(Debug, Serialize)]
347pub struct TokenResponse {
348    pub access_token: String,
349    pub token_type: &'static str,
350    pub expires_in: i64,
351    pub scope: String,
352}
353
354pub fn authorization_server_metadata(config: &OAuthConfig) -> Value {
355    json!({
356        "issuer": config.public_url,
357        "authorization_endpoint": format!("{}/oauth/authorize", config.public_url),
358        "token_endpoint": format!("{}/oauth/token", config.public_url),
359        "registration_endpoint": format!("{}/oauth/register", config.public_url),
360        "response_types_supported": ["code"],
361        "grant_types_supported": ["authorization_code"],
362        "code_challenge_methods_supported": ["S256"],
363        "token_endpoint_auth_methods_supported": ["none"],
364        "scopes_supported": [MCP_SCOPE],
365    })
366}
367
368pub fn protected_resource_metadata(config: &OAuthConfig) -> Value {
369    json!({
370        "resource": config.resource_url(),
371        "authorization_servers": [config.public_url],
372        "scopes_supported": [MCP_SCOPE],
373        "bearer_methods_supported": ["header"],
374    })
375}
376
377/// Build the OAuth 2.0 authorization-code redirect URL (RFC 6749 ยง4.1.2).
378///
379/// `code` and `state` are serialized as query components. The values are
380/// percent-encoded at the redirect boundary so reserved characters in opaque
381/// client state cannot change the query structure.
382pub fn authorize_redirect(redirect_uri: &str, code: &str, state: Option<&str>) -> String {
383    let mut target = format!(
384        "{}{}code={}",
385        redirect_uri,
386        if redirect_uri.contains('?') { "&" } else { "?" },
387        query_component(code)
388    );
389    if let Some(state) = state {
390        target.push_str("&state=");
391        target.push_str(&query_component(state));
392    }
393    target
394}
395
396fn query_component(value: &str) -> String {
397    let mut encoded = String::new();
398    for byte in value.bytes() {
399        match byte {
400            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
401                encoded.push(byte as char);
402            }
403            _ => encoded.push_str(&format!("%{byte:02X}")),
404        }
405    }
406    encoded
407}
408
409fn create_token(conn: &Connection, client_id: &str, scope: &str) -> Result<TokenResponse> {
410    let access_token = token();
411    let now = Utc::now();
412    let expires_in = 3600;
413    conn.execute(
414        "INSERT INTO oauth_tokens (token, client_id, scope, expires_at, created_at)
415         VALUES (?1, ?2, ?3, ?4, ?5)",
416        params![
417            access_token,
418            client_id,
419            scope,
420            (now + Duration::seconds(expires_in)).to_rfc3339(),
421            now.to_rfc3339()
422        ],
423    )
424    .map_err(to_store_error)?;
425    Ok(TokenResponse {
426        access_token,
427        token_type: "Bearer",
428        expires_in,
429        scope: scope.to_owned(),
430    })
431}
432
433fn normalize_scope(scope: Option<String>) -> String {
434    let scope = scope.unwrap_or_else(|| MCP_SCOPE.to_owned());
435    if scope.split_whitespace().any(|scope| scope == MCP_SCOPE) {
436        scope
437    } else {
438        MCP_SCOPE.to_owned()
439    }
440}
441
442fn token() -> String {
443    format!("stl_{}", Uuid::new_v4().simple())
444}
445
446fn pkce_s256(verifier: &str) -> String {
447    let digest = Sha256::digest(verifier.as_bytes());
448    URL_SAFE_NO_PAD.encode(digest)
449}
450
451fn parse_time(value: &str) -> Result<DateTime<Utc>> {
452    DateTime::parse_from_rfc3339(value)
453        .map(|time| time.with_timezone(&Utc))
454        .map_err(|err| ShuttleError::Store(err.to_string()))
455}
456
457fn to_store_error(err: rusqlite::Error) -> ShuttleError {
458    ShuttleError::Store(err.to_string())
459}
460
461fn purge_expired(conn: &Connection) -> Result<()> {
462    let now = Utc::now().to_rfc3339();
463    conn.execute(
464        "DELETE FROM oauth_codes WHERE expires_at < ?1 OR used_at IS NOT NULL",
465        params![now],
466    )
467    .map_err(to_store_error)?;
468    conn.execute(
469        "DELETE FROM oauth_tokens WHERE expires_at < ?1",
470        params![now],
471    )
472    .map_err(to_store_error)?;
473    Ok(())
474}
475
476struct StoredCode {
477    client_id: String,
478    redirect_uri: String,
479    code_challenge: String,
480    scope: String,
481    expires_at: String,
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn metadata_uses_public_url() {
490        let config = OAuthConfig {
491            public_url: "https://shuttle.example.test".to_owned(),
492            admin_token: None,
493        };
494
495        assert_eq!(
496            protected_resource_metadata(&config)["resource"],
497            "https://shuttle.example.test/mcp"
498        );
499        assert_eq!(
500            authorization_server_metadata(&config)["token_endpoint"],
501            "https://shuttle.example.test/oauth/token"
502        );
503    }
504
505    #[test]
506    fn authorize_redirect_encodes_state_as_query_component() {
507        let url = authorize_redirect(
508            "https://claude.ai/api/mcp/auth_callback",
509            "stl_abc123",
510            Some("opaque=value+with/special&fragment#part"),
511        );
512        assert_eq!(
513            url,
514            "https://claude.ai/api/mcp/auth_callback?code=stl_abc123&state=opaque%3Dvalue%2Bwith%2Fspecial%26fragment%23part"
515        );
516    }
517
518    #[test]
519    fn authorize_redirect_omits_state_when_absent() {
520        let url = authorize_redirect(
521            "https://claude.ai/api/mcp/auth_callback",
522            "stl_abc123",
523            None,
524        );
525        assert_eq!(
526            url,
527            "https://claude.ai/api/mcp/auth_callback?code=stl_abc123"
528        );
529        assert!(!url.contains("state="));
530    }
531
532    #[test]
533    fn code_exchange_validates_pkce() {
534        let dir = tempfile::tempdir().unwrap();
535        let store = OAuthStore::open(dir.path().join("shuttle.db")).unwrap();
536        let client = store
537            .register_client(RegisterRequest {
538                redirect_uris: vec!["https://client.example.test/callback".to_owned()],
539                client_name: Some("client".to_owned()),
540            })
541            .unwrap();
542        let verifier = "abc123abc123abc123abc123abc123abc123abc123abc123";
543        let code = store
544            .create_code(AuthorizeRequest {
545                response_type: "code".to_owned(),
546                client_id: client.client_id.clone(),
547                redirect_uri: "https://client.example.test/callback".to_owned(),
548                state: None,
549                scope: Some("mcp".to_owned()),
550                code_challenge: Some(pkce_s256(verifier)),
551                code_challenge_method: Some("S256".to_owned()),
552            })
553            .unwrap();
554
555        let token = store
556            .exchange_code(TokenRequest {
557                grant_type: "authorization_code".to_owned(),
558                client_id: client.client_id,
559                redirect_uri: "https://client.example.test/callback".to_owned(),
560                code: Some(code),
561                code_verifier: Some(verifier.to_owned()),
562            })
563            .unwrap();
564
565        assert!(store.validate_access_token(&token.access_token).unwrap());
566    }
567
568    #[test]
569    fn code_exchange_rejects_reused_code() {
570        let dir = tempfile::tempdir().unwrap();
571        let store = OAuthStore::open(dir.path().join("shuttle.db")).unwrap();
572        let client = store
573            .register_client(RegisterRequest {
574                redirect_uris: vec!["https://client.example.test/callback".to_owned()],
575                client_name: Some("client".to_owned()),
576            })
577            .unwrap();
578        let verifier = "abc123abc123abc123abc123abc123abc123abc123abc123";
579        let code = store
580            .create_code(AuthorizeRequest {
581                response_type: "code".to_owned(),
582                client_id: client.client_id.clone(),
583                redirect_uri: "https://client.example.test/callback".to_owned(),
584                state: None,
585                scope: Some("mcp".to_owned()),
586                code_challenge: Some(pkce_s256(verifier)),
587                code_challenge_method: Some("S256".to_owned()),
588            })
589            .unwrap();
590        let request = TokenRequest {
591            grant_type: "authorization_code".to_owned(),
592            client_id: client.client_id,
593            redirect_uri: "https://client.example.test/callback".to_owned(),
594            code: Some(code),
595            code_verifier: Some(verifier.to_owned()),
596        };
597
598        store
599            .exchange_code(TokenRequest { ..request.clone() })
600            .unwrap();
601        let err = store.exchange_code(request).unwrap_err();
602
603        assert!(err.to_string().contains("code already used"));
604    }
605
606    #[test]
607    fn store_validates_oauth_grant_shape() {
608        let dir = tempfile::tempdir().unwrap();
609        let store = OAuthStore::open(dir.path().join("shuttle.db")).unwrap();
610        let client = store
611            .register_client(RegisterRequest {
612                redirect_uris: vec!["https://client.example.test/callback".to_owned()],
613                client_name: Some("client".to_owned()),
614            })
615            .unwrap();
616        let verifier = "abc123abc123abc123abc123abc123abc123abc123abc123";
617
618        assert!(store
619            .create_code(AuthorizeRequest {
620                response_type: "token".to_owned(),
621                client_id: client.client_id.clone(),
622                redirect_uri: "https://client.example.test/callback".to_owned(),
623                state: None,
624                scope: Some("mcp".to_owned()),
625                code_challenge: Some(pkce_s256(verifier)),
626                code_challenge_method: Some("S256".to_owned()),
627            })
628            .unwrap_err()
629            .to_string()
630            .contains("response_type must be code"));
631
632        assert!(store
633            .exchange_code(TokenRequest {
634                grant_type: "refresh_token".to_owned(),
635                client_id: client.client_id,
636                redirect_uri: "https://client.example.test/callback".to_owned(),
637                code: Some("stl_missing".to_owned()),
638                code_verifier: Some(verifier.to_owned()),
639            })
640            .unwrap_err()
641            .to_string()
642            .contains("grant_type must be authorization_code"));
643    }
644}