cyfs_task_manager/
db_helper.rs

1use std::marker::PhantomData;
2use std::str::FromStr;
3use std::time::Duration;
4use cyfs_base::{BuckyError, BuckyErrorCode};
5use sqlx::{Transaction, Connection, Executor, ConnectOptions};
6use log::LevelFilter;
7use sqlx::pool::PoolConnection;
8use sqlx::Execute;
9pub use sqlx as cyfs_sql;
10pub use sqlx::Row as SqlRow;
11
12pub trait ErrorMap {
13    type OutError;
14    type InError;
15    fn map(e: Self::InError, msg: &str) -> Self::OutError;
16}
17
18pub struct DefaultToBuckyError;
19
20impl ErrorMap for DefaultToBuckyError {
21    type OutError = BuckyError;
22    type InError = sqlx::Error;
23
24    fn map(e: sqlx::Error, msg: &str) -> BuckyError {
25        match e {
26            sqlx::Error::RowNotFound => {
27                let msg = format!("not found, {}", msg);
28                BuckyError::new(BuckyErrorCode::NotFound, msg)
29            }
30            _ => {
31                let msg = format!("sqlite error: {:?} info:{}", e, msg);
32                if cfg!(test) {
33                    println!("{}", msg);
34                } else {
35                    log::error!("{}", msg);
36                }
37                BuckyError::new(BuckyErrorCode::SqliteError, msg)
38            }
39        }
40    }
41}
42
43pub type SqlResult = <sqlx::Any as sqlx::Database>::QueryResult;
44pub type SqlRowObject = <sqlx::Any as sqlx::Database>::Row;
45pub type SqlTransaction<'a> = sqlx::Transaction<'a, sqlx::Any>;
46pub type SqlQuery<'a> = sqlx::query::Query<'a, sqlx::Any, <sqlx::Any as sqlx::database::HasArguments<'a>>::Arguments>;
47pub type RawSqlPool = sqlx::AnyPool;
48
49#[macro_export]
50macro_rules! sql_query {
51    ($query:expr) => ({
52        cyfs_sql::query!($query)
53    });
54
55    ($query:expr, $($args:tt)*) => ({
56        cyfs_sql::query!($query, $($args)*)
57    })
58}
59
60#[derive(Clone)]
61pub struct SqlPool<EM: ErrorMap<InError = sqlx::Error> = DefaultToBuckyError> {
62    pool: sqlx::AnyPool,
63    uri: String,
64    _em: PhantomData<EM>,
65}
66
67impl<EM: ErrorMap<InError = sqlx::Error>> SqlPool<EM> {
68    pub fn from_raw_pool(pool: RawSqlPool) -> Self {
69        Self { pool, uri: "".to_string(), _em: Default::default() }
70    }
71
72    pub async fn open(uri: &str, max_connections: u32) -> Result<Self, EM::OutError> {
73        log::info!("open pool {} max_connections {}", uri, max_connections);
74        let pool_options = sqlx::any::AnyPoolOptions::new()
75            .max_connections(max_connections)
76            .connect_timeout(Duration::from_secs(300))
77            .min_connections(0)
78            .idle_timeout(Duration::from_secs(300));
79        let kind = sqlx::any::AnyKind::from_str(uri).map_err(|e| {
80            EM::map(e, format!("[{} {}]", line!(), uri).as_str())
81        })?;
82        let pool = match kind {
83            sqlx::any::AnyKind::Sqlite => {
84                let mut options = sqlx::sqlite::SqliteConnectOptions::from_str(uri).map_err(|e| {
85                    EM::map(e, format!("[{} {}]", line!(), uri).as_str())
86                })?
87                    .busy_timeout(Duration::from_secs(300))
88                    .create_if_missing(true);
89                #[cfg(target_os = "ios")]
90                {
91                    options = options.serialized(true);
92                }
93
94                options.log_statements(LevelFilter::Off)
95                    .log_slow_statements(LevelFilter::Off, Duration::from_secs(10));
96                pool_options.connect_with(sqlx::any::AnyConnectOptions::from(options)).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), uri).as_str()))?
97            },
98            _ => {
99                pool_options.connect(uri).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), uri).as_str()))?
100            }
101        };
102        Ok(Self {
103            pool,
104            uri: uri.to_string(),
105            _em: Default::default()
106        })
107    }
108
109    pub async fn raw_pool(&self) -> RawSqlPool {
110        self.pool.clone()
111    }
112
113    pub async fn get_conn(&self) -> Result<SqlConnection, EM::OutError> {
114        let conn = self.pool.acquire().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), self.uri.as_str()).as_str()))?;
115        Ok(SqlConnection::from(conn))
116    }
117}
118
119pub fn sql_query(sql: &str) -> SqlQuery<'_> {
120    sqlx::query::<sqlx::Any>(sql)
121}
122
123pub enum SqlConnectionType {
124    PoolConn(PoolConnection<sqlx::Any>),
125    Conn(sqlx::AnyConnection),
126}
127pub struct SqlConnection<EM: ErrorMap<InError = sqlx::Error> = DefaultToBuckyError> {
128    conn: SqlConnectionType,
129    trans: Option<Transaction<'static, sqlx::Any>>,
130    _em: PhantomData<EM>,
131}
132
133impl From<sqlx::pool::PoolConnection<sqlx::Any>> for SqlConnection {
134    fn from(conn: sqlx::pool::PoolConnection<sqlx::Any>) -> Self {
135        Self { conn: SqlConnectionType::PoolConn(conn), _em: Default::default(), trans: None }
136    }
137}
138
139impl<EM: 'static + ErrorMap<InError = sqlx::Error>> SqlConnection<EM> {
140    pub async fn open(uri: &str) -> Result<Self, EM::OutError> {
141        let kind = sqlx::any::AnyKind::from_str(uri).map_err(|e| EM::map(e, format!("[{} {}]", line!(), uri).as_str()))?;
142        let conn = match kind {
143            sqlx::any::AnyKind::Sqlite => {
144                let mut options = sqlx::sqlite::SqliteConnectOptions::from_str(uri).map_err(|e| EM::map(e, format!("[{} {}]", line!(), uri).as_str()))?
145                    .busy_timeout(Duration::from_secs(300));
146                #[cfg(target_os = "ios")]
147                {
148                    options = options.serialized(true);
149                }
150
151                options.log_statements(LevelFilter::Off)
152                    .log_slow_statements(LevelFilter::Off, Duration::from_secs(10));
153                sqlx::any::AnyConnectOptions::from(options).connect().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), uri).as_str()))?
154            },
155            _ => {
156                sqlx::any::AnyConnection::connect(uri).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), uri).as_str()))?
157            }
158        };
159
160        Ok(Self {
161            conn: SqlConnectionType::Conn(conn),
162            _em: Default::default(),
163            trans: None
164        })
165    }
166
167    pub async fn execute_sql(&mut self, query: SqlQuery<'_>) -> Result<SqlResult, EM::OutError> {
168        let sql = query.sql();
169        log::debug!("sql {}", sql);
170        if self.trans.is_none() {
171            match &mut self.conn {
172                SqlConnectionType::PoolConn(conn) => {
173                    conn.execute(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
174                },
175                SqlConnectionType::Conn(conn) => {
176                    conn.execute(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
177                }
178            }
179        } else {
180            self.trans.as_mut().unwrap().execute(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
181        }
182    }
183
184    pub async fn query_one(&mut self, query: SqlQuery<'_>) -> Result<SqlRowObject, EM::OutError> {
185        let sql = query.sql();
186        log::debug!("sql {}", sql);
187        if self.trans.is_none() {
188            match &mut self.conn {
189                SqlConnectionType::PoolConn(conn) => {
190                    conn.fetch_one(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
191                },
192                SqlConnectionType::Conn(conn) => {
193                    conn.fetch_one(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
194                }
195            }
196        } else {
197            self.trans.as_mut().unwrap().fetch_one(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
198        }
199    }
200
201    pub async fn query_all(&mut self, query: SqlQuery<'_>) -> Result<Vec<SqlRowObject>, EM::OutError> {
202        let sql = query.sql();
203        log::debug!("sql {}", sql);
204        if self.trans.is_none() {
205            match &mut self.conn {
206                SqlConnectionType::PoolConn(conn) => {
207                    conn.fetch_all(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
208                },
209                SqlConnectionType::Conn(conn) => {
210                    conn.fetch_all(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
211                }
212            }
213        } else {
214            self.trans.as_mut().unwrap().fetch_all(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
215        }
216    }
217
218    pub async fn begin_transaction(&mut self) -> Result<(), EM::OutError> {
219        let this: &'static mut Self = unsafe {std::mem::transmute(self)};
220        let trans = match &mut this.conn {
221            SqlConnectionType::PoolConn(conn) => {
222                conn.begin().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "begin trans").as_str()))
223            },
224            SqlConnectionType::Conn(conn) => {
225                conn.begin().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "begin trans").as_str()))
226            }
227        }?;
228        this.trans = Some(trans);
229        Ok(())
230    }
231
232    pub async fn rollback_transaction(&mut self) -> Result<(), EM::OutError> {
233        if self.trans.is_none() {
234            return Ok(())
235        } else {
236            self.trans.take().unwrap().rollback().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "rollback trans").as_str()))
237        }
238    }
239
240    pub async fn commit_transaction(&mut self) -> Result<(), EM::OutError> {
241        if self.trans.is_none() {
242            return Ok(())
243        } else {
244            self.trans.take().unwrap().commit().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "commit trans").as_str()))
245        }
246    }
247}
248
249impl<EM: ErrorMap<InError=sqlx::Error>> Drop for SqlConnection<EM> {
250    fn drop(&mut self) {
251        if self.trans.is_some() {
252            let trans = self.trans.take().unwrap();
253            async_std::task::block_on(async move {
254                let _ = trans.rollback().await;
255            });
256        }
257    }
258}
259
260#[cfg(test)]
261mod test_connection {
262    use cyfs_base::BuckyResult;
263    use sqlx::{Connection, Row};
264    use crate::*;
265
266    #[derive(sqlx::FromRow)]
267    struct DescExtra {
268        obj_id: String,
269        rent_arrears: i64,
270        rent_arrears_count: i64,
271        rent_value: i64,
272        coin_id: i8,
273        data_len: i32,
274        other_charge_balance: i64,
275    }
276
277    async fn new<DB>() -> BuckyResult<DB::Connection>
278        where
279            DB: sqlx::Database,
280    {
281        Ok(DB::Connection::connect("sqlite::memory:").await.map_err(|e|DefaultToBuckyError::map(e, ""))?)
282    }
283
284    #[test]
285    fn test() {
286        async_std::task::block_on(async {
287            let mut sqlx_conn = SqlConnection::<DefaultToBuckyError>::open("sqlite://:memory:").await.unwrap();
288            sqlx_conn.begin_transaction().await.unwrap();
289            let create_table = r#"CREATE TABLE IF NOT EXISTS desc_extra (
290            "obj_id" char(45) PRIMARY KEY NOT NULL UNIQUE,
291        "rent_arrears" INTEGER,
292        "rent_arrears_count" INTEGER,
293        "rent_value" INTEGER,
294        "coin_id" INTEGER,
295        "data_len" INTEGER,
296        "other_charge_balance" INTEGER);"#;
297            sqlx_conn.execute_sql(sql_query(create_table)).await.unwrap();
298            let insert = r#"insert into desc_extra (obj_id,
299            rent_arrears,
300            rent_arrears_count,
301            rent_value,
302            coin_id,
303            data_len,
304            other_charge_balance) values (
305            "test", 1, 1, 2, 3, 4, 5)"#;
306            sqlx_conn.execute_sql(sql_query(insert)).await.unwrap();
307            sqlx_conn.commit_transaction().await.unwrap();
308
309            let query = sql_query("select * from desc_extra where obj_id = ?").bind("test");
310            let row = sqlx_conn.query_one(query).await.unwrap();
311            let id: String = row.get("obj_id");
312            assert_eq!(id, "test".to_owned());
313            let coin_id: i32 = row.get("coin_id");
314            assert_eq!(coin_id, 3);
315
316            let row = sqlx_conn.query_one(sqlx::query("select * from desc_extra where obj_id = ?").bind("test")).await.unwrap();
317            let id: String = row.get("obj_id");
318            assert_eq!(id, "test".to_owned());
319            let coin_id: i32 = row.get("coin_id");
320            assert_eq!(coin_id, 3);
321            //
322            // let query = sqlx::query_as::<_, DescExtra>("select * from desc_extra where obj_id = ?").bind("test").query_one(&mut sqlx_conn).await.unwrap();
323            // assert_eq!(query.obj_id, "test".to_owned());
324            // assert_eq!(query.coin_id, 3);
325        })
326    }
327
328    #[test]
329    fn test_pool() {
330        async_std::task::block_on(async {
331            let pool = SqlPool::<DefaultToBuckyError>::open("sqlite://:memory:", 5).await.unwrap();
332
333            let mut sqlx_conn = pool.get_conn().await.unwrap();
334            let create_table = r#"CREATE TABLE IF NOT EXISTS desc_extra (
335            "obj_id" char(45) PRIMARY KEY NOT NULL UNIQUE,
336        "rent_arrears" INTEGER,
337        "rent_arrears_count" INTEGER,
338        "rent_value" INTEGER,
339        "coin_id" INTEGER,
340        "data_len" INTEGER,
341        "other_charge_balance" INTEGER);"#;
342            sqlx_conn.execute_sql(sql_query(create_table)).await.unwrap();
343
344            sqlx_conn.begin_transaction().await.unwrap();
345            let insert = r#"insert into desc_extra (obj_id,
346            rent_arrears,
347            rent_arrears_count,
348            rent_value,
349            coin_id,
350            data_len,
351            other_charge_balance) values (
352            "test", 1, 1, 2, 3, 4, 5)"#;
353            sqlx_conn.execute_sql(sql_query(insert)).await.unwrap();
354            sqlx_conn.rollback_transaction().await.unwrap();
355
356            let mut sqlx_conn = pool.get_conn().await.unwrap();
357            let query = sqlx::query("select * from desc_extra where obj_id = ?").bind("test");
358            let row = sqlx_conn.query_all(query).await.unwrap();
359            assert_eq!(row.len(), 0);
360
361            let mut sqlx_conn = pool.get_conn().await.unwrap();
362            sqlx_conn.begin_transaction().await.unwrap();
363            let insert = r#"insert into desc_extra (obj_id,
364            rent_arrears,
365            rent_arrears_count,
366            rent_value,
367            coin_id,
368            data_len,
369            other_charge_balance) values (
370            "test", 1, 1, 2, 3, 4, 5)"#;
371            sqlx_conn.execute_sql(sql_query(insert)).await.unwrap();
372            sqlx_conn.commit_transaction().await.unwrap();
373
374            let query = sqlx::query("select * from desc_extra where obj_id = ?").bind("test");
375            let row = sqlx_conn.query_one(query).await.unwrap();
376            let id: String = row.get("obj_id");
377            assert_eq!(id, "test".to_owned());
378            let coin_id: i32 = row.get("coin_id");
379            assert_eq!(coin_id, 3);
380
381            let row = sqlx_conn.query_one(sqlx::query("select * from desc_extra where obj_id = ?").bind("test")).await.unwrap();
382            let id: String = row.get("obj_id");
383            assert_eq!(id, "test".to_owned());
384            let coin_id: i32 = row.get("coin_id");
385            assert_eq!(coin_id, 3);
386        })
387    }
388}