qail_core/transformer/patterns/
insert.rs1use sqlparser::ast::{SetExpr, Statement};
4
5use crate::transformer::traits::*;
6
7pub struct InsertPattern;
9
10impl SqlPattern for InsertPattern {
11 fn id(&self) -> &'static str {
12 "insert"
13 }
14
15 fn priority(&self) -> u32 {
16 100
17 }
18
19 fn matches(&self, stmt: &Statement, _ctx: &MatchContext) -> bool {
20 matches!(stmt, Statement::Insert(_))
21 }
22
23 fn extract(&self, stmt: &Statement, ctx: &MatchContext) -> Result<PatternData, ExtractError> {
24 let Statement::Insert(insert) = stmt else {
25 return Err(ExtractError {
26 message: "Expected INSERT statement".to_string(),
27 });
28 };
29
30 let table = insert.table.to_string();
31 let columns: Vec<String> = insert.columns.iter().map(|c| c.value.clone()).collect();
32
33 let mut values = Vec::new();
34 if let Some(source) = &insert.source
35 && let SetExpr::Values(v) = source.body.as_ref()
36 {
37 for row in &v.rows {
38 for expr in row.iter() {
39 let value = crate::transformer::clauses::expr_to_value(expr);
40 if let ValueData::Param(n) = &value
41 && let Some(bind) = ctx.binds.get(*n - 1)
42 {
43 values.push(ValueData::Literal(bind.clone()));
44 continue;
45 }
46 values.push(value);
47 }
48 }
49 }
50
51 let returning = insert.returning.as_ref().map(|items| {
52 items
53 .iter()
54 .map(|item| match item {
55 sqlparser::ast::SelectItem::UnnamedExpr(e) => e.to_string(),
56 sqlparser::ast::SelectItem::Wildcard(_) => "*".to_string(),
57 _ => item.to_string(),
58 })
59 .collect()
60 });
61
62 Ok(PatternData::Insert {
63 table,
64 columns,
65 values,
66 returning,
67 })
68 }
69
70 fn transform(&self, data: &PatternData, ctx: &TransformContext) -> Result<String, TransformError> {
71 let PatternData::Insert {
72 table,
73 columns,
74 values,
75 returning,
76 } = data
77 else {
78 return Err(TransformError {
79 message: "Expected Insert data".to_string(),
80 });
81 };
82
83 let mut lines = Vec::new();
84
85 if ctx.include_imports {
86 lines.push("use qail_core::ast::Qail;".to_string());
87 lines.push(String::new());
88 }
89
90 let mut chain = format!("let cmd = Qail::add(\"{}\")", table);
91
92 for (i, col) in columns.iter().enumerate() {
93 let value = values.get(i).map(|v| format_value(v, &ctx.binds)).unwrap_or_else(|| "None".to_string());
94 chain.push_str(&format!("\n .set_value(\"{}\", {})", col, value));
95 }
96
97 if let Some(ret) = returning {
98 if ret.contains(&"*".to_string()) {
99 chain.push_str("\n .returning([\"*\"])");
100 } else {
101 let cols: Vec<String> = ret.iter().map(|c| format!("\"{}\"", c)).collect();
102 chain.push_str(&format!("\n .returning([{}])", cols.join(", ")));
103 }
104 }
105
106 chain.push(';');
107 lines.push(chain);
108
109 lines.push(String::new());
110 if returning.is_some() {
111 let default_row_type = format!("{}Row", to_pascal_case(table));
112 let row_type = ctx.return_type.as_deref().unwrap_or(&default_row_type);
113 lines.push(format!(
114 "let row: {} = driver.query_one(&cmd).await?;",
115 row_type
116 ));
117 } else {
118 lines.push("driver.execute(&cmd).await?;".to_string());
119 }
120
121 Ok(lines.join("\n"))
122 }
123}
124
125fn format_value(value: &ValueData, binds: &[String]) -> String {
126 match value {
127 ValueData::Param(n) => binds
128 .get(*n - 1)
129 .cloned()
130 .unwrap_or_else(|| format!("param_{}", n)),
131 ValueData::Literal(s) => s.clone(),
132 ValueData::Column(c) => format!("\"{}\"", c),
133 ValueData::Null => "None".to_string(),
134 }
135}
136
137fn to_pascal_case(s: &str) -> String {
138 s.split('_')
139 .map(|part| {
140 let mut chars = part.chars();
141 match chars.next() {
142 Some(c) => c.to_uppercase().chain(chars).collect::<String>(),
143 None => String::new(),
144 }
145 })
146 .collect()
147}