1use deadpool_postgres::{Object, Pool};
2
3use super::column::ColumnType;
4use super::util::{
5 data_into_sql_params, info_data_to_sql, quote, rows_into_data,
6};
7use super::{ColumnData, Info, TableTemplate};
8
9use crate::query::{Query, SqlBuilder, UpdateParams};
10use crate::{Error, Result};
11
12use std::borrow::Borrow;
13use std::marker::PhantomData;
14use std::sync::Arc;
15macro_rules! debug_sql {
23 ($method:expr, $name:expr, $sql:expr) => {
24 tracing::debug!("sql: {} {} with {}", $method, $name, $sql);
25 };
26}
27
28#[derive(Debug)]
29struct TableMeta {
30 info: Info,
31 select: String,
32 insert: String,
33 update_full: SqlBuilder,
34 names_for_select: String,
35}
36
37#[derive(Debug)]
38pub struct Table<T>
39where
40 T: TableTemplate,
41{
42 pool: Pool,
43 name: &'static str,
44 meta: Arc<TableMeta>,
45 phantom: PhantomData<T>,
46}
47
48impl<T> Table<T>
49where
50 T: TableTemplate,
51{
52 pub(crate) fn new(pool: Pool, name: &'static str) -> Self {
53 let info = T::table_info();
54 let meta = TableMeta {
55 select: Self::create_select_sql(&info, name),
56 insert: Self::create_insert_sql(&info, name),
57 update_full: Self::create_update_full(&info),
58 names_for_select: Self::create_names_for_select(&info),
59 info,
60 };
61
62 Self {
63 pool,
64 name,
65 meta: Arc::new(meta),
66 phantom: PhantomData,
67 }
68 }
69
70 pub fn name(&self) -> &'static str {
71 self.name
72 }
73
74 pub fn names_for_select(&self) -> &str {
77 &self.meta.names_for_select
78 }
79
80 pub fn info(&self) -> &Info {
81 &self.meta.info
82 }
83
84 async fn get_client(&self) -> Result<Object> {
85 self.pool
86 .get()
87 .await
88 .map_err(|e| Error::Postgres(e.to_string()))
89 }
90
91 fn create_names_for_select(info: &Info) -> String {
92 format!("\"{}\"", info.names().join("\", \""))
93 }
94
95 fn create_select_sql(info: &Info, name: &str) -> String {
96 let names = info.names();
97 format!("SELECT \"{}\" FROM \"{}\"", names.join("\", \""), name)
98 }
99
100 fn create_insert_sql(info: &Info, name: &str) -> String {
101 let mut names = vec![];
102 let mut vals = vec![];
103 for (i, col) in info.data().iter().enumerate() {
104 names.push(quote(col.name));
105 vals.push(format!("${}", i + 1));
106 }
107
108 format!(
110 "INSERT INTO \"{}\" ({}) VALUES ({})",
111 name,
112 names.join(", "),
113 vals.join(", ")
114 )
115 }
116
117 fn create_update_full(info: &Info) -> SqlBuilder {
121 let mut sql = SqlBuilder::new();
122
123 let last = info.data().len() - 1;
124 for (i, col) in info.data().iter().enumerate() {
125 sql.space_after(format!("\"{}\" =", col.name));
126 sql.param();
127
128 if i != last {
129 sql.space_after(",");
130 }
131 }
132
133 sql
134 }
135
136 pub async fn try_create(&self) -> Result<()> {
138 let sql = info_data_to_sql(self.name, self.meta.info.data());
139
140 debug_sql!("create", self.name, sql);
141
142 self.get_client()
143 .await?
144 .batch_execute(sql.as_str())
145 .await
146 .map_err(Into::into)
147 }
148
149 pub async fn create(self) -> Self {
152 self.try_create().await.expect("could not create table");
153 self
154 }
155
156 pub async fn insert_one(&self, input: &T) -> Result<()> {
177 let sql = &self.meta.insert;
178 debug_sql!("insert_one", self.name, sql);
179
180 let cl = self.get_client().await?;
181
182 let data = input.to_data();
183 let params = data_into_sql_params(&data);
184
185 cl.execute(sql, params.as_slice()).await?;
187 Ok(())
188 }
189
190 pub async fn insert_many<B, I>(&self, input: I) -> Result<()>
191 where
192 B: Borrow<T>,
193 I: Iterator<Item = B>,
194 {
195 let sql = &self.meta.insert;
196 debug_sql!("insert_many", self.name, sql);
197
198 let mut cl = self.get_client().await?;
201 let ts = cl.transaction().await?;
202
203 let stmt = ts.prepare(sql).await?;
204
205 for input in input {
206 let data = input.borrow().to_data();
207 let params = data_into_sql_params(&data);
208
209 ts.execute(&stmt, params.as_slice()).await?;
210 }
211
212 ts.commit().await?;
213
214 Ok(())
215 }
216
217 pub async fn find_all(&self) -> Result<Vec<T>> {
221 let sql = &self.meta.select;
222 debug_sql!("find_all", self.name, sql);
223
224 let rows = {
225 let cl = self.get_client().await?;
226 cl.query(sql, &[]).await?
227 };
228
229 rows_into_data(rows)
230 }
231
232 pub async fn find_many(&self, where_query: Query<'_>) -> Result<Vec<T>> {
233 let mut query = Query::from_sql_str(self.meta.select.clone());
234
235 self.meta.info.validate_params(where_query.params())?;
236 query.sql.space("WHERE");
237 query.append(where_query);
238
239 let sql = query.sql().to_string();
240 debug_sql!("find_many", self.name, sql);
241 let params = query.to_sql_params();
242
243 let rows = {
244 let cl = self.get_client().await?;
245 cl.query(&sql, params.as_slice()).await?
246 };
247
248 rows_into_data(rows)
249 }
250
251 pub async fn find_one(
252 &self,
253 mut where_query: Query<'_>,
254 ) -> Result<Option<T>> {
255 where_query.sql.space_before("LIMIT 1");
256 let res = self.find_many(where_query).await?;
257
258 debug_assert!(res.len() <= 1);
259
260 Ok(res.into_iter().next())
261 }
262
263 pub async fn find_many_raw(&self, sql: &str) -> Result<Vec<T>> {
266 debug_sql!("find_many_raw", self.name, sql);
267
268 let rows = {
269 let cl = self.get_client().await?;
270 cl.query(sql, &[]).await?
271 };
272
273 rows_into_data(rows)
274 }
275
276 pub async fn count_many<'a>(&self, where_query: Query<'a>) -> Result<u32> {
277 let mut query = Query::from_sql_str(format!(
278 "SELECT COUNT(*) FROM \"{}\"",
279 self.name
280 ));
281
282 if !where_query.is_empty() {
283 self.meta.info.validate_params(where_query.params())?;
284 query.sql.space("WHERE");
285 query.append(where_query);
286 }
287
288 let sql = query.sql().to_string();
289 debug_sql!("count_many", self.name, sql);
290 let params = query.to_sql_params();
291
292 let row = {
293 let cl = self.get_client().await?;
294 cl.query_one(&sql, params.as_slice()).await?
295 };
296
297 let data: ColumnData = row.try_get(0)?;
298
299 u32::from_data(data).map_err(Into::into)
300 }
301
302 pub async fn update<'a>(
304 &self,
305 where_query: Query<'a>,
306 update_query: UpdateParams<'a>,
307 ) -> Result<()> {
308 let mut query = update_query.into_query();
310 query.sql.space("WHERE");
311 query.append(where_query);
312
313 self.meta.info.validate_params(query.params())?;
314
315 let sql = format!("UPDATE \"{}\" SET {}", self.name, query.sql());
316 debug_sql!("update", self.name, sql);
317 let params = query.to_sql_params();
318
319 let cl = self.get_client().await?;
320 cl.execute(&sql, params.as_slice()).await?;
321
322 Ok(())
323 }
324
325 pub async fn update_full<'a>(
326 &self,
327 where_query: Query<'a>,
328 input: &'a T,
329 ) -> Result<()> {
330 let mut sql = self.meta.update_full.clone();
331
332 self.meta.info.validate_params(where_query.params())?;
333
334 sql.space("WHERE");
335 sql.append(where_query.sql);
336
337 let sql = format!("UPDATE \"{}\" SET {}", self.name, sql);
338 debug_sql!("update_full", self.name, sql);
339
340 let mut data = input.to_data();
341 for param in where_query.params {
342 data.push(param.data);
343 }
344 let params = data_into_sql_params(&data);
345
346 let cl = self.get_client().await?;
347 cl.execute(&sql, params.as_slice()).await?;
348
349 Ok(())
350 }
351
352 pub async fn delete(&self, where_query: Query<'_>) -> Result<()> {
354 self.meta.info.validate_params(where_query.params())?;
355
356 let sql =
357 format!("DELETE FROM \"{}\" WHERE {}", self.name, where_query.sql);
358 debug_sql!("delete_many", self.name, sql);
359 let params = where_query.to_sql_params();
360
361 let cl = self.get_client().await?;
362 cl.execute(&sql, params.as_slice()).await?;
363
364 Ok(())
365 }
366
367 pub async fn execute_raw(
369 &self,
370 sql: SqlBuilder,
371 data: &[ColumnData<'_>],
372 ) -> Result<()> {
373 let sql = sql.to_string();
374 debug_sql!("execute_raw", self.name, sql);
375
376 let params = data_into_sql_params(data);
377
378 let cl = self.get_client().await?;
379 cl.execute(&sql, params.as_slice()).await?;
380
381 Ok(())
382 }
383}
384
385impl<T> Clone for Table<T>
386where
387 T: TableTemplate,
388{
389 fn clone(&self) -> Self {
390 Self {
391 pool: self.pool.clone(),
392 name: self.name,
393 meta: self.meta.clone(),
394 phantom: PhantomData,
395 }
396 }
397}