qail_core/transformer/patterns/
insert.rs

1//! INSERT pattern implementation
2
3use sqlparser::ast::{SetExpr, Statement};
4
5use crate::transformer::traits::*;
6
7/// INSERT query pattern
8pub struct InsertPattern;
9
10impl SqlPattern for InsertPattern {
11    fn id(&self) -> &'static str {
12        "insert"
13    }
14
15    fn priority(&self) -> u32 {
16        100
17    }
18
19    fn matches(&self, stmt: &Statement, _ctx: &MatchContext) -> bool {
20        matches!(stmt, Statement::Insert(_))
21    }
22
23    fn extract(&self, stmt: &Statement, ctx: &MatchContext) -> Result<PatternData, ExtractError> {
24        let Statement::Insert(insert) = stmt else {
25            return Err(ExtractError {
26                message: "Expected INSERT statement".to_string(),
27            });
28        };
29
30        let table = insert.table.to_string();
31        let columns: Vec<String> = insert.columns.iter().map(|c| c.value.clone()).collect();
32
33        let mut values = Vec::new();
34        if let Some(source) = &insert.source
35            && let SetExpr::Values(v) = source.body.as_ref()
36        {
37            for row in &v.rows {
38                for expr in row.iter() {
39                    let value = crate::transformer::clauses::expr_to_value(expr);
40                    if let ValueData::Param(n) = &value
41                        && let Some(bind) = ctx.binds.get(*n - 1)
42                    {
43                        values.push(ValueData::Literal(bind.clone()));
44                        continue;
45                    }
46                    values.push(value);
47                }
48            }
49        }
50
51        let returning = insert.returning.as_ref().map(|items| {
52            items
53                .iter()
54                .map(|item| match item {
55                    sqlparser::ast::SelectItem::UnnamedExpr(e) => e.to_string(),
56                    sqlparser::ast::SelectItem::Wildcard(_) => "*".to_string(),
57                    _ => item.to_string(),
58                })
59                .collect()
60        });
61
62        Ok(PatternData::Insert {
63            table,
64            columns,
65            values,
66            returning,
67        })
68    }
69
70    fn transform(&self, data: &PatternData, ctx: &TransformContext) -> Result<String, TransformError> {
71        let PatternData::Insert {
72            table,
73            columns,
74            values,
75            returning,
76        } = data
77        else {
78            return Err(TransformError {
79                message: "Expected Insert data".to_string(),
80            });
81        };
82
83        let mut lines = Vec::new();
84
85        if ctx.include_imports {
86            lines.push("use qail_core::ast::Qail;".to_string());
87            lines.push(String::new());
88        }
89
90        let mut chain = format!("let cmd = Qail::add(\"{}\")", table);
91
92        for (i, col) in columns.iter().enumerate() {
93            let value = values.get(i).map(|v| format_value(v, &ctx.binds)).unwrap_or_else(|| "None".to_string());
94            chain.push_str(&format!("\n    .set_value(\"{}\", {})", col, value));
95        }
96
97        if let Some(ret) = returning {
98            if ret.contains(&"*".to_string()) {
99                chain.push_str("\n    .returning([\"*\"])");
100            } else {
101                let cols: Vec<String> = ret.iter().map(|c| format!("\"{}\"", c)).collect();
102                chain.push_str(&format!("\n    .returning([{}])", cols.join(", ")));
103            }
104        }
105
106        chain.push(';');
107        lines.push(chain);
108
109        lines.push(String::new());
110        if returning.is_some() {
111            let default_row_type = format!("{}Row", to_pascal_case(table));
112            let row_type = ctx.return_type.as_deref().unwrap_or(&default_row_type);
113            lines.push(format!(
114                "let row: {} = driver.query_one(&cmd).await?;",
115                row_type
116            ));
117        } else {
118            lines.push("driver.execute(&cmd).await?;".to_string());
119        }
120
121        Ok(lines.join("\n"))
122    }
123}
124
125fn format_value(value: &ValueData, binds: &[String]) -> String {
126    match value {
127        ValueData::Param(n) => binds
128            .get(*n - 1)
129            .cloned()
130            .unwrap_or_else(|| format!("param_{}", n)),
131        ValueData::Literal(s) => s.clone(),
132        ValueData::Column(c) => format!("\"{}\"", c),
133        ValueData::Null => "None".to_string(),
134    }
135}
136
137fn to_pascal_case(s: &str) -> String {
138    s.split('_')
139        .map(|part| {
140            let mut chars = part.chars();
141            match chars.next() {
142                Some(c) => c.to_uppercase().chain(chars).collect::<String>(),
143                None => String::new(),
144            }
145        })
146        .collect()
147}