qail_core/transformer/patterns/
update.rs

1//! UPDATE pattern implementation
2
3use sqlparser::ast::Statement;
4
5use crate::transformer::traits::*;
6use crate::transformer::clauses::*;
7
8/// UPDATE query pattern
9pub struct UpdatePattern;
10
11impl SqlPattern for UpdatePattern {
12    fn id(&self) -> &'static str {
13        "update"
14    }
15
16    fn priority(&self) -> u32 {
17        100
18    }
19
20    fn matches(&self, stmt: &Statement, _ctx: &MatchContext) -> bool {
21        matches!(stmt, Statement::Update(_))
22    }
23
24    fn extract(&self, stmt: &Statement, _ctx: &MatchContext) -> Result<PatternData, ExtractError> {
25        let Statement::Update(update) = stmt else {
26            return Err(ExtractError {
27                message: "Expected UPDATE statement".to_string(),
28            });
29        };
30
31        let table = extract_table_from_factor(&update.table.relation)?;
32
33        let set_values: Vec<SetValueData> = update
34            .assignments
35            .iter()
36            .map(|a| {
37                let column = a.target.to_string();
38                let value = expr_to_value(&a.value);
39                SetValueData { column, value }
40            })
41            .collect();
42
43        let filter = update
44            .selection
45            .as_ref()
46            .map(extract_filter)
47            .transpose()?;
48
49        let returning = update.returning.as_ref().map(|items| {
50            items
51                .iter()
52                .map(|item| match item {
53                    sqlparser::ast::SelectItem::UnnamedExpr(e) => e.to_string(),
54                    sqlparser::ast::SelectItem::Wildcard(_) => "*".to_string(),
55                    _ => item.to_string(),
56                })
57                .collect()
58        });
59
60        Ok(PatternData::Update {
61            table,
62            set_values,
63            filter,
64            returning,
65        })
66    }
67
68    fn transform(&self, data: &PatternData, ctx: &TransformContext) -> Result<String, TransformError> {
69        let PatternData::Update {
70            table,
71            set_values,
72            filter,
73            returning,
74        } = data
75        else {
76            return Err(TransformError {
77                message: "Expected Update data".to_string(),
78            });
79        };
80
81        let mut lines = Vec::new();
82
83        if ctx.include_imports {
84            lines.push("use qail_core::ast::{Qail, Operator};".to_string());
85            lines.push(String::new());
86        }
87
88        let mut chain = format!("let cmd = Qail::set(\"{}\")", table);
89
90        for sv in set_values {
91            let value = format_value(&sv.value, &ctx.binds);
92            chain.push_str(&format!("\n    .set_value(\"{}\", {})", sv.column, value));
93        }
94
95        if let Some(f) = filter {
96            let value = format_value(&f.value, &ctx.binds);
97            chain.push_str(&format!(
98                "\n    .filter(\"{}\", {}, {})",
99                f.column, f.operator, value
100            ));
101        }
102
103        if let Some(ret) = returning {
104            if ret.contains(&"*".to_string()) {
105                chain.push_str("\n    .returning([\"*\"])");
106            } else {
107                let cols: Vec<String> = ret.iter().map(|c| format!("\"{}\"", c)).collect();
108                chain.push_str(&format!("\n    .returning([{}])", cols.join(", ")));
109            }
110        }
111
112        chain.push(';');
113        lines.push(chain);
114
115        lines.push(String::new());
116        if returning.is_some() {
117            let default_row_type = format!("{}Row", to_pascal_case(table));
118            let row_type = ctx.return_type.as_deref().unwrap_or(&default_row_type);
119            lines.push(format!(
120                "let row: {} = driver.query_one(&cmd).await?;",
121                row_type
122            ));
123        } else {
124            lines.push("driver.execute(&cmd).await?;".to_string());
125        }
126
127        Ok(lines.join("\n"))
128    }
129}
130
131fn format_value(value: &ValueData, binds: &[String]) -> String {
132    match value {
133        ValueData::Param(n) => binds
134            .get(*n - 1)
135            .cloned()
136            .unwrap_or_else(|| format!("param_{}", n)),
137        ValueData::Literal(s) => s.clone(),
138        ValueData::Column(c) => format!("\"{}\"", c),
139        ValueData::Null => "None".to_string(),
140    }
141}
142
143fn to_pascal_case(s: &str) -> String {
144    s.split('_')
145        .map(|part| {
146            let mut chars = part.chars();
147            match chars.next() {
148                Some(c) => c.to_uppercase().chain(chars).collect::<String>(),
149                None => String::new(),
150            }
151        })
152        .collect()
153}