1mod mysql;
2mod pg;
3
4use crate::{config::DatabaseConfig, error::AppError};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use sqlparser::{ast, dialect::GenericDialect, parser::Parser};
8use sqlx::{MySqlPool, PgPool};
9use std::{cmp::min, convert::Infallible, str::FromStr, time::Duration};
10
11const DEFAULT_LIMIT: usize = 500;
12const MAX_LIMIT: usize = 5000;
13
14#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
15#[non_exhaustive]
16#[serde(rename_all = "lowercase")]
17pub enum DatabaseType {
18 Postgres,
19 Mysql,
20}
21
22#[derive(Debug)]
23pub struct PgPoolHandler(PgPool);
24
25#[derive(Debug)]
26pub struct MySqlPoolHandler(MySqlPool);
27
28#[derive(Debug)]
29pub enum DbPool {
30 Postgres(PgPoolHandler),
31 MySql(MySqlPoolHandler),
32 }
34
35pub trait PoolHandler: Sized {
36 async fn try_new(db_config: &DatabaseConfig) -> Result<Self, AppError>;
38 async fn list_tables(&self) -> Result<Vec<TableInfo>, AppError>;
40 async fn get_table_schema(&self, table_name: &str) -> Result<TableSchema, AppError>;
42 async fn sanitize_query(&self, query: &str, limit: usize) -> Result<String, AppError> {
44 let dialect = GenericDialect {};
45 let ast = Parser::parse_sql(&dialect, query)
46 .map_err(|e| AppError::BadRequest(format!("SQL parsing error: {}", e)))?;
47 if ast.len() != 1 {
48 return Err(AppError::BadRequest(
49 "Only single SQL statements are allowed".to_string(),
50 ));
51 }
52
53 let mut stmt = ast.into_iter().next().unwrap();
54
55 let has_limit = match stmt {
56 ast::Statement::Query(ref mut query) => {
57 match &*query.body {
59 ast::SetExpr::Select(_)
60 | ast::SetExpr::Values(_)
61 | ast::SetExpr::Query(_)
62 | ast::SetExpr::Table(_) => {
63 }
65 _ => {
66 return Err(AppError::BadRequest(
67 "Only SELECT-like queries are allowed.".to_string(),
68 ));
69 }
70 }
71
72 match &mut query.limit {
73 Some(ast::Expr::Value(ast::ValueWithSpan {
74 value: ast::Value::Number(s, _),
75 ..
76 })) => {
77 let existing_limit = s.parse::<usize>().unwrap_or(0);
78 if existing_limit < limit {
79 } else {
81 *s = min(existing_limit, MAX_LIMIT).to_string();
82 }
83 true
84 }
85 _ => false,
86 }
87 }
88 _ => {
89 return Err(AppError::BadRequest(
90 "Only SELECT queries are allowed".to_string(),
91 ));
92 }
93 };
94 let mut sql = stmt.to_string();
95 if !has_limit {
96 sql = format!("{} LIMIT {}", sql, limit);
97 }
98 Ok(sql)
99 }
100
101 async fn execute_query(
103 &self,
104 query: &str,
105 limit: Option<usize>,
106 ) -> Result<QueryResult, AppError>;
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct DatabaseInfo {
112 pub name: String,
113 #[serde(rename = "type")]
114 pub db_type: String, }
116
117#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
118#[serde(rename_all = "snake_case")]
119pub enum TableType {
120 Table,
121 View,
122 MaterializedView,
123}
124#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] pub struct TableInfo {
127 pub name: String,
128 #[sqlx(rename = "type", try_from = "String")]
129 #[serde(rename = "type")]
130 pub table_type: TableType, }
132
133#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
134pub enum ColumnType {
135 SmallInt,
137 Integer,
138 BigInt,
139 Decimal,
140 Numeric,
141 Real,
142 DoublePrecision,
143 Char,
145 Varchar,
146 Text,
147 Bytea,
149 Boolean,
151 Date,
153 Time,
154 Timestamp,
155 TimestampTz,
156 Interval,
157 Json,
159 Jsonb,
160 Inet,
162 Cidr,
163 MacAddr,
164 Uuid,
166 Point,
168 Line,
169 Lseg,
170 Box,
171 Path,
172 Polygon,
173 Circle,
174 Array,
176 Int4Range,
178 Int8Range,
179 NumRange,
180 TsRange,
181 TstzRange,
182 DateRange,
183 Bit,
185 Varbit,
186 TsVector,
188 TsQuery,
189 Xml,
191 Money,
193 Other(String),
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct ColumnInfo {
200 pub name: String,
201 pub data_type: ColumnType,
202 pub is_nullable: bool,
203 #[serde(default)]
205 pub is_pk: bool,
206 #[serde(default)]
207 pub is_unique: bool,
208 #[serde(skip_serializing_if = "Option::is_none")]
209 pub fk_table: Option<String>,
210 #[serde(skip_serializing_if = "Option::is_none")]
211 pub fk_column: Option<String>,
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct TableSchema {
216 pub table_name: String,
217 pub columns: Vec<ColumnInfo>,
218 }
222
223#[derive(Debug, Serialize)]
225pub struct QueryResult {
226 pub data: Value,
227 pub execution_time: Duration,
228 #[serde(skip_serializing_if = "Option::is_none")]
229 pub plan: Option<Value>,
230}
231
232#[derive(sqlx::FromRow)]
233pub struct JsonResult {
234 pub data: Value,
235}
236
237impl FromStr for TableType {
238 type Err = Infallible;
239
240 fn from_str(s: &str) -> Result<Self, Self::Err> {
241 match s {
242 "table" => Ok(TableType::Table),
243 "view" => Ok(TableType::View),
244 "materialized_view" => Ok(TableType::MaterializedView),
245 _ => unreachable!(),
246 }
247 }
248}
249
250impl From<String> for TableType {
251 fn from(s: String) -> Self {
252 TableType::from_str(&s).unwrap()
253 }
254}
255
256impl TableType {
257 pub fn as_str(&self) -> &str {
258 match self {
259 TableType::Table => "table",
260 TableType::View => "view",
261 TableType::MaterializedView => "materialized_view",
262 }
263 }
264}
265
266impl FromStr for ColumnType {
267 type Err = Infallible;
268
269 fn from_str(s: &str) -> Result<Self, Self::Err> {
271 match s {
272 "smallint" => Ok(ColumnType::SmallInt),
273 "integer" => Ok(ColumnType::Integer),
274 "bigint" => Ok(ColumnType::BigInt),
275 "decimal" => Ok(ColumnType::Decimal),
276 "numeric" => Ok(ColumnType::Numeric),
277 "real" => Ok(ColumnType::Real),
278 "double precision" => Ok(ColumnType::DoublePrecision),
279 "money" => Ok(ColumnType::Money),
280 "text" => Ok(ColumnType::Text),
281 "char" => Ok(ColumnType::Char),
282 "character" => Ok(ColumnType::Char),
283 "varchar" => Ok(ColumnType::Varchar),
284 "character varying" => Ok(ColumnType::Varchar),
285 "boolean" => Ok(ColumnType::Boolean),
286 "json" => Ok(ColumnType::Json),
287 "jsonb" => Ok(ColumnType::Jsonb),
288 "bytea" => Ok(ColumnType::Bytea),
289 "uuid" => Ok(ColumnType::Uuid),
290 "inet" => Ok(ColumnType::Inet),
291 "cidr" => Ok(ColumnType::Cidr),
292 "macaddr" => Ok(ColumnType::MacAddr),
293 "point" => Ok(ColumnType::Point),
294 "line" => Ok(ColumnType::Line),
295 "lseg" => Ok(ColumnType::Lseg),
296 "box" => Ok(ColumnType::Box),
297 "path" => Ok(ColumnType::Path),
298 "polygon" => Ok(ColumnType::Polygon),
299 "circle" => Ok(ColumnType::Circle),
300 "array" => Ok(ColumnType::Array),
301 "int4range" => Ok(ColumnType::Int4Range),
302 "int8range" => Ok(ColumnType::Int8Range),
303 "numrange" => Ok(ColumnType::NumRange),
304 "tsrange" => Ok(ColumnType::TsRange),
305 "tstzrange" => Ok(ColumnType::TstzRange),
306 "date" => Ok(ColumnType::Date),
307 "datetime" => Ok(ColumnType::Timestamp),
308 "time" => Ok(ColumnType::Time),
309 "timestamp" => Ok(ColumnType::Timestamp),
310 "timestamp with time zone" => Ok(ColumnType::TimestampTz),
311 "interval" => Ok(ColumnType::Interval),
312 "daterange" => Ok(ColumnType::DateRange),
313 "bit" => Ok(ColumnType::Bit),
314 "varbit" => Ok(ColumnType::Varbit),
315 "tsvector" => Ok(ColumnType::TsVector),
316 "tsquery" => Ok(ColumnType::TsQuery),
317 "xml" => Ok(ColumnType::Xml),
318 v => Ok(ColumnType::Other(v.to_string())),
319 }
320 }
321}
322
323impl From<String> for ColumnType {
324 fn from(s: String) -> Self {
325 ColumnType::from_str(&s).unwrap_or_else(|_| {
326 tracing::warn!("Unsupported database type string encountered: {}", s);
327 unreachable!("unsupported type: {}", s)
330 })
331 }
332}
333
334impl sqlx::Type<sqlx::Postgres> for ColumnType {
335 fn type_info() -> sqlx::postgres::PgTypeInfo {
336 sqlx::postgres::PgTypeInfo::with_name("TEXT") }
338}
339
340impl<'r> sqlx::Decode<'r, sqlx::Postgres> for ColumnType {
341 fn decode(
342 value: sqlx::postgres::PgValueRef<'r>,
343 ) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
344 let s = <String as sqlx::Decode<sqlx::Postgres>>::decode(value)?;
345 Ok(ColumnType::from_str(&s).map_err(|_| "Invalid ColumnType string")?)
346 }
347}
348
349impl PoolHandler for DbPool {
350 async fn try_new(db_config: &DatabaseConfig) -> Result<Self, AppError> {
351 match db_config.db_type {
352 DatabaseType::Postgres => {
353 let pool = PgPoolHandler::try_new(db_config).await?;
354 Ok(DbPool::Postgres(pool))
355 }
356 DatabaseType::Mysql => {
357 let pool = MySqlPoolHandler::try_new(db_config).await?;
358 Ok(DbPool::MySql(pool))
359 }
360 #[allow(unreachable_patterns)]
361 _ => Err(AppError::UnsupportedDatabaseType(
362 db_config.db_type.to_string(),
363 )),
364 }
365 }
366
367 async fn list_tables(&self) -> Result<Vec<TableInfo>, AppError> {
368 match self {
369 DbPool::Postgres(pg_pool) => pg_pool.list_tables().await,
370 DbPool::MySql(mysql_pool) => mysql_pool.list_tables().await,
371 }
372 }
373
374 async fn get_table_schema(&self, table_name: &str) -> Result<TableSchema, AppError> {
376 match self {
377 DbPool::Postgres(pg_pool) => pg_pool.get_table_schema(table_name).await,
378 DbPool::MySql(mysql_pool) => mysql_pool.get_table_schema(table_name).await,
379 }
380 }
381
382 async fn sanitize_query(&self, query: &str, limit: usize) -> Result<String, AppError> {
383 match self {
384 DbPool::Postgres(pg_pool) => pg_pool.sanitize_query(query, limit).await,
385 DbPool::MySql(mysql_pool) => mysql_pool.sanitize_query(query, limit).await,
386 }
387 }
388
389 async fn execute_query(
390 &self,
391 query: &str,
392 limit: Option<usize>,
393 ) -> Result<QueryResult, AppError> {
394 match self {
395 DbPool::Postgres(pg_pool) => pg_pool.execute_query(query, limit).await,
396 DbPool::MySql(mysql_pool) => mysql_pool.execute_query(query, limit).await,
397 }
398 }
399}