qail_core/transformer/patterns/
update.rs1use sqlparser::ast::Statement;
4
5use crate::transformer::traits::*;
6use crate::transformer::clauses::*;
7
8pub struct UpdatePattern;
10
11impl SqlPattern for UpdatePattern {
12 fn id(&self) -> &'static str {
13 "update"
14 }
15
16 fn priority(&self) -> u32 {
17 100
18 }
19
20 fn matches(&self, stmt: &Statement, _ctx: &MatchContext) -> bool {
21 matches!(stmt, Statement::Update(_))
22 }
23
24 fn extract(&self, stmt: &Statement, _ctx: &MatchContext) -> Result<PatternData, ExtractError> {
25 let Statement::Update(update) = stmt else {
26 return Err(ExtractError {
27 message: "Expected UPDATE statement".to_string(),
28 });
29 };
30
31 let table = extract_table_from_factor(&update.table.relation)?;
32
33 let set_values: Vec<SetValueData> = update
34 .assignments
35 .iter()
36 .map(|a| {
37 let column = a.target.to_string();
38 let value = expr_to_value(&a.value);
39 SetValueData { column, value }
40 })
41 .collect();
42
43 let filter = update
44 .selection
45 .as_ref()
46 .map(extract_filter)
47 .transpose()?;
48
49 let returning = update.returning.as_ref().map(|items| {
50 items
51 .iter()
52 .map(|item| match item {
53 sqlparser::ast::SelectItem::UnnamedExpr(e) => e.to_string(),
54 sqlparser::ast::SelectItem::Wildcard(_) => "*".to_string(),
55 _ => item.to_string(),
56 })
57 .collect()
58 });
59
60 Ok(PatternData::Update {
61 table,
62 set_values,
63 filter,
64 returning,
65 })
66 }
67
68 fn transform(&self, data: &PatternData, ctx: &TransformContext) -> Result<String, TransformError> {
69 let PatternData::Update {
70 table,
71 set_values,
72 filter,
73 returning,
74 } = data
75 else {
76 return Err(TransformError {
77 message: "Expected Update data".to_string(),
78 });
79 };
80
81 let mut lines = Vec::new();
82
83 if ctx.include_imports {
84 lines.push("use qail_core::ast::{Qail, Operator};".to_string());
85 lines.push(String::new());
86 }
87
88 let mut chain = format!("let cmd = Qail::set(\"{}\")", table);
89
90 for sv in set_values {
91 let value = format_value(&sv.value, &ctx.binds);
92 chain.push_str(&format!("\n .set_value(\"{}\", {})", sv.column, value));
93 }
94
95 if let Some(f) = filter {
96 let value = format_value(&f.value, &ctx.binds);
97 chain.push_str(&format!(
98 "\n .filter(\"{}\", {}, {})",
99 f.column, f.operator, value
100 ));
101 }
102
103 if let Some(ret) = returning {
104 if ret.contains(&"*".to_string()) {
105 chain.push_str("\n .returning([\"*\"])");
106 } else {
107 let cols: Vec<String> = ret.iter().map(|c| format!("\"{}\"", c)).collect();
108 chain.push_str(&format!("\n .returning([{}])", cols.join(", ")));
109 }
110 }
111
112 chain.push(';');
113 lines.push(chain);
114
115 lines.push(String::new());
116 if returning.is_some() {
117 let default_row_type = format!("{}Row", to_pascal_case(table));
118 let row_type = ctx.return_type.as_deref().unwrap_or(&default_row_type);
119 lines.push(format!(
120 "let row: {} = driver.query_one(&cmd).await?;",
121 row_type
122 ));
123 } else {
124 lines.push("driver.execute(&cmd).await?;".to_string());
125 }
126
127 Ok(lines.join("\n"))
128 }
129}
130
131fn format_value(value: &ValueData, binds: &[String]) -> String {
132 match value {
133 ValueData::Param(n) => binds
134 .get(*n - 1)
135 .cloned()
136 .unwrap_or_else(|| format!("param_{}", n)),
137 ValueData::Literal(s) => s.clone(),
138 ValueData::Column(c) => format!("\"{}\"", c),
139 ValueData::Null => "None".to_string(),
140 }
141}
142
143fn to_pascal_case(s: &str) -> String {
144 s.split('_')
145 .map(|part| {
146 let mut chars = part.chars();
147 match chars.next() {
148 Some(c) => c.to_uppercase().chain(chars).collect::<String>(),
149 None => String::new(),
150 }
151 })
152 .collect()
153}