use axum::http::HeaderMap;
use crate::handlers::{AppError, AppState};
pub enum PrincipalKind {
Admin,
OAuth {
#[allow(dead_code)]
client_id: String,
},
Managed {
#[allow(dead_code)]
name: String,
},
}
pub struct Principal {
pub kind: PrincipalKind,
pub scopes: Vec<String>,
pub display_name: String,
}
impl Principal {
#[allow(dead_code)]
pub fn has_scope(&self, scope: &str) -> bool {
self.scopes.is_empty() || self.scopes.iter().any(|s| s == scope)
}
pub fn is_admin(&self) -> bool {
matches!(self.kind, PrincipalKind::Admin)
}
pub fn can_write(&self) -> bool {
match self.kind {
PrincipalKind::Admin => true,
PrincipalKind::Managed { .. } => true,
PrincipalKind::OAuth { .. } => self.scopes.iter().any(|s| s == "mcp:tools"),
}
}
}
pub async fn check_auth(state: &AppState, headers: &HeaderMap) -> Result<Principal, AppError> {
let provided = extract_bearer(headers).ok_or(AppError::Unauthorized)?;
check_auth_token(state, provided).await
}
pub async fn check_auth_token(state: &AppState, provided: &str) -> Result<Principal, AppError> {
if constant_time_eq(provided.as_bytes(), state.config.token.as_bytes()) {
return Ok(Principal {
kind: PrincipalKind::Admin,
scopes: vec![],
display_name: "admin".to_string(),
});
}
let provided_for_oauth = provided.to_string();
let db_oauth = state.db.clone();
let oauth_result =
tokio::task::spawn_blocking(move || db_oauth.get_access_token(&provided_for_oauth))
.await
.map_err(|e| AppError::Internal(format!("Auth task failed: {e}")))?;
match oauth_result {
Ok(Some(record)) => {
let now = crate::helpers::chrono_now();
if record.expires_at.as_str() >= now.as_str() {
let client_id = record.client_id.clone();
let scopes: Vec<String> = record
.scope
.as_deref()
.map(|s| s.split_whitespace().map(|t| t.to_string()).collect())
.unwrap_or_default();
let display_name = format!("oauth:{client_id}");
return Ok(Principal {
kind: PrincipalKind::OAuth { client_id },
scopes,
display_name,
});
}
}
Ok(None) => {} Err(e) => {
tracing::warn!(error = %e, "Failed to look up OAuth access token");
}
}
let prefix: String = provided.chars().take(8).collect();
let prefix_for_lookup = prefix.clone();
let db_prefix = state.db.clone();
let candidate =
tokio::task::spawn_blocking(move || db_prefix.get_token_by_prefix(&prefix_for_lookup))
.await
.map_err(|e| AppError::Internal(format!("Auth task failed: {e}")))?
.map_err(|_| AppError::Internal("Failed to check tokens".to_string()))?;
if let Some(token_record) = candidate {
let provided_owned = provided.to_string();
let hash_owned = token_record.hash.clone();
let verified = tokio::task::spawn_blocking(move || {
crate::helpers::verify_password(&provided_owned, &hash_owned)
})
.await
.map_err(|e| AppError::Internal(format!("Auth task failed: {e}")))?;
if verified {
let now = crate::helpers::chrono_now();
let id_owned = token_record.id.clone();
let now_owned = now.clone();
let db_touch = state.db.clone();
let _ =
tokio::task::spawn_blocking(move || db_touch.touch_token(&id_owned, &now_owned))
.await;
let name = token_record.name.clone();
return Ok(Principal {
display_name: format!("managed:{name}"),
kind: PrincipalKind::Managed { name },
scopes: vec![],
});
}
}
let db_legacy = state.db.clone();
let legacy_tokens = tokio::task::spawn_blocking(move || db_legacy.get_legacy_active_tokens())
.await
.map_err(|e| AppError::Internal(format!("Auth task failed: {e}")))?
.map_err(|_| AppError::Internal("Failed to check tokens".to_string()))?;
if !legacy_tokens.is_empty() {
let provided_owned = provided.to_string();
let result = tokio::task::spawn_blocking(move || {
for token_record in &legacy_tokens {
if crate::helpers::verify_password(&provided_owned, &token_record.hash) {
return Some((token_record.id.clone(), token_record.name.clone()));
}
}
None
})
.await
.map_err(|e| AppError::Internal(format!("Auth task failed: {e}")))?;
if let Some((id, name)) = result {
let now = crate::helpers::chrono_now();
let db_touch2 = state.db.clone();
let _ = tokio::task::spawn_blocking(move || db_touch2.touch_token(&id, &now)).await;
return Ok(Principal {
display_name: format!("managed:{name}"),
kind: PrincipalKind::Managed { name },
scopes: vec![],
});
}
}
Err(AppError::Unauthorized)
}
pub fn extract_bearer(headers: &HeaderMap) -> Option<&str> {
let auth = headers.get("authorization")?.to_str().ok()?;
auth.strip_prefix("Bearer ")
}
pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
use subtle::ConstantTimeEq;
if a.len() != b.len() {
return false;
}
a.ct_eq(b).into()
}