const MAX_SESSIONS_PER_USER: u32 = 100;
const ROTATED_REASON: &str = "rotated";
const UNKNOWN_REASON: &str = "unspecified";
use axum::{
extract::State,
http::{header, HeaderMap},
response::IntoResponse,
Json,
};
use chrono::{Duration, Utc};
use std::sync::Arc;
use crate::callback::AuthCallback;
use crate::errors::AppError;
use crate::models::{RefreshRequest, RefreshResponse};
use crate::repositories::SessionEntity;
use crate::services::EmailService;
use crate::utils::{
build_json_response_with_cookies, extract_client_ip_with_fallback, extract_cookie,
get_default_org_context, hash_refresh_token, PeerIp,
};
use crate::AppState;
pub async fn refresh<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
PeerIp(peer_ip): PeerIp,
maybe_req: Option<Json<RefreshRequest>>,
) -> Result<impl IntoResponse, AppError> {
let refresh_token = maybe_req.and_then(|Json(req)| req.refresh_token);
let token = if let Some(token) = refresh_token {
token
} else if state.config.cookie.enabled {
extract_cookie(&headers, &state.config.cookie.refresh_cookie_name)
.ok_or(AppError::InvalidToken)?
} else {
return Err(AppError::InvalidToken);
};
let token_hash = hash_refresh_token(&token, &state.config.jwt.secret);
let session = state
.session_repo
.find_by_refresh_token(&token_hash)
.await?
.ok_or(AppError::InvalidToken)?;
if session.revoked_at.is_some() {
let reason = session.revoked_reason.as_deref().unwrap_or(UNKNOWN_REASON);
if !is_token_reuse_reason(reason) {
if !is_known_non_reuse_reason(reason) {
tracing::warn!(reason = %reason, "Unknown revoked_reason; treating as non-reuse");
}
return Err(AppError::TokenExpired);
}
state
.session_repo
.revoke_all_for_user_with_reason(session.user_id, "token_reuse")
.await?;
let ip_address =
extract_client_ip_with_fallback(&headers, state.config.server.trust_proxy, peer_ip);
let user_agent = headers
.get(header::USER_AGENT)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let _ = state
.comms_service
.notify_token_reuse(
session.user_id,
ip_address.as_deref(),
user_agent.as_deref(),
)
.await;
return Err(AppError::TokenExpired);
}
if session.expires_at <= Utc::now() {
return Err(AppError::TokenExpired);
}
let was_revoked = state
.session_repo
.revoke_if_valid_with_reason(session.id, "rotated")
.await?;
if !was_revoked {
state
.session_repo
.revoke_all_for_user_with_reason(session.user_id, "token_reuse")
.await?;
let ip_address =
extract_client_ip_with_fallback(&headers, state.config.server.trust_proxy, peer_ip);
let user_agent = headers
.get(header::USER_AGENT)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let _ = state
.comms_service
.notify_token_reuse(
session.user_id,
ip_address.as_deref(),
user_agent.as_deref(),
)
.await;
return Err(AppError::TokenExpired);
}
let user = state
.user_repo
.find_by_id(session.user_id)
.await?
.ok_or(AppError::InvalidToken)?;
let memberships = state.membership_repo.find_by_user(session.user_id).await?;
let org_ids: Vec<_> = memberships.iter().map(|m| m.org_id).collect();
let orgs = state.org_repo.find_by_ids(&org_ids).await?;
let orgs_by_id: std::collections::HashMap<_, _> = orgs.into_iter().map(|o| (o.id, o)).collect();
let token_context = get_default_org_context(&memberships, &orgs_by_id, user.is_system_admin);
let new_session_id = uuid::Uuid::new_v4();
let token_pair = state.jwt_service.generate_token_pair_with_context(
session.user_id,
new_session_id,
&token_context,
)?;
let refresh_expiry =
Utc::now() + Duration::seconds(state.jwt_service.refresh_expiry_secs() as i64);
let current_ip =
extract_client_ip_with_fallback(&headers, state.config.server.trust_proxy, peer_ip);
let current_user_agent = headers
.get(header::USER_AGENT)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
if let (Some(old_ip), Some(ref new_ip)) = (&session.ip_address, ¤t_ip) {
if old_ip != new_ip {
tracing::info!(
user_id = %session.user_id,
session_id = %session.id,
old_ip = %old_ip,
new_ip = %new_ip,
"SEC-007: IP address changed during token refresh"
);
}
}
let new_session = SessionEntity::new_with_id(
new_session_id,
session.user_id,
hash_refresh_token(&token_pair.refresh_token, &state.config.jwt.secret),
refresh_expiry,
current_ip,
current_user_agent,
);
state.session_repo.create(new_session).await?;
if let Err(e) = state
.session_repo
.revoke_oldest_active_sessions(session.user_id, MAX_SESSIONS_PER_USER)
.await
{
tracing::warn!(
user_id = %session.user_id,
error = %e,
"Failed to enforce session limit during refresh - cleanup will be retried"
);
}
let response_tokens = if state.config.cookie.enabled {
None
} else {
Some(token_pair.clone())
};
let response = RefreshResponse {
tokens: response_tokens,
};
Ok(build_json_response_with_cookies(
&state.config.cookie,
&token_pair,
state.jwt_service.refresh_expiry_secs(),
response,
))
}
fn is_token_reuse_reason(reason: &str) -> bool {
reason == ROTATED_REASON
}
fn is_known_non_reuse_reason(reason: &str) -> bool {
matches!(
reason,
"logout"
| "logout_all"
| "user_revoke_other_sessions"
| "password_reset"
| "org_switch"
| "org_switch_cleanup"
| "session_limit"
| UNKNOWN_REASON
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_reuse_reason() {
assert!(is_token_reuse_reason("rotated"));
assert!(!is_token_reuse_reason("logout"));
}
#[test]
fn test_known_non_reuse_reason() {
assert!(is_known_non_reuse_reason("logout"));
assert!(is_known_non_reuse_reason(UNKNOWN_REASON));
assert!(!is_known_non_reuse_reason("new_reason"));
}
}