use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::sync::RwLock;
use crate::secrets::{CreateSecretParams, SecretsStore};
pub struct OAuthCredentials {
pub client_id: &'static str,
pub client_secret: &'static str,
}
const GOOGLE_CLIENT_ID: &str = match option_env!("IRONCLAW_GOOGLE_CLIENT_ID") {
Some(v) => v,
None => "564604149681-efo25d43rs85v0tibdepsmdv5dsrhhr0.apps.googleusercontent.com",
};
const GOOGLE_CLIENT_SECRET: &str = match option_env!("IRONCLAW_GOOGLE_CLIENT_SECRET") {
Some(v) => v,
None => "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",
};
pub fn builtin_credentials(secret_name: &str) -> Option<OAuthCredentials> {
match secret_name {
"google_oauth_token" => Some(OAuthCredentials {
client_id: GOOGLE_CLIENT_ID,
client_secret: GOOGLE_CLIENT_SECRET,
}),
_ => None,
}
}
pub fn builtin_client_id_override_env(secret_name: &str) -> Option<&'static str> {
match secret_name {
"google_oauth_token" => Some("IRONCLAW_GOOGLE_CLIENT_ID"),
_ => None,
}
}
pub fn hosted_proxy_client_secret(
client_secret: &Option<String>,
builtin: Option<&OAuthCredentials>,
exchange_proxy_configured: bool,
) -> Option<String> {
if !exchange_proxy_configured {
return client_secret.clone();
}
let builtin_secret = builtin.map(|credentials| credentials.client_secret);
match (client_secret, builtin_secret) {
(Some(resolved), Some(baked_in)) if resolved == baked_in => None,
_ => client_secret.clone(),
}
}
pub use crate::llm::oauth_helpers::{
OAUTH_CALLBACK_PORT, OAuthCallbackError, bind_callback_listener, callback_host, callback_url,
is_loopback_host, landing_html, wait_for_callback,
};
pub struct OAuthTokenResponse {
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_in: Option<u64>,
}
pub struct OAuthUrlResult {
pub url: String,
pub code_verifier: Option<String>,
pub state: String,
}
pub fn build_oauth_url(
authorization_url: &str,
client_id: &str,
redirect_uri: &str,
scopes: &[String],
use_pkce: bool,
extra_params: &HashMap<String, String>,
) -> OAuthUrlResult {
let (code_verifier, code_challenge) = if use_pkce {
let mut verifier_bytes = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut verifier_bytes);
let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let challenge = URL_SAFE_NO_PAD.encode(hasher.finalize());
(Some(verifier), Some(challenge))
} else {
(None, None)
};
let mut state_bytes = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut state_bytes);
let state = URL_SAFE_NO_PAD.encode(state_bytes);
let mut auth_url = format!(
"{}?client_id={}&response_type=code&redirect_uri={}&state={}",
authorization_url,
urlencoding::encode(client_id),
urlencoding::encode(redirect_uri),
urlencoding::encode(&state),
);
if !scopes.is_empty() {
auth_url.push_str(&format!(
"&scope={}",
urlencoding::encode(&scopes.join(" "))
));
}
if let Some(ref challenge) = code_challenge {
auth_url.push_str(&format!(
"&code_challenge={}&code_challenge_method=S256",
challenge
));
}
for (key, value) in extra_params {
auth_url.push_str(&format!(
"&{}={}",
urlencoding::encode(key),
urlencoding::encode(value)
));
}
OAuthUrlResult {
url: auth_url,
code_verifier,
state,
}
}
pub async fn exchange_oauth_code(
token_url: &str,
client_id: &str,
client_secret: Option<&str>,
code: &str,
redirect_uri: &str,
code_verifier: Option<&str>,
access_token_field: &str,
) -> Result<OAuthTokenResponse, OAuthCallbackError> {
let extra_token_params = HashMap::new();
exchange_oauth_code_with_params(
token_url,
client_id,
client_secret,
code,
redirect_uri,
code_verifier,
access_token_field,
&extra_token_params,
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn exchange_oauth_code_with_params(
token_url: &str,
client_id: &str,
client_secret: Option<&str>,
code: &str,
redirect_uri: &str,
code_verifier: Option<&str>,
access_token_field: &str,
extra_token_params: &HashMap<String, String>,
) -> Result<OAuthTokenResponse, OAuthCallbackError> {
let client = reqwest::Client::new();
let mut token_params = vec![
("grant_type", "authorization_code".to_string()),
("code", code.to_string()),
("redirect_uri", redirect_uri.to_string()),
];
if let Some(verifier) = code_verifier {
token_params.push(("code_verifier", verifier.to_string()));
}
for (key, value) in extra_token_params {
token_params.push((key.as_str(), value.clone()));
}
let mut request = client.post(token_url);
if let Some(secret) = client_secret {
request = request.basic_auth(client_id, Some(secret));
} else {
token_params.push(("client_id", client_id.to_string()));
}
let token_response = request
.form(&token_params)
.send()
.await
.map_err(|e| OAuthCallbackError::Io(format!("Token exchange request failed: {}", e)))?;
if !token_response.status().is_success() {
let status = token_response.status();
let body = token_response.text().await.unwrap_or_default();
return Err(OAuthCallbackError::Io(format!(
"Token exchange failed: {} - {}",
status, body
)));
}
let token_data: serde_json::Value = token_response
.json()
.await
.map_err(|e| OAuthCallbackError::Io(format!("Failed to parse token response: {}", e)))?;
let access_token = token_data
.get(access_token_field)
.and_then(|v| v.as_str())
.ok_or_else(|| {
let fields: Vec<&str> = token_data
.as_object()
.map(|o| o.keys().map(|k| k.as_str()).collect())
.unwrap_or_default();
OAuthCallbackError::Io(format!(
"No '{}' field in token response (fields present: {:?})",
access_token_field, fields
))
})?
.to_string();
let refresh_token = token_data
.get("refresh_token")
.and_then(|v| v.as_str())
.map(String::from);
let expires_in = token_data.get("expires_in").and_then(|v| v.as_u64());
Ok(OAuthTokenResponse {
access_token,
refresh_token,
expires_in,
})
}
#[allow(clippy::too_many_arguments)]
pub async fn exchange_oauth_code_with_resource(
token_url: &str,
client_id: &str,
client_secret: Option<&str>,
code: &str,
redirect_uri: &str,
code_verifier: Option<&str>,
access_token_field: &str,
resource: Option<&str>,
) -> Result<OAuthTokenResponse, OAuthCallbackError> {
let mut extra_token_params = HashMap::new();
if let Some(resource) = resource {
extra_token_params.insert("resource".to_string(), resource.to_string());
}
exchange_oauth_code_with_params(
token_url,
client_id,
client_secret,
code,
redirect_uri,
code_verifier,
access_token_field,
&extra_token_params,
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn store_oauth_tokens(
store: &(dyn SecretsStore + Send + Sync),
user_id: &str,
secret_name: &str,
provider: Option<&str>,
access_token: &str,
refresh_token: Option<&str>,
expires_in: Option<u64>,
scopes: &[String],
) -> Result<(), OAuthCallbackError> {
let mut params = CreateSecretParams::new(secret_name, access_token);
if let Some(prov) = provider {
params = params.with_provider(prov);
}
if let Some(secs) = expires_in {
let expires_at = chrono::Utc::now() + chrono::Duration::seconds(secs as i64);
params = params.with_expiry(expires_at);
}
store
.create(user_id, params)
.await
.map_err(|e| OAuthCallbackError::Io(format!("Failed to save token: {}", e)))?;
if let Some(rt) = refresh_token {
let refresh_name = format!("{}_refresh_token", secret_name);
let mut refresh_params = CreateSecretParams::new(&refresh_name, rt);
if let Some(prov) = provider {
refresh_params = refresh_params.with_provider(prov);
}
store
.create(user_id, refresh_params)
.await
.map_err(|e| OAuthCallbackError::Io(format!("Failed to save refresh token: {}", e)))?;
}
if !scopes.is_empty() {
let scopes_name = format!("{}_scopes", secret_name);
let scopes_value = scopes.join(" ");
let scopes_params = CreateSecretParams::new(&scopes_name, &scopes_value);
let _ = store.create(user_id, scopes_params).await;
}
Ok(())
}
pub async fn validate_oauth_token(
token: &str,
validation: &crate::tools::wasm::ValidationEndpointSchema,
) -> Result<(), OAuthCallbackError> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.map_err(|e| OAuthCallbackError::Io(format!("Failed to build HTTP client: {}", e)))?;
let request = match validation.method.to_uppercase().as_str() {
"POST" => client.post(&validation.url),
_ => client.get(&validation.url),
};
let mut request = request.header("Authorization", format!("Bearer {}", token));
for (key, value) in &validation.headers {
request = request.header(key, value);
}
let response = request
.send()
.await
.map_err(|e| OAuthCallbackError::Io(format!("Validation request failed: {}", e)))?;
if response.status().as_u16() == validation.success_status {
Ok(())
} else {
let status = response.status();
let body = response.text().await.unwrap_or_default();
let truncated: String = if body.len() > 200 {
let mut end = 200;
while end > 0 && !body.is_char_boundary(end) {
end -= 1;
}
format!("{}...", &body[..end])
} else {
body
};
Err(OAuthCallbackError::Io(format!(
"Token validation failed: HTTP {} (expected {}): {}",
status, validation.success_status, truncated
)))
}
}
pub struct PendingOAuthFlow {
pub extension_name: String,
pub display_name: String,
pub token_url: String,
pub client_id: String,
pub client_secret: Option<String>,
pub redirect_uri: String,
pub code_verifier: Option<String>,
pub access_token_field: String,
pub secret_name: String,
pub provider: Option<String>,
pub validation_endpoint: Option<crate::tools::wasm::ValidationEndpointSchema>,
pub scopes: Vec<String>,
pub user_id: String,
pub secrets: Arc<dyn SecretsStore + Send + Sync>,
pub sse_manager: Option<Arc<crate::channels::web::sse::SseManager>>,
pub gateway_token: Option<String>,
pub token_exchange_extra_params: HashMap<String, String>,
pub client_id_secret_name: Option<String>,
pub created_at: std::time::Instant,
}
impl std::fmt::Debug for PendingOAuthFlow {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PendingOAuthFlow")
.field("extension_name", &self.extension_name)
.field("display_name", &self.display_name)
.field("secret_name", &self.secret_name)
.field("created_at", &self.created_at)
.finish_non_exhaustive()
}
}
pub type PendingOAuthRegistry = Arc<RwLock<HashMap<String, PendingOAuthFlow>>>;
pub fn new_pending_oauth_registry() -> PendingOAuthRegistry {
Arc::new(RwLock::new(HashMap::new()))
}
pub fn use_gateway_callback() -> bool {
crate::config::helpers::env_or_override("IRONCLAW_OAUTH_CALLBACK_URL")
.map(|raw| {
url::Url::parse(&raw)
.ok()
.and_then(|u| u.host_str().map(String::from))
.map(|host| !is_loopback_host(&host))
.unwrap_or(false)
})
.unwrap_or(false)
}
pub fn exchange_proxy_url() -> Option<String> {
crate::config::helpers::env_or_override("IRONCLAW_OAUTH_EXCHANGE_URL")
.map(|url| url.trim().to_string())
.filter(|url| !url.is_empty())
}
pub const OAUTH_FLOW_EXPIRY: Duration = Duration::from_secs(300);
pub async fn sweep_expired_flows(registry: &PendingOAuthRegistry) {
let mut flows = registry.write().await;
flows.retain(|_, flow| flow.created_at.elapsed() < OAUTH_FLOW_EXPIRY);
}
const HOSTED_STATE_PREFIX: &str = "ic2";
const HOSTED_STATE_CHECKSUM_BYTES: usize = 12;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DecodedHostedOAuthState {
pub flow_id: String,
pub instance_name: Option<String>,
pub is_legacy: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct HostedOAuthStatePayload {
flow_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
instance_name: Option<String>,
issued_at: u64,
}
fn current_instance_name() -> Option<String> {
crate::config::helpers::env_or_override("IRONCLAW_INSTANCE_NAME")
.or_else(|| crate::config::helpers::env_or_override("OPENCLAW_INSTANCE_NAME"))
.filter(|v| !v.is_empty())
}
fn hosted_state_checksum(payload_bytes: &[u8]) -> String {
let digest = Sha256::digest(payload_bytes);
URL_SAFE_NO_PAD.encode(&digest[..HOSTED_STATE_CHECKSUM_BYTES])
}
pub fn encode_hosted_oauth_state(flow_id: &str, instance_name: Option<&str>) -> String {
let payload = HostedOAuthStatePayload {
flow_id: flow_id.to_string(),
instance_name: instance_name
.map(str::trim)
.filter(|v| !v.is_empty())
.map(str::to_string),
issued_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
let payload_json = match serde_json::to_vec(&payload) {
Ok(payload_json) => payload_json,
Err(error) => {
tracing::warn!(%error, flow_id, "Failed to serialize hosted OAuth state payload");
return payload.flow_id;
}
};
let payload = URL_SAFE_NO_PAD.encode(&payload_json);
let checksum = hosted_state_checksum(&payload_json);
format!("{HOSTED_STATE_PREFIX}.{payload}.{checksum}")
}
pub fn decode_hosted_oauth_state(state: &str) -> Result<DecodedHostedOAuthState, String> {
if let Some(rest) = state.strip_prefix(&format!("{HOSTED_STATE_PREFIX}.")) {
let (payload_b64, checksum) = rest
.rsplit_once('.')
.ok_or("Hosted OAuth versioned state missing checksum separator")?;
let payload_json = URL_SAFE_NO_PAD
.decode(payload_b64)
.map_err(|e| format!("Hosted OAuth versioned state base64 decode failed: {e}"))?;
let expected_checksum = hosted_state_checksum(&payload_json);
if checksum != expected_checksum {
return Err("Hosted OAuth state checksum mismatch".to_string());
}
let payload: HostedOAuthStatePayload = serde_json::from_slice(&payload_json)
.map_err(|e| format!("Hosted OAuth versioned state JSON parse failed: {e}"))?;
if payload.flow_id.trim().is_empty() {
return Err("Hosted OAuth versioned state has empty flow_id".to_string());
}
return Ok(DecodedHostedOAuthState {
flow_id: payload.flow_id,
instance_name: payload.instance_name.filter(|v| !v.is_empty()),
is_legacy: false,
});
}
if let Some((instance_name, flow_id)) = state.split_once(':') {
if flow_id.is_empty() {
return Err("Hosted OAuth legacy state is missing flow_id".to_string());
}
return Ok(DecodedHostedOAuthState {
flow_id: flow_id.to_string(),
instance_name: if instance_name.is_empty() {
None
} else {
Some(instance_name.to_string())
},
is_legacy: true,
});
}
if state.is_empty() {
return Err("Hosted OAuth state is empty".to_string());
}
Ok(DecodedHostedOAuthState {
flow_id: state.to_string(),
instance_name: None,
is_legacy: true,
})
}
pub fn build_platform_state(nonce: &str) -> String {
encode_hosted_oauth_state(nonce, current_instance_name().as_deref())
}
pub fn strip_instance_prefix(state: &str) -> &str {
state
.split_once(':')
.map(|(_, nonce)| nonce)
.unwrap_or(state)
}
pub struct ProxyTokenExchangeRequest<'a> {
pub proxy_url: &'a str,
pub gateway_token: &'a str,
pub token_url: &'a str,
pub client_id: &'a str,
pub client_secret: Option<&'a str>,
pub code: &'a str,
pub redirect_uri: &'a str,
pub code_verifier: Option<&'a str>,
pub access_token_field: &'a str,
pub extra_token_params: &'a HashMap<String, String>,
}
pub struct ProxyRefreshTokenRequest<'a> {
pub proxy_url: &'a str,
pub gateway_token: &'a str,
pub token_url: &'a str,
pub client_id: &'a str,
pub client_secret: Option<&'a str>,
pub refresh_token: &'a str,
pub provider: Option<&'a str>,
}
fn oauth_token_response_from_json(
token_data: serde_json::Value,
access_token_field: &str,
) -> Result<OAuthTokenResponse, OAuthCallbackError> {
let access_token = token_data
.get(access_token_field)
.and_then(|v| v.as_str())
.ok_or_else(|| {
let fields: Vec<&str> = token_data
.as_object()
.map(|o| o.keys().map(|k| k.as_str()).collect())
.unwrap_or_default();
OAuthCallbackError::Io(format!(
"No '{}' field in proxy response (fields present: {:?})",
access_token_field, fields
))
})?
.to_string();
let refresh_token = token_data
.get("refresh_token")
.and_then(|v| v.as_str())
.map(String::from);
let expires_in = token_data.get("expires_in").and_then(|v| v.as_u64());
Ok(OAuthTokenResponse {
access_token,
refresh_token,
expires_in,
})
}
pub async fn exchange_via_proxy(
request: ProxyTokenExchangeRequest<'_>,
) -> Result<OAuthTokenResponse, OAuthCallbackError> {
if request.gateway_token.is_empty() {
return Err(OAuthCallbackError::Io(
"Gateway auth token is required for proxy token exchange".to_string(),
));
}
let exchange_url = format!("{}/oauth/exchange", request.proxy_url.trim_end_matches('/'));
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(60))
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| OAuthCallbackError::Io(format!("Failed to build HTTP client: {}", e)))?;
let mut params = vec![
("code", request.code.to_string()),
("redirect_uri", request.redirect_uri.to_string()),
("token_url", request.token_url.to_string()),
("client_id", request.client_id.to_string()),
("access_token_field", request.access_token_field.to_string()),
];
if let Some(verifier) = request.code_verifier {
params.push(("code_verifier", verifier.to_string()));
}
if let Some(secret) = request.client_secret {
params.push(("client_secret", secret.to_string()));
}
for (key, value) in request.extra_token_params {
params.push((key.as_str(), value.clone()));
}
let response = client
.post(&exchange_url)
.bearer_auth(request.gateway_token)
.form(¶ms)
.send()
.await
.map_err(|e| {
OAuthCallbackError::Io(format!("Token exchange proxy request failed: {}", e))
})?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(OAuthCallbackError::Io(format!(
"Token exchange proxy failed: {} - {}",
status, body
)));
}
let token_data: serde_json::Value = response
.json()
.await
.map_err(|e| OAuthCallbackError::Io(format!("Failed to parse proxy response: {}", e)))?;
oauth_token_response_from_json(token_data, request.access_token_field)
}
pub async fn refresh_token_via_proxy(
request: ProxyRefreshTokenRequest<'_>,
) -> Result<OAuthTokenResponse, OAuthCallbackError> {
if request.gateway_token.is_empty() {
return Err(OAuthCallbackError::Io(
"Gateway auth token is required for proxy token refresh".to_string(),
));
}
let refresh_url = format!("{}/oauth/refresh", request.proxy_url.trim_end_matches('/'));
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(15))
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| OAuthCallbackError::Io(format!("Failed to build HTTP client: {}", e)))?;
let mut params = vec![
("refresh_token", request.refresh_token.to_string()),
("token_url", request.token_url.to_string()),
("client_id", request.client_id.to_string()),
];
if let Some(secret) = request.client_secret {
params.push(("client_secret", secret.to_string()));
}
if let Some(provider) = request.provider {
params.push(("provider", provider.to_string()));
}
let response = client
.post(&refresh_url)
.bearer_auth(request.gateway_token)
.form(¶ms)
.send()
.await
.map_err(|e| {
OAuthCallbackError::Io(format!("Token refresh proxy request failed: {}", e))
})?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(OAuthCallbackError::Io(format!(
"Token refresh proxy failed: {} - {}",
status, body
)));
}
let token_data: serde_json::Value = response
.json()
.await
.map_err(|e| OAuthCallbackError::Io(format!("Failed to parse proxy response: {}", e)))?;
oauth_token_response_from_json(token_data, "access_token")
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use axum::extract::{Form, State};
use axum::http::HeaderMap;
use axum::response::Redirect;
use axum::routing::post;
use axum::{Json, Router};
use serde_json::json;
use tokio::net::TcpListener;
use tokio::sync::{Mutex, oneshot};
use crate::cli::oauth_defaults::{
builtin_credentials, callback_host, callback_url, is_loopback_host, landing_html,
};
use crate::config::helpers::lock_env;
use crate::testing::credentials::{TEST_OAUTH_CLIENT_ID, TEST_OAUTH_CLIENT_SECRET};
#[derive(Clone, Debug, PartialEq, Eq)]
struct RecordedProxyRequest {
authorization: Option<String>,
form: HashMap<String, String>,
}
#[derive(Clone)]
struct MockProxyState {
requests: Arc<Mutex<Vec<RecordedProxyRequest>>>,
exchange_redirect_target: String,
refresh_redirect_target: String,
}
struct MockProxyServer {
addr: SocketAddr,
requests: Arc<Mutex<Vec<RecordedProxyRequest>>>,
shutdown_tx: Option<oneshot::Sender<()>>,
server_task: Option<tokio::task::JoinHandle<()>>,
}
impl MockProxyServer {
async fn start() -> Self {
async fn exchange_handler(
State(state): State<MockProxyState>,
headers: HeaderMap,
Form(form): Form<HashMap<String, String>>,
) -> Json<serde_json::Value> {
state.requests.lock().await.push(RecordedProxyRequest {
authorization: headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.map(str::to_string),
form,
});
Json(json!({
"access_token": "proxy-access-token",
"refresh_token": "proxy-refresh-token",
"expires_in": 7200
}))
}
async fn refresh_handler(
State(state): State<MockProxyState>,
headers: HeaderMap,
Form(form): Form<HashMap<String, String>>,
) -> Json<serde_json::Value> {
state.requests.lock().await.push(RecordedProxyRequest {
authorization: headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.map(str::to_string),
form,
});
Json(json!({
"access_token": "proxy-access-token",
"refresh_token": "proxy-refresh-token",
"expires_in": 7200
}))
}
async fn exchange_redirect_handler(State(state): State<MockProxyState>) -> Redirect {
Redirect::temporary(&state.exchange_redirect_target)
}
async fn refresh_redirect_handler(State(state): State<MockProxyState>) -> Redirect {
Redirect::temporary(&state.refresh_redirect_target)
}
let requests = Arc::new(Mutex::new(Vec::new()));
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind mock proxy");
let addr = listener.local_addr().expect("read mock proxy addr");
let exchange_redirect_target = format!("http://{addr}/oauth/exchange");
let refresh_redirect_target = format!("http://{addr}/oauth/refresh");
let app = Router::new()
.route("/oauth/exchange", post(exchange_handler))
.route("/oauth/refresh", post(refresh_handler))
.route("/redirect/oauth/exchange", post(exchange_redirect_handler))
.route("/redirect/oauth/refresh", post(refresh_redirect_handler))
.with_state(MockProxyState {
requests: Arc::clone(&requests),
exchange_redirect_target,
refresh_redirect_target,
});
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let server_task = tokio::spawn(async move {
let _ = axum::serve(listener, app)
.with_graceful_shutdown(async {
let _ = shutdown_rx.await;
})
.await;
});
Self {
addr,
requests,
shutdown_tx: Some(shutdown_tx),
server_task: Some(server_task),
}
}
fn base_url(&self) -> String {
format!("http://{}", self.addr)
}
fn redirecting_base_url(&self) -> String {
format!("{}/redirect", self.base_url())
}
async fn requests(&self) -> Vec<RecordedProxyRequest> {
self.requests.lock().await.clone()
}
async fn shutdown(mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(task) = self.server_task.take() {
let _ = task.await;
}
}
}
impl Drop for MockProxyServer {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(task) = self.server_task.take() {
task.abort();
}
}
}
#[test]
fn test_hosted_proxy_client_secret_suppresses_builtin_secret() {
let builtin = builtin_credentials("google_oauth_token").expect("google builtin creds");
let client_secret = Some(builtin.client_secret.to_string());
let result = super::hosted_proxy_client_secret(&client_secret, Some(&builtin), true);
assert_eq!(result, None);
}
#[test]
fn test_hosted_proxy_client_secret_preserves_explicit_secret() {
let builtin = builtin_credentials("google_oauth_token").expect("google builtin creds");
let client_secret = Some("hosted-server-secret".to_string());
let result = super::hosted_proxy_client_secret(&client_secret, Some(&builtin), true);
assert_eq!(result, client_secret);
}
#[tokio::test]
async fn test_refresh_token_via_proxy_sends_auth_and_form() {
let server = MockProxyServer::start().await;
let response = super::refresh_token_via_proxy(super::ProxyRefreshTokenRequest {
proxy_url: &server.base_url(),
gateway_token: "gateway-test-token",
token_url: "https://oauth2.googleapis.com/token",
client_id: TEST_OAUTH_CLIENT_ID,
client_secret: Some(TEST_OAUTH_CLIENT_SECRET),
refresh_token: "refresh-token-123",
provider: Some("google"),
})
.await
.expect("proxy refresh succeeds");
assert_eq!(response.access_token, "proxy-access-token");
assert_eq!(
response.refresh_token.as_deref(),
Some("proxy-refresh-token")
);
assert_eq!(response.expires_in, Some(7200));
let requests = server.requests().await;
assert_eq!(requests.len(), 1);
assert_eq!(
requests[0].authorization.as_deref(),
Some("Bearer gateway-test-token")
);
assert_eq!(
requests[0].form.get("token_url").map(String::as_str),
Some("https://oauth2.googleapis.com/token")
);
assert_eq!(
requests[0].form.get("client_id").map(String::as_str),
Some(TEST_OAUTH_CLIENT_ID)
);
assert_eq!(
requests[0].form.get("client_secret").map(String::as_str),
Some(TEST_OAUTH_CLIENT_SECRET)
);
assert_eq!(
requests[0].form.get("refresh_token").map(String::as_str),
Some("refresh-token-123")
);
assert_eq!(
requests[0].form.get("provider").map(String::as_str),
Some("google")
);
server.shutdown().await;
}
#[tokio::test]
async fn test_exchange_via_proxy_does_not_follow_redirects() {
let server = MockProxyServer::start().await;
let error = match super::exchange_via_proxy(super::ProxyTokenExchangeRequest {
proxy_url: &server.redirecting_base_url(),
gateway_token: "gateway-test-token",
code: "auth-code-123",
redirect_uri: "http://localhost:3000/oauth/callback",
token_url: "https://oauth2.googleapis.com/token",
client_id: TEST_OAUTH_CLIENT_ID,
client_secret: Some(TEST_OAUTH_CLIENT_SECRET),
access_token_field: "access_token",
code_verifier: Some("code-verifier-123"),
extra_token_params: &HashMap::new(),
})
.await
{
Ok(_) => panic!("redirected proxy exchange should fail"),
Err(error) => error,
};
assert!(error.to_string().contains("307"));
assert!(server.requests().await.is_empty());
server.shutdown().await;
}
#[tokio::test]
async fn test_refresh_token_via_proxy_does_not_follow_redirects() {
let server = MockProxyServer::start().await;
let error = match super::refresh_token_via_proxy(super::ProxyRefreshTokenRequest {
proxy_url: &server.redirecting_base_url(),
gateway_token: "gateway-test-token",
token_url: "https://oauth2.googleapis.com/token",
client_id: TEST_OAUTH_CLIENT_ID,
client_secret: Some(TEST_OAUTH_CLIENT_SECRET),
refresh_token: "refresh-token-123",
provider: Some("google"),
})
.await
{
Ok(_) => panic!("redirected proxy refresh should fail"),
Err(error) => error,
};
assert!(error.to_string().contains("307"));
assert!(server.requests().await.is_empty());
server.shutdown().await;
}
#[test]
fn test_is_loopback_host() {
assert!(is_loopback_host("127.0.0.1"));
assert!(is_loopback_host("127.0.0.2")); assert!(is_loopback_host("127.255.255.254"));
assert!(is_loopback_host("::1"));
assert!(is_loopback_host("localhost"));
assert!(is_loopback_host("LOCALHOST"));
assert!(!is_loopback_host("203.0.113.10"));
assert!(!is_loopback_host("my-server.example.com"));
assert!(!is_loopback_host("0.0.0.0"));
}
#[test]
fn test_callback_host_default() {
let _guard = lock_env();
let original = std::env::var("OAUTH_CALLBACK_HOST").ok();
unsafe {
std::env::remove_var("OAUTH_CALLBACK_HOST");
}
assert_eq!(callback_host(), "127.0.0.1");
unsafe {
if let Some(val) = original {
std::env::set_var("OAUTH_CALLBACK_HOST", val);
}
}
}
#[test]
fn test_callback_host_env_override() {
let _guard = lock_env();
let original_host = std::env::var("OAUTH_CALLBACK_HOST").ok();
let original_url = std::env::var("IRONCLAW_OAUTH_CALLBACK_URL").ok();
unsafe {
std::env::set_var("OAUTH_CALLBACK_HOST", "203.0.113.10");
std::env::remove_var("IRONCLAW_OAUTH_CALLBACK_URL");
}
assert_eq!(callback_host(), "203.0.113.10");
let url = callback_url();
assert!(url.contains("203.0.113.10"), "url was: {url}");
unsafe {
if let Some(val) = original_host {
std::env::set_var("OAUTH_CALLBACK_HOST", val);
} else {
std::env::remove_var("OAUTH_CALLBACK_HOST");
}
if let Some(val) = original_url {
std::env::set_var("IRONCLAW_OAUTH_CALLBACK_URL", val);
}
}
}
#[test]
fn test_callback_url_default() {
let _guard = lock_env();
let original_url = std::env::var("IRONCLAW_OAUTH_CALLBACK_URL").ok();
let original_host = std::env::var("OAUTH_CALLBACK_HOST").ok();
unsafe {
std::env::remove_var("IRONCLAW_OAUTH_CALLBACK_URL");
std::env::remove_var("OAUTH_CALLBACK_HOST");
}
let url = callback_url();
assert_eq!(url, "http://127.0.0.1:9876");
unsafe {
if let Some(val) = original_url {
std::env::set_var("IRONCLAW_OAUTH_CALLBACK_URL", val);
}
if let Some(val) = original_host {
std::env::set_var("OAUTH_CALLBACK_HOST", val);
}
}
}
#[test]
fn test_callback_url_env_override() {
let _guard = lock_env();
let original = std::env::var("IRONCLAW_OAUTH_CALLBACK_URL").ok();
unsafe {
std::env::set_var(
"IRONCLAW_OAUTH_CALLBACK_URL",
"https://myserver.example.com:9876",
);
}
let url = callback_url();
assert_eq!(url, "https://myserver.example.com:9876");
unsafe {
if let Some(val) = original {
std::env::set_var("IRONCLAW_OAUTH_CALLBACK_URL", val);
} else {
std::env::remove_var("IRONCLAW_OAUTH_CALLBACK_URL");
}
}
}
#[test]
fn test_unknown_provider_returns_none() {
assert!(builtin_credentials("unknown_token").is_none());
}
#[test]
fn test_google_returns_based_on_compile_env() {
let creds = builtin_credentials("google_oauth_token");
assert!(creds.is_some());
let creds = creds.unwrap();
assert!(!creds.client_id.is_empty());
assert!(!creds.client_secret.is_empty());
}
#[test]
fn test_landing_html_success_contains_key_elements() {
let html = landing_html("Google", true);
assert!(html.contains("Google Connected"));
assert!(html.contains("charset"));
assert!(html.contains("IronClaw"));
assert!(html.contains("#22c55e")); assert!(!html.contains("Failed"));
}
#[test]
fn test_landing_html_escapes_provider_name() {
let html = landing_html("<script>alert(1)</script>", true);
assert!(!html.contains("<script>"));
assert!(html.contains("<script>"));
}
#[test]
fn test_landing_html_error_contains_key_elements() {
let html = landing_html("Notion", false);
assert!(html.contains("Authorization Failed"));
assert!(html.contains("charset"));
assert!(html.contains("IronClaw"));
assert!(html.contains("#ef4444")); assert!(!html.contains("Connected"));
}
#[test]
fn test_build_oauth_url_basic() {
use std::collections::HashMap;
use crate::cli::oauth_defaults::build_oauth_url;
let result = build_oauth_url(
"https://accounts.google.com/o/oauth2/auth",
"my-client-id",
"http://localhost:9876/callback",
&["openid".to_string(), "email".to_string()],
false,
&HashMap::new(),
);
assert!(
result
.url
.starts_with("https://accounts.google.com/o/oauth2/auth?")
);
assert!(result.url.contains("client_id=my-client-id"));
assert!(result.url.contains("response_type=code"));
assert!(result.url.contains("redirect_uri="));
assert!(result.url.contains("scope=openid%20email"));
assert!(result.url.contains("state="));
assert!(result.code_verifier.is_none());
assert!(!result.state.is_empty());
}
#[test]
fn test_build_oauth_url_with_pkce() {
use std::collections::HashMap;
use crate::cli::oauth_defaults::build_oauth_url;
let result = build_oauth_url(
"https://auth.example.com/authorize",
"client-123",
"http://localhost:9876/callback",
&[],
true,
&HashMap::new(),
);
assert!(result.url.contains("code_challenge="));
assert!(result.url.contains("code_challenge_method=S256"));
assert!(result.code_verifier.is_some());
let verifier = result.code_verifier.unwrap();
assert!(!verifier.is_empty());
}
#[test]
fn test_build_oauth_url_with_extra_params() {
use std::collections::HashMap;
use crate::cli::oauth_defaults::build_oauth_url;
let mut extra = HashMap::new();
extra.insert("access_type".to_string(), "offline".to_string());
extra.insert("prompt".to_string(), "consent".to_string());
let result = build_oauth_url(
"https://auth.example.com/authorize",
"client-123",
"http://localhost:9876/callback",
&["read".to_string()],
false,
&extra,
);
assert!(result.url.contains("access_type=offline"));
assert!(result.url.contains("prompt=consent"));
}
#[test]
fn test_build_oauth_url_state_is_unique() {
use std::collections::HashMap;
use crate::cli::oauth_defaults::build_oauth_url;
let result1 = build_oauth_url(
"https://auth.example.com/authorize",
"client",
"http://localhost:9876/callback",
&[],
false,
&HashMap::new(),
);
let result2 = build_oauth_url(
"https://auth.example.com/authorize",
"client",
"http://localhost:9876/callback",
&[],
false,
&HashMap::new(),
);
assert_ne!(result1.state, result2.state);
}
#[test]
fn test_use_gateway_callback_false_by_default() {
let _guard = lock_env();
let original = std::env::var("IRONCLAW_OAUTH_CALLBACK_URL").ok();
unsafe {
std::env::remove_var("IRONCLAW_OAUTH_CALLBACK_URL");
}
assert!(!crate::cli::oauth_defaults::use_gateway_callback());
unsafe {
if let Some(val) = original {
std::env::set_var("IRONCLAW_OAUTH_CALLBACK_URL", val);
}
}
}
#[test]
fn test_use_gateway_callback_true_for_hosted() {
let _guard = lock_env();
let original = std::env::var("IRONCLAW_OAUTH_CALLBACK_URL").ok();
unsafe {
std::env::set_var(
"IRONCLAW_OAUTH_CALLBACK_URL",
"https://kind-deer.agent1.near.ai",
);
}
assert!(crate::cli::oauth_defaults::use_gateway_callback());
unsafe {
if let Some(val) = original {
std::env::set_var("IRONCLAW_OAUTH_CALLBACK_URL", val);
} else {
std::env::remove_var("IRONCLAW_OAUTH_CALLBACK_URL");
}
}
}
#[test]
fn test_use_gateway_callback_false_for_localhost() {
let _guard = lock_env();
let original = std::env::var("IRONCLAW_OAUTH_CALLBACK_URL").ok();
unsafe {
std::env::set_var("IRONCLAW_OAUTH_CALLBACK_URL", "http://127.0.0.1:3001");
}
assert!(!crate::cli::oauth_defaults::use_gateway_callback());
unsafe {
if let Some(val) = original {
std::env::set_var("IRONCLAW_OAUTH_CALLBACK_URL", val);
} else {
std::env::remove_var("IRONCLAW_OAUTH_CALLBACK_URL");
}
}
}
#[test]
fn test_use_gateway_callback_false_for_empty() {
let _guard = lock_env();
let original = std::env::var("IRONCLAW_OAUTH_CALLBACK_URL").ok();
unsafe {
std::env::set_var("IRONCLAW_OAUTH_CALLBACK_URL", "");
}
assert!(!crate::cli::oauth_defaults::use_gateway_callback());
unsafe {
if let Some(val) = original {
std::env::set_var("IRONCLAW_OAUTH_CALLBACK_URL", val);
} else {
std::env::remove_var("IRONCLAW_OAUTH_CALLBACK_URL");
}
}
}
#[test]
fn test_build_platform_state_with_instance() {
use crate::cli::oauth_defaults::{build_platform_state, decode_hosted_oauth_state};
let _guard = lock_env();
let original = std::env::var("IRONCLAW_INSTANCE_NAME").ok();
unsafe {
std::env::set_var("IRONCLAW_INSTANCE_NAME", "kind-deer");
}
let encoded = build_platform_state("abc123");
let decoded = decode_hosted_oauth_state(&encoded).expect("decode hosted state");
assert_eq!(decoded.flow_id, "abc123");
assert_eq!(decoded.instance_name.as_deref(), Some("kind-deer"));
assert!(!decoded.is_legacy);
unsafe {
if let Some(val) = original {
std::env::set_var("IRONCLAW_INSTANCE_NAME", val);
} else {
std::env::remove_var("IRONCLAW_INSTANCE_NAME");
}
}
}
#[test]
fn test_build_platform_state_without_instance() {
use crate::cli::oauth_defaults::{build_platform_state, decode_hosted_oauth_state};
let _guard = lock_env();
let original = std::env::var("IRONCLAW_INSTANCE_NAME").ok();
let original_oc = std::env::var("OPENCLAW_INSTANCE_NAME").ok();
unsafe {
std::env::remove_var("IRONCLAW_INSTANCE_NAME");
std::env::remove_var("OPENCLAW_INSTANCE_NAME");
}
let encoded = build_platform_state("abc123");
let decoded = decode_hosted_oauth_state(&encoded).expect("decode hosted state");
assert_eq!(decoded.flow_id, "abc123");
assert_eq!(decoded.instance_name, None);
assert!(!decoded.is_legacy);
unsafe {
if let Some(val) = original {
std::env::set_var("IRONCLAW_INSTANCE_NAME", val);
}
if let Some(val) = original_oc {
std::env::set_var("OPENCLAW_INSTANCE_NAME", val);
}
}
}
#[test]
fn test_build_platform_state_with_openclaw_instance() {
use crate::cli::oauth_defaults::{build_platform_state, decode_hosted_oauth_state};
let _guard = lock_env();
let original_ic = std::env::var("IRONCLAW_INSTANCE_NAME").ok();
let original_oc = std::env::var("OPENCLAW_INSTANCE_NAME").ok();
unsafe {
std::env::remove_var("IRONCLAW_INSTANCE_NAME");
std::env::set_var("OPENCLAW_INSTANCE_NAME", "quiet-lion");
}
let encoded = build_platform_state("xyz789");
let decoded = decode_hosted_oauth_state(&encoded).expect("decode hosted state");
assert_eq!(decoded.flow_id, "xyz789");
assert_eq!(decoded.instance_name.as_deref(), Some("quiet-lion"));
assert!(!decoded.is_legacy);
unsafe {
if let Some(val) = original_ic {
std::env::set_var("IRONCLAW_INSTANCE_NAME", val);
}
if let Some(val) = original_oc {
std::env::set_var("OPENCLAW_INSTANCE_NAME", val);
} else {
std::env::remove_var("OPENCLAW_INSTANCE_NAME");
}
}
}
#[test]
fn test_strip_instance_prefix_with_colon() {
use crate::cli::oauth_defaults::strip_instance_prefix;
assert_eq!(strip_instance_prefix("kind-deer:abc123"), "abc123");
assert_eq!(strip_instance_prefix("my-instance:xyz"), "xyz");
}
#[test]
fn test_strip_instance_prefix_without_colon() {
use crate::cli::oauth_defaults::strip_instance_prefix;
assert_eq!(strip_instance_prefix("abc123"), "abc123");
assert_eq!(strip_instance_prefix(""), "");
}
#[test]
fn test_decode_hosted_oauth_state_accepts_legacy_formats() {
use crate::cli::oauth_defaults::decode_hosted_oauth_state;
let decoded = decode_hosted_oauth_state("kind-deer:abc123").expect("legacy prefixed");
assert_eq!(decoded.flow_id, "abc123");
assert_eq!(decoded.instance_name.as_deref(), Some("kind-deer"));
assert!(decoded.is_legacy);
let decoded = decode_hosted_oauth_state("abc123").expect("legacy raw");
assert_eq!(decoded.flow_id, "abc123");
assert_eq!(decoded.instance_name, None);
assert!(decoded.is_legacy);
}
#[test]
fn test_decode_hosted_oauth_state_rejects_non_envelope_ic2_prefix() {
use crate::cli::oauth_defaults::decode_hosted_oauth_state;
decode_hosted_oauth_state("ic2.provider-owned-state")
.expect_err("ic2-prefixed non-envelope state should fail");
}
#[test]
fn test_decode_hosted_oauth_state_rejects_tampered_checksum() {
use crate::cli::oauth_defaults::{decode_hosted_oauth_state, encode_hosted_oauth_state};
let encoded = encode_hosted_oauth_state("abc123", Some("kind-deer"));
let tampered = format!("{encoded}broken");
let err = decode_hosted_oauth_state(&tampered).expect_err("tampered state should fail");
assert!(err.contains("checksum"), "unexpected error: {err}");
}
#[test]
fn test_build_oauth_url_includes_resource_via_extra_params() {
use std::collections::HashMap;
use crate::cli::oauth_defaults::build_oauth_url;
let mut extra = HashMap::new();
extra.insert(
"resource".to_string(),
"https://mcp.example.com".to_string(),
);
let result = build_oauth_url(
"https://auth.example.com/authorize",
"client-123",
"https://gateway.example.com/oauth/callback",
&["read".to_string()],
true,
&extra,
);
assert!(
result
.url
.contains("resource=https%3A%2F%2Fmcp.example.com"),
"Expected resource param in URL: {}",
result.url
);
assert!(result.url.contains("state="));
assert!(result.url.contains("code_challenge="));
assert!(result.code_verifier.is_some());
}
#[test]
fn test_decode_versioned_state_rejects_malformed_envelopes() {
use crate::cli::oauth_defaults::decode_hosted_oauth_state;
let err =
decode_hosted_oauth_state("ic2.nodots").expect_err("missing separator should fail");
assert!(
err.contains("checksum separator"),
"unexpected error: {err}"
);
let err = decode_hosted_oauth_state("ic2.!!!badbase64!!!.fakechecksum")
.expect_err("bad base64 should fail");
assert!(err.contains("base64"), "unexpected error: {err}");
use base64::Engine;
use sha2::Digest;
let not_json_bytes = b"not json";
let not_json_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(not_json_bytes);
let digest = sha2::Sha256::digest(not_json_bytes);
let checksum = base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(&digest[..super::HOSTED_STATE_CHECKSUM_BYTES]);
let err = decode_hosted_oauth_state(&format!("ic2.{not_json_b64}.{checksum}"))
.expect_err("non-JSON payload should fail with JSON parse error");
assert!(
err.contains("JSON"),
"unexpected error (expected JSON parse failure): {err}"
);
}
#[test]
fn test_oauth_flow_key_round_trip_consistency() {
use crate::cli::oauth_defaults::{decode_hosted_oauth_state, encode_hosted_oauth_state};
let nonce = "test-nonce-abc123";
let encoded = encode_hosted_oauth_state(nonce, Some("my-instance"));
let decoded = decode_hosted_oauth_state(&encoded).expect("round-trip decode");
assert_eq!(
decoded.flow_id, nonce,
"flow_id must match the original nonce"
);
assert_eq!(decoded.instance_name.as_deref(), Some("my-instance"));
assert!(!decoded.is_legacy);
let encoded_no_instance = encode_hosted_oauth_state(nonce, None);
let decoded_no_instance =
decode_hosted_oauth_state(&encoded_no_instance).expect("round-trip without instance");
assert_eq!(decoded_no_instance.flow_id, nonce);
assert_eq!(decoded_no_instance.instance_name, None);
assert!(!decoded_no_instance.is_legacy);
}
}