systemprompt-database 0.1.22

Database abstraction layer for systemprompt.io supporting SQLite, PostgreSQL, and MySQL
Documentation
use std::sync::Arc;

use anyhow::Result;
use sqlx::Row;
use sqlx::postgres::PgPool;

use crate::models::{ColumnInfo, DatabaseInfo, IndexInfo, TableInfo};

#[derive(Debug)]
pub struct DatabaseAdminService {
    pool: Arc<PgPool>,
}

impl DatabaseAdminService {
    pub const fn new(pool: Arc<PgPool>) -> Self {
        Self { pool }
    }

    pub async fn list_tables(&self) -> Result<Vec<TableInfo>> {
        let rows = sqlx::query(
            r"
            SELECT
                t.table_name as name,
                COALESCE(s.n_live_tup, 0) as row_count,
                COALESCE(pg_total_relation_size(quote_ident(t.table_name)::regclass), 0) as size_bytes
            FROM information_schema.tables t
            LEFT JOIN pg_stat_user_tables s ON t.table_name = s.relname
            WHERE t.table_schema = 'public'
            ORDER BY t.table_name
            ",
        )
        .fetch_all(&*self.pool)
        .await?;

        let tables = rows
            .iter()
            .map(|row| {
                let name: String = row.get("name");
                let row_count: i64 = row.get("row_count");
                let size_bytes: i64 = row.get("size_bytes");
                TableInfo {
                    name,
                    row_count,
                    size_bytes,
                    columns: vec![],
                }
            })
            .collect();

        Ok(tables)
    }

    pub async fn describe_table(&self, table_name: &str) -> Result<(Vec<ColumnInfo>, i64)> {
        if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
            return Err(anyhow::anyhow!("Table '{}' not found", table_name));
        }

        let rows = sqlx::query(
            "SELECT column_name, data_type, is_nullable, column_default FROM \
             information_schema.columns WHERE table_name = $1 ORDER BY ordinal_position",
        )
        .bind(table_name)
        .fetch_all(&*self.pool)
        .await?;

        if rows.is_empty() {
            return Err(anyhow::anyhow!("Table '{}' not found", table_name));
        }

        let pk_rows = sqlx::query(
            r"
            SELECT a.attname as column_name
            FROM pg_index i
            JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
            WHERE i.indrelid = $1::regclass AND i.indisprimary
            ",
        )
        .bind(table_name)
        .fetch_all(&*self.pool)
        .await
        .unwrap_or_else(|_| Vec::new());

        let pk_columns: Vec<String> = pk_rows
            .iter()
            .map(|row| row.get::<String, _>("column_name"))
            .collect();

        let columns = rows
            .iter()
            .map(|row| {
                let name: String = row.get("column_name");
                let data_type: String = row.get("data_type");
                let nullable_str: String = row.get("is_nullable");
                let nullable = nullable_str.to_uppercase() == "YES";
                let default: Option<String> = row.get("column_default");
                let primary_key = pk_columns.contains(&name);

                ColumnInfo {
                    name,
                    data_type,
                    nullable,
                    primary_key,
                    default,
                }
            })
            .collect();

        let row_count = self.count_rows(table_name).await?;

        Ok((columns, row_count))
    }

    pub async fn get_table_indexes(&self, table_name: &str) -> Result<Vec<IndexInfo>> {
        if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
            return Err(anyhow::anyhow!("Table '{}' not found", table_name));
        }

        let rows = sqlx::query(
            r"
            SELECT
                i.relname as index_name,
                ix.indisunique as is_unique,
                array_agg(a.attname ORDER BY array_position(ix.indkey, a.attnum)) as columns
            FROM pg_class t
            JOIN pg_index ix ON t.oid = ix.indrelid
            JOIN pg_class i ON i.oid = ix.indexrelid
            JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey)
            WHERE t.relname = $1 AND t.relkind = 'r'
            GROUP BY i.relname, ix.indisunique
            ORDER BY i.relname
            ",
        )
        .bind(table_name)
        .fetch_all(&*self.pool)
        .await?;

        let indexes = rows
            .iter()
            .map(|row| {
                let name: String = row.get("index_name");
                let unique: bool = row.get("is_unique");
                let columns: Vec<String> = row.get("columns");
                IndexInfo {
                    name,
                    columns,
                    unique,
                }
            })
            .collect();

        Ok(indexes)
    }

    pub async fn count_rows(&self, table_name: &str) -> Result<i64> {
        if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
            return Err(anyhow::anyhow!("Table '{}' not found", table_name));
        }

        let quoted_table = quote_identifier(table_name);
        let count_query = format!("SELECT COUNT(*) as count FROM {quoted_table}");
        let row_count: i64 = sqlx::query_scalar(&count_query)
            .fetch_one(&*self.pool)
            .await?;

        Ok(row_count)
    }

    pub async fn get_database_info(&self) -> Result<DatabaseInfo> {
        let version: String = sqlx::query_scalar("SELECT version()")
            .fetch_one(&*self.pool)
            .await?;

        let size: i64 = sqlx::query_scalar("SELECT pg_database_size(current_database())")
            .fetch_one(&*self.pool)
            .await?;

        let tables = self.list_tables().await?;

        Ok(DatabaseInfo {
            path: "PostgreSQL".to_string(),
            size: u64::try_from(size).unwrap_or(0),
            version,
            tables,
        })
    }

    pub fn get_expected_tables() -> Vec<&'static str> {
        vec![
            "users",
            "user_sessions",
            "user_contexts",
            "agent_tasks",
            "agent_skills",
            "task_messages",
            "task_artifacts",
            "task_execution_steps",
            "artifact_parts",
            "message_parts",
            "ai_requests",
            "ai_request_messages",
            "ai_request_tool_calls",
            "mcp_tool_executions",
            "logs",
            "analytics_events",
            "oauth_clients",
            "oauth_auth_codes",
            "oauth_refresh_tokens",
            "scheduled_jobs",
            "services",
            "markdown_content",
            "markdown_categories",
            "files",
            "content_files",
        ]
    }
}

fn quote_identifier(identifier: &str) -> String {
    let escaped = identifier.replace('"', "\"\"");
    format!("\"{escaped}\"")
}