Skip to main content

forge_core/oauth/
mod.rs

1//! OAuth 2.1 Authorization Code + PKCE support.
2//!
3//! Forge acts as an OAuth 2.1 Authorization Server for its MCP endpoint.
4//! Enable with `mcp.oauth = true` in `forge.toml`.
5
6pub mod pkce;
7
8use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
9use chrono::{DateTime, Utc};
10use uuid::Uuid;
11
12/// An OAuth 2.1 client registration.
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct OAuthClient {
15    pub client_id: String,
16    pub client_name: Option<String>,
17    pub redirect_uris: Vec<String>,
18    pub token_endpoint_auth_method: String,
19    pub created_at: DateTime<Utc>,
20}
21
22/// A pending authorization code with PKCE challenge.
23#[derive(Debug, Clone)]
24pub struct AuthorizationCode {
25    pub code: String,
26    pub client_id: String,
27    pub user_id: Uuid,
28    pub redirect_uri: String,
29    pub code_challenge: String,
30    pub code_challenge_method: String,
31    pub scopes: Vec<String>,
32    pub expires_at: DateTime<Utc>,
33}
34
35/// 256 bits of CSPRNG entropy (two UUIDv4s), base64url-encoded.
36pub fn generate_random_token() -> String {
37    let a = Uuid::new_v4();
38    let b = Uuid::new_v4();
39    let mut bytes = [0u8; 32];
40    bytes[..16].copy_from_slice(a.as_bytes());
41    bytes[16..].copy_from_slice(b.as_bytes());
42    URL_SAFE_NO_PAD.encode(bytes)
43}
44
45/// Validate a redirect URI against a client's registered URIs.
46///
47/// Uses exact string match per OAuth 2.1, with RFC 8252 localhost exception:
48/// for `http://localhost` URIs, the port is allowed to differ.
49pub fn validate_redirect_uri(requested: &str, registered: &[String]) -> bool {
50    for uri in registered {
51        if requested == uri {
52            return true;
53        }
54        // RFC 8252 Section 7.3: localhost with any port is allowed
55        if is_localhost_uri(uri) && is_localhost_uri(requested) {
56            let (scheme_host_a, path_a) = split_localhost_uri(uri);
57            let (scheme_host_b, path_b) = split_localhost_uri(requested);
58            if scheme_host_a == scheme_host_b && path_a == path_b {
59                return true;
60            }
61        }
62    }
63    false
64}
65
66fn is_localhost_uri(uri: &str) -> bool {
67    uri.starts_with("http://localhost") || uri.starts_with("http://127.0.0.1")
68}
69
70/// Split a localhost URI into (scheme+host, path), stripping the port.
71/// e.g. "http://localhost:12345/callback" -> ("http://localhost", "/callback")
72fn split_localhost_uri(uri: &str) -> (&str, &str) {
73    let scheme_end = if uri.starts_with("http://localhost") {
74        "http://localhost".len()
75    } else {
76        "http://127.0.0.1".len()
77    };
78    let rest = &uri[scheme_end..];
79    // Skip the port if present
80    let path_start = rest.find('/').unwrap_or(rest.len());
81    let path = &rest[path_start..];
82    let scheme_host = &uri[..scheme_end];
83    (scheme_host, path)
84}
85
86#[cfg(test)]
87#[allow(clippy::unwrap_used)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn test_generate_random_token_length() {
93        let code = generate_random_token();
94        assert_eq!(code.len(), 43); // 32 bytes -> 43 base64url chars
95    }
96
97    #[test]
98    fn test_generate_random_token_uniqueness() {
99        let a = generate_random_token();
100        let b = generate_random_token();
101        assert_ne!(a, b);
102    }
103
104    #[test]
105    fn test_validate_redirect_uri_exact_match() {
106        let registered = vec!["https://example.com/callback".to_string()];
107        assert!(validate_redirect_uri(
108            "https://example.com/callback",
109            &registered
110        ));
111        assert!(!validate_redirect_uri(
112            "https://example.com/other",
113            &registered
114        ));
115    }
116
117    #[test]
118    fn test_validate_redirect_uri_localhost_port_exception() {
119        let registered = vec!["http://localhost:3000/callback".to_string()];
120        assert!(validate_redirect_uri(
121            "http://localhost:9999/callback",
122            &registered
123        ));
124        assert!(validate_redirect_uri(
125            "http://localhost:3000/callback",
126            &registered
127        ));
128    }
129
130    #[test]
131    fn test_validate_redirect_uri_localhost_different_path() {
132        let registered = vec!["http://localhost:3000/callback".to_string()];
133        assert!(!validate_redirect_uri(
134            "http://localhost:9999/other",
135            &registered
136        ));
137    }
138
139    #[test]
140    fn test_validate_redirect_uri_no_localhost_exception_for_https() {
141        let registered = vec!["https://localhost:3000/callback".to_string()];
142        assert!(!validate_redirect_uri(
143            "https://localhost:9999/callback",
144            &registered
145        ));
146    }
147
148    #[test]
149    fn test_validate_redirect_uri_127_0_0_1() {
150        let registered = vec!["http://127.0.0.1:3000/callback".to_string()];
151        assert!(validate_redirect_uri(
152            "http://127.0.0.1:9999/callback",
153            &registered
154        ));
155    }
156}