use crate::proxy::parity;
use base64::Engine as _;
use serde::{Deserialize, Serialize};
use sha2::Digest;
use std::collections::hash_map::DefaultHasher;
use std::collections::{HashMap, VecDeque};
use std::hash::{Hash, Hasher};
use std::sync::{Mutex, OnceLock};
use std::time::{Duration, Instant};
const TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
const USERINFO_URL_OAUTH2_V2: &str = "https://www.googleapis.com/oauth2/v2/userinfo";
const USERINFO_URL_OPENIDCONNECT_V1: &str = "https://openidconnect.googleapis.com/v1/userinfo";
const REVOKE_URL: &str = "https://oauth2.googleapis.com/revoke";
const AUTH_URL: &str = "https://accounts.google.com/o/oauth2/v2/auth";
const OAUTH_SCOPES: &str = concat!(
"openid ",
"https://www.googleapis.com/auth/cloud-platform ",
"https://www.googleapis.com/auth/userinfo.email ",
"https://www.googleapis.com/auth/userinfo.profile ",
"https://www.googleapis.com/auth/cclog ",
"https://www.googleapis.com/auth/experimentsandconfigs"
);
fn env_first(keys: &[&str]) -> Option<String> {
for k in keys {
if let Ok(v) = std::env::var(k) {
let t = v.trim();
if !t.is_empty() {
return Some(t.to_string());
}
}
}
None
}
pub(crate) fn client_id() -> Result<String, String> {
env_first(&["GOOGLE_OAUTH_CLIENT_ID"])
.ok_or_else(|| "Missing Google OAuth client_id. Set GOOGLE_OAUTH_CLIENT_ID.".to_string())
}
fn client_secret_optional() -> Option<String> {
env_first(&["GOOGLE_OAUTH_CLIENT_SECRET"])
}
fn oauth_user_agent() -> String {
if let Ok(v) = std::env::var("OAUTH_USER_AGENT") {
let t = v.trim();
if !t.is_empty() && t != crate::constants::USER_AGENT.as_str() {
tracing::warn!(
"Ignoring deprecated OAUTH_USER_AGENT override (OAuth uses the global User-Agent)."
);
}
}
crate::constants::USER_AGENT.as_str().to_string()
}
pub fn generate_pkce_verifier() -> String {
let mut bytes = [0u8; 32];
rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut bytes);
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
pub fn pkce_challenge_s256(verifier: &str) -> String {
let digest = sha2::Sha256::digest(verifier.as_bytes());
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TokenResponse {
pub access_token: String,
pub expires_in: i64,
#[serde(default)]
pub token_type: String,
#[serde(default)]
pub refresh_token: Option<String>,
#[serde(default)]
pub id_token: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UserInfo {
pub email: String,
#[serde(default, alias = "verified_email")]
pub email_verified: Option<bool>,
#[serde(default, alias = "id")]
pub sub: Option<String>,
pub name: Option<String>,
pub given_name: Option<String>,
pub family_name: Option<String>,
pub picture: Option<String>,
#[serde(default)]
pub hd: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VerifiedIdentity {
pub email: String,
pub name: Option<String>,
pub google_sub: Option<String>,
pub email_verified: bool,
pub hd: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct RefreshObservabilitySnapshot {
pub refresh_attempts_last_minute: usize,
pub refresh_attempts_by_account_last_minute: HashMap<String, usize>,
}
#[derive(Default)]
struct RefreshObservabilityState {
global_timestamps: VecDeque<Instant>,
per_account_timestamps: HashMap<String, VecDeque<Instant>>,
}
impl UserInfo {
pub fn get_display_name(&self) -> Option<String> {
if let Some(name) = &self.name {
if !name.trim().is_empty() {
return Some(name.clone());
}
}
match (&self.given_name, &self.family_name) {
(Some(given), Some(family)) => Some(format!("{} {}", given, family)),
(Some(given), None) => Some(given.clone()),
(None, Some(family)) => Some(family.clone()),
(None, None) => None,
}
}
pub fn is_email_verified(&self) -> bool {
self.email_verified.unwrap_or(false)
}
pub fn google_sub(&self) -> Option<String> {
self.sub.clone()
}
}
fn load_account_device_profile(account_id: Option<&str>) -> Option<crate::models::DeviceProfile> {
let id = account_id?;
crate::modules::auth::account::load_account(id)
.ok()
.and_then(|account| account.device_profile)
}
fn build_google_identity_headers(
account_id: Option<&str>,
endpoint: &str,
) -> reqwest::header::HeaderMap {
let policy = crate::proxy::upstream::header_policy::load_policy_from_runtime_config();
let endpoint_host = crate::proxy::upstream::header_policy::host_from_url(endpoint);
let user_agent = oauth_user_agent();
let device_profile = load_account_device_profile(account_id);
crate::proxy::upstream::header_policy::build_google_headers(
crate::proxy::upstream::header_policy::GoogleHeaderPolicyContext {
endpoint,
endpoint_host: endpoint_host.as_deref(),
scope: crate::proxy::upstream::header_policy::GoogleHeaderScope::OAuth,
user_agent: user_agent.as_str(),
access_token: None,
content_type_json: false,
device_profile: device_profile.as_ref(),
extra_headers: None,
force_connection_close: true,
},
&policy,
)
}
fn configured_userinfo_endpoints() -> Vec<&'static str> {
if let Ok(cfg) = crate::modules::system::config::load_app_config() {
let endpoints =
crate::proxy::google::endpoints::userinfo_endpoints(cfg.proxy.google.userinfo_endpoint);
if !endpoints.is_empty() {
return endpoints;
}
}
vec![USERINFO_URL_OAUTH2_V2, USERINFO_URL_OPENIDCONNECT_V1]
}
fn refresh_jitter_seconds(account_id: Option<&str>) -> i64 {
let mut hasher = DefaultHasher::new();
account_id.unwrap_or("generic-account").hash(&mut hasher);
30 + (hasher.finish() % 91) as i64
}
fn refresh_observability_state() -> &'static Mutex<RefreshObservabilityState> {
static STATE: OnceLock<Mutex<RefreshObservabilityState>> = OnceLock::new();
STATE.get_or_init(|| Mutex::new(RefreshObservabilityState::default()))
}
fn cleanup_refresh_observability_locked(state: &mut RefreshObservabilityState, now: Instant) {
let window_start = now.checked_sub(Duration::from_secs(60)).unwrap_or(now);
while let Some(ts) = state.global_timestamps.front() {
if *ts < window_start {
state.global_timestamps.pop_front();
} else {
break;
}
}
state.per_account_timestamps.retain(|_, queue| {
while let Some(ts) = queue.front() {
if *ts < window_start {
queue.pop_front();
} else {
break;
}
}
!queue.is_empty()
});
}
fn record_refresh_attempt(account_id: Option<&str>) {
if let Ok(mut state) = refresh_observability_state().lock() {
let now = Instant::now();
cleanup_refresh_observability_locked(&mut state, now);
state.global_timestamps.push_back(now);
let key = account_id.unwrap_or("generic").to_string();
state
.per_account_timestamps
.entry(key)
.or_default()
.push_back(now);
}
}
pub fn refresh_observability_snapshot() -> RefreshObservabilitySnapshot {
if let Ok(mut state) = refresh_observability_state().lock() {
cleanup_refresh_observability_locked(&mut state, Instant::now());
return RefreshObservabilitySnapshot {
refresh_attempts_last_minute: state.global_timestamps.len(),
refresh_attempts_by_account_last_minute: state
.per_account_timestamps
.iter()
.map(|(k, v)| (k.clone(), v.len()))
.collect(),
};
}
RefreshObservabilitySnapshot {
refresh_attempts_last_minute: 0,
refresh_attempts_by_account_last_minute: HashMap::new(),
}
}
#[cfg(test)]
fn clear_refresh_observability_for_tests() {
if let Ok(mut state) = refresh_observability_state().lock() {
*state = RefreshObservabilityState::default();
}
}
pub fn refresh_window_seconds(account_id: Option<&str>) -> i64 {
300 + refresh_jitter_seconds(account_id)
}
pub fn should_refresh_token(
expiry_timestamp: i64,
now_timestamp: i64,
account_id: Option<&str>,
) -> bool {
expiry_timestamp <= now_timestamp + refresh_window_seconds(account_id)
}
pub fn get_auth_url(
redirect_uri: &str,
state: &str,
code_challenge: &str,
) -> Result<String, String> {
let cid = client_id()?;
let params = vec![
("client_id", cid.as_str()),
("redirect_uri", redirect_uri),
("response_type", "code"),
("scope", OAUTH_SCOPES),
("access_type", "offline"),
("prompt", "consent"),
("include_granted_scopes", "true"),
("state", state),
("code_challenge", code_challenge),
("code_challenge_method", "S256"),
];
let url = url::Url::parse_with_params(AUTH_URL, ¶ms)
.map_err(|e| format!("Invalid Auth URL: {}", e))?;
Ok(url.to_string())
}
pub async fn exchange_code(
code: &str,
redirect_uri: &str,
code_verifier: &str,
) -> Result<TokenResponse, String> {
exchange_code_at(code, redirect_uri, code_verifier, TOKEN_URL).await
}
async fn exchange_code_at(
code: &str,
redirect_uri: &str,
code_verifier: &str,
token_url: &str,
) -> Result<TokenResponse, String> {
let client = crate::utils::http::get_long_client();
let cid = client_id()?;
let secret = client_secret_optional();
let mut params: Vec<(&str, String)> = vec![
("client_id", cid),
("code", code.to_string()),
("redirect_uri", redirect_uri.to_string()),
("grant_type", "authorization_code".to_string()),
("code_verifier", code_verifier.to_string()),
];
if let Some(s) = secret {
params.push(("client_secret", s));
}
let headers = build_google_identity_headers(None, token_url);
let started_at = std::time::Instant::now();
let response = client
.post(token_url)
.headers(headers.clone())
.form(¶ms)
.send()
.await;
parity::capture::record_reqwest_outbound(
"POST",
token_url,
&headers,
None,
started_at,
response.as_ref().ok().map(|r| r.status().as_u16()),
parity::types::RequestSource::Gephyr,
);
let response = response.map_err(|e| {
if e.is_connect() || e.is_timeout() {
format!("Token exchange request failed: {}. Please check your network proxy settings to ensure a stable connection to Google services.", e)
} else {
format!("Token exchange request failed: {}", e)
}
})?;
if response.status().is_success() {
let token_res = response
.json::<TokenResponse>()
.await
.map_err(|e| format!("Token parsing failed: {}", e))?;
crate::modules::system::logger::log_info(&format!(
"Token exchange successful! access_token: {}..., refresh_token: {}",
&token_res.access_token.chars().take(20).collect::<String>(),
if token_res.refresh_token.is_some() {
"✓"
} else {
"✗ Missing"
}
));
if token_res.refresh_token.is_none() {
crate::modules::system::logger::log_warn(
"Warning: Google did not return a refresh_token. Potential reasons:\n\
1. User has previously authorized this application\n\
2. Need to revoke access in Google Cloud Console and retry\n\
3. OAuth parameter configuration issue",
);
}
Ok(token_res)
} else {
let error_text = response.text().await.unwrap_or_default();
Err(format!("Token exchange failed: {}", error_text))
}
}
pub async fn refresh_access_token(
refresh_token: &str,
account_id: Option<&str>,
) -> Result<TokenResponse, String> {
refresh_access_token_at(refresh_token, account_id, TOKEN_URL).await
}
pub async fn revoke_refresh_token(
refresh_token: &str,
account_id: Option<&str>,
) -> Result<(), String> {
revoke_refresh_token_at(refresh_token, account_id, REVOKE_URL).await
}
async fn refresh_access_token_at(
refresh_token: &str,
account_id: Option<&str>,
token_url: &str,
) -> Result<TokenResponse, String> {
record_refresh_attempt(account_id);
let client = if let (Some(pool), Some(acc_id)) = (
crate::proxy::proxy_pool::get_global_proxy_pool(),
account_id,
) {
pool.get_effective_client(Some(acc_id), 60)
.await
.map_err(|e| format!("Failed to prepare OAuth refresh client: {}", e))?
} else {
crate::utils::http::get_long_client()
};
let cid = client_id()?;
let secret = client_secret_optional();
let mut params: Vec<(&str, String)> = vec![
("client_id", cid),
("refresh_token", refresh_token.to_string()),
("grant_type", "refresh_token".to_string()),
];
if let Some(s) = secret {
params.push(("client_secret", s));
}
if let Some(id) = account_id {
crate::modules::system::logger::log_info(&format!(
"Refreshing Token for account: {}...",
id
));
} else {
crate::modules::system::logger::log_info(
"Refreshing Token for generic request (no account_id)...",
);
}
let headers = build_google_identity_headers(account_id, token_url);
let started_at = std::time::Instant::now();
let response = client
.post(token_url)
.headers(headers.clone())
.form(¶ms)
.send()
.await;
parity::capture::record_reqwest_outbound(
"POST",
token_url,
&headers,
None,
started_at,
response.as_ref().ok().map(|r| r.status().as_u16()),
parity::types::RequestSource::Gephyr,
);
let response = response.map_err(|e| {
if e.is_connect() || e.is_timeout() {
format!("Refresh request failed: {}. Unable to connect to the Google authorization server. Please check your proxy settings.", e)
} else {
format!("Refresh request failed: {}", e)
}
})?;
if response.status().is_success() {
let token_data = response
.json::<TokenResponse>()
.await
.map_err(|e| format!("Refresh data parsing failed: {}", e))?;
crate::modules::system::logger::log_info(&format!(
"Token refreshed successfully! Expires in: {} seconds",
token_data.expires_in
));
Ok(token_data)
} else {
let error_text = response.text().await.unwrap_or_default();
Err(format!("Refresh failed: {}", error_text))
}
}
async fn revoke_refresh_token_at(
refresh_token: &str,
account_id: Option<&str>,
revoke_url: &str,
) -> Result<(), String> {
let client = if let (Some(pool), Some(acc_id)) = (
crate::proxy::proxy_pool::get_global_proxy_pool(),
account_id,
) {
pool.get_effective_client(Some(acc_id), 15)
.await
.map_err(|e| format!("Failed to prepare OAuth revoke client: {}", e))?
} else {
crate::utils::http::get_client()
};
let params = vec![
("token", refresh_token.to_string()),
("token_type_hint", "refresh_token".to_string()),
];
let headers = build_google_identity_headers(account_id, revoke_url);
let started_at = std::time::Instant::now();
let response = client
.post(revoke_url)
.headers(headers.clone())
.form(¶ms)
.send()
.await;
parity::capture::record_reqwest_outbound(
"POST",
revoke_url,
&headers,
None,
started_at,
response.as_ref().ok().map(|r| r.status().as_u16()),
parity::types::RequestSource::Gephyr,
);
let response = response.map_err(|e| format!("Revoke request failed: {}", e))?;
if response.status().is_success() || response.status() == reqwest::StatusCode::BAD_REQUEST {
return Ok(());
}
let status = response.status();
let body = response.text().await.unwrap_or_default();
Err(format!("Revoke failed: HTTP {} - {}", status, body))
}
pub async fn get_user_info(
access_token: &str,
account_id: Option<&str>,
) -> Result<UserInfo, String> {
let mut last_err = "No configured userinfo endpoints".to_string();
for endpoint in configured_userinfo_endpoints() {
match get_user_info_at(access_token, account_id, endpoint).await {
Ok(info) => return Ok(info),
Err(e) => {
last_err = e;
}
}
}
Err(last_err)
}
async fn get_user_info_at(
access_token: &str,
account_id: Option<&str>,
userinfo_url: &str,
) -> Result<UserInfo, String> {
let client = if let (Some(pool), Some(acc_id)) = (
crate::proxy::proxy_pool::get_global_proxy_pool(),
account_id,
) {
pool.get_effective_client(Some(acc_id), 15)
.await
.map_err(|e| format!("Failed to prepare userinfo client: {}", e))?
} else {
crate::utils::http::get_client()
};
let mut headers = build_google_identity_headers(account_id, userinfo_url);
headers.insert(
reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", access_token))
.unwrap_or(reqwest::header::HeaderValue::from_static("<invalid-token>")),
);
let started_at = std::time::Instant::now();
let response = client
.get(userinfo_url)
.headers(headers.clone())
.send()
.await;
parity::capture::record_reqwest_outbound(
"GET",
userinfo_url,
&headers,
None,
started_at,
response.as_ref().ok().map(|r| r.status().as_u16()),
parity::types::RequestSource::Gephyr,
);
let response = response.map_err(|e| format!("User info request failed: {}", e))?;
if response.status().is_success() {
response
.json::<UserInfo>()
.await
.map_err(|e| format!("User info parsing failed: {}", e))
} else {
let error_text = response.text().await.unwrap_or_default();
Err(format!("Failed to get user info: {}", error_text))
}
}
async fn verify_identity_at(
access_token: &str,
raw_id_token: Option<&str>,
account_id: Option<&str>,
userinfo_url: &str,
) -> Result<VerifiedIdentity, String> {
if let Some(raw) = raw_id_token {
let claims = crate::modules::auth::id_token::validate_id_token(raw)
.await
.map_err(|e| format!("Invalid id_token: {}", e))?;
return Ok(VerifiedIdentity {
email: claims.email,
name: claims.name,
google_sub: Some(claims.sub),
email_verified: claims.email_verified,
hd: claims.hd,
});
}
let user_info = get_user_info_at(access_token, account_id, userinfo_url).await?;
if !user_info.is_email_verified() {
return Err("Google userinfo rejected: email is not verified".to_string());
}
let google_sub = user_info
.google_sub()
.filter(|sub| !sub.trim().is_empty())
.ok_or_else(|| "Google userinfo rejected: missing subject identifier".to_string())?;
Ok(VerifiedIdentity {
email: user_info.email.clone(),
name: user_info.get_display_name(),
google_sub: Some(google_sub),
email_verified: true,
hd: user_info.hd.clone(),
})
}
pub async fn verify_identity(
access_token: &str,
raw_id_token: Option<&str>,
account_id: Option<&str>,
) -> Result<VerifiedIdentity, String> {
if raw_id_token.is_some() {
return verify_identity_at(
access_token,
raw_id_token,
account_id,
USERINFO_URL_OAUTH2_V2,
)
.await;
}
let mut last_err = "No configured userinfo endpoints".to_string();
for endpoint in configured_userinfo_endpoints() {
match verify_identity_at(access_token, None, account_id, endpoint).await {
Ok(identity) => return Ok(identity),
Err(e) => {
last_err = e;
}
}
}
Err(last_err)
}
#[cfg(test)]
async fn refresh_and_verify_identity_at(
refresh_token: &str,
account_id: Option<&str>,
token_url: &str,
userinfo_url: &str,
) -> Result<(TokenResponse, VerifiedIdentity), String> {
let token_res = refresh_access_token_at(refresh_token, account_id, token_url).await?;
let identity = verify_identity_at(
&token_res.access_token,
token_res.id_token.as_deref(),
account_id,
userinfo_url,
)
.await?;
Ok((token_res, identity))
}
pub async fn refresh_and_verify_identity(
refresh_token: &str,
account_id: Option<&str>,
) -> Result<(TokenResponse, VerifiedIdentity), String> {
let token_res = refresh_access_token_at(refresh_token, account_id, TOKEN_URL).await?;
let identity = verify_identity(
&token_res.access_token,
token_res.id_token.as_deref(),
account_id,
)
.await?;
Ok((token_res, identity))
}
pub async fn ensure_fresh_token(
current_token: &crate::models::TokenData,
account_id: Option<&str>,
) -> Result<crate::models::TokenData, String> {
let now = chrono::Local::now().timestamp();
let refresh_window = refresh_window_seconds(account_id);
if !should_refresh_token(current_token.expiry_timestamp, now, account_id) {
return Ok(current_token.clone());
}
crate::modules::system::logger::log_info(&format!(
"Token expiring soon for account {:?}, refreshing (window={}s)...",
account_id, refresh_window
));
let response = refresh_access_token(¤t_token.refresh_token, account_id).await?;
Ok(crate::models::TokenData::new(
response.access_token,
current_token.refresh_token.clone(),
response.expires_in,
current_token.email.clone(),
current_token.project_id.clone(),
None,
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::ScopedEnvVar;
use axum::{
body::Body,
extract::State,
http::{header, HeaderMap, HeaderValue},
response::Response,
routing::{get, post},
Json, Router,
};
use serde_json::json;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::Mutex as AsyncMutex;
fn oauth_ua_test_guard() -> std::sync::MutexGuard<'static, ()> {
crate::test_utils::lock_env()
}
#[derive(Clone)]
struct MockOauthState {
user_agents: Arc<AsyncMutex<Vec<String>>>,
accept_encodings: Arc<AsyncMutex<Vec<String>>>,
userinfo_response: serde_json::Value,
}
impl Default for MockOauthState {
fn default() -> Self {
Self {
user_agents: Arc::new(AsyncMutex::new(Vec::new())),
accept_encodings: Arc::new(AsyncMutex::new(Vec::new())),
userinfo_response: json!({
"email": "ua-test@example.com",
"email_verified": true,
"sub": "sub-ua-test",
"name": "UA Test"
}),
}
}
}
async fn token_capture_handler(
State(state): State<MockOauthState>,
headers: HeaderMap,
) -> Json<serde_json::Value> {
if let Some(ua) = headers.get(reqwest::header::USER_AGENT) {
if let Ok(ua_str) = ua.to_str() {
state.user_agents.lock().await.push(ua_str.to_string());
}
}
if let Some(accept_encoding) = headers.get(reqwest::header::ACCEPT_ENCODING) {
if let Ok(accept_encoding_str) = accept_encoding.to_str() {
state
.accept_encodings
.lock()
.await
.push(accept_encoding_str.to_string());
}
}
Json(json!({
"access_token": "access-test-token",
"expires_in": 3600,
"token_type": "Bearer",
"refresh_token": "refresh-test-token",
"id_token": null
}))
}
async fn token_capture_gzip_handler(
State(state): State<MockOauthState>,
headers: HeaderMap,
) -> Response {
if let Some(ua) = headers.get(reqwest::header::USER_AGENT) {
if let Ok(ua_str) = ua.to_str() {
state.user_agents.lock().await.push(ua_str.to_string());
}
}
if let Some(accept_encoding) = headers.get(reqwest::header::ACCEPT_ENCODING) {
if let Ok(accept_encoding_str) = accept_encoding.to_str() {
state
.accept_encodings
.lock()
.await
.push(accept_encoding_str.to_string());
}
}
const GZIP_TOKEN_RESPONSE_B64: &str = "H4sIAAAAAAAAClWM0QpAQBQF/+U8r1LKwz76kU3ryEZL964i+XfFUh5npuZA6z1VXZpHRtiMRaKm4nEG3JYgVBcibFWXpcFdXNoXwqJhKxQYCHuhDt8r838WurfHdZrOC4yC4FeBAAAA";
let body = base64::engine::general_purpose::STANDARD
.decode(GZIP_TOKEN_RESPONSE_B64)
.expect("valid gzip fixture bytes");
let mut response = Response::new(Body::from(body));
response.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
response
.headers_mut()
.insert(header::CONTENT_ENCODING, HeaderValue::from_static("gzip"));
response
}
async fn userinfo_capture_handler(
State(state): State<MockOauthState>,
headers: HeaderMap,
) -> Json<serde_json::Value> {
if let Some(ua) = headers.get(reqwest::header::USER_AGENT) {
if let Ok(ua_str) = ua.to_str() {
state.user_agents.lock().await.push(ua_str.to_string());
}
}
if let Some(accept_encoding) = headers.get(reqwest::header::ACCEPT_ENCODING) {
if let Ok(accept_encoding_str) = accept_encoding.to_str() {
state
.accept_encodings
.lock()
.await
.push(accept_encoding_str.to_string());
}
}
Json(state.userinfo_response.clone())
}
async fn start_mock_oauth_server_with_userinfo(
userinfo_response: serde_json::Value,
) -> (String, MockOauthState, tokio::task::JoinHandle<()>) {
let state = MockOauthState {
userinfo_response,
..MockOauthState::default()
};
let app = Router::new()
.route("/token", post(token_capture_handler))
.route("/userinfo", get(userinfo_capture_handler))
.with_state(state.clone());
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind test oauth listener");
let addr = listener.local_addr().expect("test oauth local addr");
let handle = tokio::spawn(async move {
axum::serve(listener, app)
.await
.expect("mock oauth server should run");
});
(format!("http://{}", addr), state, handle)
}
async fn start_mock_oauth_server() -> (String, MockOauthState, tokio::task::JoinHandle<()>) {
start_mock_oauth_server_with_userinfo(json!({
"email": "ua-test@example.com",
"email_verified": true,
"sub": "sub-ua-test",
"name": "UA Test"
}))
.await
}
async fn start_mock_oauth_gzip_token_server(
) -> (String, MockOauthState, tokio::task::JoinHandle<()>) {
let state = MockOauthState::default();
let app = Router::new()
.route("/token", post(token_capture_gzip_handler))
.route("/userinfo", get(userinfo_capture_handler))
.with_state(state.clone());
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind test oauth listener");
let addr = listener.local_addr().expect("test oauth local addr");
let handle = tokio::spawn(async move {
axum::serve(listener, app)
.await
.expect("mock oauth server should run");
});
(format!("http://{}", addr), state, handle)
}
#[test]
fn test_get_auth_url_contains_state() {
let _guard = oauth_ua_test_guard();
std::env::set_var(
"GOOGLE_OAUTH_CLIENT_ID",
"test-client.apps.googleusercontent.com",
);
let redirect_uri = "http://localhost:8080/callback";
let state = "test-state-123456";
let verifier = generate_pkce_verifier();
let challenge = pkce_challenge_s256(&verifier);
let url = get_auth_url(redirect_uri, state, &challenge).expect("auth url");
assert!(url.contains("state=test-state-123456"));
assert!(url.contains("redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fcallback"));
assert!(url.contains("response_type=code"));
}
#[test]
fn refresh_jitter_is_deterministic_and_in_range() {
let a = refresh_jitter_seconds(Some("acct-1"));
let b = refresh_jitter_seconds(Some("acct-1"));
let c = refresh_jitter_seconds(Some("acct-2"));
assert_eq!(a, b);
assert!((30..=120).contains(&a));
assert!((30..=120).contains(&c));
}
#[test]
fn should_refresh_token_respects_account_window() {
let now = 1_700_000_000_i64;
let window = refresh_window_seconds(Some("acct-1"));
assert!(should_refresh_token(now + window - 1, now, Some("acct-1")));
assert!(!should_refresh_token(now + window + 1, now, Some("acct-1")));
}
#[test]
fn oauth_user_agent_uses_default_when_override_missing() {
let _guard = oauth_ua_test_guard();
let previous = std::env::var("OAUTH_USER_AGENT").ok();
std::env::remove_var("OAUTH_USER_AGENT");
let ua = oauth_user_agent();
assert_eq!(ua, crate::constants::USER_AGENT.as_str());
match previous {
Some(value) => std::env::set_var("OAUTH_USER_AGENT", value),
None => std::env::remove_var("OAUTH_USER_AGENT"),
}
}
#[test]
fn oauth_user_agent_ignores_override_when_set() {
let _guard = oauth_ua_test_guard();
let previous = std::env::var("OAUTH_USER_AGENT").ok();
std::env::set_var("OAUTH_USER_AGENT", "vscode/1.95.0 gephyr-test");
let ua = oauth_user_agent();
assert_eq!(ua, crate::constants::USER_AGENT.as_str());
match previous {
Some(value) => std::env::set_var("OAUTH_USER_AGENT", value),
None => std::env::remove_var("OAUTH_USER_AGENT"),
}
}
#[tokio::test(flavor = "current_thread")]
async fn refresh_access_token_sends_user_agent_header() {
let _parity_guard = crate::proxy::tests::acquire_security_test_lock();
let _oauth_guard = oauth_ua_test_guard();
let _ua = ScopedEnvVar::set("OAUTH_USER_AGENT", "ua-integration-test");
let _cid = ScopedEnvVar::set(
"GOOGLE_OAUTH_CLIENT_ID",
"test-client.apps.googleusercontent.com",
);
crate::proxy::parity::capture::clear_capture();
let _ = crate::proxy::parity::capture::start_capture(
crate::proxy::parity::capture::CaptureStartConfig::default(),
);
let (base_url, state, server) = start_mock_oauth_server().await;
let token_url = format!("{}/token", base_url);
let _ = refresh_access_token_at("refresh-token", None, &token_url)
.await
.expect("refresh should succeed against mock server");
let parity_snapshot = crate::proxy::parity::capture::captured_snapshot();
let captured = state.user_agents.lock().await.clone();
let captured_accept_encoding = state.accept_encodings.lock().await.clone();
server.abort();
assert!(
captured
.iter()
.any(|ua| ua == "google-api-nodejs-client/10.3.0"),
"expected OAuth refresh call to carry google-api-nodejs-client User-Agent"
);
assert!(
captured_accept_encoding
.iter()
.any(|value| value == "gzip, deflate, br"),
"expected OAuth refresh call to carry Accept-Encoding: gzip, deflate, br"
);
assert!(
parity_snapshot
.iter()
.any(|fp| fp.normalized_endpoint.contains("/token")),
"expected token refresh call to be captured by parity"
);
let _ = crate::proxy::parity::capture::stop_capture();
crate::proxy::parity::capture::clear_capture();
}
#[tokio::test(flavor = "current_thread")]
async fn refresh_access_token_parses_gzip_encoded_response() {
let _guard = oauth_ua_test_guard();
let _cid = ScopedEnvVar::set(
"GOOGLE_OAUTH_CLIENT_ID",
"test-client.apps.googleusercontent.com",
);
let (base_url, state, server) = start_mock_oauth_gzip_token_server().await;
let token_url = format!("{}/token", base_url);
let token = refresh_access_token_at("refresh-token", None, &token_url)
.await
.expect("refresh should parse gzip-encoded token response");
let captured_accept_encoding = state.accept_encodings.lock().await.clone();
server.abort();
assert_eq!(token.access_token, "access-test-token");
assert_eq!(token.expires_in, 3600);
assert!(
captured_accept_encoding
.iter()
.any(|value| value == "gzip, deflate, br"),
"expected OAuth refresh call to carry Accept-Encoding: gzip, deflate, br"
);
}
#[tokio::test(flavor = "current_thread")]
async fn get_user_info_sends_user_agent_header() {
let _guard = oauth_ua_test_guard();
let _ua = ScopedEnvVar::set("OAUTH_USER_AGENT", "ua-userinfo-test");
let (base_url, state, server) = start_mock_oauth_server().await;
let userinfo_url = format!("{}/userinfo", base_url);
let _ = get_user_info_at("access-token", None, &userinfo_url)
.await
.expect("userinfo should succeed against mock server");
let captured = state.user_agents.lock().await.clone();
let captured_accept_encoding = state.accept_encodings.lock().await.clone();
server.abort();
assert!(
captured
.iter()
.any(|ua| ua == "google-api-nodejs-client/10.3.0"),
"expected OAuth userinfo call to carry google-api-nodejs-client User-Agent"
);
assert!(
captured_accept_encoding
.iter()
.any(|value| value == "gzip, deflate, br"),
"expected OAuth userinfo call to carry Accept-Encoding: gzip, deflate, br"
);
}
#[tokio::test(flavor = "current_thread")]
async fn exchange_code_sends_accept_encoding_header() {
let _guard = oauth_ua_test_guard();
let _cid = ScopedEnvVar::set(
"GOOGLE_OAUTH_CLIENT_ID",
"test-client.apps.googleusercontent.com",
);
let (base_url, state, server) = start_mock_oauth_server().await;
let token_url = format!("{}/token", base_url);
let _ = exchange_code_at(
"auth-code",
"http://localhost:8080/callback",
"verifier",
&token_url,
)
.await
.expect("token exchange should succeed against mock server");
let captured_accept_encoding = state.accept_encodings.lock().await.clone();
server.abort();
assert!(
captured_accept_encoding
.iter()
.any(|value| value == "gzip, deflate, br"),
"expected OAuth token exchange call to carry Accept-Encoding: gzip, deflate, br"
);
}
#[tokio::test(flavor = "current_thread")]
async fn refresh_observability_snapshot_tracks_attempts() {
let _guard = oauth_ua_test_guard();
clear_refresh_observability_for_tests();
let _cid = ScopedEnvVar::set(
"GOOGLE_OAUTH_CLIENT_ID",
"test-client.apps.googleusercontent.com",
);
let (base_url, _state, server) = start_mock_oauth_server().await;
let token_url = format!("{}/token", base_url);
let _ = refresh_access_token_at("refresh-token-1", Some("acc-1"), &token_url)
.await
.expect("refresh should succeed");
let _ = refresh_access_token_at("refresh-token-2", None, &token_url)
.await
.expect("refresh should succeed");
let snapshot = refresh_observability_snapshot();
server.abort();
assert!(snapshot.refresh_attempts_last_minute >= 2);
assert_eq!(
snapshot
.refresh_attempts_by_account_last_minute
.get("acc-1")
.copied()
.unwrap_or(0),
1
);
assert_eq!(
snapshot
.refresh_attempts_by_account_last_minute
.get("generic")
.copied()
.unwrap_or(0),
1
);
clear_refresh_observability_for_tests();
}
#[tokio::test(flavor = "current_thread")]
async fn verify_identity_fallback_rejects_unverified_email() {
let _guard = oauth_ua_test_guard();
let (base_url, _state, server) = start_mock_oauth_server_with_userinfo(json!({
"email": "unverified@example.com",
"email_verified": false,
"sub": "sub-unverified"
}))
.await;
let userinfo_url = format!("{}/userinfo", base_url);
let err = verify_identity_at("access-token", None, None, &userinfo_url)
.await
.expect_err("verify_identity fallback should fail for unverified email");
server.abort();
assert!(err.contains("email is not verified"));
}
#[tokio::test(flavor = "current_thread")]
async fn verify_identity_fallback_rejects_missing_subject_identifier() {
let _guard = oauth_ua_test_guard();
let (base_url, _state, server) = start_mock_oauth_server_with_userinfo(json!({
"email": "nosub@example.com",
"email_verified": true,
"name": "No Sub"
}))
.await;
let userinfo_url = format!("{}/userinfo", base_url);
let err = verify_identity_at("access-token", None, None, &userinfo_url)
.await
.expect_err("verify_identity fallback should fail when subject identifier is missing");
server.abort();
assert!(err.contains("missing subject identifier"));
}
#[tokio::test(flavor = "current_thread")]
async fn refresh_and_verify_identity_rejects_missing_subject_identifier() {
let _guard = oauth_ua_test_guard();
let _cid = ScopedEnvVar::set(
"GOOGLE_OAUTH_CLIENT_ID",
"test-client.apps.googleusercontent.com",
);
let (base_url, _state, server) = start_mock_oauth_server_with_userinfo(json!({
"email": "nosub@example.com",
"email_verified": true,
"name": "No Sub"
}))
.await;
let token_url = format!("{}/token", base_url);
let userinfo_url = format!("{}/userinfo", base_url);
let err = refresh_and_verify_identity_at("refresh-token", None, &token_url, &userinfo_url)
.await
.expect_err(
"refresh_and_verify_identity should fail when fallback userinfo is missing sub",
);
server.abort();
assert!(err.contains("missing subject identifier"));
}
}