Skip to main content

opensession_api/
oauth.rs

1//! Generic OAuth2 provider support.
2//!
3//! Config-driven: no provider-specific code branches. Any OAuth2 provider
4//! (GitHub, GitLab, Gitea, OIDC-compatible) can be added via configuration.
5//!
6//! This module contains only types, URL builders, and JSON parsing.
7//! No HTTP calls or DB access — those live in the backend adapters.
8
9use serde::{Deserialize, Serialize};
10
11use crate::ServiceError;
12
13// ── Provider Configuration ──────────────────────────────────────────────────
14
15/// OAuth2 provider configuration. Loaded from environment variables or config file.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct OAuthProviderConfig {
18    /// Unique provider identifier: "github", "gitlab-corp", "gitea-internal"
19    pub id: String,
20    /// UI display name: "GitHub", "GitLab (Corp)"
21    pub display_name: String,
22
23    // OAuth2 endpoints
24    pub authorize_url: String,
25    pub token_url: String,
26    pub userinfo_url: String,
27    /// Optional separate email endpoint (GitHub-specific: /user/emails)
28    pub email_url: Option<String>,
29
30    pub client_id: String,
31    #[serde(skip_serializing)]
32    pub client_secret: String,
33    pub scopes: String,
34
35    /// JSON field mapping from userinfo response to internal fields
36    pub field_map: OAuthFieldMap,
37
38    /// Skip TLS verification for self-hosted instances (dev only)
39    #[serde(default)]
40    pub tls_skip_verify: bool,
41
42    /// External URL for browser redirects (may differ from token_url for Docker setups)
43    pub external_authorize_url: Option<String>,
44}
45
46/// Maps provider-specific JSON field names to our internal fields.
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct OAuthFieldMap {
49    /// Field containing the user's unique ID: "id" (GitHub/GitLab) or "sub" (OIDC)
50    pub id: String,
51    /// Field containing the username: "login" (GitHub) or "username" (GitLab)
52    pub username: String,
53    /// Field containing the email: "email"
54    pub email: String,
55    /// Field containing the avatar URL: "avatar_url" or "picture"
56    pub avatar: String,
57}
58
59/// Normalized user info extracted from any OAuth provider's userinfo response.
60#[derive(Debug, Clone)]
61pub struct OAuthUserInfo {
62    /// Provider config id (e.g. "github")
63    pub provider_id: String,
64    /// Provider-side user ID (as string)
65    pub provider_user_id: String,
66    pub username: String,
67    pub email: Option<String>,
68    pub avatar_url: Option<String>,
69}
70
71/// Normalize OAuth config values loaded from env/secrets.
72///
73/// Some secret managers preserve trailing newlines/spaces when values are set via
74/// shell pipes. We trim and reject empty results so providers don't get enabled
75/// with unusable credentials.
76pub fn normalize_oauth_config_value(raw: &str) -> Option<String> {
77    let trimmed = raw.trim();
78    if trimmed.is_empty() {
79        None
80    } else {
81        let maybe_unquoted = if trimmed.len() >= 2
82            && ((trimmed.starts_with('"') && trimmed.ends_with('"'))
83                || (trimmed.starts_with('\'') && trimmed.ends_with('\'')))
84        {
85            &trimmed[1..trimmed.len() - 1]
86        } else {
87            trimmed
88        };
89        let normalized = maybe_unquoted.trim();
90        if normalized.is_empty() {
91            None
92        } else {
93            Some(normalized.to_string())
94        }
95    }
96}
97
98// ── URL Builders (pure functions, no HTTP) ──────────────────────────────────
99
100/// Build the OAuth authorize URL that the user's browser should be redirected to.
101pub fn build_authorize_url(
102    config: &OAuthProviderConfig,
103    redirect_uri: &str,
104    state: &str,
105) -> String {
106    let base = config
107        .external_authorize_url
108        .as_deref()
109        .unwrap_or(&config.authorize_url);
110
111    format!(
112        "{}?client_id={}&redirect_uri={}&state={}&scope={}&response_type=code",
113        base,
114        urlencoding(&config.client_id),
115        urlencoding(redirect_uri),
116        urlencoding(state),
117        urlencoding(&config.scopes),
118    )
119}
120
121/// Build the JSON body for the token exchange request.
122pub fn build_token_request_body(
123    config: &OAuthProviderConfig,
124    code: &str,
125    redirect_uri: &str,
126) -> serde_json::Value {
127    serde_json::json!({
128        "client_id": config.client_id,
129        "client_secret": config.client_secret,
130        "code": code,
131        "grant_type": "authorization_code",
132        "redirect_uri": redirect_uri,
133    })
134}
135
136/// Build OAuth2 token request as application/x-www-form-urlencoded pairs.
137///
138/// OAuth2 token exchange endpoints are required to support urlencoded form input.
139pub fn build_token_request_form(
140    config: &OAuthProviderConfig,
141    code: &str,
142    redirect_uri: &str,
143) -> Vec<(String, String)> {
144    vec![
145        ("client_id".into(), config.client_id.clone()),
146        ("client_secret".into(), config.client_secret.clone()),
147        ("code".into(), code.to_string()),
148        ("grant_type".into(), "authorization_code".into()),
149        ("redirect_uri".into(), redirect_uri.to_string()),
150    ]
151}
152
153/// Build OAuth2 token request as x-www-form-urlencoded string.
154pub fn build_token_request_form_encoded(
155    config: &OAuthProviderConfig,
156    code: &str,
157    redirect_uri: &str,
158) -> String {
159    build_token_request_form(config, code, redirect_uri)
160        .into_iter()
161        .map(|(k, v)| format!("{}={}", urlencoding(&k), urlencoding(&v)))
162        .collect::<Vec<_>>()
163        .join("&")
164}
165
166/// Parse access_token from OAuth token response.
167///
168/// Supports both JSON (`{\"access_token\":\"...\"}`) and query-string style
169/// (`access_token=...&scope=...`) payloads.
170pub fn parse_access_token_response(raw: &str) -> Result<String, ServiceError> {
171    let body = raw.trim();
172    if body.is_empty() {
173        return Err(ServiceError::Internal(
174            "OAuth token exchange failed: empty response body".into(),
175        ));
176    }
177
178    if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
179        if let Some(token) = json
180            .get("access_token")
181            .and_then(|v| v.as_str())
182            .map(str::trim)
183            .filter(|s| !s.is_empty())
184        {
185            return Ok(token.to_string());
186        }
187
188        let err = json.get("error").and_then(|v| v.as_str());
189        let err_desc = json
190            .get("error_description")
191            .and_then(|v| v.as_str())
192            .or_else(|| json.get("error_message").and_then(|v| v.as_str()));
193
194        let detail = match (err, err_desc) {
195            (Some(e), Some(d)) if !d.is_empty() => format!("{e}: {d}"),
196            (Some(e), _) => e.to_string(),
197            (_, Some(d)) if !d.is_empty() => d.to_string(),
198            _ => "no access_token field in JSON response".to_string(),
199        };
200
201        return Err(ServiceError::Internal(format!(
202            "OAuth token exchange failed: {detail}"
203        )));
204    }
205
206    let mut access_token: Option<String> = None;
207    let mut error: Option<String> = None;
208    let mut error_description: Option<String> = None;
209
210    for pair in body.split('&') {
211        let (k, v) = pair.split_once('=').unwrap_or((pair, ""));
212        let key = decode_form_component(k);
213        let value = decode_form_component(v);
214        match key.as_str() {
215            "access_token" if !value.trim().is_empty() => access_token = Some(value),
216            "error" if !value.trim().is_empty() => error = Some(value),
217            "error_description" if !value.trim().is_empty() => error_description = Some(value),
218            _ => {}
219        }
220    }
221
222    if let Some(token) = access_token {
223        return Ok(token);
224    }
225
226    let detail = match (error, error_description) {
227        (Some(e), Some(d)) => format!("{e}: {d}"),
228        (Some(e), None) => e,
229        (None, Some(d)) => d,
230        (None, None) => "no access_token field in response".to_string(),
231    };
232
233    Err(ServiceError::Internal(format!(
234        "OAuth token exchange failed: {detail}"
235    )))
236}
237
238/// Extract normalized user info from a provider's userinfo JSON response.
239///
240/// `email_json` is an optional array of email objects (GitHub `/user/emails` format)
241/// used when the primary userinfo endpoint doesn't include the email.
242pub fn extract_user_info(
243    config: &OAuthProviderConfig,
244    userinfo_json: &serde_json::Value,
245    email_json: Option<&[serde_json::Value]>,
246) -> Result<OAuthUserInfo, ServiceError> {
247    // Extract provider user ID — may be number or string depending on provider
248    let provider_user_id = match &userinfo_json[&config.field_map.id] {
249        serde_json::Value::Number(n) => n.to_string(),
250        serde_json::Value::String(s) => s.clone(),
251        _ => {
252            return Err(ServiceError::Internal(format!(
253                "OAuth userinfo missing '{}' field",
254                config.field_map.id
255            )))
256        }
257    };
258
259    let username = userinfo_json[&config.field_map.username]
260        .as_str()
261        .unwrap_or("unknown")
262        .to_string();
263
264    // Email: try userinfo first, then email_json (GitHub format: [{email, primary, verified}])
265    let email = userinfo_json[&config.field_map.email]
266        .as_str()
267        .map(|s| s.to_string())
268        .or_else(|| {
269            email_json.and_then(|emails| {
270                emails
271                    .iter()
272                    .find(|e| e["primary"].as_bool() == Some(true))
273                    .and_then(|e| e["email"].as_str())
274                    .map(|s| s.to_string())
275            })
276        });
277
278    let avatar_url = userinfo_json[&config.field_map.avatar]
279        .as_str()
280        .map(|s| s.to_string());
281
282    Ok(OAuthUserInfo {
283        provider_id: config.id.clone(),
284        provider_user_id,
285        username,
286        email,
287        avatar_url,
288    })
289}
290
291// ── Provider Presets ────────────────────────────────────────────────────────
292
293/// Create a GitHub OAuth2 provider config. Only needs client credentials.
294pub fn github_preset(client_id: String, client_secret: String) -> OAuthProviderConfig {
295    OAuthProviderConfig {
296        id: "github".into(),
297        display_name: "GitHub".into(),
298        authorize_url: "https://github.com/login/oauth/authorize".into(),
299        token_url: "https://github.com/login/oauth/access_token".into(),
300        userinfo_url: "https://api.github.com/user".into(),
301        email_url: Some("https://api.github.com/user/emails".into()),
302        client_id,
303        client_secret,
304        scopes: "read:user,user:email".into(),
305        field_map: OAuthFieldMap {
306            id: "id".into(),
307            username: "login".into(),
308            email: "email".into(),
309            avatar: "avatar_url".into(),
310        },
311        tls_skip_verify: false,
312        external_authorize_url: None,
313    }
314}
315
316/// Create a GitLab OAuth2 provider config for a given instance URL.
317///
318/// `instance_url` is the server-accessible URL (e.g. `http://gitlab:80` in Docker).
319/// `external_url` is the browser-accessible URL (e.g. `http://localhost:8929`).
320/// If `external_url` is None, `instance_url` is used for browser redirects too.
321pub fn gitlab_preset(
322    instance_url: String,
323    external_url: Option<String>,
324    client_id: String,
325    client_secret: String,
326) -> OAuthProviderConfig {
327    let base = instance_url.trim_end_matches('/');
328    let ext_base = external_url
329        .as_deref()
330        .map(|u| u.trim_end_matches('/').to_string());
331
332    OAuthProviderConfig {
333        id: "gitlab".into(),
334        display_name: "GitLab".into(),
335        authorize_url: format!("{base}/oauth/authorize"),
336        token_url: format!("{base}/oauth/token"),
337        userinfo_url: format!("{base}/api/v4/user"),
338        email_url: None, // GitLab includes email in /api/v4/user
339        client_id,
340        client_secret,
341        scopes: "read_user".into(),
342        field_map: OAuthFieldMap {
343            id: "id".into(),
344            username: "username".into(),
345            email: "email".into(),
346            avatar: "avatar_url".into(),
347        },
348        tls_skip_verify: false,
349        external_authorize_url: ext_base.map(|b| format!("{b}/oauth/authorize")),
350    }
351}
352
353// ── API Response Types ──────────────────────────────────────────────────────
354
355/// Available auth providers (returned by GET /api/auth/providers).
356#[derive(Debug, Serialize, Deserialize)]
357#[cfg_attr(feature = "ts", derive(ts_rs::TS))]
358#[cfg_attr(feature = "ts", ts(export))]
359pub struct AuthProvidersResponse {
360    pub email_password: bool,
361    pub oauth: Vec<OAuthProviderInfo>,
362}
363
364/// Public info about an OAuth provider.
365#[derive(Debug, Serialize, Deserialize)]
366#[cfg_attr(feature = "ts", derive(ts_rs::TS))]
367#[cfg_attr(feature = "ts", ts(export))]
368pub struct OAuthProviderInfo {
369    pub id: String,
370    pub display_name: String,
371}
372
373/// A linked OAuth provider shown in user settings.
374#[derive(Debug, Serialize, Deserialize)]
375#[cfg_attr(feature = "ts", derive(ts_rs::TS))]
376#[cfg_attr(feature = "ts", ts(export))]
377pub struct LinkedProvider {
378    pub provider: String,
379    pub provider_username: String,
380    pub display_name: String,
381}
382
383// ── Helpers ─────────────────────────────────────────────────────────────────
384
385fn urlencoding(s: &str) -> String {
386    // Minimal URL-encoding for OAuth parameters
387    let mut out = String::with_capacity(s.len());
388    for b in s.bytes() {
389        match b {
390            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
391                out.push(b as char);
392            }
393            _ => {
394                out.push('%');
395                out.push(char::from(b"0123456789ABCDEF"[(b >> 4) as usize]));
396                out.push(char::from(b"0123456789ABCDEF"[(b & 0x0f) as usize]));
397            }
398        }
399    }
400    out
401}
402
403fn decode_form_component(s: &str) -> String {
404    let bytes = s.as_bytes();
405    let mut out = Vec::with_capacity(bytes.len());
406    let mut i = 0usize;
407    while i < bytes.len() {
408        match bytes[i] {
409            b'+' => {
410                out.push(b' ');
411                i += 1;
412            }
413            b'%' if i + 2 < bytes.len() => {
414                let hi = hex_value(bytes[i + 1]);
415                let lo = hex_value(bytes[i + 2]);
416                if let (Some(h), Some(l)) = (hi, lo) {
417                    out.push((h << 4) | l);
418                    i += 3;
419                } else {
420                    out.push(bytes[i]);
421                    i += 1;
422                }
423            }
424            b => {
425                out.push(b);
426                i += 1;
427            }
428        }
429    }
430    String::from_utf8_lossy(&out).to_string()
431}
432
433fn hex_value(b: u8) -> Option<u8> {
434    match b {
435        b'0'..=b'9' => Some(b - b'0'),
436        b'a'..=b'f' => Some(10 + b - b'a'),
437        b'A'..=b'F' => Some(10 + b - b'A'),
438        _ => None,
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::{github_preset, normalize_oauth_config_value, parse_access_token_response};
445
446    #[test]
447    fn parse_access_token_json_ok() {
448        let raw = r#"{"access_token":"gho_123","scope":"read:user","token_type":"bearer"}"#;
449        let token = parse_access_token_response(raw).expect("token parse");
450        assert_eq!(token, "gho_123");
451    }
452
453    #[test]
454    fn parse_access_token_form_ok() {
455        let raw = "access_token=gho_abc&scope=read%3Auser&token_type=bearer";
456        let token = parse_access_token_response(raw).expect("token parse");
457        assert_eq!(token, "gho_abc");
458    }
459
460    #[test]
461    fn parse_access_token_json_error_has_reason() {
462        let raw = r#"{"error":"bad_verification_code","error_description":"The code passed is incorrect or expired."}"#;
463        let err = parse_access_token_response(raw).expect_err("must fail");
464        assert!(err.message().contains("bad_verification_code"));
465    }
466
467    #[test]
468    fn build_form_encoded_contains_required_fields() {
469        let provider = github_preset("cid".into(), "secret".into());
470        let encoded =
471            super::build_token_request_form_encoded(&provider, "code-1", "https://app/callback");
472        assert!(encoded.contains("client_id=cid"));
473        assert!(encoded.contains("client_secret=secret"));
474        assert!(encoded.contains("grant_type=authorization_code"));
475        assert!(encoded.contains("code=code-1"));
476    }
477
478    #[test]
479    fn normalize_oauth_config_value_trims_and_rejects_empty() {
480        assert_eq!(
481            normalize_oauth_config_value("  value-with-spaces\t\n"),
482            Some("value-with-spaces".to_string())
483        );
484        assert_eq!(normalize_oauth_config_value("   \n\t  "), None);
485    }
486
487    #[test]
488    fn normalize_oauth_config_value_strips_wrapping_quotes() {
489        assert_eq!(
490            normalize_oauth_config_value(" \"quoted-value\" "),
491            Some("quoted-value".to_string())
492        );
493        assert_eq!(
494            normalize_oauth_config_value(" 'another' "),
495            Some("another".to_string())
496        );
497        assert_eq!(normalize_oauth_config_value("  \"   \" "), None);
498    }
499}