1use std::marker::PhantomData;
2use std::str::FromStr;
3use std::time::Duration;
4use cyfs_base::{BuckyError, BuckyErrorCode};
5use sqlx::{Transaction, Connection, Executor, ConnectOptions};
6use log::LevelFilter;
7use sqlx::pool::PoolConnection;
8use sqlx::Execute;
9pub use sqlx as cyfs_sql;
10pub use sqlx::Row as SqlRow;
11
12pub trait ErrorMap {
13 type OutError;
14 type InError;
15 fn map(e: Self::InError, msg: &str) -> Self::OutError;
16}
17
18pub struct DefaultToBuckyError;
19
20impl ErrorMap for DefaultToBuckyError {
21 type OutError = BuckyError;
22 type InError = sqlx::Error;
23
24 fn map(e: sqlx::Error, msg: &str) -> BuckyError {
25 match e {
26 sqlx::Error::RowNotFound => {
27 let msg = format!("not found, {}", msg);
28 BuckyError::new(BuckyErrorCode::NotFound, msg)
29 }
30 _ => {
31 let msg = format!("sqlite error: {:?} info:{}", e, msg);
32 if cfg!(test) {
33 println!("{}", msg);
34 } else {
35 log::error!("{}", msg);
36 }
37 BuckyError::new(BuckyErrorCode::SqliteError, msg)
38 }
39 }
40 }
41}
42
43pub type SqlResult = <sqlx::Any as sqlx::Database>::QueryResult;
44pub type SqlRowObject = <sqlx::Any as sqlx::Database>::Row;
45pub type SqlTransaction<'a> = sqlx::Transaction<'a, sqlx::Any>;
46pub type SqlQuery<'a> = sqlx::query::Query<'a, sqlx::Any, <sqlx::Any as sqlx::database::HasArguments<'a>>::Arguments>;
47pub type RawSqlPool = sqlx::AnyPool;
48
49#[macro_export]
50macro_rules! sql_query {
51 ($query:expr) => ({
52 cyfs_sql::query!($query)
53 });
54
55 ($query:expr, $($args:tt)*) => ({
56 cyfs_sql::query!($query, $($args)*)
57 })
58}
59
60#[derive(Clone)]
61pub struct SqlPool<EM: ErrorMap<InError = sqlx::Error> = DefaultToBuckyError> {
62 pool: sqlx::AnyPool,
63 uri: String,
64 _em: PhantomData<EM>,
65}
66
67impl<EM: ErrorMap<InError = sqlx::Error>> SqlPool<EM> {
68 pub fn from_raw_pool(pool: RawSqlPool) -> Self {
69 Self { pool, uri: "".to_string(), _em: Default::default() }
70 }
71
72 pub async fn open(uri: &str, max_connections: u32) -> Result<Self, EM::OutError> {
73 log::info!("open pool {} max_connections {}", uri, max_connections);
74 let pool_options = sqlx::any::AnyPoolOptions::new()
75 .max_connections(max_connections)
76 .connect_timeout(Duration::from_secs(300))
77 .min_connections(0)
78 .idle_timeout(Duration::from_secs(300));
79 let kind = sqlx::any::AnyKind::from_str(uri).map_err(|e| {
80 EM::map(e, format!("[{} {}]", line!(), uri).as_str())
81 })?;
82 let pool = match kind {
83 sqlx::any::AnyKind::Sqlite => {
84 let mut options = sqlx::sqlite::SqliteConnectOptions::from_str(uri).map_err(|e| {
85 EM::map(e, format!("[{} {}]", line!(), uri).as_str())
86 })?
87 .busy_timeout(Duration::from_secs(300))
88 .create_if_missing(true);
89 #[cfg(target_os = "ios")]
90 {
91 options = options.serialized(true);
92 }
93
94 options.log_statements(LevelFilter::Off)
95 .log_slow_statements(LevelFilter::Off, Duration::from_secs(10));
96 pool_options.connect_with(sqlx::any::AnyConnectOptions::from(options)).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), uri).as_str()))?
97 },
98 _ => {
99 pool_options.connect(uri).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), uri).as_str()))?
100 }
101 };
102 Ok(Self {
103 pool,
104 uri: uri.to_string(),
105 _em: Default::default()
106 })
107 }
108
109 pub async fn raw_pool(&self) -> RawSqlPool {
110 self.pool.clone()
111 }
112
113 pub async fn get_conn(&self) -> Result<SqlConnection, EM::OutError> {
114 let conn = self.pool.acquire().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), self.uri.as_str()).as_str()))?;
115 Ok(SqlConnection::from(conn))
116 }
117}
118
119pub fn sql_query(sql: &str) -> SqlQuery<'_> {
120 sqlx::query::<sqlx::Any>(sql)
121}
122
123pub enum SqlConnectionType {
124 PoolConn(PoolConnection<sqlx::Any>),
125 Conn(sqlx::AnyConnection),
126}
127pub struct SqlConnection<EM: ErrorMap<InError = sqlx::Error> = DefaultToBuckyError> {
128 conn: SqlConnectionType,
129 trans: Option<Transaction<'static, sqlx::Any>>,
130 _em: PhantomData<EM>,
131}
132
133impl From<sqlx::pool::PoolConnection<sqlx::Any>> for SqlConnection {
134 fn from(conn: sqlx::pool::PoolConnection<sqlx::Any>) -> Self {
135 Self { conn: SqlConnectionType::PoolConn(conn), _em: Default::default(), trans: None }
136 }
137}
138
139impl<EM: 'static + ErrorMap<InError = sqlx::Error>> SqlConnection<EM> {
140 pub async fn open(uri: &str) -> Result<Self, EM::OutError> {
141 let kind = sqlx::any::AnyKind::from_str(uri).map_err(|e| EM::map(e, format!("[{} {}]", line!(), uri).as_str()))?;
142 let conn = match kind {
143 sqlx::any::AnyKind::Sqlite => {
144 let mut options = sqlx::sqlite::SqliteConnectOptions::from_str(uri).map_err(|e| EM::map(e, format!("[{} {}]", line!(), uri).as_str()))?
145 .busy_timeout(Duration::from_secs(300));
146 #[cfg(target_os = "ios")]
147 {
148 options = options.serialized(true);
149 }
150
151 options.log_statements(LevelFilter::Off)
152 .log_slow_statements(LevelFilter::Off, Duration::from_secs(10));
153 sqlx::any::AnyConnectOptions::from(options).connect().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), uri).as_str()))?
154 },
155 _ => {
156 sqlx::any::AnyConnection::connect(uri).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), uri).as_str()))?
157 }
158 };
159
160 Ok(Self {
161 conn: SqlConnectionType::Conn(conn),
162 _em: Default::default(),
163 trans: None
164 })
165 }
166
167 pub async fn execute_sql(&mut self, query: SqlQuery<'_>) -> Result<SqlResult, EM::OutError> {
168 let sql = query.sql();
169 log::debug!("sql {}", sql);
170 if self.trans.is_none() {
171 match &mut self.conn {
172 SqlConnectionType::PoolConn(conn) => {
173 conn.execute(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
174 },
175 SqlConnectionType::Conn(conn) => {
176 conn.execute(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
177 }
178 }
179 } else {
180 self.trans.as_mut().unwrap().execute(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
181 }
182 }
183
184 pub async fn query_one(&mut self, query: SqlQuery<'_>) -> Result<SqlRowObject, EM::OutError> {
185 let sql = query.sql();
186 log::debug!("sql {}", sql);
187 if self.trans.is_none() {
188 match &mut self.conn {
189 SqlConnectionType::PoolConn(conn) => {
190 conn.fetch_one(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
191 },
192 SqlConnectionType::Conn(conn) => {
193 conn.fetch_one(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
194 }
195 }
196 } else {
197 self.trans.as_mut().unwrap().fetch_one(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
198 }
199 }
200
201 pub async fn query_all(&mut self, query: SqlQuery<'_>) -> Result<Vec<SqlRowObject>, EM::OutError> {
202 let sql = query.sql();
203 log::debug!("sql {}", sql);
204 if self.trans.is_none() {
205 match &mut self.conn {
206 SqlConnectionType::PoolConn(conn) => {
207 conn.fetch_all(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
208 },
209 SqlConnectionType::Conn(conn) => {
210 conn.fetch_all(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
211 }
212 }
213 } else {
214 self.trans.as_mut().unwrap().fetch_all(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
215 }
216 }
217
218 pub async fn begin_transaction(&mut self) -> Result<(), EM::OutError> {
219 let this: &'static mut Self = unsafe {std::mem::transmute(self)};
220 let trans = match &mut this.conn {
221 SqlConnectionType::PoolConn(conn) => {
222 conn.begin().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "begin trans").as_str()))
223 },
224 SqlConnectionType::Conn(conn) => {
225 conn.begin().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "begin trans").as_str()))
226 }
227 }?;
228 this.trans = Some(trans);
229 Ok(())
230 }
231
232 pub async fn rollback_transaction(&mut self) -> Result<(), EM::OutError> {
233 if self.trans.is_none() {
234 return Ok(())
235 } else {
236 self.trans.take().unwrap().rollback().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "rollback trans").as_str()))
237 }
238 }
239
240 pub async fn commit_transaction(&mut self) -> Result<(), EM::OutError> {
241 if self.trans.is_none() {
242 return Ok(())
243 } else {
244 self.trans.take().unwrap().commit().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "commit trans").as_str()))
245 }
246 }
247}
248
249impl<EM: ErrorMap<InError=sqlx::Error>> Drop for SqlConnection<EM> {
250 fn drop(&mut self) {
251 if self.trans.is_some() {
252 let trans = self.trans.take().unwrap();
253 async_std::task::block_on(async move {
254 let _ = trans.rollback().await;
255 });
256 }
257 }
258}
259
260#[cfg(test)]
261mod test_connection {
262 use cyfs_base::BuckyResult;
263 use sqlx::{Connection, Row};
264 use crate::*;
265
266 #[derive(sqlx::FromRow)]
267 struct DescExtra {
268 obj_id: String,
269 rent_arrears: i64,
270 rent_arrears_count: i64,
271 rent_value: i64,
272 coin_id: i8,
273 data_len: i32,
274 other_charge_balance: i64,
275 }
276
277 async fn new<DB>() -> BuckyResult<DB::Connection>
278 where
279 DB: sqlx::Database,
280 {
281 Ok(DB::Connection::connect("sqlite::memory:").await.map_err(|e|DefaultToBuckyError::map(e, ""))?)
282 }
283
284 #[test]
285 fn test() {
286 async_std::task::block_on(async {
287 let mut sqlx_conn = SqlConnection::<DefaultToBuckyError>::open("sqlite://:memory:").await.unwrap();
288 sqlx_conn.begin_transaction().await.unwrap();
289 let create_table = r#"CREATE TABLE IF NOT EXISTS desc_extra (
290 "obj_id" char(45) PRIMARY KEY NOT NULL UNIQUE,
291 "rent_arrears" INTEGER,
292 "rent_arrears_count" INTEGER,
293 "rent_value" INTEGER,
294 "coin_id" INTEGER,
295 "data_len" INTEGER,
296 "other_charge_balance" INTEGER);"#;
297 sqlx_conn.execute_sql(sql_query(create_table)).await.unwrap();
298 let insert = r#"insert into desc_extra (obj_id,
299 rent_arrears,
300 rent_arrears_count,
301 rent_value,
302 coin_id,
303 data_len,
304 other_charge_balance) values (
305 "test", 1, 1, 2, 3, 4, 5)"#;
306 sqlx_conn.execute_sql(sql_query(insert)).await.unwrap();
307 sqlx_conn.commit_transaction().await.unwrap();
308
309 let query = sql_query("select * from desc_extra where obj_id = ?").bind("test");
310 let row = sqlx_conn.query_one(query).await.unwrap();
311 let id: String = row.get("obj_id");
312 assert_eq!(id, "test".to_owned());
313 let coin_id: i32 = row.get("coin_id");
314 assert_eq!(coin_id, 3);
315
316 let row = sqlx_conn.query_one(sqlx::query("select * from desc_extra where obj_id = ?").bind("test")).await.unwrap();
317 let id: String = row.get("obj_id");
318 assert_eq!(id, "test".to_owned());
319 let coin_id: i32 = row.get("coin_id");
320 assert_eq!(coin_id, 3);
321 })
326 }
327
328 #[test]
329 fn test_pool() {
330 async_std::task::block_on(async {
331 let pool = SqlPool::<DefaultToBuckyError>::open("sqlite://:memory:", 5).await.unwrap();
332
333 let mut sqlx_conn = pool.get_conn().await.unwrap();
334 let create_table = r#"CREATE TABLE IF NOT EXISTS desc_extra (
335 "obj_id" char(45) PRIMARY KEY NOT NULL UNIQUE,
336 "rent_arrears" INTEGER,
337 "rent_arrears_count" INTEGER,
338 "rent_value" INTEGER,
339 "coin_id" INTEGER,
340 "data_len" INTEGER,
341 "other_charge_balance" INTEGER);"#;
342 sqlx_conn.execute_sql(sql_query(create_table)).await.unwrap();
343
344 sqlx_conn.begin_transaction().await.unwrap();
345 let insert = r#"insert into desc_extra (obj_id,
346 rent_arrears,
347 rent_arrears_count,
348 rent_value,
349 coin_id,
350 data_len,
351 other_charge_balance) values (
352 "test", 1, 1, 2, 3, 4, 5)"#;
353 sqlx_conn.execute_sql(sql_query(insert)).await.unwrap();
354 sqlx_conn.rollback_transaction().await.unwrap();
355
356 let mut sqlx_conn = pool.get_conn().await.unwrap();
357 let query = sqlx::query("select * from desc_extra where obj_id = ?").bind("test");
358 let row = sqlx_conn.query_all(query).await.unwrap();
359 assert_eq!(row.len(), 0);
360
361 let mut sqlx_conn = pool.get_conn().await.unwrap();
362 sqlx_conn.begin_transaction().await.unwrap();
363 let insert = r#"insert into desc_extra (obj_id,
364 rent_arrears,
365 rent_arrears_count,
366 rent_value,
367 coin_id,
368 data_len,
369 other_charge_balance) values (
370 "test", 1, 1, 2, 3, 4, 5)"#;
371 sqlx_conn.execute_sql(sql_query(insert)).await.unwrap();
372 sqlx_conn.commit_transaction().await.unwrap();
373
374 let query = sqlx::query("select * from desc_extra where obj_id = ?").bind("test");
375 let row = sqlx_conn.query_one(query).await.unwrap();
376 let id: String = row.get("obj_id");
377 assert_eq!(id, "test".to_owned());
378 let coin_id: i32 = row.get("coin_id");
379 assert_eq!(coin_id, 3);
380
381 let row = sqlx_conn.query_one(sqlx::query("select * from desc_extra where obj_id = ?").bind("test")).await.unwrap();
382 let id: String = row.get("obj_id");
383 assert_eq!(id, "test".to_owned());
384 let coin_id: i32 = row.get("coin_id");
385 assert_eq!(coin_id, 3);
386 })
387 }
388}