qail_core/transformer/
registry.rs1use sqlparser::ast::Statement;
4use sqlparser::dialect::PostgreSqlDialect;
5use sqlparser::parser::Parser;
6
7use super::traits::*;
8use super::patterns::*;
9
10pub struct PatternRegistry {
12 patterns: Vec<Box<dyn SqlPattern>>,
13}
14
15impl Default for PatternRegistry {
16 fn default() -> Self {
17 Self::new()
18 }
19}
20
21impl PatternRegistry {
22 pub fn new() -> Self {
24 let mut registry = Self {
25 patterns: Vec::new(),
26 };
27
28 registry.register(Box::new(SelectPattern));
30 registry.register(Box::new(InsertPattern));
31 registry.register(Box::new(UpdatePattern));
32 registry.register(Box::new(DeletePattern));
33
34 registry
35 }
36
37 pub fn register(&mut self, pattern: Box<dyn SqlPattern>) {
39 self.patterns.push(pattern);
40 self.patterns.sort_by_key(|p| std::cmp::Reverse(p.priority()));
42 }
43
44 pub fn find_pattern(&self, stmt: &Statement, ctx: &MatchContext) -> Option<&dyn SqlPattern> {
46 for pattern in &self.patterns {
47 if pattern.matches(stmt, ctx) {
48 return Some(pattern.as_ref());
49 }
50 }
51 None
52 }
53
54 pub fn transform_sql(&self, sql: &str, ctx: &TransformContext) -> Result<String, String> {
56 let dialect = PostgreSqlDialect {};
57 let ast = Parser::parse_sql(&dialect, sql)
58 .map_err(|e| format!("Parse error: {}", e))?;
59
60 if ast.is_empty() {
61 return Err("Empty SQL".to_string());
62 }
63
64 let stmt = &ast[0];
65 let match_ctx = MatchContext::default();
66
67 let pattern = self
68 .find_pattern(stmt, &match_ctx)
69 .ok_or_else(|| "No matching pattern found".to_string())?;
70
71 let data = pattern
72 .extract(stmt, &match_ctx)
73 .map_err(|e| e.to_string())?;
74
75 pattern
76 .transform(&data, ctx)
77 .map_err(|e| e.to_string())
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 #[test]
86 fn test_registry_select() {
87 let registry = PatternRegistry::new();
88 let ctx = TransformContext {
89 include_imports: true,
90 ..Default::default()
91 };
92
93 let result = registry.transform_sql(
94 "SELECT id, name FROM users WHERE id = $1",
95 &ctx,
96 );
97
98 assert!(result.is_ok());
99 let code = result.unwrap();
100 assert!(code.contains("Qail::get"));
101 assert!(code.contains("users"));
102 }
103}