r2_data2/db/
mod.rs

1mod mysql;
2mod pg;
3
4use crate::{config::DatabaseConfig, error::AppError};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use sqlparser::{ast, dialect::GenericDialect, parser::Parser};
8use sqlx::{MySqlPool, PgPool};
9use std::{cmp::min, convert::Infallible, str::FromStr, time::Duration};
10
11const DEFAULT_LIMIT: usize = 500;
12const MAX_LIMIT: usize = 5000;
13
14#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
15#[non_exhaustive]
16#[serde(rename_all = "lowercase")]
17pub enum DatabaseType {
18    Postgres,
19    Mysql,
20}
21
22#[derive(Debug)]
23pub struct PgPoolHandler(PgPool);
24
25#[derive(Debug)]
26pub struct MySqlPoolHandler(MySqlPool);
27
28#[derive(Debug)]
29pub enum DbPool {
30    Postgres(PgPoolHandler),
31    MySql(MySqlPoolHandler),
32    // Add other pool types here if needed
33}
34
35pub trait PoolHandler: Sized {
36    /// Create a new pool handler
37    async fn try_new(db_config: &DatabaseConfig) -> Result<Self, AppError>;
38    /// List all tables in the database
39    async fn list_tables(&self) -> Result<Vec<TableInfo>, AppError>;
40    /// Get the schema of a table
41    async fn get_table_schema(&self, table_name: &str) -> Result<TableSchema, AppError>;
42    /// Sanitize the query and rewrite it to CTE format
43    async fn sanitize_query(&self, query: &str, limit: usize) -> Result<String, AppError> {
44        let dialect = GenericDialect {};
45        let ast = Parser::parse_sql(&dialect, query)
46            .map_err(|e| AppError::BadRequest(format!("SQL parsing error: {}", e)))?;
47        if ast.len() != 1 {
48            return Err(AppError::BadRequest(
49                "Only single SQL statements are allowed".to_string(),
50            ));
51        }
52
53        let mut stmt = ast.into_iter().next().unwrap();
54
55        let has_limit = match stmt {
56            ast::Statement::Query(ref mut query) => {
57                // Check query type
58                match &*query.body {
59                    ast::SetExpr::Select(_)
60                    | ast::SetExpr::Values(_)
61                    | ast::SetExpr::Query(_)
62                    | ast::SetExpr::Table(_) => {
63                        // Valid query type
64                    }
65                    _ => {
66                        return Err(AppError::BadRequest(
67                            "Only SELECT-like queries are allowed.".to_string(),
68                        ));
69                    }
70                }
71
72                match &mut query.limit {
73                    Some(ast::Expr::Value(ast::ValueWithSpan {
74                        value: ast::Value::Number(s, _),
75                        ..
76                    })) => {
77                        let existing_limit = s.parse::<usize>().unwrap_or(0);
78                        if existing_limit < limit {
79                            // do nothing
80                        } else {
81                            *s = min(existing_limit, MAX_LIMIT).to_string();
82                        }
83                        true
84                    }
85                    _ => false,
86                }
87            }
88            _ => {
89                return Err(AppError::BadRequest(
90                    "Only SELECT queries are allowed".to_string(),
91                ));
92            }
93        };
94        let mut sql = stmt.to_string();
95        if !has_limit {
96            sql = format!("{} LIMIT {}", sql, limit);
97        }
98        Ok(sql)
99    }
100
101    /// Execute the query and return the result along with execution time
102    async fn execute_query(
103        &self,
104        query: &str,
105        limit: Option<usize>,
106    ) -> Result<QueryResult, AppError>;
107}
108
109// Response structure for the /api/databases endpoint
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct DatabaseInfo {
112    pub name: String,
113    #[serde(rename = "type")]
114    pub db_type: String, // Use String representation for JSON response
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
118#[serde(rename_all = "snake_case")]
119pub enum TableType {
120    Table,
121    View,
122    MaterializedView,
123}
124// Response structure for the /api/databases/{dbName}/tables endpoint
125#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] // Derive FromRow for sqlx query mapping
126pub struct TableInfo {
127    pub name: String,
128    #[sqlx(rename = "type", try_from = "String")]
129    #[serde(rename = "type")]
130    pub table_type: TableType, // e.g., "BASE TABLE", "VIEW"
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
134pub enum ColumnType {
135    // Numeric types
136    SmallInt,
137    Integer,
138    BigInt,
139    Decimal,
140    Numeric,
141    Real,
142    DoublePrecision,
143    // Character types
144    Char,
145    Varchar,
146    Text,
147    // Binary types
148    Bytea,
149    // Boolean
150    Boolean,
151    // Date/Time types
152    Date,
153    Time,
154    Timestamp,
155    TimestampTz,
156    Interval,
157    // JSON types
158    Json,
159    Jsonb,
160    // Network types
161    Inet,
162    Cidr,
163    MacAddr,
164    // UUID
165    Uuid,
166    // Geometric types
167    Point,
168    Line,
169    Lseg,
170    Box,
171    Path,
172    Polygon,
173    Circle,
174    // Array types
175    Array,
176    // Range types
177    Int4Range,
178    Int8Range,
179    NumRange,
180    TsRange,
181    TstzRange,
182    DateRange,
183    // Bit string types
184    Bit,
185    Varbit,
186    // Text search types
187    TsVector,
188    TsQuery,
189    // XML
190    Xml,
191    // Money
192    Money,
193    // Other
194    Other(String),
195}
196
197// Structures for /api/.../schema endpoint
198#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct ColumnInfo {
200    pub name: String,
201    pub data_type: ColumnType,
202    pub is_nullable: bool,
203    // Add constraint fields
204    #[serde(default)]
205    pub is_pk: bool,
206    #[serde(default)]
207    pub is_unique: bool,
208    #[serde(skip_serializing_if = "Option::is_none")]
209    pub fk_table: Option<String>,
210    #[serde(skip_serializing_if = "Option::is_none")]
211    pub fk_column: Option<String>,
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct TableSchema {
216    pub table_name: String,
217    pub columns: Vec<ColumnInfo>,
218    // Optional: Add constraints, indexes later if needed
219    // pub constraints: Option<Vec<ConstraintInfo>>,
220    // pub indexes: Option<Vec<IndexInfo>>,
221}
222
223// Struct to hold the query result and execution time
224#[derive(Debug, Serialize)]
225pub struct QueryResult {
226    pub data: Value,
227    pub execution_time: Duration,
228    #[serde(skip_serializing_if = "Option::is_none")]
229    pub plan: Option<Value>,
230}
231
232#[derive(sqlx::FromRow)]
233pub struct JsonResult {
234    pub data: Value,
235}
236
237impl FromStr for TableType {
238    type Err = Infallible;
239
240    fn from_str(s: &str) -> Result<Self, Self::Err> {
241        match s {
242            "table" => Ok(TableType::Table),
243            "view" => Ok(TableType::View),
244            "materialized_view" => Ok(TableType::MaterializedView),
245            _ => unreachable!(),
246        }
247    }
248}
249
250impl From<String> for TableType {
251    fn from(s: String) -> Self {
252        TableType::from_str(&s).unwrap()
253    }
254}
255
256impl TableType {
257    pub fn as_str(&self) -> &str {
258        match self {
259            TableType::Table => "table",
260            TableType::View => "view",
261            TableType::MaterializedView => "materialized_view",
262        }
263    }
264}
265
266impl FromStr for ColumnType {
267    type Err = Infallible;
268
269    // TODO: verify mysql types
270    fn from_str(s: &str) -> Result<Self, Self::Err> {
271        match s {
272            "smallint" => Ok(ColumnType::SmallInt),
273            "integer" => Ok(ColumnType::Integer),
274            "bigint" => Ok(ColumnType::BigInt),
275            "decimal" => Ok(ColumnType::Decimal),
276            "numeric" => Ok(ColumnType::Numeric),
277            "real" => Ok(ColumnType::Real),
278            "double precision" => Ok(ColumnType::DoublePrecision),
279            "money" => Ok(ColumnType::Money),
280            "text" => Ok(ColumnType::Text),
281            "char" => Ok(ColumnType::Char),
282            "character" => Ok(ColumnType::Char),
283            "varchar" => Ok(ColumnType::Varchar),
284            "character varying" => Ok(ColumnType::Varchar),
285            "boolean" => Ok(ColumnType::Boolean),
286            "json" => Ok(ColumnType::Json),
287            "jsonb" => Ok(ColumnType::Jsonb),
288            "bytea" => Ok(ColumnType::Bytea),
289            "uuid" => Ok(ColumnType::Uuid),
290            "inet" => Ok(ColumnType::Inet),
291            "cidr" => Ok(ColumnType::Cidr),
292            "macaddr" => Ok(ColumnType::MacAddr),
293            "point" => Ok(ColumnType::Point),
294            "line" => Ok(ColumnType::Line),
295            "lseg" => Ok(ColumnType::Lseg),
296            "box" => Ok(ColumnType::Box),
297            "path" => Ok(ColumnType::Path),
298            "polygon" => Ok(ColumnType::Polygon),
299            "circle" => Ok(ColumnType::Circle),
300            "array" => Ok(ColumnType::Array),
301            "int4range" => Ok(ColumnType::Int4Range),
302            "int8range" => Ok(ColumnType::Int8Range),
303            "numrange" => Ok(ColumnType::NumRange),
304            "tsrange" => Ok(ColumnType::TsRange),
305            "tstzrange" => Ok(ColumnType::TstzRange),
306            "date" => Ok(ColumnType::Date),
307            "datetime" => Ok(ColumnType::Timestamp),
308            "time" => Ok(ColumnType::Time),
309            "timestamp" => Ok(ColumnType::Timestamp),
310            "timestamp with time zone" => Ok(ColumnType::TimestampTz),
311            "interval" => Ok(ColumnType::Interval),
312            "daterange" => Ok(ColumnType::DateRange),
313            "bit" => Ok(ColumnType::Bit),
314            "varbit" => Ok(ColumnType::Varbit),
315            "tsvector" => Ok(ColumnType::TsVector),
316            "tsquery" => Ok(ColumnType::TsQuery),
317            "xml" => Ok(ColumnType::Xml),
318            v => Ok(ColumnType::Other(v.to_string())),
319        }
320    }
321}
322
323impl From<String> for ColumnType {
324    fn from(s: String) -> Self {
325        ColumnType::from_str(&s).unwrap_or_else(|_| {
326            tracing::warn!("Unsupported database type string encountered: {}", s);
327            // Decide on a default/fallback type if needed, or keep panicking/unreachable
328            // For now, stick with unreachable! based on previous code
329            unreachable!("unsupported type: {}", s)
330        })
331    }
332}
333
334impl sqlx::Type<sqlx::Postgres> for ColumnType {
335    fn type_info() -> sqlx::postgres::PgTypeInfo {
336        sqlx::postgres::PgTypeInfo::with_name("TEXT") // Treat as TEXT for SQLx binding/fetching
337    }
338}
339
340impl<'r> sqlx::Decode<'r, sqlx::Postgres> for ColumnType {
341    fn decode(
342        value: sqlx::postgres::PgValueRef<'r>,
343    ) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
344        let s = <String as sqlx::Decode<sqlx::Postgres>>::decode(value)?;
345        Ok(ColumnType::from_str(&s).map_err(|_| "Invalid ColumnType string")?)
346    }
347}
348
349impl PoolHandler for DbPool {
350    async fn try_new(db_config: &DatabaseConfig) -> Result<Self, AppError> {
351        match db_config.db_type {
352            DatabaseType::Postgres => {
353                let pool = PgPoolHandler::try_new(db_config).await?;
354                Ok(DbPool::Postgres(pool))
355            }
356            DatabaseType::Mysql => {
357                let pool = MySqlPoolHandler::try_new(db_config).await?;
358                Ok(DbPool::MySql(pool))
359            }
360            #[allow(unreachable_patterns)]
361            _ => Err(AppError::UnsupportedDatabaseType(
362                db_config.db_type.to_string(),
363            )),
364        }
365    }
366
367    async fn list_tables(&self) -> Result<Vec<TableInfo>, AppError> {
368        match self {
369            DbPool::Postgres(pg_pool) => pg_pool.list_tables().await,
370            DbPool::MySql(mysql_pool) => mysql_pool.list_tables().await,
371        }
372    }
373
374    // Add method signature for getting table schema
375    async fn get_table_schema(&self, table_name: &str) -> Result<TableSchema, AppError> {
376        match self {
377            DbPool::Postgres(pg_pool) => pg_pool.get_table_schema(table_name).await,
378            DbPool::MySql(mysql_pool) => mysql_pool.get_table_schema(table_name).await,
379        }
380    }
381
382    async fn sanitize_query(&self, query: &str, limit: usize) -> Result<String, AppError> {
383        match self {
384            DbPool::Postgres(pg_pool) => pg_pool.sanitize_query(query, limit).await,
385            DbPool::MySql(mysql_pool) => mysql_pool.sanitize_query(query, limit).await,
386        }
387    }
388
389    async fn execute_query(
390        &self,
391        query: &str,
392        limit: Option<usize>,
393    ) -> Result<QueryResult, AppError> {
394        match self {
395            DbPool::Postgres(pg_pool) => pg_pool.execute_query(query, limit).await,
396            DbPool::MySql(mysql_pool) => mysql_pool.execute_query(query, limit).await,
397        }
398    }
399}