1use async_std::task;
2use async_trait::async_trait;
3use diesel::{
4 connection::SimpleConnection,
5 dsl::Limit,
6 query_dsl::{
7 methods::{ExecuteDsl, LimitDsl, LoadQuery},
8 RunQueryDsl,
9 },
10 r2d2::{ConnectionManager, Pool},
11 result::QueryResult,
12 Connection,
13};
14use std::{error::Error as StdError, fmt};
15
16pub type AsyncResult<R> = Result<R, AsyncError>;
17
18#[derive(Debug)]
19pub enum AsyncError {
20 Checkout(r2d2::Error),
22
23 Error(diesel::result::Error),
25}
26
27pub trait OptionalExtension<T> {
28 fn optional(self) -> Result<Option<T>, AsyncError>;
29}
30
31impl<T> OptionalExtension<T> for AsyncResult<T> {
32 fn optional(self) -> Result<Option<T>, AsyncError> {
33 match self {
34 Ok(value) => Ok(Some(value)),
35 Err(AsyncError::Error(diesel::result::Error::NotFound)) => Ok(None),
36 Err(e) => Err(e),
37 }
38 }
39}
40
41impl fmt::Display for AsyncError {
42 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
43 match *self {
44 AsyncError::Checkout(ref err) => err.fmt(f),
45 AsyncError::Error(ref err) => err.fmt(f),
46 }
47 }
48}
49
50impl StdError for AsyncError {
51 fn source(&self) -> Option<&(dyn StdError + 'static)> {
52 match *self {
53 AsyncError::Checkout(ref err) => Some(err),
54 AsyncError::Error(ref err) => Some(err),
55 }
56 }
57}
58
59#[async_trait]
60pub trait AsyncSimpleConnection<Conn>
61where
62 Conn: 'static + SimpleConnection,
63{
64 async fn batch_execute_async(&self, query: &str) -> AsyncResult<()>;
65}
66
67#[async_trait]
68impl<Conn> AsyncSimpleConnection<Conn> for Pool<ConnectionManager<Conn>>
69where
70 Conn: 'static + Connection,
71{
72 #[inline]
73 async fn batch_execute_async(&self, query: &str) -> AsyncResult<()> {
74 let self_ = self.clone();
75 let query = query.to_string();
76 task::spawn_blocking(move || {
77 let conn = self_.get().map_err(AsyncError::Checkout)?;
78 conn.batch_execute(&query).map_err(AsyncError::Error)
79 })
80 .await
81 }
82}
83
84#[async_trait]
85pub trait AsyncConnection<Conn>: AsyncSimpleConnection<Conn>
86where
87 Conn: 'static + Connection,
88{
89 async fn run<R, Func>(&self, f: Func) -> AsyncResult<R>
90 where
91 R: 'static + Send,
92 Func: 'static + FnOnce(&Conn) -> QueryResult<R> + Send;
93
94 async fn transaction<R, Func>(&self, f: Func) -> AsyncResult<R>
95 where
96 R: 'static + Send,
97 Func: 'static + FnOnce(&Conn) -> QueryResult<R> + Send;
98}
99
100#[async_trait]
101impl<Conn> AsyncConnection<Conn> for Pool<ConnectionManager<Conn>>
102where
103 Conn: 'static + Connection,
104{
105 #[inline]
106 async fn run<R, Func>(&self, f: Func) -> AsyncResult<R>
107 where
108 R: 'static + Send,
109 Func: 'static + FnOnce(&Conn) -> QueryResult<R> + Send,
110 {
111 let self_ = self.clone();
112 task::spawn_blocking(move || {
113 let conn = self_.get().map_err(AsyncError::Checkout)?;
114 f(&*conn).map_err(AsyncError::Error)
115 })
116 .await
117 }
118
119 #[inline]
120 async fn transaction<R, Func>(&self, f: Func) -> AsyncResult<R>
121 where
122 R: 'static + Send,
123 Func: 'static + FnOnce(&Conn) -> QueryResult<R> + Send,
124 {
125 let self_ = self.clone();
126 task::spawn_blocking(move || {
127 let conn = self_.get().map_err(AsyncError::Checkout)?;
128 conn.transaction(|| f(&*conn)).map_err(AsyncError::Error)
129 })
130 .await
131 }
132}
133
134#[async_trait]
135pub trait AsyncRunQueryDsl<Conn, AsyncConn>
136where
137 Conn: 'static + Connection,
138{
139 async fn execute_async(self, asc: &AsyncConn) -> AsyncResult<usize>
140 where
141 Self: ExecuteDsl<Conn>;
142
143 async fn load_async<U>(self, asc: &AsyncConn) -> AsyncResult<Vec<U>>
144 where
145 U: 'static + Send,
146 Self: LoadQuery<Conn, U>;
147
148 async fn get_result_async<U>(self, asc: &AsyncConn) -> AsyncResult<U>
149 where
150 U: 'static + Send,
151 Self: LoadQuery<Conn, U>;
152
153 async fn get_results_async<U>(self, asc: &AsyncConn) -> AsyncResult<Vec<U>>
154 where
155 U: 'static + Send,
156 Self: LoadQuery<Conn, U>;
157
158 async fn first_async<U>(self, asc: &AsyncConn) -> AsyncResult<U>
159 where
160 U: 'static + Send,
161 Self: LimitDsl,
162 Limit<Self>: LoadQuery<Conn, U>;
163}
164
165#[async_trait]
166impl<T, Conn> AsyncRunQueryDsl<Conn, Pool<ConnectionManager<Conn>>> for T
167where
168 T: 'static + Send + RunQueryDsl<Conn>,
169 Conn: 'static + Connection,
170{
171 async fn execute_async(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<usize>
172 where
173 Self: ExecuteDsl<Conn>,
174 {
175 asc.run(|conn| self.execute(&*conn)).await
176 }
177
178 async fn load_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<Vec<U>>
179 where
180 U: 'static + Send,
181 Self: LoadQuery<Conn, U>,
182 {
183 asc.run(|conn| self.load(&*conn)).await
184 }
185
186 async fn get_result_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<U>
187 where
188 U: 'static + Send,
189 Self: LoadQuery<Conn, U>,
190 {
191 asc.run(|conn| self.get_result(&*conn)).await
192 }
193
194 async fn get_results_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<Vec<U>>
195 where
196 U: 'static + Send,
197 Self: LoadQuery<Conn, U>,
198 {
199 asc.run(|conn| self.get_results(&*conn)).await
200 }
201
202 async fn first_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<U>
203 where
204 U: 'static + Send,
205 Self: LimitDsl,
206 Limit<Self>: LoadQuery<Conn, U>,
207 {
208 asc.run(|conn| self.first(&*conn)).await
209 }
210}