Skip to main content

qail_core/transpiler/
mod.rs

1//! SQL Transpiler for QAIL AST.
2//!
3
4/// Condition-to-SQL conversion.
5pub mod conditions;
6/// DDL statement transpilation (CREATE TABLE, ALTER TABLE, etc.).
7pub mod ddl;
8/// SQL dialect selection (PostgreSQL primary; SQLite compatibility retained).
9pub mod dialect;
10/// DML statement transpilation (INSERT, UPDATE, DELETE).
11pub mod dml;
12pub(crate) mod identifier;
13/// RLS policy transpilation (CREATE POLICY).
14pub mod policy;
15/// Core SQL generation utilities.
16pub mod sql;
17/// Transpiler traits (SqlGenerator, escape_identifier).
18pub mod traits;
19
20/// NoSQL/vector transpilers.
21pub mod nosql;
22pub use nosql::dynamo::ToDynamo;
23pub use nosql::mongo::ToMongo;
24pub use nosql::qdrant::ToQdrant;
25
26#[cfg(test)]
27mod tests;
28
29use crate::ast::*;
30pub use conditions::ConditionToSql;
31pub use dialect::Dialect;
32pub use traits::SqlGenerator;
33pub use traits::{escape_identifier, escape_sql_string_literal};
34
35/// Result of transpilation with extracted parameters.
36#[derive(Debug, Clone, PartialEq, Default)]
37pub struct TranspileResult {
38    /// The SQL template with placeholders (e.g., $1, $2 or ?, ?)
39    pub sql: String,
40    /// The extracted parameter values in order
41    pub params: Vec<Value>,
42    /// Names of named parameters in order they appear (for :name → $n mapping)
43    pub named_params: Vec<String>,
44}
45
46impl TranspileResult {
47    /// Create a new TranspileResult.
48    pub fn new(sql: impl Into<String>, params: Vec<Value>) -> Self {
49        Self {
50            sql: sql.into(),
51            params,
52            named_params: vec![],
53        }
54    }
55
56    /// Create a result with no parameters.
57    pub fn sql_only(sql: impl Into<String>) -> Self {
58        Self {
59            sql: sql.into(),
60            params: Vec::new(),
61            named_params: Vec::new(),
62        }
63    }
64}
65
66/// Trait for converting AST nodes to parameterized SQL.
67pub trait ToSqlParameterized {
68    /// Convert to SQL with extracted parameters (default dialect).
69    fn to_sql_parameterized(&self) -> TranspileResult {
70        self.to_sql_parameterized_with_dialect(Dialect::default())
71    }
72    /// Convert to SQL with extracted parameters for specific dialect.
73    fn to_sql_parameterized_with_dialect(&self, dialect: Dialect) -> TranspileResult;
74}
75
76/// Trait for converting AST nodes to SQL.
77pub trait ToSql {
78    /// Convert this node to a SQL string using default dialect.
79    fn to_sql(&self) -> String {
80        self.to_sql_with_dialect(Dialect::default())
81    }
82    /// Convert this node to a SQL string with specific dialect.
83    fn to_sql_with_dialect(&self, dialect: Dialect) -> String;
84}
85
86impl ToSql for Qail {
87    fn to_sql_with_dialect(&self, dialect: Dialect) -> String {
88        match self.action {
89            Action::Get => dml::select::build_select(self, dialect),
90            Action::Cnt => {
91                // Build a count query: SELECT COUNT(*) FROM table WHERE ...
92                let mut count_ast = self.clone();
93                count_ast.action = Action::Get;
94                count_ast.columns = vec![Expr::Aggregate {
95                    col: "*".to_string(),
96                    func: AggregateFunc::Count,
97                    distinct: false,
98                    filter: None,
99                    alias: None,
100                }];
101                dml::select::build_select(&count_ast, dialect)
102            }
103            Action::Set => dml::update::build_update(self, dialect),
104            Action::Del => dml::delete::build_delete(self, dialect),
105            Action::Add => dml::insert::build_insert(self, dialect),
106            Action::Merge => dml::merge::build_merge(self, dialect),
107            Action::Gen => format!("-- gen::{}  (generates Rust struct, not SQL)", self.table),
108            Action::Make => ddl::build_create_table(self, dialect),
109            Action::Mod => ddl::build_alter_table(self, dialect),
110            Action::Over => dml::window::build_window(self, dialect),
111            Action::With => dml::cte::build_cte(self, dialect),
112            Action::Index => ddl::build_create_index(self, dialect),
113            Action::DropIndex => format!("DROP INDEX IF EXISTS {}", escape_identifier(&self.table)),
114            Action::Alter => ddl::build_alter_add_column(self, dialect),
115            Action::AlterAddConstraint => ddl::build_alter_add_check_constraint(self, dialect),
116            Action::AlterDropConstraint => ddl::build_alter_drop_constraint(self, dialect),
117            Action::AlterDrop => ddl::build_alter_drop_column(self, dialect),
118            Action::AlterType => ddl::build_alter_column_type(self, dialect),
119            // Stubs
120            Action::TxnStart => "BEGIN TRANSACTION;".to_string(), // Default stub
121            Action::TxnCommit => "COMMIT;".to_string(),
122            Action::TxnRollback => "ROLLBACK;".to_string(),
123            Action::Put => dml::upsert::build_upsert(self, dialect),
124            Action::Drop => format!("DROP TABLE {}", escape_identifier(&self.table)),
125            Action::DropCol | Action::RenameCol => ddl::build_alter_column(self, dialect),
126            // JSON features
127            Action::JsonTable => dml::json_table::build_json_table(self, dialect),
128            // COPY protocol (AST-native in qail-pg, generates SELECT for fallback)
129            Action::Export => dml::select::build_select(self, dialect),
130            // TRUNCATE TABLE
131            Action::Truncate => format!("TRUNCATE TABLE {}", escape_identifier(&self.table)),
132            // EXPLAIN - wrap SELECT query
133            Action::Explain => format!("EXPLAIN {}", dml::select::build_select(self, dialect)),
134            // EXPLAIN ANALYZE - execute and analyze query
135            Action::ExplainAnalyze => format!(
136                "EXPLAIN ANALYZE {}",
137                dml::select::build_select(self, dialect)
138            ),
139            // LOCK TABLE
140            Action::Lock => format!(
141                "LOCK TABLE {} IN ACCESS EXCLUSIVE MODE",
142                escape_identifier(&self.table)
143            ),
144            // CREATE MATERIALIZED VIEW - uses source_query for the view definition
145            Action::CreateMaterializedView => {
146                if let Some(source) = &self.source_query {
147                    format!(
148                        "CREATE MATERIALIZED VIEW {} AS {}",
149                        escape_identifier(&self.table),
150                        source.to_sql_with_dialect(dialect)
151                    )
152                } else if let Some(query) = &self.payload {
153                    match checked_sql_query_fragment(query, "materialized view query") {
154                        Ok(query) => format!(
155                            "CREATE MATERIALIZED VIEW {} AS {}",
156                            escape_identifier(&self.table),
157                            query
158                        ),
159                        Err(err) => err,
160                    }
161                } else {
162                    format!(
163                        "CREATE MATERIALIZED VIEW {} AS {}",
164                        escape_identifier(&self.table),
165                        dml::select::build_select(self, dialect)
166                    )
167                }
168            }
169            // REFRESH MATERIALIZED VIEW
170            Action::RefreshMaterializedView => {
171                format!(
172                    "REFRESH MATERIALIZED VIEW {}",
173                    escape_identifier(&self.table)
174                )
175            }
176            // DROP MATERIALIZED VIEW
177            Action::DropMaterializedView => {
178                format!(
179                    "DROP MATERIALIZED VIEW IF EXISTS {}",
180                    escape_identifier(&self.table)
181                )
182            }
183            // LISTEN/NOTIFY (Pub/Sub)
184            Action::Listen => {
185                if let Some(ch) = &self.channel {
186                    format!("LISTEN {}", quote_single_identifier(ch))
187                } else {
188                    "LISTEN".to_string()
189                }
190            }
191            Action::Notify => {
192                if let Some(ch) = &self.channel {
193                    if let Some(msg) = &self.payload {
194                        format!(
195                            "NOTIFY {}, '{}'",
196                            quote_single_identifier(ch),
197                            escape_sql_string_literal(msg)
198                        )
199                    } else {
200                        format!("NOTIFY {}", quote_single_identifier(ch))
201                    }
202                } else {
203                    "NOTIFY".to_string()
204                }
205            }
206            Action::Unlisten => {
207                if let Some(ch) = &self.channel {
208                    format!("UNLISTEN {}", quote_single_identifier(ch))
209                } else {
210                    "UNLISTEN *".to_string()
211                }
212            }
213            // Savepoints
214            Action::Savepoint => {
215                if let Some(name) = &self.savepoint_name {
216                    format!("SAVEPOINT {}", quote_single_identifier(name))
217                } else {
218                    "SAVEPOINT".to_string()
219                }
220            }
221            Action::ReleaseSavepoint => {
222                if let Some(name) = &self.savepoint_name {
223                    format!("RELEASE SAVEPOINT {}", quote_single_identifier(name))
224                } else {
225                    "RELEASE SAVEPOINT".to_string()
226                }
227            }
228            Action::RollbackToSavepoint => {
229                if let Some(name) = &self.savepoint_name {
230                    format!("ROLLBACK TO SAVEPOINT {}", quote_single_identifier(name))
231                } else {
232                    "ROLLBACK TO SAVEPOINT".to_string()
233                }
234            }
235            // Views
236            Action::CreateView => {
237                if let Some(source) = &self.source_query {
238                    format!(
239                        "CREATE VIEW {} AS {}",
240                        escape_identifier(&self.table),
241                        source.to_sql_with_dialect(dialect)
242                    )
243                } else if let Some(query) = &self.payload {
244                    match checked_sql_query_fragment(query, "view query") {
245                        Ok(query) => {
246                            format!(
247                                "CREATE VIEW {} AS {}",
248                                escape_identifier(&self.table),
249                                query
250                            )
251                        }
252                        Err(err) => err,
253                    }
254                } else {
255                    format!(
256                        "CREATE VIEW {} AS {}",
257                        escape_identifier(&self.table),
258                        dml::select::build_select(self, dialect)
259                    )
260                }
261            }
262            Action::DropView => format!("DROP VIEW IF EXISTS {}", escape_identifier(&self.table)),
263            // Vector database operations - use qail-qdrant driver instead
264            operators::Action::Search | operators::Action::Upsert | operators::Action::Scroll => {
265                format!(
266                    "-- Vector operation {:?} not supported in SQL. Use qail-qdrant driver.",
267                    self.action
268                )
269            }
270            operators::Action::CreateCollection | operators::Action::DeleteCollection => {
271                format!(
272                    "-- Vector DDL {:?} not supported in SQL. Use qail-qdrant driver.",
273                    self.action
274                )
275            }
276            // Function and Trigger operations
277            operators::Action::CreateFunction => {
278                if let Some(func) = &self.function_def {
279                    let Some(args) = function_args_to_sql(&func.args) else {
280                        return "/* ERROR: Invalid function arguments */".to_string();
281                    };
282                    if !is_safe_sql_type_fragment(&func.returns) {
283                        return "/* ERROR: Invalid function return type */".to_string();
284                    }
285                    let lang = func.language.as_deref().unwrap_or("plpgsql");
286                    let volatility = if let Some(volatility) = func.volatility.as_deref() {
287                        if volatility.trim().is_empty() {
288                            String::new()
289                        } else if let Some(volatility) = volatility_to_sql(volatility) {
290                            format!(" {volatility}")
291                        } else {
292                            return "/* ERROR: Invalid function volatility */".to_string();
293                        }
294                    } else {
295                        String::new()
296                    };
297                    let body = dollar_quote_block(&func.body);
298                    format!(
299                        "CREATE OR REPLACE FUNCTION {}({}) RETURNS {} LANGUAGE {}{} AS {}",
300                        escape_identifier(&func.name),
301                        args,
302                        func.returns.trim(),
303                        escape_identifier(lang),
304                        volatility,
305                        body
306                    )
307                } else {
308                    "-- CreateFunction requires function_def".to_string()
309                }
310            }
311            operators::Action::DropFunction => {
312                if let Some(signature) = &self.payload {
313                    format!(
314                        "DROP FUNCTION IF EXISTS {}",
315                        function_signature_to_sql(signature)
316                    )
317                } else {
318                    format!(
319                        "DROP FUNCTION IF EXISTS {}()",
320                        escape_identifier(&self.table)
321                    )
322                }
323            }
324            operators::Action::CreateTrigger => {
325                if let Some(trig) = &self.trigger_def {
326                    let timing = match trig.timing {
327                        crate::ast::TriggerTiming::Before => "BEFORE",
328                        crate::ast::TriggerTiming::After => "AFTER",
329                        crate::ast::TriggerTiming::InsteadOf => "INSTEAD OF",
330                    };
331                    let events: Vec<String> = trig
332                        .events
333                        .iter()
334                        .map(|e| match e {
335                            crate::ast::TriggerEvent::Insert => "INSERT".to_string(),
336                            crate::ast::TriggerEvent::Update if !trig.update_columns.is_empty() => {
337                                format!(
338                                    "UPDATE OF {}",
339                                    trig.update_columns
340                                        .iter()
341                                        .map(|column| escape_identifier(column))
342                                        .collect::<Vec<_>>()
343                                        .join(", ")
344                                )
345                            }
346                            crate::ast::TriggerEvent::Update => "UPDATE".to_string(),
347                            crate::ast::TriggerEvent::Delete => "DELETE".to_string(),
348                            crate::ast::TriggerEvent::Truncate => "TRUNCATE".to_string(),
349                        })
350                        .collect();
351                    let for_each = if trig.for_each_row {
352                        "FOR EACH ROW"
353                    } else {
354                        "FOR EACH STATEMENT"
355                    };
356                    format!(
357                        "CREATE TRIGGER {} {} {} ON {} {} EXECUTE FUNCTION {}()",
358                        escape_identifier(&trig.name),
359                        timing,
360                        events.join(" OR "),
361                        escape_identifier(&trig.table),
362                        for_each,
363                        escape_identifier(&trig.execute_function)
364                    )
365                } else {
366                    "-- CreateTrigger requires trigger_def".to_string()
367                }
368            }
369            operators::Action::DropTrigger => {
370                if let Some((table, trigger)) = self.table.rsplit_once('.') {
371                    format!(
372                        "DROP TRIGGER IF EXISTS {} ON {}",
373                        escape_identifier(trigger),
374                        escape_identifier(table)
375                    )
376                } else {
377                    format!("DROP TRIGGER IF EXISTS {}", escape_identifier(&self.table))
378                }
379            }
380            // Phase 7: Extensions, Comments, Sequences
381            Action::CreateExtension => ddl::build_create_extension(self, dialect),
382            Action::DropExtension => ddl::build_drop_extension(self, dialect),
383            Action::CommentOn => ddl::build_comment_on(self, dialect),
384            Action::CreateSequence => ddl::build_create_sequence(self, dialect),
385            Action::DropSequence => ddl::build_drop_sequence(self, dialect),
386            Action::CreateEnum => ddl::build_create_enum(self, dialect),
387            Action::DropEnum => ddl::build_drop_enum(self, dialect),
388            Action::AlterEnumAddValue => ddl::build_alter_enum_add_value(self, dialect),
389            // ALTER TABLE property operations (from diff engine)
390            Action::AlterSetNotNull => {
391                let [Expr::Named(col)] = self.columns.as_slice() else {
392                    return "/* ERROR: ALTER SET NOT NULL requires exactly one named column */"
393                        .to_string();
394                };
395                if col.trim().is_empty() {
396                    return "/* ERROR: ALTER SET NOT NULL column cannot be empty */".to_string();
397                }
398                format!(
399                    "ALTER TABLE {} ALTER COLUMN {} SET NOT NULL",
400                    escape_identifier(&self.table),
401                    escape_identifier(col)
402                )
403            }
404            Action::AlterDropNotNull => {
405                let [Expr::Named(col)] = self.columns.as_slice() else {
406                    return "/* ERROR: ALTER DROP NOT NULL requires exactly one named column */"
407                        .to_string();
408                };
409                if col.trim().is_empty() {
410                    return "/* ERROR: ALTER DROP NOT NULL column cannot be empty */".to_string();
411                }
412                format!(
413                    "ALTER TABLE {} ALTER COLUMN {} DROP NOT NULL",
414                    escape_identifier(&self.table),
415                    escape_identifier(col)
416                )
417            }
418            Action::AlterSetDefault => {
419                let [Expr::Named(col)] = self.columns.as_slice() else {
420                    return "/* ERROR: ALTER SET DEFAULT requires exactly one named column */"
421                        .to_string();
422                };
423                if col.trim().is_empty() {
424                    return "/* ERROR: ALTER SET DEFAULT column cannot be empty */".to_string();
425                }
426                let Some(default_expr) = self.payload.as_deref() else {
427                    return "/* ERROR: ALTER SET DEFAULT requires a default expression */"
428                        .to_string();
429                };
430                if default_expr.trim().is_empty()
431                    || default_expr.contains('\0')
432                    || contains_unquoted_statement_delimiter(default_expr)
433                {
434                    return "/* ERROR: Invalid default expression */".to_string();
435                }
436                format!(
437                    "ALTER TABLE {} ALTER COLUMN {} SET DEFAULT {}",
438                    escape_identifier(&self.table),
439                    escape_identifier(col),
440                    default_expr.trim()
441                )
442            }
443            Action::AlterDropDefault => {
444                let [Expr::Named(col)] = self.columns.as_slice() else {
445                    return "/* ERROR: ALTER DROP DEFAULT requires exactly one named column */"
446                        .to_string();
447                };
448                if col.trim().is_empty() {
449                    return "/* ERROR: ALTER DROP DEFAULT column cannot be empty */".to_string();
450                }
451                format!(
452                    "ALTER TABLE {} ALTER COLUMN {} DROP DEFAULT",
453                    escape_identifier(&self.table),
454                    escape_identifier(col)
455                )
456            }
457            Action::AlterEnableRls => {
458                format!(
459                    "ALTER TABLE {} ENABLE ROW LEVEL SECURITY",
460                    escape_identifier(&self.table)
461                )
462            }
463            Action::AlterDisableRls => {
464                format!(
465                    "ALTER TABLE {} DISABLE ROW LEVEL SECURITY",
466                    escape_identifier(&self.table)
467                )
468            }
469            Action::AlterForceRls => {
470                format!(
471                    "ALTER TABLE {} FORCE ROW LEVEL SECURITY",
472                    escape_identifier(&self.table)
473                )
474            }
475            Action::AlterNoForceRls => {
476                format!(
477                    "ALTER TABLE {} NO FORCE ROW LEVEL SECURITY",
478                    escape_identifier(&self.table)
479                )
480            }
481            // Session & procedural commands
482            Action::Call => {
483                format!("CALL {}", call_target_to_sql(&self.table))
484            }
485            Action::Do => {
486                let body = self.payload.as_deref().unwrap_or("");
487                let lang = if self.table.is_empty() {
488                    "plpgsql"
489                } else {
490                    &self.table
491                };
492                format!(
493                    "DO {} LANGUAGE {}",
494                    dollar_quote_block(body),
495                    escape_identifier(lang)
496                )
497            }
498            Action::SessionSet => {
499                let value = self.payload.as_deref().unwrap_or("");
500                format!(
501                    "SET {} = '{}'",
502                    session_setting_name_to_sql(&self.table),
503                    escape_sql_string_literal(value)
504                )
505            }
506            Action::SessionShow => {
507                format!("SHOW {}", session_setting_name_to_sql(&self.table))
508            }
509            Action::SessionReset => {
510                format!("RESET {}", session_setting_name_to_sql(&self.table))
511            }
512            Action::CreateDatabase => {
513                format!("CREATE DATABASE {}", escape_identifier(&self.table))
514            }
515            Action::DropDatabase => {
516                format!("DROP DATABASE IF EXISTS {}", escape_identifier(&self.table))
517            }
518            Action::Grant => {
519                let role = self.payload.as_deref().unwrap_or("");
520                if let Some(privs) = privileges_to_sql(&self.columns) {
521                    format!(
522                        "GRANT {} ON {} TO {}",
523                        privs,
524                        escape_identifier(&self.table),
525                        escape_identifier(role)
526                    )
527                } else {
528                    "/* ERROR: Invalid privileges */".to_string()
529                }
530            }
531            Action::Revoke => {
532                let role = self.payload.as_deref().unwrap_or("");
533                if let Some(privs) = privileges_to_sql(&self.columns) {
534                    format!(
535                        "REVOKE {} ON {} FROM {}",
536                        privs,
537                        escape_identifier(&self.table),
538                        escape_identifier(role)
539                    )
540                } else {
541                    "/* ERROR: Invalid privileges */".to_string()
542                }
543            }
544            Action::CreatePolicy => {
545                if let Some(policy) = &self.policy_def {
546                    policy::create_policy_sql(policy)
547                } else {
548                    "-- CreatePolicy requires policy_def".to_string()
549                }
550            }
551            Action::DropPolicy => {
552                if let Some(policy) = &self.policy_def {
553                    policy::drop_policy_sql(&policy.name, &policy.table)
554                } else if let Some(policy_name) = &self.payload {
555                    policy::drop_policy_sql(policy_name, &self.table)
556                } else {
557                    "-- DropPolicy requires policy name + table".to_string()
558                }
559            }
560        }
561    }
562}
563
564fn session_setting_name_to_sql(name: &str) -> String {
565    if is_valid_session_setting_name(name) {
566        name.to_string()
567    } else {
568        escape_identifier(name)
569    }
570}
571
572fn quote_single_identifier(name: &str) -> String {
573    format!("\"{}\"", name.replace('"', "\"\""))
574}
575
576fn dollar_quote_block(body: &str) -> String {
577    for idx in 0..=body.len() {
578        let tag = if idx == 0 {
579            String::new()
580        } else {
581            format!("qail_body_{idx}")
582        };
583        let delimiter = format!("${tag}$");
584        if !body.contains(&delimiter) {
585            return format!("{delimiter} {body} {delimiter}");
586        }
587    }
588
589    format!("'{}'", escape_sql_string_literal(body))
590}
591
592fn call_target_to_sql(target: &str) -> String {
593    let target = target.trim().trim_end_matches(';').trim();
594    if target.is_empty()
595        || target.contains('\0')
596        || target.contains(';')
597        || target.contains("--")
598        || target.contains("/*")
599        || target.contains("*/")
600    {
601        return escape_identifier(target);
602    }
603
604    match target.split_once('(') {
605        Some((name, args)) if args.ends_with(')') && !args[..args.len() - 1].contains('(') => {
606            format!("{}({}", escape_identifier(name.trim()), args)
607        }
608        None => escape_identifier(target),
609        _ => escape_identifier(target),
610    }
611}
612
613fn contains_unquoted_statement_delimiter(value: &str) -> bool {
614    let bytes = value.as_bytes();
615    let mut i = 0;
616    let mut in_single = false;
617    let mut in_double = false;
618
619    while i < bytes.len() {
620        let b = bytes[i];
621        if b == 0 {
622            return true;
623        }
624
625        if in_single {
626            if b == b'\'' {
627                if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
628                    i += 2;
629                    continue;
630                }
631                in_single = false;
632            }
633            i += 1;
634            continue;
635        }
636
637        if in_double {
638            if b == b'"' {
639                if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
640                    i += 2;
641                    continue;
642                }
643                in_double = false;
644            }
645            i += 1;
646            continue;
647        }
648
649        match b {
650            b'\'' => in_single = true,
651            b'"' => in_double = true,
652            b';' => return true,
653            b'-' if i + 1 < bytes.len() && bytes[i + 1] == b'-' => return true,
654            b'/' if i + 1 < bytes.len() && bytes[i + 1] == b'*' => return true,
655            _ => {}
656        }
657        i += 1;
658    }
659
660    false
661}
662
663fn checked_sql_query_fragment(query: &str, context: &str) -> Result<String, String> {
664    let query = query.trim();
665    if query.is_empty() || query.contains('\0') || contains_unquoted_statement_delimiter(query) {
666        return Err(format!("/* ERROR: Invalid {context} */"));
667    }
668    Ok(query.to_string())
669}
670
671fn privilege_to_sql(privilege: &str) -> Option<&'static str> {
672    match privilege.trim().to_ascii_uppercase().as_str() {
673        "SELECT" => Some("SELECT"),
674        "INSERT" => Some("INSERT"),
675        "UPDATE" => Some("UPDATE"),
676        "DELETE" => Some("DELETE"),
677        "TRUNCATE" => Some("TRUNCATE"),
678        "REFERENCES" => Some("REFERENCES"),
679        "TRIGGER" => Some("TRIGGER"),
680        "USAGE" => Some("USAGE"),
681        "CREATE" => Some("CREATE"),
682        "CONNECT" => Some("CONNECT"),
683        "TEMP" | "TEMPORARY" => Some("TEMPORARY"),
684        "EXECUTE" => Some("EXECUTE"),
685        "ALL" | "ALL PRIVILEGES" => Some("ALL PRIVILEGES"),
686        _ => None,
687    }
688}
689
690fn privileges_to_sql(columns: &[Expr]) -> Option<String> {
691    if columns.is_empty() {
692        None
693    } else {
694        let mut privileges = Vec::with_capacity(columns.len());
695        for column in columns {
696            let Expr::Named(privilege) = column else {
697                return None;
698            };
699            let sql = privilege_to_sql(privilege)?;
700            privileges.push(sql);
701        }
702        Some(privileges.join(", "))
703    }
704}
705
706fn is_safe_sql_type_fragment(fragment: &str) -> bool {
707    let fragment = fragment.trim();
708    !fragment.is_empty()
709        && !fragment.contains('\0')
710        && !fragment.contains(';')
711        && !fragment.contains('\'')
712        && !fragment.contains('"')
713        && !fragment.contains("--")
714        && !fragment.contains("/*")
715        && !fragment.contains("*/")
716        && fragment.bytes().all(|b| {
717            b.is_ascii_alphanumeric()
718                || matches!(
719                    b,
720                    b'_' | b'.' | b' ' | b'(' | b')' | b',' | b'[' | b']' | b'%' | b'+' | b'-'
721                )
722        })
723}
724
725fn volatility_to_sql(volatility: &str) -> Option<&'static str> {
726    match volatility.trim().to_ascii_uppercase().as_str() {
727        "VOLATILE" => Some("VOLATILE"),
728        "STABLE" => Some("STABLE"),
729        "IMMUTABLE" => Some("IMMUTABLE"),
730        _ => None,
731    }
732}
733
734fn function_arg_to_sql(arg: &str) -> Option<String> {
735    let arg = arg.trim();
736    if !is_safe_sql_type_fragment(arg) {
737        return None;
738    }
739
740    let mut parts = arg.split_whitespace().collect::<Vec<_>>();
741    if parts.is_empty() {
742        return None;
743    }
744    if parts.len() == 1 {
745        return Some(parts[0].to_string());
746    }
747
748    let mode = match parts[0].to_ascii_uppercase().as_str() {
749        "IN" | "OUT" | "INOUT" | "VARIADIC" => Some(parts.remove(0).to_ascii_uppercase()),
750        _ => None,
751    };
752    if parts.len() < 2 {
753        return None;
754    }
755
756    let name = escape_identifier(parts.remove(0));
757    let type_fragment = parts.join(" ");
758    if !is_safe_sql_type_fragment(&type_fragment) {
759        return None;
760    }
761
762    let mut rendered = String::new();
763    if let Some(mode) = mode {
764        rendered.push_str(&mode);
765        rendered.push(' ');
766    }
767    rendered.push_str(&name);
768    rendered.push(' ');
769    rendered.push_str(type_fragment.trim());
770    Some(rendered)
771}
772
773fn function_args_to_sql(args: &[String]) -> Option<String> {
774    let mut rendered = Vec::with_capacity(args.len());
775    for arg in args {
776        rendered.push(function_arg_to_sql(arg)?);
777    }
778    Some(rendered.join(", "))
779}
780
781fn split_top_level_args(args: &str) -> Option<Vec<&str>> {
782    let mut result = Vec::new();
783    let mut start = 0;
784    let mut depth = 0usize;
785    for (idx, ch) in args.char_indices() {
786        match ch {
787            '(' => depth += 1,
788            ')' => depth = depth.checked_sub(1)?,
789            ',' if depth == 0 => {
790                result.push(args[start..idx].trim());
791                start = idx + ch.len_utf8();
792            }
793            _ => {}
794        }
795    }
796    if depth != 0 {
797        return None;
798    }
799    let tail = args[start..].trim();
800    if !tail.is_empty() {
801        result.push(tail);
802    }
803    Some(result)
804}
805
806fn function_signature_to_sql(signature: &str) -> String {
807    let signature = signature.trim().trim_end_matches(';').trim();
808    if signature.is_empty()
809        || signature.contains('\0')
810        || signature.contains(';')
811        || signature.contains("--")
812        || signature.contains("/*")
813        || signature.contains("*/")
814    {
815        return escape_identifier(signature);
816    }
817
818    match signature.split_once('(') {
819        Some((name, args)) if args.ends_with(')') => {
820            let args = &args[..args.len() - 1];
821            let Some(parts) = split_top_level_args(args) else {
822                return escape_identifier(signature);
823            };
824            let mut rendered_args = Vec::new();
825            for part in parts {
826                if part.is_empty() {
827                    continue;
828                }
829                if !is_safe_sql_type_fragment(part) {
830                    return escape_identifier(signature);
831                }
832                rendered_args.push(part.trim().to_string());
833            }
834            format!(
835                "{}({})",
836                escape_identifier(name.trim()),
837                rendered_args.join(", ")
838            )
839        }
840        None => escape_identifier(signature),
841        _ => escape_identifier(signature),
842    }
843}
844
845fn is_valid_session_setting_name(name: &str) -> bool {
846    !name.is_empty()
847        && name.split('.').all(|part| {
848            let mut chars = part.chars();
849            matches!(chars.next(), Some(ch) if ch.is_ascii_alphabetic() || ch == '_')
850                && chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
851        })
852}
853
854impl ToSqlParameterized for Qail {
855    fn to_sql_parameterized_with_dialect(&self, dialect: Dialect) -> TranspileResult {
856        // Use the full ToSql implementation which handles CTEs, JOINs, etc.
857        // Then post-process to extract named parameters for binding
858        let full_sql = self.to_sql_with_dialect(dialect);
859        let (sql, named_params) = replace_named_params_outside_sql_literals(&full_sql);
860
861        TranspileResult {
862            sql,
863            params: Vec::new(), // Positional params not used, named_params provides mapping
864            named_params,
865        }
866    }
867}
868
869fn replace_named_params_outside_sql_literals(sql: &str) -> (String, Vec<String>) {
870    let mut named_params: Vec<String> = Vec::new();
871    let mut seen_params: std::collections::HashMap<String, usize> =
872        std::collections::HashMap::new();
873    let mut result = String::with_capacity(sql.len());
874    let mut param_index = 1;
875    let mut i = 0;
876    let mut state = SqlScanState::Normal;
877
878    while i < sql.len() {
879        match &state {
880            SqlScanState::Normal => {
881                if sql[i..].starts_with("--") {
882                    result.push_str("--");
883                    i += 2;
884                    state = SqlScanState::LineComment;
885                    continue;
886                }
887                if sql[i..].starts_with("/*") {
888                    result.push_str("/*");
889                    i += 2;
890                    state = SqlScanState::BlockComment;
891                    continue;
892                }
893                if sql[i..].starts_with("::") {
894                    result.push_str("::");
895                    i += 2;
896                    continue;
897                }
898                if let Some(delimiter) = sql_dollar_quote_delimiter_at(sql, i) {
899                    result.push_str(&delimiter);
900                    i += delimiter.len();
901                    state = SqlScanState::DollarQuoted(delimiter);
902                    continue;
903                }
904
905                let Some((ch, next_i)) = next_sql_char(sql, i) else {
906                    break;
907                };
908                match ch {
909                    '\'' => {
910                        result.push(ch);
911                        i = next_i;
912                        state = SqlScanState::SingleQuoted;
913                    }
914                    '"' => {
915                        result.push(ch);
916                        i = next_i;
917                        state = SqlScanState::DoubleQuoted;
918                    }
919                    ':' => {
920                        let Some((next, mut cursor)) = next_sql_char(sql, next_i) else {
921                            result.push(ch);
922                            i = next_i;
923                            continue;
924                        };
925                        if is_named_param_start(next) {
926                            let mut param_name = String::new();
927                            param_name.push(next);
928                            while let Some((candidate, candidate_next)) = next_sql_char(sql, cursor)
929                            {
930                                if is_named_param_continue(candidate) {
931                                    param_name.push(candidate);
932                                    cursor = candidate_next;
933                                } else {
934                                    break;
935                                }
936                            }
937
938                            let idx = if let Some(&existing) = seen_params.get(&param_name) {
939                                existing
940                            } else {
941                                let idx = param_index;
942                                seen_params.insert(param_name.clone(), idx);
943                                named_params.push(param_name);
944                                param_index += 1;
945                                idx
946                            };
947                            result.push('$');
948                            result.push_str(&idx.to_string());
949                            i = cursor;
950                        } else {
951                            result.push(ch);
952                            i = next_i;
953                        }
954                    }
955                    _ => {
956                        result.push(ch);
957                        i = next_i;
958                    }
959                }
960            }
961            SqlScanState::SingleQuoted => {
962                let Some((ch, next_i)) = next_sql_char(sql, i) else {
963                    break;
964                };
965                result.push(ch);
966                i = next_i;
967                if ch == '\'' {
968                    if sql[i..].starts_with('\'') {
969                        result.push('\'');
970                        i += 1;
971                    } else {
972                        state = SqlScanState::Normal;
973                    }
974                }
975            }
976            SqlScanState::DoubleQuoted => {
977                let Some((ch, next_i)) = next_sql_char(sql, i) else {
978                    break;
979                };
980                result.push(ch);
981                i = next_i;
982                if ch == '"' {
983                    if sql[i..].starts_with('"') {
984                        result.push('"');
985                        i += 1;
986                    } else {
987                        state = SqlScanState::Normal;
988                    }
989                }
990            }
991            SqlScanState::LineComment => {
992                let Some((ch, next_i)) = next_sql_char(sql, i) else {
993                    break;
994                };
995                result.push(ch);
996                i = next_i;
997                if ch == '\n' {
998                    state = SqlScanState::Normal;
999                }
1000            }
1001            SqlScanState::BlockComment => {
1002                if sql[i..].starts_with("*/") {
1003                    result.push_str("*/");
1004                    i += 2;
1005                    state = SqlScanState::Normal;
1006                    continue;
1007                }
1008                let Some((ch, next_i)) = next_sql_char(sql, i) else {
1009                    break;
1010                };
1011                result.push(ch);
1012                i = next_i;
1013            }
1014            SqlScanState::DollarQuoted(delimiter) => {
1015                if sql[i..].starts_with(delimiter) {
1016                    result.push_str(delimiter);
1017                    i += delimiter.len();
1018                    state = SqlScanState::Normal;
1019                    continue;
1020                }
1021                let Some((ch, next_i)) = next_sql_char(sql, i) else {
1022                    break;
1023                };
1024                result.push(ch);
1025                i = next_i;
1026            }
1027        }
1028    }
1029
1030    (result, named_params)
1031}
1032
1033#[derive(Debug, Clone, PartialEq, Eq)]
1034enum SqlScanState {
1035    Normal,
1036    SingleQuoted,
1037    DoubleQuoted,
1038    LineComment,
1039    BlockComment,
1040    DollarQuoted(String),
1041}
1042
1043fn next_sql_char(sql: &str, idx: usize) -> Option<(char, usize)> {
1044    let ch = sql.get(idx..)?.chars().next()?;
1045    Some((ch, idx + ch.len_utf8()))
1046}
1047
1048fn is_named_param_start(ch: char) -> bool {
1049    ch.is_ascii_alphabetic() || ch == '_'
1050}
1051
1052fn is_named_param_continue(ch: char) -> bool {
1053    ch.is_ascii_alphanumeric() || ch == '_'
1054}
1055
1056fn sql_dollar_quote_delimiter_at(sql: &str, idx: usize) -> Option<String> {
1057    if !sql.get(idx..)?.starts_with('$') {
1058        return None;
1059    }
1060    let rest = sql.get(idx + 1..)?;
1061    for (offset, ch) in rest.char_indices() {
1062        if ch == '$' {
1063            let tag = &rest[..offset];
1064            if tag.is_empty()
1065                || (is_named_param_start(tag.chars().next()?)
1066                    && tag.chars().all(is_named_param_continue))
1067            {
1068                return Some(sql[idx..idx + offset + 2].to_string());
1069            }
1070            return None;
1071        }
1072        if !is_named_param_continue(ch) {
1073            return None;
1074        }
1075    }
1076    None
1077}