1use std::fmt::Write;
2
3use crate::conditional::{BuildCondition, Condition};
4#[cfg(feature = "mysql")]
5use crate::db_specific::mysql;
6#[cfg(feature = "postgres")]
7use crate::db_specific::postgres;
8#[cfg(feature = "sqlite")]
9use crate::db_specific::sqlite;
10use crate::error::Error;
11use crate::value::NullType;
12use crate::{DBImpl, OnConflict, Value};
13
14pub trait Update<'until_build, 'post_build> {
18 fn rollback_transaction(self) -> Self;
28
29 fn where_clause(self, condition: &'until_build Condition<'post_build>) -> Self;
33
34 fn add_update(self, column_name: &'until_build str, column_value: Value<'post_build>) -> Self;
42
43 fn build(self) -> Result<(String, Vec<Value<'post_build>>), Error>;
51}
52
53#[derive(Debug)]
57pub struct UpdateData<'until_build, 'post_build> {
58 pub(crate) model: &'until_build str,
59 pub(crate) on_conflict: OnConflict,
60 pub(crate) updates: Vec<(&'until_build str, Value<'post_build>)>,
61 pub(crate) where_clause: Option<&'until_build Condition<'post_build>>,
62 pub(crate) lookup: Vec<Value<'post_build>>,
63}
64
65#[derive(Debug)]
71pub enum UpdateImpl<'until_build, 'post_build> {
72 #[cfg(feature = "sqlite")]
76 SQLite(UpdateData<'until_build, 'post_build>),
77 #[cfg(feature = "mysql")]
81 MySQL(UpdateData<'until_build, 'post_build>),
82 #[cfg(feature = "postgres")]
86 Postgres(UpdateData<'until_build, 'post_build>),
87}
88
89impl<'until_build, 'post_build> Update<'until_build, 'post_build>
90 for UpdateImpl<'until_build, 'post_build>
91{
92 fn rollback_transaction(mut self) -> Self {
93 match self {
94 #[cfg(feature = "sqlite")]
95 UpdateImpl::SQLite(ref mut d) => d.on_conflict = OnConflict::ROLLBACK,
96 #[cfg(feature = "mysql")]
97 UpdateImpl::MySQL(ref mut d) => d.on_conflict = OnConflict::ROLLBACK,
98 #[cfg(feature = "postgres")]
99 UpdateImpl::Postgres(ref mut d) => d.on_conflict = OnConflict::ROLLBACK,
100 };
101 self
102 }
103
104 fn where_clause(mut self, condition: &'until_build Condition<'post_build>) -> Self {
105 match self {
106 #[cfg(feature = "sqlite")]
107 UpdateImpl::SQLite(ref mut d) => d.where_clause = Some(condition),
108 #[cfg(feature = "mysql")]
109 UpdateImpl::MySQL(ref mut d) => d.where_clause = Some(condition),
110 #[cfg(feature = "postgres")]
111 UpdateImpl::Postgres(ref mut d) => d.where_clause = Some(condition),
112 };
113 self
114 }
115
116 fn add_update(
117 mut self,
118 column_name: &'until_build str,
119 column_value: Value<'post_build>,
120 ) -> Self {
121 match self {
122 #[cfg(feature = "sqlite")]
123 UpdateImpl::SQLite(ref mut d) => d.updates.push((column_name, column_value)),
124 #[cfg(feature = "mysql")]
125 UpdateImpl::MySQL(ref mut d) => d.updates.push((column_name, column_value)),
126 #[cfg(feature = "postgres")]
127 UpdateImpl::Postgres(ref mut d) => d.updates.push((column_name, column_value)),
128 };
129 self
130 }
131
132 fn build(self) -> Result<(String, Vec<Value<'post_build>>), Error> {
133 match self {
134 #[cfg(feature = "sqlite")]
135 UpdateImpl::SQLite(mut d) => {
136 if d.updates.is_empty() {
137 return Err(Error::SQLBuildError(String::from(
138 "There must be at least one update in an UPDATE statement",
139 )));
140 }
141 let mut s = format!(
142 "UPDATE {}{} SET ",
143 match d.on_conflict {
144 OnConflict::ABORT => "OR ABORT ",
145 OnConflict::ROLLBACK => "OR ROLLBACK ",
146 },
147 d.model,
148 );
149
150 let update_index = d.updates.len() - 1;
151 for (idx, (name, value)) in d.updates.into_iter().enumerate() {
152 if let Value::Choice(c) = value {
153 write!(s, "{name} = {}", sqlite::fmt(c)).unwrap();
154 } else if let Value::Null(NullType::Choice) = value {
155 write!(s, "{name} = NULL").unwrap();
156 } else {
157 write!(s, "{name} = ?").unwrap();
158 d.lookup.push(value);
159 }
160 if idx != update_index {
161 write!(s, ", ").unwrap();
162 }
163 }
164
165 if let Some(condition) = d.where_clause {
166 write!(
167 s,
168 " WHERE {}",
169 condition.build(DBImpl::SQLite, &mut d.lookup)
170 )
171 .unwrap();
172 }
173
174 write!(s, ";").unwrap();
175
176 Ok((s, d.lookup))
177 }
178 #[cfg(feature = "mysql")]
179 UpdateImpl::MySQL(mut d) => {
180 if d.updates.is_empty() {
181 return Err(Error::SQLBuildError(String::from(
182 "There must be at least one update in an UPDATE statement",
183 )));
184 }
185 let mut s = format!(
186 "UPDATE {}{} SET ",
187 match d.on_conflict {
188 OnConflict::ABORT => "OR ABORT ",
189 OnConflict::ROLLBACK => "OR ROLLBACK ",
190 },
191 d.model,
192 );
193
194 let update_index = d.updates.len() - 1;
195 for (idx, (name, value)) in d.updates.into_iter().enumerate() {
196 if let Value::Choice(c) = value {
197 write!(s, "`{name}` = {}", mysql::fmt(c)).unwrap();
198 } else if let Value::Null(NullType::Choice) = value {
199 write!(s, "`{name}` = NULL").unwrap();
200 } else {
201 write!(s, "`{name}` = ?").unwrap();
202 d.lookup.push(value);
203 }
204 if idx != update_index {
205 write!(s, ", ").unwrap();
206 }
207 }
208
209 if let Some(condition) = d.where_clause {
210 write!(
211 s,
212 " WHERE {}",
213 condition.build(DBImpl::MySQL, &mut d.lookup)
214 )
215 .unwrap();
216 }
217
218 write!(s, ";").unwrap();
219
220 Ok((s, d.lookup))
221 }
222 #[cfg(feature = "postgres")]
223 UpdateImpl::Postgres(mut d) => {
224 if d.updates.is_empty() {
225 return Err(Error::SQLBuildError(String::from(
226 "There must be at least one update in an UPDATE statement",
227 )));
228 }
229 let mut s = format!("UPDATE \"{}\" SET ", d.model);
230
231 let update_index = d.updates.len() - 1;
232 for (idx, (name, value)) in d.updates.into_iter().enumerate() {
233 if let Value::Choice(c) = value {
234 write!(s, "\"{name}\" = {}", postgres::fmt(c)).unwrap();
235 } else if let Value::Null(NullType::Choice) = value {
236 write!(s, "\"{name}\" = NULL").unwrap();
237 } else {
238 d.lookup.push(value);
239 write!(s, "\"{name}\" = ${}", d.lookup.len()).unwrap();
240 }
241 if idx != update_index {
242 write!(s, ", ").unwrap();
243 }
244 }
245
246 if let Some(condition) = d.where_clause {
247 write!(
248 s,
249 " WHERE {}",
250 condition.build(DBImpl::Postgres, &mut d.lookup)
251 )
252 .unwrap();
253 }
254
255 write!(s, ";").unwrap();
256
257 Ok((s, d.lookup))
258 }
259 }
260 }
261}