#[cfg(feature = "portal")]
use axum::{
extract::{Json, Path, State},
http::StatusCode,
response::IntoResponse,
};
#[cfg(feature = "portal")]
use chrono::{Duration, Utc};
#[cfg(feature = "portal")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "portal")]
use sha2::{Digest, Sha256};
#[cfg(feature = "portal")]
use uuid::Uuid;
#[cfg(feature = "portal")]
use crate::portal::auth_db::PortalState;
#[cfg(feature = "portal")]
use crate::portal::db::{models::CreateApiKey, queries::ApiKeyRepository, DbError};
#[cfg(feature = "portal")]
use crate::portal::middleware::AuthClaims;
#[cfg(feature = "portal")]
const MAX_API_KEYS_PER_USER: i64 = 10;
#[cfg(feature = "portal")]
#[derive(Debug, Serialize)]
pub struct ApiKeyResponse {
pub id: String,
pub name: String,
pub prefix: String,
pub scopes: Vec<String>,
pub rate_limit_rpm: Option<i32>,
pub last_used_at: Option<String>,
pub expires_at: Option<String>,
pub created_at: String,
}
#[cfg(feature = "portal")]
#[derive(Debug, Serialize)]
pub struct ApiKeyCreatedResponse {
pub id: String,
pub name: String,
pub key: String, pub prefix: String,
pub scopes: Vec<String>,
pub rate_limit_rpm: Option<i32>,
pub expires_at: Option<String>,
pub created_at: String,
pub warning: String,
}
#[cfg(feature = "portal")]
#[derive(Debug, Deserialize)]
pub struct CreateApiKeyRequest {
pub name: String,
#[serde(default)]
pub scopes: Vec<String>,
pub rate_limit_rpm: Option<i32>,
pub expires_in_days: Option<u32>,
}
#[cfg(feature = "portal")]
fn generate_api_key() -> String {
use rand::Rng;
let mut rng = rand::rng();
let key: String = (0..32)
.map(|_| {
let idx = rng.random_range(0..62);
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
.chars()
.nth(idx)
.unwrap()
})
.collect();
format!("rk_{}", key)
}
#[cfg(feature = "portal")]
fn hash_api_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
format!("{:x}", hasher.finalize())
}
#[cfg(feature = "portal")]
fn validate_scopes(scopes: &[String]) -> Result<Vec<String>, &'static str> {
let valid_scopes = ["read", "write", "admin"];
let mut result = Vec::new();
for scope in scopes {
let scope_lower = scope.to_lowercase();
if valid_scopes.contains(&scope_lower.as_str()) {
if !result.contains(&scope_lower) {
result.push(scope_lower);
}
} else {
return Err("Invalid scope. Allowed: read, write, admin");
}
}
if result.is_empty() {
result.push("read".to_string());
}
Ok(result)
}
#[cfg(feature = "portal")]
pub async fn list_keys(
State(state): State<PortalState>,
AuthClaims(claims): AuthClaims,
) -> impl IntoResponse {
let user_id = match Uuid::parse_str(&claims.sub) {
Ok(id) => id,
Err(_) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error": "Invalid user ID"})),
);
}
};
let key_repo = ApiKeyRepository::new(state.db.pool());
match key_repo.list_for_user(user_id).await {
Ok(keys) => {
let response: Vec<ApiKeyResponse> = keys
.into_iter()
.map(|k| ApiKeyResponse {
id: k.id.to_string(),
name: k.name,
prefix: k.key_prefix,
scopes: k.scopes,
rate_limit_rpm: k.rate_limit_rpm,
last_used_at: k.last_used_at.map(|t| t.to_rfc3339()),
expires_at: k.expires_at.map(|t| t.to_rfc3339()),
created_at: k.created_at.to_rfc3339(),
})
.collect();
(StatusCode::OK, Json(serde_json::json!({"keys": response})))
}
Err(e) => {
tracing::error!("Database error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": "Failed to fetch API keys"})),
)
}
}
}
#[cfg(feature = "portal")]
pub async fn create_key(
State(state): State<PortalState>,
AuthClaims(claims): AuthClaims,
Json(req): Json<CreateApiKeyRequest>,
) -> impl IntoResponse {
let user_id = match Uuid::parse_str(&claims.sub) {
Ok(id) => id,
Err(_) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error": "Invalid user ID"})),
);
}
};
if req.name.is_empty() || req.name.len() > 100 {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error": "Name must be 1-100 characters"})),
);
}
let scopes = match validate_scopes(&req.scopes) {
Ok(s) => s,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error": e})),
);
}
};
let key_repo = ApiKeyRepository::new(state.db.pool());
match key_repo.count_for_user(user_id).await {
Ok(count) if count >= MAX_API_KEYS_PER_USER => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": format!("Maximum {} API keys allowed per user", MAX_API_KEYS_PER_USER)
})),
);
}
Err(e) => {
tracing::error!("Database error: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": "Failed to create API key"})),
);
}
_ => {}
}
let raw_key = generate_api_key();
let key_hash = hash_api_key(&raw_key);
let key_prefix = raw_key[..11].to_string();
let expires_at = req
.expires_in_days
.map(|days| Utc::now() + Duration::days(days as i64));
let create_key = CreateApiKey {
user_id,
name: req.name.clone(),
key_prefix: key_prefix.clone(),
key_hash,
scopes: scopes.clone(),
rate_limit_rpm: req.rate_limit_rpm,
expires_at,
};
match key_repo.create(create_key).await {
Ok(key) => {
tracing::info!("API key created for user {}: {}", user_id, key.id);
(
StatusCode::CREATED,
Json(serde_json::json!(ApiKeyCreatedResponse {
id: key.id.to_string(),
name: key.name,
key: raw_key,
prefix: key_prefix,
scopes,
rate_limit_rpm: key.rate_limit_rpm,
expires_at: key.expires_at.map(|t| t.to_rfc3339()),
created_at: key.created_at.to_rfc3339(),
warning: "Store this key securely. It will not be shown again.".to_string(),
})),
)
}
Err(e) => {
tracing::error!("Database error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": "Failed to create API key"})),
)
}
}
}
#[cfg(feature = "portal")]
pub async fn revoke_key(
State(state): State<PortalState>,
AuthClaims(claims): AuthClaims,
Path(key_id): Path<Uuid>,
) -> impl IntoResponse {
let user_id = match Uuid::parse_str(&claims.sub) {
Ok(id) => id,
Err(_) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error": "Invalid user ID"})),
);
}
};
let key_repo = ApiKeyRepository::new(state.db.pool());
match key_repo.revoke(key_id, user_id).await {
Ok(()) => {
tracing::info!("API key {} revoked by user {}", key_id, user_id);
(
StatusCode::OK,
Json(serde_json::json!({
"success": true,
"revoked_key_id": key_id.to_string()
})),
)
}
Err(DbError::NotFound) => (
StatusCode::NOT_FOUND,
Json(serde_json::json!({"error": "API key not found"})),
),
Err(e) => {
tracing::error!("Database error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": "Failed to revoke API key"})),
)
}
}
}
#[cfg(feature = "portal")]
pub async fn rotate_key(
State(state): State<PortalState>,
AuthClaims(claims): AuthClaims,
Path(key_id): Path<Uuid>,
) -> impl IntoResponse {
let user_id = match Uuid::parse_str(&claims.sub) {
Ok(id) => id,
Err(_) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error": "Invalid user ID"})),
);
}
};
let key_repo = ApiKeyRepository::new(state.db.pool());
let keys = match key_repo.list_for_user(user_id).await {
Ok(k) => k,
Err(e) => {
tracing::error!("Database error: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": "Failed to rotate API key"})),
);
}
};
let old_key = match keys.into_iter().find(|k| k.id == key_id) {
Some(k) => k,
None => {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({"error": "API key not found"})),
);
}
};
if let Err(e) = key_repo.revoke(key_id, user_id).await {
tracing::error!("Failed to revoke old key: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": "Failed to rotate API key"})),
);
}
let raw_key = generate_api_key();
let key_hash = hash_api_key(&raw_key);
let key_prefix = raw_key[..11].to_string();
let create_key = CreateApiKey {
user_id,
name: format!("{} (rotated)", old_key.name),
key_prefix: key_prefix.clone(),
key_hash,
scopes: old_key.scopes.clone(),
rate_limit_rpm: old_key.rate_limit_rpm,
expires_at: old_key.expires_at,
};
match key_repo.create(create_key).await {
Ok(key) => {
tracing::info!(
"API key {} rotated to {} for user {}",
key_id,
key.id,
user_id
);
(
StatusCode::OK,
Json(serde_json::json!({
"success": true,
"old_key_id": key_id.to_string(),
"new_key": ApiKeyCreatedResponse {
id: key.id.to_string(),
name: key.name,
key: raw_key,
prefix: key_prefix,
scopes: old_key.scopes,
rate_limit_rpm: key.rate_limit_rpm,
expires_at: key.expires_at.map(|t| t.to_rfc3339()),
created_at: key.created_at.to_rfc3339(),
warning: "Store this key securely. It will not be shown again.".to_string(),
}
})),
)
}
Err(e) => {
tracing::error!("Database error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": "Failed to create new API key"})),
)
}
}
}
#[cfg(feature = "portal")]
pub async fn validate_api_key(
pool: &sqlx::PgPool,
raw_key: &str,
) -> Result<crate::portal::db::models::ApiKeyRecord, DbError> {
let key_hash = hash_api_key(raw_key);
let key_repo = ApiKeyRepository::new(pool);
let key = key_repo.find_by_hash(&key_hash).await?;
let _ = key_repo.update_last_used(key.id).await;
Ok(key)
}
#[cfg(feature = "portal")]
pub mod middleware {
use super::*;
use axum::{body::Body, extract::Request, http::header, middleware::Next, response::Response};
fn extract_api_key(req: &Request) -> Option<&str> {
if let Some(auth) = req.headers().get(header::AUTHORIZATION) {
if let Ok(value) = auth.to_str() {
if value.starts_with("Bearer rk_") {
return Some(&value[7..]);
}
}
}
if let Some(key) = req.headers().get("X-API-Key") {
if let Ok(value) = key.to_str() {
if value.starts_with("rk_") {
return Some(value);
}
}
}
None
}
pub async fn require_api_key(
State(state): State<PortalState>,
req: Request,
next: Next,
) -> Result<Response, (StatusCode, Json<serde_json::Value>)> {
let api_key = extract_api_key(&req).ok_or_else(|| {
(
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"error": "Missing API key",
"code": "MISSING_API_KEY"
})),
)
})?;
let key_record = validate_api_key(state.db.pool(), api_key)
.await
.map_err(|e| match e {
DbError::NotFound => (
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"error": "Invalid API key",
"code": "INVALID_API_KEY"
})),
),
_ => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "API key validation failed",
"code": "VALIDATION_ERROR"
})),
),
})?;
if let Some(expires_at) = key_record.expires_at {
if expires_at < Utc::now() {
return Err((
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"error": "API key expired",
"code": "EXPIRED_API_KEY"
})),
));
}
}
let mut req = req;
req.extensions_mut().insert(key_record);
Ok(next.run(req).await)
}
pub fn require_api_scope(
required_scope: &'static str,
) -> impl Fn(
Request,
Next,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<Response, (StatusCode, Json<serde_json::Value>)>,
> + Send,
>,
> + Clone {
move |req: Request, next: Next| {
Box::pin(async move {
let key = req
.extensions()
.get::<crate::portal::db::models::ApiKeyRecord>()
.ok_or_else(|| {
(
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"error": "No API key found",
"code": "NO_API_KEY"
})),
)
})?;
if !key
.scopes
.iter()
.any(|s| s == required_scope || s == "admin")
{
return Err((
StatusCode::FORBIDDEN,
Json(serde_json::json!({
"error": format!("Missing required scope: {}", required_scope),
"code": "INSUFFICIENT_SCOPE"
})),
));
}
Ok(next.run(req).await)
})
}
}
}