Skip to main content

we_trust/
schema.rs

1//! Schema 反射模块
2//! 
3//! 从数据库读取现有表结构信息。
4
5use std::collections::HashMap;
6use crate::{DatabaseBackend, DatabaseConnection, DatabaseResult};
7
8/// 列类型
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum ColumnType {
11    /// 整数类型
12    Integer,
13    /// 浮点类型
14    Real,
15    /// 文本类型
16    Text,
17    /// 二进制类型
18    Blob,
19}
20
21impl ColumnType {
22    /// 从 SQL 类型字符串解析
23    pub fn from_sql(sql: &str) -> Self {
24        match sql.to_uppercase().as_str() {
25            "INTEGER" | "INT" | "BIGINT" | "SMALLINT" | "TINYINT" | "SERIAL" | "BIGSERIAL" => ColumnType::Integer,
26            "REAL" | "FLOAT" | "DOUBLE" | "NUMERIC" | "DECIMAL" | "DOUBLE PRECISION" => ColumnType::Real,
27            "TEXT" | "VARCHAR" | "CHAR" | "STRING" | "VARCHAR(255)" | "TEXT[]" => ColumnType::Text,
28            "BLOB" | "BINARY" | "BYTEA" | "LONGBLOB" => ColumnType::Blob,
29            _ => ColumnType::Text,
30        }
31    }
32}
33
34/// 列定义
35#[derive(Debug, Clone)]
36pub struct ColumnDef {
37    /// 列名
38    pub name: String,
39    /// 列类型
40    pub col_type: ColumnType,
41    /// 是否可空
42    pub nullable: bool,
43    /// 是否主键
44    pub primary_key: bool,
45    /// 是否自增
46    pub auto_increment: bool,
47    /// 默认值
48    pub default_value: Option<String>,
49    /// 是否唯一
50    pub unique: bool,
51}
52
53/// 索引定义
54#[derive(Debug, Clone)]
55pub struct IndexDef {
56    /// 索引名称
57    pub name: String,
58    /// 表名
59    pub table_name: String,
60    /// 列名列表
61    pub columns: Vec<String>,
62    /// 是否唯一索引
63    pub unique: bool,
64}
65
66/// 外键引用行为
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum ReferentialAction {
69    /// 无操作
70    NoAction,
71    /// 限制
72    Restrict,
73    /// 级联
74    Cascade,
75    /// 设为空
76    SetNull,
77    /// 设为默认值
78    SetDefault,
79}
80
81impl ReferentialAction {
82    /// 从字符串解析
83    pub fn from_str(s: &str) -> Self {
84        match s.to_uppercase().as_str() {
85            "CASCADE" => ReferentialAction::Cascade,
86            "RESTRICT" => ReferentialAction::Restrict,
87            "SET NULL" => ReferentialAction::SetNull,
88            "SET DEFAULT" => ReferentialAction::SetDefault,
89            _ => ReferentialAction::NoAction,
90        }
91    }
92}
93
94/// 外键定义
95#[derive(Debug, Clone)]
96pub struct ForeignKeyDef {
97    /// 外键名称
98    pub name: String,
99    /// 本表列名
100    pub column: String,
101    /// 引用表名
102    pub ref_table: String,
103    /// 引用列名
104    pub ref_column: String,
105    /// 更新行为
106    pub on_update: ReferentialAction,
107    /// 删除行为
108    pub on_delete: ReferentialAction,
109}
110
111/// 表结构定义
112#[derive(Debug, Clone)]
113pub struct TableSchema {
114    /// 表名
115    pub name: String,
116    /// 列定义列表
117    pub columns: Vec<ColumnDef>,
118    /// 索引定义列表
119    pub indexes: Vec<IndexDef>,
120    /// 外键约束列表
121    pub foreign_keys: Vec<ForeignKeyDef>,
122}
123
124/// Schema 反射器
125pub struct SchemaReflector<'a> {
126    conn: &'a dyn DatabaseConnection,
127}
128
129impl<'a> SchemaReflector<'a> {
130    /// 创建新的反射器
131    pub fn new(conn: &'a dyn DatabaseConnection) -> Self {
132        Self { conn }
133    }
134
135    /// 获取所有表名
136    pub async fn get_table_names(&self) -> DatabaseResult<Vec<String>> {
137        match self.conn.backend() {
138            DatabaseBackend::Limbo => self.get_table_names_limbo().await,
139            DatabaseBackend::Postgres => self.get_table_names_postgres().await,
140            DatabaseBackend::MySql => self.get_table_names_mysql().await,
141        }
142    }
143
144    /// 获取表的完整结构
145    pub async fn get_table_schema(&self, table_name: &str) -> DatabaseResult<TableSchema> {
146        let columns = self.get_columns(table_name).await?;
147        let indexes = self.get_indexes(table_name).await?;
148        let foreign_keys = self.get_foreign_keys(table_name).await.unwrap_or_default();
149
150        Ok(TableSchema {
151            name: table_name.to_string(),
152            columns,
153            indexes,
154            foreign_keys,
155        })
156    }
157
158    /// 获取表的所有外键约束
159    pub async fn get_foreign_keys(&self, table_name: &str) -> DatabaseResult<Vec<ForeignKeyDef>> {
160        match self.conn.backend() {
161            DatabaseBackend::Limbo => self.get_foreign_keys_limbo(table_name).await,
162            DatabaseBackend::Postgres => self.get_foreign_keys_postgres(table_name).await,
163            DatabaseBackend::MySql => self.get_foreign_keys_mysql(table_name).await,
164        }
165    }
166
167    /// 获取表的所有列
168    pub async fn get_columns(&self, table_name: &str) -> DatabaseResult<Vec<ColumnDef>> {
169        match self.conn.backend() {
170            DatabaseBackend::Limbo => self.get_columns_limbo(table_name).await,
171            DatabaseBackend::Postgres => self.get_columns_postgres(table_name).await,
172            DatabaseBackend::MySql => self.get_columns_mysql(table_name).await,
173        }
174    }
175
176    /// 获取表的所有索引
177    pub async fn get_indexes(&self, table_name: &str) -> DatabaseResult<Vec<IndexDef>> {
178        match self.conn.backend() {
179            DatabaseBackend::Limbo => self.get_indexes_limbo(table_name).await,
180            DatabaseBackend::Postgres => self.get_indexes_postgres(table_name).await,
181            DatabaseBackend::MySql => self.get_indexes_mysql(table_name).await,
182        }
183    }
184
185    /// 获取所有表的 Schema
186    pub async fn get_all_schemas(&self) -> DatabaseResult<HashMap<String, TableSchema>> {
187        let table_names = self.get_table_names().await?;
188        let mut schemas = HashMap::new();
189
190        for name in table_names {
191            let schema = self.get_table_schema(&name).await?;
192            schemas.insert(name, schema);
193        }
194
195        Ok(schemas)
196    }
197
198    /// 检查表是否存在
199    pub async fn table_exists(&self, table_name: &str) -> DatabaseResult<bool> {
200        match self.conn.backend() {
201            DatabaseBackend::Limbo => self.table_exists_limbo(table_name).await,
202            DatabaseBackend::Postgres => self.table_exists_postgres(table_name).await,
203            DatabaseBackend::MySql => self.table_exists_mysql(table_name).await,
204        }
205    }
206
207    async fn get_table_names_limbo(&self) -> DatabaseResult<Vec<String>> {
208        let sql = "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' AND name != '_migrations'";
209        let mut rows = self.conn.query(sql).await?;
210
211        let mut tables = Vec::new();
212        while let Some(row) = rows.next().await? {
213            tables.push(row.get_string(0)?);
214        }
215
216        Ok(tables)
217    }
218
219    async fn get_table_names_postgres(&self) -> DatabaseResult<Vec<String>> {
220        let sql = "SELECT tablename FROM pg_tables WHERE schemaname = 'public' AND tablename != '_migrations'";
221        let mut rows = self.conn.query(sql).await?;
222
223        let mut tables = Vec::new();
224        while let Some(row) = rows.next().await? {
225            tables.push(row.get_string(0)?);
226        }
227
228        Ok(tables)
229    }
230
231    async fn get_table_names_mysql(&self) -> DatabaseResult<Vec<String>> {
232        let sql = "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name != '_migrations'";
233        let mut rows = self.conn.query(sql).await?;
234
235        let mut tables = Vec::new();
236        while let Some(row) = rows.next().await? {
237            tables.push(row.get_string(0)?);
238        }
239
240        Ok(tables)
241    }
242
243    async fn get_columns_limbo(&self, table_name: &str) -> DatabaseResult<Vec<ColumnDef>> {
244        let sql = format!("PRAGMA table_info({})", table_name);
245        let mut rows = self.conn.query(&sql).await?;
246
247        let mut columns = Vec::new();
248        while let Some(row) = rows.next().await? {
249            let name = row.get_string(1)?;
250            let type_str = row.get_string(2)?;
251            let not_null = row.get_i64(3)? != 0;
252            let default_value = row.get_option_string(4)?;
253            let is_pk = row.get_i64(5)? != 0;
254
255            let col_type = ColumnType::from_sql(&type_str);
256
257            columns.push(ColumnDef {
258                name,
259                col_type,
260                nullable: !not_null,
261                primary_key: is_pk,
262                auto_increment: false,
263                default_value,
264                unique: false,
265            });
266        }
267
268        Ok(columns)
269    }
270
271    async fn get_columns_postgres(&self, table_name: &str) -> DatabaseResult<Vec<ColumnDef>> {
272        let sql = format!(
273            "SELECT column_name, data_type, is_nullable, column_default, \
274             EXISTS (SELECT 1 FROM information_schema.key_column_usage k \
275             JOIN information_schema.table_constraints t ON k.constraint_name = t.constraint_name \
276             WHERE k.table_name = '{}' AND k.column_name = c.column_name AND t.constraint_type = 'PRIMARY KEY') AS is_pk, \
277             EXISTS (SELECT 1 FROM information_schema.key_column_usage k \
278             JOIN information_schema.table_constraints t ON k.constraint_name = t.constraint_name \
279             WHERE k.table_name = '{}' AND k.column_name = c.column_name AND t.constraint_type = 'UNIQUE') AS is_unique \
280             FROM information_schema.columns c \
281             WHERE table_name = '{}'",
282            table_name, table_name, table_name
283        );
284        let mut rows = self.conn.query(&sql).await?;
285
286        let mut columns = Vec::new();
287        while let Some(row) = rows.next().await? {
288            let name = row.get_string(0)?;
289            let type_str = row.get_string(1)?;
290            let is_nullable = row.get_string(2)? == "YES";
291            let default_value = row.get_option_string(3)?;
292            let is_pk = row.get_bool(4)?;
293            let is_unique = row.get_bool(5)?;
294
295            let col_type = ColumnType::from_sql(&type_str);
296
297            columns.push(ColumnDef {
298                name,
299                col_type,
300                nullable: is_nullable,
301                primary_key: is_pk,
302                auto_increment: false,
303                default_value,
304                unique: is_unique,
305            });
306        }
307
308        Ok(columns)
309    }
310
311    async fn get_columns_mysql(&self, table_name: &str) -> DatabaseResult<Vec<ColumnDef>> {
312        let sql = format!(
313            "SELECT column_name, data_type, is_nullable, column_default, column_key = 'PRI' AS is_pk, \
314             column_key = 'UNI' AS is_unique, extra LIKE '%auto_increment%' AS is_auto_inc \
315             FROM information_schema.columns WHERE table_name = '{}' AND table_schema = DATABASE()",
316            table_name
317        );
318        let mut rows = self.conn.query(&sql).await?;
319
320        let mut columns = Vec::new();
321        while let Some(row) = rows.next().await? {
322            let name = row.get_string(0)?;
323            let type_str = row.get_string(1)?;
324            let is_nullable = row.get_string(2)? == "YES";
325            let default_value = row.get_option_string(3)?;
326            let is_pk = row.get_bool(4)?;
327            let is_unique = row.get_bool(5)?;
328            let is_auto_inc = row.get_bool(6)?;
329
330            let col_type = ColumnType::from_sql(&type_str);
331
332            columns.push(ColumnDef {
333                name,
334                col_type,
335                nullable: is_nullable,
336                primary_key: is_pk,
337                auto_increment: is_auto_inc,
338                default_value,
339                unique: is_unique,
340            });
341        }
342
343        Ok(columns)
344    }
345
346    async fn get_indexes_limbo(&self, table_name: &str) -> DatabaseResult<Vec<IndexDef>> {
347        let sql = format!("PRAGMA index_list({})", table_name);
348        let mut rows = self.conn.query(&sql).await?;
349
350        let mut indexes = Vec::new();
351        while let Some(row) = rows.next().await? {
352            let index_name = row.get_string(1)?;
353            let unique = row.get_i64(2)? != 0;
354
355            let columns = self.get_index_columns_limbo(&index_name).await?;
356
357            indexes.push(IndexDef {
358                name: index_name,
359                table_name: table_name.to_string(),
360                columns,
361                unique,
362            });
363        }
364
365        Ok(indexes)
366    }
367
368    async fn get_indexes_postgres(&self, table_name: &str) -> DatabaseResult<Vec<IndexDef>> {
369        let sql = format!(
370            "SELECT i.relname AS index_name, ix.indisunique AS is_unique, \
371             array_agg(a.attname ORDER BY array_position(ix.indkey, a.attnum)) AS columns \
372             FROM pg_index ix \
373             JOIN pg_class i ON i.oid = ix.indexrelid \
374             JOIN pg_class t ON t.oid = ix.indrelid \
375             JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey) \
376             WHERE t.relname = '{}' \
377             GROUP BY i.relname, ix.indisunique",
378            table_name
379        );
380        let mut rows = self.conn.query(&sql).await?;
381
382        let mut indexes = Vec::new();
383        while let Some(row) = rows.next().await? {
384            let index_name = row.get_string(0)?;
385            let unique = row.get_bool(1)?;
386            let columns_str = row.get_string(2)?;
387            let columns: Vec<String> = columns_str
388                .trim_matches(|c| c == '{' || c == '}')
389                .split(',')
390                .map(|s| s.trim().trim_matches('"').to_string())
391                .collect();
392
393            indexes.push(IndexDef {
394                name: index_name,
395                table_name: table_name.to_string(),
396                columns,
397                unique,
398            });
399        }
400
401        Ok(indexes)
402    }
403
404    async fn get_indexes_mysql(&self, table_name: &str) -> DatabaseResult<Vec<IndexDef>> {
405        let sql = format!(
406            "SELECT index_name, non_unique = 0 AS is_unique, \
407             GROUP_CONCAT(column_name ORDER BY seq_in_index SEPARATOR ',') AS columns \
408             FROM information_schema.statistics \
409             WHERE table_name = '{}' AND table_schema = DATABASE() \
410             GROUP BY index_name, non_unique",
411            table_name
412        );
413        let mut rows = self.conn.query(&sql).await?;
414
415        let mut indexes = Vec::new();
416        while let Some(row) = rows.next().await? {
417            let index_name = row.get_string(0)?;
418            let unique = row.get_bool(1)?;
419            let columns_str = row.get_string(2)?;
420            let columns: Vec<String> = columns_str.split(',').map(|s| s.trim().to_string()).collect();
421
422            indexes.push(IndexDef {
423                name: index_name,
424                table_name: table_name.to_string(),
425                columns,
426                unique,
427            });
428        }
429
430        Ok(indexes)
431    }
432
433    async fn get_foreign_keys_limbo(&self, table_name: &str) -> DatabaseResult<Vec<ForeignKeyDef>> {
434        let sql = format!("PRAGMA foreign_key_list({})", table_name);
435        let mut rows = self.conn.query(&sql).await?;
436
437        let mut foreign_keys = Vec::new();
438        while let Some(row) = rows.next().await? {
439            let id = row.get_i64(0)?;
440            let seq = row.get_i64(1)?;
441            if seq != 0 {
442                continue;
443            }
444            let ref_table = row.get_string(2)?;
445            let column = row.get_string(3)?;
446            let ref_column = row.get_string(4)?;
447            let on_update_str = row.get_string(5)?;
448            let on_delete_str = row.get_string(6)?;
449
450            let on_update = ReferentialAction::from_str(&on_update_str);
451            let on_delete = ReferentialAction::from_str(&on_delete_str);
452
453            let fk_name = format!("fk_{}_{}", table_name, column);
454
455            foreign_keys.push(ForeignKeyDef {
456                name: fk_name,
457                column,
458                ref_table,
459                ref_column,
460                on_update,
461                on_delete,
462            });
463        }
464
465        Ok(foreign_keys)
466    }
467
468    async fn get_foreign_keys_postgres(&self, table_name: &str) -> DatabaseResult<Vec<ForeignKeyDef>> {
469        let sql = format!(
470            "SELECT tc.constraint_name, kcu.column_name, ccu.table_name AS foreign_table_name, \
471             ccu.column_name AS foreign_column_name, rc.update_rule, rc.delete_rule \
472             FROM information_schema.table_constraints tc \
473             JOIN information_schema.key_column_usage kcu \
474             ON tc.constraint_name = kcu.constraint_name \
475             JOIN information_schema.constraint_column_usage ccu \
476             ON ccu.constraint_name = tc.constraint_name \
477             JOIN information_schema.referential_constraints rc \
478             ON tc.constraint_name = rc.constraint_name \
479             WHERE tc.table_name = '{}' AND tc.constraint_type = 'FOREIGN KEY'",
480            table_name
481        );
482        let mut rows = self.conn.query(&sql).await?;
483
484        let mut foreign_keys = Vec::new();
485        while let Some(row) = rows.next().await? {
486            let name = row.get_string(0)?;
487            let column = row.get_string(1)?;
488            let ref_table = row.get_string(2)?;
489            let ref_column = row.get_string(3)?;
490            let on_update_str = row.get_string(4)?;
491            let on_delete_str = row.get_string(5)?;
492
493            let on_update = ReferentialAction::from_str(&on_update_str);
494            let on_delete = ReferentialAction::from_str(&on_delete_str);
495
496            foreign_keys.push(ForeignKeyDef {
497                name,
498                column,
499                ref_table,
500                ref_column,
501                on_update,
502                on_delete,
503            });
504        }
505
506        Ok(foreign_keys)
507    }
508
509    async fn get_foreign_keys_mysql(&self, table_name: &str) -> DatabaseResult<Vec<ForeignKeyDef>> {
510        let sql = format!(
511            "SELECT kcu.constraint_name, kcu.column_name, kcu.referenced_table_name, kcu.referenced_column_name, \
512             rc.update_rule, rc.delete_rule \
513             FROM information_schema.key_column_usage kcu \
514             JOIN information_schema.referential_constraints rc \
515             ON kcu.constraint_name = rc.constraint_name \
516             WHERE kcu.table_name = '{}' AND kcu.table_schema = DATABASE() AND kcu.referenced_table_name IS NOT NULL",
517            table_name
518        );
519        let mut rows = self.conn.query(&sql).await?;
520
521        let mut foreign_keys = Vec::new();
522        while let Some(row) = rows.next().await? {
523            let name = row.get_string(0)?;
524            let column = row.get_string(1)?;
525            let ref_table = row.get_string(2)?;
526            let ref_column = row.get_string(3)?;
527            let on_update_str = row.get_string(4)?;
528            let on_delete_str = row.get_string(5)?;
529
530            let on_update = ReferentialAction::from_str(&on_update_str);
531            let on_delete = ReferentialAction::from_str(&on_delete_str);
532
533            foreign_keys.push(ForeignKeyDef {
534                name,
535                column,
536                ref_table,
537                ref_column,
538                on_update,
539                on_delete,
540            });
541        }
542
543        Ok(foreign_keys)
544    }
545
546    async fn table_exists_limbo(&self, table_name: &str) -> DatabaseResult<bool> {
547        let sql = format!("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='{}'", table_name);
548        let mut rows = self.conn.query(&sql).await?;
549
550        if let Some(row) = rows.next().await? {
551            let count = row.get_i64(0)?;
552            Ok(count > 0)
553        }
554        else {
555            Ok(false)
556        }
557    }
558
559    async fn table_exists_postgres(&self, table_name: &str) -> DatabaseResult<bool> {
560        let sql = format!("SELECT COUNT(*) FROM pg_tables WHERE schemaname = 'public' AND tablename = '{}'", table_name);
561        let mut rows = self.conn.query(&sql).await?;
562
563        if let Some(row) = rows.next().await? {
564            let count = row.get_i64(0)?;
565            Ok(count > 0)
566        }
567        else {
568            Ok(false)
569        }
570    }
571
572    async fn table_exists_mysql(&self, table_name: &str) -> DatabaseResult<bool> {
573        let sql = format!(
574            "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = '{}'",
575            table_name
576        );
577        let mut rows = self.conn.query(&sql).await?;
578
579        if let Some(row) = rows.next().await? {
580            let count = row.get_i64(0)?;
581            Ok(count > 0)
582        }
583        else {
584            Ok(false)
585        }
586    }
587
588    async fn get_index_columns_limbo(&self, index_name: &str) -> DatabaseResult<Vec<String>> {
589        let sql = format!("PRAGMA index_info({})", index_name);
590        let mut rows = self.conn.query(&sql).await?;
591
592        let mut columns = Vec::new();
593        while let Some(row) = rows.next().await? {
594            columns.push(row.get_string(2)?);
595        }
596
597        Ok(columns)
598    }
599}