rest_model_postgres/
lib.rs

1use anyhow::{bail, Result};
2use bb8_postgres::{bb8::Pool, tokio_postgres::NoTls, PostgresConnectionManager};
3use oid::ObjectId;
4use rest_model::{
5    pagination::{Pagination, PaginationParams},
6    DbClient, DeleteParams, DeleteResult, Doc, PaginationResult, PatchParams, RestModel,
7    UpdateResult, UpsertResult,
8};
9use serde_json::Value;
10use tokio_postgres::types::ToSql;
11
12mod query;
13pub use query::*;
14use tracing::debug;
15mod oid;
16
17#[derive(Debug, Clone)]
18pub struct Db {
19    pub pool: Pool<PostgresConnectionManager<NoTls>>,
20}
21
22impl Db {
23    pub async fn try_new(postgres_uri: &str) -> Result<Self> {
24        let manager = PostgresConnectionManager::new_from_stringlike(postgres_uri, NoTls)?;
25        let pool = Pool::builder().max_size(10).build(manager).await?;
26        Ok(Self { pool })
27    }
28}
29
30impl<T: RestModel> DbClient<T> for Db {
31    fn generate_id(&self) -> String {
32        ObjectId::new().to_hex()
33    }
34
35    async fn init(
36        &self,
37        db_name: &str,
38        table_name: &str,
39    ) -> std::result::Result<(), anyhow::Error> {
40        let sql = format!(
41            "CREATE TABLE IF NOT EXISTS {}.{} (
42                _id VARCHAR(24) PRIMARY KEY,
43                data JSONB NOT NULL,
44                _created_at BIGINT DEFAULT (EXTRACT(EPOCH FROM NOW()) * 1000)::BIGINT,
45                _updated_at BIGINT DEFAULT (EXTRACT(EPOCH FROM NOW()) * 1000)::BIGINT
46            );",
47            db_name, table_name
48        );
49        self.pool
50            .get()
51            .await?
52            .execute(&sql, &[])
53            .await
54            .map(|_| ())
55            .map_err(|e| anyhow::Error::from(e))
56    }
57
58    async fn select_by_id(&self, db_name: &str, table_name: &str, id: &str) -> Result<Doc<T>> {
59        let sql = format!(
60            "SELECT * FROM {}.{} WHERE _id = '{}'",
61            db_name, table_name, id
62        );
63        let conn = self.pool.get().await?;
64        let rows = conn.query(&sql, &[]).await?;
65        if rows.is_empty() {
66            bail!("Document not found");
67        }
68        let row = rows.get(0).unwrap();
69        let data: Value = row.get("data");
70        let data: T = serde_json::from_value(data)?;
71        let doc = Doc {
72            _id: row.get("_id"),
73            data,
74            _created_at: row.get("_created_at"),
75            _updated_at: row.get("_updated_at"),
76        };
77        Ok(doc)
78    }
79
80    async fn paginate(
81        &self,
82        db_name: &str,
83        table_name: &str,
84        pagination_params: &PaginationParams,
85    ) -> Result<PaginationResult<T>> {
86        let page = pagination_params.page.unwrap_or(1).max(1);
87        let limit = pagination_params.limit.unwrap_or(10).max(1);
88        let offset = (page - 1) * limit;
89
90        let mut seq = 1u32;
91        let mut bindings = vec![];
92
93        let where_sql = if let Some(ref filter) = pagination_params.filter {
94            let sql = cond_to_sql(filter, &mut bindings, &mut seq)?;
95            if sql.is_empty() {
96                "".to_string()
97            } else {
98                format!("WHERE {}", sql)
99            }
100        } else {
101            "".to_string()
102        };
103
104        // 处理排序
105        let order_sql = if let Some(sort_expr) = &pagination_params.sort {
106            sort_to_sql(sort_expr)?
107        } else {
108            "_id ASC".to_string()
109        };
110
111        // 查询分页数据
112        let query_sql = format!(
113            "SELECT * FROM {}.{} {} ORDER BY {} LIMIT {} OFFSET {}",
114            db_name, table_name, where_sql, order_sql, limit, offset
115        );
116
117        // 查询总数
118        let total_sql = format!(
119            "SELECT COUNT(*) FROM {}.{} {}",
120            db_name, table_name, where_sql
121        );
122
123        let conn = self.pool.get().await?;
124        let args: Vec<&(dyn ToSql + Sync)> =
125            bindings.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
126        debug!("args: {:?}", args);
127        debug!("total_sql: {}", total_sql);
128        let row = conn.query_one(&total_sql, &args).await?;
129        let total_count: i64 = row.get(0);
130        let total_count = total_count as u32;
131
132        let items = if total_count > 0 {
133            debug!("query_sql: {}", query_sql);
134            let row = conn.query(&query_sql, &args).await?;
135            let mut items = Vec::new();
136            for row in row {
137                let data: Value = row.get(1);
138                let data: T = serde_json::from_value(data)?;
139                let doc = Doc {
140                    _id: row.get(0),
141                    data,
142                    _created_at: row.get(2),
143                    _updated_at: row.get(3),
144                };
145                items.push(doc);
146            }
147            items
148        } else {
149            vec![]
150        };
151
152        let total_pages = total_count / limit + if total_count % limit > 0 { 1 } else { 0 };
153        Ok(PaginationResult {
154            items,
155            pagination: Pagination {
156                total_count,
157                total_pages,
158                current_page: page,
159                items_per_page: limit,
160            },
161        })
162    }
163
164    async fn upsert(
165        &self,
166        db_name: &str,
167        table_name: &str,
168        items: &[Doc<T>],
169    ) -> Result<UpsertResult> {
170        if items.is_empty() {
171            return Ok(UpsertResult {
172                created_count: 0,
173                updated_count: 0,
174            });
175        }
176
177        let mut query = format!("INSERT INTO {}.{} (_id, data) VALUES ", db_name, table_name);
178
179        let mut values = Vec::new();
180        let mut args: Vec<Box<dyn ToSql + Sync>> = Vec::new();
181
182        for (i, doc) in items.iter().enumerate() {
183            values.push(format!("(${}, ${})", i * 2 + 1, i * 2 + 2,));
184
185            args.push(Box::new(doc._id.clone()));
186            args.push(Box::new(serde_json::to_value(&doc.data)?));
187        }
188
189        query.push_str(&values.join(", "));
190        query.push_str(
191            " ON CONFLICT (_id) DO UPDATE SET
192              data = EXCLUDED.data,
193              _updated_at = (EXTRACT(EPOCH FROM NOW()) * 1000)::BIGINT
194            RETURNING (xmax = 0) AS inserted;",
195        );
196
197        debug!("{}", query);
198
199        let conn = self.pool.get().await?;
200        let args_refs: Vec<&(dyn ToSql + Sync)> = args.iter().map(|x| x.as_ref()).collect();
201        let rows = conn.query(&query, &args_refs[..]).await?;
202
203        let created_count = rows.iter().filter(|row| row.get::<_, bool>(0)).count() as u32;
204        let updated_count = rows.len() as u32 - created_count;
205
206        Ok(UpsertResult {
207            created_count,
208            updated_count,
209        })
210    }
211
212    async fn update(
213        &self,
214        db_name: &str,
215        table_name: &str,
216        params: &PatchParams,
217    ) -> Result<UpdateResult> {
218        // 2️⃣ 解析 `patch` 生成 `JSONB SET` 语句
219        let mut set_sql = "data = ".to_string();
220        let mut args: Vec<Box<dyn ToSql + Sync>> = Vec::new();
221        let mut jsonb_expr = "data".to_string(); // 初始值为 `data`
222
223        for (key, value) in params.patch.as_object().unwrap() {
224            let path_arg_index = args.len() + 1;
225            let value_arg_index = path_arg_index + 1;
226
227            jsonb_expr = format!(
228                "jsonb_set({}, ${}, ${}, true)",
229                jsonb_expr, path_arg_index, value_arg_index
230            );
231
232            args.push(Box::new(vec![key.to_string()])); // JSON 路径,必须是 `TEXT[]`
233            args.push(Box::new(value.clone())); // JSON 值
234        }
235
236        set_sql.push_str(&jsonb_expr);
237
238        // 更新 `_updated_at`
239        let set_sql = format!(
240            "{}, _updated_at = EXTRACT(EPOCH FROM NOW()) * 1000",
241            set_sql
242        );
243
244        // 1️⃣ 解析 `filter` 生成 `WHERE` 语句
245        let mut bindings = vec![];
246        let seq = &mut (args.len() as u32 + 1);
247        let where_sql = cond_to_sql(&params.filter, &mut bindings, seq)?;
248
249        // 3️⃣ 生成 SQL
250        let query = format!(
251            "UPDATE {}.{} SET {} {} RETURNING _id;",
252            db_name, table_name, set_sql, where_sql
253        );
254        args.append(&mut bindings);
255
256        // 4️⃣ 执行 SQL
257        debug!("{}", query);
258        debug!("{:?}", args);
259        let args_refs: Vec<&(dyn ToSql + Sync)> = args.iter().map(|x| x.as_ref()).collect();
260        let conn = self.pool.get().await?;
261        let rows = conn.query(&query, &args_refs[..]).await?;
262
263        // 5️⃣ 返回更新的行数
264        Ok(UpdateResult {
265            updated_count: rows.len() as u32,
266        })
267    }
268
269    async fn delete(
270        &self,
271        db_name: &str,
272        table_name: &str,
273        params: &DeleteParams,
274    ) -> Result<DeleteResult> {
275        // 1️⃣ 解析 `filter` 生成 `WHERE` 语句
276        let bindings = &mut vec![];
277        let seq = &mut 1;
278        let where_sql = cond_to_sql(&params.filter, bindings, seq)?;
279
280        // 2️⃣ 生成 SQL
281        let query = format!(
282            "DELETE FROM {}.{} {} RETURNING _id;",
283            db_name, table_name, where_sql
284        );
285
286        // 3️⃣ 执行 SQL
287        let conn = self.pool.get().await?;
288        let args_refs: Vec<&(dyn ToSql + Sync)> = bindings.iter().map(|x| x.as_ref()).collect();
289        let rows = conn.query(&query, &args_refs).await?;
290
291        // 4️⃣ 返回删除的行数
292        Ok(DeleteResult {
293            deleted_count: rows.len() as u32,
294        })
295    }
296}