use base64::Engine;
use jerrycan_core::{Error, Result};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
#[derive(Clone)]
pub struct Secret(String);
impl Secret {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
fn expose(&self) -> &str {
&self.0
}
}
impl From<&str> for Secret {
fn from(s: &str) -> Self {
Self::new(s)
}
}
impl From<String> for Secret {
fn from(s: String) -> Self {
Self::new(s)
}
}
#[derive(Clone, Debug)]
pub struct Provider {
pub auth_url: &'static str,
pub token_url: &'static str,
pub default_scopes: &'static [&'static str],
}
impl Provider {
pub fn google() -> Self {
Self {
auth_url: "https://accounts.google.com/o/oauth2/v2/auth",
token_url: "https://oauth2.googleapis.com/token",
default_scopes: &["openid", "email", "profile"],
}
}
pub fn github() -> Self {
Self {
auth_url: "https://github.com/login/oauth/authorize",
token_url: "https://github.com/login/oauth/access_token",
default_scopes: &["read:user", "user:email"],
}
}
pub fn hubspot() -> Self {
Self {
auth_url: "https://app.hubspot.com/oauth/authorize",
token_url: "https://api.hubapi.com/oauth/v1/token",
default_scopes: &["oauth"],
}
}
pub fn salesforce() -> Self {
Self {
auth_url: "https://login.salesforce.com/services/oauth2/authorize",
token_url: "https://login.salesforce.com/services/oauth2/token",
default_scopes: &["api", "refresh_token"],
}
}
}
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct TokenResponse {
pub access_token: String,
#[serde(default = "default_token_type")]
pub token_type: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub expires_in: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
}
impl std::fmt::Debug for TokenResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenResponse")
.field("access_token", &"***")
.field("token_type", &self.token_type)
.field(
"refresh_token",
&if self.refresh_token.is_some() {
"<present>"
} else {
"<absent>"
},
)
.field("expires_in", &self.expires_in)
.field("scope", &self.scope)
.finish()
}
}
fn default_token_type() -> String {
"Bearer".to_string()
}
pub type TokenFuture<'a> = Pin<Box<dyn Future<Output = Result<TokenResponse>> + Send + 'a>>;
pub trait TokenTransport: Send + Sync {
fn post_form<'a>(&'a self, url: &'a str, form: &'a [(&'a str, &'a str)]) -> TokenFuture<'a>;
}
pub fn parse_token_body(body: &[u8]) -> Result<TokenResponse> {
let value: serde_json::Value = serde_json::from_slice(body)
.map_err(|_| Error::bad_request("oauth: token endpoint returned a non-JSON body"))?;
if let Some(err) = value.get("error").and_then(|e| e.as_str()) {
let message = match value.get("error_description").and_then(|d| d.as_str()) {
Some(desc) => format!("oauth provider error: {err}: {desc}"),
None => format!("oauth provider error: {err}"),
};
return Err(Error::bad_request(message));
}
serde_json::from_value(value)
.map_err(|e| Error::bad_request(format!("oauth: malformed token response: {e}")))
}
#[derive(Clone)]
pub struct PkceVerifier(String);
impl PkceVerifier {
pub fn generate() -> Self {
use rand::RngCore;
let mut bytes = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut bytes);
Self(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes))
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn challenge(&self) -> String {
let mut hasher = Sha256::new();
hasher.update(self.0.as_bytes());
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize())
}
}
#[derive(Clone)]
pub struct OAuthClient {
provider: Provider,
client_id: String,
client_secret: Secret,
redirect_uri: String,
http: Arc<dyn TokenTransport>,
}
impl OAuthClient {
pub fn new(
provider: Provider,
client_id: impl Into<String>,
client_secret: impl Into<Secret>,
redirect_uri: impl Into<String>,
) -> Self {
Self {
provider,
client_id: client_id.into(),
client_secret: client_secret.into(),
redirect_uri: redirect_uri.into(),
http: Arc::new(HttpTransport::new()),
}
}
pub fn with_transport(mut self, transport: Arc<dyn TokenTransport>) -> Self {
self.http = transport;
self
}
pub fn authorize_url(&self, state: &str, scopes: &[&str]) -> String {
let scope = self.scope_string(scopes);
format!(
"{}?response_type=code&client_id={}&redirect_uri={}&scope={}&state={}",
self.provider.auth_url,
encode(&self.client_id),
encode(&self.redirect_uri),
encode(&scope),
encode(state),
)
}
pub fn authorize_url_pkce(&self, state: &str, scopes: &[&str]) -> (String, PkceVerifier) {
let verifier = PkceVerifier::generate();
let base = self.authorize_url(state, scopes);
let url = format!(
"{base}&code_challenge={}&code_challenge_method=S256",
encode(&verifier.challenge()),
);
(url, verifier)
}
pub async fn exchange_code(
&self,
code: &str,
pkce: Option<&PkceVerifier>,
) -> Result<TokenResponse> {
let mut form: Vec<(&str, &str)> = vec![
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", &self.redirect_uri),
("client_id", &self.client_id),
("client_secret", self.client_secret.expose()),
];
if let Some(verifier) = pkce {
form.push(("code_verifier", verifier.as_str()));
}
self.http.post_form(self.provider.token_url, &form).await
}
pub async fn refresh(&self, refresh_token: &str) -> Result<TokenResponse> {
let form: Vec<(&str, &str)> = vec![
("grant_type", "refresh_token"),
("refresh_token", refresh_token),
("client_id", &self.client_id),
("client_secret", self.client_secret.expose()),
];
self.http.post_form(self.provider.token_url, &form).await
}
fn scope_string(&self, scopes: &[&str]) -> String {
if scopes.is_empty() {
self.provider.default_scopes.join(" ")
} else {
scopes.join(" ")
}
}
}
fn encode(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for &b in s.as_bytes() {
match b {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
out.push(b as char);
}
_ => {
out.push('%');
out.push(
char::from_digit((b >> 4) as u32, 16)
.expect("nibble < 16")
.to_ascii_uppercase(),
);
out.push(
char::from_digit((b & 0x0f) as u32, 16)
.expect("nibble < 16")
.to_ascii_uppercase(),
);
}
}
}
out
}
#[derive(Clone)]
pub struct HttpTransport {
client: hyper_util::client::legacy::Client<
hyper_rustls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
http_body_util::Full<bytes::Bytes>,
>,
}
impl Default for HttpTransport {
fn default() -> Self {
Self::new()
}
}
impl HttpTransport {
pub fn new() -> Self {
let connector = hyper_rustls::HttpsConnectorBuilder::new()
.with_provider_and_webpki_roots(rustls::crypto::ring::default_provider())
.expect("ring provider supports rustls' safe default protocol versions")
.https_or_http()
.enable_http1()
.build();
let client =
hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
.build(connector);
Self { client }
}
}
impl TokenTransport for HttpTransport {
fn post_form<'a>(&'a self, url: &'a str, form: &'a [(&'a str, &'a str)]) -> TokenFuture<'a> {
Box::pin(async move {
use http_body_util::BodyExt;
if !is_loopback_http_ok(url) {
return Err(Error::internal(
"oauth: refusing plaintext http:// to a non-loopback token endpoint",
));
}
let body = encode_form(form);
let request = hyper::Request::builder()
.method(hyper::Method::POST)
.uri(url)
.header(
hyper::header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
)
.header(hyper::header::ACCEPT, "application/json")
.body(http_body_util::Full::new(bytes::Bytes::from(body)))
.map_err(|e| {
Error::internal(format!("oauth: building token request failed: {e}"))
})?;
let response = self
.client
.request(request)
.await
.map_err(|_| Error::internal("oauth: token endpoint request failed"))?;
let bytes = response
.into_body()
.collect()
.await
.map_err(|_| Error::internal("oauth: reading token response body failed"))?
.to_bytes();
parse_token_body(&bytes)
})
}
}
fn is_loopback_http_ok(url: &str) -> bool {
let Some((scheme, rest)) = url.split_once("://") else {
return false;
};
if scheme.eq_ignore_ascii_case("https") {
return true;
}
if !scheme.eq_ignore_ascii_case("http") {
return false;
}
let authority = rest
.split(['/', '?', '#'])
.next()
.expect("split always yields at least one element");
if authority.contains('@') {
return false;
}
let host = if let Some(after) = authority.strip_prefix('[') {
match after.split_once(']') {
Some((inner, _port)) => inner,
None => return false, }
} else {
authority.rsplit_once(':').map_or(authority, |(h, _)| h)
};
host.eq_ignore_ascii_case("localhost") || host == "127.0.0.1" || host == "::1"
}
fn encode_form(form: &[(&str, &str)]) -> String {
let mut out = String::new();
for (i, (k, v)) in form.iter().enumerate() {
if i > 0 {
out.push('&');
}
out.push_str(&encode(k));
out.push('=');
out.push_str(&encode(v));
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mock_idp::MockIdp;
#[test]
fn authorize_url_contains_required_params_and_encodes_them() {
let client = OAuthClient::new(
Provider::google(),
"client-123",
"topsecret",
"https://app.example.com/callback",
);
let url = client.authorize_url("st@te/with spaces&x", &["email", "profile"]);
assert!(url.starts_with("https://accounts.google.com/o/oauth2/v2/auth?"));
assert!(url.contains("response_type=code"));
assert!(url.contains("client_id=client-123"));
assert!(
url.contains("redirect_uri=https%3A%2F%2Fapp.example.com%2Fcallback"),
"redirect_uri not encoded: {url}"
);
assert!(url.contains("scope=email%20profile"), "scope wrong: {url}");
assert!(
url.contains("state=st%40te%2Fwith%20spaces%26x"),
"state not encoded: {url}"
);
assert!(!url.contains("st@te"), "raw state leaked: {url}");
}
#[test]
fn authorize_url_empty_scopes_uses_provider_defaults() {
let client = OAuthClient::new(Provider::github(), "id", "sec", "https://x/cb");
let url = client.authorize_url("s", &[]);
assert!(
url.contains("scope=read%3Auser%20user%3Aemail"),
"default scopes not applied: {url}"
);
}
#[test]
fn pkce_url_carries_s256_challenge_matching_the_verifier() {
let client = OAuthClient::new(Provider::google(), "id", "sec", "https://x/cb");
let (url, verifier) = client.authorize_url_pkce("state-1", &["openid"]);
assert!(
url.contains("code_challenge_method=S256"),
"missing method: {url}"
);
let expected = encode(&verifier.challenge());
assert!(
url.contains(&format!("code_challenge={expected}")),
"challenge does not match verifier: {url}"
);
assert_ne!(verifier.challenge(), verifier.as_str());
}
#[tokio::test]
async fn exchange_code_and_refresh_happy_path_via_mock_transport() {
let idp = MockIdp::new();
let (access, refresh) = idp.issue_code("auth-code-1");
let client = OAuthClient::new(
Provider::google(),
"id",
"sec",
"https://app.example.com/cb",
)
.with_transport(idp.token_transport());
let token = client.exchange_code("auth-code-1", None).await.unwrap();
assert_eq!(token.access_token, access);
assert_eq!(token.refresh_token.as_deref(), Some(refresh.as_str()));
let refreshed = client.refresh(&refresh).await.unwrap();
assert_ne!(refreshed.access_token, access);
assert!(!refreshed.access_token.is_empty());
}
#[tokio::test]
async fn oauth_error_response_is_non_500_and_never_leaks_the_secret() {
struct ErrorTransport;
impl TokenTransport for ErrorTransport {
fn post_form<'a>(
&'a self,
_url: &'a str,
_form: &'a [(&'a str, &'a str)],
) -> TokenFuture<'a> {
Box::pin(async {
parse_token_body(
br#"{"error":"invalid_grant","error_description":"code expired"}"#,
)
})
}
}
let client = OAuthClient::new(
Provider::google(),
"id",
"super-secret-value-xyz",
"https://x/cb",
)
.with_transport(Arc::new(ErrorTransport));
let err = client.exchange_code("dead", None).await.unwrap_err();
assert_eq!(err.status().as_u16(), 400, "must not be a 500");
assert_eq!(err.code(), "JC0400");
assert!(
err.message().contains("invalid_grant"),
"reason missing: {err}"
);
assert!(err.message().contains("code expired"));
let rendered = err.to_string();
assert!(
!rendered.contains("super-secret-value-xyz"),
"client_secret leaked into error: {rendered}"
);
}
#[test]
fn token_response_round_trips_through_serde_for_at_rest_encryption() {
let token = TokenResponse {
access_token: "at".into(),
token_type: "Bearer".into(),
refresh_token: Some("rt".into()),
expires_in: Some(3600),
scope: Some("email".into()),
};
let json = serde_json::to_string(&token).unwrap();
let back: TokenResponse = serde_json::from_str(&json).unwrap();
assert_eq!(token, back);
let minimal: TokenResponse = serde_json::from_str(r#"{"access_token":"only"}"#).unwrap();
assert_eq!(minimal.access_token, "only");
assert_eq!(minimal.token_type, "Bearer");
assert!(minimal.refresh_token.is_none());
}
#[test]
fn token_response_debug_redacts_access_and_refresh_tokens() {
let token = TokenResponse {
access_token: "ACCESS-SECRET-abc123".into(),
token_type: "Bearer".into(),
refresh_token: Some("REFRESH-SECRET-xyz789".into()),
expires_in: Some(3600),
scope: Some("email".into()),
};
let rendered = format!("{token:?}");
assert!(
!rendered.contains("ACCESS-SECRET-abc123"),
"access_token leaked through Debug: {rendered}"
);
assert!(
!rendered.contains("REFRESH-SECRET-xyz789"),
"refresh_token leaked through Debug: {rendered}"
);
assert!(
rendered.contains("access_token: \"***\""),
"got: {rendered}"
);
assert!(rendered.contains("<present>"), "got: {rendered}");
assert!(rendered.contains("Bearer"), "got: {rendered}");
let no_refresh = TokenResponse {
access_token: "ANOTHER-ACCESS-SECRET".into(),
token_type: "Bearer".into(),
refresh_token: None,
expires_in: None,
scope: None,
};
let rendered = format!("{no_refresh:?}");
assert!(
!rendered.contains("ANOTHER-ACCESS-SECRET"),
"got: {rendered}"
);
assert!(rendered.contains("<absent>"), "got: {rendered}");
}
#[test]
fn loopback_http_guard_allows_only_loopback_plaintext_and_all_https() {
assert!(is_loopback_http_ok("https://oauth2.googleapis.com/token"));
assert!(is_loopback_http_ok("HTTPS://EVIL.example.com/token"));
assert!(is_loopback_http_ok("http://127.0.0.1:8080/token"));
assert!(is_loopback_http_ok("http://localhost/token"));
assert!(is_loopback_http_ok("http://LocalHost:3000/token"));
assert!(is_loopback_http_ok("http://[::1]:9000/token"));
assert!(!is_loopback_http_ok("http://evil.example.com/token"));
assert!(!is_loopback_http_ok("http://oauth2.googleapis.com/token"));
assert!(!is_loopback_http_ok("http://localhost.evil.com/token"));
assert!(!is_loopback_http_ok("http://127.0.0.1.evil.com/token"));
assert!(!is_loopback_http_ok("http://127.0.0.1@evil.com/token"));
assert!(!is_loopback_http_ok("http://localhost@evil.com/token"));
assert!(!is_loopback_http_ok("http://evil.com@127.0.0.1/token"));
assert!(!is_loopback_http_ok("ftp://127.0.0.1/token"));
assert!(!is_loopback_http_ok("not-a-url"));
}
#[tokio::test]
async fn real_http_transport_rejects_non_loopback_plaintext_without_a_network_call() {
let provider = Provider {
auth_url: "http://evil.example.com/authorize",
token_url: "http://evil.example.com/token",
default_scopes: &["openid"],
};
let client = OAuthClient::new(provider, "id", "super-secret-value", "http://app/cb")
.with_transport(Arc::new(HttpTransport::new()));
let err = client.exchange_code("any-code", None).await.unwrap_err();
assert!(
err.message().contains("refusing plaintext http"),
"expected the loopback guard, got: {err}"
);
assert!(
!err.to_string().contains("super-secret-value"),
"secret leaked into error: {err}"
);
}
#[test]
fn secret_does_not_implement_debug_or_display() {
let s = Secret::new("abc");
assert_eq!(s.expose(), "abc");
}
}