1pub use sql_composer;
7
8use sql_composer::composer::ComposedSql;
9
10#[derive(Debug, thiserror::Error)]
12pub enum Error {
13 #[error("composer error: {0}")]
15 Composer(#[from] sql_composer::Error),
16
17 #[error("sqlx error: {0}")]
19 Sqlx(#[from] sqlx::Error),
20
21 #[error("SQL syntax error: {0}")]
23 Syntax(String),
24}
25
26pub type Result<T> = std::result::Result<T, Error>;
28
29#[cfg(feature = "postgres")]
35pub async fn verify_postgres(database_url: &str, statements: &[&ComposedSql]) -> Result<()> {
36 use sqlx::postgres::PgPoolOptions;
37 use sqlx::Executor;
38
39 let pool = PgPoolOptions::new()
40 .max_connections(1)
41 .connect(database_url)
42 .await?;
43
44 for (i, stmt) in statements.iter().enumerate() {
45 pool.execute(sqlx::query(&format!(
46 "PREPARE _sqlc_verify_{i} AS {}",
47 stmt.sql
48 )))
49 .await?;
50
51 pool.execute(sqlx::query(&format!("DEALLOCATE _sqlc_verify_{i}")))
52 .await?;
53 }
54
55 pool.close().await;
56 Ok(())
57}
58
59#[cfg(feature = "validate")]
64pub fn validate_syntax(sql: &str, dialect: sql_composer::Dialect) -> Result<()> {
65 use sqlparser::dialect::{MySqlDialect, PostgreSqlDialect, SQLiteDialect};
66 use sqlparser::parser::Parser;
67
68 let dialect: Box<dyn sqlparser::dialect::Dialect> = match dialect {
69 sql_composer::Dialect::Postgres => Box::new(PostgreSqlDialect {}),
70 sql_composer::Dialect::Mysql => Box::new(MySqlDialect {}),
71 sql_composer::Dialect::Sqlite => Box::new(SQLiteDialect {}),
72 };
73
74 let normalized = normalize_placeholders(sql);
76 Parser::parse_sql(dialect.as_ref(), &normalized).map_err(|e| Error::Syntax(e.to_string()))?;
77
78 Ok(())
79}
80
81#[cfg(feature = "validate")]
83fn normalize_placeholders(sql: &str) -> String {
84 let mut result = String::with_capacity(sql.len());
85 let mut chars = sql.chars().peekable();
86
87 while let Some(ch) = chars.next() {
88 if ch == '$' || ch == '?' {
89 let mut has_digits = false;
91 while let Some(&next) = chars.peek() {
92 if next.is_ascii_digit() {
93 chars.next();
94 has_digits = true;
95 } else {
96 break;
97 }
98 }
99 if has_digits || ch == '?' {
100 result.push('1');
101 } else {
102 result.push(ch);
103 }
104 } else {
105 result.push(ch);
106 }
107 }
108
109 result
110}
111
112#[cfg(test)]
113mod tests {
114 #[cfg(feature = "validate")]
115 mod validate_tests {
116 use crate::{normalize_placeholders, validate_syntax};
117 use sql_composer::Dialect;
118
119 #[test]
120 fn test_validate_syntax_postgres() {
121 validate_syntax("SELECT 1", Dialect::Postgres).unwrap();
122 }
123
124 #[test]
125 fn test_validate_syntax_mysql() {
126 validate_syntax("SELECT 1", Dialect::Mysql).unwrap();
127 }
128
129 #[test]
130 fn test_validate_syntax_sqlite() {
131 validate_syntax("SELECT 1", Dialect::Sqlite).unwrap();
132 }
133
134 #[test]
135 fn test_validate_syntax_invalid() {
136 let result = validate_syntax("SELECTT 1 FROMM", Dialect::Postgres);
137 assert!(result.is_err());
138 }
139
140 #[test]
141 fn test_validate_syntax_with_placeholders() {
142 validate_syntax("SELECT * FROM users WHERE id = $1", Dialect::Postgres).unwrap();
144 }
145
146 #[test]
147 fn test_normalize_placeholders_postgres() {
148 assert_eq!(normalize_placeholders("$1"), "1");
149 assert_eq!(normalize_placeholders("$10"), "1");
150 assert_eq!(
151 normalize_placeholders("WHERE a = $1 AND b = $2"),
152 "WHERE a = 1 AND b = 1"
153 );
154 }
155
156 #[test]
157 fn test_normalize_placeholders_mysql() {
158 assert_eq!(normalize_placeholders("?"), "1");
159 assert_eq!(
160 normalize_placeholders("WHERE a = ? AND b = ?"),
161 "WHERE a = 1 AND b = 1"
162 );
163 }
164
165 #[test]
166 fn test_normalize_placeholders_sqlite() {
167 assert_eq!(normalize_placeholders("?1"), "1");
168 assert_eq!(
169 normalize_placeholders("WHERE a = ?1 AND b = ?2"),
170 "WHERE a = 1 AND b = 1"
171 );
172 }
173
174 #[test]
175 fn test_normalize_preserves_dollar_without_digits() {
176 assert_eq!(normalize_placeholders("$"), "$");
178 }
179 }
180}