rest_model_postgres/
lib.rs

1use anyhow::{bail, Result};
2use bb8_postgres::{bb8::Pool, tokio_postgres::NoTls, PostgresConnectionManager};
3use oid::ObjectId;
4use rest_model::{
5    pagination::{Pagination, PaginationParams},
6    DbClient, DeleteParams, DeleteResult, Doc, PaginationResult, PatchParams, RestModel,
7    UpdateResult, UpsertResult,
8};
9use serde_json::Value;
10use tokio_postgres::types::ToSql;
11
12mod query;
13pub use query::*;
14use tracing::debug;
15mod oid;
16
17pub struct Db {
18    pub pool: Pool<PostgresConnectionManager<NoTls>>,
19}
20
21impl Db {
22    pub async fn try_new(postgres_uri: &str) -> Result<Self> {
23        let manager = PostgresConnectionManager::new_from_stringlike(postgres_uri, NoTls)?;
24        let pool = Pool::builder().max_size(10).build(manager).await?;
25        Ok(Self { pool })
26    }
27}
28
29impl<T: RestModel> DbClient<T> for Db {
30    fn generate_id(&self) -> String {
31        ObjectId::new().to_hex()
32    }
33
34    async fn init(
35        &self,
36        db_name: &str,
37        table_name: &str,
38    ) -> std::result::Result<(), anyhow::Error> {
39        let sql = format!(
40            "CREATE TABLE IF NOT EXISTS {}.{} (
41                _id VARCHAR(24) PRIMARY KEY,
42                data JSONB NOT NULL,
43                _created_at BIGINT DEFAULT (EXTRACT(EPOCH FROM NOW()) * 1000)::BIGINT,
44                _updated_at BIGINT DEFAULT (EXTRACT(EPOCH FROM NOW()) * 1000)::BIGINT
45            );",
46            db_name, table_name
47        );
48        self.pool
49            .get()
50            .await?
51            .execute(&sql, &[])
52            .await
53            .map(|_| ())
54            .map_err(|e| anyhow::Error::from(e))
55    }
56
57    async fn select_by_id(&self, db_name: &str, table_name: &str, id: &str) -> Result<Doc<T>> {
58        let sql = format!(
59            "SELECT * FROM {}.{} WHERE _id = '{}'",
60            db_name, table_name, id
61        );
62        let conn = self.pool.get().await?;
63        let rows = conn.query(&sql, &[]).await?;
64        if rows.is_empty() {
65            bail!("Document not found");
66        }
67        let row = rows.get(0).unwrap();
68        let data: Value = row.get("data");
69        let data: T = serde_json::from_value(data)?;
70        let doc = Doc {
71            _id: row.get("_id"),
72            data,
73            _created_at: row.get("_created_at"),
74            _updated_at: row.get("_updated_at"),
75        };
76        Ok(doc)
77    }
78
79    async fn paginate(
80        &self,
81        db_name: &str,
82        table_name: &str,
83        pagination_params: &PaginationParams,
84    ) -> Result<PaginationResult<T>> {
85        let page = pagination_params.page.unwrap_or(1).max(1);
86        let limit = pagination_params.limit.unwrap_or(10).max(1);
87        let offset = (page - 1) * limit;
88
89        let mut seq = 1u32;
90        let mut bindings = vec![];
91
92        let where_sql = if let Some(ref filter) = pagination_params.filter {
93            let sql = cond_to_sql(filter, &mut bindings, &mut seq)?;
94            if sql.is_empty() {
95                "".to_string()
96            } else {
97                format!("WHERE {}", sql)
98            }
99        } else {
100            "".to_string()
101        };
102
103        // 处理排序
104        let order_sql = if let Some(sort_expr) = &pagination_params.sort {
105            sort_to_sql(sort_expr)?
106        } else {
107            "_id ASC".to_string()
108        };
109
110        // 查询分页数据
111        let query_sql = format!(
112            "SELECT * FROM {}.{} {} ORDER BY {} LIMIT {} OFFSET {}",
113            db_name, table_name, where_sql, order_sql, limit, offset
114        );
115
116        // 查询总数
117        let total_sql = format!(
118            "SELECT COUNT(*) FROM {}.{} {}",
119            db_name, table_name, where_sql
120        );
121
122        let conn = self.pool.get().await?;
123        let args: Vec<&(dyn ToSql + Sync)> =
124            bindings.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
125        debug!("args: {:?}", args);
126        debug!("total_sql: {}", total_sql);
127        let row = conn.query_one(&total_sql, &args).await?;
128        let total_count: i64 = row.get(0);
129        let total_count = total_count as u32;
130
131        let items = if total_count > 0 {
132            debug!("query_sql: {}", query_sql);
133            let row = conn.query(&query_sql, &args).await?;
134            let mut items = Vec::new();
135            for row in row {
136                let data: Value = row.get(1);
137                let data: T = serde_json::from_value(data)?;
138                let doc = Doc {
139                    _id: row.get(0),
140                    data,
141                    _created_at: row.get(2),
142                    _updated_at: row.get(3),
143                };
144                items.push(doc);
145            }
146            items
147        } else {
148            vec![]
149        };
150
151        let total_pages = total_count / limit + if total_count % limit > 0 { 1 } else { 0 };
152        Ok(PaginationResult {
153            items,
154            pagination: Pagination {
155                total_count,
156                total_pages,
157                current_page: page,
158                items_per_page: limit,
159            },
160        })
161    }
162
163    async fn upsert(
164        &self,
165        db_name: &str,
166        table_name: &str,
167        items: Vec<Doc<T>>,
168    ) -> Result<UpsertResult> {
169        if items.is_empty() {
170            return Ok(UpsertResult {
171                created_count: 0,
172                updated_count: 0,
173            });
174        }
175
176        let mut query = format!("INSERT INTO {}.{} (_id, data) VALUES ", db_name, table_name);
177
178        let mut values = Vec::new();
179        let mut args: Vec<Box<dyn ToSql + Sync>> = Vec::new();
180
181        for (i, doc) in items.iter().enumerate() {
182            values.push(format!("(${}, ${})", i * 2 + 1, i * 2 + 2,));
183
184            args.push(Box::new(doc._id.clone()));
185            args.push(Box::new(serde_json::to_value(&doc.data)?));
186        }
187
188        query.push_str(&values.join(", "));
189        query.push_str(
190            " ON CONFLICT (_id) DO UPDATE SET
191              data = EXCLUDED.data,
192              _updated_at = (EXTRACT(EPOCH FROM NOW()) * 1000)::BIGINT
193            RETURNING (xmax = 0) AS inserted;",
194        );
195
196        debug!("{}", query);
197
198        let conn = self.pool.get().await?;
199        let args_refs: Vec<&(dyn ToSql + Sync)> = args.iter().map(|x| x.as_ref()).collect();
200        let rows = conn.query(&query, &args_refs[..]).await?;
201
202        let created_count = rows.iter().filter(|row| row.get::<_, bool>(0)).count() as u32;
203        let updated_count = rows.len() as u32 - created_count;
204
205        Ok(UpsertResult {
206            created_count,
207            updated_count,
208        })
209    }
210
211    async fn update(
212        &self,
213        db_name: &str,
214        table_name: &str,
215        params: &PatchParams,
216    ) -> Result<UpdateResult> {
217        // 2️⃣ 解析 `patch` 生成 `JSONB SET` 语句
218        let mut set_sql = "data = ".to_string();
219        let mut args: Vec<Box<dyn ToSql + Sync>> = Vec::new();
220        let mut jsonb_expr = "data".to_string(); // 初始值为 `data`
221
222        for (key, value) in params.patch.as_object().unwrap() {
223            let path_arg_index = args.len() + 1;
224            let value_arg_index = path_arg_index + 1;
225
226            jsonb_expr = format!(
227                "jsonb_set({}, ${}, ${}, true)",
228                jsonb_expr, path_arg_index, value_arg_index
229            );
230
231            args.push(Box::new(vec![key.to_string()])); // JSON 路径,必须是 `TEXT[]`
232            args.push(Box::new(value.clone())); // JSON 值
233        }
234
235        set_sql.push_str(&jsonb_expr);
236
237        // 更新 `_updated_at`
238        let set_sql = format!(
239            "{}, _updated_at = EXTRACT(EPOCH FROM NOW()) * 1000",
240            set_sql
241        );
242
243        // 1️⃣ 解析 `filter` 生成 `WHERE` 语句
244        let mut bindings = vec![];
245        let seq = &mut (args.len() as u32 + 1);
246        let where_sql = cond_to_sql(&params.filter, &mut bindings, seq)?;
247
248        // 3️⃣ 生成 SQL
249        let query = format!(
250            "UPDATE {}.{} SET {} {} RETURNING _id;",
251            db_name, table_name, set_sql, where_sql
252        );
253        args.append(&mut bindings);
254
255        // 4️⃣ 执行 SQL
256        debug!("{}", query);
257        debug!("{:?}", args);
258        let args_refs: Vec<&(dyn ToSql + Sync)> = args.iter().map(|x| x.as_ref()).collect();
259        let conn = self.pool.get().await?;
260        let rows = conn.query(&query, &args_refs[..]).await?;
261
262        // 5️⃣ 返回更新的行数
263        Ok(UpdateResult {
264            updated_count: rows.len() as u32,
265        })
266    }
267
268    async fn delete(
269        &self,
270        db_name: &str,
271        table_name: &str,
272        params: &DeleteParams,
273    ) -> Result<DeleteResult> {
274        // 1️⃣ 解析 `filter` 生成 `WHERE` 语句
275        let bindings = &mut vec![];
276        let seq = &mut 1;
277        let where_sql = cond_to_sql(&params.filter, bindings, seq)?;
278
279        // 2️⃣ 生成 SQL
280        let query = format!(
281            "DELETE FROM {}.{} {} RETURNING _id;",
282            db_name, table_name, where_sql
283        );
284
285        // 3️⃣ 执行 SQL
286        let conn = self.pool.get().await?;
287        let args_refs: Vec<&(dyn ToSql + Sync)> = bindings.iter().map(|x| x.as_ref()).collect();
288        let rows = conn.query(&query, &args_refs).await?;
289
290        // 4️⃣ 返回删除的行数
291        Ok(DeleteResult {
292            deleted_count: rows.len() as u32,
293        })
294    }
295}