1use crate::{Client, QueryExecutor, Row, ToSqlParam, Transaction};
2
3pub struct PgRow(pub tokio_postgres::Row);
6
7impl Row for PgRow {
8 fn get_i16(&self, idx: usize) -> i16 { self.0.get(idx) }
9 fn get_i32(&self, idx: usize) -> i32 { self.0.get(idx) }
10 fn get_i64(&self, idx: usize) -> i64 { self.0.get(idx) }
11 fn get_f32(&self, idx: usize) -> f32 { self.0.get(idx) }
12 fn get_f64(&self, idx: usize) -> f64 { self.0.get(idx) }
13 fn get_string(&self, idx: usize) -> String { self.0.get(idx) }
14 fn get_bool(&self, idx: usize) -> bool { self.0.get(idx) }
15 fn get_uuid(&self, idx: usize) -> uuid::Uuid { self.0.get(idx) }
16 fn get_timestamp(&self, idx: usize) -> chrono::NaiveDateTime { self.0.get(idx) }
17 fn get_timestamptz(&self, idx: usize) -> chrono::DateTime<chrono::Utc> { self.0.get(idx) }
18 fn get_date(&self, idx: usize) -> chrono::NaiveDate { self.0.get(idx) }
19 fn get_decimal(&self, idx: usize) -> rust_decimal::Decimal { self.0.get(idx) }
20 fn get_json(&self, idx: usize) -> serde_json::Value { self.0.get(idx) }
21
22 fn get_opt_i16(&self, idx: usize) -> Option<i16> { self.0.get(idx) }
23 fn get_opt_i32(&self, idx: usize) -> Option<i32> { self.0.get(idx) }
24 fn get_opt_i64(&self, idx: usize) -> Option<i64> { self.0.get(idx) }
25 fn get_opt_f32(&self, idx: usize) -> Option<f32> { self.0.get(idx) }
26 fn get_opt_f64(&self, idx: usize) -> Option<f64> { self.0.get(idx) }
27 fn get_opt_string(&self, idx: usize) -> Option<String> { self.0.get(idx) }
28 fn get_opt_bool(&self, idx: usize) -> Option<bool> { self.0.get(idx) }
29 fn get_opt_uuid(&self, idx: usize) -> Option<uuid::Uuid> { self.0.get(idx) }
30 fn get_opt_timestamp(&self, idx: usize) -> Option<chrono::NaiveDateTime> { self.0.get(idx) }
31 fn get_opt_timestamptz(&self, idx: usize) -> Option<chrono::DateTime<chrono::Utc>> { self.0.get(idx) }
32 fn get_opt_date(&self, idx: usize) -> Option<chrono::NaiveDate> { self.0.get(idx) }
33 fn get_opt_decimal(&self, idx: usize) -> Option<rust_decimal::Decimal> { self.0.get(idx) }
34 fn get_opt_json(&self, idx: usize) -> Option<serde_json::Value> { self.0.get(idx) }
35
36 fn get_vec_i16(&self, idx: usize) -> Vec<i16> { self.0.get(idx) }
37 fn get_vec_i32(&self, idx: usize) -> Vec<i32> { self.0.get(idx) }
38 fn get_vec_i64(&self, idx: usize) -> Vec<i64> { self.0.get(idx) }
39 fn get_vec_f32(&self, idx: usize) -> Vec<f32> { self.0.get(idx) }
40 fn get_vec_f64(&self, idx: usize) -> Vec<f64> { self.0.get(idx) }
41 fn get_vec_string(&self, idx: usize) -> Vec<String> { self.0.get(idx) }
42 fn get_vec_bool(&self, idx: usize) -> Vec<bool> { self.0.get(idx) }
43 fn get_vec_uuid(&self, idx: usize) -> Vec<uuid::Uuid> { self.0.get(idx) }
44}
45
46#[derive(Debug)]
51pub enum OwnedParam {
52 I16(i16),
53 I32(i32),
54 I64(i64),
55 F32(f32),
56 F64(f64),
57 Bool(bool),
58 String(String),
59 Uuid(uuid::Uuid),
60 NaiveDateTime(chrono::NaiveDateTime),
61 DateTimeUtc(chrono::DateTime<chrono::Utc>),
62 NaiveDate(chrono::NaiveDate),
63 Decimal(rust_decimal::Decimal),
64 Json(serde_json::Value),
65 OptI16(Option<i16>),
66 OptI32(Option<i32>),
67 OptI64(Option<i64>),
68 OptF32(Option<f32>),
69 OptF64(Option<f64>),
70 OptBool(Option<bool>),
71 OptString(Option<String>),
72 OptUuid(Option<uuid::Uuid>),
73 OptNaiveDateTime(Option<chrono::NaiveDateTime>),
74 OptDateTimeUtc(Option<chrono::DateTime<chrono::Utc>>),
75 OptNaiveDate(Option<chrono::NaiveDate>),
76 OptDecimal(Option<rust_decimal::Decimal>),
77 OptJson(Option<serde_json::Value>),
78 VecI16(Vec<i16>),
79 VecI32(Vec<i32>),
80 VecI64(Vec<i64>),
81 VecF32(Vec<f32>),
82 VecF64(Vec<f64>),
83 VecBool(Vec<bool>),
84 VecString(Vec<String>),
85 VecUuid(Vec<uuid::Uuid>),
86}
87
88impl tokio_postgres::types::ToSql for OwnedParam {
89 fn to_sql(
90 &self,
91 ty: &tokio_postgres::types::Type,
92 out: &mut bytes::BytesMut,
93 ) -> Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
94 match self {
95 OwnedParam::I16(v) => v.to_sql(ty, out),
96 OwnedParam::I32(v) => v.to_sql(ty, out),
97 OwnedParam::I64(v) => v.to_sql(ty, out),
98 OwnedParam::F32(v) => v.to_sql(ty, out),
99 OwnedParam::F64(v) => v.to_sql(ty, out),
100 OwnedParam::Bool(v) => v.to_sql(ty, out),
101 OwnedParam::String(v) => v.to_sql(ty, out),
102 OwnedParam::Uuid(v) => v.to_sql(ty, out),
103 OwnedParam::NaiveDateTime(v) => v.to_sql(ty, out),
104 OwnedParam::DateTimeUtc(v) => v.to_sql(ty, out),
105 OwnedParam::NaiveDate(v) => v.to_sql(ty, out),
106 OwnedParam::Decimal(v) => v.to_sql(ty, out),
107 OwnedParam::Json(v) => v.to_sql(ty, out),
108 OwnedParam::OptI16(v) => v.to_sql(ty, out),
109 OwnedParam::OptI32(v) => v.to_sql(ty, out),
110 OwnedParam::OptI64(v) => v.to_sql(ty, out),
111 OwnedParam::OptF32(v) => v.to_sql(ty, out),
112 OwnedParam::OptF64(v) => v.to_sql(ty, out),
113 OwnedParam::OptBool(v) => v.to_sql(ty, out),
114 OwnedParam::OptString(v) => v.to_sql(ty, out),
115 OwnedParam::OptUuid(v) => v.to_sql(ty, out),
116 OwnedParam::OptNaiveDateTime(v) => v.to_sql(ty, out),
117 OwnedParam::OptDateTimeUtc(v) => v.to_sql(ty, out),
118 OwnedParam::OptNaiveDate(v) => v.to_sql(ty, out),
119 OwnedParam::OptDecimal(v) => v.to_sql(ty, out),
120 OwnedParam::OptJson(v) => v.to_sql(ty, out),
121 OwnedParam::VecI16(v) => v.to_sql(ty, out),
122 OwnedParam::VecI32(v) => v.to_sql(ty, out),
123 OwnedParam::VecI64(v) => v.to_sql(ty, out),
124 OwnedParam::VecF32(v) => v.to_sql(ty, out),
125 OwnedParam::VecF64(v) => v.to_sql(ty, out),
126 OwnedParam::VecBool(v) => v.to_sql(ty, out),
127 OwnedParam::VecString(v) => v.to_sql(ty, out),
128 OwnedParam::VecUuid(v) => v.to_sql(ty, out),
129 }
130 }
131
132 fn accepts(ty: &tokio_postgres::types::Type) -> bool {
133 i16::accepts(ty)
134 || i32::accepts(ty)
135 || i64::accepts(ty)
136 || f32::accepts(ty)
137 || f64::accepts(ty)
138 || bool::accepts(ty)
139 || String::accepts(ty)
140 || uuid::Uuid::accepts(ty)
141 || chrono::NaiveDateTime::accepts(ty)
142 || chrono::DateTime::<chrono::Utc>::accepts(ty)
143 || chrono::NaiveDate::accepts(ty)
144 || rust_decimal::Decimal::accepts(ty)
145 || serde_json::Value::accepts(ty)
146 || <Option<i16>>::accepts(ty)
147 || <Option<i32>>::accepts(ty)
148 || <Option<i64>>::accepts(ty)
149 || <Option<f32>>::accepts(ty)
150 || <Option<f64>>::accepts(ty)
151 || <Option<bool>>::accepts(ty)
152 || <Option<String>>::accepts(ty)
153 || <Option<uuid::Uuid>>::accepts(ty)
154 || <Option<chrono::NaiveDateTime>>::accepts(ty)
155 || <Option<chrono::DateTime<chrono::Utc>>>::accepts(ty)
156 || <Option<chrono::NaiveDate>>::accepts(ty)
157 || <Option<rust_decimal::Decimal>>::accepts(ty)
158 || <Option<serde_json::Value>>::accepts(ty)
159 || <Vec<i16>>::accepts(ty)
160 || <Vec<i32>>::accepts(ty)
161 || <Vec<i64>>::accepts(ty)
162 || <Vec<f32>>::accepts(ty)
163 || <Vec<f64>>::accepts(ty)
164 || <Vec<bool>>::accepts(ty)
165 || <Vec<String>>::accepts(ty)
166 || <Vec<uuid::Uuid>>::accepts(ty)
167 }
168
169 tokio_postgres::types::to_sql_checked!();
170}
171
172pub fn to_owned_param(param: &dyn ToSqlParam) -> OwnedParam {
174 let any = param.as_any();
175 if let Some(v) = any.downcast_ref::<i16>() { return OwnedParam::I16(*v); }
177 if let Some(v) = any.downcast_ref::<i32>() { return OwnedParam::I32(*v); }
178 if let Some(v) = any.downcast_ref::<i64>() { return OwnedParam::I64(*v); }
179 if let Some(v) = any.downcast_ref::<f32>() { return OwnedParam::F32(*v); }
180 if let Some(v) = any.downcast_ref::<f64>() { return OwnedParam::F64(*v); }
181 if let Some(v) = any.downcast_ref::<bool>() { return OwnedParam::Bool(*v); }
182 if let Some(v) = any.downcast_ref::<String>() { return OwnedParam::String(v.clone()); }
183 if let Some(v) = any.downcast_ref::<uuid::Uuid>() { return OwnedParam::Uuid(*v); }
184 if let Some(v) = any.downcast_ref::<chrono::NaiveDateTime>() { return OwnedParam::NaiveDateTime(*v); }
185 if let Some(v) = any.downcast_ref::<chrono::DateTime<chrono::Utc>>() { return OwnedParam::DateTimeUtc(*v); }
186 if let Some(v) = any.downcast_ref::<chrono::NaiveDate>() { return OwnedParam::NaiveDate(*v); }
187 if let Some(v) = any.downcast_ref::<rust_decimal::Decimal>() { return OwnedParam::Decimal(*v); }
188 if let Some(v) = any.downcast_ref::<serde_json::Value>() { return OwnedParam::Json(v.clone()); }
189 if let Some(v) = any.downcast_ref::<Option<i16>>() { return OwnedParam::OptI16(*v); }
191 if let Some(v) = any.downcast_ref::<Option<i32>>() { return OwnedParam::OptI32(*v); }
192 if let Some(v) = any.downcast_ref::<Option<i64>>() { return OwnedParam::OptI64(*v); }
193 if let Some(v) = any.downcast_ref::<Option<f32>>() { return OwnedParam::OptF32(*v); }
194 if let Some(v) = any.downcast_ref::<Option<f64>>() { return OwnedParam::OptF64(*v); }
195 if let Some(v) = any.downcast_ref::<Option<bool>>() { return OwnedParam::OptBool(*v); }
196 if let Some(v) = any.downcast_ref::<Option<String>>() { return OwnedParam::OptString(v.clone()); }
197 if let Some(v) = any.downcast_ref::<Option<uuid::Uuid>>() { return OwnedParam::OptUuid(*v); }
198 if let Some(v) = any.downcast_ref::<Option<chrono::NaiveDateTime>>() { return OwnedParam::OptNaiveDateTime(*v); }
199 if let Some(v) = any.downcast_ref::<Option<chrono::DateTime<chrono::Utc>>>() { return OwnedParam::OptDateTimeUtc(*v); }
200 if let Some(v) = any.downcast_ref::<Option<chrono::NaiveDate>>() { return OwnedParam::OptNaiveDate(*v); }
201 if let Some(v) = any.downcast_ref::<Option<rust_decimal::Decimal>>() { return OwnedParam::OptDecimal(*v); }
202 if let Some(v) = any.downcast_ref::<Option<serde_json::Value>>() { return OwnedParam::OptJson(v.clone()); }
203 if let Some(v) = any.downcast_ref::<Vec<i16>>() { return OwnedParam::VecI16(v.clone()); }
205 if let Some(v) = any.downcast_ref::<Vec<i32>>() { return OwnedParam::VecI32(v.clone()); }
206 if let Some(v) = any.downcast_ref::<Vec<i64>>() { return OwnedParam::VecI64(v.clone()); }
207 if let Some(v) = any.downcast_ref::<Vec<f32>>() { return OwnedParam::VecF32(v.clone()); }
208 if let Some(v) = any.downcast_ref::<Vec<f64>>() { return OwnedParam::VecF64(v.clone()); }
209 if let Some(v) = any.downcast_ref::<Vec<bool>>() { return OwnedParam::VecBool(v.clone()); }
210 if let Some(v) = any.downcast_ref::<Vec<String>>() { return OwnedParam::VecString(v.clone()); }
211 if let Some(v) = any.downcast_ref::<Vec<uuid::Uuid>>() { return OwnedParam::VecUuid(v.clone()); }
212 panic!("Unsupported parameter type for tokio-postgres backend");
213}
214
215pub fn convert_params(params: &[&dyn ToSqlParam]) -> Vec<OwnedParam> {
217 params.iter().map(|p| to_owned_param(*p)).collect()
218}
219
220pub struct PgClient {
224 inner: tokio_postgres::Client,
225}
226
227impl PgClient {
228 pub fn new(client: tokio_postgres::Client) -> Self {
229 Self { inner: client }
230 }
231
232 pub fn inner(&self) -> &tokio_postgres::Client {
234 &self.inner
235 }
236
237 pub fn into_inner(self) -> tokio_postgres::Client {
239 self.inner
240 }
241}
242
243impl QueryExecutor for PgClient {
244 type Error = tokio_postgres::Error;
245 type Row = PgRow;
246
247 fn query(
248 &self,
249 sql: &str,
250 params: &[&dyn ToSqlParam],
251 ) -> impl std::future::Future<Output = Result<Vec<Self::Row>, Self::Error>> + Send {
252 let owned_params = convert_params(params);
253 let sql = sql.to_owned();
254 let client = &self.inner;
255 async move {
256 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
257 owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
258 let rows = client.query(&sql, ¶m_refs).await?;
259 Ok(rows.into_iter().map(PgRow).collect())
260 }
261 }
262
263 fn execute(
264 &self,
265 sql: &str,
266 params: &[&dyn ToSqlParam],
267 ) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
268 let owned_params = convert_params(params);
269 let sql = sql.to_owned();
270 let client = &self.inner;
271 async move {
272 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
273 owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
274 client.execute(&sql, ¶m_refs).await
275 }
276 }
277}
278
279impl Client for PgClient {
280 type Transaction<'a> = PgTransaction<'a>;
281
282 fn transaction(&mut self) -> impl std::future::Future<Output = Result<Self::Transaction<'_>, Self::Error>> + Send {
283 async move {
284 let tx = self.inner.transaction().await?;
285 Ok(PgTransaction { inner: tx })
286 }
287 }
288}
289
290pub struct PgTransaction<'a> {
296 inner: tokio_postgres::Transaction<'a>,
297}
298
299impl<'a> PgTransaction<'a> {
300 pub fn new(tx: tokio_postgres::Transaction<'a>) -> Self {
301 Self { inner: tx }
302 }
303}
304
305impl QueryExecutor for PgTransaction<'_> {
306 type Error = tokio_postgres::Error;
307 type Row = PgRow;
308
309 fn query(
310 &self,
311 sql: &str,
312 params: &[&dyn ToSqlParam],
313 ) -> impl std::future::Future<Output = Result<Vec<Self::Row>, Self::Error>> + Send {
314 let owned_params = convert_params(params);
315 let sql = sql.to_owned();
316 let client = &self.inner;
317 async move {
318 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
319 owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
320 let rows = client.query(&sql, ¶m_refs).await?;
321 Ok(rows.into_iter().map(PgRow).collect())
322 }
323 }
324
325 fn execute(
326 &self,
327 sql: &str,
328 params: &[&dyn ToSqlParam],
329 ) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
330 let owned_params = convert_params(params);
331 let sql = sql.to_owned();
332 let client = &self.inner;
333 async move {
334 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
335 owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
336 client.execute(&sql, ¶m_refs).await
337 }
338 }
339}
340
341impl Transaction for PgTransaction<'_> {
342 type Nested<'a> = PgTransaction<'a> where Self: 'a;
343
344 fn transaction(&mut self) -> impl std::future::Future<Output = Result<Self::Nested<'_>, Self::Error>> + Send {
345 async move {
346 let tx = self.inner.transaction().await?;
347 Ok(PgTransaction { inner: tx })
348 }
349 }
350
351 fn commit(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
352 async move { self.inner.commit().await }
353 }
354
355 fn rollback(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
356 async move { self.inner.rollback().await }
357 }
358}