Skip to main content

systemprompt_database/admin/
introspection.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use sqlx::Row;
5use sqlx::postgres::PgPool;
6
7use crate::admin::identifier::SafeIdentifier;
8use crate::models::{ColumnInfo, DatabaseInfo, IndexInfo, TableInfo};
9
10#[derive(Debug)]
11pub struct DatabaseAdminService {
12    pool: Arc<PgPool>,
13}
14
15impl DatabaseAdminService {
16    pub const fn new(pool: Arc<PgPool>) -> Self {
17        Self { pool }
18    }
19
20    pub async fn list_tables(&self) -> Result<Vec<TableInfo>> {
21        let rows = sqlx::query(
22            r"
23            SELECT
24                t.table_name as name,
25                COALESCE(s.n_live_tup, 0) as row_count,
26                COALESCE(pg_total_relation_size(quote_ident(t.table_name)::regclass), 0) as size_bytes
27            FROM information_schema.tables t
28            LEFT JOIN pg_stat_user_tables s ON t.table_name = s.relname
29            WHERE t.table_schema = 'public'
30            ORDER BY t.table_name
31            ",
32        )
33        .fetch_all(&*self.pool)
34        .await?;
35
36        let tables = rows
37            .iter()
38            .map(|row| {
39                let name: String = row.get("name");
40                let row_count: i64 = row.get("row_count");
41                let size_bytes: i64 = row.get("size_bytes");
42                TableInfo {
43                    name,
44                    row_count,
45                    size_bytes,
46                    columns: vec![],
47                }
48            })
49            .collect();
50
51        Ok(tables)
52    }
53
54    pub async fn describe_table(
55        &self,
56        table_name: &SafeIdentifier,
57    ) -> Result<(Vec<ColumnInfo>, i64)> {
58        let rows = sqlx::query(
59            "SELECT column_name, data_type, is_nullable, column_default FROM \
60             information_schema.columns WHERE table_name = $1 ORDER BY ordinal_position",
61        )
62        .bind(table_name.as_str())
63        .fetch_all(&*self.pool)
64        .await?;
65
66        if rows.is_empty() {
67            return Err(anyhow::anyhow!("Table '{table_name}' not found"));
68        }
69
70        let pk_rows = sqlx::query(
71            r"
72            SELECT a.attname as column_name
73            FROM pg_index i
74            JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
75            WHERE i.indrelid = $1::regclass AND i.indisprimary
76            ",
77        )
78        .bind(table_name.as_str())
79        .fetch_all(&*self.pool)
80        .await?;
81
82        let pk_columns: Vec<String> = pk_rows
83            .iter()
84            .map(|row| row.get::<String, _>("column_name"))
85            .collect();
86
87        let columns = rows
88            .iter()
89            .map(|row| {
90                let name: String = row.get("column_name");
91                let data_type: String = row.get("data_type");
92                let nullable_str: String = row.get("is_nullable");
93                let nullable = nullable_str.to_uppercase() == "YES";
94                let default: Option<String> = row.get("column_default");
95                let primary_key = pk_columns.contains(&name);
96
97                ColumnInfo {
98                    name,
99                    data_type,
100                    nullable,
101                    primary_key,
102                    default,
103                }
104            })
105            .collect();
106
107        let row_count = self.count_rows(table_name).await?;
108
109        Ok((columns, row_count))
110    }
111
112    pub async fn get_table_indexes(&self, table_name: &SafeIdentifier) -> Result<Vec<IndexInfo>> {
113        let rows = sqlx::query(
114            r"
115            SELECT
116                i.relname as index_name,
117                ix.indisunique as is_unique,
118                array_agg(a.attname ORDER BY array_position(ix.indkey, a.attnum)) as columns
119            FROM pg_class t
120            JOIN pg_index ix ON t.oid = ix.indrelid
121            JOIN pg_class i ON i.oid = ix.indexrelid
122            JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey)
123            WHERE t.relname = $1 AND t.relkind = 'r'
124            GROUP BY i.relname, ix.indisunique
125            ORDER BY i.relname
126            ",
127        )
128        .bind(table_name.as_str())
129        .fetch_all(&*self.pool)
130        .await?;
131
132        let indexes = rows
133            .iter()
134            .map(|row| {
135                let name: String = row.get("index_name");
136                let unique: bool = row.get("is_unique");
137                let columns: Vec<String> = row.get("columns");
138                IndexInfo {
139                    name,
140                    columns,
141                    unique,
142                }
143            })
144            .collect();
145
146        Ok(indexes)
147    }
148
149    pub async fn count_rows(&self, table_name: &SafeIdentifier) -> Result<i64> {
150        let quoted_table = quote_identifier(table_name.as_str());
151        let count_query = format!("SELECT COUNT(*) as count FROM {quoted_table}");
152        let row_count: i64 = sqlx::query_scalar(&count_query)
153            .fetch_one(&*self.pool)
154            .await?;
155
156        Ok(row_count)
157    }
158
159    pub async fn get_database_info(&self) -> Result<DatabaseInfo> {
160        let version: String = sqlx::query_scalar("SELECT version()")
161            .fetch_one(&*self.pool)
162            .await?;
163
164        let size: i64 = sqlx::query_scalar("SELECT pg_database_size(current_database())")
165            .fetch_one(&*self.pool)
166            .await?;
167
168        let size = u64::try_from(size)
169            .map_err(|_| anyhow::anyhow!("pg_database_size returned negative value: {size}"))?;
170
171        let tables = self.list_tables().await?;
172
173        Ok(DatabaseInfo {
174            path: "PostgreSQL".to_string(),
175            size,
176            version,
177            tables,
178        })
179    }
180
181    pub fn get_expected_tables() -> Vec<&'static str> {
182        vec![
183            "users",
184            "user_sessions",
185            "user_contexts",
186            "agent_tasks",
187            "agent_skills",
188            "task_messages",
189            "task_artifacts",
190            "task_execution_steps",
191            "artifact_parts",
192            "message_parts",
193            "ai_requests",
194            "ai_request_messages",
195            "ai_request_tool_calls",
196            "mcp_tool_executions",
197            "logs",
198            "analytics_events",
199            "oauth_clients",
200            "oauth_auth_codes",
201            "oauth_refresh_tokens",
202            "scheduled_jobs",
203            "services",
204            "markdown_content",
205            "markdown_categories",
206            "files",
207            "content_files",
208        ]
209    }
210}
211
212fn quote_identifier(identifier: &str) -> String {
213    let escaped = identifier.replace('"', "\"\"");
214    format!("\"{escaped}\"")
215}