use axum::extract::Request;
use axum::http::header;
use axum::middleware::Next;
use axum::response::Response;
use std::sync::Arc;
use systemprompt_analytics::AnalyticsService;
use systemprompt_identifiers::{AgentName, ClientId, ContextId, SessionId, SessionSource, UserId};
use systemprompt_models::api::ApiError;
use systemprompt_models::auth::UserType;
use systemprompt_models::execution::context::RequestContext;
use systemprompt_models::modules::ApiPaths;
use systemprompt_oauth::services::{
CreateAnonymousSessionInput, SessionCreationError, SessionCreationService,
};
use systemprompt_runtime::AppContext;
use systemprompt_security::{HeaderExtractor, TokenExtractor};
use systemprompt_traits::AnalyticsProvider;
use systemprompt_users::{UserProviderImpl, UserService};
use uuid::Uuid;
use super::jwt::JwtExtractor;
#[derive(Clone, Debug)]
pub struct SessionMiddleware {
jwt_extractor: Arc<JwtExtractor>,
analytics_service: Arc<AnalyticsService>,
session_creation_service: Arc<SessionCreationService>,
}
impl SessionMiddleware {
pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?;
let jwt_extractor = Arc::new(JwtExtractor::new(jwt_secret));
let user_service = UserService::new(ctx.db_pool())?;
let session_creation_service = Arc::new(SessionCreationService::new(
ctx.analytics_service().clone(),
Arc::new(UserProviderImpl::new(user_service)),
));
Ok(Self {
jwt_extractor,
analytics_service: ctx.analytics_service().clone(),
session_creation_service,
})
}
pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
let headers = request.headers();
let uri = request.uri().clone();
let method = request.method().clone();
let should_skip = Self::should_skip_session_tracking(uri.path());
tracing::debug!(
path = %uri.path(),
should_skip = should_skip,
"Session middleware evaluating request"
);
let trace_id = HeaderExtractor::extract_trace_id(headers);
let (req_ctx, jwt_cookie) = if should_skip {
let ctx = RequestContext::new(
SessionId::new(format!("untracked_{}", Uuid::new_v4())),
trace_id,
ContextId::new(String::new()),
AgentName::system(),
)
.with_user_id(UserId::new("anonymous".to_string()))
.with_user_type(UserType::Anon)
.with_tracked(false);
(ctx, None)
} else {
let analytics = self
.analytics_service
.extract_analytics(headers, Some(&uri));
let is_bot = AnalyticsService::is_bot(&analytics);
tracing::debug!(
path = %uri.path(),
is_bot = is_bot,
user_agent = ?analytics.user_agent,
"Session middleware bot check"
);
if is_bot {
let ctx = RequestContext::new(
SessionId::new(format!("bot_{}", Uuid::new_v4())),
trace_id,
ContextId::new(String::new()),
AgentName::system(),
)
.with_user_id(UserId::new("bot".to_string()))
.with_user_type(UserType::Anon)
.with_tracked(false);
(ctx, None)
} else {
let token_result = TokenExtractor::browser_only().extract(headers).ok();
let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) = if let Some(
token,
) =
token_result
{
if let Ok(jwt_context) = self.jwt_extractor.extract_user_context(&token) {
let session_exists = self
.analytics_service
.find_session_by_id(&jwt_context.session_id)
.await
.ok()
.flatten()
.is_some();
if session_exists {
(
jwt_context.session_id,
jwt_context.user_id,
token,
None,
None,
)
} else {
tracing::info!(
old_session_id = %jwt_context.session_id,
user_id = %jwt_context.user_id,
"JWT valid but session missing, refreshing with new session"
);
match self
.refresh_session_for_user(&jwt_context.user_id, headers, &uri)
.await
{
Ok((sid, uid, new_token, _, fp)) => {
(sid, uid, new_token.clone(), Some(new_token), Some(fp))
},
Err(e) if e.error_key.as_deref() == Some("user_not_found") => {
tracing::warn!(
user_id = %jwt_context.user_id,
"JWT references non-existent user, creating new anonymous session"
);
let (sid, uid, token, _, fp) =
self.create_new_session(headers, &uri, &method).await?;
(sid, uid, token.clone(), Some(token), Some(fp))
},
Err(e) => return Err(e),
}
}
} else {
let (sid, uid, token, is_new, fp) =
self.create_new_session(headers, &uri, &method).await?;
let jwt_cookie = if is_new { Some(token.clone()) } else { None };
(sid, uid, token, jwt_cookie, Some(fp))
}
} else {
let (sid, uid, token, is_new, fp) =
self.create_new_session(headers, &uri, &method).await?;
let jwt_cookie = if is_new { Some(token.clone()) } else { None };
(sid, uid, token, jwt_cookie, Some(fp))
};
let mut ctx = RequestContext::new(
session_id,
trace_id,
ContextId::new(String::new()),
AgentName::system(),
)
.with_user_id(user_id)
.with_auth_token(jwt_token)
.with_user_type(UserType::Anon)
.with_tracked(true);
if let Some(fp) = fingerprint_hash {
ctx = ctx.with_fingerprint_hash(fp);
}
(ctx, jwt_cookie)
}
};
tracing::debug!(
path = %uri.path(),
session_id = %req_ctx.session_id(),
"Session middleware setting context"
);
request.extensions_mut().insert(req_ctx);
let mut response = next.run(request).await;
if let Some(token) = jwt_cookie {
let cookie =
format!("access_token={token}; HttpOnly; SameSite=Lax; Path=/; Max-Age=604800");
if let Ok(cookie_value) = cookie.parse() {
response
.headers_mut()
.insert(header::SET_COOKIE, cookie_value);
}
}
Ok(response)
}
async fn create_new_session(
&self,
headers: &http::HeaderMap,
uri: &http::Uri,
_method: &http::Method,
) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
let client_id = ClientId::new("sp_web".to_string());
let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
tracing::error!(error = %e, "Failed to get JWT secret during session creation");
ApiError::internal_error("Failed to initialize session")
})?;
self.session_creation_service
.create_anonymous_session(CreateAnonymousSessionInput {
headers,
uri: Some(uri),
client_id: &client_id,
jwt_secret,
session_source: SessionSource::Web,
})
.await
.map(|session_info| {
(
session_info.session_id,
session_info.user_id,
session_info.jwt_token,
session_info.is_new,
session_info.fingerprint_hash,
)
})
.map_err(|e| {
tracing::error!(error = %e, "Failed to create anonymous session");
ApiError::internal_error("Service temporarily unavailable")
})
}
async fn refresh_session_for_user(
&self,
user_id: &UserId,
headers: &http::HeaderMap,
uri: &http::Uri,
) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
let session_id = self
.session_creation_service
.create_authenticated_session(user_id, headers, SessionSource::Web)
.await
.map_err(|e| match e {
SessionCreationError::UserNotFound { ref user_id } => {
ApiError::not_found(format!("User not found: {}", user_id))
.with_error_key("user_not_found")
}
SessionCreationError::Internal(ref msg) => {
tracing::error!(error = %msg, user_id = %user_id, "Failed to create session for user");
ApiError::internal_error("Failed to refresh session")
}
})?;
let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
tracing::error!(error = %e, "Failed to get JWT secret during session refresh");
ApiError::internal_error("Failed to refresh session")
})?;
let config = systemprompt_models::Config::get().map_err(|e| {
tracing::error!(error = %e, "Failed to get config during session refresh");
ApiError::internal_error("Failed to refresh session")
})?;
let token = systemprompt_oauth::services::generation::generate_anonymous_jwt(
user_id.as_str(),
session_id.as_str(),
&ClientId::new("sp_web".to_string()),
&systemprompt_oauth::services::JwtSigningParams {
secret: jwt_secret,
issuer: &config.jwt_issuer,
},
)
.map_err(|e| {
tracing::error!(error = %e, "Failed to generate JWT during session refresh");
ApiError::internal_error("Failed to refresh session")
})?;
let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
let fingerprint = AnalyticsService::compute_fingerprint(&analytics);
Ok((session_id, user_id.clone(), token, true, fingerprint))
}
fn should_skip_session_tracking(path: &str) -> bool {
if path.starts_with(ApiPaths::TRACK_BASE) {
return false;
}
if path.starts_with(ApiPaths::MCP_BASE) {
return true;
}
if path.starts_with(ApiPaths::API_BASE) {
return true;
}
if path.starts_with(ApiPaths::NEXT_BASE) {
return true;
}
if path.starts_with(ApiPaths::STATIC_BASE)
|| path.starts_with(ApiPaths::ASSETS_BASE)
|| path.starts_with(ApiPaths::IMAGES_BASE)
{
return true;
}
if path == "/health" || path == "/ready" || path == "/healthz" {
return true;
}
if path == "/favicon.ico"
|| path == "/robots.txt"
|| path == "/sitemap.xml"
|| path == "/manifest.json"
{
return true;
}
if let Some(last_segment) = path.rsplit('/').next() {
if last_segment.contains('.') {
let extension = last_segment.rsplit('.').next().unwrap_or("");
match extension {
"html" | "htm" => {},
_ => return true,
}
}
}
false
}
}