1use anyhow::{bail, Result};
2use bb8_postgres::{bb8::Pool, tokio_postgres::NoTls, PostgresConnectionManager};
3use oid::ObjectId;
4use rest_model::{
5 pagination::{Pagination, PaginationParams},
6 DbClient, DeleteParams, DeleteResult, Doc, PaginationResult, PatchParams, RestModel,
7 UpdateResult, UpsertResult,
8};
9use serde_json::Value;
10use tokio_postgres::types::ToSql;
11
12mod query;
13pub use query::*;
14use tracing::debug;
15mod oid;
16
17pub struct Db {
18 pub pool: Pool<PostgresConnectionManager<NoTls>>,
19}
20
21impl Db {
22 pub async fn try_new(postgres_uri: &str) -> Result<Self> {
23 let manager = PostgresConnectionManager::new_from_stringlike(postgres_uri, NoTls)?;
24 let pool = Pool::builder().max_size(10).build(manager).await?;
25 Ok(Self { pool })
26 }
27}
28
29impl<T: RestModel> DbClient<T> for Db {
30 fn generate_id(&self) -> String {
31 ObjectId::new().to_hex()
32 }
33
34 async fn init(
35 &self,
36 db_name: &str,
37 table_name: &str,
38 ) -> std::result::Result<(), anyhow::Error> {
39 let sql = format!(
40 "CREATE TABLE IF NOT EXISTS {}.{} (
41 _id VARCHAR(24) PRIMARY KEY,
42 data JSONB NOT NULL,
43 _created_at BIGINT DEFAULT (EXTRACT(EPOCH FROM NOW()) * 1000)::BIGINT,
44 _updated_at BIGINT DEFAULT (EXTRACT(EPOCH FROM NOW()) * 1000)::BIGINT
45 );",
46 db_name, table_name
47 );
48 self.pool
49 .get()
50 .await?
51 .execute(&sql, &[])
52 .await
53 .map(|_| ())
54 .map_err(|e| anyhow::Error::from(e))
55 }
56
57 async fn select_by_id(&self, db_name: &str, table_name: &str, id: &str) -> Result<Doc<T>> {
58 let sql = format!(
59 "SELECT * FROM {}.{} WHERE _id = '{}'",
60 db_name, table_name, id
61 );
62 let conn = self.pool.get().await?;
63 let rows = conn.query(&sql, &[]).await?;
64 if rows.is_empty() {
65 bail!("Document not found");
66 }
67 let row = rows.get(0).unwrap();
68 let data: Value = row.get("data");
69 let data: T = serde_json::from_value(data)?;
70 let doc = Doc {
71 _id: row.get("_id"),
72 data,
73 _created_at: row.get("_created_at"),
74 _updated_at: row.get("_updated_at"),
75 };
76 Ok(doc)
77 }
78
79 async fn paginate(
80 &self,
81 db_name: &str,
82 table_name: &str,
83 pagination_params: &PaginationParams,
84 ) -> Result<PaginationResult<T>> {
85 let page = pagination_params.page.unwrap_or(1).max(1);
86 let limit = pagination_params.limit.unwrap_or(10).max(1);
87 let offset = (page - 1) * limit;
88
89 let mut seq = 1u32;
90 let mut bindings = vec![];
91
92 let where_sql = if let Some(ref filter) = pagination_params.filter {
93 let sql = cond_to_sql(filter, &mut bindings, &mut seq)?;
94 if sql.is_empty() {
95 "".to_string()
96 } else {
97 format!("WHERE {}", sql)
98 }
99 } else {
100 "".to_string()
101 };
102
103 let order_sql = if let Some(sort_expr) = &pagination_params.sort {
105 sort_to_sql(sort_expr)?
106 } else {
107 "_id ASC".to_string()
108 };
109
110 let query_sql = format!(
112 "SELECT * FROM {}.{} {} ORDER BY {} LIMIT {} OFFSET {}",
113 db_name, table_name, where_sql, order_sql, limit, offset
114 );
115
116 let total_sql = format!(
118 "SELECT COUNT(*) FROM {}.{} {}",
119 db_name, table_name, where_sql
120 );
121
122 let conn = self.pool.get().await?;
123 let args: Vec<&(dyn ToSql + Sync)> =
124 bindings.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
125 debug!("args: {:?}", args);
126 debug!("total_sql: {}", total_sql);
127 let row = conn.query_one(&total_sql, &args).await?;
128 let total_count: i64 = row.get(0);
129 let total_count = total_count as u32;
130
131 let items = if total_count > 0 {
132 debug!("query_sql: {}", query_sql);
133 let row = conn.query(&query_sql, &args).await?;
134 let mut items = Vec::new();
135 for row in row {
136 let data: Value = row.get(1);
137 let data: T = serde_json::from_value(data)?;
138 let doc = Doc {
139 _id: row.get(0),
140 data,
141 _created_at: row.get(2),
142 _updated_at: row.get(3),
143 };
144 items.push(doc);
145 }
146 items
147 } else {
148 vec![]
149 };
150
151 let total_pages = total_count / limit + if total_count % limit > 0 { 1 } else { 0 };
152 Ok(PaginationResult {
153 items,
154 pagination: Pagination {
155 total_count,
156 total_pages,
157 current_page: page,
158 items_per_page: limit,
159 },
160 })
161 }
162
163 async fn upsert(
164 &self,
165 db_name: &str,
166 table_name: &str,
167 items: Vec<Doc<T>>,
168 ) -> Result<UpsertResult> {
169 if items.is_empty() {
170 return Ok(UpsertResult {
171 created_count: 0,
172 updated_count: 0,
173 });
174 }
175
176 let mut query = format!("INSERT INTO {}.{} (_id, data) VALUES ", db_name, table_name);
177
178 let mut values = Vec::new();
179 let mut args: Vec<Box<dyn ToSql + Sync>> = Vec::new();
180
181 for (i, doc) in items.iter().enumerate() {
182 values.push(format!("(${}, ${})", i * 2 + 1, i * 2 + 2,));
183
184 args.push(Box::new(doc._id.clone()));
185 args.push(Box::new(serde_json::to_value(&doc.data)?));
186 }
187
188 query.push_str(&values.join(", "));
189 query.push_str(
190 " ON CONFLICT (_id) DO UPDATE SET
191 data = EXCLUDED.data,
192 _updated_at = (EXTRACT(EPOCH FROM NOW()) * 1000)::BIGINT
193 RETURNING (xmax = 0) AS inserted;",
194 );
195
196 debug!("{}", query);
197
198 let conn = self.pool.get().await?;
199 let args_refs: Vec<&(dyn ToSql + Sync)> = args.iter().map(|x| x.as_ref()).collect();
200 let rows = conn.query(&query, &args_refs[..]).await?;
201
202 let created_count = rows.iter().filter(|row| row.get::<_, bool>(0)).count() as u32;
203 let updated_count = rows.len() as u32 - created_count;
204
205 Ok(UpsertResult {
206 created_count,
207 updated_count,
208 })
209 }
210
211 async fn update(
212 &self,
213 db_name: &str,
214 table_name: &str,
215 params: &PatchParams,
216 ) -> Result<UpdateResult> {
217 let mut set_sql = "data = ".to_string();
219 let mut args: Vec<Box<dyn ToSql + Sync>> = Vec::new();
220 let mut jsonb_expr = "data".to_string(); for (key, value) in params.patch.as_object().unwrap() {
223 let path_arg_index = args.len() + 1;
224 let value_arg_index = path_arg_index + 1;
225
226 jsonb_expr = format!(
227 "jsonb_set({}, ${}, ${}, true)",
228 jsonb_expr, path_arg_index, value_arg_index
229 );
230
231 args.push(Box::new(vec![key.to_string()])); args.push(Box::new(value.clone())); }
234
235 set_sql.push_str(&jsonb_expr);
236
237 let set_sql = format!(
239 "{}, _updated_at = EXTRACT(EPOCH FROM NOW()) * 1000",
240 set_sql
241 );
242
243 let mut bindings = vec![];
245 let seq = &mut (args.len() as u32 + 1);
246 let where_sql = cond_to_sql(¶ms.filter, &mut bindings, seq)?;
247
248 let query = format!(
250 "UPDATE {}.{} SET {} {} RETURNING _id;",
251 db_name, table_name, set_sql, where_sql
252 );
253 args.append(&mut bindings);
254
255 debug!("{}", query);
257 debug!("{:?}", args);
258 let args_refs: Vec<&(dyn ToSql + Sync)> = args.iter().map(|x| x.as_ref()).collect();
259 let conn = self.pool.get().await?;
260 let rows = conn.query(&query, &args_refs[..]).await?;
261
262 Ok(UpdateResult {
264 updated_count: rows.len() as u32,
265 })
266 }
267
268 async fn delete(
269 &self,
270 db_name: &str,
271 table_name: &str,
272 params: &DeleteParams,
273 ) -> Result<DeleteResult> {
274 let bindings = &mut vec![];
276 let seq = &mut 1;
277 let where_sql = cond_to_sql(¶ms.filter, bindings, seq)?;
278
279 let query = format!(
281 "DELETE FROM {}.{} {} RETURNING _id;",
282 db_name, table_name, where_sql
283 );
284
285 let conn = self.pool.get().await?;
287 let args_refs: Vec<&(dyn ToSql + Sync)> = bindings.iter().map(|x| x.as_ref()).collect();
288 let rows = conn.query(&query, &args_refs).await?;
289
290 Ok(DeleteResult {
292 deleted_count: rows.len() as u32,
293 })
294 }
295}