1pub mod auth;
2pub mod storage;
3
4use auth::{Auth, AuthConfig};
5use axum::{
6 Json, Router,
7 body::Body,
8 extract::{Path, State},
9 http::{HeaderMap, HeaderValue, Method, StatusCode, header},
10 response::Response,
11 routing::{get, post},
12};
13use sqlx::{PgPool, Row, SqlitePool, postgres::PgPoolOptions, sqlite::SqlitePoolOptions};
14use std::net::SocketAddr;
15use std::path::PathBuf;
16use std::sync::Arc;
17use tower_http::cors::{AllowOrigin, CorsLayer};
18use trace_weft_core::{BlobHash, SpanRecord};
19use trace_weft_recorder::TraceStore;
20
21#[derive(Clone)]
22pub enum DbPool {
23 Sqlite(SqlitePool),
24 Postgres(PgPool),
25}
26
27#[derive(Clone)]
28pub struct AppState {
29 pub pool: DbPool,
30 pub blob_store: Arc<dyn trace_weft_core::BlobStore>,
31 pub trace_store: Arc<dyn TraceStore>,
32 pub clickhouse: Option<Arc<storage::analytics::ClickHouseAnalytics>>,
33 pub auth: Arc<AuthConfig>,
34}
35
36pub async fn start_server(db_url: &str, port: u16, blob_dir: PathBuf) -> anyhow::Result<()> {
41 start_server_with_shutdown(
42 db_url,
43 port,
44 blob_dir,
45 AuthConfig::from_env(),
46 std::future::pending::<()>(),
47 )
48 .await
49}
50
51pub async fn start_dev_server(db_url: &str, port: u16, blob_dir: PathBuf) -> anyhow::Result<()> {
55 start_server_with_shutdown(
56 db_url,
57 port,
58 blob_dir,
59 AuthConfig::from_env_local_first(),
60 std::future::pending::<()>(),
61 )
62 .await
63}
64
65pub async fn start_server_with_shutdown(
69 db_url: &str,
70 port: u16,
71 blob_dir: PathBuf,
72 auth: AuthConfig,
73 shutdown: impl std::future::Future<Output = ()> + Send + 'static,
74) -> anyhow::Result<()> {
75 let pool = if db_url.starts_with("postgres://") || db_url.starts_with("postgresql://") {
76 let pg_pool = PgPoolOptions::new().connect(db_url).await?;
77 DbPool::Postgres(pg_pool)
78 } else {
79 let url = if db_url.starts_with("sqlite://") {
81 db_url.to_string()
82 } else {
83 if let Some(parent) = std::path::Path::new(db_url).parent() {
84 tokio::fs::create_dir_all(parent).await?;
85 }
86 format!("sqlite://{}?mode=rwc", db_url)
87 };
88 let sq_pool = SqlitePoolOptions::new().connect(&url).await?;
89 DbPool::Sqlite(sq_pool)
90 };
91
92 let blob_store = Arc::new(storage::blob::LocalBlobStore::new(blob_dir));
93
94 let trace_store: Arc<dyn TraceStore> = match &pool {
95 DbPool::Postgres(pg_pool) => {
96 Arc::new(storage::postgres::PostgresRecorder::from_pool(pg_pool.clone()).await?)
97 }
98 DbPool::Sqlite(sq_pool) => {
99 Arc::new(trace_weft_recorder::sqlite::SqliteRecorder::from_pool(sq_pool.clone()).await?)
100 }
101 };
102
103 let clickhouse = if let Ok(ch_url) = std::env::var("TRACE_WEFT_CH_URL") {
105 tracing::info!("Initializing ClickHouse analytics connected to {}", ch_url);
106 Some(Arc::new(storage::analytics::ClickHouseAnalytics::new(
107 &ch_url, "default", "", "default",
108 )))
109 } else {
110 None
111 };
112
113 let state = AppState {
114 pool,
115 blob_store,
116 trace_store,
117 clickhouse,
118 auth: Arc::new(auth),
119 };
120
121 let app = build_router(state);
122
123 let addr = SocketAddr::from(([127, 0, 0, 1], port));
124 tracing::info!("Server listening on http://{}", addr);
125
126 let listener = tokio::net::TcpListener::bind(addr).await?;
127 axum::serve(listener, app)
128 .with_graceful_shutdown(shutdown)
129 .await?;
130
131 Ok(())
132}
133
134pub fn build_router(state: AppState) -> Router {
136 Router::new()
137 .route("/api/traces", get(list_traces))
138 .route("/api/traces/{trace_id}", get(get_trace))
139 .route("/api/traces/{trace_id}/events", get(get_trace_events))
140 .route(
141 "/api/traces/{trace_id}/replay-plan/{span_id}",
142 get(get_replay_plan),
143 )
144 .route("/api/diff/{trace_a}/{trace_b}", get(get_trace_diff))
145 .route("/api/blobs/{hash}", get(get_blob))
146 .route("/api/openapi.json", get(openapi_contract))
147 .route("/api/evals", get(list_evals))
148 .route("/api/v1/batch", post(batch_ingest))
149 .route("/v1/traces", post(otlp_traces_ingest))
150 .route("/api/replay/config", post(generate_replay_config))
151 .route("/api/hitl/pending", get(get_pending_approvals))
152 .route("/api/hitl/resolve", post(resolve_approval))
153 .layer(local_cors())
154 .with_state(state)
155}
156
157fn local_cors() -> CorsLayer {
163 CorsLayer::new()
164 .allow_methods([Method::GET, Method::POST, Method::OPTIONS])
165 .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE])
166 .allow_origin(AllowOrigin::predicate(|origin: &HeaderValue, _req| {
167 origin.to_str().map(is_allowed_origin).unwrap_or(false)
168 }))
169}
170
171fn is_allowed_origin(origin: &str) -> bool {
174 if origin == "tauri://localhost" || origin == "http://tauri.localhost" {
175 return true;
176 }
177 ["http://localhost", "http://127.0.0.1"]
178 .iter()
179 .any(|host| origin == *host || origin.starts_with(&format!("{host}:")))
180}
181
182fn authorize(state: &AppState, headers: &HeaderMap) -> Result<Auth, StatusCode> {
185 state
186 .auth
187 .authenticate(headers)
188 .ok_or(StatusCode::UNAUTHORIZED)
189}
190
191async fn batch_ingest(
192 headers: HeaderMap,
193 State(state): State<AppState>,
194 Json(mut spans): Json<Vec<SpanRecord>>,
195) -> Result<StatusCode, StatusCode> {
196 let auth = authorize(&state, &headers)?;
197 let project_id = auth.project().map(|p| p.to_string());
200 for span in &mut spans {
201 span.project_id = project_id.clone();
202 }
203
204 tracing::info!(
205 "Received batch of {} spans for project {:?}",
206 spans.len(),
207 project_id
208 );
209
210 for span in &spans {
212 if let Err(e) = state.trace_store.record_span(span.clone()).await {
214 tracing::error!("Failed to record span: {}", e);
215 return Err(StatusCode::INTERNAL_SERVER_ERROR);
216 }
217 }
218
219 if let Some(ch) = &state.clickhouse
221 && let Err(e) = ch.ingest_batch(&spans).await
222 {
223 tracing::warn!("Failed to stream to ClickHouse: {}", e);
224 }
225
226 Ok(StatusCode::ACCEPTED)
227}
228
229async fn otlp_traces_ingest(
235 headers: HeaderMap,
236 State(state): State<AppState>,
237 body: axum::body::Bytes,
238) -> Result<Json<serde_json::Value>, StatusCode> {
239 let auth = authorize(&state, &headers)?;
240 let project_id = auth.project().map(|p| p.to_string());
241
242 let mut spans = trace_weft_ingest::records_from_otlp_json(&body).map_err(|e| {
243 tracing::warn!("rejecting malformed OTLP payload: {e}");
244 StatusCode::BAD_REQUEST
245 })?;
246 for span in &mut spans {
247 span.project_id = project_id.clone();
248 }
249
250 tracing::info!(
251 "Received OTLP export of {} spans for project {:?}",
252 spans.len(),
253 project_id
254 );
255
256 for span in &spans {
257 if let Err(e) = state.trace_store.record_span(span.clone()).await {
258 tracing::error!("Failed to record OTLP span: {}", e);
259 return Err(StatusCode::INTERNAL_SERVER_ERROR);
260 }
261 }
262
263 if let Some(ch) = &state.clickhouse
264 && let Err(e) = ch.ingest_batch(&spans).await
265 {
266 tracing::warn!("Failed to stream OTLP spans to ClickHouse: {}", e);
267 }
268
269 Ok(Json(serde_json::json!({})))
270}
271
272fn db_error<E: std::fmt::Display>(e: E) -> StatusCode {
275 tracing::error!("database query failed: {e}");
276 StatusCode::INTERNAL_SERVER_ERROR
277}
278
279fn parse_json_column(raw: &str) -> Result<serde_json::Value, StatusCode> {
283 serde_json::from_str(raw).map_err(|e| {
284 tracing::error!("corrupt JSON in spans column: {e}");
285 StatusCode::INTERNAL_SERVER_ERROR
286 })
287}
288
289fn parse_opt_json_column(raw: Option<String>) -> Result<serde_json::Value, StatusCode> {
291 match raw {
292 Some(s) => parse_json_column(&s),
293 None => Ok(serde_json::Value::Null),
294 }
295}
296
297macro_rules! trace_summary_json {
304 ($row:expr) => {{
305 let row = $row;
306 let trace_id: String = row.get("trace_id");
307 let run_id: String = row.get("run_id");
308 let start_time: i64 = row.get("start_time");
309 let end_time: Option<i64> = row.get("end_time");
310 let span_count: i64 = row.get("span_count");
311 let has_error: i64 = row.get("has_error");
312 let root_name: Option<String> = row.get("root_name");
313 let root_span_kind: Option<String> = row.get("root_span_kind");
314 let model_provider: Option<String> = row.get("model_provider");
315 let model_name: Option<String> = row.get("model_name");
316 let error_summary: Option<String> = row.get("error_summary");
317 serde_json::json!({
318 "trace_id": trace_id,
319 "run_id": run_id,
320 "start_time": start_time,
321 "end_time": end_time,
322 "span_count": span_count,
323 "root_name": root_name,
324 "root_span_kind": root_span_kind,
325 "model_provider": model_provider,
326 "model_name": model_name,
327 "error_summary": error_summary,
328 "status": if has_error != 0 { "error" } else { "ok" },
330 })
331 }};
332}
333
334macro_rules! eval_row_json {
336 ($row:expr) => {{
337 let row = $row;
338 let trace_id: String = row.get("trace_id");
339 let span_id: String = row.get("span_id");
340 let name: String = row.get("name");
341 let start_time: i64 = row.get("start_time");
342 let status: String = row.get("status");
343 let attributes: String = row.get("attributes");
344 serde_json::json!({
345 "trace_id": trace_id,
346 "span_id": span_id,
347 "name": name,
348 "start_time": start_time,
349 "status": status,
350 "attributes": parse_json_column(&attributes)?,
351 })
352 }};
353}
354
355macro_rules! span_detail_json {
357 ($row:expr) => {{
358 let row = $row;
359 let trace_id: String = row.get("trace_id");
360 let span_id: String = row.get("span_id");
361 let parent_span_id: Option<String> = row.get("parent_span_id");
362 let run_id: String = row.get("run_id");
363 let session_id: Option<String> = row.get("session_id");
364 let user_id_hash: Option<String> = row.get("user_id_hash");
365 let project_id: Option<String> = row.get("project_id");
366 let span_kind: String = row.get("span_kind");
367 let name: String = row.get("name");
368 let start_time: i64 = row.get("start_time");
369 let end_time: Option<i64> = row.get("end_time");
370 let status: String = row.get("status");
371 let status_message: Option<String> = row.get("status_message");
372 let error_type: Option<String> = row.get("error_type");
373 let error_message_redacted: Option<String> = row.get("error_message_redacted");
374 let attributes: String = row.get("attributes");
375 let otel_attributes: String = row.get("otel_attributes");
376 let openinference_attributes: String = row.get("openinference_attributes");
377 let memory_state: Option<String> = row.get("memory_state");
378 let latency_ms: Option<i64> = row.get("latency_ms");
379 let input_ref: Option<String> = row.get("input_ref");
380 let output_ref: Option<String> = row.get("output_ref");
381 let prompt_template_id: Option<String> = row.get("prompt_template_id");
382 let prompt_version: Option<String> = row.get("prompt_version");
383 let model_provider: Option<String> = row.get("model_provider");
384 let model_name: Option<String> = row.get("model_name");
385 let tool_name: Option<String> = row.get("tool_name");
386 let tool_schema_hash: Option<String> = row.get("tool_schema_hash");
387 let retrieval_query_hash: Option<String> = row.get("retrieval_query_hash");
388 let retrieved_document_refs: String = row.get("retrieved_document_refs");
389 let token_usage: Option<String> = row.get("token_usage");
390 let cost_estimate: Option<String> = row.get("cost_estimate");
391 let retry_count: Option<i64> = row.get("retry_count");
392 let cache_hit: Option<bool> = row.get("cache_hit");
393 let redaction_policy: String = row.get("redaction_policy");
394 let schema_version: String = row.get("schema_version");
395 serde_json::json!({
396 "trace_id": trace_id,
397 "span_id": span_id,
398 "parent_span_id": parent_span_id,
399 "run_id": run_id,
400 "session_id": session_id,
401 "user_id_hash": user_id_hash,
402 "project_id": project_id,
403 "span_kind": span_kind,
404 "name": name,
405 "start_time": start_time,
406 "end_time": end_time,
407 "status": status,
408 "status_message": status_message,
409 "error_type": error_type,
410 "error_message_redacted": error_message_redacted,
411 "attributes": parse_json_column(&attributes)?,
412 "otel_attributes": parse_json_column(&otel_attributes)?,
413 "openinference_attributes": parse_json_column(&openinference_attributes)?,
414 "memory_state": parse_opt_json_column(memory_state)?,
415 "latency_ms": latency_ms,
416 "input_ref": parse_opt_json_column(input_ref)?,
417 "output_ref": parse_opt_json_column(output_ref)?,
418 "prompt_template_id": prompt_template_id,
419 "prompt_version": prompt_version,
420 "model_provider": model_provider,
421 "model_name": model_name,
422 "tool_name": tool_name,
423 "tool_schema_hash": tool_schema_hash,
424 "retrieval_query_hash": retrieval_query_hash,
425 "retrieved_document_refs": parse_json_column(&retrieved_document_refs)?,
426 "token_usage": parse_opt_json_column(token_usage)?,
427 "cost_estimate": parse_opt_json_column(cost_estimate)?,
428 "retry_count": retry_count,
429 "cache_hit": cache_hit,
430 "redaction_policy": redaction_policy,
431 "schema_version": schema_version,
432 })
433 }};
434}
435
436macro_rules! event_detail_json {
438 ($row:expr) => {{
439 let row = $row;
440 let event_id: String = row.get("event_id");
441 let trace_id: String = row.get("trace_id");
442 let run_id: String = row.get("run_id");
443 let parent_span_id: Option<String> = row.get("parent_span_id");
444 let seq: i64 = row.get("seq");
445 let event_kind: String = row.get("event_kind");
446 let name: String = row.get("name");
447 let timestamp: i64 = row.get("timestamp");
448 let attributes: String = row.get("attributes");
449 let schema_version: String = row.get("schema_version");
450 serde_json::json!({
451 "event_id": event_id,
452 "trace_id": trace_id,
453 "run_id": run_id,
454 "parent_span_id": parent_span_id,
455 "seq": seq,
456 "event_kind": event_kind,
457 "name": name,
458 "timestamp": timestamp,
459 "attributes": parse_json_column(&attributes)?,
460 "schema_version": schema_version,
461 })
462 }};
463}
464
465const LIST_TRACES_SQL_SQLITE: &str = r#"
476 SELECT trace_id, MIN(run_id) AS run_id, MIN(start_time) AS start_time,
477 MAX(end_time) AS end_time, COUNT(span_id) AS span_count,
478 CAST(MAX(CASE WHEN status = 'error' THEN 1 ELSE 0 END) AS BIGINT) AS has_error,
479 COALESCE(
480 MIN(CASE WHEN parent_span_id IS NULL THEN name END),
481 MIN(name)
482 ) AS root_name,
483 COALESCE(
484 MIN(CASE WHEN parent_span_id IS NULL THEN span_kind END),
485 MIN(span_kind)
486 ) AS root_span_kind,
487 MAX(model_provider) AS model_provider,
488 MAX(model_name) AS model_name,
489 MIN(CASE WHEN status = 'error' THEN error_message_redacted END) AS error_summary
490 FROM spans
491 WHERE (project_id = ? OR ? IS NULL)
492 GROUP BY trace_id
493 ORDER BY start_time DESC
494 LIMIT 50
495"#;
496
497const LIST_TRACES_SQL_PG: &str = r#"
498 SELECT trace_id, MIN(run_id) AS run_id, MIN(start_time) AS start_time,
499 MAX(end_time) AS end_time, COUNT(span_id) AS span_count,
500 CAST(MAX(CASE WHEN status = 'error' THEN 1 ELSE 0 END) AS BIGINT) AS has_error,
501 COALESCE(
502 MIN(CASE WHEN parent_span_id IS NULL THEN name END),
503 MIN(name)
504 ) AS root_name,
505 COALESCE(
506 MIN(CASE WHEN parent_span_id IS NULL THEN span_kind END),
507 MIN(span_kind)
508 ) AS root_span_kind,
509 MAX(model_provider) AS model_provider,
510 MAX(model_name) AS model_name,
511 MIN(CASE WHEN status = 'error' THEN error_message_redacted END) AS error_summary
512 FROM spans
513 WHERE (project_id = $1 OR $1 IS NULL)
514 GROUP BY trace_id
515 ORDER BY start_time DESC
516 LIMIT 50
517"#;
518
519const LIST_EVALS_SQL_SQLITE: &str = r#"
520 SELECT trace_id, span_id, name, start_time, status, attributes
521 FROM spans
522 WHERE (span_kind = 'evaluator' OR span_kind = 'Evaluator')
523 AND (project_id = ? OR ? IS NULL)
524 ORDER BY start_time DESC
525 LIMIT 50
526"#;
527
528const LIST_EVALS_SQL_PG: &str = r#"
529 SELECT trace_id, span_id, name, start_time, status, attributes
530 FROM spans
531 WHERE (span_kind = 'evaluator' OR span_kind = 'Evaluator')
532 AND (project_id = $1 OR $1 IS NULL)
533 ORDER BY start_time DESC
534 LIMIT 50
535"#;
536
537const GET_TRACE_SQL_SQLITE: &str = "SELECT * FROM spans WHERE trace_id = ? AND (project_id = ? OR ? IS NULL) ORDER BY start_time ASC";
538
539const GET_TRACE_SQL_PG: &str = "SELECT * FROM spans WHERE trace_id = $1 AND (project_id = $2 OR $2 IS NULL) ORDER BY start_time ASC";
540
541const GET_TRACE_EVENTS_SQL_SQLITE: &str = r#"
542 SELECT e.*
543 FROM events e
544 WHERE e.trace_id = ?
545 AND EXISTS (
546 SELECT 1 FROM spans s
547 WHERE s.trace_id = e.trace_id
548 AND (s.project_id = ? OR ? IS NULL)
549 )
550 ORDER BY e.timestamp ASC, e.seq ASC
551"#;
552
553const GET_TRACE_EVENTS_SQL_PG: &str = r#"
554 SELECT e.*
555 FROM events e
556 WHERE e.trace_id = $1
557 AND EXISTS (
558 SELECT 1 FROM spans s
559 WHERE s.trace_id = e.trace_id
560 AND (s.project_id = $2 OR $2 IS NULL)
561 )
562 ORDER BY e.timestamp ASC, e.seq ASC
563"#;
564
565async fn list_traces(
566 headers: HeaderMap,
567 State(state): State<AppState>,
568) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
569 let project = authorize(&state, &headers)?.project().map(str::to_string);
570 let mut traces = Vec::new();
571 match &state.pool {
572 DbPool::Sqlite(pool) => {
573 let rows = sqlx::query(LIST_TRACES_SQL_SQLITE)
574 .bind(project.clone())
575 .bind(project)
576 .fetch_all(pool)
577 .await
578 .map_err(db_error)?;
579 for row in &rows {
580 traces.push(trace_summary_json!(row));
581 }
582 }
583 DbPool::Postgres(pool) => {
584 let rows = sqlx::query(LIST_TRACES_SQL_PG)
585 .bind(project)
586 .fetch_all(pool)
587 .await
588 .map_err(db_error)?;
589 for row in &rows {
590 traces.push(trace_summary_json!(row));
591 }
592 }
593 }
594 Ok(Json(traces))
595}
596
597async fn list_evals(
598 headers: HeaderMap,
599 State(state): State<AppState>,
600) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
601 let project = authorize(&state, &headers)?.project().map(str::to_string);
602 let mut evals = Vec::new();
603 match &state.pool {
604 DbPool::Sqlite(pool) => {
605 let rows = sqlx::query(LIST_EVALS_SQL_SQLITE)
606 .bind(project.clone())
607 .bind(project)
608 .fetch_all(pool)
609 .await
610 .map_err(db_error)?;
611 for row in &rows {
612 evals.push(eval_row_json!(row));
613 }
614 }
615 DbPool::Postgres(pool) => {
616 let rows = sqlx::query(LIST_EVALS_SQL_PG)
617 .bind(project)
618 .fetch_all(pool)
619 .await
620 .map_err(db_error)?;
621 for row in &rows {
622 evals.push(eval_row_json!(row));
623 }
624 }
625 }
626 Ok(Json(evals))
627}
628
629async fn get_trace(
630 Path(trace_id): Path<String>,
631 headers: HeaderMap,
632 State(state): State<AppState>,
633) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
634 let project = authorize(&state, &headers)?.project().map(str::to_string);
635 Ok(Json(
636 trace_rows_for_project(&state, &trace_id, project).await?,
637 ))
638}
639
640async fn trace_rows_for_project(
641 state: &AppState,
642 trace_id: &str,
643 project: Option<String>,
644) -> Result<Vec<serde_json::Value>, StatusCode> {
645 let mut spans = Vec::new();
646 match &state.pool {
647 DbPool::Sqlite(pool) => {
648 let rows = sqlx::query(GET_TRACE_SQL_SQLITE)
649 .bind(trace_id)
650 .bind(project.clone())
651 .bind(project)
652 .fetch_all(pool)
653 .await
654 .map_err(db_error)?;
655 for row in &rows {
656 spans.push(span_detail_json!(row));
657 }
658 }
659 DbPool::Postgres(pool) => {
660 let rows = sqlx::query(GET_TRACE_SQL_PG)
661 .bind(trace_id)
662 .bind(project)
663 .fetch_all(pool)
664 .await
665 .map_err(db_error)?;
666 for row in &rows {
667 spans.push(span_detail_json!(row));
668 }
669 }
670 }
671 Ok(spans)
672}
673
674async fn get_trace_events(
675 Path(trace_id): Path<String>,
676 headers: HeaderMap,
677 State(state): State<AppState>,
678) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
679 let project = authorize(&state, &headers)?.project().map(str::to_string);
680 let mut events = Vec::new();
681 match &state.pool {
682 DbPool::Sqlite(pool) => {
683 let rows = sqlx::query(GET_TRACE_EVENTS_SQL_SQLITE)
684 .bind(trace_id)
685 .bind(project.clone())
686 .bind(project)
687 .fetch_all(pool)
688 .await
689 .map_err(db_error)?;
690 for row in &rows {
691 events.push(event_detail_json!(row));
692 }
693 }
694 DbPool::Postgres(pool) => {
695 let rows = sqlx::query(GET_TRACE_EVENTS_SQL_PG)
696 .bind(trace_id)
697 .bind(project)
698 .fetch_all(pool)
699 .await
700 .map_err(db_error)?;
701 for row in &rows {
702 events.push(event_detail_json!(row));
703 }
704 }
705 }
706 Ok(Json(events))
707}
708
709async fn openapi_contract() -> Result<Json<serde_json::Value>, StatusCode> {
710 serde_json::from_str(include_str!("../openapi/trace-weft.openapi.json"))
711 .map(Json)
712 .map_err(|e| {
713 tracing::error!("embedded OpenAPI contract is invalid JSON: {e}");
714 StatusCode::INTERNAL_SERVER_ERROR
715 })
716}
717
718fn trace_span_key(span: &serde_json::Value) -> String {
719 format!(
720 "{}::{}",
721 span.get("span_kind")
722 .and_then(serde_json::Value::as_str)
723 .unwrap_or("unknown"),
724 span.get("name")
725 .and_then(serde_json::Value::as_str)
726 .unwrap_or("unnamed")
727 )
728}
729
730fn field_changed(a: &serde_json::Value, b: &serde_json::Value, key: &str) -> Option<String> {
731 (a.get(key) != b.get(key)).then(|| key.to_string())
732}
733
734fn diff_rows(
735 spans_a: Vec<serde_json::Value>,
736 spans_b: Vec<serde_json::Value>,
737) -> (Vec<serde_json::Value>, serde_json::Value) {
738 let mut used_b = std::collections::HashSet::new();
739 let mut b_by_key: std::collections::HashMap<String, Vec<usize>> =
740 std::collections::HashMap::new();
741 for (idx, span) in spans_b.iter().enumerate() {
742 b_by_key.entry(trace_span_key(span)).or_default().push(idx);
743 }
744
745 let mut rows = Vec::new();
746 let mut changed = 0usize;
747 let mut removed = 0usize;
748 let mut matched = 0usize;
749
750 for (idx_a, span_a) in spans_a.iter().enumerate() {
751 let key = trace_span_key(span_a);
752 let match_idx = b_by_key
753 .get(&key)
754 .and_then(|candidates| candidates.iter().find(|idx| !used_b.contains(*idx)))
755 .copied();
756
757 if let Some(idx_b) = match_idx {
758 used_b.insert(idx_b);
759 matched += 1;
760 let span_b = &spans_b[idx_b];
761 let changed_fields: Vec<String> = [
762 "status",
763 "latency_ms",
764 "attributes",
765 "token_usage",
766 "cost_estimate",
767 "prompt_version",
768 "model_name",
769 "retrieval_query_hash",
770 ]
771 .iter()
772 .filter_map(|field| field_changed(span_a, span_b, field))
773 .collect();
774 if changed_fields.is_empty() {
775 rows.push(serde_json::json!({
776 "key": format!("pair-{idx_a}-{idx_b}"),
777 "change": "unchanged",
778 "changed_fields": changed_fields,
779 "a": span_a,
780 "b": span_b,
781 }));
782 } else {
783 changed += 1;
784 rows.push(serde_json::json!({
785 "key": format!("pair-{idx_a}-{idx_b}"),
786 "change": "changed",
787 "changed_fields": changed_fields,
788 "a": span_a,
789 "b": span_b,
790 }));
791 }
792 } else {
793 removed += 1;
794 rows.push(serde_json::json!({
795 "key": format!("a-{idx_a}"),
796 "change": "removed",
797 "changed_fields": [],
798 "a": span_a,
799 "b": null,
800 }));
801 }
802 }
803
804 let mut added = 0usize;
805 for (idx_b, span_b) in spans_b.iter().enumerate() {
806 if !used_b.contains(&idx_b) {
807 added += 1;
808 rows.push(serde_json::json!({
809 "key": format!("b-{idx_b}"),
810 "change": "added",
811 "changed_fields": [],
812 "a": null,
813 "b": span_b,
814 }));
815 }
816 }
817
818 (
819 rows,
820 serde_json::json!({
821 "changed": changed,
822 "added": added,
823 "removed": removed,
824 "matched": matched,
825 }),
826 )
827}
828
829async fn get_trace_diff(
830 Path((trace_a, trace_b)): Path<(String, String)>,
831 headers: HeaderMap,
832 State(state): State<AppState>,
833) -> Result<Json<serde_json::Value>, StatusCode> {
834 let project = authorize(&state, &headers)?.project().map(str::to_string);
835 let spans_a = trace_rows_for_project(&state, &trace_a, project.clone()).await?;
836 let spans_b = trace_rows_for_project(&state, &trace_b, project).await?;
837 let (rows, summary) = diff_rows(spans_a, spans_b);
838 Ok(Json(serde_json::json!({
839 "trace_a": trace_a,
840 "trace_b": trace_b,
841 "summary": summary,
842 "rows": rows,
843 })))
844}
845
846async fn get_replay_plan(
847 Path((trace_id, span_id)): Path<(String, String)>,
848 headers: HeaderMap,
849 State(state): State<AppState>,
850) -> Result<Json<serde_json::Value>, StatusCode> {
851 let project = authorize(&state, &headers)?.project().map(str::to_string);
852 let spans = trace_rows_for_project(&state, &trace_id, project).await?;
853 let Some(target) = spans.iter().find(|span| {
854 span.get("span_id")
855 .and_then(serde_json::Value::as_str)
856 .is_some_and(|id| id == span_id)
857 }) else {
858 return Err(StatusCode::NOT_FOUND);
859 };
860
861 let span_name = target
862 .get("name")
863 .and_then(serde_json::Value::as_str)
864 .unwrap_or("span");
865 let mut mocked_span_ids = serde_json::Map::new();
866 mocked_span_ids.insert(
867 span_id.clone(),
868 target
869 .get("output_ref")
870 .cloned()
871 .unwrap_or(serde_json::Value::Null),
872 );
873 Ok(Json(serde_json::json!({
874 "trace_id": trace_id,
875 "target_span": target,
876 "config_template": {
877 "mocked_spans": {},
878 "mocked_span_ids": mocked_span_ids,
879 "block_side_effects": true
880 },
881 "command": format!("TRACE_WEFT_REPLAY_FILE=replay_config_{span_name}.json cargo run"),
882 })))
883}
884
885async fn get_blob(
886 Path(hash): Path<String>,
887 headers: HeaderMap,
888 State(state): State<AppState>,
889) -> Result<Response<Body>, StatusCode> {
890 authorize(&state, &headers)?;
891 let hash = BlobHash(hash);
892 let Some(bytes) = state.blob_store.get_blob(&hash).await.map_err(db_error)? else {
893 return Err(StatusCode::NOT_FOUND);
894 };
895
896 Response::builder()
897 .status(StatusCode::OK)
898 .header(header::CONTENT_TYPE, "application/octet-stream")
899 .header(header::CACHE_CONTROL, "no-store")
900 .body(Body::from(bytes))
901 .map_err(|e| {
902 tracing::error!("failed to build blob response: {e}");
903 StatusCode::INTERNAL_SERVER_ERROR
904 })
905}
906
907use serde::{Deserialize, Serialize};
908use trace_weft::hitl::HitlResponse;
909
910#[derive(Deserialize)]
911struct ReplayConfigRequest {
912 span_id: String,
913 span_name: String,
914 mocked_output: serde_json::Value,
915 #[serde(default = "default_block_side_effects")]
916 block_side_effects: bool,
917}
918
919#[derive(Serialize)]
920struct ReplayConfigResponse {
921 file_name: String,
922 command: String,
923 config: trace_weft::ReplayConfig,
924}
925
926fn default_block_side_effects() -> bool {
927 true
928}
929
930async fn generate_replay_config(
931 headers: HeaderMap,
932 State(state): State<AppState>,
933 Json(req): Json<ReplayConfigRequest>,
934) -> Result<Json<ReplayConfigResponse>, StatusCode> {
935 authorize(&state, &headers)?;
936 let mut config = trace_weft::ReplayConfig::default();
937 config
938 .mocked_span_ids
939 .insert(req.span_id.clone(), req.mocked_output);
940 config.block_side_effects = req.block_side_effects;
941
942 let safe_name = req
943 .span_name
944 .chars()
945 .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
946 .collect::<String>();
947
948 Ok(Json(ReplayConfigResponse {
949 file_name: format!("replay_config_{safe_name}.json"),
950 command: format!("TRACE_WEFT_REPLAY_FILE=replay_config_{safe_name}.json cargo run"),
951 config,
952 }))
953}
954
955async fn get_pending_approvals(
956 headers: HeaderMap,
957 State(state): State<AppState>,
958) -> Result<Json<Vec<String>>, StatusCode> {
959 authorize(&state, &headers)?;
960 Ok(Json(trace_weft::hitl::get_pending_approvals()))
961}
962
963#[derive(Deserialize)]
964struct ResolveRequest {
965 span_id: String,
966 action: String,
967 value: Option<serde_json::Value>,
968 reason: Option<String>,
969}
970
971async fn resolve_approval(
972 headers: HeaderMap,
973 State(state): State<AppState>,
974 Json(req): Json<ResolveRequest>,
975) -> Result<StatusCode, StatusCode> {
976 authorize(&state, &headers)?;
977 let response = if req.action == "approve" {
978 HitlResponse::Approved(req.value.unwrap_or(serde_json::json!({})))
979 } else {
980 HitlResponse::Rejected(req.reason.unwrap_or_else(|| "Rejected by user".to_string()))
981 };
982
983 if trace_weft::hitl::resolve_approval(&req.span_id, response).is_ok() {
984 Ok(StatusCode::OK)
985 } else {
986 Err(StatusCode::NOT_FOUND)
987 }
988}
989
990#[cfg(test)]
991mod cors_tests {
992 use super::is_allowed_origin;
993
994 #[test]
995 fn allows_local_ui_and_tauri_origins() {
996 for origin in [
997 "http://localhost:5173",
998 "http://127.0.0.1:5173",
999 "http://localhost:3000",
1000 "http://localhost",
1001 "http://127.0.0.1",
1002 "tauri://localhost",
1003 "http://tauri.localhost",
1004 ] {
1005 assert!(is_allowed_origin(origin), "{origin} should be allowed");
1006 }
1007 }
1008
1009 #[test]
1010 fn rejects_external_and_lookalike_origins() {
1011 for origin in [
1012 "https://evil.example.com",
1013 "http://localhost.evil.com",
1014 "http://127.0.0.1.evil.com",
1015 "https://localhost:5173",
1016 "http://evil.com",
1017 "null",
1018 ] {
1019 assert!(!is_allowed_origin(origin), "{origin} should be rejected");
1020 }
1021 }
1022}