Skip to main content

nextsql_backend_rust_runtime/
tokio_postgres_impl.rs

1use crate::{Client, Row, ToSqlParam, Transaction, UpdateField};
2
3// ---- ToSqlParam implementations for primitive types ----
4
5macro_rules! impl_to_sql_param {
6    ($($ty:ty),*) => {
7        $(
8            impl ToSqlParam for $ty {
9                fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
10                    self
11                }
12            }
13        )*
14    };
15}
16
17impl_to_sql_param!(
18    i16, i32, i64, f32, f64, bool, String,
19    uuid::Uuid, chrono::NaiveDateTime, chrono::NaiveDate,
20    rust_decimal::Decimal
21);
22
23impl ToSqlParam for chrono::DateTime<chrono::Utc> {
24    fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
25        self
26    }
27}
28
29impl ToSqlParam for serde_json::Value {
30    fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
31        self
32    }
33}
34
35impl ToSqlParam for &'static str {
36    fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
37        self
38    }
39}
40
41impl<T: ToSqlParam + 'static> ToSqlParam for Option<T> {
42    fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
43        self
44    }
45}
46
47impl<T: ToSqlParam + 'static> ToSqlParam for Vec<T> {
48    fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
49        self
50    }
51}
52
53impl<T: ToSqlParam + 'static> ToSqlParam for UpdateField<T> {
54    fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
55        match self {
56            UpdateField::Set(v) => v.as_any(),
57            UpdateField::Unchanged => panic!("Unchanged fields should not be passed as SQL params"),
58        }
59    }
60}
61
62// ---- Row implementation wrapping tokio_postgres::Row ----
63
64pub struct PgRow(pub tokio_postgres::Row);
65
66impl Row for PgRow {
67    fn get_i16(&self, idx: usize) -> i16 { self.0.get(idx) }
68    fn get_i32(&self, idx: usize) -> i32 { self.0.get(idx) }
69    fn get_i64(&self, idx: usize) -> i64 { self.0.get(idx) }
70    fn get_f32(&self, idx: usize) -> f32 { self.0.get(idx) }
71    fn get_f64(&self, idx: usize) -> f64 { self.0.get(idx) }
72    fn get_string(&self, idx: usize) -> String { self.0.get(idx) }
73    fn get_bool(&self, idx: usize) -> bool { self.0.get(idx) }
74    fn get_uuid(&self, idx: usize) -> uuid::Uuid { self.0.get(idx) }
75    fn get_timestamp(&self, idx: usize) -> chrono::NaiveDateTime { self.0.get(idx) }
76    fn get_timestamptz(&self, idx: usize) -> chrono::DateTime<chrono::Utc> { self.0.get(idx) }
77    fn get_date(&self, idx: usize) -> chrono::NaiveDate { self.0.get(idx) }
78    fn get_decimal(&self, idx: usize) -> rust_decimal::Decimal { self.0.get(idx) }
79    fn get_json(&self, idx: usize) -> serde_json::Value { self.0.get(idx) }
80
81    fn get_opt_i16(&self, idx: usize) -> Option<i16> { self.0.get(idx) }
82    fn get_opt_i32(&self, idx: usize) -> Option<i32> { self.0.get(idx) }
83    fn get_opt_i64(&self, idx: usize) -> Option<i64> { self.0.get(idx) }
84    fn get_opt_f32(&self, idx: usize) -> Option<f32> { self.0.get(idx) }
85    fn get_opt_f64(&self, idx: usize) -> Option<f64> { self.0.get(idx) }
86    fn get_opt_string(&self, idx: usize) -> Option<String> { self.0.get(idx) }
87    fn get_opt_bool(&self, idx: usize) -> Option<bool> { self.0.get(idx) }
88    fn get_opt_uuid(&self, idx: usize) -> Option<uuid::Uuid> { self.0.get(idx) }
89    fn get_opt_timestamp(&self, idx: usize) -> Option<chrono::NaiveDateTime> { self.0.get(idx) }
90    fn get_opt_timestamptz(&self, idx: usize) -> Option<chrono::DateTime<chrono::Utc>> { self.0.get(idx) }
91    fn get_opt_date(&self, idx: usize) -> Option<chrono::NaiveDate> { self.0.get(idx) }
92    fn get_opt_decimal(&self, idx: usize) -> Option<rust_decimal::Decimal> { self.0.get(idx) }
93    fn get_opt_json(&self, idx: usize) -> Option<serde_json::Value> { self.0.get(idx) }
94
95    fn get_vec_i16(&self, idx: usize) -> Vec<i16> { self.0.get(idx) }
96    fn get_vec_i32(&self, idx: usize) -> Vec<i32> { self.0.get(idx) }
97    fn get_vec_i64(&self, idx: usize) -> Vec<i64> { self.0.get(idx) }
98    fn get_vec_f32(&self, idx: usize) -> Vec<f32> { self.0.get(idx) }
99    fn get_vec_f64(&self, idx: usize) -> Vec<f64> { self.0.get(idx) }
100    fn get_vec_string(&self, idx: usize) -> Vec<String> { self.0.get(idx) }
101    fn get_vec_bool(&self, idx: usize) -> Vec<bool> { self.0.get(idx) }
102    fn get_vec_uuid(&self, idx: usize) -> Vec<uuid::Uuid> { self.0.get(idx) }
103}
104
105// ---- Owned parameter enum for crossing async boundaries ----
106
107/// An owned SQL parameter value. Used internally to convert borrowed `&dyn ToSqlParam`
108/// into owned values that can be moved into async futures.
109#[derive(Debug)]
110pub enum OwnedParam {
111    I16(i16),
112    I32(i32),
113    I64(i64),
114    F32(f32),
115    F64(f64),
116    Bool(bool),
117    String(String),
118    Uuid(uuid::Uuid),
119    NaiveDateTime(chrono::NaiveDateTime),
120    DateTimeUtc(chrono::DateTime<chrono::Utc>),
121    NaiveDate(chrono::NaiveDate),
122    Decimal(rust_decimal::Decimal),
123    Json(serde_json::Value),
124    OptI16(Option<i16>),
125    OptI32(Option<i32>),
126    OptI64(Option<i64>),
127    OptF32(Option<f32>),
128    OptF64(Option<f64>),
129    OptBool(Option<bool>),
130    OptString(Option<String>),
131    OptUuid(Option<uuid::Uuid>),
132    OptNaiveDateTime(Option<chrono::NaiveDateTime>),
133    OptDateTimeUtc(Option<chrono::DateTime<chrono::Utc>>),
134    OptNaiveDate(Option<chrono::NaiveDate>),
135    OptDecimal(Option<rust_decimal::Decimal>),
136    OptJson(Option<serde_json::Value>),
137    VecI16(Vec<i16>),
138    VecI32(Vec<i32>),
139    VecI64(Vec<i64>),
140    VecF32(Vec<f32>),
141    VecF64(Vec<f64>),
142    VecBool(Vec<bool>),
143    VecString(Vec<String>),
144    VecUuid(Vec<uuid::Uuid>),
145}
146
147impl tokio_postgres::types::ToSql for OwnedParam {
148    fn to_sql(
149        &self,
150        ty: &tokio_postgres::types::Type,
151        out: &mut bytes::BytesMut,
152    ) -> Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
153        match self {
154            OwnedParam::I16(v) => v.to_sql(ty, out),
155            OwnedParam::I32(v) => v.to_sql(ty, out),
156            OwnedParam::I64(v) => v.to_sql(ty, out),
157            OwnedParam::F32(v) => v.to_sql(ty, out),
158            OwnedParam::F64(v) => v.to_sql(ty, out),
159            OwnedParam::Bool(v) => v.to_sql(ty, out),
160            OwnedParam::String(v) => v.to_sql(ty, out),
161            OwnedParam::Uuid(v) => v.to_sql(ty, out),
162            OwnedParam::NaiveDateTime(v) => v.to_sql(ty, out),
163            OwnedParam::DateTimeUtc(v) => v.to_sql(ty, out),
164            OwnedParam::NaiveDate(v) => v.to_sql(ty, out),
165            OwnedParam::Decimal(v) => v.to_sql(ty, out),
166            OwnedParam::Json(v) => v.to_sql(ty, out),
167            OwnedParam::OptI16(v) => v.to_sql(ty, out),
168            OwnedParam::OptI32(v) => v.to_sql(ty, out),
169            OwnedParam::OptI64(v) => v.to_sql(ty, out),
170            OwnedParam::OptF32(v) => v.to_sql(ty, out),
171            OwnedParam::OptF64(v) => v.to_sql(ty, out),
172            OwnedParam::OptBool(v) => v.to_sql(ty, out),
173            OwnedParam::OptString(v) => v.to_sql(ty, out),
174            OwnedParam::OptUuid(v) => v.to_sql(ty, out),
175            OwnedParam::OptNaiveDateTime(v) => v.to_sql(ty, out),
176            OwnedParam::OptDateTimeUtc(v) => v.to_sql(ty, out),
177            OwnedParam::OptNaiveDate(v) => v.to_sql(ty, out),
178            OwnedParam::OptDecimal(v) => v.to_sql(ty, out),
179            OwnedParam::OptJson(v) => v.to_sql(ty, out),
180            OwnedParam::VecI16(v) => v.to_sql(ty, out),
181            OwnedParam::VecI32(v) => v.to_sql(ty, out),
182            OwnedParam::VecI64(v) => v.to_sql(ty, out),
183            OwnedParam::VecF32(v) => v.to_sql(ty, out),
184            OwnedParam::VecF64(v) => v.to_sql(ty, out),
185            OwnedParam::VecBool(v) => v.to_sql(ty, out),
186            OwnedParam::VecString(v) => v.to_sql(ty, out),
187            OwnedParam::VecUuid(v) => v.to_sql(ty, out),
188        }
189    }
190
191    fn accepts(ty: &tokio_postgres::types::Type) -> bool {
192        i16::accepts(ty)
193            || i32::accepts(ty)
194            || i64::accepts(ty)
195            || f32::accepts(ty)
196            || f64::accepts(ty)
197            || bool::accepts(ty)
198            || String::accepts(ty)
199            || uuid::Uuid::accepts(ty)
200            || chrono::NaiveDateTime::accepts(ty)
201            || chrono::DateTime::<chrono::Utc>::accepts(ty)
202            || chrono::NaiveDate::accepts(ty)
203            || rust_decimal::Decimal::accepts(ty)
204            || serde_json::Value::accepts(ty)
205            || <Option<i16>>::accepts(ty)
206            || <Option<i32>>::accepts(ty)
207            || <Option<i64>>::accepts(ty)
208            || <Option<f32>>::accepts(ty)
209            || <Option<f64>>::accepts(ty)
210            || <Option<bool>>::accepts(ty)
211            || <Option<String>>::accepts(ty)
212            || <Option<uuid::Uuid>>::accepts(ty)
213            || <Option<chrono::NaiveDateTime>>::accepts(ty)
214            || <Option<chrono::DateTime<chrono::Utc>>>::accepts(ty)
215            || <Option<chrono::NaiveDate>>::accepts(ty)
216            || <Option<rust_decimal::Decimal>>::accepts(ty)
217            || <Option<serde_json::Value>>::accepts(ty)
218            || <Vec<i16>>::accepts(ty)
219            || <Vec<i32>>::accepts(ty)
220            || <Vec<i64>>::accepts(ty)
221            || <Vec<f32>>::accepts(ty)
222            || <Vec<f64>>::accepts(ty)
223            || <Vec<bool>>::accepts(ty)
224            || <Vec<String>>::accepts(ty)
225            || <Vec<uuid::Uuid>>::accepts(ty)
226    }
227
228    tokio_postgres::types::to_sql_checked!();
229}
230
231/// Convert a `&dyn ToSqlParam` to an owned parameter value by downcasting.
232pub fn to_owned_param(param: &dyn ToSqlParam) -> OwnedParam {
233    let any = param.as_any();
234    // Non-optional types
235    if let Some(v) = any.downcast_ref::<i16>() { return OwnedParam::I16(*v); }
236    if let Some(v) = any.downcast_ref::<i32>() { return OwnedParam::I32(*v); }
237    if let Some(v) = any.downcast_ref::<i64>() { return OwnedParam::I64(*v); }
238    if let Some(v) = any.downcast_ref::<f32>() { return OwnedParam::F32(*v); }
239    if let Some(v) = any.downcast_ref::<f64>() { return OwnedParam::F64(*v); }
240    if let Some(v) = any.downcast_ref::<bool>() { return OwnedParam::Bool(*v); }
241    if let Some(v) = any.downcast_ref::<String>() { return OwnedParam::String(v.clone()); }
242    if let Some(v) = any.downcast_ref::<uuid::Uuid>() { return OwnedParam::Uuid(*v); }
243    if let Some(v) = any.downcast_ref::<chrono::NaiveDateTime>() { return OwnedParam::NaiveDateTime(*v); }
244    if let Some(v) = any.downcast_ref::<chrono::DateTime<chrono::Utc>>() { return OwnedParam::DateTimeUtc(*v); }
245    if let Some(v) = any.downcast_ref::<chrono::NaiveDate>() { return OwnedParam::NaiveDate(*v); }
246    if let Some(v) = any.downcast_ref::<rust_decimal::Decimal>() { return OwnedParam::Decimal(*v); }
247    if let Some(v) = any.downcast_ref::<serde_json::Value>() { return OwnedParam::Json(v.clone()); }
248    // Optional types
249    if let Some(v) = any.downcast_ref::<Option<i16>>() { return OwnedParam::OptI16(*v); }
250    if let Some(v) = any.downcast_ref::<Option<i32>>() { return OwnedParam::OptI32(*v); }
251    if let Some(v) = any.downcast_ref::<Option<i64>>() { return OwnedParam::OptI64(*v); }
252    if let Some(v) = any.downcast_ref::<Option<f32>>() { return OwnedParam::OptF32(*v); }
253    if let Some(v) = any.downcast_ref::<Option<f64>>() { return OwnedParam::OptF64(*v); }
254    if let Some(v) = any.downcast_ref::<Option<bool>>() { return OwnedParam::OptBool(*v); }
255    if let Some(v) = any.downcast_ref::<Option<String>>() { return OwnedParam::OptString(v.clone()); }
256    if let Some(v) = any.downcast_ref::<Option<uuid::Uuid>>() { return OwnedParam::OptUuid(*v); }
257    if let Some(v) = any.downcast_ref::<Option<chrono::NaiveDateTime>>() { return OwnedParam::OptNaiveDateTime(*v); }
258    if let Some(v) = any.downcast_ref::<Option<chrono::DateTime<chrono::Utc>>>() { return OwnedParam::OptDateTimeUtc(*v); }
259    if let Some(v) = any.downcast_ref::<Option<chrono::NaiveDate>>() { return OwnedParam::OptNaiveDate(*v); }
260    if let Some(v) = any.downcast_ref::<Option<rust_decimal::Decimal>>() { return OwnedParam::OptDecimal(*v); }
261    if let Some(v) = any.downcast_ref::<Option<serde_json::Value>>() { return OwnedParam::OptJson(v.clone()); }
262    // Vec types
263    if let Some(v) = any.downcast_ref::<Vec<i16>>() { return OwnedParam::VecI16(v.clone()); }
264    if let Some(v) = any.downcast_ref::<Vec<i32>>() { return OwnedParam::VecI32(v.clone()); }
265    if let Some(v) = any.downcast_ref::<Vec<i64>>() { return OwnedParam::VecI64(v.clone()); }
266    if let Some(v) = any.downcast_ref::<Vec<f32>>() { return OwnedParam::VecF32(v.clone()); }
267    if let Some(v) = any.downcast_ref::<Vec<f64>>() { return OwnedParam::VecF64(v.clone()); }
268    if let Some(v) = any.downcast_ref::<Vec<bool>>() { return OwnedParam::VecBool(v.clone()); }
269    if let Some(v) = any.downcast_ref::<Vec<String>>() { return OwnedParam::VecString(v.clone()); }
270    if let Some(v) = any.downcast_ref::<Vec<uuid::Uuid>>() { return OwnedParam::VecUuid(v.clone()); }
271    panic!("Unsupported parameter type for tokio-postgres backend");
272}
273
274/// Convert a slice of borrowed params to a vec of owned params.
275pub fn convert_params(params: &[&dyn ToSqlParam]) -> Vec<OwnedParam> {
276    params.iter().map(|p| to_owned_param(*p)).collect()
277}
278
279// ---- Client wrapper ----
280
281/// A wrapper around `tokio_postgres::Client` implementing the NextSQL `Client` trait.
282pub struct PgClient {
283    inner: tokio_postgres::Client,
284}
285
286impl PgClient {
287    pub fn new(client: tokio_postgres::Client) -> Self {
288        Self { inner: client }
289    }
290
291    /// Get a reference to the underlying `tokio_postgres::Client`.
292    pub fn inner(&self) -> &tokio_postgres::Client {
293        &self.inner
294    }
295
296    /// Consume this wrapper and return the underlying `tokio_postgres::Client`.
297    pub fn into_inner(self) -> tokio_postgres::Client {
298        self.inner
299    }
300}
301
302impl Client for PgClient {
303    type Error = tokio_postgres::Error;
304    type Row = PgRow;
305    type Transaction<'a> = PgTransaction<'a>;
306
307    fn query(
308        &self,
309        sql: &str,
310        params: &[&dyn ToSqlParam],
311    ) -> impl std::future::Future<Output = Result<Vec<Self::Row>, Self::Error>> + Send {
312        let owned_params = convert_params(params);
313        let sql = sql.to_owned();
314        let client = &self.inner;
315        async move {
316            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
317                owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
318            let rows = client.query(&sql, &param_refs).await?;
319            Ok(rows.into_iter().map(PgRow).collect())
320        }
321    }
322
323    fn execute(
324        &self,
325        sql: &str,
326        params: &[&dyn ToSqlParam],
327    ) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
328        let owned_params = convert_params(params);
329        let sql = sql.to_owned();
330        let client = &self.inner;
331        async move {
332            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
333                owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
334            client.execute(&sql, &param_refs).await
335        }
336    }
337
338    fn transaction(&mut self) -> impl std::future::Future<Output = Result<Self::Transaction<'_>, Self::Error>> + Send {
339        async move {
340            let tx = self.inner.transaction().await?;
341            Ok(PgTransaction { inner: tx })
342        }
343    }
344}
345
346// ---- Transaction wrapper ----
347
348/// A wrapper around `tokio_postgres::Transaction` implementing the NextSQL `Transaction` trait.
349/// Supports nested transactions (savepoints) and commit/rollback.
350/// Drop without calling `commit()` will rollback the transaction.
351pub struct PgTransaction<'a> {
352    inner: tokio_postgres::Transaction<'a>,
353}
354
355impl<'a> PgTransaction<'a> {
356    pub fn new(tx: tokio_postgres::Transaction<'a>) -> Self {
357        Self { inner: tx }
358    }
359}
360
361impl Transaction for PgTransaction<'_> {
362    type Error = tokio_postgres::Error;
363    type Row = PgRow;
364    type Nested<'a> = PgTransaction<'a> where Self: 'a;
365
366    fn query(
367        &self,
368        sql: &str,
369        params: &[&dyn ToSqlParam],
370    ) -> impl std::future::Future<Output = Result<Vec<Self::Row>, Self::Error>> + Send {
371        let owned_params = convert_params(params);
372        let sql = sql.to_owned();
373        let client = &self.inner;
374        async move {
375            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
376                owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
377            let rows = client.query(&sql, &param_refs).await?;
378            Ok(rows.into_iter().map(PgRow).collect())
379        }
380    }
381
382    fn execute(
383        &self,
384        sql: &str,
385        params: &[&dyn ToSqlParam],
386    ) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
387        let owned_params = convert_params(params);
388        let sql = sql.to_owned();
389        let client = &self.inner;
390        async move {
391            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
392                owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
393            client.execute(&sql, &param_refs).await
394        }
395    }
396
397    fn transaction(&mut self) -> impl std::future::Future<Output = Result<Self::Nested<'_>, Self::Error>> + Send {
398        async move {
399            let tx = self.inner.transaction().await?;
400            Ok(PgTransaction { inner: tx })
401        }
402    }
403
404    fn commit(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
405        async move { self.inner.commit().await }
406    }
407
408    fn rollback(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
409        async move { self.inner.rollback().await }
410    }
411}