use std::sync::Arc;
use crate::auth;
use crate::core::{ApiError, ConfigProvider, HttpRequest, HttpResponse, Transport};
use crate::render::percent_encode;
pub struct RefreshingTransport {
inner: Arc<dyn Transport>,
config: Arc<dyn ConfigProvider>,
}
impl RefreshingTransport {
pub fn new(inner: Arc<dyn Transport>, config: Arc<dyn ConfigProvider>) -> Self {
Self { inner, config }
}
fn host_owning(&self, failed_bearer: &str) -> Option<String> {
self.config.hosts().into_iter().find(|host| {
self.config.get(host, "auth_type").as_deref() == Some(auth::OAUTH)
&& self.config.get(host, "token").as_deref() == Some(failed_bearer)
})
}
fn try_refresh(&self, failed_bearer: &str) -> Option<String> {
let host = &self.host_owning(failed_bearer)?;
let refresh_token = self.config.get(host, "refresh_token")?;
let client_id = self.config.get(host, "oauth_client_id")?;
let client_secret = self.config.get(host, "oauth_client_secret")?;
let body = format!(
"grant_type=refresh_token&refresh_token={}",
percent_encode(&refresh_token)
);
let basic = auth::basic_header(&client_id, &client_secret);
let token: auth::TokenResponse =
match auth::post_form(self.inner.as_ref(), auth::TOKEN_URL, &body, &basic) {
Ok(t) => t,
Err(e) => {
eprintln!("bb: OAuth token refresh failed: {e}");
return None;
}
};
let _ = self.config.set(host, "token", &token.access_token);
if let Some(rt) = &token.refresh_token {
let _ = self.config.set(host, "refresh_token", rt);
}
if let Err(e) = self.config.save() {
eprintln!("bb: refreshed the OAuth token but could not save it: {e}");
}
Some(token.access_token)
}
}
impl Transport for RefreshingTransport {
fn execute(&self, req: HttpRequest) -> Result<HttpResponse, ApiError> {
let failed_bearer = req
.headers
.get("Authorization")
.and_then(|v| v.strip_prefix("Bearer "))
.map(str::to_owned);
let Some(failed_bearer) = failed_bearer else {
return self.inner.execute(req);
};
let retry = req.clone();
let resp = self.inner.execute(req)?;
if resp.status != 401 {
return Ok(resp);
}
match self.try_refresh(&failed_bearer) {
Some(new_token) => {
let mut retry = retry;
retry
.headers
.insert("Authorization".to_owned(), auth::bearer_header(&new_token));
self.inner.execute(retry)
}
None => Ok(resp),
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use crate::config::FileConfig;
use crate::core::Method;
use super::*;
struct ScriptedTransport {
responses: Mutex<Vec<HttpResponse>>,
seen: Mutex<Vec<HttpRequest>>,
}
impl ScriptedTransport {
fn new(responses: Vec<HttpResponse>) -> Self {
Self {
responses: Mutex::new(responses),
seen: Mutex::new(Vec::new()),
}
}
}
impl Transport for ScriptedTransport {
fn execute(&self, req: HttpRequest) -> Result<HttpResponse, ApiError> {
self.seen.lock().unwrap().push(req);
Ok(self.responses.lock().unwrap().remove(0))
}
}
fn resp(status: u16, body: &str) -> HttpResponse {
HttpResponse {
status,
headers: std::collections::BTreeMap::new(),
body: body.as_bytes().to_vec(),
}
}
fn bearer_get(token: &str) -> HttpRequest {
HttpRequest::new(Method::Get, "https://api.bitbucket.org/2.0/user")
.header("Authorization", auth::bearer_header(token))
}
fn oauth_config() -> (Arc<dyn ConfigProvider>, tempfile::TempDir) {
let dir = tempfile::tempdir().unwrap();
let cfg = FileConfig::load_from(dir.path().to_path_buf()).unwrap();
cfg.set("bitbucket.org", "auth_type", "oauth").unwrap();
cfg.set("bitbucket.org", "token", "old-access").unwrap();
cfg.set("bitbucket.org", "refresh_token", "rt-1").unwrap();
cfg.set("bitbucket.org", "oauth_client_id", "cid").unwrap();
cfg.set("bitbucket.org", "oauth_client_secret", "csec")
.unwrap();
(Arc::new(cfg), dir)
}
#[test]
fn refreshes_on_401_then_retries_and_persists() {
let inner = Arc::new(ScriptedTransport::new(vec![
resp(401, r#"{"type":"error"}"#),
resp(
200,
r#"{"access_token":"new-access","refresh_token":"rt-2"}"#,
),
resp(200, r#"{"username":"davidd"}"#),
]));
let (config, _dir) = oauth_config();
let t = RefreshingTransport::new(inner.clone(), config.clone());
let out = t.execute(bearer_get("old-access")).unwrap();
assert_eq!(out.status, 200);
assert_eq!(out.body_str(), r#"{"username":"davidd"}"#);
assert_eq!(
config.get("bitbucket.org", "token").as_deref(),
Some("new-access")
);
assert_eq!(
config.get("bitbucket.org", "refresh_token").as_deref(),
Some("rt-2")
);
let seen = inner.seen.lock().unwrap();
assert_eq!(seen.len(), 3);
assert_eq!(seen[1].url, auth::TOKEN_URL);
assert_eq!(
seen[2].headers.get("Authorization").map(String::as_str),
Some("Bearer new-access")
);
}
#[test]
fn non_401_passes_through_without_refresh() {
let inner = Arc::new(ScriptedTransport::new(vec![resp(200, "{}")]));
let (config, _dir) = oauth_config();
let t = RefreshingTransport::new(inner.clone(), config);
let out = t.execute(bearer_get("old-access")).unwrap();
assert_eq!(out.status, 200);
assert_eq!(inner.seen.lock().unwrap().len(), 1);
}
#[test]
fn surfaces_401_when_no_refresh_token() {
let dir = tempfile::tempdir().unwrap();
let cfg = FileConfig::load_from(dir.path().to_path_buf()).unwrap();
cfg.set("bitbucket.org", "auth_type", "oauth").unwrap();
cfg.set("bitbucket.org", "token", "old").unwrap();
let config: Arc<dyn ConfigProvider> = Arc::new(cfg);
let inner = Arc::new(ScriptedTransport::new(vec![resp(401, "{}")]));
let t = RefreshingTransport::new(inner.clone(), config);
let out = t.execute(bearer_get("old")).unwrap();
assert_eq!(out.status, 401);
assert_eq!(inner.seen.lock().unwrap().len(), 1);
}
#[test]
fn basic_auth_401_is_not_refreshed() {
let inner = Arc::new(ScriptedTransport::new(vec![resp(401, "{}")]));
let (config, _dir) = oauth_config();
let t = RefreshingTransport::new(inner.clone(), config);
let req = HttpRequest::new(Method::Get, "https://api.bitbucket.org/2.0/user")
.header("Authorization", auth::basic_header("u", "p"));
let out = t.execute(req).unwrap();
assert_eq!(out.status, 401);
assert_eq!(inner.seen.lock().unwrap().len(), 1);
}
#[test]
fn refreshes_against_a_non_default_host() {
let dir = tempfile::tempdir().unwrap();
let cfg = FileConfig::load_from(dir.path().to_path_buf()).unwrap();
let host = "git.example.com";
cfg.set(host, "auth_type", "oauth").unwrap();
cfg.set(host, "token", "old-access").unwrap();
cfg.set(host, "refresh_token", "rt-1").unwrap();
cfg.set(host, "oauth_client_id", "cid").unwrap();
cfg.set(host, "oauth_client_secret", "csec").unwrap();
let config: Arc<dyn ConfigProvider> = Arc::new(cfg);
let inner = Arc::new(ScriptedTransport::new(vec![
resp(401, r#"{"type":"error"}"#),
resp(
200,
r#"{"access_token":"new-access","refresh_token":"rt-2"}"#,
),
resp(200, r#"{"username":"davidd"}"#),
]));
let t = RefreshingTransport::new(inner.clone(), config.clone());
let out = t.execute(bearer_get("old-access")).unwrap();
assert_eq!(out.status, 200);
assert_eq!(config.get(host, "token").as_deref(), Some("new-access"));
assert_eq!(inner.seen.lock().unwrap().len(), 3);
}
#[test]
fn bearer_not_matching_stored_token_is_not_refreshed() {
let inner = Arc::new(ScriptedTransport::new(vec![resp(401, "{}")]));
let (config, _dir) = oauth_config(); let t = RefreshingTransport::new(inner.clone(), config.clone());
let out = t.execute(bearer_get("some-other-token")).unwrap();
assert_eq!(out.status, 401);
assert_eq!(inner.seen.lock().unwrap().len(), 1);
assert_eq!(
config.get("bitbucket.org", "token").as_deref(),
Some("old-access")
);
assert_eq!(
config.get("bitbucket.org", "refresh_token").as_deref(),
Some("rt-1")
);
}
}