pub mod auth;
pub mod storage;
use auth::{Auth, AuthConfig};
use axum::{
Json, Router,
body::Body,
extract::{Path, State},
http::{HeaderMap, HeaderValue, Method, StatusCode, header},
response::Response,
routing::{get, post},
};
use sqlx::{PgPool, Row, SqlitePool, postgres::PgPoolOptions, sqlite::SqlitePoolOptions};
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use tower_http::cors::{AllowOrigin, CorsLayer};
use trace_weft_core::{BlobHash, SpanRecord};
use trace_weft_recorder::TraceStore;
#[derive(Clone)]
pub enum DbPool {
Sqlite(SqlitePool),
Postgres(PgPool),
}
#[derive(Clone)]
pub struct AppState {
pub pool: DbPool,
pub blob_store: Arc<dyn trace_weft_core::BlobStore>,
pub trace_store: Arc<dyn TraceStore>,
pub clickhouse: Option<Arc<storage::analytics::ClickHouseAnalytics>>,
pub auth: Arc<AuthConfig>,
}
pub async fn start_server(db_url: &str, port: u16, blob_dir: PathBuf) -> anyhow::Result<()> {
start_server_with_shutdown(
db_url,
port,
blob_dir,
AuthConfig::from_env(),
std::future::pending::<()>(),
)
.await
}
pub async fn start_dev_server(db_url: &str, port: u16, blob_dir: PathBuf) -> anyhow::Result<()> {
start_server_with_shutdown(
db_url,
port,
blob_dir,
AuthConfig::from_env_local_first(),
std::future::pending::<()>(),
)
.await
}
pub async fn start_server_with_shutdown(
db_url: &str,
port: u16,
blob_dir: PathBuf,
auth: AuthConfig,
shutdown: impl std::future::Future<Output = ()> + Send + 'static,
) -> anyhow::Result<()> {
let pool = if db_url.starts_with("postgres://") || db_url.starts_with("postgresql://") {
let pg_pool = PgPoolOptions::new().connect(db_url).await?;
DbPool::Postgres(pg_pool)
} else {
let url = if db_url.starts_with("sqlite://") {
db_url.to_string()
} else {
if let Some(parent) = std::path::Path::new(db_url).parent() {
tokio::fs::create_dir_all(parent).await?;
}
format!("sqlite://{}?mode=rwc", db_url)
};
let sq_pool = SqlitePoolOptions::new().connect(&url).await?;
DbPool::Sqlite(sq_pool)
};
let blob_store = Arc::new(storage::blob::LocalBlobStore::new(blob_dir));
let trace_store: Arc<dyn TraceStore> = match &pool {
DbPool::Postgres(pg_pool) => {
Arc::new(storage::postgres::PostgresRecorder::from_pool(pg_pool.clone()).await?)
}
DbPool::Sqlite(sq_pool) => {
Arc::new(trace_weft_recorder::sqlite::SqliteRecorder::from_pool(sq_pool.clone()).await?)
}
};
let clickhouse = if let Ok(ch_url) = std::env::var("TRACE_WEFT_CH_URL") {
tracing::info!("Initializing ClickHouse analytics connected to {}", ch_url);
Some(Arc::new(storage::analytics::ClickHouseAnalytics::new(
&ch_url, "default", "", "default",
)))
} else {
None
};
let state = AppState {
pool,
blob_store,
trace_store,
clickhouse,
auth: Arc::new(auth),
};
let app = build_router(state);
let addr = SocketAddr::from(([127, 0, 0, 1], port));
tracing::info!("Server listening on http://{}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app)
.with_graceful_shutdown(shutdown)
.await?;
Ok(())
}
pub fn build_router(state: AppState) -> Router {
Router::new()
.route("/api/traces", get(list_traces))
.route("/api/traces/{trace_id}", get(get_trace))
.route("/api/traces/{trace_id}/events", get(get_trace_events))
.route(
"/api/traces/{trace_id}/replay-plan/{span_id}",
get(get_replay_plan),
)
.route("/api/diff/{trace_a}/{trace_b}", get(get_trace_diff))
.route("/api/blobs/{hash}", get(get_blob))
.route("/api/openapi.json", get(openapi_contract))
.route("/api/evals", get(list_evals))
.route("/api/v1/batch", post(batch_ingest))
.route("/v1/traces", post(otlp_traces_ingest))
.route("/api/replay/config", post(generate_replay_config))
.route("/api/hitl/pending", get(get_pending_approvals))
.route("/api/hitl/resolve", post(resolve_approval))
.layer(local_cors())
.with_state(state)
}
fn local_cors() -> CorsLayer {
CorsLayer::new()
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
.allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE])
.allow_origin(AllowOrigin::predicate(|origin: &HeaderValue, _req| {
origin.to_str().map(is_allowed_origin).unwrap_or(false)
}))
}
fn is_allowed_origin(origin: &str) -> bool {
if origin == "tauri://localhost" || origin == "http://tauri.localhost" {
return true;
}
["http://localhost", "http://127.0.0.1"]
.iter()
.any(|host| origin == *host || origin.starts_with(&format!("{host}:")))
}
fn authorize(state: &AppState, headers: &HeaderMap) -> Result<Auth, StatusCode> {
state
.auth
.authenticate(headers)
.ok_or(StatusCode::UNAUTHORIZED)
}
async fn batch_ingest(
headers: HeaderMap,
State(state): State<AppState>,
Json(mut spans): Json<Vec<SpanRecord>>,
) -> Result<StatusCode, StatusCode> {
let auth = authorize(&state, &headers)?;
let project_id = auth.project().map(|p| p.to_string());
for span in &mut spans {
span.project_id = project_id.clone();
}
tracing::info!(
"Received batch of {} spans for project {:?}",
spans.len(),
project_id
);
for span in &spans {
if let Err(e) = state.trace_store.record_span(span.clone()).await {
tracing::error!("Failed to record span: {}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
}
if let Some(ch) = &state.clickhouse
&& let Err(e) = ch.ingest_batch(&spans).await
{
tracing::warn!("Failed to stream to ClickHouse: {}", e);
}
Ok(StatusCode::ACCEPTED)
}
async fn otlp_traces_ingest(
headers: HeaderMap,
State(state): State<AppState>,
body: axum::body::Bytes,
) -> Result<Json<serde_json::Value>, StatusCode> {
let auth = authorize(&state, &headers)?;
let project_id = auth.project().map(|p| p.to_string());
let mut spans = trace_weft_ingest::records_from_otlp_json(&body).map_err(|e| {
tracing::warn!("rejecting malformed OTLP payload: {e}");
StatusCode::BAD_REQUEST
})?;
for span in &mut spans {
span.project_id = project_id.clone();
}
tracing::info!(
"Received OTLP export of {} spans for project {:?}",
spans.len(),
project_id
);
for span in &spans {
if let Err(e) = state.trace_store.record_span(span.clone()).await {
tracing::error!("Failed to record OTLP span: {}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
}
if let Some(ch) = &state.clickhouse
&& let Err(e) = ch.ingest_batch(&spans).await
{
tracing::warn!("Failed to stream OTLP spans to ClickHouse: {}", e);
}
Ok(Json(serde_json::json!({})))
}
fn db_error<E: std::fmt::Display>(e: E) -> StatusCode {
tracing::error!("database query failed: {e}");
StatusCode::INTERNAL_SERVER_ERROR
}
fn parse_json_column(raw: &str) -> Result<serde_json::Value, StatusCode> {
serde_json::from_str(raw).map_err(|e| {
tracing::error!("corrupt JSON in spans column: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})
}
fn parse_opt_json_column(raw: Option<String>) -> Result<serde_json::Value, StatusCode> {
match raw {
Some(s) => parse_json_column(&s),
None => Ok(serde_json::Value::Null),
}
}
macro_rules! trace_summary_json {
($row:expr) => {{
let row = $row;
let trace_id: String = row.get("trace_id");
let run_id: String = row.get("run_id");
let start_time: i64 = row.get("start_time");
let end_time: Option<i64> = row.get("end_time");
let span_count: i64 = row.get("span_count");
let has_error: i64 = row.get("has_error");
let root_name: Option<String> = row.get("root_name");
let root_span_kind: Option<String> = row.get("root_span_kind");
let model_provider: Option<String> = row.get("model_provider");
let model_name: Option<String> = row.get("model_name");
let error_summary: Option<String> = row.get("error_summary");
serde_json::json!({
"trace_id": trace_id,
"run_id": run_id,
"start_time": start_time,
"end_time": end_time,
"span_count": span_count,
"root_name": root_name,
"root_span_kind": root_span_kind,
"model_provider": model_provider,
"model_name": model_name,
"error_summary": error_summary,
"status": if has_error != 0 { "error" } else { "ok" },
})
}};
}
macro_rules! eval_row_json {
($row:expr) => {{
let row = $row;
let trace_id: String = row.get("trace_id");
let span_id: String = row.get("span_id");
let name: String = row.get("name");
let start_time: i64 = row.get("start_time");
let status: String = row.get("status");
let attributes: String = row.get("attributes");
serde_json::json!({
"trace_id": trace_id,
"span_id": span_id,
"name": name,
"start_time": start_time,
"status": status,
"attributes": parse_json_column(&attributes)?,
})
}};
}
macro_rules! span_detail_json {
($row:expr) => {{
let row = $row;
let trace_id: String = row.get("trace_id");
let span_id: String = row.get("span_id");
let parent_span_id: Option<String> = row.get("parent_span_id");
let run_id: String = row.get("run_id");
let session_id: Option<String> = row.get("session_id");
let user_id_hash: Option<String> = row.get("user_id_hash");
let project_id: Option<String> = row.get("project_id");
let span_kind: String = row.get("span_kind");
let name: String = row.get("name");
let start_time: i64 = row.get("start_time");
let end_time: Option<i64> = row.get("end_time");
let status: String = row.get("status");
let status_message: Option<String> = row.get("status_message");
let error_type: Option<String> = row.get("error_type");
let error_message_redacted: Option<String> = row.get("error_message_redacted");
let attributes: String = row.get("attributes");
let otel_attributes: String = row.get("otel_attributes");
let openinference_attributes: String = row.get("openinference_attributes");
let memory_state: Option<String> = row.get("memory_state");
let latency_ms: Option<i64> = row.get("latency_ms");
let input_ref: Option<String> = row.get("input_ref");
let output_ref: Option<String> = row.get("output_ref");
let prompt_template_id: Option<String> = row.get("prompt_template_id");
let prompt_version: Option<String> = row.get("prompt_version");
let model_provider: Option<String> = row.get("model_provider");
let model_name: Option<String> = row.get("model_name");
let tool_name: Option<String> = row.get("tool_name");
let tool_schema_hash: Option<String> = row.get("tool_schema_hash");
let retrieval_query_hash: Option<String> = row.get("retrieval_query_hash");
let retrieved_document_refs: String = row.get("retrieved_document_refs");
let token_usage: Option<String> = row.get("token_usage");
let cost_estimate: Option<String> = row.get("cost_estimate");
let retry_count: Option<i64> = row.get("retry_count");
let cache_hit: Option<bool> = row.get("cache_hit");
let redaction_policy: String = row.get("redaction_policy");
let schema_version: String = row.get("schema_version");
serde_json::json!({
"trace_id": trace_id,
"span_id": span_id,
"parent_span_id": parent_span_id,
"run_id": run_id,
"session_id": session_id,
"user_id_hash": user_id_hash,
"project_id": project_id,
"span_kind": span_kind,
"name": name,
"start_time": start_time,
"end_time": end_time,
"status": status,
"status_message": status_message,
"error_type": error_type,
"error_message_redacted": error_message_redacted,
"attributes": parse_json_column(&attributes)?,
"otel_attributes": parse_json_column(&otel_attributes)?,
"openinference_attributes": parse_json_column(&openinference_attributes)?,
"memory_state": parse_opt_json_column(memory_state)?,
"latency_ms": latency_ms,
"input_ref": parse_opt_json_column(input_ref)?,
"output_ref": parse_opt_json_column(output_ref)?,
"prompt_template_id": prompt_template_id,
"prompt_version": prompt_version,
"model_provider": model_provider,
"model_name": model_name,
"tool_name": tool_name,
"tool_schema_hash": tool_schema_hash,
"retrieval_query_hash": retrieval_query_hash,
"retrieved_document_refs": parse_json_column(&retrieved_document_refs)?,
"token_usage": parse_opt_json_column(token_usage)?,
"cost_estimate": parse_opt_json_column(cost_estimate)?,
"retry_count": retry_count,
"cache_hit": cache_hit,
"redaction_policy": redaction_policy,
"schema_version": schema_version,
})
}};
}
macro_rules! event_detail_json {
($row:expr) => {{
let row = $row;
let event_id: String = row.get("event_id");
let trace_id: String = row.get("trace_id");
let run_id: String = row.get("run_id");
let parent_span_id: Option<String> = row.get("parent_span_id");
let seq: i64 = row.get("seq");
let event_kind: String = row.get("event_kind");
let name: String = row.get("name");
let timestamp: i64 = row.get("timestamp");
let attributes: String = row.get("attributes");
let schema_version: String = row.get("schema_version");
serde_json::json!({
"event_id": event_id,
"trace_id": trace_id,
"run_id": run_id,
"parent_span_id": parent_span_id,
"seq": seq,
"event_kind": event_kind,
"name": name,
"timestamp": timestamp,
"attributes": parse_json_column(&attributes)?,
"schema_version": schema_version,
})
}};
}
const LIST_TRACES_SQL_SQLITE: &str = r#"
SELECT trace_id, MIN(run_id) AS run_id, MIN(start_time) AS start_time,
MAX(end_time) AS end_time, COUNT(span_id) AS span_count,
CAST(MAX(CASE WHEN status = 'error' THEN 1 ELSE 0 END) AS BIGINT) AS has_error,
COALESCE(
MIN(CASE WHEN parent_span_id IS NULL THEN name END),
MIN(name)
) AS root_name,
COALESCE(
MIN(CASE WHEN parent_span_id IS NULL THEN span_kind END),
MIN(span_kind)
) AS root_span_kind,
MAX(model_provider) AS model_provider,
MAX(model_name) AS model_name,
MIN(CASE WHEN status = 'error' THEN error_message_redacted END) AS error_summary
FROM spans
WHERE (project_id = ? OR ? IS NULL)
GROUP BY trace_id
ORDER BY start_time DESC
LIMIT 50
"#;
const LIST_TRACES_SQL_PG: &str = r#"
SELECT trace_id, MIN(run_id) AS run_id, MIN(start_time) AS start_time,
MAX(end_time) AS end_time, COUNT(span_id) AS span_count,
CAST(MAX(CASE WHEN status = 'error' THEN 1 ELSE 0 END) AS BIGINT) AS has_error,
COALESCE(
MIN(CASE WHEN parent_span_id IS NULL THEN name END),
MIN(name)
) AS root_name,
COALESCE(
MIN(CASE WHEN parent_span_id IS NULL THEN span_kind END),
MIN(span_kind)
) AS root_span_kind,
MAX(model_provider) AS model_provider,
MAX(model_name) AS model_name,
MIN(CASE WHEN status = 'error' THEN error_message_redacted END) AS error_summary
FROM spans
WHERE (project_id = $1 OR $1 IS NULL)
GROUP BY trace_id
ORDER BY start_time DESC
LIMIT 50
"#;
const LIST_EVALS_SQL_SQLITE: &str = r#"
SELECT trace_id, span_id, name, start_time, status, attributes
FROM spans
WHERE (span_kind = 'evaluator' OR span_kind = 'Evaluator')
AND (project_id = ? OR ? IS NULL)
ORDER BY start_time DESC
LIMIT 50
"#;
const LIST_EVALS_SQL_PG: &str = r#"
SELECT trace_id, span_id, name, start_time, status, attributes
FROM spans
WHERE (span_kind = 'evaluator' OR span_kind = 'Evaluator')
AND (project_id = $1 OR $1 IS NULL)
ORDER BY start_time DESC
LIMIT 50
"#;
const GET_TRACE_SQL_SQLITE: &str = "SELECT * FROM spans WHERE trace_id = ? AND (project_id = ? OR ? IS NULL) ORDER BY start_time ASC";
const GET_TRACE_SQL_PG: &str = "SELECT * FROM spans WHERE trace_id = $1 AND (project_id = $2 OR $2 IS NULL) ORDER BY start_time ASC";
const GET_TRACE_EVENTS_SQL_SQLITE: &str = r#"
SELECT e.*
FROM events e
WHERE e.trace_id = ?
AND EXISTS (
SELECT 1 FROM spans s
WHERE s.trace_id = e.trace_id
AND (s.project_id = ? OR ? IS NULL)
)
ORDER BY e.timestamp ASC, e.seq ASC
"#;
const GET_TRACE_EVENTS_SQL_PG: &str = r#"
SELECT e.*
FROM events e
WHERE e.trace_id = $1
AND EXISTS (
SELECT 1 FROM spans s
WHERE s.trace_id = e.trace_id
AND (s.project_id = $2 OR $2 IS NULL)
)
ORDER BY e.timestamp ASC, e.seq ASC
"#;
async fn list_traces(
headers: HeaderMap,
State(state): State<AppState>,
) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
let project = authorize(&state, &headers)?.project().map(str::to_string);
let mut traces = Vec::new();
match &state.pool {
DbPool::Sqlite(pool) => {
let rows = sqlx::query(LIST_TRACES_SQL_SQLITE)
.bind(project.clone())
.bind(project)
.fetch_all(pool)
.await
.map_err(db_error)?;
for row in &rows {
traces.push(trace_summary_json!(row));
}
}
DbPool::Postgres(pool) => {
let rows = sqlx::query(LIST_TRACES_SQL_PG)
.bind(project)
.fetch_all(pool)
.await
.map_err(db_error)?;
for row in &rows {
traces.push(trace_summary_json!(row));
}
}
}
Ok(Json(traces))
}
async fn list_evals(
headers: HeaderMap,
State(state): State<AppState>,
) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
let project = authorize(&state, &headers)?.project().map(str::to_string);
let mut evals = Vec::new();
match &state.pool {
DbPool::Sqlite(pool) => {
let rows = sqlx::query(LIST_EVALS_SQL_SQLITE)
.bind(project.clone())
.bind(project)
.fetch_all(pool)
.await
.map_err(db_error)?;
for row in &rows {
evals.push(eval_row_json!(row));
}
}
DbPool::Postgres(pool) => {
let rows = sqlx::query(LIST_EVALS_SQL_PG)
.bind(project)
.fetch_all(pool)
.await
.map_err(db_error)?;
for row in &rows {
evals.push(eval_row_json!(row));
}
}
}
Ok(Json(evals))
}
async fn get_trace(
Path(trace_id): Path<String>,
headers: HeaderMap,
State(state): State<AppState>,
) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
let project = authorize(&state, &headers)?.project().map(str::to_string);
Ok(Json(
trace_rows_for_project(&state, &trace_id, project).await?,
))
}
async fn trace_rows_for_project(
state: &AppState,
trace_id: &str,
project: Option<String>,
) -> Result<Vec<serde_json::Value>, StatusCode> {
let mut spans = Vec::new();
match &state.pool {
DbPool::Sqlite(pool) => {
let rows = sqlx::query(GET_TRACE_SQL_SQLITE)
.bind(trace_id)
.bind(project.clone())
.bind(project)
.fetch_all(pool)
.await
.map_err(db_error)?;
for row in &rows {
spans.push(span_detail_json!(row));
}
}
DbPool::Postgres(pool) => {
let rows = sqlx::query(GET_TRACE_SQL_PG)
.bind(trace_id)
.bind(project)
.fetch_all(pool)
.await
.map_err(db_error)?;
for row in &rows {
spans.push(span_detail_json!(row));
}
}
}
Ok(spans)
}
async fn get_trace_events(
Path(trace_id): Path<String>,
headers: HeaderMap,
State(state): State<AppState>,
) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
let project = authorize(&state, &headers)?.project().map(str::to_string);
let mut events = Vec::new();
match &state.pool {
DbPool::Sqlite(pool) => {
let rows = sqlx::query(GET_TRACE_EVENTS_SQL_SQLITE)
.bind(trace_id)
.bind(project.clone())
.bind(project)
.fetch_all(pool)
.await
.map_err(db_error)?;
for row in &rows {
events.push(event_detail_json!(row));
}
}
DbPool::Postgres(pool) => {
let rows = sqlx::query(GET_TRACE_EVENTS_SQL_PG)
.bind(trace_id)
.bind(project)
.fetch_all(pool)
.await
.map_err(db_error)?;
for row in &rows {
events.push(event_detail_json!(row));
}
}
}
Ok(Json(events))
}
async fn openapi_contract() -> Result<Json<serde_json::Value>, StatusCode> {
serde_json::from_str(include_str!("../openapi/trace-weft.openapi.json"))
.map(Json)
.map_err(|e| {
tracing::error!("embedded OpenAPI contract is invalid JSON: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})
}
fn trace_span_key(span: &serde_json::Value) -> String {
format!(
"{}::{}",
span.get("span_kind")
.and_then(serde_json::Value::as_str)
.unwrap_or("unknown"),
span.get("name")
.and_then(serde_json::Value::as_str)
.unwrap_or("unnamed")
)
}
fn field_changed(a: &serde_json::Value, b: &serde_json::Value, key: &str) -> Option<String> {
(a.get(key) != b.get(key)).then(|| key.to_string())
}
fn diff_rows(
spans_a: Vec<serde_json::Value>,
spans_b: Vec<serde_json::Value>,
) -> (Vec<serde_json::Value>, serde_json::Value) {
let mut used_b = std::collections::HashSet::new();
let mut b_by_key: std::collections::HashMap<String, Vec<usize>> =
std::collections::HashMap::new();
for (idx, span) in spans_b.iter().enumerate() {
b_by_key.entry(trace_span_key(span)).or_default().push(idx);
}
let mut rows = Vec::new();
let mut changed = 0usize;
let mut removed = 0usize;
let mut matched = 0usize;
for (idx_a, span_a) in spans_a.iter().enumerate() {
let key = trace_span_key(span_a);
let match_idx = b_by_key
.get(&key)
.and_then(|candidates| candidates.iter().find(|idx| !used_b.contains(*idx)))
.copied();
if let Some(idx_b) = match_idx {
used_b.insert(idx_b);
matched += 1;
let span_b = &spans_b[idx_b];
let changed_fields: Vec<String> = [
"status",
"latency_ms",
"attributes",
"token_usage",
"cost_estimate",
"prompt_version",
"model_name",
"retrieval_query_hash",
]
.iter()
.filter_map(|field| field_changed(span_a, span_b, field))
.collect();
if changed_fields.is_empty() {
rows.push(serde_json::json!({
"key": format!("pair-{idx_a}-{idx_b}"),
"change": "unchanged",
"changed_fields": changed_fields,
"a": span_a,
"b": span_b,
}));
} else {
changed += 1;
rows.push(serde_json::json!({
"key": format!("pair-{idx_a}-{idx_b}"),
"change": "changed",
"changed_fields": changed_fields,
"a": span_a,
"b": span_b,
}));
}
} else {
removed += 1;
rows.push(serde_json::json!({
"key": format!("a-{idx_a}"),
"change": "removed",
"changed_fields": [],
"a": span_a,
"b": null,
}));
}
}
let mut added = 0usize;
for (idx_b, span_b) in spans_b.iter().enumerate() {
if !used_b.contains(&idx_b) {
added += 1;
rows.push(serde_json::json!({
"key": format!("b-{idx_b}"),
"change": "added",
"changed_fields": [],
"a": null,
"b": span_b,
}));
}
}
(
rows,
serde_json::json!({
"changed": changed,
"added": added,
"removed": removed,
"matched": matched,
}),
)
}
async fn get_trace_diff(
Path((trace_a, trace_b)): Path<(String, String)>,
headers: HeaderMap,
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let project = authorize(&state, &headers)?.project().map(str::to_string);
let spans_a = trace_rows_for_project(&state, &trace_a, project.clone()).await?;
let spans_b = trace_rows_for_project(&state, &trace_b, project).await?;
let (rows, summary) = diff_rows(spans_a, spans_b);
Ok(Json(serde_json::json!({
"trace_a": trace_a,
"trace_b": trace_b,
"summary": summary,
"rows": rows,
})))
}
async fn get_replay_plan(
Path((trace_id, span_id)): Path<(String, String)>,
headers: HeaderMap,
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let project = authorize(&state, &headers)?.project().map(str::to_string);
let spans = trace_rows_for_project(&state, &trace_id, project).await?;
let Some(target) = spans.iter().find(|span| {
span.get("span_id")
.and_then(serde_json::Value::as_str)
.is_some_and(|id| id == span_id)
}) else {
return Err(StatusCode::NOT_FOUND);
};
let span_name = target
.get("name")
.and_then(serde_json::Value::as_str)
.unwrap_or("span");
let mut mocked_span_ids = serde_json::Map::new();
mocked_span_ids.insert(
span_id.clone(),
target
.get("output_ref")
.cloned()
.unwrap_or(serde_json::Value::Null),
);
Ok(Json(serde_json::json!({
"trace_id": trace_id,
"target_span": target,
"config_template": {
"mocked_spans": {},
"mocked_span_ids": mocked_span_ids,
"block_side_effects": true
},
"command": format!("TRACE_WEFT_REPLAY_FILE=replay_config_{span_name}.json cargo run"),
})))
}
async fn get_blob(
Path(hash): Path<String>,
headers: HeaderMap,
State(state): State<AppState>,
) -> Result<Response<Body>, StatusCode> {
authorize(&state, &headers)?;
let hash = BlobHash(hash);
let Some(bytes) = state.blob_store.get_blob(&hash).await.map_err(db_error)? else {
return Err(StatusCode::NOT_FOUND);
};
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/octet-stream")
.header(header::CACHE_CONTROL, "no-store")
.body(Body::from(bytes))
.map_err(|e| {
tracing::error!("failed to build blob response: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})
}
use serde::{Deserialize, Serialize};
use trace_weft::hitl::HitlResponse;
#[derive(Deserialize)]
struct ReplayConfigRequest {
span_id: String,
span_name: String,
mocked_output: serde_json::Value,
#[serde(default = "default_block_side_effects")]
block_side_effects: bool,
}
#[derive(Serialize)]
struct ReplayConfigResponse {
file_name: String,
command: String,
config: trace_weft::ReplayConfig,
}
fn default_block_side_effects() -> bool {
true
}
async fn generate_replay_config(
headers: HeaderMap,
State(state): State<AppState>,
Json(req): Json<ReplayConfigRequest>,
) -> Result<Json<ReplayConfigResponse>, StatusCode> {
authorize(&state, &headers)?;
let mut config = trace_weft::ReplayConfig::default();
config
.mocked_span_ids
.insert(req.span_id.clone(), req.mocked_output);
config.block_side_effects = req.block_side_effects;
let safe_name = req
.span_name
.chars()
.map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
.collect::<String>();
Ok(Json(ReplayConfigResponse {
file_name: format!("replay_config_{safe_name}.json"),
command: format!("TRACE_WEFT_REPLAY_FILE=replay_config_{safe_name}.json cargo run"),
config,
}))
}
async fn get_pending_approvals(
headers: HeaderMap,
State(state): State<AppState>,
) -> Result<Json<Vec<String>>, StatusCode> {
authorize(&state, &headers)?;
Ok(Json(trace_weft::hitl::get_pending_approvals()))
}
#[derive(Deserialize)]
struct ResolveRequest {
span_id: String,
action: String,
value: Option<serde_json::Value>,
reason: Option<String>,
}
async fn resolve_approval(
headers: HeaderMap,
State(state): State<AppState>,
Json(req): Json<ResolveRequest>,
) -> Result<StatusCode, StatusCode> {
authorize(&state, &headers)?;
let response = if req.action == "approve" {
HitlResponse::Approved(req.value.unwrap_or(serde_json::json!({})))
} else {
HitlResponse::Rejected(req.reason.unwrap_or_else(|| "Rejected by user".to_string()))
};
if trace_weft::hitl::resolve_approval(&req.span_id, response).is_ok() {
Ok(StatusCode::OK)
} else {
Err(StatusCode::NOT_FOUND)
}
}
#[cfg(test)]
mod cors_tests {
use super::is_allowed_origin;
#[test]
fn allows_local_ui_and_tauri_origins() {
for origin in [
"http://localhost:5173",
"http://127.0.0.1:5173",
"http://localhost:3000",
"http://localhost",
"http://127.0.0.1",
"tauri://localhost",
"http://tauri.localhost",
] {
assert!(is_allowed_origin(origin), "{origin} should be allowed");
}
}
#[test]
fn rejects_external_and_lookalike_origins() {
for origin in [
"https://evil.example.com",
"http://localhost.evil.com",
"http://127.0.0.1.evil.com",
"https://localhost:5173",
"http://evil.com",
"null",
] {
assert!(!is_allowed_origin(origin), "{origin} should be rejected");
}
}
}