use crate::circuit_breaker::CircuitBreaker;
use crate::cost::CostEstimator;
use crate::provider::{self, ParsedResponse};
use crate::streaming::{StreamingAccumulator, StreamingOutputMonitor, StreamingSecurityMonitor};
use axum::body::Body;
use axum::extract::State;
use axum::http::{HeaderMap, Request, Response, StatusCode};
use bytes::Bytes;
use chrono::Utc;
use futures_util::StreamExt;
use llmtrace_core::{
AgentAction, AnalysisContext, LLMProvider, ProxyConfig, SecurityAnalyzer, SecurityFinding,
Storage, TenantId, TraceEvent, TraceSpan,
};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
#[derive(Debug, Clone)]
pub enum MlModelStatus {
Disabled,
Loaded {
prompt_injection: bool,
ner: bool,
injecguard: bool,
piguard: bool,
load_time_ms: u64,
},
Failed {
error: String,
},
}
pub struct AppState {
pub config: ProxyConfig,
pub client: Client,
pub storage: Storage,
pub security: Arc<dyn SecurityAnalyzer>,
pub fast_analyzer: Arc<dyn SecurityAnalyzer>,
pub storage_breaker: Arc<CircuitBreaker>,
pub security_breaker: Arc<CircuitBreaker>,
pub cost_estimator: CostEstimator,
pub alert_engine: Option<crate::alerts::AlertEngine>,
pub cost_tracker: Option<crate::cost_caps::CostTracker>,
pub anomaly_detector: Option<crate::anomaly::AnomalyDetector>,
pub report_store: crate::compliance::ReportStore,
pub rate_limiter: Option<crate::rate_limit::RateLimiter>,
pub ml_status: MlModelStatus,
pub shutdown: crate::shutdown::ShutdownCoordinator,
pub metrics: crate::metrics::Metrics,
pub ready: Arc<AtomicBool>,
}
impl AppState {
pub fn metadata(&self) -> &dyn llmtrace_core::MetadataRepository {
self.storage.metadata.as_ref()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LLMRequestBody {
#[serde(default)]
model: String,
#[serde(default)]
messages: Vec<ChatMessage>,
#[serde(default)]
prompt: Option<String>,
#[serde(default)]
stream: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ChatMessage {
role: String,
content: String,
}
pub(crate) fn extract_agent_id(headers: &HeaderMap) -> Option<String> {
headers
.get("x-llmtrace-agent-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
fn extract_api_key(headers: &HeaderMap) -> Option<String> {
headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.map(|s| s.to_string())
}
pub(crate) fn resolve_tenant(headers: &HeaderMap) -> Option<TenantId> {
if let Some(raw) = headers.get("x-llmtrace-tenant-id") {
if let Ok(s) = raw.to_str() {
if let Ok(uuid) = Uuid::parse_str(s) {
return Some(TenantId(uuid));
}
}
}
if let Some(key) = extract_api_key(headers) {
let ns = Uuid::NAMESPACE_URL;
return Some(TenantId(Uuid::new_v5(&ns, key.as_bytes())));
}
None
}
fn messages_to_prompt_text(messages: &[ChatMessage]) -> String {
messages
.iter()
.map(|m| format!("{}: {}", m.role, m.content))
.collect::<Vec<_>>()
.join("\n")
}
fn messages_to_analysis_text(messages: &[ChatMessage]) -> String {
messages
.iter()
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n")
}
fn build_upstream_url(config: &ProxyConfig, path: &str, query: Option<&str>) -> String {
let base = config.upstream_url.trim_end_matches('/');
match query {
Some(q) => format!("{base}{path}?{q}"),
None => format!("{base}{path}"),
}
}
pub async fn proxy_handler(
State(state): State<Arc<AppState>>,
req: Request<Body>,
) -> Response<Body> {
state.metrics.active_connections.inc();
let start_time = Utc::now();
let trace_id = Uuid::new_v4();
let method = req.method().clone();
let uri = req.uri().clone();
let path = uri.path().to_string();
let query = uri.query().map(|q| q.to_string());
let headers = req.headers().clone();
let (tenant_id_opt, _) = crate::auth::resolve_authenticated_tenant(&headers, req.extensions());
let tenant_id = match tenant_id_opt {
Some(id) if !id.0.is_nil() => id,
_ => {
if state.config.auth.enabled {
warn!(%trace_id, "Missing authenticated tenant when auth is enabled");
return error_response(StatusCode::UNAUTHORIZED, "Authentication required");
}
TenantId(Uuid::new_v5(&Uuid::NAMESPACE_OID, b"Unknown"))
}
};
let _api_key = extract_api_key(&headers);
let agent_id = extract_agent_id(&headers);
let detected_provider = provider::detect_provider(&headers, &state.config.upstream_url, &path);
let tenant_config = state
.metadata()
.get_tenant_config(tenant_id)
.await
.ok()
.flatten();
let monitoring_scope = tenant_config
.as_ref()
.map(|c| c.monitoring_scope)
.unwrap_or(llmtrace_core::MonitoringScope::Hybrid);
if !state.config.auth.enabled || tenant_id_opt.is_some() {
let state_ac = Arc::clone(&state);
let name = if tenant_id_opt.is_some() {
_api_key
.as_deref()
.map(|k| {
let prefix_len = k.len().min(8);
format!("key-{}", &k[..prefix_len])
})
.unwrap_or_else(|| format!("tenant-{}", tenant_id.0))
} else {
"Unknown".to_string()
};
tokio::spawn(async move {
crate::tenant_api::ensure_tenant_exists(&state_ac, tenant_id, &name).await;
});
}
if let Some(ref limiter) = state.rate_limiter {
match limiter.check(tenant_id).await {
crate::rate_limit::RateLimitResult::Exceeded {
retry_after_secs,
limit,
tenant_id: tid,
} => {
warn!(
%trace_id,
%tid,
limit,
retry_after_secs,
"Rate limit exceeded"
);
state.metrics.active_connections.dec();
return rate_limit_response(tid, limit, retry_after_secs);
}
crate::rate_limit::RateLimitResult::Allowed => {}
}
}
debug!(
%trace_id,
%tenant_id,
%method,
%path,
provider = ?detected_provider,
"Proxying request"
);
let body_bytes = match axum::body::to_bytes(
req.into_body(),
state.config.max_request_size_bytes as usize,
)
.await
{
Ok(b) => b,
Err(e) => {
warn!(%trace_id, "Failed to read request body: {}", e);
return error_response(StatusCode::BAD_REQUEST, "Failed to read request body");
}
};
let llm_body: Option<LLMRequestBody> = serde_json::from_slice(&body_bytes).ok();
let model_name = llm_body
.as_ref()
.map(|b| b.model.clone())
.unwrap_or_default();
let prompt_text = llm_body
.as_ref()
.map(|b| {
if !b.messages.is_empty() {
messages_to_prompt_text(&b.messages)
} else {
b.prompt.clone().unwrap_or_default()
}
})
.unwrap_or_default();
let analysis_text = llm_body
.as_ref()
.map(|b| {
if !b.messages.is_empty() {
messages_to_analysis_text(&b.messages)
} else {
b.prompt.clone().unwrap_or_default()
}
})
.unwrap_or_default();
if let Some(ref tracker) = state.cost_tracker {
let req_max_tokens: Option<u32> = llm_body
.as_ref()
.and_then(|b| serde_json::to_value(b).ok())
.and_then(|v| v.get("max_tokens").and_then(|t| t.as_u64()))
.map(|t| t as u32);
let token_result = tracker.check_token_caps(
agent_id.as_deref(),
None, req_max_tokens, None,
);
if let crate::cost_caps::CapCheckResult::TokenCapExceeded { reason } = token_result {
warn!(%trace_id, %reason, "Token cap exceeded — rejecting request");
return cap_rejected_response(&reason, 0);
}
let budget_result = tracker
.check_budget_caps(tenant_id, agent_id.as_deref())
.await;
match budget_result {
crate::cost_caps::CapCheckResult::Rejected {
window,
current_spend_usd,
hard_limit_usd,
retry_after_secs,
} => {
let msg = format!(
"{window} budget exceeded: ${current_spend_usd:.4} / ${hard_limit_usd:.2}"
);
warn!(%trace_id, %msg, "Budget cap exceeded — rejecting request");
return cap_rejected_response(&msg, retry_after_secs);
}
crate::cost_caps::CapCheckResult::AllowedWithWarning { warnings } => {
for w in &warnings {
info!(%trace_id, warning = %w, "Cost cap warning");
}
if let Some(ref engine) = state.alert_engine {
let alert_findings: Vec<llmtrace_core::SecurityFinding> = warnings
.iter()
.map(|w| {
llmtrace_core::SecurityFinding::new(
llmtrace_core::SecuritySeverity::Medium,
"cost_cap_warning".to_string(),
w.clone(),
0.9,
)
.with_alert_required(true)
})
.collect();
engine.check_and_alert(trace_id, tenant_id, &alert_findings);
}
}
_ => {}
}
}
let mut flagged_findings: Vec<SecurityFinding> = Vec::new();
if state.config.enable_security_analysis {
let enf_context = AnalysisContext {
tenant_id,
trace_id,
span_id: Uuid::new_v4(),
provider: detected_provider.clone(),
model_name: model_name.clone(),
parameters: std::collections::HashMap::new(),
};
let decision = crate::enforcement::run_enforcement(
&analysis_text,
&enf_context,
&state.config.enforcement,
&state.security,
&state.fast_analyzer,
)
.await;
match decision {
crate::enforcement::EnforcementDecision::Block { reason, findings } => {
warn!(%trace_id, %reason, "Security enforcement blocked request");
state.metrics.active_connections.dec();
return crate::enforcement::blocked_response(&reason, &findings);
}
crate::enforcement::EnforcementDecision::Flag { findings } => {
info!(%trace_id, count = findings.len(), "Security enforcement flagged request");
flagged_findings = findings;
}
crate::enforcement::EnforcementDecision::Allow => {}
}
}
let upstream_url = build_upstream_url(&state.config, &path, query.as_deref());
let mut upstream_req = state.client.request(
reqwest::Method::from_bytes(method.as_str().as_bytes()).unwrap_or(reqwest::Method::POST),
&upstream_url,
);
let mut forwarded_headers = reqwest::header::HeaderMap::new();
for (name, value) in headers.iter() {
if name == "host" || name == "accept-encoding" {
continue;
}
if let Ok(rname) = reqwest::header::HeaderName::from_bytes(name.as_str().as_bytes()) {
if let Ok(rval) = reqwest::header::HeaderValue::from_bytes(value.as_bytes()) {
forwarded_headers.insert(rname, rval);
}
}
}
upstream_req = upstream_req.headers(forwarded_headers);
upstream_req = upstream_req.body(body_bytes.to_vec());
let upstream_response = match upstream_req.send().await {
Ok(resp) => resp,
Err(e) => {
error!(%trace_id, "Upstream request failed: {}", e);
return error_response(StatusCode::BAD_GATEWAY, "Upstream request failed");
}
};
let response_status = upstream_response.status();
let response_headers = upstream_response.headers().clone();
debug!(
%trace_id,
status = %response_status,
"Upstream responded"
);
let response_stream = upstream_response.bytes_stream();
let (body_sender, body_receiver) = tokio::sync::mpsc::channel::<Result<Bytes, String>>(64);
let response_body_stream = async_stream::stream! {
let mut rx = tokio_stream::wrappers::ReceiverStream::new(body_receiver);
while let Some(item) = rx.next().await {
match item {
Ok(bytes) => yield Ok::<_, std::io::Error>(bytes),
Err(e) => yield Err(std::io::Error::other(e)),
}
}
};
let is_streaming = llm_body.as_ref().and_then(|b| b.stream).unwrap_or(false);
let state_bg = Arc::clone(&state);
let prompt_text_bg = prompt_text.clone();
let analysis_text_bg = analysis_text;
let model_name_bg = model_name.clone();
let provider_bg = detected_provider;
let agent_id_bg = agent_id;
let scope_bg = monitoring_scope;
let task_guard = state.shutdown.track_task();
tokio::spawn(async move {
let _guard = task_guard;
let mut stream = response_stream;
let mut sse_accumulator = if is_streaming {
Some(StreamingAccumulator::new())
} else {
None
};
let mut streaming_monitor =
if is_streaming && scope_bg != llmtrace_core::MonitoringScope::OutputOnly {
StreamingSecurityMonitor::new(&state_bg.config.streaming_analysis)
} else {
None
};
let mut output_monitor =
if is_streaming && scope_bg != llmtrace_core::MonitoringScope::InputOnly {
StreamingOutputMonitor::new(
&state_bg.config.streaming_analysis,
&state_bg.config.output_safety,
)
} else {
None
};
let mut raw_collected = Vec::new();
let mut ttft_ms: Option<u64> = None;
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
if let Some(ref mut acc) = sse_accumulator {
let is_first_token = acc.process_chunk(&bytes);
if is_first_token {
let elapsed = Utc::now().signed_duration_since(start_time);
ttft_ms = Some(elapsed.num_milliseconds().max(0) as u64);
}
if let Some(ref mut monitor) = streaming_monitor {
if monitor.should_analyze(acc.completion_token_count) {
let new_findings = monitor
.analyze_incremental(&acc.content, acc.completion_token_count);
if !new_findings.is_empty() {
info!(
%trace_id,
count = new_findings.len(),
tokens = acc.completion_token_count,
"Streaming security findings detected mid-stream"
);
if let Some(ref engine) = state_bg.alert_engine {
engine.check_and_alert(trace_id, tenant_id, &new_findings);
}
}
}
}
if let Some(ref mut out_mon) = output_monitor {
if out_mon.should_analyze(acc.completion_token_count) {
let new_findings = out_mon
.analyze_incremental(&acc.content, acc.completion_token_count);
if !new_findings.is_empty() {
info!(
%trace_id,
count = new_findings.len(),
tokens = acc.completion_token_count,
"Streaming output safety findings detected mid-stream"
);
if let Some(ref engine) = state_bg.alert_engine {
engine.check_and_alert(trace_id, tenant_id, &new_findings);
}
}
}
if out_mon.should_early_stop() {
warn!(
%trace_id,
"Critical output safety issue detected — early stopping stream"
);
let warning = StreamingOutputMonitor::early_stop_sse_event();
let _ = body_sender.send(Ok(Bytes::from(warning))).await;
break;
}
}
}
raw_collected.extend_from_slice(&bytes);
if body_sender.send(Ok(bytes)).await.is_err() {
break;
}
}
Err(e) => {
let err_msg = e.to_string();
let _ = body_sender.send(Err(err_msg)).await;
break;
}
}
}
drop(body_sender);
if let (Some(ref acc), Some(ref mut monitor)) = (&sse_accumulator, &mut streaming_monitor) {
let final_findings =
monitor.analyze_incremental(&acc.content, acc.completion_token_count);
if !final_findings.is_empty() {
info!(
%trace_id,
count = final_findings.len(),
"Streaming security findings in final flush"
);
if let Some(ref engine) = state_bg.alert_engine {
engine.check_and_alert(trace_id, tenant_id, &final_findings);
}
}
}
if let (Some(ref acc), Some(ref mut out_mon)) = (&sse_accumulator, &mut output_monitor) {
let final_findings =
out_mon.analyze_incremental(&acc.content, acc.completion_token_count);
if !final_findings.is_empty() {
info!(
%trace_id,
count = final_findings.len(),
"Streaming output safety findings in final flush"
);
if let Some(ref engine) = state_bg.alert_engine {
engine.check_and_alert(trace_id, tenant_id, &final_findings);
}
}
}
let mut streaming_findings: Vec<SecurityFinding> = streaming_monitor
.as_mut()
.map(|m| m.take_findings())
.unwrap_or_default();
if let Some(ref mut out_mon) = output_monitor {
streaming_findings.extend(out_mon.take_findings());
}
let (response_text, prompt_tokens, completion_tokens, total_tokens) =
if let Some(acc) = sse_accumulator {
(
acc.content.clone(),
acc.prompt_tokens(),
Some(acc.final_completion_tokens()),
acc.total_tokens(),
)
} else {
let ParsedResponse { text, usage } =
provider::parse_response(&provider_bg, &raw_collected);
let response_str =
text.unwrap_or_else(|| String::from_utf8_lossy(&raw_collected).to_string());
(
response_str,
usage.prompt_tokens,
usage.completion_tokens,
usage.total_tokens,
)
};
let auto_actions = if is_streaming {
Vec::new() } else {
provider::extract_tool_calls(&provider_bg, &raw_collected)
};
let captured = CapturedInteraction {
trace_id,
tenant_id,
provider: provider_bg,
model_name: model_name_bg,
prompt_text: prompt_text_bg,
analysis_text: analysis_text_bg,
response_text,
status_code: response_status.as_u16(),
start_time,
is_streaming,
time_to_first_token_ms: ttft_ms,
prompt_tokens,
completion_tokens,
total_tokens,
agent_actions: auto_actions,
monitoring_scope: scope_bg,
};
if let Some(ref tracker) = state_bg.cost_tracker {
let estimated = state_bg.cost_estimator.estimate_cost(
&captured.provider,
&captured.model_name,
captured.prompt_tokens,
captured.completion_tokens,
);
if let Some(cost) = estimated {
tracker
.record_spend(captured.tenant_id, agent_id_bg.as_deref(), cost)
.await;
}
}
let security_start = std::time::Instant::now();
let mut security_findings = run_security_analysis(&state_bg, &captured).await;
let security_ms = security_start.elapsed().as_millis() as u64;
state_bg
.metrics
.record_detector_latency("ensemble", security_ms);
security_findings.extend(streaming_findings);
if let Some(ref detector) = state_bg.anomaly_detector {
let anomaly_findings = detector
.record_and_check(
captured.tenant_id,
state_bg.cost_estimator.estimate_cost(
&captured.provider,
&captured.model_name,
captured.prompt_tokens,
captured.completion_tokens,
),
captured.total_tokens,
captured
.start_time
.signed_duration_since(captured.start_time)
.num_milliseconds()
.max(0)
.try_into()
.ok()
.or_else(|| {
Utc::now()
.signed_duration_since(captured.start_time)
.num_milliseconds()
.try_into()
.ok()
}),
)
.await;
if !anomaly_findings.is_empty() {
info!(
trace_id = %captured.trace_id,
count = anomaly_findings.len(),
"Anomaly findings detected"
);
security_findings.extend(anomaly_findings);
}
}
if let Some(ref engine) = state_bg.alert_engine {
engine.check_and_alert(captured.trace_id, captured.tenant_id, &security_findings);
}
run_trace_capture(&state_bg, &captured, &security_findings).await;
{
let provider_lbl = crate::metrics::provider_label(&captured.provider);
let model_lbl = &captured.model_name;
let duration_secs = Utc::now()
.signed_duration_since(captured.start_time)
.num_milliseconds()
.max(0) as f64
/ 1000.0;
state_bg.metrics.record_request(
provider_lbl,
model_lbl,
captured.status_code,
duration_secs,
);
state_bg.metrics.record_tokens(
provider_lbl,
model_lbl,
captured.prompt_tokens,
captured.completion_tokens,
);
state_bg
.metrics
.record_security_findings(&security_findings);
state_bg.metrics.record_anomalies(&security_findings);
if let Some(cost) = state_bg.cost_estimator.estimate_cost(
&captured.provider,
&captured.model_name,
captured.prompt_tokens,
captured.completion_tokens,
) {
state_bg
.metrics
.record_cost(&captured.tenant_id.0.to_string(), model_lbl, cost);
}
state_bg.metrics.active_connections.dec();
}
});
let mut builder = Response::builder()
.status(StatusCode::from_u16(response_status.as_u16()).unwrap_or(StatusCode::OK));
for (name, value) in response_headers.iter() {
if let Ok(hname) = axum::http::HeaderName::from_bytes(name.as_str().as_bytes()) {
if let Ok(hval) = axum::http::HeaderValue::from_bytes(value.as_bytes()) {
builder = builder.header(hname, hval);
}
}
}
if !flagged_findings.is_empty() {
builder = builder.header("x-llmtrace-flagged", "true");
let summary = crate::enforcement::findings_header_value(&flagged_findings);
if let Ok(hval) = axum::http::HeaderValue::from_str(&summary) {
builder = builder.header("x-llmtrace-findings", hval);
}
}
builder
.body(Body::from_stream(response_body_stream))
.unwrap_or_else(|_| {
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to build response",
)
})
}
struct CapturedInteraction {
trace_id: Uuid,
tenant_id: TenantId,
provider: LLMProvider,
model_name: String,
prompt_text: String,
analysis_text: String,
response_text: String,
status_code: u16,
start_time: chrono::DateTime<Utc>,
is_streaming: bool,
time_to_first_token_ms: Option<u64>,
prompt_tokens: Option<u32>,
completion_tokens: Option<u32>,
total_tokens: Option<u32>,
agent_actions: Vec<AgentAction>,
monitoring_scope: llmtrace_core::MonitoringScope,
}
async fn run_security_analysis(
state: &Arc<AppState>,
captured: &CapturedInteraction,
) -> Vec<SecurityFinding> {
if !state.config.enable_security_analysis {
return Vec::new();
}
if !state.security_breaker.allow().await {
debug!(trace_id = %captured.trace_id, "Security circuit breaker open — skipping analysis");
state.metrics.set_circuit_breaker_state("security", "open");
return Vec::new();
}
let context = AnalysisContext {
tenant_id: captured.tenant_id,
trace_id: captured.trace_id,
span_id: Uuid::new_v4(),
provider: captured.provider.clone(),
model_name: captured.model_name.clone(),
parameters: std::collections::HashMap::new(),
};
let timeout = std::time::Duration::from_millis(state.config.security_analysis_timeout_ms);
let prompt = if captured.monitoring_scope == llmtrace_core::MonitoringScope::OutputOnly {
""
} else {
&captured.analysis_text
};
let response = if captured.monitoring_scope == llmtrace_core::MonitoringScope::InputOnly {
""
} else {
&captured.response_text
};
let analysis_result = tokio::time::timeout(
timeout,
state
.security
.analyze_interaction(prompt, response, &context),
)
.await;
let mut all_findings = match analysis_result {
Ok(Ok(findings)) => {
state.security_breaker.record_success().await;
let cb_state = state.security_breaker.state().await;
state
.metrics
.set_circuit_breaker_state("security", circuit_breaker_state_label(cb_state));
if findings.is_empty() {
debug!(trace_id = %captured.trace_id, "Security analysis: no findings");
} else {
info!(
trace_id = %captured.trace_id,
finding_count = findings.len(),
"Security findings detected"
);
}
findings
}
Ok(Err(e)) => {
state.security_breaker.record_failure().await;
let cb_state = state.security_breaker.state().await;
state
.metrics
.set_circuit_breaker_state("security", circuit_breaker_state_label(cb_state));
error!(trace_id = %captured.trace_id, "Security analysis failed: {}", e);
Vec::new()
}
Err(_elapsed) => {
state.security_breaker.record_failure().await;
let cb_state = state.security_breaker.state().await;
state
.metrics
.set_circuit_breaker_state("security", circuit_breaker_state_label(cb_state));
warn!(
trace_id = %captured.trace_id,
timeout_ms = state.config.security_analysis_timeout_ms,
"Security analysis timed out"
);
Vec::new()
}
};
if state.config.output_safety.enabled
&& !captured.response_text.is_empty()
&& captured.monitoring_scope != llmtrace_core::MonitoringScope::InputOnly
{
let output_analyzer =
llmtrace_security::OutputAnalyzer::new_with_fallback(&state.config.output_safety);
let result = output_analyzer.analyze_output(&captured.response_text);
if !result.findings.is_empty() {
info!(
trace_id = %captured.trace_id,
finding_count = result.findings.len(),
has_critical = result.has_critical_toxicity,
"Output safety findings detected"
);
all_findings.extend(result.findings);
}
}
all_findings
}
async fn run_trace_capture(
state: &Arc<AppState>,
captured: &CapturedInteraction,
security_findings: &[SecurityFinding],
) {
if !state.config.enable_trace_storage {
return;
}
if !state.storage_breaker.allow().await {
debug!(trace_id = %captured.trace_id, "Storage circuit breaker open — skipping trace capture");
state.metrics.set_circuit_breaker_state("storage", "open");
return;
}
let operation = if captured.is_streaming {
"chat_completion_stream"
} else {
"chat_completion"
};
let mut span = TraceSpan::new(
captured.trace_id,
captured.tenant_id,
operation.to_string(),
captured.provider.clone(),
captured.model_name.clone(),
captured.prompt_text.clone(),
)
.finish_with_response(captured.response_text.clone());
span.status_code = Some(captured.status_code);
span.prompt_tokens = captured.prompt_tokens;
span.completion_tokens = captured.completion_tokens;
span.total_tokens = captured.total_tokens;
span.time_to_first_token_ms = captured.time_to_first_token_ms;
span.estimated_cost_usd = state.cost_estimator.estimate_cost(
&captured.provider,
&captured.model_name,
captured.prompt_tokens,
captured.completion_tokens,
);
let end_time = Utc::now();
let duration = end_time.signed_duration_since(captured.start_time);
span.duration_ms = Some(duration.num_milliseconds().max(0) as u64);
for action in &captured.agent_actions {
span.add_agent_action(action.clone());
}
if !captured.agent_actions.is_empty() {
if let Ok(analyzer) = llmtrace_security::RegexSecurityAnalyzer::new() {
let action_findings = analyzer.analyze_agent_actions(&captured.agent_actions);
for finding in action_findings {
span.add_security_finding(finding);
}
}
}
for finding in security_findings {
span.add_security_finding(finding.clone());
}
let trace = TraceEvent {
trace_id: captured.trace_id,
tenant_id: captured.tenant_id,
spans: vec![span],
created_at: captured.start_time,
};
match state.storage.traces.store_trace(&trace).await {
Ok(()) => {
state.storage_breaker.record_success().await;
state.metrics.record_storage_operation("store_trace", true);
let cb_state = state.storage_breaker.state().await;
state
.metrics
.set_circuit_breaker_state("storage", circuit_breaker_state_label(cb_state));
info!(trace_id = %captured.trace_id, "Trace stored successfully");
}
Err(e) => {
state.storage_breaker.record_failure().await;
state.metrics.record_storage_operation("store_trace", false);
let cb_state = state.storage_breaker.state().await;
state
.metrics
.set_circuit_breaker_state("storage", circuit_breaker_state_label(cb_state));
error!(trace_id = %captured.trace_id, "Failed to store trace: {}", e);
}
}
}
pub async fn health_handler(State(state): State<Arc<AppState>>) -> Response<Body> {
let traces_ok = state.storage.traces.health_check().await.is_ok();
let metadata_ok = state.storage.metadata.health_check().await.is_ok();
let cache_ok = state.storage.cache.health_check().await.is_ok();
let security_ok = state.security.health_check().await.is_ok();
let storage_circuit = state.storage_breaker.state().await;
let security_circuit = state.security_breaker.state().await;
let ml_status = match &state.ml_status {
MlModelStatus::Disabled => serde_json::json!({
"status": "disabled",
}),
MlModelStatus::Loaded {
prompt_injection,
ner,
injecguard,
piguard,
load_time_ms,
} => {
let injection_detectors =
1 + (*prompt_injection as u8) + (*injecguard as u8) + (*piguard as u8);
let voting_mode = if injection_detectors >= 3 {
"majority"
} else {
"union"
};
serde_json::json!({
"status": "loaded",
"prompt_injection_model": prompt_injection,
"ner_model": ner,
"injecguard_model": injecguard,
"piguard_model": piguard,
"load_time_ms": load_time_ms,
"injection_detector_count": injection_detectors,
"voting_mode": voting_mode,
})
}
MlModelStatus::Failed { error } => serde_json::json!({
"status": "failed",
"error": error,
}),
};
let all_healthy = traces_ok && metadata_ok && cache_ok && security_ok;
let was_ready = state.ready.load(Ordering::Acquire);
if !was_ready && all_healthy {
state.ready.store(true, Ordering::Release);
}
let is_ready = was_ready || all_healthy;
let (status_label, http_status) = if !is_ready {
("starting", StatusCode::SERVICE_UNAVAILABLE)
} else if all_healthy {
("healthy", StatusCode::OK)
} else {
("degraded", StatusCode::OK)
};
let body = serde_json::json!({
"status": status_label,
"starting": !is_ready,
"storage": {
"traces": { "healthy": traces_ok },
"metadata": { "healthy": metadata_ok },
"cache": { "healthy": cache_ok },
"circuit_breaker": format!("{:?}", storage_circuit),
},
"security": {
"healthy": security_ok,
"circuit_breaker": format!("{:?}", security_circuit),
},
"ml": ml_status,
});
Response::builder()
.status(http_status)
.header("content-type", "application/json")
.body(Body::from(body.to_string()))
.unwrap()
}
fn rate_limit_response(tenant_id: TenantId, limit: u32, retry_after_secs: u64) -> Response<Body> {
let body = serde_json::json!({
"error": {
"message": format!("Rate limit exceeded for tenant {tenant_id}"),
"type": "rate_limit_exceeded",
"tenant_id": tenant_id.0.to_string(),
"limit_requests_per_second": limit,
"retry_after_secs": retry_after_secs,
}
});
let mut builder = Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header("content-type", "application/json")
.header("retry-after", retry_after_secs.to_string());
builder = builder.header("x-ratelimit-limit", limit.to_string());
builder = builder.header("x-ratelimit-remaining", "0");
builder.body(Body::from(body.to_string())).unwrap()
}
fn cap_rejected_response(message: &str, retry_after_secs: u64) -> Response<Body> {
let body = serde_json::json!({
"error": {
"message": message,
"type": "cost_cap_exceeded",
"retry_after_secs": retry_after_secs,
}
});
let mut builder = Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header("content-type", "application/json");
if retry_after_secs > 0 {
builder = builder.header("retry-after", retry_after_secs.to_string());
}
builder.body(Body::from(body.to_string())).unwrap()
}
fn circuit_breaker_state_label(state: crate::circuit_breaker::CircuitState) -> &'static str {
match state {
crate::circuit_breaker::CircuitState::Closed => "closed",
crate::circuit_breaker::CircuitState::Open => "open",
crate::circuit_breaker::CircuitState::HalfOpen => "half_open",
}
}
fn error_response(status: StatusCode, message: &str) -> Response<Body> {
let body = serde_json::json!({
"error": {
"message": message,
"type": "proxy_error",
}
});
Response::builder()
.status(status)
.header("content-type", "application/json")
.body(Body::from(body.to_string()))
.unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_api_key_bearer() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer sk-test-key-123".parse().unwrap());
assert_eq!(
extract_api_key(&headers),
Some("sk-test-key-123".to_string())
);
}
#[test]
fn test_extract_api_key_missing() {
let headers = HeaderMap::new();
assert_eq!(extract_api_key(&headers), None);
}
#[test]
fn test_extract_api_key_no_bearer_prefix() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Basic dXNlcjpwYXNz".parse().unwrap());
assert_eq!(extract_api_key(&headers), None);
}
#[test]
fn test_resolve_tenant_from_header() {
let mut headers = HeaderMap::new();
let tenant_uuid = Uuid::new_v4();
headers.insert(
"x-llmtrace-tenant-id",
tenant_uuid.to_string().parse().unwrap(),
);
let tenant = resolve_tenant(&headers).unwrap();
assert_eq!(tenant.0, tenant_uuid);
}
#[test]
fn test_resolve_tenant_from_api_key() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer sk-my-key".parse().unwrap());
let tenant = resolve_tenant(&headers).unwrap();
let expected = Uuid::new_v5(&Uuid::NAMESPACE_URL, b"sk-my-key");
assert_eq!(tenant.0, expected);
}
#[test]
fn test_resolve_tenant_fallback() {
let headers = HeaderMap::new();
let tenant = resolve_tenant(&headers);
assert!(tenant.is_none());
}
#[test]
fn test_extract_agent_id_present() {
let mut headers = HeaderMap::new();
headers.insert("x-llmtrace-agent-id", "my-agent".parse().unwrap());
assert_eq!(extract_agent_id(&headers), Some("my-agent".to_string()));
}
#[test]
fn test_extract_agent_id_missing() {
let headers = HeaderMap::new();
assert_eq!(extract_agent_id(&headers), None);
}
#[test]
fn test_cap_rejected_response_format() {
let resp = cap_rejected_response("budget exceeded", 3600);
assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
assert_eq!(
resp.headers().get("retry-after").unwrap().to_str().unwrap(),
"3600"
);
}
#[test]
fn test_build_upstream_url_no_query() {
let config = ProxyConfig {
upstream_url: "http://localhost:11434".to_string(),
..ProxyConfig::default()
};
assert_eq!(
build_upstream_url(&config, "/v1/chat/completions", None),
"http://localhost:11434/v1/chat/completions"
);
}
#[test]
fn test_build_upstream_url_with_query() {
let config = ProxyConfig {
upstream_url: "http://localhost:11434/".to_string(),
..ProxyConfig::default()
};
assert_eq!(
build_upstream_url(&config, "/v1/models", Some("format=json")),
"http://localhost:11434/v1/models?format=json"
);
}
#[test]
fn test_messages_to_prompt_text() {
let msgs = vec![
ChatMessage {
role: "system".to_string(),
content: "You are helpful.".to_string(),
},
ChatMessage {
role: "user".to_string(),
content: "Hello!".to_string(),
},
];
let text = messages_to_prompt_text(&msgs);
assert!(text.contains("system: You are helpful."));
assert!(text.contains("user: Hello!"));
}
#[test]
fn test_messages_to_prompt_text_empty() {
let text = messages_to_prompt_text(&[]);
assert!(text.is_empty());
}
#[test]
fn test_messages_to_analysis_text() {
let msgs = vec![
ChatMessage {
role: "system".to_string(),
content: "You are helpful.".to_string(),
},
ChatMessage {
role: "user".to_string(),
content: "Hello!".to_string(),
},
];
let text = messages_to_analysis_text(&msgs);
assert!(text.contains("You are helpful."));
assert!(text.contains("Hello!"));
assert!(
!text.contains("user:"),
"analysis text must not include role prefixes"
);
assert!(
!text.contains("system:"),
"analysis text must not include role prefixes"
);
}
#[test]
fn test_messages_to_analysis_text_empty() {
let text = messages_to_analysis_text(&[]);
assert!(text.is_empty());
}
#[test]
fn test_error_response_format() {
let resp = error_response(StatusCode::BAD_GATEWAY, "upstream down");
assert_eq!(resp.status(), StatusCode::BAD_GATEWAY);
}
}