1use serde::{Deserialize, Serialize};
10
11use crate::ServiceError;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct OAuthProviderConfig {
18 pub id: String,
20 pub display_name: String,
22
23 pub authorize_url: String,
25 pub token_url: String,
26 pub userinfo_url: String,
27 pub email_url: Option<String>,
29
30 pub client_id: String,
31 #[serde(skip_serializing)]
32 pub client_secret: String,
33 pub scopes: String,
34
35 pub field_map: OAuthFieldMap,
37
38 #[serde(default)]
40 pub tls_skip_verify: bool,
41
42 pub external_authorize_url: Option<String>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct OAuthFieldMap {
49 pub id: String,
51 pub username: String,
53 pub email: String,
55 pub avatar: String,
57}
58
59#[derive(Debug, Clone)]
61pub struct OAuthUserInfo {
62 pub provider_id: String,
64 pub provider_user_id: String,
66 pub username: String,
67 pub email: Option<String>,
68 pub avatar_url: Option<String>,
69}
70
71pub fn build_authorize_url(
75 config: &OAuthProviderConfig,
76 redirect_uri: &str,
77 state: &str,
78) -> String {
79 let base = config
80 .external_authorize_url
81 .as_deref()
82 .unwrap_or(&config.authorize_url);
83
84 format!(
85 "{}?client_id={}&redirect_uri={}&state={}&scope={}&response_type=code",
86 base,
87 urlencoding(&config.client_id),
88 urlencoding(redirect_uri),
89 urlencoding(state),
90 urlencoding(&config.scopes),
91 )
92}
93
94pub fn build_token_request_body(
96 config: &OAuthProviderConfig,
97 code: &str,
98 redirect_uri: &str,
99) -> serde_json::Value {
100 serde_json::json!({
101 "client_id": config.client_id,
102 "client_secret": config.client_secret,
103 "code": code,
104 "grant_type": "authorization_code",
105 "redirect_uri": redirect_uri,
106 })
107}
108
109pub fn build_token_request_form(
113 config: &OAuthProviderConfig,
114 code: &str,
115 redirect_uri: &str,
116) -> Vec<(String, String)> {
117 vec![
118 ("client_id".into(), config.client_id.clone()),
119 ("client_secret".into(), config.client_secret.clone()),
120 ("code".into(), code.to_string()),
121 ("grant_type".into(), "authorization_code".into()),
122 ("redirect_uri".into(), redirect_uri.to_string()),
123 ]
124}
125
126pub fn build_token_request_form_encoded(
128 config: &OAuthProviderConfig,
129 code: &str,
130 redirect_uri: &str,
131) -> String {
132 build_token_request_form(config, code, redirect_uri)
133 .into_iter()
134 .map(|(k, v)| format!("{}={}", urlencoding(&k), urlencoding(&v)))
135 .collect::<Vec<_>>()
136 .join("&")
137}
138
139pub fn parse_access_token_response(raw: &str) -> Result<String, ServiceError> {
144 let body = raw.trim();
145 if body.is_empty() {
146 return Err(ServiceError::Internal(
147 "OAuth token exchange failed: empty response body".into(),
148 ));
149 }
150
151 if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
152 if let Some(token) = json
153 .get("access_token")
154 .and_then(|v| v.as_str())
155 .map(str::trim)
156 .filter(|s| !s.is_empty())
157 {
158 return Ok(token.to_string());
159 }
160
161 let err = json.get("error").and_then(|v| v.as_str());
162 let err_desc = json
163 .get("error_description")
164 .and_then(|v| v.as_str())
165 .or_else(|| json.get("error_message").and_then(|v| v.as_str()));
166
167 let detail = match (err, err_desc) {
168 (Some(e), Some(d)) if !d.is_empty() => format!("{e}: {d}"),
169 (Some(e), _) => e.to_string(),
170 (_, Some(d)) if !d.is_empty() => d.to_string(),
171 _ => "no access_token field in JSON response".to_string(),
172 };
173
174 return Err(ServiceError::Internal(format!(
175 "OAuth token exchange failed: {detail}"
176 )));
177 }
178
179 let mut access_token: Option<String> = None;
180 let mut error: Option<String> = None;
181 let mut error_description: Option<String> = None;
182
183 for pair in body.split('&') {
184 let (k, v) = pair.split_once('=').unwrap_or((pair, ""));
185 let key = decode_form_component(k);
186 let value = decode_form_component(v);
187 match key.as_str() {
188 "access_token" if !value.trim().is_empty() => access_token = Some(value),
189 "error" if !value.trim().is_empty() => error = Some(value),
190 "error_description" if !value.trim().is_empty() => error_description = Some(value),
191 _ => {}
192 }
193 }
194
195 if let Some(token) = access_token {
196 return Ok(token);
197 }
198
199 let detail = match (error, error_description) {
200 (Some(e), Some(d)) => format!("{e}: {d}"),
201 (Some(e), None) => e,
202 (None, Some(d)) => d,
203 (None, None) => "no access_token field in response".to_string(),
204 };
205
206 Err(ServiceError::Internal(format!(
207 "OAuth token exchange failed: {detail}"
208 )))
209}
210
211pub fn extract_user_info(
216 config: &OAuthProviderConfig,
217 userinfo_json: &serde_json::Value,
218 email_json: Option<&[serde_json::Value]>,
219) -> Result<OAuthUserInfo, ServiceError> {
220 let provider_user_id = match &userinfo_json[&config.field_map.id] {
222 serde_json::Value::Number(n) => n.to_string(),
223 serde_json::Value::String(s) => s.clone(),
224 _ => {
225 return Err(ServiceError::Internal(format!(
226 "OAuth userinfo missing '{}' field",
227 config.field_map.id
228 )))
229 }
230 };
231
232 let username = userinfo_json[&config.field_map.username]
233 .as_str()
234 .unwrap_or("unknown")
235 .to_string();
236
237 let email = userinfo_json[&config.field_map.email]
239 .as_str()
240 .map(|s| s.to_string())
241 .or_else(|| {
242 email_json.and_then(|emails| {
243 emails
244 .iter()
245 .find(|e| e["primary"].as_bool() == Some(true))
246 .and_then(|e| e["email"].as_str())
247 .map(|s| s.to_string())
248 })
249 });
250
251 let avatar_url = userinfo_json[&config.field_map.avatar]
252 .as_str()
253 .map(|s| s.to_string());
254
255 Ok(OAuthUserInfo {
256 provider_id: config.id.clone(),
257 provider_user_id,
258 username,
259 email,
260 avatar_url,
261 })
262}
263
264pub fn github_preset(client_id: String, client_secret: String) -> OAuthProviderConfig {
268 OAuthProviderConfig {
269 id: "github".into(),
270 display_name: "GitHub".into(),
271 authorize_url: "https://github.com/login/oauth/authorize".into(),
272 token_url: "https://github.com/login/oauth/access_token".into(),
273 userinfo_url: "https://api.github.com/user".into(),
274 email_url: Some("https://api.github.com/user/emails".into()),
275 client_id,
276 client_secret,
277 scopes: "read:user,user:email".into(),
278 field_map: OAuthFieldMap {
279 id: "id".into(),
280 username: "login".into(),
281 email: "email".into(),
282 avatar: "avatar_url".into(),
283 },
284 tls_skip_verify: false,
285 external_authorize_url: None,
286 }
287}
288
289pub fn gitlab_preset(
295 instance_url: String,
296 external_url: Option<String>,
297 client_id: String,
298 client_secret: String,
299) -> OAuthProviderConfig {
300 let base = instance_url.trim_end_matches('/');
301 let ext_base = external_url
302 .as_deref()
303 .map(|u| u.trim_end_matches('/').to_string());
304
305 OAuthProviderConfig {
306 id: "gitlab".into(),
307 display_name: "GitLab".into(),
308 authorize_url: format!("{base}/oauth/authorize"),
309 token_url: format!("{base}/oauth/token"),
310 userinfo_url: format!("{base}/api/v4/user"),
311 email_url: None, client_id,
313 client_secret,
314 scopes: "read_user".into(),
315 field_map: OAuthFieldMap {
316 id: "id".into(),
317 username: "username".into(),
318 email: "email".into(),
319 avatar: "avatar_url".into(),
320 },
321 tls_skip_verify: false,
322 external_authorize_url: ext_base.map(|b| format!("{b}/oauth/authorize")),
323 }
324}
325
326#[derive(Debug, Serialize, Deserialize)]
330#[cfg_attr(feature = "ts", derive(ts_rs::TS))]
331#[cfg_attr(feature = "ts", ts(export))]
332pub struct AuthProvidersResponse {
333 pub email_password: bool,
334 pub oauth: Vec<OAuthProviderInfo>,
335}
336
337#[derive(Debug, Serialize, Deserialize)]
339#[cfg_attr(feature = "ts", derive(ts_rs::TS))]
340#[cfg_attr(feature = "ts", ts(export))]
341pub struct OAuthProviderInfo {
342 pub id: String,
343 pub display_name: String,
344}
345
346#[derive(Debug, Serialize, Deserialize)]
348#[cfg_attr(feature = "ts", derive(ts_rs::TS))]
349#[cfg_attr(feature = "ts", ts(export))]
350pub struct LinkedProvider {
351 pub provider: String,
352 pub provider_username: String,
353 pub display_name: String,
354}
355
356fn urlencoding(s: &str) -> String {
359 let mut out = String::with_capacity(s.len());
361 for b in s.bytes() {
362 match b {
363 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
364 out.push(b as char);
365 }
366 _ => {
367 out.push('%');
368 out.push(char::from(b"0123456789ABCDEF"[(b >> 4) as usize]));
369 out.push(char::from(b"0123456789ABCDEF"[(b & 0x0f) as usize]));
370 }
371 }
372 }
373 out
374}
375
376fn decode_form_component(s: &str) -> String {
377 let bytes = s.as_bytes();
378 let mut out = Vec::with_capacity(bytes.len());
379 let mut i = 0usize;
380 while i < bytes.len() {
381 match bytes[i] {
382 b'+' => {
383 out.push(b' ');
384 i += 1;
385 }
386 b'%' if i + 2 < bytes.len() => {
387 let hi = hex_value(bytes[i + 1]);
388 let lo = hex_value(bytes[i + 2]);
389 if let (Some(h), Some(l)) = (hi, lo) {
390 out.push((h << 4) | l);
391 i += 3;
392 } else {
393 out.push(bytes[i]);
394 i += 1;
395 }
396 }
397 b => {
398 out.push(b);
399 i += 1;
400 }
401 }
402 }
403 String::from_utf8_lossy(&out).to_string()
404}
405
406fn hex_value(b: u8) -> Option<u8> {
407 match b {
408 b'0'..=b'9' => Some(b - b'0'),
409 b'a'..=b'f' => Some(10 + b - b'a'),
410 b'A'..=b'F' => Some(10 + b - b'A'),
411 _ => None,
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::{github_preset, parse_access_token_response};
418
419 #[test]
420 fn parse_access_token_json_ok() {
421 let raw = r#"{"access_token":"gho_123","scope":"read:user","token_type":"bearer"}"#;
422 let token = parse_access_token_response(raw).expect("token parse");
423 assert_eq!(token, "gho_123");
424 }
425
426 #[test]
427 fn parse_access_token_form_ok() {
428 let raw = "access_token=gho_abc&scope=read%3Auser&token_type=bearer";
429 let token = parse_access_token_response(raw).expect("token parse");
430 assert_eq!(token, "gho_abc");
431 }
432
433 #[test]
434 fn parse_access_token_json_error_has_reason() {
435 let raw = r#"{"error":"bad_verification_code","error_description":"The code passed is incorrect or expired."}"#;
436 let err = parse_access_token_response(raw).expect_err("must fail");
437 assert!(err.message().contains("bad_verification_code"));
438 }
439
440 #[test]
441 fn build_form_encoded_contains_required_fields() {
442 let provider = github_preset("cid".into(), "secret".into());
443 let encoded =
444 super::build_token_request_form_encoded(&provider, "code-1", "https://app/callback");
445 assert!(encoded.contains("client_id=cid"));
446 assert!(encoded.contains("client_secret=secret"));
447 assert!(encoded.contains("grant_type=authorization_code"));
448 assert!(encoded.contains("code=code-1"));
449 }
450}