elefant_tools/
postgres_client_wrapper.rs

1use crate::Result;
2use bytes::Buf;
3use std::fmt::Display;
4use std::ops::Deref;
5use tokio::task::JoinHandle;
6use tokio_postgres::row::RowIndex;
7use tokio_postgres::types::FromSqlOwned;
8use tokio_postgres::{Client, CopyInSink, CopyOutStream, NoTls, Row};
9use tracing::instrument;
10
11/// A wrapper around tokio_postgres::Client, which provides a more convenient interface for working with the client.
12pub struct PostgresClientWrapper {
13    /// The actual client
14    client: PostgresClient,
15    /// The version of the postgres server, reduced by 1000. For example, version 15.0 is represented as 150.
16    version: i32,
17    /// The connection string used to connect to the server
18    connection_string: String,
19}
20
21impl PostgresClientWrapper {
22    /// Create a new PostgresClientWrapper.
23    ///
24    /// This will connect to the postgres server to figure out the version of the server.
25    /// If the version is less than 12, an error is returned.
26    #[instrument(skip_all)]
27    pub async fn new(connection_string: &str) -> Result<Self> {
28        let client = PostgresClient::new(connection_string).await?;
29
30        let version = match &client
31            .client
32            .simple_query("SHOW server_version_num;")
33            .await?[0]
34        {
35            tokio_postgres::SimpleQueryMessage::Row(row) => {
36                let version: i32 = row
37                    .get(0)
38                    .expect("failed to get version from row")
39                    .parse()
40                    .expect("failed to parse version");
41                if version < 120000 {
42                    return Err(crate::ElefantToolsError::UnsupportedPostgresVersion(
43                        version,
44                    ));
45                }
46                version / 1000
47            }
48            _ => return Err(crate::ElefantToolsError::InvalidPostgresVersionResponse),
49        };
50
51        Ok(PostgresClientWrapper {
52            client,
53            version,
54            connection_string: connection_string.to_string(),
55        })
56    }
57
58    /// Get the version of the postgres server
59    pub fn version(&self) -> i32 {
60        self.version
61    }
62
63    /// Create another connection to the same server
64    pub async fn create_another_connection(&self) -> Result<Self> {
65        let client = PostgresClient::new(&self.connection_string).await?;
66        Ok(PostgresClientWrapper {
67            client,
68            version: self.version,
69            connection_string: self.connection_string.clone(),
70        })
71    }
72
73    #[cfg(test)]
74    pub(crate) fn underlying_connection(&self) -> &Client {
75        &self.client.client
76    }
77}
78
79impl Deref for PostgresClientWrapper {
80    type Target = PostgresClient;
81
82    fn deref(&self) -> &Self::Target {
83        &self.client
84    }
85}
86
87/// A wrapper around tokio_postgres::Client, which provides a more convenient interface for working with the client.
88pub struct PostgresClient {
89    client: Client,
90    join_handle: JoinHandle<Result<()>>,
91}
92
93impl PostgresClient {
94    /// Create a new PostgresClient.
95    ///
96    /// This will establish a connection to the postgres server.
97    pub async fn new(connection_string: &str) -> Result<Self> {
98        let (client, connection) = tokio_postgres::connect(connection_string, NoTls).await?;
99
100        // The connection object performs the actual communication with the database,
101        // so spawn it off to run on its own.
102        let join_handle = tokio::spawn(async move {
103            match connection.await {
104                Err(e) => Err(crate::ElefantToolsError::PostgresError(e)),
105                Ok(_) => Ok(()),
106            }
107        });
108
109        Ok(PostgresClient {
110            client,
111            join_handle,
112        })
113    }
114
115    /// Execute a query that does not return any results.
116    pub async fn execute_non_query(&self, sql: &str) -> Result {
117        self.client.batch_execute(sql).await.map_err(|e| {
118            crate::ElefantToolsError::PostgresErrorWithQuery {
119                source: e,
120                query: sql.to_string(),
121            }
122        })?;
123
124        Ok(())
125    }
126
127    /// Execute a query that returns results.
128    pub async fn get_results<T: FromRow>(&self, sql: &str) -> Result<Vec<T>> {
129        let query_results = self.client.query(sql, &[]).await.map_err(|e| {
130            crate::ElefantToolsError::PostgresErrorWithQuery {
131                source: e,
132                query: sql.to_string(),
133            }
134        })?;
135
136        let mut output = Vec::with_capacity(query_results.len());
137
138        for row in query_results.into_iter() {
139            output.push(T::from_row(row)?);
140        }
141
142        Ok(output)
143    }
144
145    /// Execute a query that returns a single result.
146    pub async fn get_result<T: FromRow>(&self, sql: &str) -> Result<T> {
147        let results = self.get_results(sql).await?;
148        if results.len() != 1 {
149            return Err(crate::ElefantToolsError::InvalidNumberOfResults {
150                actual: results.len(),
151                expected: 1,
152            });
153        }
154
155        // Safe, we have just checked the length of the vector
156        let r = results.into_iter().next().unwrap();
157
158        Ok(r)
159    }
160
161    /// Execute a query that returns a single column of results.
162    pub async fn get_single_results<T: FromSqlOwned>(&self, sql: &str) -> Result<Vec<T>> {
163        let r = self
164            .get_results::<(T,)>(sql)
165            .await?
166            .into_iter()
167            .map(|t| t.0)
168            .collect();
169
170        Ok(r)
171    }
172
173    /// Execute a query that returns a single column of a single row of results.
174    pub async fn get_single_result<T: FromSqlOwned>(&self, sql: &str) -> Result<T> {
175        let result = self.get_result::<(T,)>(sql).await?;
176        Ok(result.0)
177    }
178
179    /// Starts a COPY IN operation.
180    pub async fn copy_in<U>(&self, sql: &str) -> Result<CopyInSink<U>>
181    where
182        U: Buf + Send + 'static,
183    {
184        let sink = self.client.copy_in(sql).await?;
185        Ok(sink)
186    }
187
188    /// Starts a COPY OUT operation.
189    pub async fn copy_out(&self, sql: &str) -> Result<CopyOutStream> {
190        let stream = self.client.copy_out(sql).await?;
191        Ok(stream)
192    }
193}
194
195impl Drop for PostgresClient {
196    fn drop(&mut self) {
197        self.join_handle.abort();
198    }
199}
200
201/// Provides a more convenient way of reading an
202/// entire row from a tokio_postgres::Row into a type.
203pub trait FromRow: Sized {
204    fn from_row(row: Row) -> Result<Self>;
205}
206
207impl<T1: FromSqlOwned> FromRow for (T1,) {
208    fn from_row(row: Row) -> Result<Self> {
209        Ok((row.try_get(0)?,))
210    }
211}
212
213impl<T1: FromSqlOwned, T2: FromSqlOwned> FromRow for (T1, T2) {
214    fn from_row(row: Row) -> Result<Self> {
215        Ok((row.try_get(0)?, row.try_get(1)?))
216    }
217}
218
219impl<T1: FromSqlOwned, T2: FromSqlOwned, T3: FromSqlOwned> FromRow for (T1, T2, T3) {
220    fn from_row(row: Row) -> Result<Self> {
221        Ok((row.try_get(0)?, row.try_get(1)?, row.try_get(2)?))
222    }
223}
224
225impl<T1: FromSqlOwned, T2: FromSqlOwned, T3: FromSqlOwned, T4: FromSqlOwned> FromRow
226    for (T1, T2, T3, T4)
227{
228    fn from_row(row: Row) -> Result<Self> {
229        Ok((
230            row.try_get(0)?,
231            row.try_get(1)?,
232            row.try_get(2)?,
233            row.try_get(3)?,
234        ))
235    }
236}
237
238impl<T1: FromSqlOwned, T2: FromSqlOwned, T3: FromSqlOwned, T4: FromSqlOwned, T5: FromSqlOwned>
239    FromRow for (T1, T2, T3, T4, T5)
240{
241    fn from_row(row: Row) -> Result<Self> {
242        Ok((
243            row.try_get(0)?,
244            row.try_get(1)?,
245            row.try_get(2)?,
246            row.try_get(3)?,
247            row.try_get(4)?,
248        ))
249    }
250}
251
252impl<
253        T1: FromSqlOwned,
254        T2: FromSqlOwned,
255        T3: FromSqlOwned,
256        T4: FromSqlOwned,
257        T5: FromSqlOwned,
258        T6: FromSqlOwned,
259    > FromRow for (T1, T2, T3, T4, T5, T6)
260{
261    fn from_row(row: Row) -> Result<Self> {
262        Ok((
263            row.try_get(0)?,
264            row.try_get(1)?,
265            row.try_get(2)?,
266            row.try_get(3)?,
267            row.try_get(4)?,
268            row.try_get(5)?,
269        ))
270    }
271}
272
273/// A trait for converting a postgres char to a Rust type.
274pub(crate) trait FromPgChar: Sized {
275    fn from_pg_char(c: char) -> std::result::Result<Self, crate::ElefantToolsError>;
276}
277
278/// Provides extension methods on tokio_postgres::Row for working with enums that implements FromPgChar.
279pub(crate) trait RowEnumExt {
280    /// Get an enum value from a row.
281    fn try_get_enum_value<T: FromPgChar, I: RowIndex + Display>(&self, idx: I) -> Result<T>;
282    /// Get an optional enum value from a row, aka `Option<T>`.
283    fn try_get_opt_enum_value<T: FromPgChar, I: RowIndex + Display>(
284        &self,
285        idx: I,
286    ) -> Result<Option<T>>;
287}
288
289impl RowEnumExt for Row {
290    fn try_get_enum_value<T: FromPgChar, I: RowIndex + Display>(&self, idx: I) -> Result<T> {
291        let value: i8 = self.try_get(idx)?;
292        let c = value as u8 as char;
293        T::from_pg_char(c)
294    }
295
296    fn try_get_opt_enum_value<T: FromPgChar, I: RowIndex + Display>(
297        &self,
298        idx: I,
299    ) -> Result<Option<T>> {
300        let value: Option<i8> = self.try_get(idx)?;
301        match value {
302            Some(value) => {
303                let c = value as u8 as char;
304                Ok(Some(T::from_pg_char(c)?))
305            }
306            None => Ok(None),
307        }
308    }
309}