rquery_orm/infrastructure/
generic_repository.rs

1use std::marker::PhantomData;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5
6use crate::db::{DatabaseRef, DbKind};
7use crate::mapping::{Entity, FromRowNamed, Persistable, Validatable};
8use crate::query::{DualQuery, Expr, PlaceholderStyle, Query, SqlParam};
9use crate::repository::{Crud, QueryExecutor, Repository};
10use anyhow::{anyhow, Result};
11
12pub struct GenericRepository<T> {
13    db: Arc<DatabaseRef>,
14    _t: PhantomData<T>,
15}
16
17impl<T> GenericRepository<T> {
18    pub fn new(db: DatabaseRef) -> Self {
19        Self {
20            db: Arc::new(db),
21            _t: PhantomData,
22        }
23    }
24}
25
26impl<T> Clone for GenericRepository<T> {
27    fn clone(&self) -> Self {
28        Self {
29            db: self.db.clone(),
30            _t: PhantomData,
31        }
32    }
33}
34
35#[async_trait]
36impl<T> QueryExecutor<T> for GenericRepository<T>
37where
38    T: Entity + FromRowNamed + Validatable + Persistable + Send + Sync,
39{
40    fn Select(&self) -> Query<T> {
41        let style = match self.db.as_ref().kind() {
42            DbKind::Mssql => PlaceholderStyle::AtP,
43            DbKind::Postgres => PlaceholderStyle::Dollar,
44        };
45        Query::new(T::table().name, style).with_db(self.db.clone())
46    }
47
48    async fn get_by_key_async(&self, key: SqlParam) -> Result<Option<T>> {
49        let table = T::table();
50        let pk = table
51            .keys
52            .first()
53            .ok_or_else(|| anyhow!("no primary key metadata"))?;
54        let expr = Expr::Col(format!("{}.{}", table.name, pk.column)).eq(Expr::Param(key));
55        self.Select().Where(expr).to_single_async().await
56    }
57}
58
59impl<T, U> GenericRepository<(T, U)>
60where
61    T: Entity + crate::mapping::FromRowWithPrefix + Send + Sync,
62    U: Entity + crate::mapping::FromRowWithPrefix + Send + Sync,
63{
64    pub fn Select(&self) -> DualQuery<T, U> {
65        let style = match self.db.as_ref().kind() {
66            DbKind::Mssql => PlaceholderStyle::AtP,
67            DbKind::Postgres => PlaceholderStyle::Dollar,
68        };
69        DualQuery::<T, U>::new(style).with_db(self.db.clone())
70    }
71}
72
73#[async_trait]
74impl<T> Crud<T> for GenericRepository<T>
75where
76    T: Entity + FromRowNamed + Validatable + Persistable + Send + Sync,
77{
78    async fn insert_async(&self, entity: &T) -> Result<()> {
79        entity.validate().map_err(|e| anyhow!(e.join(", ")))?;
80        let style = match self.db.as_ref().kind() {
81            DbKind::Mssql => PlaceholderStyle::AtP,
82            DbKind::Postgres => PlaceholderStyle::Dollar,
83        };
84        let (sql, params, _has_identity) = entity.build_insert(style);
85        execute(&self.db, &sql, &params).await.map(|_| ())
86    }
87
88    async fn update_async(&self, entity: &T) -> Result<()> {
89        entity.validate().map_err(|e| anyhow!(e.join(", ")))?;
90        let style = match self.db.as_ref().kind() {
91            DbKind::Mssql => PlaceholderStyle::AtP,
92            DbKind::Postgres => PlaceholderStyle::Dollar,
93        };
94        let (sql, params) = entity.build_update(style);
95        execute(&self.db, &sql, &params).await.map(|_| ())
96    }
97
98    async fn delete_by_entity_async(&self, entity: &T) -> Result<()> {
99        let style = match self.db.as_ref().kind() {
100            DbKind::Mssql => PlaceholderStyle::AtP,
101            DbKind::Postgres => PlaceholderStyle::Dollar,
102        };
103        let (sql, params) = entity.build_delete(style);
104        execute(&self.db, &sql, &params).await.map(|_| ())
105    }
106
107    async fn delete_by_key_async(&self, key: SqlParam) -> Result<()> {
108        let style = match self.db.as_ref().kind() {
109            DbKind::Mssql => PlaceholderStyle::AtP,
110            DbKind::Postgres => PlaceholderStyle::Dollar,
111        };
112        let (sql, params) = T::build_delete_by_key(key, style);
113        execute(&self.db, &sql, &params).await.map(|_| ())
114    }
115}
116
117impl<T> Repository<T> for GenericRepository<T> where
118    T: Entity + FromRowNamed + Validatable + Persistable + Send + Sync
119{
120}
121
122async fn execute(db: &Arc<DatabaseRef>, sql: &str, params: &[SqlParam]) -> Result<u64> {
123    match db.as_ref() {
124        DatabaseRef::Mssql(conn) => {
125            let mut guard = conn.lock().await;
126            let mut boxed: Vec<Box<dyn tiberius::ToSql + Send + Sync>> = Vec::new();
127            for p in params {
128                let b: Box<dyn tiberius::ToSql + Send + Sync> = match p {
129                    SqlParam::I32(v) => Box::new(*v),
130                    SqlParam::I64(v) => Box::new(*v),
131                    SqlParam::Bool(v) => Box::new(*v),
132                    SqlParam::Text(v) => Box::new(v.clone()),
133                    SqlParam::Uuid(v) => Box::new(*v),
134                    SqlParam::Decimal(v) => Box::new(v.to_string()),
135                    SqlParam::DateTime(v) => Box::new(*v),
136                    SqlParam::Bytes(v) => Box::new(v.clone()),
137                    SqlParam::Null => Box::new(Option::<i32>::None),
138                };
139                boxed.push(b);
140            }
141            let refs: Vec<&dyn tiberius::ToSql> =
142                boxed.iter().map(|b| &**b as &dyn tiberius::ToSql).collect();
143            let res = guard.execute(sql, &refs[..]).await?;
144            Ok(res.total())
145        }
146        DatabaseRef::Postgres(pg) => {
147            let mut boxed: Vec<Box<dyn tokio_postgres::types::ToSql + Send + Sync>> = Vec::new();
148            for p in params {
149                let b: Box<dyn tokio_postgres::types::ToSql + Send + Sync> = match p {
150                    SqlParam::I32(v) => Box::new(*v),
151                    SqlParam::I64(v) => Box::new(*v),
152                    SqlParam::Bool(v) => Box::new(*v),
153                    SqlParam::Text(v) => Box::new(v.clone()),
154                    SqlParam::Uuid(v) => Box::new(*v),
155                    SqlParam::Decimal(v) => Box::new(v.to_string()),
156                    SqlParam::DateTime(v) => Box::new(*v),
157                    SqlParam::Bytes(v) => Box::new(v.clone()),
158                    SqlParam::Null => Box::new(Option::<i32>::None),
159                };
160                boxed.push(b);
161            }
162            let refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
163                boxed.iter().map(|b| &**b as _).collect();
164            let res = pg.execute(sql, &refs[..]).await?;
165            Ok(res)
166        }
167    }
168}