use std::sync::Arc;
use axum::Json;
use axum::extract::State;
use axum::http::HeaderMap;
use axum::response::IntoResponse;
use forge_core::AuthContext;
use forge_core::signals::{
DiagnosticReport, PageViewPayload, SignalEvent, SignalEventBatch, SignalEventType,
SignalPayload, SignalResponse, UtmParams,
};
use serde_json::Value;
use sqlx::PgPool;
use uuid::Uuid;
use super::bot;
use super::collector::SignalsCollector;
use super::device;
use super::rate_limit::SignalRateLimiter;
use super::session;
use super::visitor;
const MAX_BATCH_SIZE: usize = 50;
fn dnt_opted_out(headers: &HeaderMap) -> bool {
let dnt = extract_header(headers, "dnt");
if dnt.as_deref() == Some("1") {
return true;
}
let gpc = extract_header(headers, "sec-gpc");
gpc.as_deref() == Some("1")
}
#[derive(Clone)]
pub struct SignalsState {
pub collector: SignalsCollector,
pub pool: PgPool,
pub server_secret: String,
pub anonymize_ip: bool,
pub geoip: Option<super::geoip::GeoIpResolver>,
pub rate_limiter: Arc<SignalRateLimiter>,
}
fn resolve_rate_limit_ip(
resolved_ip: &Option<axum::Extension<crate::gateway::ResolvedClientIp>>,
_headers: &HeaderMap,
) -> Option<String> {
resolved_ip.as_ref().and_then(|r| r.0.0.clone())
}
fn rate_limited_response() -> Json<SignalResponse> {
Json(SignalResponse {
ok: false,
session_id: None,
})
}
pub async fn signal_handler(
State(state): State<Arc<SignalsState>>,
resolved_ip: Option<axum::Extension<crate::gateway::ResolvedClientIp>>,
auth: Option<axum::Extension<AuthContext>>,
headers: HeaderMap,
Json(payload): Json<SignalPayload>,
) -> impl IntoResponse {
let limiter_ip = resolve_rate_limit_ip(&resolved_ip, &headers);
if !state.rate_limiter.check(limiter_ip.as_deref()) {
return rate_limited_response();
}
match payload {
SignalPayload::Event(batch) => {
handle_event(&state, resolved_ip, &auth, &headers, batch).await
}
SignalPayload::View(view) => handle_view(&state, resolved_ip, &auth, &headers, view).await,
SignalPayload::Report(report) => {
handle_report(&state, resolved_ip, &auth, &headers, report).await
}
}
}
async fn handle_event(
state: &SignalsState,
resolved_ip: Option<axum::Extension<crate::gateway::ResolvedClientIp>>,
auth: &Option<axum::Extension<AuthContext>>,
headers: &HeaderMap,
batch: SignalEventBatch,
) -> Json<SignalResponse> {
if dnt_opted_out(headers) {
return Json(SignalResponse {
ok: true,
session_id: None,
});
}
if batch.events.len() > MAX_BATCH_SIZE {
return rate_limited_response();
}
let ctx = extract_request_ctx(
headers,
resolved_ip.and_then(|r| r.0.0.clone()),
auth,
&state.server_secret,
state.anonymize_ip,
state.geoip.as_ref(),
);
let session_id =
resolve_session_id(batch.context.as_ref().and_then(|c| c.session_id.as_deref()));
let page_url = batch.context.as_ref().and_then(|c| c.page_url.clone());
let session_id = session::upsert_session(
&state.pool,
session_id,
&ctx.visitor_id,
ctx.user_id,
ctx.tenant_id,
page_url.as_deref(),
batch.context.as_ref().and_then(|c| c.referrer.as_deref()),
ctx.user_agent.as_deref(),
ctx.client_ip.as_deref(),
ctx.is_bot,
"track",
ctx.device_type.as_deref(),
ctx.browser.as_deref(),
ctx.os.as_deref(),
)
.await;
for event in batch.events {
let signal = SignalEvent {
event_type: SignalEventType::Track,
event_name: Some(event.event),
correlation_id: event.correlation_id,
session_id,
visitor_id: Some(ctx.visitor_id.clone()),
user_id: ctx.user_id,
tenant_id: ctx.tenant_id,
properties: event.properties,
page_url: page_url.clone(),
referrer: None,
function_name: None,
function_kind: None,
duration_ms: None,
status: None,
error_message: None,
error_stack: None,
error_context: None,
client_ip: ctx.client_ip.clone(),
country: ctx.country.clone(),
city: ctx.city.clone(),
user_agent: ctx.user_agent.clone(),
device_type: ctx.device_type.clone(),
browser: ctx.browser.clone(),
os: ctx.os.clone(),
utm: None,
is_bot: ctx.is_bot,
timestamp: event.timestamp.unwrap_or_else(chrono::Utc::now),
};
state.collector.try_send(signal);
}
Json(SignalResponse {
ok: true,
session_id,
})
}
async fn handle_view(
state: &SignalsState,
resolved_ip: Option<axum::Extension<crate::gateway::ResolvedClientIp>>,
auth: &Option<axum::Extension<AuthContext>>,
headers: &HeaderMap,
payload: PageViewPayload,
) -> Json<SignalResponse> {
if dnt_opted_out(headers) {
return Json(SignalResponse {
ok: true,
session_id: None,
});
}
let ctx = extract_request_ctx(
headers,
resolved_ip.and_then(|r| r.0.0.clone()),
auth,
&state.server_secret,
state.anonymize_ip,
state.geoip.as_ref(),
);
let session_id_header = extract_header(headers, "x-session-id");
let session_id = resolve_session_id(session_id_header.as_deref());
let session_id = session::upsert_session(
&state.pool,
session_id,
&ctx.visitor_id,
ctx.user_id,
ctx.tenant_id,
Some(&payload.url),
payload.referrer.as_deref(),
ctx.user_agent.as_deref(),
ctx.client_ip.as_deref(),
ctx.is_bot,
"page_view",
ctx.device_type.as_deref(),
ctx.browser.as_deref(),
ctx.os.as_deref(),
)
.await;
let utm = if payload.utm_source.is_some()
|| payload.utm_medium.is_some()
|| payload.utm_campaign.is_some()
{
Some(UtmParams {
source: payload.utm_source,
medium: payload.utm_medium,
campaign: payload.utm_campaign,
term: payload.utm_term,
content: payload.utm_content,
})
} else {
None
};
let signal = SignalEvent {
event_type: SignalEventType::PageView,
event_name: payload.title,
correlation_id: payload.correlation_id,
session_id,
visitor_id: Some(ctx.visitor_id),
user_id: ctx.user_id,
tenant_id: ctx.tenant_id,
properties: Value::Object(serde_json::Map::new()),
page_url: Some(payload.url),
referrer: payload.referrer,
function_name: None,
function_kind: None,
duration_ms: None,
status: None,
error_message: None,
error_stack: None,
error_context: None,
client_ip: ctx.client_ip,
country: ctx.country,
city: ctx.city,
user_agent: ctx.user_agent,
device_type: ctx.device_type,
browser: ctx.browser,
os: ctx.os,
utm,
is_bot: ctx.is_bot,
timestamp: chrono::Utc::now(),
};
state.collector.try_send(signal);
Json(SignalResponse {
ok: true,
session_id,
})
}
async fn handle_report(
state: &SignalsState,
resolved_ip: Option<axum::Extension<crate::gateway::ResolvedClientIp>>,
auth: &Option<axum::Extension<AuthContext>>,
headers: &HeaderMap,
report: DiagnosticReport,
) -> Json<SignalResponse> {
if report.errors.len() > MAX_BATCH_SIZE {
return rate_limited_response();
}
let ctx = extract_request_ctx(
headers,
resolved_ip.and_then(|r| r.0.0.clone()),
auth,
&state.server_secret,
state.anonymize_ip,
state.geoip.as_ref(),
);
let session_id_header = extract_header(headers, "x-session-id");
let session_id = resolve_session_id(session_id_header.as_deref());
if let Some(sid) = session_id {
session::upsert_session(
&state.pool,
Some(sid),
&ctx.visitor_id,
ctx.user_id,
ctx.tenant_id,
None,
None,
ctx.user_agent.as_deref(),
ctx.client_ip.as_deref(),
ctx.is_bot,
"error",
ctx.device_type.as_deref(),
ctx.browser.as_deref(),
ctx.os.as_deref(),
)
.await;
}
for err in report.errors {
let signal = SignalEvent {
event_type: SignalEventType::Error,
event_name: Some(err.message.clone()),
correlation_id: err.correlation_id,
session_id,
visitor_id: Some(ctx.visitor_id.clone()),
user_id: ctx.user_id,
tenant_id: ctx.tenant_id,
properties: Value::Object(serde_json::Map::new()),
page_url: err.page_url,
referrer: None,
function_name: None,
function_kind: None,
duration_ms: None,
status: None,
error_message: Some(err.message),
error_stack: err.stack,
error_context: err.context,
client_ip: ctx.client_ip.clone(),
country: ctx.country.clone(),
city: ctx.city.clone(),
user_agent: ctx.user_agent.clone(),
device_type: ctx.device_type.clone(),
browser: ctx.browser.clone(),
os: ctx.os.clone(),
utm: None,
is_bot: ctx.is_bot,
timestamp: chrono::Utc::now(),
};
state.collector.try_send(signal);
}
Json(SignalResponse {
ok: true,
session_id,
})
}
struct RequestCtx {
user_agent: Option<String>,
client_ip: Option<String>,
country: Option<String>,
city: Option<String>,
is_bot: bool,
visitor_id: String,
user_id: Option<Uuid>,
tenant_id: Option<Uuid>,
device_type: Option<String>,
browser: Option<String>,
os: Option<String>,
}
fn extract_request_ctx(
headers: &HeaderMap,
resolved_ip: Option<String>,
auth: &Option<axum::Extension<AuthContext>>,
server_secret: &str,
anonymize_ip: bool,
geoip: Option<&super::geoip::GeoIpResolver>,
) -> RequestCtx {
let user_agent = extract_header(headers, "user-agent");
let platform_header = extract_header(headers, "x-forge-platform");
let raw_ip = resolved_ip;
let ua_lower = user_agent
.as_deref()
.unwrap_or_default()
.to_ascii_lowercase();
let is_bot = bot::is_bot_lower(&ua_lower);
let visitor_id =
visitor::generate_visitor_id(raw_ip.as_deref(), user_agent.as_deref(), server_secret);
let user_id = auth.as_ref().and_then(|a| a.user_id());
let tenant_id = auth.as_ref().and_then(|a| a.tenant_id());
let device_info = device::parse_lowered(platform_header.as_deref(), &ua_lower);
let geo = geoip
.zip(raw_ip.as_deref())
.map(|(g, ip)| g.lookup(ip))
.unwrap_or_default();
let client_ip = if anonymize_ip { None } else { raw_ip };
RequestCtx {
user_agent,
client_ip,
country: geo.country,
city: geo.city,
is_bot,
visitor_id,
user_id,
tenant_id,
device_type: device_info.device_type,
browser: device_info.browser,
os: device_info.os,
}
}
fn extract_header(headers: &HeaderMap, name: &str) -> Option<String> {
crate::gateway::extract_header(headers, name)
}
fn resolve_session_id(raw: Option<&str>) -> Option<Uuid> {
raw.and_then(|s| Uuid::parse_str(s).ok())
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use axum::http::{HeaderMap, HeaderValue};
use uuid::Uuid;
use super::{extract_header, resolve_session_id};
#[tokio::test]
async fn extract_header_returns_value() {
let mut headers = HeaderMap::new();
headers.insert("x-custom", HeaderValue::from_static("hello"));
assert_eq!(extract_header(&headers, "x-custom"), Some("hello".into()));
}
#[tokio::test]
async fn extract_header_returns_none_for_missing() {
let headers = HeaderMap::new();
assert_eq!(extract_header(&headers, "x-custom"), None);
}
#[tokio::test]
async fn extract_header_returns_none_for_empty_value() {
let mut headers = HeaderMap::new();
headers.insert("x-custom", HeaderValue::from_static(""));
assert_eq!(extract_header(&headers, "x-custom"), None);
}
#[tokio::test]
async fn resolve_session_id_parses_valid_uuid() {
let raw = "550e8400-e29b-41d4-a716-446655440000";
let expected = Uuid::parse_str(raw).unwrap();
assert_eq!(resolve_session_id(Some(raw)), Some(expected));
}
#[tokio::test]
async fn resolve_session_id_returns_none_for_garbage() {
assert_eq!(resolve_session_id(Some("not-a-uuid")), None);
}
#[tokio::test]
async fn resolve_session_id_returns_none_for_none() {
assert_eq!(resolve_session_id(None), None);
}
}