systemprompt-mcp 0.11.2

Native Model Context Protocol (MCP) implementation for systemprompt.io. Orchestration, per-server OAuth2, RBAC middleware, and tool-call governance — the core of the AI governance pipeline.
Documentation
//! Tool-usage repository — persists each MCP tool execution and aggregates
//! stats.

mod stats;

use crate::error::McpDomainResult;
use chrono::Utc;
use sqlx::PgPool;
use std::sync::Arc;
use systemprompt_database::DbPool;
use systemprompt_identifiers::{AiToolCallId, ContextId, McpExecutionId, UserId};
use uuid::Uuid;

use crate::models::{
    ExecutionStatus, ToolExecution, ToolExecutionRequest, ToolExecutionResult, ToolStats,
};
use systemprompt_models::RequestContext;

fn extract_trace_id(ctx: &RequestContext) -> Option<String> {
    let trace_id = ctx.trace_id();
    (!trace_id.as_str().is_empty()).then(|| trace_id.to_string())
}

#[derive(Debug)]
pub struct ToolUsageRepository {
    pool: Arc<PgPool>,
    write_pool: Arc<PgPool>,
}

impl ToolUsageRepository {
    pub fn new(db: &DbPool) -> McpDomainResult<Self> {
        let pool = db.pool_arc().map_err(|e| {
            crate::error::McpDomainError::Internal(format!("Database must be PostgreSQL: {e}"))
        })?;
        let write_pool = db.write_pool_arc().map_err(|e| {
            crate::error::McpDomainError::Internal(format!("Database must be PostgreSQL: {e}"))
        })?;
        Ok(Self { pool, write_pool })
    }

    pub async fn start_execution(
        &self,
        request: &ToolExecutionRequest,
    ) -> McpDomainResult<McpExecutionId> {
        if let Some(existing_id) = self.find_existing_execution(request).await? {
            return Ok(existing_id);
        }

        let id = Uuid::new_v4().to_string();
        let mcp_execution_id = McpExecutionId::new(id.clone());
        let context_id = request.context.context_id().to_string();
        let user_id = request.context.user_id().to_string();
        let ai_tool_call_id = request.ai_tool_call_id.as_ref().map(ToString::to_string);
        let input_str = serde_json::to_string(&request.input)?;
        let task_id = request.context.task_id().map(ToString::to_string);
        let session_id = request.context.session_id().to_string();
        let trace_id = extract_trace_id(&request.context);
        let status = ExecutionStatus::Pending.as_str();

        sqlx::query!(
            r#"
            INSERT INTO mcp_tool_executions (
                mcp_execution_id, tool_name, server_name, context_id, ai_tool_call_id,
                user_id, task_id, session_id, trace_id, status, input, started_at
            )
            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
            "#,
            id,
            request.tool_name,
            request.server_name,
            context_id,
            ai_tool_call_id,
            user_id,
            task_id,
            session_id,
            trace_id,
            status,
            input_str,
            request.started_at
        )
        .execute(&*self.write_pool)
        .await?;

        Ok(mcp_execution_id)
    }

    pub async fn complete_execution(
        &self,
        mcp_execution_id: &McpExecutionId,
        result: &ToolExecutionResult,
    ) -> McpDomainResult<()> {
        let id = mcp_execution_id.as_str();
        let duration_ms = (result.completed_at - result.started_at).num_milliseconds() as i32;
        let output_str = result.output.as_ref().and_then(|v| {
            serde_json::to_string(v)
                .map_err(|e| {
                    tracing::error!(
                        mcp_execution_id = %id,
                        error = %e,
                        "Failed to serialize tool execution output"
                    );
                    e
                })
                .ok()
        });

        sqlx::query!(
            r#"
            UPDATE mcp_tool_executions
            SET status = $1, output = $2, error_message = $3, execution_time_ms = $4, completed_at = $5
            WHERE mcp_execution_id = $6
            "#,
            result.status,
            output_str,
            result.error_message,
            duration_ms,
            result.completed_at,
            id
        )
        .execute(&*self.write_pool)
        .await?;

        Ok(())
    }

    pub async fn log_execution_sync(
        &self,
        request: &ToolExecutionRequest,
        result: &ToolExecutionResult,
    ) -> McpDomainResult<McpExecutionId> {
        let id = Uuid::new_v4().to_string();
        let mcp_execution_id = McpExecutionId::new(id.clone());
        let status = ExecutionStatus::from_error(result.error_message.is_some()).as_str();
        let context_id = request.context.context_id().to_string();
        let user_id = request.context.user_id().to_string();
        let task_id = request.context.task_id().map(ToString::to_string);
        let session_id = request.context.session_id().to_string();
        let trace_id = extract_trace_id(&request.context);
        let duration_ms = (result.completed_at - request.started_at).num_milliseconds() as i32;
        let input_str = serde_json::to_string(&request.input)?;
        let output_str = result
            .output
            .as_ref()
            .and_then(|v| serde_json::to_string(v).ok());

        sqlx::query!(
            r#"
            INSERT INTO mcp_tool_executions (
                mcp_execution_id, tool_name, server_name, context_id, user_id, task_id,
                session_id, trace_id, status, input, output, error_message, execution_time_ms,
                started_at, completed_at
            )
            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
            "#,
            id,
            request.tool_name,
            request.server_name,
            context_id,
            user_id,
            task_id,
            session_id,
            trace_id,
            status,
            input_str,
            output_str,
            result.error_message,
            duration_ms,
            request.started_at,
            result.completed_at
        )
        .execute(&*self.write_pool)
        .await?;

        Ok(mcp_execution_id)
    }

    pub async fn find_by_id(&self, id: &McpExecutionId) -> McpDomainResult<Option<ToolExecution>> {
        let id_str = id.as_str();
        let row = sqlx::query!(
            r#"SELECT
                mcp_execution_id as "mcp_execution_id!",
                tool_name as "tool_name!",
                server_name as "server_name!",
                context_id,
                ai_tool_call_id,
                user_id as "user_id!",
                status as "status!",
                input as "input!",
                output,
                error_message,
                execution_time_ms,
                started_at as "started_at!",
                completed_at
            FROM mcp_tool_executions
            WHERE mcp_execution_id = $1"#,
            id_str
        )
        .fetch_optional(&*self.pool)
        .await?;

        Ok(row.map(|r| ToolExecution {
            mcp_execution_id: McpExecutionId::new(r.mcp_execution_id),
            tool_name: r.tool_name,
            server_name: r.server_name,
            context_id: r.context_id.and_then(|s| {
                ContextId::try_new(&s)
                    .map_err(|e| {
                        tracing::warn!(error = %e, raw = %s, "Skipping non-UUID context_id from mcp_tool_executions row");
                        e
                    })
                    .ok()
            }),
            ai_tool_call_id: r.ai_tool_call_id.map(AiToolCallId::new),
            user_id: UserId::new(r.user_id),
            status: r.status,
            input: r.input,
            output: r.output,
            error_message: r.error_message,
            execution_time_ms: r.execution_time_ms,
            started_at: r.started_at,
            completed_at: r.completed_at,
        }))
    }

    pub async fn find_by_ai_call_id(
        &self,
        ai_tool_call_id: &AiToolCallId,
    ) -> McpDomainResult<Option<McpExecutionId>> {
        let id_str = ai_tool_call_id.as_str();
        let result = sqlx::query_scalar!(
            r#"SELECT mcp_execution_id as "mcp_execution_id!" FROM mcp_tool_executions WHERE ai_tool_call_id = $1"#,
            id_str
        )
        .fetch_optional(&*self.pool)
        .await?;
        Ok(result.map(McpExecutionId::new))
    }

    async fn find_existing_execution(
        &self,
        request: &ToolExecutionRequest,
    ) -> McpDomainResult<Option<McpExecutionId>> {
        let Some(ai_call_id) = &request.ai_tool_call_id else {
            return Ok(None);
        };
        self.find_by_ai_call_id(ai_call_id).await
    }

    pub async fn find_context_id(
        &self,
        execution_id: &McpExecutionId,
    ) -> McpDomainResult<Option<ContextId>> {
        let id_str = execution_id.as_str();
        let result = sqlx::query_scalar!(
            "SELECT context_id FROM mcp_tool_executions WHERE mcp_execution_id = $1",
            id_str
        )
        .fetch_optional(&*self.pool)
        .await?;
        Ok(result.flatten().map(ContextId::new))
    }

    pub async fn list_tool_stats(&self, limit: i64) -> McpDomainResult<Vec<ToolStats>> {
        stats::list_tool_stats(&self.pool, limit).await
    }

    pub async fn update_context_timestamp(&self, context_id: &ContextId) -> McpDomainResult<()> {
        let now = Utc::now();
        let context_id_str = context_id.to_string();
        sqlx::query!(
            "UPDATE user_contexts SET updated_at = $1 WHERE context_id = $2",
            now,
            context_id_str
        )
        .execute(&*self.write_pool)
        .await?;
        Ok(())
    }
}