Skip to main content

trace_weft_server/
lib.rs

1pub mod auth;
2pub mod storage;
3
4use auth::{Auth, AuthConfig};
5use axum::{
6    Json, Router,
7    body::Body,
8    extract::{Path, State},
9    http::{HeaderMap, HeaderValue, Method, StatusCode, header},
10    response::Response,
11    routing::{get, post},
12};
13use sqlx::{PgPool, Row, SqlitePool, postgres::PgPoolOptions, sqlite::SqlitePoolOptions};
14use std::net::SocketAddr;
15use std::path::PathBuf;
16use std::sync::Arc;
17use tower_http::cors::{AllowOrigin, CorsLayer};
18use trace_weft_core::{BlobHash, SpanRecord};
19use trace_weft_recorder::TraceStore;
20
21#[derive(Clone)]
22pub enum DbPool {
23    Sqlite(SqlitePool),
24    Postgres(PgPool),
25}
26
27#[derive(Clone)]
28pub struct AppState {
29    pub pool: DbPool,
30    pub blob_store: Arc<dyn trace_weft_core::BlobStore>,
31    pub trace_store: Arc<dyn TraceStore>,
32    pub clickhouse: Option<Arc<storage::analytics::ClickHouseAnalytics>>,
33    pub auth: Arc<AuthConfig>,
34}
35
36/// Start the server with the **production-secure** auth default
37/// ([`AuthConfig::from_env`]): unauthenticated requests are rejected unless
38/// `TRACE_WEFT_API_KEYS`/`TRACE_WEFT_DEV_MODE` are configured. Runs until the
39/// process ends.
40pub async fn start_server(db_url: &str, port: u16, blob_dir: PathBuf) -> anyhow::Result<()> {
41    start_server_with_shutdown(
42        db_url,
43        port,
44        blob_dir,
45        AuthConfig::from_env(),
46        std::future::pending::<()>(),
47    )
48    .await
49}
50
51/// Start a **local-first** dev server: the auth bypass defaults on when no keys
52/// are configured (see [`AuthConfig::from_env_local_first`]), so the local UI
53/// works without keys. Used by `trace-weft dev`.
54pub async fn start_dev_server(db_url: &str, port: u16, blob_dir: PathBuf) -> anyhow::Result<()> {
55    start_server_with_shutdown(
56        db_url,
57        port,
58        blob_dir,
59        AuthConfig::from_env_local_first(),
60        std::future::pending::<()>(),
61    )
62    .await
63}
64
65/// Start the server with an explicit [`AuthConfig`], stopping gracefully when
66/// `shutdown` resolves. Used by the desktop app to start/stop the embedded
67/// server on demand and to drain it cleanly on app exit.
68pub async fn start_server_with_shutdown(
69    db_url: &str,
70    port: u16,
71    blob_dir: PathBuf,
72    auth: AuthConfig,
73    shutdown: impl std::future::Future<Output = ()> + Send + 'static,
74) -> anyhow::Result<()> {
75    let pool = if db_url.starts_with("postgres://") || db_url.starts_with("postgresql://") {
76        let pg_pool = PgPoolOptions::new().connect(db_url).await?;
77        DbPool::Postgres(pg_pool)
78    } else {
79        // Assume sqlite file path or sqlite:// url
80        let url = if db_url.starts_with("sqlite://") {
81            db_url.to_string()
82        } else {
83            if let Some(parent) = std::path::Path::new(db_url).parent() {
84                tokio::fs::create_dir_all(parent).await?;
85            }
86            format!("sqlite://{}?mode=rwc", db_url)
87        };
88        let sq_pool = SqlitePoolOptions::new().connect(&url).await?;
89        DbPool::Sqlite(sq_pool)
90    };
91
92    let blob_store = Arc::new(storage::blob::LocalBlobStore::new(blob_dir));
93
94    let trace_store: Arc<dyn TraceStore> = match &pool {
95        DbPool::Postgres(pg_pool) => {
96            Arc::new(storage::postgres::PostgresRecorder::from_pool(pg_pool.clone()).await?)
97        }
98        DbPool::Sqlite(sq_pool) => {
99            Arc::new(trace_weft_recorder::sqlite::SqliteRecorder::from_pool(sq_pool.clone()).await?)
100        }
101    };
102
103    // Enterprise Analytics (Stubbed connection if env var is present)
104    let clickhouse = if let Ok(ch_url) = std::env::var("TRACE_WEFT_CH_URL") {
105        tracing::info!("Initializing ClickHouse analytics connected to {}", ch_url);
106        Some(Arc::new(storage::analytics::ClickHouseAnalytics::new(
107            &ch_url, "default", "", "default",
108        )))
109    } else {
110        None
111    };
112
113    let state = AppState {
114        pool,
115        blob_store,
116        trace_store,
117        clickhouse,
118        auth: Arc::new(auth),
119    };
120
121    let app = build_router(state);
122
123    let addr = SocketAddr::from(([127, 0, 0, 1], port));
124    tracing::info!("Server listening on http://{}", addr);
125
126    let listener = tokio::net::TcpListener::bind(addr).await?;
127    axum::serve(listener, app)
128        .with_graceful_shutdown(shutdown)
129        .await?;
130
131    Ok(())
132}
133
134/// Build the TraceWeft API router over the given application state.
135pub fn build_router(state: AppState) -> Router {
136    Router::new()
137        .route("/api/traces", get(list_traces))
138        .route("/api/traces/{trace_id}", get(get_trace))
139        .route("/api/traces/{trace_id}/events", get(get_trace_events))
140        .route(
141            "/api/traces/{trace_id}/replay-plan/{span_id}",
142            get(get_replay_plan),
143        )
144        .route("/api/diff/{trace_a}/{trace_b}", get(get_trace_diff))
145        .route("/api/blobs/{hash}", get(get_blob))
146        .route("/api/openapi.json", get(openapi_contract))
147        .route("/api/evals", get(list_evals))
148        .route("/api/v1/batch", post(batch_ingest))
149        .route("/v1/traces", post(otlp_traces_ingest))
150        .route("/api/replay/config", post(generate_replay_config))
151        .route("/api/hitl/pending", get(get_pending_approvals))
152        .route("/api/hitl/resolve", post(resolve_approval))
153        .layer(local_cors())
154        .with_state(state)
155}
156
157/// CORS for a local-first server: only the local dev UI and the desktop webview
158/// may read API responses. A permissive policy would let any website the user
159/// visits script `127.0.0.1:<port>` and exfiltrate locally-stored prompts and
160/// tool outputs (and, for JSON `POST`s, drive HITL/ingest via CSRF). Restricting
161/// the allowed origins makes the browser block both.
162fn local_cors() -> CorsLayer {
163    CorsLayer::new()
164        .allow_methods([Method::GET, Method::POST, Method::OPTIONS])
165        .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE])
166        .allow_origin(AllowOrigin::predicate(|origin: &HeaderValue, _req| {
167            origin.to_str().map(is_allowed_origin).unwrap_or(false)
168        }))
169}
170
171/// Allow the Tauri webview origins and loopback (any port, for the Vite dev
172/// server and direct browser access); reject everything else.
173fn is_allowed_origin(origin: &str) -> bool {
174    if origin == "tauri://localhost" || origin == "http://tauri.localhost" {
175        return true;
176    }
177    ["http://localhost", "http://127.0.0.1"]
178        .iter()
179        .any(|host| origin == *host || origin.starts_with(&format!("{host}:")))
180}
181
182/// Resolve the request's API key to a tenant, or `401` when none is valid and
183/// the dev bypass is off.
184fn authorize(state: &AppState, headers: &HeaderMap) -> Result<Auth, StatusCode> {
185    state
186        .auth
187        .authenticate(headers)
188        .ok_or(StatusCode::UNAUTHORIZED)
189}
190
191async fn batch_ingest(
192    headers: HeaderMap,
193    State(state): State<AppState>,
194    Json(mut spans): Json<Vec<SpanRecord>>,
195) -> Result<StatusCode, StatusCode> {
196    let auth = authorize(&state, &headers)?;
197    // The server is authoritative on tenancy: stamp the authenticated project
198    // onto every span so a client cannot assert someone else's project_id.
199    let project_id = auth.project().map(|p| p.to_string());
200    for span in &mut spans {
201        span.project_id = project_id.clone();
202    }
203
204    tracing::info!(
205        "Received batch of {} spans for project {:?}",
206        spans.len(),
207        project_id
208    );
209
210    // 1. Ingest metadata into Postgres
211    for span in &spans {
212        // In a real app, this should be a bulk insert
213        if let Err(e) = state.trace_store.record_span(span.clone()).await {
214            tracing::error!("Failed to record span: {}", e);
215            return Err(StatusCode::INTERNAL_SERVER_ERROR);
216        }
217    }
218
219    // 2. Stream to ClickHouse for analytics
220    if let Some(ch) = &state.clickhouse
221        && let Err(e) = ch.ingest_batch(&spans).await
222    {
223        tracing::warn!("Failed to stream to ClickHouse: {}", e);
224    }
225
226    Ok(StatusCode::ACCEPTED)
227}
228
229/// OTLP/HTTP JSON trace ingestion at `/v1/traces`. Decodes the export request
230/// via `trace-weft-ingest`, then — exactly like [`batch_ingest`] — stamps the
231/// authenticated project onto every span before persisting so a client cannot
232/// assert another tenant's `project_id`. Returns an empty OTLP
233/// `ExportTraceServiceResponse` (`{}`) on success, `400` for a malformed body.
234async fn otlp_traces_ingest(
235    headers: HeaderMap,
236    State(state): State<AppState>,
237    body: axum::body::Bytes,
238) -> Result<Json<serde_json::Value>, StatusCode> {
239    let auth = authorize(&state, &headers)?;
240    let project_id = auth.project().map(|p| p.to_string());
241
242    let mut spans = trace_weft_ingest::records_from_otlp_json(&body).map_err(|e| {
243        tracing::warn!("rejecting malformed OTLP payload: {e}");
244        StatusCode::BAD_REQUEST
245    })?;
246    for span in &mut spans {
247        span.project_id = project_id.clone();
248    }
249
250    tracing::info!(
251        "Received OTLP export of {} spans for project {:?}",
252        spans.len(),
253        project_id
254    );
255
256    for span in &spans {
257        if let Err(e) = state.trace_store.record_span(span.clone()).await {
258            tracing::error!("Failed to record OTLP span: {}", e);
259            return Err(StatusCode::INTERNAL_SERVER_ERROR);
260        }
261    }
262
263    if let Some(ch) = &state.clickhouse
264        && let Err(e) = ch.ingest_batch(&spans).await
265    {
266        tracing::warn!("Failed to stream OTLP spans to ClickHouse: {}", e);
267    }
268
269    Ok(Json(serde_json::json!({})))
270}
271
272/// Log a database error and surface it as a 500. Used by every query handler so
273/// failures are recorded rather than silently flattened to an empty body.
274fn db_error<E: std::fmt::Display>(e: E) -> StatusCode {
275    tracing::error!("database query failed: {e}");
276    StatusCode::INTERNAL_SERVER_ERROR
277}
278
279/// Decode a JSON column we wrote ourselves. A parse failure means the row is
280/// corrupt, so we surface a 500 instead of masking it with an empty object —
281/// silently substituting `{}` would hide data loss from the caller.
282fn parse_json_column(raw: &str) -> Result<serde_json::Value, StatusCode> {
283    serde_json::from_str(raw).map_err(|e| {
284        tracing::error!("corrupt JSON in spans column: {e}");
285        StatusCode::INTERNAL_SERVER_ERROR
286    })
287}
288
289/// Decode a nullable JSON column, preserving SQL `NULL` as JSON `null`.
290fn parse_opt_json_column(raw: Option<String>) -> Result<serde_json::Value, StatusCode> {
291    match raw {
292        Some(s) => parse_json_column(&s),
293        None => Ok(serde_json::Value::Null),
294    }
295}
296
297// The SQLite and Postgres `spans` tables share an identical column layout, so a
298// single row shape maps to JSON for either backend. These macros expand the
299// same extraction against `SqliteRow` or `PgRow` (the `?` inside propagates to
300// the calling handler), keeping the two dialects from drifting apart.
301
302/// One row of the trace-summary aggregate (see `list_traces`).
303macro_rules! trace_summary_json {
304    ($row:expr) => {{
305        let row = $row;
306        let trace_id: String = row.get("trace_id");
307        let run_id: String = row.get("run_id");
308        let start_time: i64 = row.get("start_time");
309        let end_time: Option<i64> = row.get("end_time");
310        let span_count: i64 = row.get("span_count");
311        let has_error: i64 = row.get("has_error");
312        let root_name: Option<String> = row.get("root_name");
313        let root_span_kind: Option<String> = row.get("root_span_kind");
314        let model_provider: Option<String> = row.get("model_provider");
315        let model_name: Option<String> = row.get("model_name");
316        let error_summary: Option<String> = row.get("error_summary");
317        serde_json::json!({
318            "trace_id": trace_id,
319            "run_id": run_id,
320            "start_time": start_time,
321            "end_time": end_time,
322            "span_count": span_count,
323            "root_name": root_name,
324            "root_span_kind": root_span_kind,
325            "model_provider": model_provider,
326            "model_name": model_name,
327            "error_summary": error_summary,
328            // A trace is errored if any of its spans errored, otherwise ok.
329            "status": if has_error != 0 { "error" } else { "ok" },
330        })
331    }};
332}
333
334/// One evaluator span row (see `list_evals`).
335macro_rules! eval_row_json {
336    ($row:expr) => {{
337        let row = $row;
338        let trace_id: String = row.get("trace_id");
339        let span_id: String = row.get("span_id");
340        let name: String = row.get("name");
341        let start_time: i64 = row.get("start_time");
342        let status: String = row.get("status");
343        let attributes: String = row.get("attributes");
344        serde_json::json!({
345            "trace_id": trace_id,
346            "span_id": span_id,
347            "name": name,
348            "start_time": start_time,
349            "status": status,
350            "attributes": parse_json_column(&attributes)?,
351        })
352    }};
353}
354
355/// One full span row (see `get_trace`).
356macro_rules! span_detail_json {
357    ($row:expr) => {{
358        let row = $row;
359        let trace_id: String = row.get("trace_id");
360        let span_id: String = row.get("span_id");
361        let parent_span_id: Option<String> = row.get("parent_span_id");
362        let run_id: String = row.get("run_id");
363        let session_id: Option<String> = row.get("session_id");
364        let user_id_hash: Option<String> = row.get("user_id_hash");
365        let project_id: Option<String> = row.get("project_id");
366        let span_kind: String = row.get("span_kind");
367        let name: String = row.get("name");
368        let start_time: i64 = row.get("start_time");
369        let end_time: Option<i64> = row.get("end_time");
370        let status: String = row.get("status");
371        let status_message: Option<String> = row.get("status_message");
372        let error_type: Option<String> = row.get("error_type");
373        let error_message_redacted: Option<String> = row.get("error_message_redacted");
374        let attributes: String = row.get("attributes");
375        let otel_attributes: String = row.get("otel_attributes");
376        let openinference_attributes: String = row.get("openinference_attributes");
377        let memory_state: Option<String> = row.get("memory_state");
378        let latency_ms: Option<i64> = row.get("latency_ms");
379        let input_ref: Option<String> = row.get("input_ref");
380        let output_ref: Option<String> = row.get("output_ref");
381        let prompt_template_id: Option<String> = row.get("prompt_template_id");
382        let prompt_version: Option<String> = row.get("prompt_version");
383        let model_provider: Option<String> = row.get("model_provider");
384        let model_name: Option<String> = row.get("model_name");
385        let tool_name: Option<String> = row.get("tool_name");
386        let tool_schema_hash: Option<String> = row.get("tool_schema_hash");
387        let retrieval_query_hash: Option<String> = row.get("retrieval_query_hash");
388        let retrieved_document_refs: String = row.get("retrieved_document_refs");
389        let token_usage: Option<String> = row.get("token_usage");
390        let cost_estimate: Option<String> = row.get("cost_estimate");
391        let retry_count: Option<i64> = row.get("retry_count");
392        let cache_hit: Option<bool> = row.get("cache_hit");
393        let redaction_policy: String = row.get("redaction_policy");
394        let schema_version: String = row.get("schema_version");
395        serde_json::json!({
396            "trace_id": trace_id,
397            "span_id": span_id,
398            "parent_span_id": parent_span_id,
399            "run_id": run_id,
400            "session_id": session_id,
401            "user_id_hash": user_id_hash,
402            "project_id": project_id,
403            "span_kind": span_kind,
404            "name": name,
405            "start_time": start_time,
406            "end_time": end_time,
407            "status": status,
408            "status_message": status_message,
409            "error_type": error_type,
410            "error_message_redacted": error_message_redacted,
411            "attributes": parse_json_column(&attributes)?,
412            "otel_attributes": parse_json_column(&otel_attributes)?,
413            "openinference_attributes": parse_json_column(&openinference_attributes)?,
414            "memory_state": parse_opt_json_column(memory_state)?,
415            "latency_ms": latency_ms,
416            "input_ref": parse_opt_json_column(input_ref)?,
417            "output_ref": parse_opt_json_column(output_ref)?,
418            "prompt_template_id": prompt_template_id,
419            "prompt_version": prompt_version,
420            "model_provider": model_provider,
421            "model_name": model_name,
422            "tool_name": tool_name,
423            "tool_schema_hash": tool_schema_hash,
424            "retrieval_query_hash": retrieval_query_hash,
425            "retrieved_document_refs": parse_json_column(&retrieved_document_refs)?,
426            "token_usage": parse_opt_json_column(token_usage)?,
427            "cost_estimate": parse_opt_json_column(cost_estimate)?,
428            "retry_count": retry_count,
429            "cache_hit": cache_hit,
430            "redaction_policy": redaction_policy,
431            "schema_version": schema_version,
432        })
433    }};
434}
435
436/// One event row (see `get_trace_events`).
437macro_rules! event_detail_json {
438    ($row:expr) => {{
439        let row = $row;
440        let event_id: String = row.get("event_id");
441        let trace_id: String = row.get("trace_id");
442        let run_id: String = row.get("run_id");
443        let parent_span_id: Option<String> = row.get("parent_span_id");
444        let seq: i64 = row.get("seq");
445        let event_kind: String = row.get("event_kind");
446        let name: String = row.get("name");
447        let timestamp: i64 = row.get("timestamp");
448        let attributes: String = row.get("attributes");
449        let schema_version: String = row.get("schema_version");
450        serde_json::json!({
451            "event_id": event_id,
452            "trace_id": trace_id,
453            "run_id": run_id,
454            "parent_span_id": parent_span_id,
455            "seq": seq,
456            "event_kind": event_kind,
457            "name": name,
458            "timestamp": timestamp,
459            "attributes": parse_json_column(&attributes)?,
460            "schema_version": schema_version,
461        })
462    }};
463}
464
465// Project scoping: each query filters on `project_id` against the bound
466// `project` value. A real tenant binds its project id; the dev bypass binds
467// SQL `NULL`, and the `OR <param> IS NULL` arm then matches every row so
468// local-first runs see all traces. Postgres reuses one `$1`; SQLite repeats the
469// positional `?`, so the project value is bound twice there.
470//
471// The aggregate is portable: every span of a trace shares a run_id (so
472// MIN(run_id) is deterministic), the error rollup is CAST to BIGINT so both
473// engines decode it as i64, and only grouped/aggregated columns are selected so
474// Postgres (which rejects bare columns under GROUP BY) is happy.
475const LIST_TRACES_SQL_SQLITE: &str = r#"
476    SELECT trace_id, MIN(run_id) AS run_id, MIN(start_time) AS start_time,
477           MAX(end_time) AS end_time, COUNT(span_id) AS span_count,
478           CAST(MAX(CASE WHEN status = 'error' THEN 1 ELSE 0 END) AS BIGINT) AS has_error,
479           COALESCE(
480             MIN(CASE WHEN parent_span_id IS NULL THEN name END),
481             MIN(name)
482           ) AS root_name,
483           COALESCE(
484             MIN(CASE WHEN parent_span_id IS NULL THEN span_kind END),
485             MIN(span_kind)
486           ) AS root_span_kind,
487           MAX(model_provider) AS model_provider,
488           MAX(model_name) AS model_name,
489           MIN(CASE WHEN status = 'error' THEN error_message_redacted END) AS error_summary
490    FROM spans
491    WHERE (project_id = ? OR ? IS NULL)
492    GROUP BY trace_id
493    ORDER BY start_time DESC
494    LIMIT 50
495"#;
496
497const LIST_TRACES_SQL_PG: &str = r#"
498    SELECT trace_id, MIN(run_id) AS run_id, MIN(start_time) AS start_time,
499           MAX(end_time) AS end_time, COUNT(span_id) AS span_count,
500           CAST(MAX(CASE WHEN status = 'error' THEN 1 ELSE 0 END) AS BIGINT) AS has_error,
501           COALESCE(
502             MIN(CASE WHEN parent_span_id IS NULL THEN name END),
503             MIN(name)
504           ) AS root_name,
505           COALESCE(
506             MIN(CASE WHEN parent_span_id IS NULL THEN span_kind END),
507             MIN(span_kind)
508           ) AS root_span_kind,
509           MAX(model_provider) AS model_provider,
510           MAX(model_name) AS model_name,
511           MIN(CASE WHEN status = 'error' THEN error_message_redacted END) AS error_summary
512    FROM spans
513    WHERE (project_id = $1 OR $1 IS NULL)
514    GROUP BY trace_id
515    ORDER BY start_time DESC
516    LIMIT 50
517"#;
518
519const LIST_EVALS_SQL_SQLITE: &str = r#"
520    SELECT trace_id, span_id, name, start_time, status, attributes
521    FROM spans
522    WHERE (span_kind = 'evaluator' OR span_kind = 'Evaluator')
523      AND (project_id = ? OR ? IS NULL)
524    ORDER BY start_time DESC
525    LIMIT 50
526"#;
527
528const LIST_EVALS_SQL_PG: &str = r#"
529    SELECT trace_id, span_id, name, start_time, status, attributes
530    FROM spans
531    WHERE (span_kind = 'evaluator' OR span_kind = 'Evaluator')
532      AND (project_id = $1 OR $1 IS NULL)
533    ORDER BY start_time DESC
534    LIMIT 50
535"#;
536
537const GET_TRACE_SQL_SQLITE: &str = "SELECT * FROM spans WHERE trace_id = ? AND (project_id = ? OR ? IS NULL) ORDER BY start_time ASC";
538
539const GET_TRACE_SQL_PG: &str = "SELECT * FROM spans WHERE trace_id = $1 AND (project_id = $2 OR $2 IS NULL) ORDER BY start_time ASC";
540
541const GET_TRACE_EVENTS_SQL_SQLITE: &str = r#"
542    SELECT e.*
543    FROM events e
544    WHERE e.trace_id = ?
545      AND EXISTS (
546        SELECT 1 FROM spans s
547        WHERE s.trace_id = e.trace_id
548          AND (s.project_id = ? OR ? IS NULL)
549      )
550    ORDER BY e.timestamp ASC, e.seq ASC
551"#;
552
553const GET_TRACE_EVENTS_SQL_PG: &str = r#"
554    SELECT e.*
555    FROM events e
556    WHERE e.trace_id = $1
557      AND EXISTS (
558        SELECT 1 FROM spans s
559        WHERE s.trace_id = e.trace_id
560          AND (s.project_id = $2 OR $2 IS NULL)
561      )
562    ORDER BY e.timestamp ASC, e.seq ASC
563"#;
564
565async fn list_traces(
566    headers: HeaderMap,
567    State(state): State<AppState>,
568) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
569    let project = authorize(&state, &headers)?.project().map(str::to_string);
570    let mut traces = Vec::new();
571    match &state.pool {
572        DbPool::Sqlite(pool) => {
573            let rows = sqlx::query(LIST_TRACES_SQL_SQLITE)
574                .bind(project.clone())
575                .bind(project)
576                .fetch_all(pool)
577                .await
578                .map_err(db_error)?;
579            for row in &rows {
580                traces.push(trace_summary_json!(row));
581            }
582        }
583        DbPool::Postgres(pool) => {
584            let rows = sqlx::query(LIST_TRACES_SQL_PG)
585                .bind(project)
586                .fetch_all(pool)
587                .await
588                .map_err(db_error)?;
589            for row in &rows {
590                traces.push(trace_summary_json!(row));
591            }
592        }
593    }
594    Ok(Json(traces))
595}
596
597async fn list_evals(
598    headers: HeaderMap,
599    State(state): State<AppState>,
600) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
601    let project = authorize(&state, &headers)?.project().map(str::to_string);
602    let mut evals = Vec::new();
603    match &state.pool {
604        DbPool::Sqlite(pool) => {
605            let rows = sqlx::query(LIST_EVALS_SQL_SQLITE)
606                .bind(project.clone())
607                .bind(project)
608                .fetch_all(pool)
609                .await
610                .map_err(db_error)?;
611            for row in &rows {
612                evals.push(eval_row_json!(row));
613            }
614        }
615        DbPool::Postgres(pool) => {
616            let rows = sqlx::query(LIST_EVALS_SQL_PG)
617                .bind(project)
618                .fetch_all(pool)
619                .await
620                .map_err(db_error)?;
621            for row in &rows {
622                evals.push(eval_row_json!(row));
623            }
624        }
625    }
626    Ok(Json(evals))
627}
628
629async fn get_trace(
630    Path(trace_id): Path<String>,
631    headers: HeaderMap,
632    State(state): State<AppState>,
633) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
634    let project = authorize(&state, &headers)?.project().map(str::to_string);
635    Ok(Json(
636        trace_rows_for_project(&state, &trace_id, project).await?,
637    ))
638}
639
640async fn trace_rows_for_project(
641    state: &AppState,
642    trace_id: &str,
643    project: Option<String>,
644) -> Result<Vec<serde_json::Value>, StatusCode> {
645    let mut spans = Vec::new();
646    match &state.pool {
647        DbPool::Sqlite(pool) => {
648            let rows = sqlx::query(GET_TRACE_SQL_SQLITE)
649                .bind(trace_id)
650                .bind(project.clone())
651                .bind(project)
652                .fetch_all(pool)
653                .await
654                .map_err(db_error)?;
655            for row in &rows {
656                spans.push(span_detail_json!(row));
657            }
658        }
659        DbPool::Postgres(pool) => {
660            let rows = sqlx::query(GET_TRACE_SQL_PG)
661                .bind(trace_id)
662                .bind(project)
663                .fetch_all(pool)
664                .await
665                .map_err(db_error)?;
666            for row in &rows {
667                spans.push(span_detail_json!(row));
668            }
669        }
670    }
671    Ok(spans)
672}
673
674async fn get_trace_events(
675    Path(trace_id): Path<String>,
676    headers: HeaderMap,
677    State(state): State<AppState>,
678) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
679    let project = authorize(&state, &headers)?.project().map(str::to_string);
680    let mut events = Vec::new();
681    match &state.pool {
682        DbPool::Sqlite(pool) => {
683            let rows = sqlx::query(GET_TRACE_EVENTS_SQL_SQLITE)
684                .bind(trace_id)
685                .bind(project.clone())
686                .bind(project)
687                .fetch_all(pool)
688                .await
689                .map_err(db_error)?;
690            for row in &rows {
691                events.push(event_detail_json!(row));
692            }
693        }
694        DbPool::Postgres(pool) => {
695            let rows = sqlx::query(GET_TRACE_EVENTS_SQL_PG)
696                .bind(trace_id)
697                .bind(project)
698                .fetch_all(pool)
699                .await
700                .map_err(db_error)?;
701            for row in &rows {
702                events.push(event_detail_json!(row));
703            }
704        }
705    }
706    Ok(Json(events))
707}
708
709async fn openapi_contract() -> Result<Json<serde_json::Value>, StatusCode> {
710    serde_json::from_str(include_str!("../openapi/trace-weft.openapi.json"))
711        .map(Json)
712        .map_err(|e| {
713            tracing::error!("embedded OpenAPI contract is invalid JSON: {e}");
714            StatusCode::INTERNAL_SERVER_ERROR
715        })
716}
717
718fn trace_span_key(span: &serde_json::Value) -> String {
719    format!(
720        "{}::{}",
721        span.get("span_kind")
722            .and_then(serde_json::Value::as_str)
723            .unwrap_or("unknown"),
724        span.get("name")
725            .and_then(serde_json::Value::as_str)
726            .unwrap_or("unnamed")
727    )
728}
729
730fn field_changed(a: &serde_json::Value, b: &serde_json::Value, key: &str) -> Option<String> {
731    (a.get(key) != b.get(key)).then(|| key.to_string())
732}
733
734fn diff_rows(
735    spans_a: Vec<serde_json::Value>,
736    spans_b: Vec<serde_json::Value>,
737) -> (Vec<serde_json::Value>, serde_json::Value) {
738    let mut used_b = std::collections::HashSet::new();
739    let mut b_by_key: std::collections::HashMap<String, Vec<usize>> =
740        std::collections::HashMap::new();
741    for (idx, span) in spans_b.iter().enumerate() {
742        b_by_key.entry(trace_span_key(span)).or_default().push(idx);
743    }
744
745    let mut rows = Vec::new();
746    let mut changed = 0usize;
747    let mut removed = 0usize;
748    let mut matched = 0usize;
749
750    for (idx_a, span_a) in spans_a.iter().enumerate() {
751        let key = trace_span_key(span_a);
752        let match_idx = b_by_key
753            .get(&key)
754            .and_then(|candidates| candidates.iter().find(|idx| !used_b.contains(*idx)))
755            .copied();
756
757        if let Some(idx_b) = match_idx {
758            used_b.insert(idx_b);
759            matched += 1;
760            let span_b = &spans_b[idx_b];
761            let changed_fields: Vec<String> = [
762                "status",
763                "latency_ms",
764                "attributes",
765                "token_usage",
766                "cost_estimate",
767                "prompt_version",
768                "model_name",
769                "retrieval_query_hash",
770            ]
771            .iter()
772            .filter_map(|field| field_changed(span_a, span_b, field))
773            .collect();
774            if changed_fields.is_empty() {
775                rows.push(serde_json::json!({
776                    "key": format!("pair-{idx_a}-{idx_b}"),
777                    "change": "unchanged",
778                    "changed_fields": changed_fields,
779                    "a": span_a,
780                    "b": span_b,
781                }));
782            } else {
783                changed += 1;
784                rows.push(serde_json::json!({
785                    "key": format!("pair-{idx_a}-{idx_b}"),
786                    "change": "changed",
787                    "changed_fields": changed_fields,
788                    "a": span_a,
789                    "b": span_b,
790                }));
791            }
792        } else {
793            removed += 1;
794            rows.push(serde_json::json!({
795                "key": format!("a-{idx_a}"),
796                "change": "removed",
797                "changed_fields": [],
798                "a": span_a,
799                "b": null,
800            }));
801        }
802    }
803
804    let mut added = 0usize;
805    for (idx_b, span_b) in spans_b.iter().enumerate() {
806        if !used_b.contains(&idx_b) {
807            added += 1;
808            rows.push(serde_json::json!({
809                "key": format!("b-{idx_b}"),
810                "change": "added",
811                "changed_fields": [],
812                "a": null,
813                "b": span_b,
814            }));
815        }
816    }
817
818    (
819        rows,
820        serde_json::json!({
821            "changed": changed,
822            "added": added,
823            "removed": removed,
824            "matched": matched,
825        }),
826    )
827}
828
829async fn get_trace_diff(
830    Path((trace_a, trace_b)): Path<(String, String)>,
831    headers: HeaderMap,
832    State(state): State<AppState>,
833) -> Result<Json<serde_json::Value>, StatusCode> {
834    let project = authorize(&state, &headers)?.project().map(str::to_string);
835    let spans_a = trace_rows_for_project(&state, &trace_a, project.clone()).await?;
836    let spans_b = trace_rows_for_project(&state, &trace_b, project).await?;
837    let (rows, summary) = diff_rows(spans_a, spans_b);
838    Ok(Json(serde_json::json!({
839        "trace_a": trace_a,
840        "trace_b": trace_b,
841        "summary": summary,
842        "rows": rows,
843    })))
844}
845
846async fn get_replay_plan(
847    Path((trace_id, span_id)): Path<(String, String)>,
848    headers: HeaderMap,
849    State(state): State<AppState>,
850) -> Result<Json<serde_json::Value>, StatusCode> {
851    let project = authorize(&state, &headers)?.project().map(str::to_string);
852    let spans = trace_rows_for_project(&state, &trace_id, project).await?;
853    let Some(target) = spans.iter().find(|span| {
854        span.get("span_id")
855            .and_then(serde_json::Value::as_str)
856            .is_some_and(|id| id == span_id)
857    }) else {
858        return Err(StatusCode::NOT_FOUND);
859    };
860
861    let span_name = target
862        .get("name")
863        .and_then(serde_json::Value::as_str)
864        .unwrap_or("span");
865    let mut mocked_span_ids = serde_json::Map::new();
866    mocked_span_ids.insert(
867        span_id.clone(),
868        target
869            .get("output_ref")
870            .cloned()
871            .unwrap_or(serde_json::Value::Null),
872    );
873    Ok(Json(serde_json::json!({
874        "trace_id": trace_id,
875        "target_span": target,
876        "config_template": {
877            "mocked_spans": {},
878            "mocked_span_ids": mocked_span_ids,
879            "block_side_effects": true
880        },
881        "command": format!("TRACE_WEFT_REPLAY_FILE=replay_config_{span_name}.json cargo run"),
882    })))
883}
884
885async fn get_blob(
886    Path(hash): Path<String>,
887    headers: HeaderMap,
888    State(state): State<AppState>,
889) -> Result<Response<Body>, StatusCode> {
890    authorize(&state, &headers)?;
891    let hash = BlobHash(hash);
892    let Some(bytes) = state.blob_store.get_blob(&hash).await.map_err(db_error)? else {
893        return Err(StatusCode::NOT_FOUND);
894    };
895
896    Response::builder()
897        .status(StatusCode::OK)
898        .header(header::CONTENT_TYPE, "application/octet-stream")
899        .header(header::CACHE_CONTROL, "no-store")
900        .body(Body::from(bytes))
901        .map_err(|e| {
902            tracing::error!("failed to build blob response: {e}");
903            StatusCode::INTERNAL_SERVER_ERROR
904        })
905}
906
907use serde::{Deserialize, Serialize};
908use trace_weft::hitl::HitlResponse;
909
910#[derive(Deserialize)]
911struct ReplayConfigRequest {
912    span_id: String,
913    span_name: String,
914    mocked_output: serde_json::Value,
915    #[serde(default = "default_block_side_effects")]
916    block_side_effects: bool,
917}
918
919#[derive(Serialize)]
920struct ReplayConfigResponse {
921    file_name: String,
922    command: String,
923    config: trace_weft::ReplayConfig,
924}
925
926fn default_block_side_effects() -> bool {
927    true
928}
929
930async fn generate_replay_config(
931    headers: HeaderMap,
932    State(state): State<AppState>,
933    Json(req): Json<ReplayConfigRequest>,
934) -> Result<Json<ReplayConfigResponse>, StatusCode> {
935    authorize(&state, &headers)?;
936    let mut config = trace_weft::ReplayConfig::default();
937    config
938        .mocked_span_ids
939        .insert(req.span_id.clone(), req.mocked_output);
940    config.block_side_effects = req.block_side_effects;
941
942    let safe_name = req
943        .span_name
944        .chars()
945        .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
946        .collect::<String>();
947
948    Ok(Json(ReplayConfigResponse {
949        file_name: format!("replay_config_{safe_name}.json"),
950        command: format!("TRACE_WEFT_REPLAY_FILE=replay_config_{safe_name}.json cargo run"),
951        config,
952    }))
953}
954
955async fn get_pending_approvals(
956    headers: HeaderMap,
957    State(state): State<AppState>,
958) -> Result<Json<Vec<String>>, StatusCode> {
959    authorize(&state, &headers)?;
960    Ok(Json(trace_weft::hitl::get_pending_approvals()))
961}
962
963#[derive(Deserialize)]
964struct ResolveRequest {
965    span_id: String,
966    action: String,
967    value: Option<serde_json::Value>,
968    reason: Option<String>,
969}
970
971async fn resolve_approval(
972    headers: HeaderMap,
973    State(state): State<AppState>,
974    Json(req): Json<ResolveRequest>,
975) -> Result<StatusCode, StatusCode> {
976    authorize(&state, &headers)?;
977    let response = if req.action == "approve" {
978        HitlResponse::Approved(req.value.unwrap_or(serde_json::json!({})))
979    } else {
980        HitlResponse::Rejected(req.reason.unwrap_or_else(|| "Rejected by user".to_string()))
981    };
982
983    if trace_weft::hitl::resolve_approval(&req.span_id, response).is_ok() {
984        Ok(StatusCode::OK)
985    } else {
986        Err(StatusCode::NOT_FOUND)
987    }
988}
989
990#[cfg(test)]
991mod cors_tests {
992    use super::is_allowed_origin;
993
994    #[test]
995    fn allows_local_ui_and_tauri_origins() {
996        for origin in [
997            "http://localhost:5173",
998            "http://127.0.0.1:5173",
999            "http://localhost:3000",
1000            "http://localhost",
1001            "http://127.0.0.1",
1002            "tauri://localhost",
1003            "http://tauri.localhost",
1004        ] {
1005            assert!(is_allowed_origin(origin), "{origin} should be allowed");
1006        }
1007    }
1008
1009    #[test]
1010    fn rejects_external_and_lookalike_origins() {
1011        for origin in [
1012            "https://evil.example.com",
1013            "http://localhost.evil.com",
1014            "http://127.0.0.1.evil.com",
1015            "https://localhost:5173",
1016            "http://evil.com",
1017            "null",
1018        ] {
1019            assert!(!is_allowed_origin(origin), "{origin} should be rejected");
1020        }
1021    }
1022}