1use anyhow::{bail, Result};
2use bb8_postgres::{bb8::Pool, tokio_postgres::NoTls, PostgresConnectionManager};
3use oid::ObjectId;
4use rest_model::{
5 pagination::{Pagination, PaginationParams},
6 DbClient, DeleteParams, DeleteResult, Doc, PaginationResult, PatchParams, RestModel,
7 UpdateResult, UpsertResult,
8};
9use serde_json::Value;
10use tokio_postgres::types::ToSql;
11
12mod query;
13pub use query::*;
14use tracing::debug;
15mod oid;
16
17#[derive(Debug, Clone)]
18pub struct Db {
19 pub pool: Pool<PostgresConnectionManager<NoTls>>,
20}
21
22impl Db {
23 pub async fn try_new(postgres_uri: &str) -> Result<Self> {
24 let manager = PostgresConnectionManager::new_from_stringlike(postgres_uri, NoTls)?;
25 let pool = Pool::builder().max_size(10).build(manager).await?;
26 Ok(Self { pool })
27 }
28}
29
30impl<T: RestModel> DbClient<T> for Db {
31 fn generate_id(&self) -> String {
32 ObjectId::new().to_hex()
33 }
34
35 async fn init(
36 &self,
37 db_name: &str,
38 table_name: &str,
39 ) -> std::result::Result<(), anyhow::Error> {
40 let sql = format!(
41 "CREATE TABLE IF NOT EXISTS {}.{} (
42 _id VARCHAR(24) PRIMARY KEY,
43 data JSONB NOT NULL,
44 _created_at BIGINT DEFAULT (EXTRACT(EPOCH FROM NOW()) * 1000)::BIGINT,
45 _updated_at BIGINT DEFAULT (EXTRACT(EPOCH FROM NOW()) * 1000)::BIGINT
46 );",
47 db_name, table_name
48 );
49 self.pool
50 .get()
51 .await?
52 .execute(&sql, &[])
53 .await
54 .map(|_| ())
55 .map_err(|e| anyhow::Error::from(e))
56 }
57
58 async fn select_by_id(&self, db_name: &str, table_name: &str, id: &str) -> Result<Doc<T>> {
59 let sql = format!(
60 "SELECT * FROM {}.{} WHERE _id = '{}'",
61 db_name, table_name, id
62 );
63 let conn = self.pool.get().await?;
64 let rows = conn.query(&sql, &[]).await?;
65 if rows.is_empty() {
66 bail!("Document not found");
67 }
68 let row = rows.get(0).unwrap();
69 let data: Value = row.get("data");
70 let data: T = serde_json::from_value(data)?;
71 let doc = Doc {
72 _id: row.get("_id"),
73 data,
74 _created_at: row.get("_created_at"),
75 _updated_at: row.get("_updated_at"),
76 };
77 Ok(doc)
78 }
79
80 async fn paginate(
81 &self,
82 db_name: &str,
83 table_name: &str,
84 pagination_params: &PaginationParams,
85 ) -> Result<PaginationResult<T>> {
86 let page = pagination_params.page.unwrap_or(1).max(1);
87 let limit = pagination_params.limit.unwrap_or(10).max(1);
88 let offset = (page - 1) * limit;
89
90 let mut seq = 1u32;
91 let mut bindings = vec![];
92
93 let where_sql = if let Some(ref filter) = pagination_params.filter {
94 let sql = cond_to_sql(filter, &mut bindings, &mut seq)?;
95 if sql.is_empty() {
96 "".to_string()
97 } else {
98 format!("WHERE {}", sql)
99 }
100 } else {
101 "".to_string()
102 };
103
104 let order_sql = if let Some(sort_expr) = &pagination_params.sort {
106 sort_to_sql(sort_expr)?
107 } else {
108 "_id ASC".to_string()
109 };
110
111 let query_sql = format!(
113 "SELECT * FROM {}.{} {} ORDER BY {} LIMIT {} OFFSET {}",
114 db_name, table_name, where_sql, order_sql, limit, offset
115 );
116
117 let total_sql = format!(
119 "SELECT COUNT(*) FROM {}.{} {}",
120 db_name, table_name, where_sql
121 );
122
123 let conn = self.pool.get().await?;
124 let args: Vec<&(dyn ToSql + Sync)> =
125 bindings.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
126 debug!("args: {:?}", args);
127 debug!("total_sql: {}", total_sql);
128 let row = conn.query_one(&total_sql, &args).await?;
129 let total_count: i64 = row.get(0);
130 let total_count = total_count as u32;
131
132 let items = if total_count > 0 {
133 debug!("query_sql: {}", query_sql);
134 let row = conn.query(&query_sql, &args).await?;
135 let mut items = Vec::new();
136 for row in row {
137 let data: Value = row.get(1);
138 let data: T = serde_json::from_value(data)?;
139 let doc = Doc {
140 _id: row.get(0),
141 data,
142 _created_at: row.get(2),
143 _updated_at: row.get(3),
144 };
145 items.push(doc);
146 }
147 items
148 } else {
149 vec![]
150 };
151
152 let total_pages = total_count / limit + if total_count % limit > 0 { 1 } else { 0 };
153 Ok(PaginationResult {
154 items,
155 pagination: Pagination {
156 total_count,
157 total_pages,
158 current_page: page,
159 items_per_page: limit,
160 },
161 })
162 }
163
164 async fn upsert(
165 &self,
166 db_name: &str,
167 table_name: &str,
168 items: &[Doc<T>],
169 ) -> Result<UpsertResult> {
170 if items.is_empty() {
171 return Ok(UpsertResult {
172 created_count: 0,
173 updated_count: 0,
174 });
175 }
176
177 let mut query = format!("INSERT INTO {}.{} (_id, data) VALUES ", db_name, table_name);
178
179 let mut values = Vec::new();
180 let mut args: Vec<Box<dyn ToSql + Sync>> = Vec::new();
181
182 for (i, doc) in items.iter().enumerate() {
183 values.push(format!("(${}, ${})", i * 2 + 1, i * 2 + 2,));
184
185 args.push(Box::new(doc._id.clone()));
186 args.push(Box::new(serde_json::to_value(&doc.data)?));
187 }
188
189 query.push_str(&values.join(", "));
190 query.push_str(
191 " ON CONFLICT (_id) DO UPDATE SET
192 data = EXCLUDED.data,
193 _updated_at = (EXTRACT(EPOCH FROM NOW()) * 1000)::BIGINT
194 RETURNING (xmax = 0) AS inserted;",
195 );
196
197 debug!("{}", query);
198
199 let conn = self.pool.get().await?;
200 let args_refs: Vec<&(dyn ToSql + Sync)> = args.iter().map(|x| x.as_ref()).collect();
201 let rows = conn.query(&query, &args_refs[..]).await?;
202
203 let created_count = rows.iter().filter(|row| row.get::<_, bool>(0)).count() as u32;
204 let updated_count = rows.len() as u32 - created_count;
205
206 Ok(UpsertResult {
207 created_count,
208 updated_count,
209 })
210 }
211
212 async fn update(
213 &self,
214 db_name: &str,
215 table_name: &str,
216 params: &PatchParams,
217 ) -> Result<UpdateResult> {
218 let mut set_sql = "data = ".to_string();
220 let mut args: Vec<Box<dyn ToSql + Sync>> = Vec::new();
221 let mut jsonb_expr = "data".to_string(); for (key, value) in params.patch.as_object().unwrap() {
224 let path_arg_index = args.len() + 1;
225 let value_arg_index = path_arg_index + 1;
226
227 jsonb_expr = format!(
228 "jsonb_set({}, ${}, ${}, true)",
229 jsonb_expr, path_arg_index, value_arg_index
230 );
231
232 args.push(Box::new(vec![key.to_string()])); args.push(Box::new(value.clone())); }
235
236 set_sql.push_str(&jsonb_expr);
237
238 let set_sql = format!(
240 "{}, _updated_at = EXTRACT(EPOCH FROM NOW()) * 1000",
241 set_sql
242 );
243
244 let mut bindings = vec![];
246 let seq = &mut (args.len() as u32 + 1);
247 let where_sql = cond_to_sql(¶ms.filter, &mut bindings, seq)?;
248
249 let query = format!(
251 "UPDATE {}.{} SET {} {} RETURNING _id;",
252 db_name, table_name, set_sql, where_sql
253 );
254 args.append(&mut bindings);
255
256 debug!("{}", query);
258 debug!("{:?}", args);
259 let args_refs: Vec<&(dyn ToSql + Sync)> = args.iter().map(|x| x.as_ref()).collect();
260 let conn = self.pool.get().await?;
261 let rows = conn.query(&query, &args_refs[..]).await?;
262
263 Ok(UpdateResult {
265 updated_count: rows.len() as u32,
266 })
267 }
268
269 async fn delete(
270 &self,
271 db_name: &str,
272 table_name: &str,
273 params: &DeleteParams,
274 ) -> Result<DeleteResult> {
275 let bindings = &mut vec![];
277 let seq = &mut 1;
278 let where_sql = cond_to_sql(¶ms.filter, bindings, seq)?;
279
280 let query = format!(
282 "DELETE FROM {}.{} {} RETURNING _id;",
283 db_name, table_name, where_sql
284 );
285
286 let conn = self.pool.get().await?;
288 let args_refs: Vec<&(dyn ToSql + Sync)> = bindings.iter().map(|x| x.as_ref()).collect();
289 let rows = conn.query(&query, &args_refs).await?;
290
291 Ok(DeleteResult {
293 deleted_count: rows.len() as u32,
294 })
295 }
296}