1use crate::{Client, QueryExecutor, Row, ToSqlParam, Transaction};
2
3struct AnyString(String);
10
11impl<'a> tokio_postgres::types::FromSql<'a> for AnyString {
12 fn from_sql(
13 _ty: &tokio_postgres::types::Type,
14 raw: &'a [u8],
15 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
16 Ok(AnyString(String::from_utf8(raw.to_vec())?))
17 }
18
19 fn accepts(_ty: &tokio_postgres::types::Type) -> bool {
20 true
21 }
22}
23
24pub struct PgRow(pub tokio_postgres::Row);
27
28impl Row for PgRow {
29 fn get_i16(&self, idx: usize) -> i16 { self.0.get(idx) }
30 fn get_i32(&self, idx: usize) -> i32 { self.0.get(idx) }
31 fn get_i64(&self, idx: usize) -> i64 { self.0.get(idx) }
32 fn get_f32(&self, idx: usize) -> f32 { self.0.get(idx) }
33 fn get_f64(&self, idx: usize) -> f64 { self.0.get(idx) }
34 fn get_string(&self, idx: usize) -> String { self.0.get::<_, AnyString>(idx).0 }
35 fn get_bool(&self, idx: usize) -> bool { self.0.get(idx) }
36 fn get_uuid(&self, idx: usize) -> uuid::Uuid { self.0.get(idx) }
37 fn get_timestamp(&self, idx: usize) -> chrono::NaiveDateTime { self.0.get(idx) }
38 fn get_timestamptz(&self, idx: usize) -> chrono::DateTime<chrono::Utc> { self.0.get(idx) }
39 fn get_date(&self, idx: usize) -> chrono::NaiveDate { self.0.get(idx) }
40 fn get_decimal(&self, idx: usize) -> rust_decimal::Decimal { self.0.get(idx) }
41 fn get_json(&self, idx: usize) -> serde_json::Value { self.0.get(idx) }
42
43 fn get_opt_i16(&self, idx: usize) -> Option<i16> { self.0.get(idx) }
44 fn get_opt_i32(&self, idx: usize) -> Option<i32> { self.0.get(idx) }
45 fn get_opt_i64(&self, idx: usize) -> Option<i64> { self.0.get(idx) }
46 fn get_opt_f32(&self, idx: usize) -> Option<f32> { self.0.get(idx) }
47 fn get_opt_f64(&self, idx: usize) -> Option<f64> { self.0.get(idx) }
48 fn get_opt_string(&self, idx: usize) -> Option<String> {
49 self.0.get::<_, Option<AnyString>>(idx).map(|s| s.0)
50 }
51 fn get_opt_bool(&self, idx: usize) -> Option<bool> { self.0.get(idx) }
52 fn get_opt_uuid(&self, idx: usize) -> Option<uuid::Uuid> { self.0.get(idx) }
53 fn get_opt_timestamp(&self, idx: usize) -> Option<chrono::NaiveDateTime> { self.0.get(idx) }
54 fn get_opt_timestamptz(&self, idx: usize) -> Option<chrono::DateTime<chrono::Utc>> { self.0.get(idx) }
55 fn get_opt_date(&self, idx: usize) -> Option<chrono::NaiveDate> { self.0.get(idx) }
56 fn get_opt_decimal(&self, idx: usize) -> Option<rust_decimal::Decimal> { self.0.get(idx) }
57 fn get_opt_json(&self, idx: usize) -> Option<serde_json::Value> { self.0.get(idx) }
58
59 fn get_vec_i16(&self, idx: usize) -> Vec<i16> { self.0.get(idx) }
60 fn get_vec_i32(&self, idx: usize) -> Vec<i32> { self.0.get(idx) }
61 fn get_vec_i64(&self, idx: usize) -> Vec<i64> { self.0.get(idx) }
62 fn get_vec_f32(&self, idx: usize) -> Vec<f32> { self.0.get(idx) }
63 fn get_vec_f64(&self, idx: usize) -> Vec<f64> { self.0.get(idx) }
64 fn get_vec_string(&self, idx: usize) -> Vec<String> { self.0.get(idx) }
65 fn get_vec_bool(&self, idx: usize) -> Vec<bool> { self.0.get(idx) }
66 fn get_vec_uuid(&self, idx: usize) -> Vec<uuid::Uuid> { self.0.get(idx) }
67}
68
69#[derive(Debug)]
74pub enum OwnedParam {
75 I16(i16),
76 I32(i32),
77 I64(i64),
78 F32(f32),
79 F64(f64),
80 Bool(bool),
81 String(String),
82 Uuid(uuid::Uuid),
83 NaiveDateTime(chrono::NaiveDateTime),
84 DateTimeUtc(chrono::DateTime<chrono::Utc>),
85 NaiveDate(chrono::NaiveDate),
86 Decimal(rust_decimal::Decimal),
87 Json(serde_json::Value),
88 OptI16(Option<i16>),
89 OptI32(Option<i32>),
90 OptI64(Option<i64>),
91 OptF32(Option<f32>),
92 OptF64(Option<f64>),
93 OptBool(Option<bool>),
94 OptString(Option<String>),
95 OptUuid(Option<uuid::Uuid>),
96 OptNaiveDateTime(Option<chrono::NaiveDateTime>),
97 OptDateTimeUtc(Option<chrono::DateTime<chrono::Utc>>),
98 OptNaiveDate(Option<chrono::NaiveDate>),
99 OptDecimal(Option<rust_decimal::Decimal>),
100 OptJson(Option<serde_json::Value>),
101 VecI16(Vec<i16>),
102 VecI32(Vec<i32>),
103 VecI64(Vec<i64>),
104 VecF32(Vec<f32>),
105 VecF64(Vec<f64>),
106 VecBool(Vec<bool>),
107 VecString(Vec<String>),
108 VecUuid(Vec<uuid::Uuid>),
109}
110
111impl tokio_postgres::types::ToSql for OwnedParam {
112 fn to_sql(
113 &self,
114 ty: &tokio_postgres::types::Type,
115 out: &mut bytes::BytesMut,
116 ) -> Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
117 match self {
118 OwnedParam::I16(v) => v.to_sql(ty, out),
119 OwnedParam::I32(v) => v.to_sql(ty, out),
120 OwnedParam::I64(v) => v.to_sql(ty, out),
121 OwnedParam::F32(v) => v.to_sql(ty, out),
122 OwnedParam::F64(v) => v.to_sql(ty, out),
123 OwnedParam::Bool(v) => v.to_sql(ty, out),
124 OwnedParam::String(v) => v.to_sql(ty, out),
125 OwnedParam::Uuid(v) => v.to_sql(ty, out),
126 OwnedParam::NaiveDateTime(v) => v.to_sql(ty, out),
127 OwnedParam::DateTimeUtc(v) => v.to_sql(ty, out),
128 OwnedParam::NaiveDate(v) => v.to_sql(ty, out),
129 OwnedParam::Decimal(v) => v.to_sql(ty, out),
130 OwnedParam::Json(v) => v.to_sql(ty, out),
131 OwnedParam::OptI16(v) => v.to_sql(ty, out),
132 OwnedParam::OptI32(v) => v.to_sql(ty, out),
133 OwnedParam::OptI64(v) => v.to_sql(ty, out),
134 OwnedParam::OptF32(v) => v.to_sql(ty, out),
135 OwnedParam::OptF64(v) => v.to_sql(ty, out),
136 OwnedParam::OptBool(v) => v.to_sql(ty, out),
137 OwnedParam::OptString(v) => v.to_sql(ty, out),
138 OwnedParam::OptUuid(v) => v.to_sql(ty, out),
139 OwnedParam::OptNaiveDateTime(v) => v.to_sql(ty, out),
140 OwnedParam::OptDateTimeUtc(v) => v.to_sql(ty, out),
141 OwnedParam::OptNaiveDate(v) => v.to_sql(ty, out),
142 OwnedParam::OptDecimal(v) => v.to_sql(ty, out),
143 OwnedParam::OptJson(v) => v.to_sql(ty, out),
144 OwnedParam::VecI16(v) => v.to_sql(ty, out),
145 OwnedParam::VecI32(v) => v.to_sql(ty, out),
146 OwnedParam::VecI64(v) => v.to_sql(ty, out),
147 OwnedParam::VecF32(v) => v.to_sql(ty, out),
148 OwnedParam::VecF64(v) => v.to_sql(ty, out),
149 OwnedParam::VecBool(v) => v.to_sql(ty, out),
150 OwnedParam::VecString(v) => v.to_sql(ty, out),
151 OwnedParam::VecUuid(v) => v.to_sql(ty, out),
152 }
153 }
154
155 fn accepts(_ty: &tokio_postgres::types::Type) -> bool {
156 true
161 }
162
163 tokio_postgres::types::to_sql_checked!();
164}
165
166pub fn to_owned_param(param: &dyn ToSqlParam) -> OwnedParam {
168 let any = param.as_any();
169 if let Some(v) = any.downcast_ref::<i16>() { return OwnedParam::I16(*v); }
171 if let Some(v) = any.downcast_ref::<i32>() { return OwnedParam::I32(*v); }
172 if let Some(v) = any.downcast_ref::<i64>() { return OwnedParam::I64(*v); }
173 if let Some(v) = any.downcast_ref::<f32>() { return OwnedParam::F32(*v); }
174 if let Some(v) = any.downcast_ref::<f64>() { return OwnedParam::F64(*v); }
175 if let Some(v) = any.downcast_ref::<bool>() { return OwnedParam::Bool(*v); }
176 if let Some(v) = any.downcast_ref::<String>() { return OwnedParam::String(v.clone()); }
177 if let Some(v) = any.downcast_ref::<uuid::Uuid>() { return OwnedParam::Uuid(*v); }
178 if let Some(v) = any.downcast_ref::<chrono::NaiveDateTime>() { return OwnedParam::NaiveDateTime(*v); }
179 if let Some(v) = any.downcast_ref::<chrono::DateTime<chrono::Utc>>() { return OwnedParam::DateTimeUtc(*v); }
180 if let Some(v) = any.downcast_ref::<chrono::NaiveDate>() { return OwnedParam::NaiveDate(*v); }
181 if let Some(v) = any.downcast_ref::<rust_decimal::Decimal>() { return OwnedParam::Decimal(*v); }
182 if let Some(v) = any.downcast_ref::<serde_json::Value>() { return OwnedParam::Json(v.clone()); }
183 if let Some(v) = any.downcast_ref::<Option<i16>>() { return OwnedParam::OptI16(*v); }
185 if let Some(v) = any.downcast_ref::<Option<i32>>() { return OwnedParam::OptI32(*v); }
186 if let Some(v) = any.downcast_ref::<Option<i64>>() { return OwnedParam::OptI64(*v); }
187 if let Some(v) = any.downcast_ref::<Option<f32>>() { return OwnedParam::OptF32(*v); }
188 if let Some(v) = any.downcast_ref::<Option<f64>>() { return OwnedParam::OptF64(*v); }
189 if let Some(v) = any.downcast_ref::<Option<bool>>() { return OwnedParam::OptBool(*v); }
190 if let Some(v) = any.downcast_ref::<Option<String>>() { return OwnedParam::OptString(v.clone()); }
191 if let Some(v) = any.downcast_ref::<Option<uuid::Uuid>>() { return OwnedParam::OptUuid(*v); }
192 if let Some(v) = any.downcast_ref::<Option<chrono::NaiveDateTime>>() { return OwnedParam::OptNaiveDateTime(*v); }
193 if let Some(v) = any.downcast_ref::<Option<chrono::DateTime<chrono::Utc>>>() { return OwnedParam::OptDateTimeUtc(*v); }
194 if let Some(v) = any.downcast_ref::<Option<chrono::NaiveDate>>() { return OwnedParam::OptNaiveDate(*v); }
195 if let Some(v) = any.downcast_ref::<Option<rust_decimal::Decimal>>() { return OwnedParam::OptDecimal(*v); }
196 if let Some(v) = any.downcast_ref::<Option<serde_json::Value>>() { return OwnedParam::OptJson(v.clone()); }
197 if let Some(v) = any.downcast_ref::<Vec<i16>>() { return OwnedParam::VecI16(v.clone()); }
199 if let Some(v) = any.downcast_ref::<Vec<i32>>() { return OwnedParam::VecI32(v.clone()); }
200 if let Some(v) = any.downcast_ref::<Vec<i64>>() { return OwnedParam::VecI64(v.clone()); }
201 if let Some(v) = any.downcast_ref::<Vec<f32>>() { return OwnedParam::VecF32(v.clone()); }
202 if let Some(v) = any.downcast_ref::<Vec<f64>>() { return OwnedParam::VecF64(v.clone()); }
203 if let Some(v) = any.downcast_ref::<Vec<bool>>() { return OwnedParam::VecBool(v.clone()); }
204 if let Some(v) = any.downcast_ref::<Vec<String>>() { return OwnedParam::VecString(v.clone()); }
205 if let Some(v) = any.downcast_ref::<Vec<uuid::Uuid>>() { return OwnedParam::VecUuid(v.clone()); }
206 panic!("Unsupported parameter type for tokio-postgres backend");
207}
208
209pub fn convert_params(params: &[&dyn ToSqlParam]) -> Vec<OwnedParam> {
211 params.iter().map(|p| to_owned_param(*p)).collect()
212}
213
214pub struct PgClient {
218 inner: tokio_postgres::Client,
219}
220
221impl PgClient {
222 pub fn new(client: tokio_postgres::Client) -> Self {
223 Self { inner: client }
224 }
225
226 pub fn inner(&self) -> &tokio_postgres::Client {
228 &self.inner
229 }
230
231 pub fn into_inner(self) -> tokio_postgres::Client {
233 self.inner
234 }
235}
236
237impl QueryExecutor for PgClient {
238 type Error = tokio_postgres::Error;
239 type Row = PgRow;
240
241 fn query(
242 &self,
243 sql: &str,
244 params: &[&dyn ToSqlParam],
245 ) -> impl std::future::Future<Output = Result<Vec<Self::Row>, Self::Error>> + Send {
246 let owned_params = convert_params(params);
247 let sql = sql.to_owned();
248 let client = &self.inner;
249 async move {
250 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
251 owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
252 let rows = client.query(&sql, ¶m_refs).await?;
253 Ok(rows.into_iter().map(PgRow).collect())
254 }
255 }
256
257 fn execute(
258 &self,
259 sql: &str,
260 params: &[&dyn ToSqlParam],
261 ) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
262 let owned_params = convert_params(params);
263 let sql = sql.to_owned();
264 let client = &self.inner;
265 async move {
266 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
267 owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
268 client.execute(&sql, ¶m_refs).await
269 }
270 }
271}
272
273impl Client for PgClient {
274 type Transaction<'a> = PgTransaction<'a>;
275
276 fn transaction(&mut self) -> impl std::future::Future<Output = Result<Self::Transaction<'_>, Self::Error>> + Send {
277 async move {
278 let tx = self.inner.transaction().await?;
279 Ok(PgTransaction { inner: tx })
280 }
281 }
282}
283
284pub struct PgTransaction<'a> {
290 inner: tokio_postgres::Transaction<'a>,
291}
292
293impl<'a> PgTransaction<'a> {
294 pub fn new(tx: tokio_postgres::Transaction<'a>) -> Self {
295 Self { inner: tx }
296 }
297}
298
299impl QueryExecutor for PgTransaction<'_> {
300 type Error = tokio_postgres::Error;
301 type Row = PgRow;
302
303 fn query(
304 &self,
305 sql: &str,
306 params: &[&dyn ToSqlParam],
307 ) -> impl std::future::Future<Output = Result<Vec<Self::Row>, Self::Error>> + Send {
308 let owned_params = convert_params(params);
309 let sql = sql.to_owned();
310 let client = &self.inner;
311 async move {
312 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
313 owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
314 let rows = client.query(&sql, ¶m_refs).await?;
315 Ok(rows.into_iter().map(PgRow).collect())
316 }
317 }
318
319 fn execute(
320 &self,
321 sql: &str,
322 params: &[&dyn ToSqlParam],
323 ) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
324 let owned_params = convert_params(params);
325 let sql = sql.to_owned();
326 let client = &self.inner;
327 async move {
328 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
329 owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
330 client.execute(&sql, ¶m_refs).await
331 }
332 }
333}
334
335impl Transaction for PgTransaction<'_> {
336 type Nested<'a> = PgTransaction<'a> where Self: 'a;
337
338 fn transaction(&mut self) -> impl std::future::Future<Output = Result<Self::Nested<'_>, Self::Error>> + Send {
339 async move {
340 let tx = self.inner.transaction().await?;
341 Ok(PgTransaction { inner: tx })
342 }
343 }
344
345 fn commit(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
346 async move { self.inner.commit().await }
347 }
348
349 fn rollback(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
350 async move { self.inner.rollback().await }
351 }
352}