1use crate::{Client, Row, ToSqlParam, Transaction, UpdateField};
2
3macro_rules! impl_to_sql_param {
6 ($($ty:ty),*) => {
7 $(
8 impl ToSqlParam for $ty {
9 fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
10 self
11 }
12 }
13 )*
14 };
15}
16
17impl_to_sql_param!(
18 i16, i32, i64, f32, f64, bool, String,
19 uuid::Uuid, chrono::NaiveDateTime, chrono::NaiveDate,
20 rust_decimal::Decimal
21);
22
23impl ToSqlParam for chrono::DateTime<chrono::Utc> {
24 fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
25 self
26 }
27}
28
29impl ToSqlParam for serde_json::Value {
30 fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
31 self
32 }
33}
34
35impl ToSqlParam for &'static str {
36 fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
37 self
38 }
39}
40
41impl<T: ToSqlParam + 'static> ToSqlParam for Option<T> {
42 fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
43 self
44 }
45}
46
47impl<T: ToSqlParam + 'static> ToSqlParam for Vec<T> {
48 fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
49 self
50 }
51}
52
53impl<T: ToSqlParam + 'static> ToSqlParam for UpdateField<T> {
54 fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
55 match self {
56 UpdateField::Set(v) => v.as_any(),
57 UpdateField::Unchanged => panic!("Unchanged fields should not be passed as SQL params"),
58 }
59 }
60}
61
62pub struct PgRow(pub tokio_postgres::Row);
65
66impl Row for PgRow {
67 fn get_i16(&self, idx: usize) -> i16 { self.0.get(idx) }
68 fn get_i32(&self, idx: usize) -> i32 { self.0.get(idx) }
69 fn get_i64(&self, idx: usize) -> i64 { self.0.get(idx) }
70 fn get_f32(&self, idx: usize) -> f32 { self.0.get(idx) }
71 fn get_f64(&self, idx: usize) -> f64 { self.0.get(idx) }
72 fn get_string(&self, idx: usize) -> String { self.0.get(idx) }
73 fn get_bool(&self, idx: usize) -> bool { self.0.get(idx) }
74 fn get_uuid(&self, idx: usize) -> uuid::Uuid { self.0.get(idx) }
75 fn get_timestamp(&self, idx: usize) -> chrono::NaiveDateTime { self.0.get(idx) }
76 fn get_timestamptz(&self, idx: usize) -> chrono::DateTime<chrono::Utc> { self.0.get(idx) }
77 fn get_date(&self, idx: usize) -> chrono::NaiveDate { self.0.get(idx) }
78 fn get_decimal(&self, idx: usize) -> rust_decimal::Decimal { self.0.get(idx) }
79 fn get_json(&self, idx: usize) -> serde_json::Value { self.0.get(idx) }
80
81 fn get_opt_i16(&self, idx: usize) -> Option<i16> { self.0.get(idx) }
82 fn get_opt_i32(&self, idx: usize) -> Option<i32> { self.0.get(idx) }
83 fn get_opt_i64(&self, idx: usize) -> Option<i64> { self.0.get(idx) }
84 fn get_opt_f32(&self, idx: usize) -> Option<f32> { self.0.get(idx) }
85 fn get_opt_f64(&self, idx: usize) -> Option<f64> { self.0.get(idx) }
86 fn get_opt_string(&self, idx: usize) -> Option<String> { self.0.get(idx) }
87 fn get_opt_bool(&self, idx: usize) -> Option<bool> { self.0.get(idx) }
88 fn get_opt_uuid(&self, idx: usize) -> Option<uuid::Uuid> { self.0.get(idx) }
89 fn get_opt_timestamp(&self, idx: usize) -> Option<chrono::NaiveDateTime> { self.0.get(idx) }
90 fn get_opt_timestamptz(&self, idx: usize) -> Option<chrono::DateTime<chrono::Utc>> { self.0.get(idx) }
91 fn get_opt_date(&self, idx: usize) -> Option<chrono::NaiveDate> { self.0.get(idx) }
92 fn get_opt_decimal(&self, idx: usize) -> Option<rust_decimal::Decimal> { self.0.get(idx) }
93 fn get_opt_json(&self, idx: usize) -> Option<serde_json::Value> { self.0.get(idx) }
94
95 fn get_vec_i16(&self, idx: usize) -> Vec<i16> { self.0.get(idx) }
96 fn get_vec_i32(&self, idx: usize) -> Vec<i32> { self.0.get(idx) }
97 fn get_vec_i64(&self, idx: usize) -> Vec<i64> { self.0.get(idx) }
98 fn get_vec_f32(&self, idx: usize) -> Vec<f32> { self.0.get(idx) }
99 fn get_vec_f64(&self, idx: usize) -> Vec<f64> { self.0.get(idx) }
100 fn get_vec_string(&self, idx: usize) -> Vec<String> { self.0.get(idx) }
101 fn get_vec_bool(&self, idx: usize) -> Vec<bool> { self.0.get(idx) }
102 fn get_vec_uuid(&self, idx: usize) -> Vec<uuid::Uuid> { self.0.get(idx) }
103}
104
105#[derive(Debug)]
110pub enum OwnedParam {
111 I16(i16),
112 I32(i32),
113 I64(i64),
114 F32(f32),
115 F64(f64),
116 Bool(bool),
117 String(String),
118 Uuid(uuid::Uuid),
119 NaiveDateTime(chrono::NaiveDateTime),
120 DateTimeUtc(chrono::DateTime<chrono::Utc>),
121 NaiveDate(chrono::NaiveDate),
122 Decimal(rust_decimal::Decimal),
123 Json(serde_json::Value),
124 OptI16(Option<i16>),
125 OptI32(Option<i32>),
126 OptI64(Option<i64>),
127 OptF32(Option<f32>),
128 OptF64(Option<f64>),
129 OptBool(Option<bool>),
130 OptString(Option<String>),
131 OptUuid(Option<uuid::Uuid>),
132 OptNaiveDateTime(Option<chrono::NaiveDateTime>),
133 OptDateTimeUtc(Option<chrono::DateTime<chrono::Utc>>),
134 OptNaiveDate(Option<chrono::NaiveDate>),
135 OptDecimal(Option<rust_decimal::Decimal>),
136 OptJson(Option<serde_json::Value>),
137 VecI16(Vec<i16>),
138 VecI32(Vec<i32>),
139 VecI64(Vec<i64>),
140 VecF32(Vec<f32>),
141 VecF64(Vec<f64>),
142 VecBool(Vec<bool>),
143 VecString(Vec<String>),
144 VecUuid(Vec<uuid::Uuid>),
145}
146
147impl tokio_postgres::types::ToSql for OwnedParam {
148 fn to_sql(
149 &self,
150 ty: &tokio_postgres::types::Type,
151 out: &mut bytes::BytesMut,
152 ) -> Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
153 match self {
154 OwnedParam::I16(v) => v.to_sql(ty, out),
155 OwnedParam::I32(v) => v.to_sql(ty, out),
156 OwnedParam::I64(v) => v.to_sql(ty, out),
157 OwnedParam::F32(v) => v.to_sql(ty, out),
158 OwnedParam::F64(v) => v.to_sql(ty, out),
159 OwnedParam::Bool(v) => v.to_sql(ty, out),
160 OwnedParam::String(v) => v.to_sql(ty, out),
161 OwnedParam::Uuid(v) => v.to_sql(ty, out),
162 OwnedParam::NaiveDateTime(v) => v.to_sql(ty, out),
163 OwnedParam::DateTimeUtc(v) => v.to_sql(ty, out),
164 OwnedParam::NaiveDate(v) => v.to_sql(ty, out),
165 OwnedParam::Decimal(v) => v.to_sql(ty, out),
166 OwnedParam::Json(v) => v.to_sql(ty, out),
167 OwnedParam::OptI16(v) => v.to_sql(ty, out),
168 OwnedParam::OptI32(v) => v.to_sql(ty, out),
169 OwnedParam::OptI64(v) => v.to_sql(ty, out),
170 OwnedParam::OptF32(v) => v.to_sql(ty, out),
171 OwnedParam::OptF64(v) => v.to_sql(ty, out),
172 OwnedParam::OptBool(v) => v.to_sql(ty, out),
173 OwnedParam::OptString(v) => v.to_sql(ty, out),
174 OwnedParam::OptUuid(v) => v.to_sql(ty, out),
175 OwnedParam::OptNaiveDateTime(v) => v.to_sql(ty, out),
176 OwnedParam::OptDateTimeUtc(v) => v.to_sql(ty, out),
177 OwnedParam::OptNaiveDate(v) => v.to_sql(ty, out),
178 OwnedParam::OptDecimal(v) => v.to_sql(ty, out),
179 OwnedParam::OptJson(v) => v.to_sql(ty, out),
180 OwnedParam::VecI16(v) => v.to_sql(ty, out),
181 OwnedParam::VecI32(v) => v.to_sql(ty, out),
182 OwnedParam::VecI64(v) => v.to_sql(ty, out),
183 OwnedParam::VecF32(v) => v.to_sql(ty, out),
184 OwnedParam::VecF64(v) => v.to_sql(ty, out),
185 OwnedParam::VecBool(v) => v.to_sql(ty, out),
186 OwnedParam::VecString(v) => v.to_sql(ty, out),
187 OwnedParam::VecUuid(v) => v.to_sql(ty, out),
188 }
189 }
190
191 fn accepts(ty: &tokio_postgres::types::Type) -> bool {
192 i16::accepts(ty)
193 || i32::accepts(ty)
194 || i64::accepts(ty)
195 || f32::accepts(ty)
196 || f64::accepts(ty)
197 || bool::accepts(ty)
198 || String::accepts(ty)
199 || uuid::Uuid::accepts(ty)
200 || chrono::NaiveDateTime::accepts(ty)
201 || chrono::DateTime::<chrono::Utc>::accepts(ty)
202 || chrono::NaiveDate::accepts(ty)
203 || rust_decimal::Decimal::accepts(ty)
204 || serde_json::Value::accepts(ty)
205 || <Option<i16>>::accepts(ty)
206 || <Option<i32>>::accepts(ty)
207 || <Option<i64>>::accepts(ty)
208 || <Option<f32>>::accepts(ty)
209 || <Option<f64>>::accepts(ty)
210 || <Option<bool>>::accepts(ty)
211 || <Option<String>>::accepts(ty)
212 || <Option<uuid::Uuid>>::accepts(ty)
213 || <Option<chrono::NaiveDateTime>>::accepts(ty)
214 || <Option<chrono::DateTime<chrono::Utc>>>::accepts(ty)
215 || <Option<chrono::NaiveDate>>::accepts(ty)
216 || <Option<rust_decimal::Decimal>>::accepts(ty)
217 || <Option<serde_json::Value>>::accepts(ty)
218 || <Vec<i16>>::accepts(ty)
219 || <Vec<i32>>::accepts(ty)
220 || <Vec<i64>>::accepts(ty)
221 || <Vec<f32>>::accepts(ty)
222 || <Vec<f64>>::accepts(ty)
223 || <Vec<bool>>::accepts(ty)
224 || <Vec<String>>::accepts(ty)
225 || <Vec<uuid::Uuid>>::accepts(ty)
226 }
227
228 tokio_postgres::types::to_sql_checked!();
229}
230
231pub fn to_owned_param(param: &dyn ToSqlParam) -> OwnedParam {
233 let any = param.as_any();
234 if let Some(v) = any.downcast_ref::<i16>() { return OwnedParam::I16(*v); }
236 if let Some(v) = any.downcast_ref::<i32>() { return OwnedParam::I32(*v); }
237 if let Some(v) = any.downcast_ref::<i64>() { return OwnedParam::I64(*v); }
238 if let Some(v) = any.downcast_ref::<f32>() { return OwnedParam::F32(*v); }
239 if let Some(v) = any.downcast_ref::<f64>() { return OwnedParam::F64(*v); }
240 if let Some(v) = any.downcast_ref::<bool>() { return OwnedParam::Bool(*v); }
241 if let Some(v) = any.downcast_ref::<String>() { return OwnedParam::String(v.clone()); }
242 if let Some(v) = any.downcast_ref::<uuid::Uuid>() { return OwnedParam::Uuid(*v); }
243 if let Some(v) = any.downcast_ref::<chrono::NaiveDateTime>() { return OwnedParam::NaiveDateTime(*v); }
244 if let Some(v) = any.downcast_ref::<chrono::DateTime<chrono::Utc>>() { return OwnedParam::DateTimeUtc(*v); }
245 if let Some(v) = any.downcast_ref::<chrono::NaiveDate>() { return OwnedParam::NaiveDate(*v); }
246 if let Some(v) = any.downcast_ref::<rust_decimal::Decimal>() { return OwnedParam::Decimal(*v); }
247 if let Some(v) = any.downcast_ref::<serde_json::Value>() { return OwnedParam::Json(v.clone()); }
248 if let Some(v) = any.downcast_ref::<Option<i16>>() { return OwnedParam::OptI16(*v); }
250 if let Some(v) = any.downcast_ref::<Option<i32>>() { return OwnedParam::OptI32(*v); }
251 if let Some(v) = any.downcast_ref::<Option<i64>>() { return OwnedParam::OptI64(*v); }
252 if let Some(v) = any.downcast_ref::<Option<f32>>() { return OwnedParam::OptF32(*v); }
253 if let Some(v) = any.downcast_ref::<Option<f64>>() { return OwnedParam::OptF64(*v); }
254 if let Some(v) = any.downcast_ref::<Option<bool>>() { return OwnedParam::OptBool(*v); }
255 if let Some(v) = any.downcast_ref::<Option<String>>() { return OwnedParam::OptString(v.clone()); }
256 if let Some(v) = any.downcast_ref::<Option<uuid::Uuid>>() { return OwnedParam::OptUuid(*v); }
257 if let Some(v) = any.downcast_ref::<Option<chrono::NaiveDateTime>>() { return OwnedParam::OptNaiveDateTime(*v); }
258 if let Some(v) = any.downcast_ref::<Option<chrono::DateTime<chrono::Utc>>>() { return OwnedParam::OptDateTimeUtc(*v); }
259 if let Some(v) = any.downcast_ref::<Option<chrono::NaiveDate>>() { return OwnedParam::OptNaiveDate(*v); }
260 if let Some(v) = any.downcast_ref::<Option<rust_decimal::Decimal>>() { return OwnedParam::OptDecimal(*v); }
261 if let Some(v) = any.downcast_ref::<Option<serde_json::Value>>() { return OwnedParam::OptJson(v.clone()); }
262 if let Some(v) = any.downcast_ref::<Vec<i16>>() { return OwnedParam::VecI16(v.clone()); }
264 if let Some(v) = any.downcast_ref::<Vec<i32>>() { return OwnedParam::VecI32(v.clone()); }
265 if let Some(v) = any.downcast_ref::<Vec<i64>>() { return OwnedParam::VecI64(v.clone()); }
266 if let Some(v) = any.downcast_ref::<Vec<f32>>() { return OwnedParam::VecF32(v.clone()); }
267 if let Some(v) = any.downcast_ref::<Vec<f64>>() { return OwnedParam::VecF64(v.clone()); }
268 if let Some(v) = any.downcast_ref::<Vec<bool>>() { return OwnedParam::VecBool(v.clone()); }
269 if let Some(v) = any.downcast_ref::<Vec<String>>() { return OwnedParam::VecString(v.clone()); }
270 if let Some(v) = any.downcast_ref::<Vec<uuid::Uuid>>() { return OwnedParam::VecUuid(v.clone()); }
271 panic!("Unsupported parameter type for tokio-postgres backend");
272}
273
274pub fn convert_params(params: &[&dyn ToSqlParam]) -> Vec<OwnedParam> {
276 params.iter().map(|p| to_owned_param(*p)).collect()
277}
278
279pub struct PgClient {
283 inner: tokio_postgres::Client,
284}
285
286impl PgClient {
287 pub fn new(client: tokio_postgres::Client) -> Self {
288 Self { inner: client }
289 }
290
291 pub fn inner(&self) -> &tokio_postgres::Client {
293 &self.inner
294 }
295
296 pub fn into_inner(self) -> tokio_postgres::Client {
298 self.inner
299 }
300}
301
302impl Client for PgClient {
303 type Error = tokio_postgres::Error;
304 type Row = PgRow;
305 type Transaction<'a> = PgTransaction<'a>;
306
307 fn query(
308 &self,
309 sql: &str,
310 params: &[&dyn ToSqlParam],
311 ) -> impl std::future::Future<Output = Result<Vec<Self::Row>, Self::Error>> + Send {
312 let owned_params = convert_params(params);
313 let sql = sql.to_owned();
314 let client = &self.inner;
315 async move {
316 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
317 owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
318 let rows = client.query(&sql, ¶m_refs).await?;
319 Ok(rows.into_iter().map(PgRow).collect())
320 }
321 }
322
323 fn execute(
324 &self,
325 sql: &str,
326 params: &[&dyn ToSqlParam],
327 ) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
328 let owned_params = convert_params(params);
329 let sql = sql.to_owned();
330 let client = &self.inner;
331 async move {
332 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
333 owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
334 client.execute(&sql, ¶m_refs).await
335 }
336 }
337
338 fn transaction(&mut self) -> impl std::future::Future<Output = Result<Self::Transaction<'_>, Self::Error>> + Send {
339 async move {
340 let tx = self.inner.transaction().await?;
341 Ok(PgTransaction { inner: tx })
342 }
343 }
344}
345
346pub struct PgTransaction<'a> {
352 inner: tokio_postgres::Transaction<'a>,
353}
354
355impl<'a> PgTransaction<'a> {
356 pub fn new(tx: tokio_postgres::Transaction<'a>) -> Self {
357 Self { inner: tx }
358 }
359}
360
361impl Transaction for PgTransaction<'_> {
362 type Error = tokio_postgres::Error;
363 type Row = PgRow;
364 type Nested<'a> = PgTransaction<'a> where Self: 'a;
365
366 fn query(
367 &self,
368 sql: &str,
369 params: &[&dyn ToSqlParam],
370 ) -> impl std::future::Future<Output = Result<Vec<Self::Row>, Self::Error>> + Send {
371 let owned_params = convert_params(params);
372 let sql = sql.to_owned();
373 let client = &self.inner;
374 async move {
375 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
376 owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
377 let rows = client.query(&sql, ¶m_refs).await?;
378 Ok(rows.into_iter().map(PgRow).collect())
379 }
380 }
381
382 fn execute(
383 &self,
384 sql: &str,
385 params: &[&dyn ToSqlParam],
386 ) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
387 let owned_params = convert_params(params);
388 let sql = sql.to_owned();
389 let client = &self.inner;
390 async move {
391 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
392 owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
393 client.execute(&sql, ¶m_refs).await
394 }
395 }
396
397 fn transaction(&mut self) -> impl std::future::Future<Output = Result<Self::Nested<'_>, Self::Error>> + Send {
398 async move {
399 let tx = self.inner.transaction().await?;
400 Ok(PgTransaction { inner: tx })
401 }
402 }
403
404 fn commit(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
405 async move { self.inner.commit().await }
406 }
407
408 fn rollback(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
409 async move { self.inner.rollback().await }
410 }
411}