trace-weft-server 0.3.5

Axum API and query layer for TraceWeft (SQLite, Postgres, auth, OTLP ingest)
Documentation
use anyhow::Result;
use sqlx::{PgPool, Postgres, postgres::PgArguments, postgres::PgPoolOptions, query::Query};
use trace_weft_core::{EventRecord, SpanRecord};
use trace_weft_recorder::TraceStore;

pub struct PostgresRecorder {
    pub pool: PgPool,
}

impl PostgresRecorder {
    pub async fn new(db_url: &str) -> Result<Self> {
        let pool = PgPoolOptions::new()
            .max_connections(5)
            .connect(db_url)
            .await?;

        Self::from_pool(pool).await
    }

    /// Wrap an existing pool, creating the schema on first use. The server
    /// constructs the recorder from its own pool (so it shares connection
    /// settings), so schema creation must live here rather than only in
    /// [`new`] — otherwise a fresh Postgres has no tables.
    pub async fn from_pool(pool: PgPool) -> Result<Self> {
        // `raw_sql` runs the whole multi-statement block unprepared; `query`
        // would prepare it and Postgres rejects multiple commands in one
        // prepared statement ("cannot insert multiple commands…").
        sqlx::raw_sql(
            r#"
            CREATE TABLE IF NOT EXISTS spans (
                trace_id TEXT NOT NULL,
                span_id TEXT NOT NULL PRIMARY KEY,
                parent_span_id TEXT,
                run_id TEXT NOT NULL,
                session_id TEXT,
                user_id_hash TEXT,
                span_kind TEXT NOT NULL,
                name TEXT NOT NULL,
                start_time BIGINT NOT NULL,
                end_time BIGINT,
                status TEXT NOT NULL,
                status_message TEXT,
                error_type TEXT,
                error_message_redacted TEXT,
                attributes TEXT NOT NULL,
                otel_attributes TEXT NOT NULL,
                openinference_attributes TEXT NOT NULL,
                memory_state TEXT,
                input_ref TEXT,
                output_ref TEXT,
                prompt_template_id TEXT,
                prompt_version TEXT,
                model_provider TEXT,
                model_name TEXT,
                tool_name TEXT,
                tool_schema_hash TEXT,
                retrieval_query_hash TEXT,
                retrieved_document_refs TEXT NOT NULL,
                token_usage TEXT,
                cost_estimate TEXT,
                latency_ms BIGINT,
                retry_count BIGINT,
                cache_hit BOOLEAN,
                redaction_policy TEXT NOT NULL,
                schema_version TEXT NOT NULL,
                project_id TEXT
            );
            CREATE INDEX IF NOT EXISTS idx_spans_trace_id ON spans(trace_id);
            CREATE INDEX IF NOT EXISTS idx_spans_run_id ON spans(run_id);
            CREATE INDEX IF NOT EXISTS idx_spans_project_id ON spans(project_id);

            CREATE TABLE IF NOT EXISTS events (
                event_id TEXT NOT NULL PRIMARY KEY,
                trace_id TEXT NOT NULL,
                run_id TEXT NOT NULL,
                parent_span_id TEXT,
                seq BIGINT NOT NULL,
                event_kind TEXT NOT NULL,
                name TEXT NOT NULL,
                timestamp BIGINT NOT NULL,
                attributes TEXT NOT NULL,
                schema_version TEXT NOT NULL
            );
            CREATE INDEX IF NOT EXISTS idx_events_trace_id ON events(trace_id);
            CREATE INDEX IF NOT EXISTS idx_events_parent_span_id ON events(parent_span_id);
            "#,
        )
        .execute(&pool)
        .await?;

        Ok(Self { pool })
    }
}

#[async_trait::async_trait]
impl TraceStore for PostgresRecorder {
    async fn record_span(&self, span: SpanRecord) -> Result<()> {
        let trace_id = span.trace_id.0.to_string();
        let span_id = span.span_id.0.to_string();
        let parent_span_id = span.parent_span_id.map(|id| id.0.to_string());
        let run_id = span.run_id.0.to_string();
        let session_id = span.session_id.map(|id| id.0.to_string());
        let span_kind = serde_json::to_string(&span.span_kind)?
            .trim_matches('"')
            .to_string();
        let status = serde_json::to_string(&span.status)?
            .trim_matches('"')
            .to_string();

        let attributes = serde_json::to_string(&span.attributes)?;
        let otel_attributes = serde_json::to_string(&span.otel_attributes)?;
        let openinference_attributes = serde_json::to_string(&span.openinference_attributes)?;
        let memory_state = span
            .memory_state
            .map(|s| serde_json::to_string(&s).unwrap());

        let input_ref = span.input_ref.map(|r| serde_json::to_string(&r).unwrap());
        let output_ref = span.output_ref.map(|r| serde_json::to_string(&r).unwrap());
        let retrieved_document_refs = serde_json::to_string(&span.retrieved_document_refs)?;
        let token_usage = span.token_usage.map(|u| serde_json::to_string(&u).unwrap());
        let cost_estimate = span
            .cost_estimate
            .map(|c| serde_json::to_string(&c).unwrap());
        let redaction_policy = serde_json::to_string(&span.redaction_policy)?
            .trim_matches('"')
            .to_string();

        let q = sqlx::query(
            r#"
            INSERT INTO spans (
                trace_id, span_id, parent_span_id, run_id, session_id, user_id_hash,
                span_kind, name, start_time, end_time, status, status_message, error_type, error_message_redacted,
                attributes, otel_attributes, openinference_attributes, memory_state,
                input_ref, output_ref, prompt_template_id, prompt_version,
                model_provider, model_name, tool_name, tool_schema_hash, retrieval_query_hash,
                retrieved_document_refs, token_usage, cost_estimate, latency_ms, retry_count, cache_hit,
                redaction_policy, schema_version, project_id
            ) VALUES (
                $1, $2, $3, $4, $5, $6,
                $7, $8, $9, $10, $11, $12, $13, $14,
                $15, $16, $17, $18,
                $19, $20, $21, $22,
                $23, $24, $25, $26, $27,
                $28, $29, $30, $31, $32, $33,
                $34, $35, $36
            )
            -- A span may be recorded twice with the same span_id (e.g. a HITL
            -- breakpoint: first PendingApproval, then Ok once resolved). Upsert
            -- so the resolved state replaces the pending row; `DO NOTHING` would
            -- silently discard it. For ordinary single-write spans the conflict
            -- arm never fires.
            ON CONFLICT (span_id) DO UPDATE SET
                trace_id=EXCLUDED.trace_id, parent_span_id=EXCLUDED.parent_span_id,
                run_id=EXCLUDED.run_id, session_id=EXCLUDED.session_id,
                user_id_hash=EXCLUDED.user_id_hash, span_kind=EXCLUDED.span_kind,
                name=EXCLUDED.name, start_time=EXCLUDED.start_time, end_time=EXCLUDED.end_time,
                status=EXCLUDED.status, status_message=EXCLUDED.status_message,
                error_type=EXCLUDED.error_type, error_message_redacted=EXCLUDED.error_message_redacted,
                attributes=EXCLUDED.attributes, otel_attributes=EXCLUDED.otel_attributes,
                openinference_attributes=EXCLUDED.openinference_attributes,
                memory_state=EXCLUDED.memory_state, input_ref=EXCLUDED.input_ref,
                output_ref=EXCLUDED.output_ref, prompt_template_id=EXCLUDED.prompt_template_id,
                prompt_version=EXCLUDED.prompt_version, model_provider=EXCLUDED.model_provider,
                model_name=EXCLUDED.model_name, tool_name=EXCLUDED.tool_name,
                tool_schema_hash=EXCLUDED.tool_schema_hash,
                retrieval_query_hash=EXCLUDED.retrieval_query_hash,
                retrieved_document_refs=EXCLUDED.retrieved_document_refs,
                token_usage=EXCLUDED.token_usage, cost_estimate=EXCLUDED.cost_estimate,
                latency_ms=EXCLUDED.latency_ms, retry_count=EXCLUDED.retry_count,
                cache_hit=EXCLUDED.cache_hit, redaction_policy=EXCLUDED.redaction_policy,
                schema_version=EXCLUDED.schema_version, project_id=EXCLUDED.project_id
            "#,
        );

        let q: Query<'_, Postgres, PgArguments> = q;
        q.bind(trace_id)
            .bind(span_id)
            .bind(parent_span_id)
            .bind(run_id)
            .bind(session_id)
            .bind(span.user_id_hash)
            .bind(span_kind)
            .bind(span.name)
            .bind(span.start_time as i64)
            .bind(span.end_time.map(|t| t as i64))
            .bind(status)
            .bind(span.status_message)
            .bind(span.error_type)
            .bind(span.error_message_redacted)
            .bind(attributes)
            .bind(otel_attributes)
            .bind(openinference_attributes)
            .bind(memory_state)
            .bind(input_ref)
            .bind(output_ref)
            .bind(span.prompt_template_id)
            .bind(span.prompt_version)
            .bind(span.model_provider)
            .bind(span.model_name)
            .bind(span.tool_name)
            .bind(span.tool_schema_hash)
            .bind(span.retrieval_query_hash)
            .bind(retrieved_document_refs)
            .bind(token_usage)
            .bind(cost_estimate)
            .bind(span.latency_ms.map(|t| t as i64))
            .bind(span.retry_count.map(|c| c as i64))
            .bind(span.cache_hit)
            .bind(redaction_policy)
            .bind(span.schema_version)
            .bind(span.project_id)
            .execute(&self.pool)
            .await?;

        Ok(())
    }

    async fn record_event(&self, event: EventRecord) -> Result<()> {
        let event_kind = serde_json::to_string(&event.event_kind)?
            .trim_matches('"')
            .to_string();
        let attributes = serde_json::to_string(&event.attributes)?;

        let q = sqlx::query(
            r#"
            INSERT INTO events (
                event_id, trace_id, run_id, parent_span_id, seq,
                event_kind, name, timestamp, attributes, schema_version
            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
            ON CONFLICT (event_id) DO NOTHING
            "#,
        );

        let q: Query<'_, Postgres, PgArguments> = q;
        q.bind(event.event_id.0.to_string())
            .bind(event.trace_id.0.to_string())
            .bind(event.run_id.0.to_string())
            .bind(event.parent_span_id.map(|id| id.0.to_string()))
            .bind(event.seq as i64)
            .bind(event_kind)
            .bind(event.name)
            .bind(event.timestamp as i64)
            .bind(attributes)
            .bind(event.schema_version)
            .execute(&self.pool)
            .await?;

        Ok(())
    }
}