systemprompt-logging 0.2.0

Core logging module for systemprompt.io OS
Documentation
use anyhow::Result;
use sqlx::PgPool;
use std::sync::Arc;

use systemprompt_identifiers::AiRequestId;

use super::models::{AuditLookupResult, AuditToolCallRow, ConversationMessage, LinkedMcpCall};

struct AuditRow {
    id: String,
    provider: String,
    model: String,
    input_tokens: Option<i32>,
    output_tokens: Option<i32>,
    cost_microdollars: i64,
    latency_ms: Option<i32>,
    task_id: Option<String>,
    trace_id: Option<String>,
}

struct MsgRow {
    role: String,
    content: String,
    sequence_number: i32,
}

struct ToolCallDbRow {
    tool_name: String,
    tool_input: String,
    sequence_number: i32,
}

struct LinkedMcpDbRow {
    tool_name: String,
    server_name: String,
    status: String,
    execution_time_ms: Option<i32>,
}

pub async fn find_ai_request_for_audit(
    pool: &Arc<PgPool>,
    id: &str,
) -> Result<Option<AuditLookupResult>> {
    let partial = format!("{id}%");

    if let Some(row) = find_audit_by_request_id(pool, id, &partial).await? {
        return Ok(Some(row));
    }
    if let Some(row) = find_audit_by_task_id(pool, id, &partial).await? {
        return Ok(Some(row));
    }
    find_audit_by_trace_id(pool, id, &partial).await
}

async fn find_audit_by_request_id(
    pool: &Arc<PgPool>,
    id: &str,
    partial: &str,
) -> Result<Option<AuditLookupResult>> {
    let row = sqlx::query_as!(
        AuditRow,
        r#"
        SELECT id as "id!", provider as "provider!", model as "model!",
            input_tokens, output_tokens, cost_microdollars as "cost_microdollars!",
            latency_ms, task_id, trace_id
        FROM ai_requests WHERE id = $1 OR id LIKE $2 LIMIT 1
        "#,
        id,
        partial
    )
    .fetch_optional(&**pool)
    .await?;

    Ok(row.map(audit_row_to_result))
}

async fn find_audit_by_task_id(
    pool: &Arc<PgPool>,
    id: &str,
    partial: &str,
) -> Result<Option<AuditLookupResult>> {
    let row = sqlx::query_as!(
        AuditRow,
        r#"
        SELECT id as "id!", provider as "provider!", model as "model!",
            input_tokens, output_tokens, cost_microdollars as "cost_microdollars!",
            latency_ms, task_id, trace_id
        FROM ai_requests WHERE task_id = $1 OR task_id LIKE $2
        ORDER BY created_at DESC LIMIT 1
        "#,
        id,
        partial
    )
    .fetch_optional(&**pool)
    .await?;

    Ok(row.map(audit_row_to_result))
}

async fn find_audit_by_trace_id(
    pool: &Arc<PgPool>,
    id: &str,
    partial: &str,
) -> Result<Option<AuditLookupResult>> {
    let row = sqlx::query_as!(
        AuditRow,
        r#"
        SELECT id as "id!", provider as "provider!", model as "model!",
            input_tokens, output_tokens, cost_microdollars as "cost_microdollars!",
            latency_ms, task_id, trace_id
        FROM ai_requests WHERE trace_id = $1 OR trace_id LIKE $2
        ORDER BY created_at DESC LIMIT 1
        "#,
        id,
        partial
    )
    .fetch_optional(&**pool)
    .await?;

    Ok(row.map(audit_row_to_result))
}

fn audit_row_to_result(r: AuditRow) -> AuditLookupResult {
    AuditLookupResult {
        id: AiRequestId::new(r.id),
        provider: r.provider,
        model: r.model,
        input_tokens: r.input_tokens,
        output_tokens: r.output_tokens,
        cost_microdollars: r.cost_microdollars,
        latency_ms: r.latency_ms,
        task_id: r.task_id.map(Into::into),
        trace_id: r.trace_id.map(Into::into),
    }
}

pub async fn list_audit_messages(
    pool: &Arc<PgPool>,
    request_id: &str,
) -> Result<Vec<ConversationMessage>> {
    let rows = sqlx::query_as!(
        MsgRow,
        r#"
        SELECT role as "role!", content as "content!", sequence_number as "sequence_number!"
        FROM ai_request_messages WHERE request_id = $1 ORDER BY sequence_number
        "#,
        request_id
    )
    .fetch_all(&**pool)
    .await?;

    Ok(rows
        .into_iter()
        .map(|m| ConversationMessage {
            role: m.role,
            content: m.content,
            sequence_number: m.sequence_number,
        })
        .collect())
}

pub async fn list_audit_tool_calls(
    pool: &Arc<PgPool>,
    request_id: &str,
) -> Result<Vec<AuditToolCallRow>> {
    let rows = sqlx::query_as!(
        ToolCallDbRow,
        r#"
        SELECT tool_name as "tool_name!", tool_input as "tool_input!",
            sequence_number as "sequence_number!"
        FROM ai_request_tool_calls WHERE request_id = $1 ORDER BY sequence_number
        "#,
        request_id
    )
    .fetch_all(&**pool)
    .await?;

    Ok(rows
        .into_iter()
        .map(|t| AuditToolCallRow {
            tool_name: t.tool_name,
            tool_input: t.tool_input,
            sequence_number: t.sequence_number,
        })
        .collect())
}

pub async fn list_linked_mcp_calls(
    pool: &Arc<PgPool>,
    request_id: &str,
) -> Result<Vec<LinkedMcpCall>> {
    let rows = sqlx::query_as!(
        LinkedMcpDbRow,
        r#"
        SELECT
            mte.tool_name as "tool_name!",
            mte.server_name as "server_name!",
            mte.status as "status!",
            mte.execution_time_ms
        FROM mcp_tool_executions mte
        JOIN ai_request_tool_calls artc ON artc.mcp_execution_id = mte.mcp_execution_id
        WHERE artc.request_id = $1
        "#,
        request_id
    )
    .fetch_all(&**pool)
    .await?;

    Ok(rows
        .into_iter()
        .map(|r| LinkedMcpCall {
            tool_name: r.tool_name,
            server_name: r.server_name,
            status: r.status,
            execution_time_ms: r.execution_time_ms,
        })
        .collect())
}