use std::time::UNIX_EPOCH;
use awaken_ext_observability::trace_store::{
ReferenceKind, RunSummary, TraceFilter, TraceStoreError,
};
use axum::Json;
use axum::body::Body;
use axum::extract::{Path, Query, State};
use axum::http::{StatusCode, header};
use axum::response::{IntoResponse, Response};
use serde::{Deserialize, Serialize};
use serde_json::json;
use crate::app::TraceRoutesState;
use crate::error::ApiError;
#[derive(Debug, Serialize)]
pub struct RunSummaryWire {
pub run_id: String,
pub agent_id: String,
pub started_at: u64,
pub ended_at: Option<u64>,
pub prompt_ids: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub experiment_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub variant_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub final_status: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub judge_score: Option<f32>,
}
impl From<RunSummary> for RunSummaryWire {
fn from(s: RunSummary) -> Self {
let started_at = s
.started_at
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let ended_at = s
.ended_at
.and_then(|t| t.duration_since(UNIX_EPOCH).map(|d| d.as_secs()).ok());
Self {
run_id: s.run_id,
agent_id: s.agent_id,
started_at,
ended_at,
prompt_ids: s.prompt_ids,
experiment_id: s.experiment_id,
variant_name: s.variant_name,
final_status: s.final_status,
judge_score: s.judge_score,
}
}
}
#[derive(Debug, Serialize)]
pub struct ListTracesResponse {
pub runs: Vec<RunSummaryWire>,
}
#[derive(Debug, Deserialize, Default)]
#[serde(deny_unknown_fields)]
pub struct ListTracesQuery {
pub agent_id: Option<String>,
pub prompt_id: Option<String>,
pub experiment_id: Option<String>,
pub variant_name: Option<String>,
pub limit: Option<usize>,
pub since: Option<String>,
}
fn map_trace_store_error(err: TraceStoreError) -> ApiError {
match err {
TraceStoreError::NotFound { run_id } => {
ApiError::NotFound(format!("trace not found: {run_id}"))
}
TraceStoreError::InvalidRunId(id) => ApiError::BadRequest(format!("invalid run id: {id}")),
err => ApiError::Internal(err.to_string()),
}
}
#[tracing::instrument(skip_all, fields(agent_id = ?params.agent_id))]
pub async fn list_traces(
State(state): State<TraceRoutesState>,
headers: axum::http::HeaderMap,
Query(params): Query<ListTracesQuery>,
) -> Result<Response, ApiError> {
crate::config_routes::ensure_admin_auth(&state.admin, &headers)?;
let store = state.trace.trace_store.clone();
let since = match params.since.as_deref() {
None => None,
Some(s) => match chrono::DateTime::parse_from_rfc3339(s) {
Ok(dt) => Some(std::time::SystemTime::from(dt)),
Err(err) => {
return Ok((
StatusCode::BAD_REQUEST,
Json(json!({
"error": "invalid `since` query parameter; expected RFC 3339 timestamp",
"detail": err.to_string(),
})),
)
.into_response());
}
},
};
if matches!(params.limit, Some(0)) {
return Err(ApiError::BadRequest("`limit` must be >= 1".into()));
}
let filter = TraceFilter {
agent_id: params.agent_id,
prompt_id: params.prompt_id,
experiment_id: params.experiment_id,
variant_name: params.variant_name,
since,
limit: params.limit,
};
let summaries = store.list(&filter).map_err(map_trace_store_error)?;
let runs: Vec<RunSummaryWire> = summaries.into_iter().map(RunSummaryWire::from).collect();
Ok(Json(ListTracesResponse { runs }).into_response())
}
const DEFAULT_TRACE_EVENT_PAGE: usize = 1_000;
#[derive(Debug, Deserialize, Default)]
#[serde(deny_unknown_fields)]
pub struct GetTraceQuery {
pub offset: Option<usize>,
pub limit: Option<usize>,
}
#[tracing::instrument(skip_all, fields(run_id = %run_id))]
pub async fn get_trace(
State(state): State<TraceRoutesState>,
headers: axum::http::HeaderMap,
Path(run_id): Path<String>,
Query(params): Query<GetTraceQuery>,
) -> Result<Response, ApiError> {
crate::config_routes::ensure_admin_auth(&state.admin, &headers)?;
let store = state.trace.trace_store.clone();
let offset = params.offset.unwrap_or(0);
let raw_limit = params.limit.unwrap_or(DEFAULT_TRACE_EVENT_PAGE);
if raw_limit == 0 {
return Err(ApiError::BadRequest("`limit` must be >= 1".into()));
}
let limit = raw_limit.min(DEFAULT_TRACE_EVENT_PAGE);
let events = store.read(&run_id).map_err(map_trace_store_error)?;
let total = events.len();
let end = offset.saturating_add(limit).min(total);
let page = events.get(offset..end).unwrap_or(&[]);
let mut buf = String::new();
for event in page {
match serde_json::to_string(event) {
Ok(line) => {
buf.push_str(&line);
buf.push('\n');
}
Err(err) => {
tracing::warn!(run_id = %run_id, error = %err, "failed to serialise trace event");
}
}
}
let mut builder = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/x-ndjson")
.header("x-trace-total-events", total.to_string());
if end < total {
builder = builder.header("x-trace-next-offset", end.to_string());
}
let resp = builder
.body(Body::from(buf))
.map_err(|e| ApiError::Internal(format!("response build failed: {e}")))?;
Ok(resp)
}
#[tracing::instrument(skip_all, fields(run_id = %run_id))]
pub async fn pin_trace(
State(state): State<TraceRoutesState>,
headers: axum::http::HeaderMap,
Path(run_id): Path<String>,
) -> Result<Response, ApiError> {
crate::config_routes::ensure_admin_auth(&state.admin, &headers)?;
let store = state.trace.trace_store.clone();
store
.mark_referenced(&run_id, ReferenceKind::OperatorPin)
.map_err(map_trace_store_error)?;
Ok(StatusCode::NO_CONTENT.into_response())
}