Skip to main content

opensession_api/
oauth.rs

1//! Generic OAuth2 provider support.
2//!
3//! Config-driven: no provider-specific code branches. Any OAuth2 provider
4//! (GitHub, GitLab, Gitea, OIDC-compatible) can be added via configuration.
5//!
6//! This module contains only types, URL builders, and JSON parsing.
7//! No HTTP calls or DB access — those live in the backend adapters.
8
9use serde::{Deserialize, Serialize};
10
11use crate::ServiceError;
12
13// ── Provider Configuration ──────────────────────────────────────────────────
14
15/// OAuth2 provider configuration. Loaded from environment variables or config file.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct OAuthProviderConfig {
18    /// Unique provider identifier: "github", "gitlab-corp", "gitea-internal"
19    pub id: String,
20    /// UI display name: "GitHub", "GitLab (Corp)"
21    pub display_name: String,
22
23    // OAuth2 endpoints
24    pub authorize_url: String,
25    pub token_url: String,
26    pub userinfo_url: String,
27    /// Optional separate email endpoint (GitHub-specific: /user/emails)
28    pub email_url: Option<String>,
29
30    pub client_id: String,
31    #[serde(skip_serializing)]
32    pub client_secret: String,
33    pub scopes: String,
34
35    /// JSON field mapping from userinfo response to internal fields
36    pub field_map: OAuthFieldMap,
37
38    /// Skip TLS verification for self-hosted instances (dev only)
39    #[serde(default)]
40    pub tls_skip_verify: bool,
41
42    /// External URL for browser redirects (may differ from token_url for Docker setups)
43    pub external_authorize_url: Option<String>,
44}
45
46/// Maps provider-specific JSON field names to our internal fields.
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct OAuthFieldMap {
49    /// Field containing the user's unique ID: "id" (GitHub/GitLab) or "sub" (OIDC)
50    pub id: String,
51    /// Field containing the username: "login" (GitHub) or "username" (GitLab)
52    pub username: String,
53    /// Field containing the email: "email"
54    pub email: String,
55    /// Field containing the avatar URL: "avatar_url" or "picture"
56    pub avatar: String,
57}
58
59/// Normalized user info extracted from any OAuth provider's userinfo response.
60#[derive(Debug, Clone)]
61pub struct OAuthUserInfo {
62    /// Provider config id (e.g. "github")
63    pub provider_id: String,
64    /// Provider-side user ID (as string)
65    pub provider_user_id: String,
66    pub username: String,
67    pub email: Option<String>,
68    pub avatar_url: Option<String>,
69}
70
71// ── URL Builders (pure functions, no HTTP) ──────────────────────────────────
72
73/// Build the OAuth authorize URL that the user's browser should be redirected to.
74pub fn build_authorize_url(
75    config: &OAuthProviderConfig,
76    redirect_uri: &str,
77    state: &str,
78) -> String {
79    let base = config
80        .external_authorize_url
81        .as_deref()
82        .unwrap_or(&config.authorize_url);
83
84    format!(
85        "{}?client_id={}&redirect_uri={}&state={}&scope={}&response_type=code",
86        base,
87        urlencoding(&config.client_id),
88        urlencoding(redirect_uri),
89        urlencoding(state),
90        urlencoding(&config.scopes),
91    )
92}
93
94/// Build the JSON body for the token exchange request.
95pub fn build_token_request_body(
96    config: &OAuthProviderConfig,
97    code: &str,
98    redirect_uri: &str,
99) -> serde_json::Value {
100    serde_json::json!({
101        "client_id": config.client_id,
102        "client_secret": config.client_secret,
103        "code": code,
104        "grant_type": "authorization_code",
105        "redirect_uri": redirect_uri,
106    })
107}
108
109/// Build OAuth2 token request as application/x-www-form-urlencoded pairs.
110///
111/// OAuth2 token exchange endpoints are required to support urlencoded form input.
112pub fn build_token_request_form(
113    config: &OAuthProviderConfig,
114    code: &str,
115    redirect_uri: &str,
116) -> Vec<(String, String)> {
117    vec![
118        ("client_id".into(), config.client_id.clone()),
119        ("client_secret".into(), config.client_secret.clone()),
120        ("code".into(), code.to_string()),
121        ("grant_type".into(), "authorization_code".into()),
122        ("redirect_uri".into(), redirect_uri.to_string()),
123    ]
124}
125
126/// Build OAuth2 token request as x-www-form-urlencoded string.
127pub fn build_token_request_form_encoded(
128    config: &OAuthProviderConfig,
129    code: &str,
130    redirect_uri: &str,
131) -> String {
132    build_token_request_form(config, code, redirect_uri)
133        .into_iter()
134        .map(|(k, v)| format!("{}={}", urlencoding(&k), urlencoding(&v)))
135        .collect::<Vec<_>>()
136        .join("&")
137}
138
139/// Parse access_token from OAuth token response.
140///
141/// Supports both JSON (`{\"access_token\":\"...\"}`) and query-string style
142/// (`access_token=...&scope=...`) payloads.
143pub fn parse_access_token_response(raw: &str) -> Result<String, ServiceError> {
144    let body = raw.trim();
145    if body.is_empty() {
146        return Err(ServiceError::Internal(
147            "OAuth token exchange failed: empty response body".into(),
148        ));
149    }
150
151    if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
152        if let Some(token) = json
153            .get("access_token")
154            .and_then(|v| v.as_str())
155            .map(str::trim)
156            .filter(|s| !s.is_empty())
157        {
158            return Ok(token.to_string());
159        }
160
161        let err = json.get("error").and_then(|v| v.as_str());
162        let err_desc = json
163            .get("error_description")
164            .and_then(|v| v.as_str())
165            .or_else(|| json.get("error_message").and_then(|v| v.as_str()));
166
167        let detail = match (err, err_desc) {
168            (Some(e), Some(d)) if !d.is_empty() => format!("{e}: {d}"),
169            (Some(e), _) => e.to_string(),
170            (_, Some(d)) if !d.is_empty() => d.to_string(),
171            _ => "no access_token field in JSON response".to_string(),
172        };
173
174        return Err(ServiceError::Internal(format!(
175            "OAuth token exchange failed: {detail}"
176        )));
177    }
178
179    let mut access_token: Option<String> = None;
180    let mut error: Option<String> = None;
181    let mut error_description: Option<String> = None;
182
183    for pair in body.split('&') {
184        let (k, v) = pair.split_once('=').unwrap_or((pair, ""));
185        let key = decode_form_component(k);
186        let value = decode_form_component(v);
187        match key.as_str() {
188            "access_token" if !value.trim().is_empty() => access_token = Some(value),
189            "error" if !value.trim().is_empty() => error = Some(value),
190            "error_description" if !value.trim().is_empty() => error_description = Some(value),
191            _ => {}
192        }
193    }
194
195    if let Some(token) = access_token {
196        return Ok(token);
197    }
198
199    let detail = match (error, error_description) {
200        (Some(e), Some(d)) => format!("{e}: {d}"),
201        (Some(e), None) => e,
202        (None, Some(d)) => d,
203        (None, None) => "no access_token field in response".to_string(),
204    };
205
206    Err(ServiceError::Internal(format!(
207        "OAuth token exchange failed: {detail}"
208    )))
209}
210
211/// Extract normalized user info from a provider's userinfo JSON response.
212///
213/// `email_json` is an optional array of email objects (GitHub `/user/emails` format)
214/// used when the primary userinfo endpoint doesn't include the email.
215pub fn extract_user_info(
216    config: &OAuthProviderConfig,
217    userinfo_json: &serde_json::Value,
218    email_json: Option<&[serde_json::Value]>,
219) -> Result<OAuthUserInfo, ServiceError> {
220    // Extract provider user ID — may be number or string depending on provider
221    let provider_user_id = match &userinfo_json[&config.field_map.id] {
222        serde_json::Value::Number(n) => n.to_string(),
223        serde_json::Value::String(s) => s.clone(),
224        _ => {
225            return Err(ServiceError::Internal(format!(
226                "OAuth userinfo missing '{}' field",
227                config.field_map.id
228            )))
229        }
230    };
231
232    let username = userinfo_json[&config.field_map.username]
233        .as_str()
234        .unwrap_or("unknown")
235        .to_string();
236
237    // Email: try userinfo first, then email_json (GitHub format: [{email, primary, verified}])
238    let email = userinfo_json[&config.field_map.email]
239        .as_str()
240        .map(|s| s.to_string())
241        .or_else(|| {
242            email_json.and_then(|emails| {
243                emails
244                    .iter()
245                    .find(|e| e["primary"].as_bool() == Some(true))
246                    .and_then(|e| e["email"].as_str())
247                    .map(|s| s.to_string())
248            })
249        });
250
251    let avatar_url = userinfo_json[&config.field_map.avatar]
252        .as_str()
253        .map(|s| s.to_string());
254
255    Ok(OAuthUserInfo {
256        provider_id: config.id.clone(),
257        provider_user_id,
258        username,
259        email,
260        avatar_url,
261    })
262}
263
264// ── Provider Presets ────────────────────────────────────────────────────────
265
266/// Create a GitHub OAuth2 provider config. Only needs client credentials.
267pub fn github_preset(client_id: String, client_secret: String) -> OAuthProviderConfig {
268    OAuthProviderConfig {
269        id: "github".into(),
270        display_name: "GitHub".into(),
271        authorize_url: "https://github.com/login/oauth/authorize".into(),
272        token_url: "https://github.com/login/oauth/access_token".into(),
273        userinfo_url: "https://api.github.com/user".into(),
274        email_url: Some("https://api.github.com/user/emails".into()),
275        client_id,
276        client_secret,
277        scopes: "read:user,user:email".into(),
278        field_map: OAuthFieldMap {
279            id: "id".into(),
280            username: "login".into(),
281            email: "email".into(),
282            avatar: "avatar_url".into(),
283        },
284        tls_skip_verify: false,
285        external_authorize_url: None,
286    }
287}
288
289/// Create a GitLab OAuth2 provider config for a given instance URL.
290///
291/// `instance_url` is the server-accessible URL (e.g. `http://gitlab:80` in Docker).
292/// `external_url` is the browser-accessible URL (e.g. `http://localhost:8929`).
293/// If `external_url` is None, `instance_url` is used for browser redirects too.
294pub fn gitlab_preset(
295    instance_url: String,
296    external_url: Option<String>,
297    client_id: String,
298    client_secret: String,
299) -> OAuthProviderConfig {
300    let base = instance_url.trim_end_matches('/');
301    let ext_base = external_url
302        .as_deref()
303        .map(|u| u.trim_end_matches('/').to_string());
304
305    OAuthProviderConfig {
306        id: "gitlab".into(),
307        display_name: "GitLab".into(),
308        authorize_url: format!("{base}/oauth/authorize"),
309        token_url: format!("{base}/oauth/token"),
310        userinfo_url: format!("{base}/api/v4/user"),
311        email_url: None, // GitLab includes email in /api/v4/user
312        client_id,
313        client_secret,
314        scopes: "read_user".into(),
315        field_map: OAuthFieldMap {
316            id: "id".into(),
317            username: "username".into(),
318            email: "email".into(),
319            avatar: "avatar_url".into(),
320        },
321        tls_skip_verify: false,
322        external_authorize_url: ext_base.map(|b| format!("{b}/oauth/authorize")),
323    }
324}
325
326// ── API Response Types ──────────────────────────────────────────────────────
327
328/// Available auth providers (returned by GET /api/auth/providers).
329#[derive(Debug, Serialize, Deserialize)]
330#[cfg_attr(feature = "ts", derive(ts_rs::TS))]
331#[cfg_attr(feature = "ts", ts(export))]
332pub struct AuthProvidersResponse {
333    pub email_password: bool,
334    pub oauth: Vec<OAuthProviderInfo>,
335}
336
337/// Public info about an OAuth provider.
338#[derive(Debug, Serialize, Deserialize)]
339#[cfg_attr(feature = "ts", derive(ts_rs::TS))]
340#[cfg_attr(feature = "ts", ts(export))]
341pub struct OAuthProviderInfo {
342    pub id: String,
343    pub display_name: String,
344}
345
346/// A linked OAuth provider shown in user settings.
347#[derive(Debug, Serialize, Deserialize)]
348#[cfg_attr(feature = "ts", derive(ts_rs::TS))]
349#[cfg_attr(feature = "ts", ts(export))]
350pub struct LinkedProvider {
351    pub provider: String,
352    pub provider_username: String,
353    pub display_name: String,
354}
355
356// ── Helpers ─────────────────────────────────────────────────────────────────
357
358fn urlencoding(s: &str) -> String {
359    // Minimal URL-encoding for OAuth parameters
360    let mut out = String::with_capacity(s.len());
361    for b in s.bytes() {
362        match b {
363            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
364                out.push(b as char);
365            }
366            _ => {
367                out.push('%');
368                out.push(char::from(b"0123456789ABCDEF"[(b >> 4) as usize]));
369                out.push(char::from(b"0123456789ABCDEF"[(b & 0x0f) as usize]));
370            }
371        }
372    }
373    out
374}
375
376fn decode_form_component(s: &str) -> String {
377    let bytes = s.as_bytes();
378    let mut out = Vec::with_capacity(bytes.len());
379    let mut i = 0usize;
380    while i < bytes.len() {
381        match bytes[i] {
382            b'+' => {
383                out.push(b' ');
384                i += 1;
385            }
386            b'%' if i + 2 < bytes.len() => {
387                let hi = hex_value(bytes[i + 1]);
388                let lo = hex_value(bytes[i + 2]);
389                if let (Some(h), Some(l)) = (hi, lo) {
390                    out.push((h << 4) | l);
391                    i += 3;
392                } else {
393                    out.push(bytes[i]);
394                    i += 1;
395                }
396            }
397            b => {
398                out.push(b);
399                i += 1;
400            }
401        }
402    }
403    String::from_utf8_lossy(&out).to_string()
404}
405
406fn hex_value(b: u8) -> Option<u8> {
407    match b {
408        b'0'..=b'9' => Some(b - b'0'),
409        b'a'..=b'f' => Some(10 + b - b'a'),
410        b'A'..=b'F' => Some(10 + b - b'A'),
411        _ => None,
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::{github_preset, parse_access_token_response};
418
419    #[test]
420    fn parse_access_token_json_ok() {
421        let raw = r#"{"access_token":"gho_123","scope":"read:user","token_type":"bearer"}"#;
422        let token = parse_access_token_response(raw).expect("token parse");
423        assert_eq!(token, "gho_123");
424    }
425
426    #[test]
427    fn parse_access_token_form_ok() {
428        let raw = "access_token=gho_abc&scope=read%3Auser&token_type=bearer";
429        let token = parse_access_token_response(raw).expect("token parse");
430        assert_eq!(token, "gho_abc");
431    }
432
433    #[test]
434    fn parse_access_token_json_error_has_reason() {
435        let raw = r#"{"error":"bad_verification_code","error_description":"The code passed is incorrect or expired."}"#;
436        let err = parse_access_token_response(raw).expect_err("must fail");
437        assert!(err.message().contains("bad_verification_code"));
438    }
439
440    #[test]
441    fn build_form_encoded_contains_required_fields() {
442        let provider = github_preset("cid".into(), "secret".into());
443        let encoded =
444            super::build_token_request_form_encoded(&provider, "code-1", "https://app/callback");
445        assert!(encoded.contains("client_id=cid"));
446        assert!(encoded.contains("client_secret=secret"));
447        assert!(encoded.contains("grant_type=authorization_code"));
448        assert!(encoded.contains("code=code-1"));
449    }
450}