1use std::collections::BTreeMap;
2use std::fs;
3use std::io;
4use std::path::PathBuf;
5
6use serde::{Deserialize, Serialize};
7use serde_json::{Map, Value};
8use sha2::{Digest, Sha256};
9
10use crate::config::OAuthConfig;
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
13pub struct OAuthTokenSet {
14 pub access_token: String,
15 pub refresh_token: Option<String>,
16 pub expires_at: Option<u64>,
17 pub scopes: Vec<String>,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct PkceCodePair {
22 pub verifier: String,
23 pub challenge: String,
24 pub challenge_method: PkceChallengeMethod,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum PkceChallengeMethod {
29 S256,
30}
31
32impl PkceChallengeMethod {
33 #[must_use]
34 pub const fn as_str(self) -> &'static str {
35 match self {
36 Self::S256 => "S256",
37 }
38 }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub struct OAuthAuthorizationRequest {
43 pub authorize_url: String,
44 pub client_id: String,
45 pub redirect_uri: String,
46 pub scopes: Vec<String>,
47 pub state: String,
48 pub code_challenge: String,
49 pub code_challenge_method: PkceChallengeMethod,
50 pub extra_params: BTreeMap<String, String>,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct OAuthTokenExchangeRequest {
55 pub grant_type: &'static str,
56 pub code: String,
57 pub redirect_uri: String,
58 pub client_id: String,
59 pub code_verifier: String,
60 pub state: String,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct OAuthRefreshRequest {
65 pub grant_type: &'static str,
66 pub refresh_token: String,
67 pub client_id: String,
68 pub scopes: Vec<String>,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct OAuthCallbackParams {
73 pub code: Option<String>,
74 pub state: Option<String>,
75 pub error: Option<String>,
76 pub error_description: Option<String>,
77}
78
79#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
80#[serde(rename_all = "camelCase")]
81struct StoredOAuthCredentials {
82 access_token: String,
83 #[serde(default)]
84 refresh_token: Option<String>,
85 #[serde(default)]
86 expires_at: Option<u64>,
87 #[serde(default)]
88 scopes: Vec<String>,
89}
90
91impl From<OAuthTokenSet> for StoredOAuthCredentials {
92 fn from(value: OAuthTokenSet) -> Self {
93 Self {
94 access_token: value.access_token,
95 refresh_token: value.refresh_token,
96 expires_at: value.expires_at,
97 scopes: value.scopes,
98 }
99 }
100}
101
102impl From<StoredOAuthCredentials> for OAuthTokenSet {
103 fn from(value: StoredOAuthCredentials) -> Self {
104 Self {
105 access_token: value.access_token,
106 refresh_token: value.refresh_token,
107 expires_at: value.expires_at,
108 scopes: value.scopes,
109 }
110 }
111}
112
113impl OAuthAuthorizationRequest {
114 #[must_use]
115 pub fn from_config(
116 config: &OAuthConfig,
117 redirect_uri: impl Into<String>,
118 state: impl Into<String>,
119 pkce: &PkceCodePair,
120 ) -> Self {
121 Self {
122 authorize_url: config.authorize_url.clone(),
123 client_id: config.client_id.clone(),
124 redirect_uri: redirect_uri.into(),
125 scopes: config.scopes.clone(),
126 state: state.into(),
127 code_challenge: pkce.challenge.clone(),
128 code_challenge_method: pkce.challenge_method,
129 extra_params: BTreeMap::new(),
130 }
131 }
132
133 #[must_use]
134 pub fn with_extra_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
135 self.extra_params.insert(key.into(), value.into());
136 self
137 }
138
139 #[must_use]
140 pub fn build_url(&self) -> String {
141 let mut params = vec![
142 ("response_type", "code".to_string()),
143 ("client_id", self.client_id.clone()),
144 ("redirect_uri", self.redirect_uri.clone()),
145 ("scope", self.scopes.join(" ")),
146 ("state", self.state.clone()),
147 ("code_challenge", self.code_challenge.clone()),
148 (
149 "code_challenge_method",
150 self.code_challenge_method.as_str().to_string(),
151 ),
152 ];
153 params.extend(
154 self.extra_params
155 .iter()
156 .map(|(key, value)| (key.as_str(), value.clone())),
157 );
158 let query = params
159 .into_iter()
160 .map(|(key, value)| format!("{}={}", percent_encode(key), percent_encode(&value)))
161 .collect::<Vec<_>>()
162 .join("&");
163 format!(
164 "{}{}{}",
165 self.authorize_url,
166 if self.authorize_url.contains('?') {
167 '&'
168 } else {
169 '?'
170 },
171 query
172 )
173 }
174}
175
176impl OAuthTokenExchangeRequest {
177 #[must_use]
178 pub fn from_config(
179 config: &OAuthConfig,
180 code: impl Into<String>,
181 state: impl Into<String>,
182 verifier: impl Into<String>,
183 redirect_uri: impl Into<String>,
184 ) -> Self {
185 Self {
186 grant_type: "authorization_code",
187 code: code.into(),
188 redirect_uri: redirect_uri.into(),
189 client_id: config.client_id.clone(),
190 code_verifier: verifier.into(),
191 state: state.into(),
192 }
193 }
194
195 #[must_use]
196 pub fn form_params(&self) -> BTreeMap<&str, String> {
197 BTreeMap::from([
198 ("grant_type", self.grant_type.to_string()),
199 ("code", self.code.clone()),
200 ("redirect_uri", self.redirect_uri.clone()),
201 ("client_id", self.client_id.clone()),
202 ("code_verifier", self.code_verifier.clone()),
203 ("state", self.state.clone()),
204 ])
205 }
206}
207
208impl OAuthRefreshRequest {
209 #[must_use]
210 pub fn from_config(
211 config: &OAuthConfig,
212 refresh_token: impl Into<String>,
213 scopes: Option<Vec<String>>,
214 ) -> Self {
215 Self {
216 grant_type: "refresh_token",
217 refresh_token: refresh_token.into(),
218 client_id: config.client_id.clone(),
219 scopes: scopes.unwrap_or_else(|| config.scopes.clone()),
220 }
221 }
222
223 #[must_use]
224 pub fn form_params(&self) -> BTreeMap<&str, String> {
225 BTreeMap::from([
226 ("grant_type", self.grant_type.to_string()),
227 ("refresh_token", self.refresh_token.clone()),
228 ("client_id", self.client_id.clone()),
229 ("scope", self.scopes.join(" ")),
230 ])
231 }
232}
233
234pub fn generate_pkce_pair() -> io::Result<PkceCodePair> {
235 let verifier = generate_random_token(32)?;
236 Ok(PkceCodePair {
237 challenge: code_challenge_s256(&verifier),
238 verifier,
239 challenge_method: PkceChallengeMethod::S256,
240 })
241}
242
243pub fn generate_state() -> io::Result<String> {
244 generate_random_token(32)
245}
246
247#[must_use]
248pub fn code_challenge_s256(verifier: &str) -> String {
249 let digest = Sha256::digest(verifier.as_bytes());
250 base64url_encode(&digest)
251}
252
253#[must_use]
254pub fn loopback_redirect_uri(port: u16) -> String {
255 format!("http://localhost:{port}/callback")
256}
257
258pub fn credentials_path() -> io::Result<PathBuf> {
259 Ok(credentials_home_dir()?.join("credentials.json"))
260}
261
262const KEYRING_SERVICE: &str = "codineer";
263const KEYRING_USER: &str = "oauth";
264
265fn keyring_entry() -> Option<keyring::Entry> {
266 keyring::Entry::new(KEYRING_SERVICE, KEYRING_USER).ok()
267}
268
269fn load_from_keyring() -> Option<OAuthTokenSet> {
270 let entry = keyring_entry()?;
271 let json = entry.get_password().ok()?;
272 let stored: StoredOAuthCredentials = serde_json::from_str(&json).ok()?;
273 Some(stored.into())
274}
275
276fn save_to_keyring(token_set: &OAuthTokenSet) -> bool {
277 let Some(entry) = keyring_entry() else {
278 return false;
279 };
280 let stored = StoredOAuthCredentials::from(token_set.clone());
281 let Ok(json) = serde_json::to_string(&stored) else {
282 return false;
283 };
284 entry.set_password(&json).is_ok()
285}
286
287fn clear_from_keyring() {
288 if let Some(entry) = keyring_entry() {
289 let _ = entry.delete_credential();
290 }
291}
292
293pub fn load_oauth_credentials() -> io::Result<Option<OAuthTokenSet>> {
294 if let Some(token_set) = load_from_keyring() {
295 return Ok(Some(token_set));
296 }
297
298 let path = match credentials_path() {
299 Ok(path) => path,
300 Err(error) if error.kind() == io::ErrorKind::NotFound => return Ok(None),
301 Err(error) => return Err(error),
302 };
303 let root = read_credentials_root(&path)?;
304 let Some(oauth) = root.get("oauth") else {
305 return Ok(None);
306 };
307 if oauth.is_null() {
308 return Ok(None);
309 }
310 let stored = serde_json::from_value::<StoredOAuthCredentials>(oauth.clone())
311 .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
312 let token_set: OAuthTokenSet = stored.into();
313
314 if save_to_keyring(&token_set) {
315 let mut migrated_root = root;
316 migrated_root.remove("oauth");
317 let _ = write_credentials_root(&path, &migrated_root);
318 }
319
320 Ok(Some(token_set))
321}
322
323pub fn save_oauth_credentials(token_set: &OAuthTokenSet) -> io::Result<()> {
324 if save_to_keyring(token_set) {
325 return Ok(());
326 }
327
328 let path = credentials_path()?;
329 let mut root = read_credentials_root(&path)?;
330 root.insert(
331 "oauth".to_string(),
332 serde_json::to_value(StoredOAuthCredentials::from(token_set.clone()))
333 .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
334 );
335 write_credentials_root(&path, &root)
336}
337
338pub fn clear_oauth_credentials() -> io::Result<()> {
339 clear_from_keyring();
340
341 let path = credentials_path()?;
342 let mut root = read_credentials_root(&path)?;
343 root.remove("oauth");
344 write_credentials_root(&path, &root)
345}
346
347pub fn parse_oauth_callback_request_target(target: &str) -> Result<OAuthCallbackParams, String> {
348 let (path, query) = target
349 .split_once('?')
350 .map_or((target, ""), |(path, query)| (path, query));
351 if path != "/callback" {
352 return Err(format!("unexpected callback path: {path}"));
353 }
354 parse_oauth_callback_query(query)
355}
356
357pub fn parse_oauth_callback_query(query: &str) -> Result<OAuthCallbackParams, String> {
358 let mut params = BTreeMap::new();
359 for pair in query.split('&').filter(|pair| !pair.is_empty()) {
360 let (key, value) = pair
361 .split_once('=')
362 .map_or((pair, ""), |(key, value)| (key, value));
363 params.insert(percent_decode(key)?, percent_decode(value)?);
364 }
365 Ok(OAuthCallbackParams {
366 code: params.get("code").cloned(),
367 state: params.get("state").cloned(),
368 error: params.get("error").cloned(),
369 error_description: params.get("error_description").cloned(),
370 })
371}
372
373fn generate_random_token(bytes: usize) -> io::Result<String> {
374 let mut buffer = vec![0_u8; bytes];
375 getrandom::getrandom(&mut buffer).map_err(|e| io::Error::other(e.to_string()))?;
376 Ok(base64url_encode(&buffer))
377}
378
379fn credentials_home_dir() -> io::Result<PathBuf> {
380 if let Some(path) = std::env::var_os("CODINEER_CONFIG_HOME") {
381 return Ok(PathBuf::from(path));
382 }
383 for key in ["HOME", "USERPROFILE"] {
384 if let Some(home) = std::env::var_os(key) {
385 return Ok(PathBuf::from(home).join(".codineer"));
386 }
387 }
388 Err(io::Error::new(
389 io::ErrorKind::NotFound,
390 "home directory not found (neither HOME nor USERPROFILE is set)",
391 ))
392}
393
394fn read_credentials_root(path: &PathBuf) -> io::Result<Map<String, Value>> {
395 match fs::read_to_string(path) {
396 Ok(contents) => {
397 if contents.trim().is_empty() {
398 return Ok(Map::new());
399 }
400 serde_json::from_str::<Value>(&contents)
401 .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?
402 .as_object()
403 .cloned()
404 .ok_or_else(|| {
405 io::Error::new(
406 io::ErrorKind::InvalidData,
407 "credentials file must contain a JSON object",
408 )
409 })
410 }
411 Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Map::new()),
412 Err(error) => Err(error),
413 }
414}
415
416fn write_credentials_root(path: &PathBuf, root: &Map<String, Value>) -> io::Result<()> {
417 if let Some(parent) = path.parent() {
418 fs::create_dir_all(parent)?;
419 }
420 let rendered = serde_json::to_string_pretty(&Value::Object(root.clone()))
421 .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
422 let temp_path = path.with_extension("json.tmp");
423 fs::write(&temp_path, format!("{rendered}\n"))?;
424 set_file_permissions_owner_only(&temp_path);
425 fs::rename(temp_path, path)
426}
427
428#[cfg(unix)]
429fn set_file_permissions_owner_only(path: &std::path::Path) {
430 use std::os::unix::fs::PermissionsExt;
431 let _ = fs::set_permissions(path, fs::Permissions::from_mode(0o600));
432}
433
434#[cfg(not(unix))]
435fn set_file_permissions_owner_only(_path: &std::path::Path) {}
436
437fn base64url_encode(bytes: &[u8]) -> String {
438 const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
439 let mut output = String::new();
440 let mut index = 0;
441 while index + 3 <= bytes.len() {
442 let block = (u32::from(bytes[index]) << 16)
443 | (u32::from(bytes[index + 1]) << 8)
444 | u32::from(bytes[index + 2]);
445 output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
446 output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
447 output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
448 output.push(TABLE[(block & 0x3F) as usize] as char);
449 index += 3;
450 }
451 match bytes.len().saturating_sub(index) {
452 1 => {
453 let block = u32::from(bytes[index]) << 16;
454 output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
455 output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
456 }
457 2 => {
458 let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8);
459 output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
460 output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
461 output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
462 }
463 _ => {}
464 }
465 output
466}
467
468fn percent_encode(value: &str) -> String {
469 let mut encoded = String::new();
470 for byte in value.bytes() {
471 match byte {
472 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
473 encoded.push(char::from(byte));
474 }
475 _ => {
476 use std::fmt::Write as _;
477 let _ = write!(&mut encoded, "%{byte:02X}");
478 }
479 }
480 }
481 encoded
482}
483
484fn percent_decode(value: &str) -> Result<String, String> {
485 let mut decoded = Vec::with_capacity(value.len());
486 let bytes = value.as_bytes();
487 let mut index = 0;
488 while index < bytes.len() {
489 match bytes[index] {
490 b'%' if index + 2 < bytes.len() => {
491 let hi = decode_hex(bytes[index + 1])?;
492 let lo = decode_hex(bytes[index + 2])?;
493 decoded.push((hi << 4) | lo);
494 index += 3;
495 }
496 b'+' => {
497 decoded.push(b' ');
498 index += 1;
499 }
500 byte => {
501 decoded.push(byte);
502 index += 1;
503 }
504 }
505 }
506 String::from_utf8(decoded).map_err(|error| error.to_string())
507}
508
509fn decode_hex(byte: u8) -> Result<u8, String> {
510 match byte {
511 b'0'..=b'9' => Ok(byte - b'0'),
512 b'a'..=b'f' => Ok(byte - b'a' + 10),
513 b'A'..=b'F' => Ok(byte - b'A' + 10),
514 _ => Err(format!("invalid percent-encoding byte: {byte}")),
515 }
516}
517
518#[cfg(test)]
519mod tests {
520 use std::time::{SystemTime, UNIX_EPOCH};
521
522 use super::{
523 clear_from_keyring, clear_oauth_credentials, code_challenge_s256, credentials_path,
524 generate_pkce_pair, generate_state, load_from_keyring, load_oauth_credentials,
525 loopback_redirect_uri, parse_oauth_callback_query, parse_oauth_callback_request_target,
526 save_oauth_credentials, OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest,
527 OAuthTokenExchangeRequest, OAuthTokenSet,
528 };
529
530 fn sample_config() -> OAuthConfig {
531 OAuthConfig {
532 client_id: "runtime-client".to_string(),
533 authorize_url: "https://console.test/oauth/authorize".to_string(),
534 token_url: "https://console.test/oauth/token".to_string(),
535 callback_port: Some(4545),
536 manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
537 scopes: vec!["org:read".to_string(), "user:write".to_string()],
538 }
539 }
540
541 fn env_lock() -> std::sync::MutexGuard<'static, ()> {
542 crate::test_env_lock()
543 }
544
545 fn temp_config_home() -> std::path::PathBuf {
546 std::env::temp_dir().join(format!(
547 "runtime-oauth-test-{}-{}",
548 std::process::id(),
549 SystemTime::now()
550 .duration_since(UNIX_EPOCH)
551 .expect("time")
552 .as_nanos()
553 ))
554 }
555
556 #[test]
557 fn s256_challenge_matches_expected_vector() {
558 assert_eq!(
559 code_challenge_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"),
560 "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
561 );
562 }
563
564 #[test]
565 fn generates_pkce_pair_and_state() {
566 let pair = generate_pkce_pair().expect("pkce pair");
567 let state = generate_state().expect("state");
568 assert!(!pair.verifier.is_empty());
569 assert!(!pair.challenge.is_empty());
570 assert!(!state.is_empty());
571 }
572
573 #[test]
574 fn builds_authorize_url_and_form_requests() {
575 let config = sample_config();
576 let pair = generate_pkce_pair().expect("pkce");
577 let url = OAuthAuthorizationRequest::from_config(
578 &config,
579 loopback_redirect_uri(4545),
580 "state-123",
581 &pair,
582 )
583 .with_extra_param("login_hint", "user@example.com")
584 .build_url();
585 assert!(url.starts_with("https://console.test/oauth/authorize?"));
586 assert!(url.contains("response_type=code"));
587 assert!(url.contains("client_id=runtime-client"));
588 assert!(url.contains("scope=org%3Aread%20user%3Awrite"));
589 assert!(url.contains("login_hint=user%40example.com"));
590
591 let exchange = OAuthTokenExchangeRequest::from_config(
592 &config,
593 "auth-code",
594 "state-123",
595 pair.verifier,
596 loopback_redirect_uri(4545),
597 );
598 assert_eq!(
599 exchange.form_params().get("grant_type").map(String::as_str),
600 Some("authorization_code")
601 );
602
603 let refresh = OAuthRefreshRequest::from_config(&config, "refresh-token", None);
604 assert_eq!(
605 refresh.form_params().get("scope").map(String::as_str),
606 Some("org:read user:write")
607 );
608 }
609
610 #[test]
611 fn oauth_credentials_round_trip_and_clear() {
612 let _guard = env_lock();
613 let config_home = temp_config_home();
614 std::env::set_var("CODINEER_CONFIG_HOME", &config_home);
615 let path = credentials_path().expect("credentials path");
616 std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent");
617 std::fs::write(&path, "{\"other\":\"value\"}\n").expect("seed credentials");
618
619 let token_set = OAuthTokenSet {
620 access_token: "access-token".to_string(),
621 refresh_token: Some("refresh-token".to_string()),
622 expires_at: Some(123),
623 scopes: vec!["scope:a".to_string()],
624 };
625 save_oauth_credentials(&token_set).expect("save credentials");
626 assert_eq!(
627 load_oauth_credentials().expect("load credentials"),
628 Some(token_set)
629 );
630
631 let keyring_available = load_from_keyring().is_some();
632 let saved = std::fs::read_to_string(&path).expect("read saved file");
633 assert!(saved.contains("\"other\""));
634 if !keyring_available {
635 assert!(saved.contains("\"oauth\""));
636 }
637
638 clear_oauth_credentials().expect("clear credentials");
639 assert_eq!(load_oauth_credentials().expect("load cleared"), None);
640 let cleared = std::fs::read_to_string(&path).expect("read cleared file");
641 assert!(cleared.contains("\"other\""));
642 assert!(!cleared.contains("\"oauth\""));
643
644 clear_from_keyring();
645 std::env::remove_var("CODINEER_CONFIG_HOME");
646 std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
647 }
648
649 #[test]
650 fn parses_callback_query_and_target() {
651 let params =
652 parse_oauth_callback_query("code=abc123&state=state-1&error_description=needs%20login")
653 .expect("parse query");
654 assert_eq!(params.code.as_deref(), Some("abc123"));
655 assert_eq!(params.state.as_deref(), Some("state-1"));
656 assert_eq!(params.error_description.as_deref(), Some("needs login"));
657
658 let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz")
659 .expect("parse callback target");
660 assert_eq!(params.code.as_deref(), Some("abc"));
661 assert_eq!(params.state.as_deref(), Some("xyz"));
662 assert!(parse_oauth_callback_request_target("/wrong?code=abc").is_err());
663 }
664}