1use crate::Result;
2use bytes::Buf;
3use std::fmt::Display;
4use std::ops::Deref;
5use tokio::task::JoinHandle;
6use tokio_postgres::row::RowIndex;
7use tokio_postgres::types::FromSqlOwned;
8use tokio_postgres::{Client, CopyInSink, CopyOutStream, NoTls, Row};
9use tracing::instrument;
10
11pub struct PostgresClientWrapper {
13 client: PostgresClient,
15 version: i32,
17 connection_string: String,
19}
20
21impl PostgresClientWrapper {
22 #[instrument(skip_all)]
27 pub async fn new(connection_string: &str) -> Result<Self> {
28 let client = PostgresClient::new(connection_string).await?;
29
30 let version = match &client
31 .client
32 .simple_query("SHOW server_version_num;")
33 .await?[0]
34 {
35 tokio_postgres::SimpleQueryMessage::Row(row) => {
36 let version: i32 = row
37 .get(0)
38 .expect("failed to get version from row")
39 .parse()
40 .expect("failed to parse version");
41 if version < 120000 {
42 return Err(crate::ElefantToolsError::UnsupportedPostgresVersion(
43 version,
44 ));
45 }
46 version / 1000
47 }
48 _ => return Err(crate::ElefantToolsError::InvalidPostgresVersionResponse),
49 };
50
51 Ok(PostgresClientWrapper {
52 client,
53 version,
54 connection_string: connection_string.to_string(),
55 })
56 }
57
58 pub fn version(&self) -> i32 {
60 self.version
61 }
62
63 pub async fn create_another_connection(&self) -> Result<Self> {
65 let client = PostgresClient::new(&self.connection_string).await?;
66 Ok(PostgresClientWrapper {
67 client,
68 version: self.version,
69 connection_string: self.connection_string.clone(),
70 })
71 }
72
73 #[cfg(test)]
74 pub(crate) fn underlying_connection(&self) -> &Client {
75 &self.client.client
76 }
77}
78
79impl Deref for PostgresClientWrapper {
80 type Target = PostgresClient;
81
82 fn deref(&self) -> &Self::Target {
83 &self.client
84 }
85}
86
87pub struct PostgresClient {
89 client: Client,
90 join_handle: JoinHandle<Result<()>>,
91}
92
93impl PostgresClient {
94 pub async fn new(connection_string: &str) -> Result<Self> {
98 let (client, connection) = tokio_postgres::connect(connection_string, NoTls).await?;
99
100 let join_handle = tokio::spawn(async move {
103 match connection.await {
104 Err(e) => Err(crate::ElefantToolsError::PostgresError(e)),
105 Ok(_) => Ok(()),
106 }
107 });
108
109 Ok(PostgresClient {
110 client,
111 join_handle,
112 })
113 }
114
115 pub async fn execute_non_query(&self, sql: &str) -> Result {
117 self.client.batch_execute(sql).await.map_err(|e| {
118 crate::ElefantToolsError::PostgresErrorWithQuery {
119 source: e,
120 query: sql.to_string(),
121 }
122 })?;
123
124 Ok(())
125 }
126
127 pub async fn get_results<T: FromRow>(&self, sql: &str) -> Result<Vec<T>> {
129 let query_results = self.client.query(sql, &[]).await.map_err(|e| {
130 crate::ElefantToolsError::PostgresErrorWithQuery {
131 source: e,
132 query: sql.to_string(),
133 }
134 })?;
135
136 let mut output = Vec::with_capacity(query_results.len());
137
138 for row in query_results.into_iter() {
139 output.push(T::from_row(row)?);
140 }
141
142 Ok(output)
143 }
144
145 pub async fn get_result<T: FromRow>(&self, sql: &str) -> Result<T> {
147 let results = self.get_results(sql).await?;
148 if results.len() != 1 {
149 return Err(crate::ElefantToolsError::InvalidNumberOfResults {
150 actual: results.len(),
151 expected: 1,
152 });
153 }
154
155 let r = results.into_iter().next().unwrap();
157
158 Ok(r)
159 }
160
161 pub async fn get_single_results<T: FromSqlOwned>(&self, sql: &str) -> Result<Vec<T>> {
163 let r = self
164 .get_results::<(T,)>(sql)
165 .await?
166 .into_iter()
167 .map(|t| t.0)
168 .collect();
169
170 Ok(r)
171 }
172
173 pub async fn get_single_result<T: FromSqlOwned>(&self, sql: &str) -> Result<T> {
175 let result = self.get_result::<(T,)>(sql).await?;
176 Ok(result.0)
177 }
178
179 pub async fn copy_in<U>(&self, sql: &str) -> Result<CopyInSink<U>>
181 where
182 U: Buf + Send + 'static,
183 {
184 let sink = self.client.copy_in(sql).await?;
185 Ok(sink)
186 }
187
188 pub async fn copy_out(&self, sql: &str) -> Result<CopyOutStream> {
190 let stream = self.client.copy_out(sql).await?;
191 Ok(stream)
192 }
193}
194
195impl Drop for PostgresClient {
196 fn drop(&mut self) {
197 self.join_handle.abort();
198 }
199}
200
201pub trait FromRow: Sized {
204 fn from_row(row: Row) -> Result<Self>;
205}
206
207impl<T1: FromSqlOwned> FromRow for (T1,) {
208 fn from_row(row: Row) -> Result<Self> {
209 Ok((row.try_get(0)?,))
210 }
211}
212
213impl<T1: FromSqlOwned, T2: FromSqlOwned> FromRow for (T1, T2) {
214 fn from_row(row: Row) -> Result<Self> {
215 Ok((row.try_get(0)?, row.try_get(1)?))
216 }
217}
218
219impl<T1: FromSqlOwned, T2: FromSqlOwned, T3: FromSqlOwned> FromRow for (T1, T2, T3) {
220 fn from_row(row: Row) -> Result<Self> {
221 Ok((row.try_get(0)?, row.try_get(1)?, row.try_get(2)?))
222 }
223}
224
225impl<T1: FromSqlOwned, T2: FromSqlOwned, T3: FromSqlOwned, T4: FromSqlOwned> FromRow
226 for (T1, T2, T3, T4)
227{
228 fn from_row(row: Row) -> Result<Self> {
229 Ok((
230 row.try_get(0)?,
231 row.try_get(1)?,
232 row.try_get(2)?,
233 row.try_get(3)?,
234 ))
235 }
236}
237
238impl<T1: FromSqlOwned, T2: FromSqlOwned, T3: FromSqlOwned, T4: FromSqlOwned, T5: FromSqlOwned>
239 FromRow for (T1, T2, T3, T4, T5)
240{
241 fn from_row(row: Row) -> Result<Self> {
242 Ok((
243 row.try_get(0)?,
244 row.try_get(1)?,
245 row.try_get(2)?,
246 row.try_get(3)?,
247 row.try_get(4)?,
248 ))
249 }
250}
251
252impl<
253 T1: FromSqlOwned,
254 T2: FromSqlOwned,
255 T3: FromSqlOwned,
256 T4: FromSqlOwned,
257 T5: FromSqlOwned,
258 T6: FromSqlOwned,
259 > FromRow for (T1, T2, T3, T4, T5, T6)
260{
261 fn from_row(row: Row) -> Result<Self> {
262 Ok((
263 row.try_get(0)?,
264 row.try_get(1)?,
265 row.try_get(2)?,
266 row.try_get(3)?,
267 row.try_get(4)?,
268 row.try_get(5)?,
269 ))
270 }
271}
272
273pub(crate) trait FromPgChar: Sized {
275 fn from_pg_char(c: char) -> std::result::Result<Self, crate::ElefantToolsError>;
276}
277
278pub(crate) trait RowEnumExt {
280 fn try_get_enum_value<T: FromPgChar, I: RowIndex + Display>(&self, idx: I) -> Result<T>;
282 fn try_get_opt_enum_value<T: FromPgChar, I: RowIndex + Display>(
284 &self,
285 idx: I,
286 ) -> Result<Option<T>>;
287}
288
289impl RowEnumExt for Row {
290 fn try_get_enum_value<T: FromPgChar, I: RowIndex + Display>(&self, idx: I) -> Result<T> {
291 let value: i8 = self.try_get(idx)?;
292 let c = value as u8 as char;
293 T::from_pg_char(c)
294 }
295
296 fn try_get_opt_enum_value<T: FromPgChar, I: RowIndex + Display>(
297 &self,
298 idx: I,
299 ) -> Result<Option<T>> {
300 let value: Option<i8> = self.try_get(idx)?;
301 match value {
302 Some(value) => {
303 let c = value as u8 as char;
304 Ok(Some(T::from_pg_char(c)?))
305 }
306 None => Ok(None),
307 }
308 }
309}