async_diesel/
lib.rs

1use async_std::task;
2use async_trait::async_trait;
3use diesel::{
4    connection::SimpleConnection,
5    dsl::Limit,
6    query_dsl::{
7        methods::{ExecuteDsl, LimitDsl, LoadQuery},
8        RunQueryDsl,
9    },
10    r2d2::{ConnectionManager, Pool},
11    result::QueryResult,
12    Connection,
13};
14use std::{error::Error as StdError, fmt};
15
16pub type AsyncResult<R> = Result<R, AsyncError>;
17
18#[derive(Debug)]
19pub enum AsyncError {
20    // Failed to checkout a connection
21    Checkout(r2d2::Error),
22
23    // The query failed in some way
24    Error(diesel::result::Error),
25}
26
27pub trait OptionalExtension<T> {
28    fn optional(self) -> Result<Option<T>, AsyncError>;
29}
30
31impl<T> OptionalExtension<T> for AsyncResult<T> {
32    fn optional(self) -> Result<Option<T>, AsyncError> {
33        match self {
34            Ok(value) => Ok(Some(value)),
35            Err(AsyncError::Error(diesel::result::Error::NotFound)) => Ok(None),
36            Err(e) => Err(e),
37        }
38    }
39}
40
41impl fmt::Display for AsyncError {
42    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
43        match *self {
44            AsyncError::Checkout(ref err) => err.fmt(f),
45            AsyncError::Error(ref err) => err.fmt(f),
46        }
47    }
48}
49
50impl StdError for AsyncError {
51    fn source(&self) -> Option<&(dyn StdError + 'static)> {
52        match *self {
53            AsyncError::Checkout(ref err) => Some(err),
54            AsyncError::Error(ref err) => Some(err),
55        }
56    }
57}
58
59#[async_trait]
60pub trait AsyncSimpleConnection<Conn>
61where
62    Conn: 'static + SimpleConnection,
63{
64    async fn batch_execute_async(&self, query: &str) -> AsyncResult<()>;
65}
66
67#[async_trait]
68impl<Conn> AsyncSimpleConnection<Conn> for Pool<ConnectionManager<Conn>>
69where
70    Conn: 'static + Connection,
71{
72    #[inline]
73    async fn batch_execute_async(&self, query: &str) -> AsyncResult<()> {
74        let self_ = self.clone();
75        let query = query.to_string();
76        task::spawn_blocking(move || {
77            let conn = self_.get().map_err(AsyncError::Checkout)?;
78            conn.batch_execute(&query).map_err(AsyncError::Error)
79        })
80        .await
81    }
82}
83
84#[async_trait]
85pub trait AsyncConnection<Conn>: AsyncSimpleConnection<Conn>
86where
87    Conn: 'static + Connection,
88{
89    async fn run<R, Func>(&self, f: Func) -> AsyncResult<R>
90    where
91        R: 'static + Send,
92        Func: 'static + FnOnce(&Conn) -> QueryResult<R> + Send;
93
94    async fn transaction<R, Func>(&self, f: Func) -> AsyncResult<R>
95    where
96        R: 'static + Send,
97        Func: 'static + FnOnce(&Conn) -> QueryResult<R> + Send;
98}
99
100#[async_trait]
101impl<Conn> AsyncConnection<Conn> for Pool<ConnectionManager<Conn>>
102where
103    Conn: 'static + Connection,
104{
105    #[inline]
106    async fn run<R, Func>(&self, f: Func) -> AsyncResult<R>
107    where
108        R: 'static + Send,
109        Func: 'static + FnOnce(&Conn) -> QueryResult<R> + Send,
110    {
111        let self_ = self.clone();
112        task::spawn_blocking(move || {
113            let conn = self_.get().map_err(AsyncError::Checkout)?;
114            f(&*conn).map_err(AsyncError::Error)
115        })
116        .await
117    }
118
119    #[inline]
120    async fn transaction<R, Func>(&self, f: Func) -> AsyncResult<R>
121    where
122        R: 'static + Send,
123        Func: 'static + FnOnce(&Conn) -> QueryResult<R> + Send,
124    {
125        let self_ = self.clone();
126        task::spawn_blocking(move || {
127            let conn = self_.get().map_err(AsyncError::Checkout)?;
128            conn.transaction(|| f(&*conn)).map_err(AsyncError::Error)
129        })
130        .await
131    }
132}
133
134#[async_trait]
135pub trait AsyncRunQueryDsl<Conn, AsyncConn>
136where
137    Conn: 'static + Connection,
138{
139    async fn execute_async(self, asc: &AsyncConn) -> AsyncResult<usize>
140    where
141        Self: ExecuteDsl<Conn>;
142
143    async fn load_async<U>(self, asc: &AsyncConn) -> AsyncResult<Vec<U>>
144    where
145        U: 'static + Send,
146        Self: LoadQuery<Conn, U>;
147
148    async fn get_result_async<U>(self, asc: &AsyncConn) -> AsyncResult<U>
149    where
150        U: 'static + Send,
151        Self: LoadQuery<Conn, U>;
152
153    async fn get_results_async<U>(self, asc: &AsyncConn) -> AsyncResult<Vec<U>>
154    where
155        U: 'static + Send,
156        Self: LoadQuery<Conn, U>;
157
158    async fn first_async<U>(self, asc: &AsyncConn) -> AsyncResult<U>
159    where
160        U: 'static + Send,
161        Self: LimitDsl,
162        Limit<Self>: LoadQuery<Conn, U>;
163}
164
165#[async_trait]
166impl<T, Conn> AsyncRunQueryDsl<Conn, Pool<ConnectionManager<Conn>>> for T
167where
168    T: 'static + Send + RunQueryDsl<Conn>,
169    Conn: 'static + Connection,
170{
171    async fn execute_async(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<usize>
172    where
173        Self: ExecuteDsl<Conn>,
174    {
175        asc.run(|conn| self.execute(&*conn)).await
176    }
177
178    async fn load_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<Vec<U>>
179    where
180        U: 'static + Send,
181        Self: LoadQuery<Conn, U>,
182    {
183        asc.run(|conn| self.load(&*conn)).await
184    }
185
186    async fn get_result_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<U>
187    where
188        U: 'static + Send,
189        Self: LoadQuery<Conn, U>,
190    {
191        asc.run(|conn| self.get_result(&*conn)).await
192    }
193
194    async fn get_results_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<Vec<U>>
195    where
196        U: 'static + Send,
197        Self: LoadQuery<Conn, U>,
198    {
199        asc.run(|conn| self.get_results(&*conn)).await
200    }
201
202    async fn first_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<U>
203    where
204        U: 'static + Send,
205        Self: LimitDsl,
206        Limit<Self>: LoadQuery<Conn, U>,
207    {
208        asc.run(|conn| self.first(&*conn)).await
209    }
210}