Skip to main content

systemprompt_database/admin/
introspection.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use sqlx::postgres::PgPool;
5use sqlx::Row;
6
7use crate::models::{ColumnInfo, DatabaseInfo, IndexInfo, TableInfo};
8
9#[derive(Debug)]
10pub struct DatabaseAdminService {
11    pool: Arc<PgPool>,
12}
13
14impl DatabaseAdminService {
15    pub const fn new(pool: Arc<PgPool>) -> Self {
16        Self { pool }
17    }
18
19    pub async fn list_tables(&self) -> Result<Vec<TableInfo>> {
20        let rows = sqlx::query(
21            r"
22            SELECT
23                t.table_name as name,
24                COALESCE(s.n_live_tup, 0) as row_count,
25                COALESCE(pg_total_relation_size(quote_ident(t.table_name)::regclass), 0) as size_bytes
26            FROM information_schema.tables t
27            LEFT JOIN pg_stat_user_tables s ON t.table_name = s.relname
28            WHERE t.table_schema = 'public'
29            ORDER BY t.table_name
30            ",
31        )
32        .fetch_all(&*self.pool)
33        .await?;
34
35        let tables = rows
36            .iter()
37            .map(|row| {
38                let name: String = row.get("name");
39                let row_count: i64 = row.get("row_count");
40                let size_bytes: i64 = row.get("size_bytes");
41                TableInfo {
42                    name,
43                    row_count,
44                    size_bytes,
45                    columns: vec![],
46                }
47            })
48            .collect();
49
50        Ok(tables)
51    }
52
53    pub async fn describe_table(&self, table_name: &str) -> Result<(Vec<ColumnInfo>, i64)> {
54        if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
55            return Err(anyhow::anyhow!("Table '{}' not found", table_name));
56        }
57
58        let rows = sqlx::query(
59            "SELECT column_name, data_type, is_nullable, column_default FROM \
60             information_schema.columns WHERE table_name = $1 ORDER BY ordinal_position",
61        )
62        .bind(table_name)
63        .fetch_all(&*self.pool)
64        .await?;
65
66        if rows.is_empty() {
67            return Err(anyhow::anyhow!("Table '{}' not found", table_name));
68        }
69
70        let pk_rows = sqlx::query(
71            r"
72            SELECT a.attname as column_name
73            FROM pg_index i
74            JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
75            WHERE i.indrelid = $1::regclass AND i.indisprimary
76            ",
77        )
78        .bind(table_name)
79        .fetch_all(&*self.pool)
80        .await
81        .unwrap_or_else(|_| Vec::new());
82
83        let pk_columns: Vec<String> = pk_rows
84            .iter()
85            .map(|row| row.get::<String, _>("column_name"))
86            .collect();
87
88        let columns = rows
89            .iter()
90            .map(|row| {
91                let name: String = row.get("column_name");
92                let data_type: String = row.get("data_type");
93                let nullable_str: String = row.get("is_nullable");
94                let nullable = nullable_str.to_uppercase() == "YES";
95                let default: Option<String> = row.get("column_default");
96                let primary_key = pk_columns.contains(&name);
97
98                ColumnInfo {
99                    name,
100                    data_type,
101                    nullable,
102                    primary_key,
103                    default,
104                }
105            })
106            .collect();
107
108        let row_count = self.count_rows(table_name).await?;
109
110        Ok((columns, row_count))
111    }
112
113    pub async fn get_table_indexes(&self, table_name: &str) -> Result<Vec<IndexInfo>> {
114        if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
115            return Err(anyhow::anyhow!("Table '{}' not found", table_name));
116        }
117
118        let rows = sqlx::query(
119            r"
120            SELECT
121                i.relname as index_name,
122                ix.indisunique as is_unique,
123                array_agg(a.attname ORDER BY array_position(ix.indkey, a.attnum)) as columns
124            FROM pg_class t
125            JOIN pg_index ix ON t.oid = ix.indrelid
126            JOIN pg_class i ON i.oid = ix.indexrelid
127            JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey)
128            WHERE t.relname = $1 AND t.relkind = 'r'
129            GROUP BY i.relname, ix.indisunique
130            ORDER BY i.relname
131            ",
132        )
133        .bind(table_name)
134        .fetch_all(&*self.pool)
135        .await?;
136
137        let indexes = rows
138            .iter()
139            .map(|row| {
140                let name: String = row.get("index_name");
141                let unique: bool = row.get("is_unique");
142                let columns: Vec<String> = row.get("columns");
143                IndexInfo {
144                    name,
145                    columns,
146                    unique,
147                }
148            })
149            .collect();
150
151        Ok(indexes)
152    }
153
154    pub async fn count_rows(&self, table_name: &str) -> Result<i64> {
155        if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
156            return Err(anyhow::anyhow!("Table '{}' not found", table_name));
157        }
158
159        let quoted_table = quote_identifier(table_name);
160        let count_query = format!("SELECT COUNT(*) as count FROM {quoted_table}");
161        let row_count: i64 = sqlx::query_scalar(&count_query)
162            .fetch_one(&*self.pool)
163            .await?;
164
165        Ok(row_count)
166    }
167
168    pub async fn get_database_info(&self) -> Result<DatabaseInfo> {
169        let version: String = sqlx::query_scalar("SELECT version()")
170            .fetch_one(&*self.pool)
171            .await?;
172
173        let size: i64 = sqlx::query_scalar("SELECT pg_database_size(current_database())")
174            .fetch_one(&*self.pool)
175            .await?;
176
177        let tables = self.list_tables().await?;
178
179        Ok(DatabaseInfo {
180            path: "PostgreSQL".to_string(),
181            size: u64::try_from(size).unwrap_or(0),
182            version,
183            tables,
184        })
185    }
186
187    pub fn get_expected_tables() -> Vec<&'static str> {
188        vec![
189            "users",
190            "user_sessions",
191            "user_contexts",
192            "agent_tasks",
193            "agent_skills",
194            "task_messages",
195            "task_artifacts",
196            "task_execution_steps",
197            "artifact_parts",
198            "message_parts",
199            "ai_requests",
200            "ai_request_messages",
201            "ai_request_tool_calls",
202            "mcp_tool_executions",
203            "logs",
204            "analytics_events",
205            "oauth_clients",
206            "oauth_auth_codes",
207            "oauth_refresh_tokens",
208            "scheduled_jobs",
209            "services",
210            "markdown_content",
211            "markdown_categories",
212            "files",
213            "content_files",
214        ]
215    }
216}
217
218fn quote_identifier(identifier: &str) -> String {
219    let escaped = identifier.replace('"', "\"\"");
220    format!("\"{escaped}\"")
221}